diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index c6c1de95d..5125f24ce 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -4,7 +4,10 @@ // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // Please see LICENSE files in the repository root for full details. -use std::sync::{Arc, LazyLock}; +use std::{ + net::IpAddr, + sync::{Arc, LazyLock}, +}; use axum::{ Form, @@ -19,7 +22,10 @@ use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, record_error, }; -use mas_data_model::{BoxClock, BoxRng, UpstreamOAuthProviderOnConflict}; +use mas_data_model::{ + BoxClock, BoxRng, UpstreamOAuthAuthorizationSession, UpstreamOAuthProviderOnConflict, + UserRegistration, +}; use mas_jose::jwt::Jwt; use mas_matrix::HomeserverConnection; use mas_policy::Policy; @@ -238,10 +244,6 @@ pub(crate) async fn get( .lookup_link(link_id) .map_err(|_| RouteError::MissingCookie)?; - let post_auth_action = OptionalPostAuthAction { - post_auth_action: post_auth_action.cloned(), - }; - let link = repo .upstream_oauth_link() .lookup(link_id) @@ -285,6 +287,10 @@ pub(crate) async fn get( repo.save().await?; + let post_auth_action = OptionalPostAuthAction { + post_auth_action: post_auth_action.cloned(), + }; + post_auth_action.go_next(&url_builder).into_response() } @@ -357,6 +363,10 @@ pub(crate) async fn get( .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session) .await?; + let post_auth_action = OptionalPostAuthAction { + post_auth_action: post_auth_action.cloned(), + }; + cookie_jar = sessions_cookie .consume_link(link_id)? .save(cookie_jar, &clock); @@ -386,8 +396,6 @@ pub(crate) async fn get( .await? .ok_or(RouteError::ProviderNotFound(link.provider_id))?; - let mut ctx = UpstreamRegister::new(link.clone(), provider.clone()); - let env = environment(); let mut context = AttributeMappingContext::new(); @@ -421,13 +429,6 @@ pub(crate) async fn get( )? }; - if let Some(displayname) = displayname { - ctx = ctx.with_display_name( - displayname, - provider.claims_imports.displayname.is_forced_or_required(), - ); - } - let email = if provider.claims_imports.email.ignore() { None } else { @@ -446,13 +447,6 @@ pub(crate) async fn get( )? }; - if let Some(ref email) = email { - ctx = ctx.with_email( - email.clone(), - provider.claims_imports.email.is_forced_or_required(), - ); - } - // We do a bunch of checks for the localpart. Instead of using nested ifs all // the way, we use a labelled block, and use `break` for 'exiting' early when // needed @@ -622,6 +616,10 @@ pub(crate) async fn get( .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session) .await?; + let post_auth_action = OptionalPostAuthAction { + post_auth_action: post_auth_action.cloned(), + }; + let cookie_jar = sessions_cookie .consume_link(link_id)? .save(cookie_jar, &clock) @@ -679,6 +677,53 @@ pub(crate) async fn get( Some(localpart) }; + if provider.claims_imports.skip_confirmation { + let Some(localpart) = localpart else { + return Err(RouteError::Internal( + "No localpart available even though the provider is configured to skip confirmation, this is a bug!".into() + )); + }; + + // Register on the fly + REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]); + + let registration = prepare_user_registration( + &mut rng, + &clock, + &mut repo, + upstream_session, + localpart, + displayname, + email, + activity_tracker.ip(), + user_agent, + post_auth_action.map(|action| serde_json::json!(action)), + ) + .await?; + + let registrations = UserRegistrationSessionsCookie::load(&cookie_jar); + + let cookie_jar = sessions_cookie + .consume_link(link_id)? + .save(cookie_jar, &clock); + + let cookie_jar = registrations.add(®istration).save(cookie_jar, &clock); + + repo.save().await?; + + // Redirect to the user registration flow, in case we have any other step to + // finish + return Ok(( + cookie_jar, + url_builder + .redirect(&mas_router::RegisterFinish::new(registration.id)) + .into_response(), + )); + } + + // Else we show the upstream registration screen + let mut ctx = UpstreamRegister::new(link.clone(), provider.clone()); + if let Some(localpart) = localpart { ctx = ctx.with_localpart( localpart, @@ -686,6 +731,17 @@ pub(crate) async fn get( ); } + if let Some(displayname) = displayname { + ctx = ctx.with_display_name( + displayname, + provider.claims_imports.displayname.is_forced_or_required(), + ); + } + + if let Some(email) = email { + ctx = ctx.with_email(email, provider.claims_imports.email.is_forced_or_required()); + } + let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale); Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response() @@ -1002,17 +1058,19 @@ pub(crate) async fn post( REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]); - let mut registration = repo - .user_registration() - .add( - &mut rng, - &clock, - username, - activity_tracker.ip(), - user_agent, - post_auth_action.map(|action| serde_json::json!(action)), - ) - .await?; + let mut registration = prepare_user_registration( + &mut rng, + &clock, + &mut repo, + upstream_session, + username, + display_name, + email, + activity_tracker.ip(), + user_agent, + post_auth_action.map(|action| serde_json::json!(action)), + ) + .await?; if let Some(terms_url) = &site_config.tos_uri { registration = repo @@ -1021,44 +1079,6 @@ pub(crate) async fn post( .await?; } - // If we have an email, add an email authentication and complete it - if let Some(email) = email { - let authentication = repo - .user_email() - .add_authentication_for_registration(&mut rng, &clock, email, ®istration) - .await?; - let authentication = repo - .user_email() - .complete_authentication_with_upstream( - &clock, - authentication, - &upstream_session, - ) - .await?; - - registration = repo - .user_registration() - .set_email_authentication(registration, &authentication) - .await?; - } - - // If we have a display name, add it to the registration - if let Some(name) = display_name { - registration = repo - .user_registration() - .set_display_name(registration, name) - .await?; - } - - let registration = repo - .user_registration() - .set_upstream_oauth_authorization_session(registration, &upstream_session) - .await?; - - repo.upstream_oauth_session() - .consume(&clock, upstream_session) - .await?; - let registrations = UserRegistrationSessionsCookie::load(&cookie_jar); let cookie_jar = sessions_cookie @@ -1082,6 +1102,69 @@ pub(crate) async fn post( } } +/// Create a user registration using attributes got from the upstream +/// authorization session +async fn prepare_user_registration( + rng: &mut BoxRng, + clock: &BoxClock, + repo: &mut BoxRepository, + upstream_session: UpstreamOAuthAuthorizationSession, + localpart: String, + displayname: Option, + email: Option, + ip_address: Option, + user_agent: Option, + post_auth_action: Option, +) -> Result { + let mut registration = repo + .user_registration() + .add( + rng, + clock, + localpart, + ip_address, + user_agent, + post_auth_action, + ) + .await?; + + // If we have an email, add an email authentication and complete it + if let Some(email) = email { + let authentication = repo + .user_email() + .add_authentication_for_registration(rng, clock, email, ®istration) + .await?; + let authentication = repo + .user_email() + .complete_authentication_with_upstream(clock, authentication, &upstream_session) + .await?; + + registration = repo + .user_registration() + .set_email_authentication(registration, &authentication) + .await?; + } + + // If we have a display name, add it to the registration + if let Some(name) = displayname { + registration = repo + .user_registration() + .set_display_name(registration, name) + .await?; + } + + let registration = repo + .user_registration() + .set_upstream_oauth_authorization_session(registration, &upstream_session) + .await?; + + repo.upstream_oauth_session() + .consume(clock, upstream_session) + .await?; + + Ok(registration) +} + #[cfg(test)] mod tests { use hyper::{Request, StatusCode, header::CONTENT_TYPE};