diff --git a/crates/cli/src/sync.rs b/crates/cli/src/sync.rs index 9d810a7f3..c4aeb9a9c 100644 --- a/crates/cli/src/sync.rs +++ b/crates/cli/src/sync.rs @@ -64,6 +64,7 @@ fn map_claims_imports( subject: mas_data_model::UpstreamOAuthProviderSubjectPreference { template: config.subject.template.clone(), }, + skip_confirmation: config.skip_confirmation, localpart: mas_data_model::UpstreamOAuthProviderLocalpartPreference { action: map_import_action(config.localpart.action), template: config.localpart.template.clone(), diff --git a/crates/config/src/sections/upstream_oauth2.rs b/crates/config/src/sections/upstream_oauth2.rs index caee97294..40591b004 100644 --- a/crates/config/src/sections/upstream_oauth2.rs +++ b/crates/config/src/sections/upstream_oauth2.rs @@ -118,6 +118,26 @@ impl ConfigurationSection for UpstreamOAuth2Config { } } + if provider.claims_imports.skip_confirmation { + if provider.claims_imports.localpart.action != ImportAction::Require { + return Err(annotate(figment::Error::custom( + "The field `action` must be `require` when `skip_confirmation` is set to `true`", + )).with_path("claims_imports.localpart").into()); + } + + if provider.claims_imports.email.action == ImportAction::Suggest { + return Err(annotate(figment::Error::custom( + "The field `action` must not be `suggest` when `skip_confirmation` is set to `true`", + )).with_path("claims_imports.email").into()); + } + + if provider.claims_imports.displayname.action == ImportAction::Suggest { + return Err(annotate(figment::Error::custom( + "The field `action` must not be `suggest` when `skip_confirmation` is set to `true`", + )).with_path("claims_imports.displayname").into()); + } + } + if matches!( provider.claims_imports.localpart.on_conflict, OnConflict::Add | OnConflict::Replace | OnConflict::Set @@ -333,6 +353,13 @@ pub struct ClaimsImports { #[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")] pub subject: SubjectImportPreference, + /// Whether to skip the interactive screen prompting the user to confirm the + /// attributes that are being imported. This requires `localpart.action` to + /// be `require` and other attribute actions to be either `ignore`, `force` + /// or `require` + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub skip_confirmation: bool, + /// Import the localpart of the MXID #[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")] pub localpart: LocalpartImportPreference, @@ -344,8 +371,7 @@ pub struct ClaimsImports { )] pub displayname: DisplaynameImportPreference, - /// Import the email address of the user based on the `email` and - /// `email_verified` claims + /// Import the email address of the user #[serde(default, skip_serializing_if = "EmailImportPreference::is_default")] pub email: EmailImportPreference, @@ -361,8 +387,10 @@ impl ClaimsImports { const fn is_default(&self) -> bool { self.subject.is_default() && self.localpart.is_default() + && !self.skip_confirmation && self.displayname.is_default() && self.email.is_default() + && self.account_name.is_default() } } diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs index 671393910..94f6c2e51 100644 --- a/crates/data-model/src/upstream_oauth2/provider.rs +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -312,6 +312,9 @@ pub struct ClaimsImports { #[serde(default)] pub subject: SubjectPreference, + #[serde(default)] + pub skip_confirmation: bool, + #[serde(default)] pub localpart: LocalpartPreference, diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 15c065be3..ba24ed311 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -4,7 +4,10 @@ // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // Please see LICENSE files in the repository root for full details. -use std::sync::{Arc, LazyLock}; +use std::{ + net::IpAddr, + sync::{Arc, LazyLock}, +}; use axum::{ Form, @@ -19,7 +22,10 @@ use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, record_error, }; -use mas_data_model::{BoxClock, BoxRng, UpstreamOAuthProviderOnConflict}; +use mas_data_model::{ + BoxClock, BoxRng, UpstreamOAuthAuthorizationSession, UpstreamOAuthProviderOnConflict, + UserRegistration, +}; use mas_jose::jwt::Jwt; use mas_matrix::HomeserverConnection; use mas_policy::Policy; @@ -240,10 +246,6 @@ pub(crate) async fn get( .lookup_link(link_id) .map_err(|_| RouteError::MissingCookie)?; - let post_auth_action = OptionalPostAuthAction { - post_auth_action: post_auth_action.cloned(), - }; - let link = repo .upstream_oauth_link() .lookup(link_id) @@ -287,6 +289,10 @@ pub(crate) async fn get( repo.save().await?; + let post_auth_action = OptionalPostAuthAction { + post_auth_action: post_auth_action.cloned(), + }; + post_auth_action.go_next(&url_builder).into_response() } @@ -359,6 +365,10 @@ pub(crate) async fn get( .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session) .await?; + let post_auth_action = OptionalPostAuthAction { + post_auth_action: post_auth_action.cloned(), + }; + cookie_jar = sessions_cookie .consume_link(link_id)? .save(cookie_jar, &clock); @@ -388,8 +398,6 @@ pub(crate) async fn get( .await? .ok_or(RouteError::ProviderNotFound(link.provider_id))?; - let mut ctx = UpstreamRegister::new(link.clone(), provider.clone()); - let env = environment(); let mut context = AttributeMappingContext::new(); @@ -423,13 +431,6 @@ pub(crate) async fn get( )? }; - if let Some(displayname) = displayname { - ctx = ctx.with_display_name( - displayname, - provider.claims_imports.displayname.is_forced_or_required(), - ); - } - let email = if provider.claims_imports.email.ignore() { None } else { @@ -448,13 +449,6 @@ pub(crate) async fn get( )? }; - if let Some(ref email) = email { - ctx = ctx.with_email( - email.clone(), - provider.claims_imports.email.is_forced_or_required(), - ); - } - // We do a bunch of checks for the localpart. Instead of using nested ifs all // the way, we use a labelled block, and use `break` for 'exiting' early when // needed @@ -710,6 +704,10 @@ pub(crate) async fn get( .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session) .await?; + let post_auth_action = OptionalPostAuthAction { + post_auth_action: post_auth_action.cloned(), + }; + let cookie_jar = sessions_cookie .consume_link(link_id)? .save(cookie_jar, &clock) @@ -767,6 +765,53 @@ pub(crate) async fn get( Some(localpart) }; + if provider.claims_imports.skip_confirmation { + let Some(localpart) = localpart else { + return Err(RouteError::Internal( + "No localpart available even though the provider is configured to skip confirmation, this is a bug!".into() + )); + }; + + // Register on the fly + REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]); + + let registration = prepare_user_registration( + &mut rng, + &clock, + &mut repo, + upstream_session, + localpart, + displayname, + email, + activity_tracker.ip(), + user_agent, + post_auth_action.map(|action| serde_json::json!(action)), + ) + .await?; + + let registrations = UserRegistrationSessionsCookie::load(&cookie_jar); + + let cookie_jar = sessions_cookie + .consume_link(link_id)? + .save(cookie_jar, &clock); + + let cookie_jar = registrations.add(®istration).save(cookie_jar, &clock); + + repo.save().await?; + + // Redirect to the user registration flow, in case we have any other step to + // finish + return Ok(( + cookie_jar, + url_builder + .redirect(&mas_router::RegisterFinish::new(registration.id)) + .into_response(), + )); + } + + // Else we show the upstream registration screen + let mut ctx = UpstreamRegister::new(link.clone(), provider.clone()); + if let Some(localpart) = localpart { ctx = ctx.with_localpart( localpart, @@ -774,6 +819,17 @@ pub(crate) async fn get( ); } + if let Some(displayname) = displayname { + ctx = ctx.with_display_name( + displayname, + provider.claims_imports.displayname.is_forced_or_required(), + ); + } + + if let Some(email) = email { + ctx = ctx.with_email(email, provider.claims_imports.email.is_forced_or_required()); + } + let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale); Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response() @@ -1090,17 +1146,19 @@ pub(crate) async fn post( REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]); - let mut registration = repo - .user_registration() - .add( - &mut rng, - &clock, - username, - activity_tracker.ip(), - user_agent, - post_auth_action.map(|action| serde_json::json!(action)), - ) - .await?; + let mut registration = prepare_user_registration( + &mut rng, + &clock, + &mut repo, + upstream_session, + username, + display_name, + email, + activity_tracker.ip(), + user_agent, + post_auth_action.map(|action| serde_json::json!(action)), + ) + .await?; if let Some(terms_url) = &site_config.tos_uri { registration = repo @@ -1109,44 +1167,6 @@ pub(crate) async fn post( .await?; } - // If we have an email, add an email authentication and complete it - if let Some(email) = email { - let authentication = repo - .user_email() - .add_authentication_for_registration(&mut rng, &clock, email, ®istration) - .await?; - let authentication = repo - .user_email() - .complete_authentication_with_upstream( - &clock, - authentication, - &upstream_session, - ) - .await?; - - registration = repo - .user_registration() - .set_email_authentication(registration, &authentication) - .await?; - } - - // If we have a display name, add it to the registration - if let Some(name) = display_name { - registration = repo - .user_registration() - .set_display_name(registration, name) - .await?; - } - - let registration = repo - .user_registration() - .set_upstream_oauth_authorization_session(registration, &upstream_session) - .await?; - - repo.upstream_oauth_session() - .consume(&clock, upstream_session) - .await?; - let registrations = UserRegistrationSessionsCookie::load(&cookie_jar); let cookie_jar = sessions_cookie @@ -1170,6 +1190,69 @@ pub(crate) async fn post( } } +/// Create a user registration using attributes got from the upstream +/// authorization session +async fn prepare_user_registration( + rng: &mut BoxRng, + clock: &BoxClock, + repo: &mut BoxRepository, + upstream_session: UpstreamOAuthAuthorizationSession, + localpart: String, + displayname: Option, + email: Option, + ip_address: Option, + user_agent: Option, + post_auth_action: Option, +) -> Result { + let mut registration = repo + .user_registration() + .add( + rng, + clock, + localpart, + ip_address, + user_agent, + post_auth_action, + ) + .await?; + + // If we have an email, add an email authentication and complete it + if let Some(email) = email { + let authentication = repo + .user_email() + .add_authentication_for_registration(rng, clock, email, ®istration) + .await?; + let authentication = repo + .user_email() + .complete_authentication_with_upstream(clock, authentication, &upstream_session) + .await?; + + registration = repo + .user_registration() + .set_email_authentication(registration, &authentication) + .await?; + } + + // If we have a display name, add it to the registration + if let Some(name) = displayname { + registration = repo + .user_registration() + .set_display_name(registration, name) + .await?; + } + + let registration = repo + .user_registration() + .set_upstream_oauth_authorization_session(registration, &upstream_session) + .await?; + + repo.upstream_oauth_session() + .consume(clock, upstream_session) + .await?; + + Ok(registration) +} + #[cfg(test)] mod tests { use hyper::{Request, StatusCode, header::CONTENT_TYPE}; @@ -1386,6 +1469,178 @@ mod tests { assert!(email_auth.completed_at.is_some()); } + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_register_skip_confirmation(pool: PgPool) { + // Same test as test_register, but checks that we get straight to the + // registration flow skipping the confirmation + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + let mut rng = state.rng(); + let cookies = CookieHelper::new(); + + let claims_imports = UpstreamOAuthProviderClaimsImports { + skip_confirmation: true, + localpart: UpstreamOAuthProviderLocalpartPreference { + action: mas_data_model::UpstreamOAuthProviderImportAction::Require, + template: None, + on_conflict: mas_data_model::UpstreamOAuthProviderOnConflict::default(), + }, + email: UpstreamOAuthProviderImportPreference { + action: mas_data_model::UpstreamOAuthProviderImportAction::Force, + template: None, + }, + ..UpstreamOAuthProviderClaimsImports::default() + }; + + let id_token_claims = serde_json::json!({ + "preferred_username": "john", + "email": "john@example.com", + "email_verified": true, + }); + + // Grab a key to sign the id_token + // We could generate a key on the fly, but because we have one available here, + // why not use it? + let key = state + .key_store + .signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256) + .unwrap(); + + let signer = key + .params() + .signing_key_for_alg(&JsonWebSignatureAlg::Rs256) + .unwrap(); + let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256); + let id_token = + Jwt::sign_with_rng(&mut rng, header, id_token_claims.clone(), &signer).unwrap(); + + // Provision a provider and a link + let mut repo = state.repository().await.unwrap(); + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &state.clock, + UpstreamOAuthProviderParams { + issuer: Some("https://example.com/".to_owned()), + human_name: Some("Example Ltd.".to_owned()), + brand_name: None, + scope: Scope::from_iter([OPENID]), + token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, + token_endpoint_signing_alg: None, + id_token_signed_response_alg: JsonWebSignatureAlg::Rs256, + client_id: "client".to_owned(), + encrypted_client_secret: None, + claims_imports, + authorization_endpoint_override: None, + token_endpoint_override: None, + userinfo_endpoint_override: None, + fetch_userinfo: false, + userinfo_signed_response_alg: None, + jwks_uri_override: None, + discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, + pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, + response_mode: None, + additional_authorization_parameters: Vec::new(), + forward_login_hint: false, + ui_order: 0, + on_backchannel_logout: + mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::DoNothing, + }, + ) + .await + .unwrap(); + + let session = repo + .upstream_oauth_session() + .add( + &mut rng, + &state.clock, + &provider, + "state".to_owned(), + None, + None, + ) + .await + .unwrap(); + + let link = repo + .upstream_oauth_link() + .add( + &mut rng, + &state.clock, + &provider, + "subject".to_owned(), + None, + ) + .await + .unwrap(); + + let session = repo + .upstream_oauth_session() + .complete_with_link( + &state.clock, + session, + &link, + Some(id_token.into_string()), + Some(id_token_claims), + None, + None, + ) + .await + .unwrap(); + + repo.save().await.unwrap(); + + let cookie_jar = state.cookie_jar(); + let upstream_sessions = UpstreamSessionsCookie::default() + .add(session.id, provider.id, "state".to_owned(), None) + .add_link_to_session(session.id, link.id) + .unwrap(); + let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock); + cookies.import(cookie_jar); + + let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty(); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + let location = response.headers().get(hyper::header::LOCATION).unwrap(); + // Grab the registration ID from the redirected URL: + // /register/steps/{id}/finish + let registration_id: Ulid = str::from_utf8(location.as_bytes()) + .unwrap() + .rsplit('/') + .nth(1) + .expect("Location to have two slashes") + .parse() + .expect("last segment of location to be a ULID"); + + // Check that we have a registered user, with the email imported + let mut repo = state.repository().await.unwrap(); + let registration: UserRegistration = repo + .user_registration() + .lookup(registration_id) + .await + .unwrap() + .expect("user registration exists"); + + assert_eq!(registration.password, None); + assert_eq!(registration.completed_at, None); + assert_eq!(registration.username, "john"); + + let email_auth_id = registration + .email_authentication_id + .expect("registration should have an email authentication"); + let email_auth: UserEmailAuthentication = repo + .user_email() + .lookup_authentication(email_auth_id) + .await + .unwrap() + .expect("email authentication should exist"); + assert_eq!(email_auth.email, "john@example.com"); + assert!(email_auth.completed_at.is_some()); + } + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_link_existing_account(pool: PgPool) { let existing_username = "john"; diff --git a/docs/config.schema.json b/docs/config.schema.json index dd9a95b46..f6d947e48 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -2471,6 +2471,10 @@ } ] }, + "skip_confirmation": { + "description": "Whether to skip the interactive screen prompting the user to confirm the\n attributes that are being imported. This requires `localpart.action` to\n be `require` and other attribute actions to be either `ignore`, `force`\n or `require`", + "type": "boolean" + }, "localpart": { "description": "Import the localpart of the MXID", "allOf": [ @@ -2488,7 +2492,7 @@ ] }, "email": { - "description": "Import the email address of the user based on the `email` and\n `email_verified` claims", + "description": "Import the email address of the user", "allOf": [ { "$ref": "#/definitions/EmailImportPreference" diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index 5c1572ac1..62a86db68 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -771,6 +771,14 @@ upstream_oauth2: subject: #template: "{{ user.sub }}" + # By default, new users will see a screen confirming the attributes they + # are about to have on their account. + # + # Setting this to `true` allows skipping this screen, but requires the + # `localpart.action` to be set to `require` and the other attributes + # actions to be set to `ignore`, `force` or `require`. + #skip_confirmation: false + # The localpart is the local part of the user's Matrix ID. # For example, on the `example.com` server, if the localpart is `alice`, # the user's Matrix ID will be `@alice:example.com`.