static oauth2 client from config

This commit is contained in:
Quentin Gliech
2021-07-01 14:56:27 +02:00
parent 4422b63dfd
commit 0e30a1fb0c
6 changed files with 116 additions and 34 deletions

34
Cargo.lock generated
View File

@@ -774,6 +774,12 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "212d0f5754cb6769937f4501cc0e67f4f4483c8d2c3e1e922ee9edbe4ab4c7c0"
[[package]]
name = "dtoa"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56899898ce76aaf4a0f24d914c97ea6ed976d42fec6ad33fcbb0a1103e07b2b0"
[[package]]
name = "event-listener"
version = "2.5.1"
@@ -820,6 +826,7 @@ dependencies = [
"atomic",
"pear",
"serde",
"serde_yaml",
"uncased",
"version_check",
]
@@ -1201,6 +1208,12 @@ version = "0.2.96"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5600b4e6efc5421841a2138a6b082e07fe12f9aaa12783d50e5d13325b26b4fc"
[[package]]
name = "linked-hash-map"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fb9b38af92608140b86b693604b9ffcc5824240a484d1ecd4795bacb2fe88f3"
[[package]]
name = "log"
version = "0.4.14"
@@ -1761,6 +1774,18 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_yaml"
version = "0.8.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15654ed4ab61726bf918a39cb8d98a2e2995b002387807fa6ba58fdf7f59bb23"
dependencies = [
"dtoa",
"linked-hash-map",
"serde",
"yaml-rust",
]
[[package]]
name = "sha-1"
version = "0.8.2"
@@ -2502,6 +2527,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "yaml-rust"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56c1936c4cc7a1c9ab21a1ebb602eb942ba868cbd44a99cb7cdc5892335e1c85"
dependencies = [
"linked-hash-map",
]
[[package]]
name = "yansi"
version = "0.5.0"

View File

@@ -10,7 +10,7 @@ async-std = { version = "1.8.0", features = ["attributes"] }
tide = "0.16.0"
tracing = "0.1.26"
tracing-subscriber = "0.2.18"
figment = { version = "0.10.5", features = ["env"] }
figment = { version = "0.10.5", features = ["env", "yaml"] }
url = "2.2.2"
oauth2-types = { path = "../oauth2-types" }
thiserror = "1.0.25"

View File

@@ -1,44 +1,71 @@
use figment::{error::Error as FigmentError, providers::Env, Figment};
use figment::{
error::Error as FigmentError,
providers::{Env, Format, Yaml},
Figment,
};
use serde::Deserialize;
use url::Url;
#[derive(Debug, Deserialize)]
pub struct OAuth2 {
pub issuer: Url,
pub struct OAuth2ClientConfig {
pub client_id: String,
#[serde(default)]
pub redirect_uris: Option<Vec<Url>>,
}
impl Default for OAuth2 {
fn default_oauth2_issuer() -> Url {
"http://[::]:8080".parse().unwrap()
}
#[derive(Debug, Deserialize)]
pub struct OAuth2Config {
#[serde(default = "default_oauth2_issuer")]
pub issuer: Url,
#[serde(default)]
pub clients: Vec<OAuth2ClientConfig>,
}
impl Default for OAuth2Config {
fn default() -> Self {
Self {
issuer: "http://[::]:8080".parse().unwrap(),
issuer: default_oauth2_issuer(),
clients: Default::default(),
}
}
}
fn default_listener_address() -> String {
"[::]:8080".into()
}
#[derive(Debug, Deserialize)]
pub struct Listener {
pub struct ListenerConfig {
#[serde(default = "default_listener_address")]
pub address: String,
}
impl Default for Listener {
impl Default for ListenerConfig {
fn default() -> Self {
Listener {
address: "[::]:8080".into(),
ListenerConfig {
address: default_listener_address(),
}
}
}
#[derive(Debug, Default, Deserialize)]
#[serde(default)]
pub struct Config {
pub oauth2: OAuth2,
pub listener: Listener,
pub struct RootConfig {
pub oauth2: OAuth2Config,
pub listener: ListenerConfig,
}
impl Config {
pub fn load() -> Result<Config, FigmentError> {
impl RootConfig {
pub fn load() -> Result<RootConfig, FigmentError> {
Figment::new()
.merge(Env::prefixed("MAS_").split("_"))
.merge(Yaml::file("config.yaml"))
.extract()
}
}

View File

@@ -7,7 +7,7 @@ mod state;
mod storage;
mod templates;
use self::config::Config;
use self::config::RootConfig;
use self::state::State;
#[async_std::main]
@@ -20,16 +20,20 @@ async fn main() -> tide::Result<()> {
subscriber.try_init()?;
// Loading the config
let config = Config::load()?;
let address = config.listener.address.clone();
let config = RootConfig::load()?;
// Load and compile the templates
let templates = self::templates::load()?;
// Create the shared state
let state = State::new(config, templates);
state
.storage()
.load_static_clients(&state.config().oauth2.clients)
.await;
// Start the server
let address = state.config().listener.address.clone();
let mut app = tide::with_state(state);
app.with(tide_tracing::TraceMiddleware::new());
self::handlers::install(&mut app);

View File

@@ -9,11 +9,11 @@ use tide::{
};
use url::Url;
use crate::{config::Config, storage::Storage};
use crate::{config::RootConfig, storage::Storage};
#[derive(Clone)]
pub struct State {
config: Arc<Config>,
config: Arc<RootConfig>,
templates: Arc<Tera>,
storage: Arc<Storage>,
session_store: Arc<MemoryStore>,
@@ -27,7 +27,7 @@ impl std::fmt::Debug for State {
}
impl State {
pub fn new(config: Config, templates: Tera) -> Self {
pub fn new(config: RootConfig, templates: Tera) -> Self {
Self {
config: Arc::new(config),
templates: Arc::new(templates),
@@ -39,6 +39,10 @@ impl State {
}
}
pub fn config(&self) -> &RootConfig {
&self.config
}
pub fn storage(&self) -> &Storage {
&self.storage
}

View File

@@ -5,6 +5,8 @@ use serde::Serialize;
use thiserror::Error;
use url::Url;
use crate::config::OAuth2ClientConfig;
#[derive(Debug, Default)]
pub struct Storage {
clients: RwLock<HashMap<String, Client>>,
@@ -69,23 +71,34 @@ impl User {
}
impl Storage {
pub async fn lookup_client(&self, client_id: &str) -> Result<Client, ClientLookupError> {
// First lookup for an existing client
let clients = self.clients.upgradable_read().await;
if let Some(client) = clients.get(client_id) {
Ok(client.clone())
} else {
// If it does not exist, insert a new client
let mut clients = RwLockUpgradableReadGuard::upgrade(clients).await;
let new_client = Client {
client_id: client_id.to_string(),
redirect_uris: None,
pub async fn load_static_clients(&self, clients: &[OAuth2ClientConfig]) {
let mut storage = self.clients.write().await;
for config in clients {
let redirect_uris = config
.redirect_uris
.as_ref()
.map(|uris| uris.iter().cloned().collect());
let client_id = config.client_id.clone();
let client = Client {
client_id: client_id.clone(),
redirect_uris,
};
clients.insert(client_id.to_string(), new_client.clone());
Ok(new_client)
// TODO: we could warn about duplicate clients here
storage.insert(client_id, client);
}
}
pub async fn lookup_client(&self, client_id: &str) -> Result<Client, ClientLookupError> {
self.clients
.read()
.await
.get(client_id)
.cloned()
.ok_or(ClientLookupError)
}
pub async fn login(&self, name: &str, password: &str) -> Result<User, UserLoginError> {
// Hardcoded bad password to test login failures
if password == "bad" {