Skip the attributes confirmation screen if configured to do so
This commit is contained in:
@@ -4,7 +4,10 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::{
|
||||
net::IpAddr,
|
||||
sync::{Arc, LazyLock},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
Form,
|
||||
@@ -19,7 +22,10 @@ use mas_axum_utils::{
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
record_error,
|
||||
};
|
||||
use mas_data_model::{BoxClock, BoxRng, UpstreamOAuthProviderOnConflict};
|
||||
use mas_data_model::{
|
||||
BoxClock, BoxRng, UpstreamOAuthAuthorizationSession, UpstreamOAuthProviderOnConflict,
|
||||
UserRegistration,
|
||||
};
|
||||
use mas_jose::jwt::Jwt;
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_policy::Policy;
|
||||
@@ -238,10 +244,6 @@ pub(crate) async fn get(
|
||||
.lookup_link(link_id)
|
||||
.map_err(|_| RouteError::MissingCookie)?;
|
||||
|
||||
let post_auth_action = OptionalPostAuthAction {
|
||||
post_auth_action: post_auth_action.cloned(),
|
||||
};
|
||||
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.lookup(link_id)
|
||||
@@ -285,6 +287,10 @@ pub(crate) async fn get(
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
let post_auth_action = OptionalPostAuthAction {
|
||||
post_auth_action: post_auth_action.cloned(),
|
||||
};
|
||||
|
||||
post_auth_action.go_next(&url_builder).into_response()
|
||||
}
|
||||
|
||||
@@ -357,6 +363,10 @@ pub(crate) async fn get(
|
||||
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
|
||||
.await?;
|
||||
|
||||
let post_auth_action = OptionalPostAuthAction {
|
||||
post_auth_action: post_auth_action.cloned(),
|
||||
};
|
||||
|
||||
cookie_jar = sessions_cookie
|
||||
.consume_link(link_id)?
|
||||
.save(cookie_jar, &clock);
|
||||
@@ -386,8 +396,6 @@ pub(crate) async fn get(
|
||||
.await?
|
||||
.ok_or(RouteError::ProviderNotFound(link.provider_id))?;
|
||||
|
||||
let mut ctx = UpstreamRegister::new(link.clone(), provider.clone());
|
||||
|
||||
let env = environment();
|
||||
|
||||
let mut context = AttributeMappingContext::new();
|
||||
@@ -421,13 +429,6 @@ pub(crate) async fn get(
|
||||
)?
|
||||
};
|
||||
|
||||
if let Some(displayname) = displayname {
|
||||
ctx = ctx.with_display_name(
|
||||
displayname,
|
||||
provider.claims_imports.displayname.is_forced_or_required(),
|
||||
);
|
||||
}
|
||||
|
||||
let email = if provider.claims_imports.email.ignore() {
|
||||
None
|
||||
} else {
|
||||
@@ -446,13 +447,6 @@ pub(crate) async fn get(
|
||||
)?
|
||||
};
|
||||
|
||||
if let Some(ref email) = email {
|
||||
ctx = ctx.with_email(
|
||||
email.clone(),
|
||||
provider.claims_imports.email.is_forced_or_required(),
|
||||
);
|
||||
}
|
||||
|
||||
// We do a bunch of checks for the localpart. Instead of using nested ifs all
|
||||
// the way, we use a labelled block, and use `break` for 'exiting' early when
|
||||
// needed
|
||||
@@ -622,6 +616,10 @@ pub(crate) async fn get(
|
||||
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
|
||||
.await?;
|
||||
|
||||
let post_auth_action = OptionalPostAuthAction {
|
||||
post_auth_action: post_auth_action.cloned(),
|
||||
};
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
.consume_link(link_id)?
|
||||
.save(cookie_jar, &clock)
|
||||
@@ -679,6 +677,53 @@ pub(crate) async fn get(
|
||||
Some(localpart)
|
||||
};
|
||||
|
||||
if provider.claims_imports.skip_confirmation {
|
||||
let Some(localpart) = localpart else {
|
||||
return Err(RouteError::Internal(
|
||||
"No localpart available even though the provider is configured to skip confirmation, this is a bug!".into()
|
||||
));
|
||||
};
|
||||
|
||||
// Register on the fly
|
||||
REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
|
||||
|
||||
let registration = prepare_user_registration(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&mut repo,
|
||||
upstream_session,
|
||||
localpart,
|
||||
displayname,
|
||||
email,
|
||||
activity_tracker.ip(),
|
||||
user_agent,
|
||||
post_auth_action.map(|action| serde_json::json!(action)),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let registrations = UserRegistrationSessionsCookie::load(&cookie_jar);
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
.consume_link(link_id)?
|
||||
.save(cookie_jar, &clock);
|
||||
|
||||
let cookie_jar = registrations.add(®istration).save(cookie_jar, &clock);
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
// Redirect to the user registration flow, in case we have any other step to
|
||||
// finish
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
url_builder
|
||||
.redirect(&mas_router::RegisterFinish::new(registration.id))
|
||||
.into_response(),
|
||||
));
|
||||
}
|
||||
|
||||
// Else we show the upstream registration screen
|
||||
let mut ctx = UpstreamRegister::new(link.clone(), provider.clone());
|
||||
|
||||
if let Some(localpart) = localpart {
|
||||
ctx = ctx.with_localpart(
|
||||
localpart,
|
||||
@@ -686,6 +731,17 @@ pub(crate) async fn get(
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(displayname) = displayname {
|
||||
ctx = ctx.with_display_name(
|
||||
displayname,
|
||||
provider.claims_imports.displayname.is_forced_or_required(),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(email) = email {
|
||||
ctx = ctx.with_email(email, provider.claims_imports.email.is_forced_or_required());
|
||||
}
|
||||
|
||||
let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
|
||||
|
||||
Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
|
||||
@@ -1002,17 +1058,19 @@ pub(crate) async fn post(
|
||||
|
||||
REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
|
||||
|
||||
let mut registration = repo
|
||||
.user_registration()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
username,
|
||||
activity_tracker.ip(),
|
||||
user_agent,
|
||||
post_auth_action.map(|action| serde_json::json!(action)),
|
||||
)
|
||||
.await?;
|
||||
let mut registration = prepare_user_registration(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&mut repo,
|
||||
upstream_session,
|
||||
username,
|
||||
display_name,
|
||||
email,
|
||||
activity_tracker.ip(),
|
||||
user_agent,
|
||||
post_auth_action.map(|action| serde_json::json!(action)),
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Some(terms_url) = &site_config.tos_uri {
|
||||
registration = repo
|
||||
@@ -1021,44 +1079,6 @@ pub(crate) async fn post(
|
||||
.await?;
|
||||
}
|
||||
|
||||
// If we have an email, add an email authentication and complete it
|
||||
if let Some(email) = email {
|
||||
let authentication = repo
|
||||
.user_email()
|
||||
.add_authentication_for_registration(&mut rng, &clock, email, ®istration)
|
||||
.await?;
|
||||
let authentication = repo
|
||||
.user_email()
|
||||
.complete_authentication_with_upstream(
|
||||
&clock,
|
||||
authentication,
|
||||
&upstream_session,
|
||||
)
|
||||
.await?;
|
||||
|
||||
registration = repo
|
||||
.user_registration()
|
||||
.set_email_authentication(registration, &authentication)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// If we have a display name, add it to the registration
|
||||
if let Some(name) = display_name {
|
||||
registration = repo
|
||||
.user_registration()
|
||||
.set_display_name(registration, name)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let registration = repo
|
||||
.user_registration()
|
||||
.set_upstream_oauth_authorization_session(registration, &upstream_session)
|
||||
.await?;
|
||||
|
||||
repo.upstream_oauth_session()
|
||||
.consume(&clock, upstream_session)
|
||||
.await?;
|
||||
|
||||
let registrations = UserRegistrationSessionsCookie::load(&cookie_jar);
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
@@ -1082,6 +1102,69 @@ pub(crate) async fn post(
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a user registration using attributes got from the upstream
|
||||
/// authorization session
|
||||
async fn prepare_user_registration(
|
||||
rng: &mut BoxRng,
|
||||
clock: &BoxClock,
|
||||
repo: &mut BoxRepository,
|
||||
upstream_session: UpstreamOAuthAuthorizationSession,
|
||||
localpart: String,
|
||||
displayname: Option<String>,
|
||||
email: Option<String>,
|
||||
ip_address: Option<IpAddr>,
|
||||
user_agent: Option<String>,
|
||||
post_auth_action: Option<serde_json::Value>,
|
||||
) -> Result<UserRegistration, RouteError> {
|
||||
let mut registration = repo
|
||||
.user_registration()
|
||||
.add(
|
||||
rng,
|
||||
clock,
|
||||
localpart,
|
||||
ip_address,
|
||||
user_agent,
|
||||
post_auth_action,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// If we have an email, add an email authentication and complete it
|
||||
if let Some(email) = email {
|
||||
let authentication = repo
|
||||
.user_email()
|
||||
.add_authentication_for_registration(rng, clock, email, ®istration)
|
||||
.await?;
|
||||
let authentication = repo
|
||||
.user_email()
|
||||
.complete_authentication_with_upstream(clock, authentication, &upstream_session)
|
||||
.await?;
|
||||
|
||||
registration = repo
|
||||
.user_registration()
|
||||
.set_email_authentication(registration, &authentication)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// If we have a display name, add it to the registration
|
||||
if let Some(name) = displayname {
|
||||
registration = repo
|
||||
.user_registration()
|
||||
.set_display_name(registration, name)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let registration = repo
|
||||
.user_registration()
|
||||
.set_upstream_oauth_authorization_session(registration, &upstream_session)
|
||||
.await?;
|
||||
|
||||
repo.upstream_oauth_session()
|
||||
.consume(clock, upstream_session)
|
||||
.await?;
|
||||
|
||||
Ok(registration)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use hyper::{Request, StatusCode, header::CONTENT_TYPE};
|
||||
|
||||
Reference in New Issue
Block a user