static oauth2 client from config
This commit is contained in:
34
Cargo.lock
generated
34
Cargo.lock
generated
@@ -774,6 +774,12 @@ version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "212d0f5754cb6769937f4501cc0e67f4f4483c8d2c3e1e922ee9edbe4ab4c7c0"
|
||||
|
||||
[[package]]
|
||||
name = "dtoa"
|
||||
version = "0.4.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56899898ce76aaf4a0f24d914c97ea6ed976d42fec6ad33fcbb0a1103e07b2b0"
|
||||
|
||||
[[package]]
|
||||
name = "event-listener"
|
||||
version = "2.5.1"
|
||||
@@ -820,6 +826,7 @@ dependencies = [
|
||||
"atomic",
|
||||
"pear",
|
||||
"serde",
|
||||
"serde_yaml",
|
||||
"uncased",
|
||||
"version_check",
|
||||
]
|
||||
@@ -1201,6 +1208,12 @@ version = "0.2.96"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5600b4e6efc5421841a2138a6b082e07fe12f9aaa12783d50e5d13325b26b4fc"
|
||||
|
||||
[[package]]
|
||||
name = "linked-hash-map"
|
||||
version = "0.5.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7fb9b38af92608140b86b693604b9ffcc5824240a484d1ecd4795bacb2fe88f3"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.14"
|
||||
@@ -1761,6 +1774,18 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_yaml"
|
||||
version = "0.8.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "15654ed4ab61726bf918a39cb8d98a2e2995b002387807fa6ba58fdf7f59bb23"
|
||||
dependencies = [
|
||||
"dtoa",
|
||||
"linked-hash-map",
|
||||
"serde",
|
||||
"yaml-rust",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha-1"
|
||||
version = "0.8.2"
|
||||
@@ -2502,6 +2527,15 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "yaml-rust"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56c1936c4cc7a1c9ab21a1ebb602eb942ba868cbd44a99cb7cdc5892335e1c85"
|
||||
dependencies = [
|
||||
"linked-hash-map",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yansi"
|
||||
version = "0.5.0"
|
||||
|
||||
@@ -10,7 +10,7 @@ async-std = { version = "1.8.0", features = ["attributes"] }
|
||||
tide = "0.16.0"
|
||||
tracing = "0.1.26"
|
||||
tracing-subscriber = "0.2.18"
|
||||
figment = { version = "0.10.5", features = ["env"] }
|
||||
figment = { version = "0.10.5", features = ["env", "yaml"] }
|
||||
url = "2.2.2"
|
||||
oauth2-types = { path = "../oauth2-types" }
|
||||
thiserror = "1.0.25"
|
||||
|
||||
@@ -1,44 +1,71 @@
|
||||
use figment::{error::Error as FigmentError, providers::Env, Figment};
|
||||
use figment::{
|
||||
error::Error as FigmentError,
|
||||
providers::{Env, Format, Yaml},
|
||||
Figment,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OAuth2 {
|
||||
pub issuer: Url,
|
||||
pub struct OAuth2ClientConfig {
|
||||
pub client_id: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub redirect_uris: Option<Vec<Url>>,
|
||||
}
|
||||
|
||||
impl Default for OAuth2 {
|
||||
fn default_oauth2_issuer() -> Url {
|
||||
"http://[::]:8080".parse().unwrap()
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OAuth2Config {
|
||||
#[serde(default = "default_oauth2_issuer")]
|
||||
pub issuer: Url,
|
||||
|
||||
#[serde(default)]
|
||||
pub clients: Vec<OAuth2ClientConfig>,
|
||||
}
|
||||
|
||||
impl Default for OAuth2Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
issuer: "http://[::]:8080".parse().unwrap(),
|
||||
issuer: default_oauth2_issuer(),
|
||||
clients: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_listener_address() -> String {
|
||||
"[::]:8080".into()
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Listener {
|
||||
pub struct ListenerConfig {
|
||||
#[serde(default = "default_listener_address")]
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
impl Default for Listener {
|
||||
impl Default for ListenerConfig {
|
||||
fn default() -> Self {
|
||||
Listener {
|
||||
address: "[::]:8080".into(),
|
||||
ListenerConfig {
|
||||
address: default_listener_address(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct Config {
|
||||
pub oauth2: OAuth2,
|
||||
pub listener: Listener,
|
||||
pub struct RootConfig {
|
||||
pub oauth2: OAuth2Config,
|
||||
pub listener: ListenerConfig,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn load() -> Result<Config, FigmentError> {
|
||||
impl RootConfig {
|
||||
pub fn load() -> Result<RootConfig, FigmentError> {
|
||||
Figment::new()
|
||||
.merge(Env::prefixed("MAS_").split("_"))
|
||||
.merge(Yaml::file("config.yaml"))
|
||||
.extract()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ mod state;
|
||||
mod storage;
|
||||
mod templates;
|
||||
|
||||
use self::config::Config;
|
||||
use self::config::RootConfig;
|
||||
use self::state::State;
|
||||
|
||||
#[async_std::main]
|
||||
@@ -20,16 +20,20 @@ async fn main() -> tide::Result<()> {
|
||||
subscriber.try_init()?;
|
||||
|
||||
// Loading the config
|
||||
let config = Config::load()?;
|
||||
let address = config.listener.address.clone();
|
||||
let config = RootConfig::load()?;
|
||||
|
||||
// Load and compile the templates
|
||||
let templates = self::templates::load()?;
|
||||
|
||||
// Create the shared state
|
||||
let state = State::new(config, templates);
|
||||
state
|
||||
.storage()
|
||||
.load_static_clients(&state.config().oauth2.clients)
|
||||
.await;
|
||||
|
||||
// Start the server
|
||||
let address = state.config().listener.address.clone();
|
||||
let mut app = tide::with_state(state);
|
||||
app.with(tide_tracing::TraceMiddleware::new());
|
||||
self::handlers::install(&mut app);
|
||||
|
||||
@@ -9,11 +9,11 @@ use tide::{
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use crate::{config::Config, storage::Storage};
|
||||
use crate::{config::RootConfig, storage::Storage};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct State {
|
||||
config: Arc<Config>,
|
||||
config: Arc<RootConfig>,
|
||||
templates: Arc<Tera>,
|
||||
storage: Arc<Storage>,
|
||||
session_store: Arc<MemoryStore>,
|
||||
@@ -27,7 +27,7 @@ impl std::fmt::Debug for State {
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(config: Config, templates: Tera) -> Self {
|
||||
pub fn new(config: RootConfig, templates: Tera) -> Self {
|
||||
Self {
|
||||
config: Arc::new(config),
|
||||
templates: Arc::new(templates),
|
||||
@@ -39,6 +39,10 @@ impl State {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &RootConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn storage(&self) -> &Storage {
|
||||
&self.storage
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use url::Url;
|
||||
|
||||
use crate::config::OAuth2ClientConfig;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Storage {
|
||||
clients: RwLock<HashMap<String, Client>>,
|
||||
@@ -69,23 +71,34 @@ impl User {
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub async fn lookup_client(&self, client_id: &str) -> Result<Client, ClientLookupError> {
|
||||
// First lookup for an existing client
|
||||
let clients = self.clients.upgradable_read().await;
|
||||
if let Some(client) = clients.get(client_id) {
|
||||
Ok(client.clone())
|
||||
} else {
|
||||
// If it does not exist, insert a new client
|
||||
let mut clients = RwLockUpgradableReadGuard::upgrade(clients).await;
|
||||
let new_client = Client {
|
||||
client_id: client_id.to_string(),
|
||||
redirect_uris: None,
|
||||
pub async fn load_static_clients(&self, clients: &[OAuth2ClientConfig]) {
|
||||
let mut storage = self.clients.write().await;
|
||||
for config in clients {
|
||||
let redirect_uris = config
|
||||
.redirect_uris
|
||||
.as_ref()
|
||||
.map(|uris| uris.iter().cloned().collect());
|
||||
let client_id = config.client_id.clone();
|
||||
|
||||
let client = Client {
|
||||
client_id: client_id.clone(),
|
||||
redirect_uris,
|
||||
};
|
||||
clients.insert(client_id.to_string(), new_client.clone());
|
||||
Ok(new_client)
|
||||
|
||||
// TODO: we could warn about duplicate clients here
|
||||
storage.insert(client_id, client);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn lookup_client(&self, client_id: &str) -> Result<Client, ClientLookupError> {
|
||||
self.clients
|
||||
.read()
|
||||
.await
|
||||
.get(client_id)
|
||||
.cloned()
|
||||
.ok_or(ClientLookupError)
|
||||
}
|
||||
|
||||
pub async fn login(&self, name: &str, password: &str) -> Result<User, UserLoginError> {
|
||||
// Hardcoded bad password to test login failures
|
||||
if password == "bad" {
|
||||
|
||||
Reference in New Issue
Block a user