From c8090d8ed4554ff730e6f5ced70569f861fdd91e Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 29 Jul 2021 14:56:33 +0200 Subject: [PATCH] WIP: migrate to warp, part 2 --- Cargo.lock | 137 +++++++++++++++-- matrix-authentication-service/Cargo.toml | 2 + .../src/config/csrf.rs | 23 +-- matrix-authentication-service/src/csrf.rs | 13 +- .../src/filters/csrf.rs | 139 ++++++++++-------- .../src/filters/mod.rs | 4 +- .../src/handlers/mod.rs | 40 +++-- .../src/handlers/views/index.rs | 25 +++- .../src/handlers/views/login.rs | 38 +++-- .../src/handlers/views/mod.rs | 2 +- matrix-authentication-service/src/main.rs | 2 +- .../src/storage/mod.rs | 2 +- .../src/templates.rs | 43 +++++- 13 files changed, 338 insertions(+), 132 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9eed82681..d63227242 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,12 +57,48 @@ dependencies = [ "password-hash", ] +[[package]] +name = "arrayref" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4c527152e37cf757a3f78aae5a06fbeefdb07ccc535c980a3208ee3060dd544" + [[package]] name = "arrayvec" version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" +[[package]] +name = "async-lock" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6a8ea61bf9947a1007c5cada31e647dbc77b103c679858150003ba697ea798b" +dependencies = [ + "event-listener", +] + +[[package]] +name = "async-session" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07da4ce523b4e2ebaaf330746761df23a465b951a83d84bbce4233dabedae630" +dependencies = [ + "anyhow", + "async-lock", + "async-trait", + "base64", + "bincode", + "blake3", + "chrono", + "hmac 0.11.0", + "log", + "rand 0.8.4", + "serde", + "serde_json", + "sha2", +] + [[package]] name = "async-trait" version = "0.1.50" @@ -165,6 +201,21 @@ dependencies = [ "opaque-debug 0.3.0", ] +[[package]] +name = "blake3" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b64485778c4f16a6a5a9d335e80d449ac6c70cdd6a06d2af18a6f6f775a125b3" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if 0.1.10", + "constant_time_eq", + "crypto-mac 0.8.0", + "digest 0.9.0", +] + [[package]] name = "block-buffer" version = "0.7.3" @@ -250,6 +301,12 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e70cc2f62c6ce1868963827bd677764c62d07c3d9a3e1fb1177ee1a9ab199eb2" +[[package]] +name = "cfg-if" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" + [[package]] name = "cfg-if" version = "1.0.0" @@ -262,7 +319,7 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea8756167ea0aca10e066cdbe7813bd71d2f24e69b0bc7b50509590cef2ce0b9" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "cipher", "cpufeatures", "zeroize", @@ -352,6 +409,12 @@ version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f92cfa0fd5690b3cf8c1ef2cabbd9b7ef22fa53cf5e1f92b05103f6d5d1cf6e7" +[[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" + [[package]] name = "cookie" version = "0.15.1" @@ -386,7 +449,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "crossbeam-utils", ] @@ -396,7 +459,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b10ddc024425c88c2ad148c1b0fd53f4c6d38db9697c9f1588381212fa657c9" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "crossbeam-utils", ] @@ -406,7 +469,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "lazy_static", ] @@ -430,6 +493,16 @@ dependencies = [ "subtle", ] +[[package]] +name = "crypto-mac" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1d1a86f49236c215f271d40892d5fc950490551400b02ef360692c29815c714" +dependencies = [ + "generic-array 0.14.4", + "subtle", +] + [[package]] name = "darling" version = "0.13.0" @@ -545,6 +618,12 @@ version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" +[[package]] +name = "event-listener" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7531096570974c3a9dcf9e4b8e1cede1ec26cf5046219fb3b9d897503b9be59" + [[package]] name = "fake-simd" version = "0.1.2" @@ -708,7 +787,7 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "libc", "wasi 0.9.0+wasi-snapshot-preview1", ] @@ -719,7 +798,7 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "libc", "wasi 0.10.0+wasi-snapshot-preview1", ] @@ -844,6 +923,16 @@ dependencies = [ "digest 0.9.0", ] +[[package]] +name = "hmac" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a2a2320eb7ec0ebe8da8f744d7812d9fc4cb4d09344ac01898dbcb6a20ae69b" +dependencies = [ + "crypto-mac 0.11.1", + "digest 0.9.0", +] + [[package]] name = "http" version = "0.2.4" @@ -983,7 +1072,7 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bee0328b1209d157ef001c94dd85b4f8f64139adb0eac2659f4b08382b2f474d" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", ] [[package]] @@ -1024,7 +1113,7 @@ checksum = "6607c62aa161d23d17a9072cc5da0be67cdfc89d3afb1e8d9c842bebc2525ffe" dependencies = [ "arrayvec", "bitflags", - "cfg-if", + "cfg-if 1.0.0", "ryu", "static_assertions", ] @@ -1056,7 +1145,7 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", ] [[package]] @@ -1094,6 +1183,7 @@ dependencies = [ "cookie", "data-encoding", "figment", + "headers", "mime", "oauth2-types", "password-hash", @@ -1110,6 +1200,7 @@ dependencies = [ "tracing-subscriber", "url", "warp", + "warp-sessions", ] [[package]] @@ -1291,7 +1382,7 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa7a782938e745763fe6907fc6ba86946d72f49fe7e21de074e08128a99fb018" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "instant", "libc", "redox_syscall", @@ -1917,7 +2008,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a0c8611594e2ab4ebbf06ec7cbbf0a99450b8570e96cbf5188b5d5f6ef18d81" dependencies = [ "block-buffer 0.9.0", - "cfg-if", + "cfg-if 1.0.0", "cpufeatures", "digest 0.9.0", "opaque-debug 0.3.0", @@ -1936,7 +2027,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b362ae5752fd2137731f9fa25fd4d9058af34666ca1966fb969119cc35719f12" dependencies = [ "block-buffer 0.9.0", - "cfg-if", + "cfg-if 1.0.0", "cpufeatures", "digest 0.9.0", "opaque-debug 0.3.0", @@ -2044,7 +2135,7 @@ dependencies = [ "futures-util", "hashlink", "hex", - "hmac", + "hmac 0.10.1", "itoa", "libc", "log", @@ -2234,7 +2325,7 @@ version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "libc", "rand 0.8.4", "redox_syscall", @@ -2467,7 +2558,7 @@ version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09adeb8c97449311ccd28a427f96fb563e7fd31aabf994189879d9da2394b89d" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "log", "pin-project-lite", "tracing-attributes", @@ -2796,6 +2887,20 @@ dependencies = [ "tracing", ] +[[package]] +name = "warp-sessions" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c193bbfd203c4fa0b1ce64dce5d3ce9eb78c03f39723e080771e160fd416c145" +dependencies = [ + "async-session", + "async-trait", + "http", + "serde", + "tokio", + "warp", +] + [[package]] name = "wasi" version = "0.9.0+wasi-snapshot-preview1" @@ -2814,7 +2919,7 @@ version = "0.2.74" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d54ee1d4ed486f78874278e63e4069fc1ab9f6a18ca492076ffb90c5eb2997fd" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "wasm-bindgen-macro", ] diff --git a/matrix-authentication-service/Cargo.toml b/matrix-authentication-service/Cargo.toml index bdac8f9f1..529054bd1 100644 --- a/matrix-authentication-service/Cargo.toml +++ b/matrix-authentication-service/Cargo.toml @@ -50,3 +50,5 @@ cookie = "0.15.1" chacha20poly1305 = { version = "0.8.1", features = ["std"] } oauth2-types = { path = "../oauth2-types" } +headers = "0.3.4" +warp-sessions = "1.0.15" diff --git a/matrix-authentication-service/src/config/csrf.rs b/matrix-authentication-service/src/config/csrf.rs index 081fd1686..b6b0840b9 100644 --- a/matrix-authentication-service/src/config/csrf.rs +++ b/matrix-authentication-service/src/config/csrf.rs @@ -16,6 +16,9 @@ use chrono::Duration; use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema}; use serde::{Deserialize, Serialize}; use serde_with::serde_as; +use warp::filters::BoxedFilter; + +use crate::filters::{csrf::extract_or_generate, CsrfToken}; use super::ConfigurationSection; @@ -40,10 +43,10 @@ fn ttl_schema(gen: &mut SchemaGenerator) -> Schema { pub struct CsrfConfig { #[schemars(schema_with = "key_schema")] #[serde_as(as = "serde_with::hex::Hex")] - key: [u8; 32], + pub key: [u8; 32], #[serde(default = "default_cookie_name")] - cookie_name: String, + pub cookie_name: String, #[schemars(schema_with = "ttl_schema")] #[serde(default = "default_ttl")] @@ -51,14 +54,14 @@ pub struct CsrfConfig { ttl: Duration, } -// impl CsrfConfig { -// pub fn into_middleware(self) -> impl Middleware { -// let ttl = self.ttl; -// let cookie_name = self.cookie_name.clone(); -// let protection = self.key; -// CsrfMiddleware::new(protection, cookie_name, ttl) -// } -// } +impl CsrfConfig { + pub fn into_extract_filter(self) -> BoxedFilter<(CsrfToken,)> { + let ttl = self.ttl; + // TODO: we should probably not leak here + let cookie_name = Box::leak(Box::new(self.cookie_name)); + extract_or_generate(self.key, cookie_name, ttl) + } +} impl ConfigurationSection<'_> for CsrfConfig { fn path() -> &'static str { diff --git a/matrix-authentication-service/src/csrf.rs b/matrix-authentication-service/src/csrf.rs index 80b1fe900..866e91f68 100644 --- a/matrix-authentication-service/src/csrf.rs +++ b/matrix-authentication-service/src/csrf.rs @@ -14,7 +14,7 @@ use serde::Deserialize; -use crate::middlewares::CsrfToken; +use crate::filters::CsrfToken; /// A CSRF-protected form #[derive(Deserialize)] @@ -26,16 +26,9 @@ pub struct CsrfForm { } impl CsrfForm { - pub fn verify_csrf(self, request: &tide::Request) -> tide::Result - where - State: Clone + Send + Sync + 'static, - { + pub fn verify_csrf(self, token: &CsrfToken) -> anyhow::Result { // Verify CSRF from request - let csrf_token: &CsrfToken = request - .ext() - .ok_or_else(|| anyhow::anyhow!("missing csrf cookie"))?; // TODO: proper error - - csrf_token.verify_form_value(&self.csrf)?; + token.verify_form_value(&self.csrf)?; Ok(self.inner) } } diff --git a/matrix-authentication-service/src/filters/csrf.rs b/matrix-authentication-service/src/filters/csrf.rs index ed0b9927e..a236a0273 100644 --- a/matrix-authentication-service/src/filters/csrf.rs +++ b/matrix-authentication-service/src/filters/csrf.rs @@ -24,9 +24,12 @@ use chacha20poly1305::{ use chrono::{DateTime, Duration, Utc}; use cookie::{Cookie, CookieBuilder, SameSite}; use data_encoding::BASE64URL_NOPAD; +use headers::{Header, HeaderMapExt, HeaderValue, SetCookie}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, TimestampSeconds}; -use warp::filters::BoxedFilter; +use warp::{filters::BoxedFilter, Filter, Rejection, Reply}; + +use crate::errors::WrapError; #[serde_as] #[derive(Serialize, Deserialize)] @@ -100,17 +103,16 @@ impl UnencryptedToken { name: &'n str, key: &[u8; 32], ) -> anyhow::Result> { - let value = self.encrypt(key)?.to_cookie_value()?; // Converting expiration time from `chrono` to `time` via native `SystemTime` let expires: SystemTime = self.expiration.into(); - Ok(Cookie::build(name, value) - .expires(expires) - .http_only(true) - .same_site(SameSite::Strict)) + Ok(self + .encrypt(key)? + .to_cookie_builder(name)? + .expires(Some(expires.into()))) } - fn from_cookie(cookie: &Cookie, key: &[u8; 32]) -> anyhow::Result { - let encrypted = EncryptedToken::from_cookie_value(cookie.value())?; + fn from_cookie_value(value: &str, key: &[u8; 32]) -> anyhow::Result { + let encrypted = EncryptedToken::from_cookie_value(value)?; let token = encrypted.decrypt(key)?; Ok(token) } @@ -147,75 +149,82 @@ impl EncryptedToken { let content = bincode::deserialize(&raw)?; Ok(content) } -} -#[derive(Debug, Clone)] -pub struct Middleware { - key: [u8; 32], - ttl: Duration, - cookie_name: String, -} - -impl Middleware { - /// Create a new CSRF protection middleware from a key, cookie name and TTL - pub fn new(key: [u8; 32], cookie_name: String, ttl: Duration) -> Self { - Self { - key, - ttl, - cookie_name, - } + fn to_cookie_builder<'c, 'n: 'c>(&self, name: &'n str) -> anyhow::Result> { + let value = self.to_cookie_value()?; + Ok(Cookie::build(name, value) + .http_only(true) + .same_site(SameSite::Strict)) } } pub fn extract_or_generate( key: [u8; 32], - cookie_name: String, + cookie_name: &'static str, ttl: Duration, ) -> BoxedFilter<(UnencryptedToken,)> { - warp::cookie::optional(cookie_name) + warp::any() + .map(move || (key, ttl)) + .untuple_one() + .and(warp::cookie::optional(cookie_name)) + .and_then(|key, ttl, maybe_cookie: Option| async move { + // Explicitely specify the "Error" type here to have the `?` operation working + Ok::<_, Rejection>( + maybe_cookie + // Try decrypting the cookie + .map(|cookie| UnencryptedToken::from_cookie_value(&cookie, &key)) + // If there was an error decrypting it, bail out here + .transpose() + .wrap_error()? + // Verify its TTL (but do not hard-error if it expired) + .and_then(|token| token.verify_expiration().ok()) + .map_or_else( + // Generate a new token if no valid one were found + || UnencryptedToken::generate(ttl), + // Else, refresh the expiration of the token + |token| token.refresh(ttl), + ), + ) + }) + .boxed() } -/* +pub struct WithTypedHeader { + reply: R, + header: H, +} -#[async_trait] -impl tide::Middleware for Middleware +impl Reply for WithTypedHeader where - State: Clone + Send + Sync + 'static, + R: Reply, + H: Header + Send, { - async fn handle( - &self, - mut request: tide::Request, - next: tide::Next<'_, State>, - ) -> tide::Result { - let csrf_token = request - // Get the CSRF cookie - .cookie(&self.cookie_name) - // Try decrypting it - .map(|cookie| UnencryptedToken::from_cookie(&cookie, &self.key)) - // If there was an error decrypting it, bail out here - .transpose()? - // Verify it's TTL (but do not hard-error if it expired) - .and_then(|token| token.verify_expiration().ok()) - .map_or_else( - // Generate a new token if no valid one were found - || UnencryptedToken::generate(self.ttl), - // Else, refresh the expiration of the cookie - |token| token.refresh(self.ttl), - ); - - // Build the cookie before calling the next stage since the owned csrf_token has - // to be passed as a request extension - let cookie = csrf_token - .to_cookie_builder(&self.cookie_name, &self.key)? - .finish() - .into_owned(); - - request.set_ext(csrf_token); - - let mut response = next.run(request).await; - response.insert_cookie(cookie); - - Ok(response) + fn into_response(self) -> warp::reply::Response { + let mut res = self.reply.into_response(); + res.headers_mut().typed_insert(self.header); + res + } +} + +pub fn with_csrf( + key: [u8; 32], + cookie_name: &'static str, +) -> impl Fn(F) -> BoxedFilter<(WithTypedHeader,)> +where + F: Filter + Clone + Send + Sync + 'static, +{ + move |f: F| { + f.and_then(move |token: UnencryptedToken, reply: R| async move { + let cookie = token + .to_cookie_builder(cookie_name, &key) + .wrap_error()? + .finish() + .to_string(); + let header = + SetCookie::decode(&mut [HeaderValue::from_str(&cookie).wrap_error()?].iter()) + .wrap_error()?; + Ok::<_, Rejection>(WithTypedHeader { reply, header }) + }) + .boxed() } } -*/ diff --git a/matrix-authentication-service/src/filters/mod.rs b/matrix-authentication-service/src/filters/mod.rs index de5b2a38a..761d80efc 100644 --- a/matrix-authentication-service/src/filters/mod.rs +++ b/matrix-authentication-service/src/filters/mod.rs @@ -12,5 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod csrf; +pub mod csrf; // mod errors; + +pub use csrf::UnencryptedToken as CsrfToken; diff --git a/matrix-authentication-service/src/handlers/mod.rs b/matrix-authentication-service/src/handlers/mod.rs index 13229f8e7..aae31d0eb 100644 --- a/matrix-authentication-service/src/handlers/mod.rs +++ b/matrix-authentication-service/src/handlers/mod.rs @@ -12,32 +12,41 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; +use std::{convert::Infallible, sync::Arc}; use sqlx::PgPool; use tera::Tera; -use warp::{filters::BoxedFilter, Filter}; +use warp::{filters::BoxedFilter, wrap_fn, Filter, Rejection, Reply}; -use crate::config::RootConfig; +use crate::{config::RootConfig, filters::csrf::with_csrf}; mod health; mod oauth2; mod views; +async fn display_error(err: Rejection) -> Result { + let ret = format!("{:?}", err); + Ok(ret) +} + pub fn root( pool: PgPool, templates: Tera, config: &RootConfig, ) -> BoxedFilter<(impl warp::Reply,)> { let templates = Arc::new(templates); - let with_pool = move || pool.clone(); - let with_templates = move || templates.clone(); + let with_csrf_token = config.csrf.clone().into_extract_filter(); + let with_pool = warp::any().map(move || pool.clone()); + let with_templates = warp::any().map(move || templates.clone()); + + // TODO: this is ugly and leaks + let csrf_cookie_name = Box::leak(Box::new(config.csrf.cookie_name.clone())); let cors = warp::cors().allow_any_origin(); let health = warp::path("health") .and(warp::get()) - .map(with_pool) + .and(with_pool.clone()) .and_then(self::health::get) .boxed(); @@ -48,10 +57,23 @@ pub fn root( let index = warp::path::end() .and(warp::get()) - .map(with_templates) - .and_then(self::views::index::get); + .and(with_templates.clone()) + .and(with_csrf_token.clone()) + .and(with_pool.clone()) + .and_then(self::views::index::get) + .untuple_one() + .with(wrap_fn(with_csrf(config.csrf.key, csrf_cookie_name))); - health.or(index).or(metadata).boxed() + let login = warp::path("login") + .and(warp::get()) + .and(with_templates) + .and(with_csrf_token) + .and(with_pool) + .and_then(self::views::login::get) + .untuple_one() + .with(wrap_fn(with_csrf(config.csrf.key, csrf_cookie_name))); + + health.or(index).or(login).or(metadata).boxed() // app.at("/").nest({ // let mut views = tide::with_state(state.clone()); diff --git a/matrix-authentication-service/src/handlers/views/index.rs b/matrix-authentication-service/src/handlers/views/index.rs index b4cd86214..f6f5ecb23 100644 --- a/matrix-authentication-service/src/handlers/views/index.rs +++ b/matrix-authentication-service/src/handlers/views/index.rs @@ -14,13 +14,28 @@ use std::sync::Arc; -use tera::{Context, Tera}; +use sqlx::PgPool; +use tera::Tera; use warp::{reply::with_header, Rejection, Reply}; -use crate::errors::WrapError; +use crate::{errors::WrapError, filters::CsrfToken, templates::CommonContext}; + +pub async fn get( + templates: Arc, + csrf_token: CsrfToken, + db: PgPool, +) -> Result<(CsrfToken, impl Reply), Rejection> { + let ctx = CommonContext::default() + .with_csrf_token(&csrf_token) + .with_session(&db) + .await + .wrap_error()? + .finish() + .wrap_error()?; -pub async fn get(templates: Arc) -> Result { - let ctx = Context::new(); let content = templates.render("index.html", &ctx).wrap_error()?; - Ok(with_header(content, "Content-Type", "text/html")) + Ok(( + csrf_token, + with_header(content, "Content-Type", "text/html"), + )) } diff --git a/matrix-authentication-service/src/handlers/views/login.rs b/matrix-authentication-service/src/handlers/views/login.rs index 79fce0d2b..5573882f8 100644 --- a/matrix-authentication-service/src/handlers/views/login.rs +++ b/matrix-authentication-service/src/handlers/views/login.rs @@ -12,10 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use serde::Deserialize; -use tide::{Redirect, Request, Response}; +use std::sync::Arc; -use crate::{csrf::CsrfForm, state::State, templates::common_context}; +use serde::Deserialize; +use sqlx::PgPool; +use tera::Tera; +use warp::{reply::with_header, Rejection, Reply}; + +use crate::{errors::WrapError, filters::CsrfToken, templates::CommonContext}; #[derive(Deserialize)] struct LoginForm { @@ -23,19 +27,28 @@ struct LoginForm { password: String, } -pub async fn get(req: Request) -> tide::Result { - let state = req.state(); - let ctx = common_context(&req).await?; +pub async fn get( + templates: Arc, + csrf_token: CsrfToken, + db: PgPool, +) -> Result<(CsrfToken, impl Reply), Rejection> { + let ctx = CommonContext::default() + .with_csrf_token(&csrf_token) + .with_session(&db) + .await + .wrap_error()? + .finish() + .wrap_error()?; // TODO: check if there is an existing session - let content = state.templates().render("login.html", &ctx)?; - let body = Response::builder(200) - .body(content) - .content_type("text/html") - .into(); - Ok(body) + let content = templates.render("login.html", &ctx).wrap_error()?; + Ok(( + csrf_token, + with_header(content, "Content-Type", "text/html"), + )) } +/* pub async fn post(mut req: Request) -> tide::Result { let form: CsrfForm = req.body_form().await?; let form = form.verify_csrf(&req)?; @@ -51,3 +64,4 @@ pub async fn post(mut req: Request) -> tide::Result { Ok(Redirect::new("/").into()) } +*/ diff --git a/matrix-authentication-service/src/handlers/views/mod.rs b/matrix-authentication-service/src/handlers/views/mod.rs index 6dd6c6a43..eda6f9e6a 100644 --- a/matrix-authentication-service/src/handlers/views/mod.rs +++ b/matrix-authentication-service/src/handlers/views/mod.rs @@ -13,6 +13,6 @@ // limitations under the License. pub(super) mod index; -// pub(super) mod login; +pub(super) mod login; // pub(super) mod logout; // pub(super) mod reauth; diff --git a/matrix-authentication-service/src/main.rs b/matrix-authentication-service/src/main.rs index 3777fe2de..f92adbce1 100644 --- a/matrix-authentication-service/src/main.rs +++ b/matrix-authentication-service/src/main.rs @@ -23,8 +23,8 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilte mod cli; mod config; +mod csrf; mod errors; -// mod csrf; mod filters; mod handlers; mod state; diff --git a/matrix-authentication-service/src/storage/mod.rs b/matrix-authentication-service/src/storage/mod.rs index d536b884d..329db0398 100644 --- a/matrix-authentication-service/src/storage/mod.rs +++ b/matrix-authentication-service/src/storage/mod.rs @@ -22,7 +22,7 @@ mod user; pub use self::{ client::{Client, ClientLookupError, InvalidRedirectUriError}, - user::User, + user::{lookup_session, SessionInfo, User}, }; pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/matrix-authentication-service/src/templates.rs b/matrix-authentication-service/src/templates.rs index e2ce61a94..a7a5bd578 100644 --- a/matrix-authentication-service/src/templates.rs +++ b/matrix-authentication-service/src/templates.rs @@ -12,15 +12,56 @@ // See the License for the specific language governing permissions and // limitations under the License. -use tera::Tera; +use anyhow::Context as _; +use serde::Serialize; +use sqlx::{Executor, Postgres}; +use tera::{Context, Tera}; use tracing::info; +use crate::{ + filters::CsrfToken, + storage::{lookup_session, SessionInfo}, +}; + pub fn load() -> Result { let path = format!("{}/templates/**/*.{{html,txt}}", env!("CARGO_MANIFEST_DIR")); info!(%path, "Loading templates"); Tera::new(&path) } +#[derive(Serialize, Default)] +pub struct CommonContext { + csrf_token: Option, + session: Option, +} + +impl CommonContext { + pub fn with_csrf_token(self, token: &CsrfToken) -> Self { + Self { + csrf_token: Some(token.form_value()), + ..self + } + } + + pub async fn with_session<'e>( + self, + _executor: impl Executor<'e, Database = Postgres>, + ) -> anyhow::Result { + Ok(self) + /* + let session = lookup_session(executor, 1).await?; + Ok(Self { + session: Some(session), + ..self + }) + */ + } + + pub fn finish(self) -> anyhow::Result { + Context::from_serialize(&self).context("could not serialize common context for templates") + } +} + // pub async fn common_context(req: &Request) -> Result { // let state = req.state(); // let session = req.session();