From 846a4ee14ab029986576bf7d1ef1a205751ef241 Mon Sep 17 00:00:00 2001 From: Tonkku Date: Mon, 7 Oct 2024 21:50:18 +0300 Subject: [PATCH] Implement login_hint --- crates/data-model/Cargo.toml | 2 + crates/data-model/src/lib.rs | 2 +- .../src/oauth2/authorization_grant.rs | 124 ++++++++++++++++++ crates/data-model/src/oauth2/mod.rs | 2 +- .../handlers/src/oauth2/authorization/mod.rs | 1 + crates/handlers/src/oauth2/token.rs | 2 + crates/handlers/src/views/login.rs | 30 ++++- .../20241007160050_oidc_login_hint.sql | 3 + .../src/oauth2/authorization_grant.rs | 10 +- crates/storage-pg/src/oauth2/mod.rs | 1 + .../storage/src/oauth2/authorization_grant.rs | 3 + crates/templates/src/context.rs | 5 + crates/templates/src/forms.rs | 10 ++ 13 files changed, 189 insertions(+), 6 deletions(-) create mode 100644 crates/storage-pg/migrations/20241007160050_oidc_login_hint.sql diff --git a/crates/data-model/Cargo.toml b/crates/data-model/Cargo.toml index 081e09ae9..8e555a536 100644 --- a/crates/data-model/Cargo.toml +++ b/crates/data-model/Cargo.toml @@ -22,7 +22,9 @@ rand.workspace = true rand_chacha = "0.3.1" regex = "1.11.1" woothee = "0.13.0" +ruma-common = "0.13.0" mas-iana.workspace = true mas-jose.workspace = true +mas-matrix.workspace = true oauth2-types.workspace = true diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 939e904f3..c0b39792a 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -9,7 +9,7 @@ use thiserror::Error; pub(crate) mod compat; -pub(crate) mod oauth2; +pub mod oauth2; mod site_config; pub(crate) mod tokens; pub(crate) mod upstream_oauth2; diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index c995a3623..2ddfb14ed 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -8,6 +8,7 @@ use std::num::NonZeroU32; use chrono::{DateTime, Duration, Utc}; use mas_iana::oauth::PkceCodeChallengeMethod; +use mas_matrix::BoxHomeserverConnection; use oauth2_types::{ pkce::{CodeChallengeError, CodeChallengeMethodExt}, requests::ResponseMode, @@ -17,6 +18,7 @@ use rand::{ distributions::{Alphanumeric, DistString}, RngCore, }; +use ruma_common::{OwnedUserId, UserId}; use serde::Serialize; use ulid::Ulid; use url::Url; @@ -141,6 +143,11 @@ impl AuthorizationGrantStage { } } +pub enum LoginHint { + MXID(OwnedUserId), + None, +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct AuthorizationGrant { pub id: Ulid, @@ -157,6 +164,7 @@ pub struct AuthorizationGrant { pub response_type_id_token: bool, pub created_at: DateTime, pub requires_consent: bool, + pub login_hint: Option, } impl std::ops::Deref for AuthorizationGrant { @@ -179,6 +187,35 @@ impl AuthorizationGrant { self.created_at - max_age } + pub fn parse_login_hint(&self, homeserver: BoxHomeserverConnection) -> LoginHint { + let Some(login_hint) = &self.login_hint else { + return LoginHint::None; + }; + + // Return none if the format is incorrect + let Some((prefix, value)) = login_hint.split_once(":") else { + return LoginHint::None; + }; + + match prefix { + "mxid" => { + // Instead of erroring just return none + let Ok(mxid) = UserId::parse(value) else { + return LoginHint::None; + }; + + // Only handle MXIDs for current homeserver + if mxid.server_name() != homeserver.homeserver() { + return LoginHint::None; + } + + LoginHint::MXID(mxid) + }, + // Unknown hint type, treat as none + _ => LoginHint::None + } + } + /// Mark the authorization grant as exchanged. /// /// # Errors @@ -242,6 +279,93 @@ impl AuthorizationGrant { response_type_id_token: false, created_at: now, requires_consent: false, + login_hint: Some(String::from("mxid:@example-user:example.com")) } } } + +#[cfg(test)] +mod tests { + use mas_matrix::MockHomeserverConnection; + use rand::thread_rng; + use super::*; + + fn get_homeserver() -> BoxHomeserverConnection { + Box::new(MockHomeserverConnection::new("example.com")) + } + + #[test] + fn no_login_hint() { + #[allow(clippy::disallowed_methods)] + let mut rng = thread_rng(); + + let grant = AuthorizationGrant { + login_hint: None, + ..AuthorizationGrant::sample(Utc::now(), &mut rng) + }; + + let hint = grant.parse_login_hint(get_homeserver()); + + assert!(matches!(hint, LoginHint::None)); + } + + #[test] + fn valid_login_hint() { + #[allow(clippy::disallowed_methods)] + let mut rng = thread_rng(); + + let grant = AuthorizationGrant { + login_hint: Some(String::from("mxid:@example-user:example.com")), + ..AuthorizationGrant::sample(Utc::now(), &mut rng) + }; + + let hint = grant.parse_login_hint(get_homeserver()); + + assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user")); + } + + #[test] + fn invalid_login_hint() { + #[allow(clippy::disallowed_methods)] + let mut rng = thread_rng(); + + let grant = AuthorizationGrant { + login_hint: Some(String::from("example-user")), + ..AuthorizationGrant::sample(Utc::now(), &mut rng) + }; + + let hint = grant.parse_login_hint(get_homeserver()); + + assert!(matches!(hint, LoginHint::None)); + } + + #[test] + fn valid_login_hint_for_wrong_homeserver() { + #[allow(clippy::disallowed_methods)] + let mut rng = thread_rng(); + + let grant = AuthorizationGrant { + login_hint: Some(String::from("mxid:@example-user:matrix.org")), + ..AuthorizationGrant::sample(Utc::now(), &mut rng) + }; + + let hint = grant.parse_login_hint(get_homeserver()); + + assert!(matches!(hint, LoginHint::None)); + } + + #[test] + fn unknown_login_hint_type() { + #[allow(clippy::disallowed_methods)] + let mut rng = thread_rng(); + + let grant = AuthorizationGrant { + login_hint: Some(String::from("something:anything")), + ..AuthorizationGrant::sample(Utc::now(), &mut rng) + }; + + let hint = grant.parse_login_hint(get_homeserver()); + + assert!(matches!(hint, LoginHint::None)); + } +} diff --git a/crates/data-model/src/oauth2/mod.rs b/crates/data-model/src/oauth2/mod.rs index 75fd04126..d4d019634 100644 --- a/crates/data-model/src/oauth2/mod.rs +++ b/crates/data-model/src/oauth2/mod.rs @@ -10,7 +10,7 @@ mod device_code_grant; mod session; pub use self::{ - authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce}, + authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce, LoginHint}, client::{Client, InvalidRedirectUriError, JwksOrJwksUri}, device_code_grant::{DeviceCodeGrant, DeviceCodeGrantState}, session::{Session, SessionState}, diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index fc7b477c9..236e8a795 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -291,6 +291,7 @@ pub(crate) async fn get( response_mode, response_type.has_id_token(), requires_consent, + params.auth.login_hint, ) .await?; let continue_grant = PostAuthAction::continue_grant(grant.id); diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 6d6d36849..a228f746b 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -854,6 +854,7 @@ mod tests { ResponseMode::Query, false, false, + None, ) .await .unwrap(); @@ -954,6 +955,7 @@ mod tests { ResponseMode::Query, false, false, + None, ) .await .unwrap(); diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 3a65c3284..51b49bb7a 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -15,8 +15,9 @@ use mas_axum_utils::{ csrf::{CsrfExt, CsrfToken, ProtectedForm}, FancyError, SessionInfoExt, }; -use mas_data_model::{BrowserSession, UserAgent}; +use mas_data_model::{BrowserSession, UserAgent, oauth2::LoginHint}; use mas_i18n::DataLocale; +use mas_matrix::BoxHomeserverConnection; use mas_router::{UpstreamOAuth2Authorize, UrlBuilder}; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, @@ -24,7 +25,7 @@ use mas_storage::{ BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, }; use mas_templates::{ - FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, + FieldError, FormError, LoginContext, LoginFormField, PostAuthContext, PostAuthContextInner, TemplateContext, Templates, ToFormState }; use rand::{CryptoRng, Rng}; use serde::{Deserialize, Serialize}; @@ -54,6 +55,7 @@ pub(crate) async fn get( State(templates): State, State(url_builder): State, State(site_config): State, + State(homeserver): State, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, Query(query): Query, @@ -96,6 +98,7 @@ pub(crate) async fn get( csrf_token, &mut repo, &templates, + homeserver, ) .await?; @@ -112,6 +115,7 @@ pub(crate) async fn post( State(templates): State, State(url_builder): State, State(limiter): State, + State(homeserver): State, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, requester: RequesterFingerprint, @@ -156,6 +160,7 @@ pub(crate) async fn post( csrf_token, &mut repo, &templates, + homeserver, ) .await?; @@ -196,6 +201,7 @@ pub(crate) async fn post( csrf_token, &mut repo, &templates, + homeserver, ) .await?; @@ -286,16 +292,34 @@ async fn login( Ok(user_session) } +fn handle_login_hint(ctx: &mut LoginContext, next: &PostAuthContext, homeserver: BoxHomeserverConnection) { + let form_state = ctx.form_state_mut(); + + // Do not override username if coming from a failed login attempt + if form_state.has_value(LoginFormField::Username) { return; } + + if let PostAuthContextInner::ContinueAuthorizationGrant { ref grant } = next.ctx { + let value = match grant.parse_login_hint(homeserver) { + LoginHint::MXID(mxid) => Some(mxid.localpart().to_string()), + LoginHint::None => None, + }; + form_state.set_value(LoginFormField::Username, value); + } +} + async fn render( locale: DataLocale, - ctx: LoginContext, + mut ctx: LoginContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, repo: &mut impl RepositoryAccess, templates: &Templates, + homeserver: BoxHomeserverConnection, ) -> Result { let next = action.load_context(repo).await?; let ctx = if let Some(next) = next { + handle_login_hint(&mut ctx, &next, homeserver); + ctx.with_post_action(next) } else { ctx diff --git a/crates/storage-pg/migrations/20241007160050_oidc_login_hint.sql b/crates/storage-pg/migrations/20241007160050_oidc_login_hint.sql new file mode 100644 index 000000000..b18932ec7 --- /dev/null +++ b/crates/storage-pg/migrations/20241007160050_oidc_login_hint.sql @@ -0,0 +1,3 @@ +-- Add login_hint to oauth2_authorization_grants +ALTER TABLE "oauth2_authorization_grants" + ADD COLUMN "login_hint" TEXT; diff --git a/crates/storage-pg/src/oauth2/authorization_grant.rs b/crates/storage-pg/src/oauth2/authorization_grant.rs index 9034a042d..12450ca4b 100644 --- a/crates/storage-pg/src/oauth2/authorization_grant.rs +++ b/crates/storage-pg/src/oauth2/authorization_grant.rs @@ -55,6 +55,7 @@ struct GrantLookup { code_challenge: Option, code_challenge_method: Option, requires_consent: bool, + login_hint: Option, oauth2_client_id: Uuid, oauth2_session_id: Option, } @@ -185,6 +186,7 @@ impl TryFrom for AuthorizationGrant { created_at: value.created_at, response_type_id_token: value.response_type_id_token, requires_consent: value.requires_consent, + login_hint: value.login_hint, }) } } @@ -218,6 +220,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi response_mode: ResponseMode, response_type_id_token: bool, requires_consent: bool, + login_hint: Option, ) -> Result { let code_challenge = code .as_ref() @@ -252,10 +255,11 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi response_type_id_token, authorization_code, requires_consent, + login_hint, created_at ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) "#, Uuid::from(id), Uuid::from(client.id), @@ -271,6 +275,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi response_type_id_token, code_str, requires_consent, + login_hint, created_at, ) .traced() @@ -291,6 +296,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi created_at, response_type_id_token, requires_consent, + login_hint, }) } @@ -325,6 +331,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi , code_challenge , code_challenge_method , requires_consent + , login_hint , oauth2_session_id FROM oauth2_authorization_grants @@ -375,6 +382,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi , code_challenge , code_challenge_method , requires_consent + , login_hint , oauth2_session_id FROM oauth2_authorization_grants diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs index 2225b1c83..c55b4d70e 100644 --- a/crates/storage-pg/src/oauth2/mod.rs +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -138,6 +138,7 @@ mod tests { ResponseMode::Query, true, false, + None, ) .await .unwrap(); diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 3313c7bcd..ea18087ea 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -43,6 +43,7 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync { /// * `response_type_id_token`: Whether the `id_token` `response_type` was /// requested /// * `requires_consent`: Whether the client explicitly requested consent + /// * `login_hint`: The login_hint the client sent, if set /// /// # Errors /// @@ -62,6 +63,7 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync { response_mode: ResponseMode, response_type_id_token: bool, requires_consent: bool, + login_hint: Option, ) -> Result; /// Lookup an authorization grant by its ID @@ -162,6 +164,7 @@ repository_impl!(OAuth2AuthorizationGrantRepository: response_mode: ResponseMode, response_type_id_token: bool, requires_consent: bool, + login_hint: Option, ) -> Result; async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 045450373..762930b64 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -482,6 +482,11 @@ impl LoginContext { Self { form, ..self } } + /// Mutably borrow the form state + pub fn form_state_mut(&mut self) -> &mut FormState { + &mut self.form + } + /// Set the upstream OAuth 2.0 providers #[must_use] pub fn with_upstream_providers(self, providers: Vec) -> Self { diff --git a/crates/templates/src/forms.rs b/crates/templates/src/forms.rs index 2fcb7e3e9..4539cf54e 100644 --- a/crates/templates/src/forms.rs +++ b/crates/templates/src/forms.rs @@ -166,6 +166,16 @@ impl FormState { self } + /// Set a value on the form + pub fn set_value(&mut self, field: K, value: Option) { + self.fields.entry(field).or_default().value = value; + } + + /// Checks if a field contains a value + pub fn has_value(&self, field: K) -> bool { + self.fields.get(&field).map(|f| f.value.is_some()).unwrap_or(false) + } + /// Returns `true` if the form has no error attached to it #[must_use] pub fn is_valid(&self) -> bool {