Option to skip confirmation when registering through an upstream OAuth provider (#5296)

This commit is contained in:
Quentin Gliech
2025-12-03 13:46:31 +01:00
committed by GitHub
6 changed files with 373 additions and 74 deletions

View File

@@ -64,6 +64,7 @@ fn map_claims_imports(
subject: mas_data_model::UpstreamOAuthProviderSubjectPreference {
template: config.subject.template.clone(),
},
skip_confirmation: config.skip_confirmation,
localpart: mas_data_model::UpstreamOAuthProviderLocalpartPreference {
action: map_import_action(config.localpart.action),
template: config.localpart.template.clone(),

View File

@@ -118,6 +118,26 @@ impl ConfigurationSection for UpstreamOAuth2Config {
}
}
if provider.claims_imports.skip_confirmation {
if provider.claims_imports.localpart.action != ImportAction::Require {
return Err(annotate(figment::Error::custom(
"The field `action` must be `require` when `skip_confirmation` is set to `true`",
)).with_path("claims_imports.localpart").into());
}
if provider.claims_imports.email.action == ImportAction::Suggest {
return Err(annotate(figment::Error::custom(
"The field `action` must not be `suggest` when `skip_confirmation` is set to `true`",
)).with_path("claims_imports.email").into());
}
if provider.claims_imports.displayname.action == ImportAction::Suggest {
return Err(annotate(figment::Error::custom(
"The field `action` must not be `suggest` when `skip_confirmation` is set to `true`",
)).with_path("claims_imports.displayname").into());
}
}
if matches!(
provider.claims_imports.localpart.on_conflict,
OnConflict::Add | OnConflict::Replace | OnConflict::Set
@@ -333,6 +353,13 @@ pub struct ClaimsImports {
#[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")]
pub subject: SubjectImportPreference,
/// Whether to skip the interactive screen prompting the user to confirm the
/// attributes that are being imported. This requires `localpart.action` to
/// be `require` and other attribute actions to be either `ignore`, `force`
/// or `require`
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub skip_confirmation: bool,
/// Import the localpart of the MXID
#[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")]
pub localpart: LocalpartImportPreference,
@@ -344,8 +371,7 @@ pub struct ClaimsImports {
)]
pub displayname: DisplaynameImportPreference,
/// Import the email address of the user based on the `email` and
/// `email_verified` claims
/// Import the email address of the user
#[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
pub email: EmailImportPreference,
@@ -361,8 +387,10 @@ impl ClaimsImports {
const fn is_default(&self) -> bool {
self.subject.is_default()
&& self.localpart.is_default()
&& !self.skip_confirmation
&& self.displayname.is_default()
&& self.email.is_default()
&& self.account_name.is_default()
}
}

View File

@@ -312,6 +312,9 @@ pub struct ClaimsImports {
#[serde(default)]
pub subject: SubjectPreference,
#[serde(default)]
pub skip_confirmation: bool,
#[serde(default)]
pub localpart: LocalpartPreference,

View File

@@ -4,7 +4,10 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::sync::{Arc, LazyLock};
use std::{
net::IpAddr,
sync::{Arc, LazyLock},
};
use axum::{
Form,
@@ -19,7 +22,10 @@ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm},
record_error,
};
use mas_data_model::{BoxClock, BoxRng, UpstreamOAuthProviderOnConflict};
use mas_data_model::{
BoxClock, BoxRng, UpstreamOAuthAuthorizationSession, UpstreamOAuthProviderOnConflict,
UserRegistration,
};
use mas_jose::jwt::Jwt;
use mas_matrix::HomeserverConnection;
use mas_policy::Policy;
@@ -240,10 +246,6 @@ pub(crate) async fn get(
.lookup_link(link_id)
.map_err(|_| RouteError::MissingCookie)?;
let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};
let link = repo
.upstream_oauth_link()
.lookup(link_id)
@@ -287,6 +289,10 @@ pub(crate) async fn get(
repo.save().await?;
let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};
post_auth_action.go_next(&url_builder).into_response()
}
@@ -359,6 +365,10 @@ pub(crate) async fn get(
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?;
let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};
cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);
@@ -388,8 +398,6 @@ pub(crate) async fn get(
.await?
.ok_or(RouteError::ProviderNotFound(link.provider_id))?;
let mut ctx = UpstreamRegister::new(link.clone(), provider.clone());
let env = environment();
let mut context = AttributeMappingContext::new();
@@ -423,13 +431,6 @@ pub(crate) async fn get(
)?
};
if let Some(displayname) = displayname {
ctx = ctx.with_display_name(
displayname,
provider.claims_imports.displayname.is_forced_or_required(),
);
}
let email = if provider.claims_imports.email.ignore() {
None
} else {
@@ -448,13 +449,6 @@ pub(crate) async fn get(
)?
};
if let Some(ref email) = email {
ctx = ctx.with_email(
email.clone(),
provider.claims_imports.email.is_forced_or_required(),
);
}
// We do a bunch of checks for the localpart. Instead of using nested ifs all
// the way, we use a labelled block, and use `break` for 'exiting' early when
// needed
@@ -710,6 +704,10 @@ pub(crate) async fn get(
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?;
let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};
let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock)
@@ -767,6 +765,53 @@ pub(crate) async fn get(
Some(localpart)
};
if provider.claims_imports.skip_confirmation {
let Some(localpart) = localpart else {
return Err(RouteError::Internal(
"No localpart available even though the provider is configured to skip confirmation, this is a bug!".into()
));
};
// Register on the fly
REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
let registration = prepare_user_registration(
&mut rng,
&clock,
&mut repo,
upstream_session,
localpart,
displayname,
email,
activity_tracker.ip(),
user_agent,
post_auth_action.map(|action| serde_json::json!(action)),
)
.await?;
let registrations = UserRegistrationSessionsCookie::load(&cookie_jar);
let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);
let cookie_jar = registrations.add(&registration).save(cookie_jar, &clock);
repo.save().await?;
// Redirect to the user registration flow, in case we have any other step to
// finish
return Ok((
cookie_jar,
url_builder
.redirect(&mas_router::RegisterFinish::new(registration.id))
.into_response(),
));
}
// Else we show the upstream registration screen
let mut ctx = UpstreamRegister::new(link.clone(), provider.clone());
if let Some(localpart) = localpart {
ctx = ctx.with_localpart(
localpart,
@@ -774,6 +819,17 @@ pub(crate) async fn get(
);
}
if let Some(displayname) = displayname {
ctx = ctx.with_display_name(
displayname,
provider.claims_imports.displayname.is_forced_or_required(),
);
}
if let Some(email) = email {
ctx = ctx.with_email(email, provider.claims_imports.email.is_forced_or_required());
}
let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
@@ -1090,17 +1146,19 @@ pub(crate) async fn post(
REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
let mut registration = repo
.user_registration()
.add(
&mut rng,
&clock,
username,
activity_tracker.ip(),
user_agent,
post_auth_action.map(|action| serde_json::json!(action)),
)
.await?;
let mut registration = prepare_user_registration(
&mut rng,
&clock,
&mut repo,
upstream_session,
username,
display_name,
email,
activity_tracker.ip(),
user_agent,
post_auth_action.map(|action| serde_json::json!(action)),
)
.await?;
if let Some(terms_url) = &site_config.tos_uri {
registration = repo
@@ -1109,44 +1167,6 @@ pub(crate) async fn post(
.await?;
}
// If we have an email, add an email authentication and complete it
if let Some(email) = email {
let authentication = repo
.user_email()
.add_authentication_for_registration(&mut rng, &clock, email, &registration)
.await?;
let authentication = repo
.user_email()
.complete_authentication_with_upstream(
&clock,
authentication,
&upstream_session,
)
.await?;
registration = repo
.user_registration()
.set_email_authentication(registration, &authentication)
.await?;
}
// If we have a display name, add it to the registration
if let Some(name) = display_name {
registration = repo
.user_registration()
.set_display_name(registration, name)
.await?;
}
let registration = repo
.user_registration()
.set_upstream_oauth_authorization_session(registration, &upstream_session)
.await?;
repo.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;
let registrations = UserRegistrationSessionsCookie::load(&cookie_jar);
let cookie_jar = sessions_cookie
@@ -1170,6 +1190,69 @@ pub(crate) async fn post(
}
}
/// Create a user registration using attributes got from the upstream
/// authorization session
async fn prepare_user_registration(
rng: &mut BoxRng,
clock: &BoxClock,
repo: &mut BoxRepository,
upstream_session: UpstreamOAuthAuthorizationSession,
localpart: String,
displayname: Option<String>,
email: Option<String>,
ip_address: Option<IpAddr>,
user_agent: Option<String>,
post_auth_action: Option<serde_json::Value>,
) -> Result<UserRegistration, RouteError> {
let mut registration = repo
.user_registration()
.add(
rng,
clock,
localpart,
ip_address,
user_agent,
post_auth_action,
)
.await?;
// If we have an email, add an email authentication and complete it
if let Some(email) = email {
let authentication = repo
.user_email()
.add_authentication_for_registration(rng, clock, email, &registration)
.await?;
let authentication = repo
.user_email()
.complete_authentication_with_upstream(clock, authentication, &upstream_session)
.await?;
registration = repo
.user_registration()
.set_email_authentication(registration, &authentication)
.await?;
}
// If we have a display name, add it to the registration
if let Some(name) = displayname {
registration = repo
.user_registration()
.set_display_name(registration, name)
.await?;
}
let registration = repo
.user_registration()
.set_upstream_oauth_authorization_session(registration, &upstream_session)
.await?;
repo.upstream_oauth_session()
.consume(clock, upstream_session)
.await?;
Ok(registration)
}
#[cfg(test)]
mod tests {
use hyper::{Request, StatusCode, header::CONTENT_TYPE};
@@ -1386,6 +1469,178 @@ mod tests {
assert!(email_auth.completed_at.is_some());
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_register_skip_confirmation(pool: PgPool) {
// Same test as test_register, but checks that we get straight to the
// registration flow skipping the confirmation
setup();
let state = TestState::from_pool(pool).await.unwrap();
let mut rng = state.rng();
let cookies = CookieHelper::new();
let claims_imports = UpstreamOAuthProviderClaimsImports {
skip_confirmation: true,
localpart: UpstreamOAuthProviderLocalpartPreference {
action: mas_data_model::UpstreamOAuthProviderImportAction::Require,
template: None,
on_conflict: mas_data_model::UpstreamOAuthProviderOnConflict::default(),
},
email: UpstreamOAuthProviderImportPreference {
action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
template: None,
},
..UpstreamOAuthProviderClaimsImports::default()
};
let id_token_claims = serde_json::json!({
"preferred_username": "john",
"email": "john@example.com",
"email_verified": true,
});
// Grab a key to sign the id_token
// We could generate a key on the fly, but because we have one available here,
// why not use it?
let key = state
.key_store
.signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256)
.unwrap();
let signer = key
.params()
.signing_key_for_alg(&JsonWebSignatureAlg::Rs256)
.unwrap();
let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256);
let id_token =
Jwt::sign_with_rng(&mut rng, header, id_token_claims.clone(), &signer).unwrap();
// Provision a provider and a link
let mut repo = state.repository().await.unwrap();
let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
&state.clock,
UpstreamOAuthProviderParams {
issuer: Some("https://example.com/".to_owned()),
human_name: Some("Example Ltd.".to_owned()),
brand_name: None,
scope: Scope::from_iter([OPENID]),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
token_endpoint_signing_alg: None,
id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports,
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
fetch_userinfo: false,
userinfo_signed_response_alg: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
response_mode: None,
additional_authorization_parameters: Vec::new(),
forward_login_hint: false,
ui_order: 0,
on_backchannel_logout:
mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
},
)
.await
.unwrap();
let session = repo
.upstream_oauth_session()
.add(
&mut rng,
&state.clock,
&provider,
"state".to_owned(),
None,
None,
)
.await
.unwrap();
let link = repo
.upstream_oauth_link()
.add(
&mut rng,
&state.clock,
&provider,
"subject".to_owned(),
None,
)
.await
.unwrap();
let session = repo
.upstream_oauth_session()
.complete_with_link(
&state.clock,
session,
&link,
Some(id_token.into_string()),
Some(id_token_claims),
None,
None,
)
.await
.unwrap();
repo.save().await.unwrap();
let cookie_jar = state.cookie_jar();
let upstream_sessions = UpstreamSessionsCookie::default()
.add(session.id, provider.id, "state".to_owned(), None)
.add_link_to_session(session.id, link.id)
.unwrap();
let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock);
cookies.import(cookie_jar);
let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty();
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
let location = response.headers().get(hyper::header::LOCATION).unwrap();
// Grab the registration ID from the redirected URL:
// /register/steps/{id}/finish
let registration_id: Ulid = str::from_utf8(location.as_bytes())
.unwrap()
.rsplit('/')
.nth(1)
.expect("Location to have two slashes")
.parse()
.expect("last segment of location to be a ULID");
// Check that we have a registered user, with the email imported
let mut repo = state.repository().await.unwrap();
let registration: UserRegistration = repo
.user_registration()
.lookup(registration_id)
.await
.unwrap()
.expect("user registration exists");
assert_eq!(registration.password, None);
assert_eq!(registration.completed_at, None);
assert_eq!(registration.username, "john");
let email_auth_id = registration
.email_authentication_id
.expect("registration should have an email authentication");
let email_auth: UserEmailAuthentication = repo
.user_email()
.lookup_authentication(email_auth_id)
.await
.unwrap()
.expect("email authentication should exist");
assert_eq!(email_auth.email, "john@example.com");
assert!(email_auth.completed_at.is_some());
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_link_existing_account(pool: PgPool) {
let existing_username = "john";

View File

@@ -2471,6 +2471,10 @@
}
]
},
"skip_confirmation": {
"description": "Whether to skip the interactive screen prompting the user to confirm the\n attributes that are being imported. This requires `localpart.action` to\n be `require` and other attribute actions to be either `ignore`, `force`\n or `require`",
"type": "boolean"
},
"localpart": {
"description": "Import the localpart of the MXID",
"allOf": [
@@ -2488,7 +2492,7 @@
]
},
"email": {
"description": "Import the email address of the user based on the `email` and\n `email_verified` claims",
"description": "Import the email address of the user",
"allOf": [
{
"$ref": "#/definitions/EmailImportPreference"

View File

@@ -771,6 +771,14 @@ upstream_oauth2:
subject:
#template: "{{ user.sub }}"
# By default, new users will see a screen confirming the attributes they
# are about to have on their account.
#
# Setting this to `true` allows skipping this screen, but requires the
# `localpart.action` to be set to `require` and the other attributes
# actions to be set to `ignore`, `force` or `require`.
#skip_confirmation: false
# The localpart is the local part of the user's Matrix ID.
# For example, on the `example.com` server, if the localpart is `alice`,
# the user's Matrix ID will be `@alice:example.com`.