From 8b2addbe0e5c4b59c67e046dc1c5ac5a7055cb56 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 27 Nov 2025 15:50:58 +0100 Subject: [PATCH 1/3] Allow linking upstream accounts to matching users without confirmation This reworks the link flow to handle many edge cases better. One major functionality change is that when we had a new upstream account with no user linked, but the localpart matching an existing user, if `on_conflict` was set to `add`, we prompt the user to link the existing account. This prompt is now skipped and the user is linked automatically. --- crates/handlers/src/upstream_oauth2/link.rs | 479 ++++++++++---------- 1 file changed, 236 insertions(+), 243 deletions(-) diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 96d1b0180..c794eff3e 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -386,7 +386,7 @@ pub(crate) async fn get( .await? .ok_or(RouteError::ProviderNotFound(link.provider_id))?; - let ctx = UpstreamRegister::new(link.clone(), provider.clone()); + let mut ctx = UpstreamRegister::new(link.clone(), provider.clone()); let env = environment(); @@ -403,8 +403,8 @@ pub(crate) async fn get( } let context = context.build(); - let ctx = if provider.claims_imports.displayname.ignore() { - ctx + let displayname = if provider.claims_imports.displayname.ignore() { + None } else { let template = provider .claims_imports @@ -413,22 +413,23 @@ pub(crate) async fn get( .as_deref() .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE); - match render_attribute_template( + render_attribute_template( &env, template, &context, provider.claims_imports.displayname.is_required(), - )? { - Some(value) => ctx.with_display_name( - value, - provider.claims_imports.displayname.is_forced_or_required(), - ), - None => ctx, - } + )? }; - let ctx = if provider.claims_imports.email.ignore() { - ctx + if let Some(displayname) = displayname { + ctx = ctx.with_display_name( + displayname, + provider.claims_imports.displayname.is_forced_or_required(), + ); + } + + let email = if provider.claims_imports.email.ignore() { + None } else { let template = provider .claims_imports @@ -437,22 +438,29 @@ pub(crate) async fn get( .as_deref() .unwrap_or(DEFAULT_EMAIL_TEMPLATE); - match render_attribute_template( + render_attribute_template( &env, template, &context, provider.claims_imports.email.is_required(), - )? { - Some(value) => { - ctx.with_email(value, provider.claims_imports.email.is_forced_or_required()) - } - None => ctx, - } + )? }; - let ctx = if provider.claims_imports.localpart.ignore() { - ctx - } else { + if let Some(ref email) = email { + ctx = ctx.with_email( + email.clone(), + provider.claims_imports.email.is_forced_or_required(), + ); + } + + // We do a bunch of checks for the localpart. Instead of using nested ifs all + // the way, we use a labelled block, and use `break` for 'exitting' early when + // needed + let localpart = 'localpart: { + if provider.claims_imports.localpart.ignore() { + break 'localpart None; + } + let template = provider .claims_imports .localpart @@ -460,101 +468,98 @@ pub(crate) async fn get( .as_deref() .unwrap_or(DEFAULT_LOCALPART_TEMPLATE); - match render_attribute_template( + let Some(localpart) = render_attribute_template( &env, template, &context, provider.claims_imports.localpart.is_required(), - )? { - Some(localpart) => { - // We could run policy & existing user checks when the user submits the - // form, but this lead to poor UX. This is why we do - // it ahead of time here. - let maybe_existing_user = repo.user().find_by_username(&localpart).await?; - let is_available = homeserver - .is_localpart_available(&localpart) - .await - .map_err(RouteError::HomeserverConnection)?; + )? + else { + break 'localpart None; + }; - if let Some(existing_user) = maybe_existing_user { - // The mapper returned a username which already exists, but isn't - // linked to this upstream user. - let on_conflict = provider.claims_imports.localpart.on_conflict; + let forced_or_required = provider.claims_imports.localpart.is_forced_or_required(); - match on_conflict { - UpstreamOAuthProviderOnConflict::Fail => { - // TODO: translate - let ctx = ErrorContext::new() - .with_code("User exists") - .with_description(format!( - r"Upstream account provider returned {localpart:?} as username, - which is not linked to that upstream account. Your homeserver does not allow - linking an upstream account to an existing account" - )) - .with_language(&locale); + // We've got a localpart from the template. Let's run the policy + // engine on this registration and react early to a problem on + // the username + let res = policy + .evaluate_register(mas_policy::RegisterInput { + registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2, + username: &localpart, + email: email.as_deref(), + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent: user_agent.clone(), + }, + }) + .await?; - return Ok(( - cookie_jar, - Html(templates.render_error(&ctx)?).into_response(), - )); - } - UpstreamOAuthProviderOnConflict::Add => { - // new oauth link is allowed - let ctx = UpstreamExistingLinkContext::new(existing_user) - .with_csrf(csrf_token.form_value()) - .with_language(locale); + // We don't do a full policy check at this point, only look for violations on + // the username + if res + .violations + .iter() + .any(|violation| violation.field.as_deref() == Some("username")) + { + if !forced_or_required { + tracing::warn!( + upstream_oauth_provider.id = %provider.id, + upstream_oauth_link.id = %link.id, + "Upstream provider returned a localpart {localpart:?} which was denied by the policy ({res}). As the username is just a suggestion, it was ignored." + ); + break 'localpart None; + } - return Ok(( - cookie_jar, - Html(templates.render_upstream_oauth2_login_link(&ctx)?) - .into_response(), - )); - } - } - } + // If the username policy check fails, we display an error message. + // TODO: translate + let ctx = ErrorContext::new() + .with_code("Policy error") + .with_description(format!( + r"Upstream account provider returned {localpart:?} as username, + which does not pass the policy check: {res}" + )) + .with_language(&locale); + + return Ok(( + cookie_jar, + Html(templates.render_error(&ctx)?).into_response(), + )); + } + + // We got a localpart from the template. We need to check if it's + // available, and if it's not apply the conflict resolution setup in + // the config + let maybe_existing_user = repo.user().find_by_username(&localpart).await?; + if let Some(existing_user) = maybe_existing_user { + if !forced_or_required { + tracing::warn!( + upstream_oauth_provider.id = %provider.id, + upstream_oauth_link.id = %link.id, + user.id = %existing_user.id, + "Upstream provider returned a localpart {localpart:?} which is already used by another user. As the username is just a suggestion, it was ignored." + ); + break 'localpart None; + } + + match provider.claims_imports.localpart.on_conflict { + // We matched an existing user, but the server doesn't allow us to link to + // existing users automatically. In this case, we error out + UpstreamOAuthProviderOnConflict::Fail => { + tracing::warn!( + upstream_oauth_provider.id = %provider.id, + upstream_oauth_link.id = %link.id, + user.id = %existing_user.id, + "Upstream provider returned a localpart {localpart:?} which is already used by another user. Configuration doesn't allow for automatic linking of existing users." + ); - if !is_available { // TODO: translate let ctx = ErrorContext::new() - .with_code("Localpart not available") - .with_description(format!( - r"Localpart {localpart:?} is not available on this homeserver" - )) - .with_language(&locale); - - return Ok(( - cookie_jar, - Html(templates.render_error(&ctx)?).into_response(), - )); - } - - let res = policy - .evaluate_register(mas_policy::RegisterInput { - registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2, - username: &localpart, - email: None, - requester: mas_policy::Requester { - ip_address: activity_tracker.ip(), - user_agent: user_agent.clone(), - }, - }) - .await?; - - if res.valid() { - // The username passes the policy check, add it to the context - ctx.with_localpart( - localpart, - provider.claims_imports.localpart.is_forced_or_required(), - ) - } else if provider.claims_imports.localpart.is_forced_or_required() { - // If the username claim is 'forced' but doesn't pass the policy check, - // we display an error message. - // TODO: translate - let ctx = ErrorContext::new() - .with_code("Policy error") + .with_code("User exists") .with_description(format!( r"Upstream account provider returned {localpart:?} as username, - which does not pass the policy check: {res}" + which is not linked to that upstream account. Your homeserver does not allow + linking an upstream account to an existing account" )) .with_language(&locale); @@ -562,15 +567,125 @@ pub(crate) async fn get( cookie_jar, Html(templates.render_error(&ctx)?).into_response(), )); - } else { - // Else, we just ignore it when it doesn't pass the policy check. - ctx + } + + // We matched an existing user and the conflict resolution is to add the + // link to the existing user. In this case, we add the link + UpstreamOAuthProviderOnConflict::Add => { + tracing::info!( + user.id = %existing_user.id, + upstream_oauth_provider.id = %provider.id, + upstream_oauth_link.id = %link.id, + upstream_oauth_link.subject = link.subject, + "Upstream account mapped localpart {localpart:?} matched an existing user, linking" + ); + + // Add link to the user + repo.upstream_oauth_link() + .associate_to_user(&link, &existing_user) + .await?; } } - None => ctx, + + // Now that we've resolved the conflict, log in that existing user + + // Check that the user is not locked or deactivated + if existing_user.deactivated_at.is_some() { + // The account is deactivated, show the 'account deactivated' fallback + let ctx = AccountInactiveContext::new(existing_user) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + let fallback = templates.render_account_deactivated(&ctx)?; + return Ok((cookie_jar, Html(fallback).into_response())); + } + + if existing_user.locked_at.is_some() { + // The account is locked, show the 'account locked' fallback + let ctx = AccountInactiveContext::new(existing_user) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + let fallback = templates.render_account_locked(&ctx)?; + return Ok((cookie_jar, Html(fallback).into_response())); + } + + let session = repo + .browser_session() + .add(&mut rng, &clock, &existing_user, user_agent) + .await?; + + let upstream_session = repo + .upstream_oauth_session() + .consume(&clock, upstream_session) + .await?; + + repo.browser_session() + .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session) + .await?; + + let cookie_jar = sessions_cookie + .consume_link(link_id)? + .save(cookie_jar, &clock) + .set_session(&session); + + repo.save().await?; + + // Count this 'on-the-fly' linking as a login + LOGIN_COUNTER.add( + 1, + &[KeyValue::new( + PROVIDER, + upstream_session.provider_id.to_string(), + )], + ); + + return Ok(( + cookie_jar, + post_auth_action.go_next(&url_builder).into_response(), + )); } + + // Now let's check if the localpart is allowed by the homeserver. It's possible + // that it's plain invalid (although that should have been caught by the + // policy), or just reserved by an application service + let is_available = homeserver + .is_localpart_available(&localpart) + .await + .map_err(RouteError::HomeserverConnection)?; + + if !is_available { + if !forced_or_required { + tracing::warn!( + upstream_oauth_provider.id = %provider.id, + upstream_oauth_link.id = %link.id, + "Upstream provider returned a localpart {localpart:?} which isn't available on the homeserver. As the username is just a suggestion, it was ignored." + ); + break 'localpart None; + } + + // TODO: translate + let ctx = ErrorContext::new() + .with_code("Localpart not available") + .with_description(format!( + r"Localpart {localpart:?} is not available on this homeserver" + )) + .with_language(&locale); + + return Ok(( + cookie_jar, + Html(templates.render_error(&ctx)?).into_response(), + )); + } + + Some(localpart) }; + if let Some(localpart) = localpart { + ctx = ctx.with_localpart( + localpart, + provider.claims_imports.localpart.is_forced_or_required(), + ); + } + let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale); Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response() @@ -667,104 +782,6 @@ pub(crate) async fn post( Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response()) } - (None, None, FormData::Link) => { - // There is an existing user with the same username, but no link. - // If the configuration allows it, the user is prompted to link the - // existing account. Note that we cannot trust the user input here, - // which is why we have to re-calculate the localpart, instead of - // passing it through form data. - - let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?; - - let provider = repo - .upstream_oauth_provider() - .lookup(link.provider_id) - .await? - .ok_or(RouteError::ProviderNotFound(link.provider_id))?; - - let env = environment(); - - let mut context = AttributeMappingContext::new(); - if let Some(id_token) = id_token { - let (_, payload) = id_token.into_parts(); - context = context.with_id_token_claims(payload); - } - if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() { - context = context.with_extra_callback_parameters(extra_callback_parameters.clone()); - } - if let Some(userinfo) = upstream_session.userinfo() { - context = context.with_userinfo_claims(userinfo.clone()); - } - let context = context.build(); - - if !provider.claims_imports.localpart.is_forced_or_required() { - //Claims import for `localpart` should be `require` or `force` at this stage - return Err(RouteError::InvalidFormAction); - } - - let template = provider - .claims_imports - .localpart - .template - .as_deref() - .unwrap_or(DEFAULT_LOCALPART_TEMPLATE); - - let Some(localpart) = render_attribute_template(&env, template, &context, true)? else { - // This should never be the case at this point - return Err(RouteError::InvalidFormAction); - }; - - let maybe_user = repo.user().find_by_username(&localpart).await?; - - let Some(user) = maybe_user else { - // user cannot be None at this stage - return Err(RouteError::InvalidFormAction); - }; - - let on_conflict = provider.claims_imports.localpart.on_conflict; - - match on_conflict { - UpstreamOAuthProviderOnConflict::Fail => { - //OnConflict can not be equals to Fail at this stage - return Err(RouteError::InvalidFormAction); - } - UpstreamOAuthProviderOnConflict::Add => { - // Add link to the user - repo.upstream_oauth_link() - .associate_to_user(&link, &user) - .await?; - - // And sign in the user - let session = repo - .browser_session() - .add(&mut rng, &clock, &user, user_agent) - .await?; - - let upstream_session = repo - .upstream_oauth_session() - .consume(&clock, upstream_session) - .await?; - - repo.browser_session() - .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session) - .await?; - - let post_auth_action = OptionalPostAuthAction { - post_auth_action: post_auth_action.cloned(), - }; - - let cookie_jar = sessions_cookie - .consume_link(link_id)? - .save(cookie_jar, &clock); - let cookie_jar = cookie_jar.set_session(&session); - - repo.save().await?; - - Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response()) - } - } - } - ( None, None, @@ -810,7 +827,7 @@ pub(crate) async fn post( let context = context.build(); // Create a template context in case we need to re-render because of an error - let ctx = UpstreamRegister::new(link.clone(), provider.clone()); + let mut ctx = UpstreamRegister::new(link.clone(), provider.clone()); let display_name = if provider .claims_imports @@ -834,14 +851,12 @@ pub(crate) async fn post( None }; - let ctx = if let Some(ref display_name) = display_name { - ctx.with_display_name( + if let Some(ref display_name) = display_name { + ctx = ctx.with_display_name( display_name.clone(), provider.claims_imports.email.is_forced_or_required(), - ) - } else { - ctx - }; + ); + } let email = if provider.claims_imports.email.should_import(import_email) { let template = provider @@ -861,14 +876,12 @@ pub(crate) async fn post( None }; - let ctx = if let Some(ref email) = email { - ctx.with_email( + if let Some(ref email) = email { + ctx = ctx.with_email( email.clone(), provider.claims_imports.email.is_forced_or_required(), - ) - } else { - ctx - }; + ); + } let username = if provider.claims_imports.localpart.is_forced_or_required() { let template = provider @@ -885,7 +898,7 @@ pub(crate) async fn post( } .unwrap_or_default(); - let ctx = ctx.with_localpart( + ctx = ctx.with_localpart( username.clone(), provider.claims_imports.localpart.is_forced_or_required(), ); @@ -1299,6 +1312,8 @@ mod tests { localpart: UpstreamOAuthProviderLocalpartPreference { action: mas_data_model::UpstreamOAuthProviderImportAction::Require, template: None, + // This is the important bit: this will automatically link + // existing accounts if the localpart matches on_conflict: mas_data_model::UpstreamOAuthProviderOnConflict::Add, }, email: UpstreamOAuthProviderImportPreference { @@ -1387,28 +1402,6 @@ mod tests { let request = cookies.with_cookies(request); let response = state.request(request).await; cookies.save_cookies(&response); - response.assert_status(StatusCode::OK); - response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); - - // Extract the CSRF token from the response body - let csrf_token = response - .body() - .split("name=\"csrf\" value=\"") - .nth(1) - .unwrap() - .split('\"') - .next() - .unwrap(); - - let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form( - serde_json::json!({ - "csrf": csrf_token, - "action": "link" - }), - ); - let request = cookies.with_cookies(request); - let response = state.request(request).await; - cookies.save_cookies(&response); response.assert_status(StatusCode::SEE_OTHER); // Check that the existing user has the oidc link From e90f11b8f83fa6645780fcda55b94bb8bc81d292 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 27 Nov 2025 15:58:39 +0100 Subject: [PATCH 2/3] Remove unused login_link.html template --- crates/templates/src/context.rs | 2 +- crates/templates/src/lib.rs | 3 -- .../pages/upstream_oauth2/login_link.html | 31 ------------------- translations/en.json | 14 --------- 4 files changed, 1 insertion(+), 49 deletions(-) delete mode 100644 templates/pages/upstream_oauth2/login_link.html diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 4ed09c3e1..f836d7c4b 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -1456,7 +1456,7 @@ impl TemplateContext for RecoveryFinishContext { } } -/// Context used by the `pages/upstream_oauth2/{link_mismatch,login_link}.html` +/// Context used by the `pages/upstream_oauth2/link_mismatch.html` /// templates #[derive(Serialize)] pub struct UpstreamExistingLinkContext { diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 32a41e8b2..1f9aa3337 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -442,9 +442,6 @@ register_templates! { /// Render the upstream link mismatch message pub fn render_upstream_oauth2_link_mismatch(WithLanguage>>) { "pages/upstream_oauth2/link_mismatch.html" } - /// Render the upstream link match - pub fn render_upstream_oauth2_login_link(WithLanguage>) { "pages/upstream_oauth2/login_link.html" } - /// Render the upstream suggest link message pub fn render_upstream_oauth2_suggest_link(WithLanguage>>) { "pages/upstream_oauth2/suggest_link.html" } diff --git a/templates/pages/upstream_oauth2/login_link.html b/templates/pages/upstream_oauth2/login_link.html deleted file mode 100644 index cdde102b2..000000000 --- a/templates/pages/upstream_oauth2/login_link.html +++ /dev/null @@ -1,31 +0,0 @@ -{# -Copyright 2025 New Vector Ltd. - -SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial -Please see LICENSE in the repository root for full details. --#} - -{% extends "base.html" %} - -{% block content %} -
-
- {{ icon.link() }} -
- -
-

{{ _("mas.upstream_oauth2.login_link.heading") }}

-
-
-
- {{ _("mas.upstream_oauth2.login_link.description", username=linked_user.username) }} - -
- - - - {{ button.button(text=_("mas.upstream_oauth2.login_link.action")) }} -
- -
-{% endblock content %} diff --git a/translations/en.json b/translations/en.json index cdf2df82d..c935542f3 100644 --- a/translations/en.json +++ b/translations/en.json @@ -697,20 +697,6 @@ "description": "Page shown when the user tries to link an upstream account that is already linked to another account" } }, - "login_link": { - "action": "Continue", - "@action": { - "context": "pages/upstream_oauth2/login_link.html:27:28-70" - }, - "description": "An account exists for this username (%(username)s), it will be linked to this upstream account.", - "@description": { - "context": "pages/upstream_oauth2/login_link.html:21:7-85" - }, - "heading": "Link to your existing account", - "@heading": { - "context": "pages/upstream_oauth2/login_link.html:17:27-70" - } - }, "register": { "choose_username": { "description": "This cannot be changed later.", From 5e1100d22f7862cd533284a7d80ffd450cc8d603 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 27 Nov 2025 17:23:10 +0100 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- crates/handlers/src/upstream_oauth2/link.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index c794eff3e..c6c1de95d 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -454,7 +454,7 @@ pub(crate) async fn get( } // We do a bunch of checks for the localpart. Instead of using nested ifs all - // the way, we use a labelled block, and use `break` for 'exitting' early when + // the way, we use a labelled block, and use `break` for 'exiting' early when // needed let localpart = 'localpart: { if provider.claims_imports.localpart.ignore() { @@ -854,7 +854,7 @@ pub(crate) async fn post( if let Some(ref display_name) = display_name { ctx = ctx.with_display_name( display_name.clone(), - provider.claims_imports.email.is_forced_or_required(), + provider.claims_imports.displayname.is_forced_or_required(), ); }