Unify registrations for local passwords and upstream OAuth registrations (#5281)
This commit is contained in:
@@ -272,6 +272,7 @@ pub struct UserRegistration {
|
||||
pub email_authentication_id: Option<Ulid>,
|
||||
pub user_registration_token_id: Option<Ulid>,
|
||||
pub password: Option<UserRegistrationPassword>,
|
||||
pub upstream_oauth_authorization_session_id: Option<Ulid>,
|
||||
pub post_auth_action: Option<serde_json::Value>,
|
||||
pub ip_address: Option<IpAddr>,
|
||||
pub user_agent: Option<String>,
|
||||
|
||||
@@ -817,7 +817,7 @@ impl UserEmailMutations {
|
||||
|
||||
let authentication = repo
|
||||
.user_email()
|
||||
.complete_authentication(&clock, authentication, &code)
|
||||
.complete_authentication_with_code(&clock, authentication, &code)
|
||||
.await?;
|
||||
|
||||
// Check the email is not already in use by anyone, including the current user
|
||||
|
||||
@@ -26,7 +26,6 @@ use mas_policy::Policy;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::{
|
||||
BoxRepository, RepositoryAccess,
|
||||
queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
|
||||
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
|
||||
user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
|
||||
};
|
||||
@@ -46,7 +45,7 @@ use super::{
|
||||
};
|
||||
use crate::{
|
||||
BoundActivityTracker, METER, PreferredLanguage, SiteConfig, impl_from_error_for_route,
|
||||
views::shared::OptionalPostAuthAction,
|
||||
views::{register::UserRegistrationSessionsCookie, shared::OptionalPostAuthAction},
|
||||
};
|
||||
|
||||
static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
|
||||
@@ -610,10 +609,6 @@ pub(crate) async fn post(
|
||||
.lookup_link(link_id)
|
||||
.map_err(|_| RouteError::MissingCookie)?;
|
||||
|
||||
let post_auth_action = OptionalPostAuthAction {
|
||||
post_auth_action: post_auth_action.cloned(),
|
||||
};
|
||||
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.lookup(link_id)
|
||||
@@ -641,7 +636,7 @@ pub(crate) async fn post(
|
||||
let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
|
||||
let form_state = form.to_form_state();
|
||||
|
||||
let session = match (maybe_user_session, link.user_id, form) {
|
||||
match (maybe_user_session, link.user_id, form) {
|
||||
(Some(session), None, FormData::Link) => {
|
||||
// The user is already logged in, the link is not linked to any user, and the
|
||||
// user asked to link their account.
|
||||
@@ -649,7 +644,27 @@ pub(crate) async fn post(
|
||||
.associate_to_user(&link, &session.user)
|
||||
.await?;
|
||||
|
||||
session
|
||||
let upstream_session = repo
|
||||
.upstream_oauth_session()
|
||||
.consume(&clock, upstream_session)
|
||||
.await?;
|
||||
|
||||
repo.browser_session()
|
||||
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
|
||||
.await?;
|
||||
|
||||
let post_auth_action = OptionalPostAuthAction {
|
||||
post_auth_action: post_auth_action.cloned(),
|
||||
};
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
.consume_link(link_id)?
|
||||
.save(cookie_jar, &clock);
|
||||
let cookie_jar = cookie_jar.set_session(&session);
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
|
||||
}
|
||||
|
||||
(None, None, FormData::Link) => {
|
||||
@@ -714,14 +729,38 @@ pub(crate) async fn post(
|
||||
return Err(RouteError::InvalidFormAction);
|
||||
}
|
||||
UpstreamOAuthProviderOnConflict::Add => {
|
||||
//add link to the user
|
||||
// Add link to the user
|
||||
repo.upstream_oauth_link()
|
||||
.associate_to_user(&link, &user)
|
||||
.await?;
|
||||
|
||||
repo.browser_session()
|
||||
// And sign in the user
|
||||
let session = repo
|
||||
.browser_session()
|
||||
.add(&mut rng, &clock, &user, user_agent)
|
||||
.await?
|
||||
.await?;
|
||||
|
||||
let upstream_session = repo
|
||||
.upstream_oauth_session()
|
||||
.consume(&clock, upstream_session)
|
||||
.await?;
|
||||
|
||||
repo.browser_session()
|
||||
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
|
||||
.await?;
|
||||
|
||||
let post_auth_action = OptionalPostAuthAction {
|
||||
post_auth_action: post_auth_action.cloned(),
|
||||
};
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
.consume_link(link_id)?
|
||||
.save(cookie_jar, &clock);
|
||||
let cookie_jar = cookie_jar.set_session(&session);
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -950,61 +989,84 @@ pub(crate) async fn post(
|
||||
|
||||
REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
|
||||
|
||||
// Now we can create the user
|
||||
let user = repo.user().add(&mut rng, &clock, username).await?;
|
||||
|
||||
if let Some(terms_url) = &site_config.tos_uri {
|
||||
repo.user_terms()
|
||||
.accept_terms(&mut rng, &clock, &user, terms_url.clone())
|
||||
.await?;
|
||||
}
|
||||
|
||||
// And schedule the job to provision it
|
||||
let mut job = ProvisionUserJob::new(&user);
|
||||
|
||||
// If we have a display name, set it during provisioning
|
||||
if let Some(name) = display_name {
|
||||
job = job.set_display_name(name);
|
||||
}
|
||||
|
||||
repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
|
||||
|
||||
// If we have an email, add it to the user
|
||||
if let Some(email) = email {
|
||||
repo.user_email()
|
||||
.add(&mut rng, &clock, &user, email)
|
||||
.await?;
|
||||
}
|
||||
|
||||
repo.upstream_oauth_link()
|
||||
.associate_to_user(&link, &user)
|
||||
let mut registration = repo
|
||||
.user_registration()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
username,
|
||||
activity_tracker.ip(),
|
||||
user_agent,
|
||||
post_auth_action.map(|action| serde_json::json!(action)),
|
||||
)
|
||||
.await?;
|
||||
|
||||
repo.browser_session()
|
||||
.add(&mut rng, &clock, &user, user_agent)
|
||||
.await?
|
||||
if let Some(terms_url) = &site_config.tos_uri {
|
||||
registration = repo
|
||||
.user_registration()
|
||||
.set_terms_url(registration, terms_url.clone())
|
||||
.await?;
|
||||
}
|
||||
|
||||
// If we have an email, add an email authentication and complete it
|
||||
if let Some(email) = email {
|
||||
let authentication = repo
|
||||
.user_email()
|
||||
.add_authentication_for_registration(&mut rng, &clock, email, ®istration)
|
||||
.await?;
|
||||
let authentication = repo
|
||||
.user_email()
|
||||
.complete_authentication_with_upstream(
|
||||
&clock,
|
||||
authentication,
|
||||
&upstream_session,
|
||||
)
|
||||
.await?;
|
||||
|
||||
registration = repo
|
||||
.user_registration()
|
||||
.set_email_authentication(registration, &authentication)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// If we have a display name, add it to the registration
|
||||
if let Some(name) = display_name {
|
||||
registration = repo
|
||||
.user_registration()
|
||||
.set_display_name(registration, name)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let registration = repo
|
||||
.user_registration()
|
||||
.set_upstream_oauth_authorization_session(registration, &upstream_session)
|
||||
.await?;
|
||||
|
||||
repo.upstream_oauth_session()
|
||||
.consume(&clock, upstream_session)
|
||||
.await?;
|
||||
|
||||
let registrations = UserRegistrationSessionsCookie::load(&cookie_jar);
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
.consume_link(link_id)?
|
||||
.save(cookie_jar, &clock);
|
||||
|
||||
let cookie_jar = registrations.add(®istration).save(cookie_jar, &clock);
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
// Redirect to the user registration flow, in case we have any other step to
|
||||
// finish
|
||||
Ok((
|
||||
cookie_jar,
|
||||
url_builder.redirect(&mas_router::RegisterFinish::new(registration.id)),
|
||||
)
|
||||
.into_response())
|
||||
}
|
||||
|
||||
_ => return Err(RouteError::InvalidFormAction),
|
||||
};
|
||||
|
||||
let upstream_session = repo
|
||||
.upstream_oauth_session()
|
||||
.consume(&clock, upstream_session)
|
||||
.await?;
|
||||
|
||||
repo.browser_session()
|
||||
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
|
||||
.await?;
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
.consume_link(link_id)?
|
||||
.save(cookie_jar, &clock);
|
||||
let cookie_jar = cookie_jar.set_session(&session);
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
|
||||
_ => Err(RouteError::InvalidFormAction),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -1013,20 +1075,18 @@ mod tests {
|
||||
use mas_data_model::{
|
||||
UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProviderClaimsImports,
|
||||
UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderLocalpartPreference,
|
||||
UpstreamOAuthProviderTokenAuthMethod,
|
||||
UpstreamOAuthProviderTokenAuthMethod, UserEmailAuthentication, UserRegistration,
|
||||
};
|
||||
use mas_iana::jose::JsonWebSignatureAlg;
|
||||
use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
|
||||
use mas_keystore::Keystore;
|
||||
use mas_router::Route;
|
||||
use mas_storage::{
|
||||
Pagination, Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams,
|
||||
user::UserEmailFilter,
|
||||
};
|
||||
use mas_storage::{Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams};
|
||||
use oauth2_types::scope::{OPENID, Scope};
|
||||
use rand_chacha::ChaChaRng;
|
||||
use serde_json::Value;
|
||||
use sqlx::PgPool;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::UpstreamSessionsCookie;
|
||||
use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
|
||||
@@ -1188,33 +1248,41 @@ mod tests {
|
||||
let response = state.request(request).await;
|
||||
cookies.save_cookies(&response);
|
||||
response.assert_status(StatusCode::SEE_OTHER);
|
||||
let location = response.headers().get(hyper::header::LOCATION).unwrap();
|
||||
// Grab the registration ID from the redirected URL:
|
||||
// /register/steps/{id}/finish
|
||||
let registration_id: Ulid = str::from_utf8(location.as_bytes())
|
||||
.unwrap()
|
||||
.rsplit('/')
|
||||
.nth(1)
|
||||
.expect("Location to have two slashes")
|
||||
.parse()
|
||||
.expect("last segment of location to be a ULID");
|
||||
|
||||
// Check that we have a registered user, with the email imported
|
||||
let mut repo = state.repository().await.unwrap();
|
||||
let user = repo
|
||||
.user()
|
||||
.find_by_username("john")
|
||||
let registration: UserRegistration = repo
|
||||
.user_registration()
|
||||
.lookup(registration_id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("user exists");
|
||||
.expect("user registration exists");
|
||||
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.find_by_subject(&provider, "subject")
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("link exists");
|
||||
assert_eq!(registration.password, None);
|
||||
assert_eq!(registration.completed_at, None);
|
||||
assert_eq!(registration.username, "john");
|
||||
|
||||
assert_eq!(link.user_id, Some(user.id));
|
||||
|
||||
let page = repo
|
||||
let email_auth_id = registration
|
||||
.email_authentication_id
|
||||
.expect("registration should have an email authentication");
|
||||
let email_auth: UserEmailAuthentication = repo
|
||||
.user_email()
|
||||
.list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
|
||||
.lookup_authentication(email_auth_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let edge = page.edges.first().expect("email exists");
|
||||
|
||||
assert_eq!(edge.node.email, "john@example.com");
|
||||
.unwrap()
|
||||
.expect("email authentication should exist");
|
||||
assert_eq!(email_auth.email, "john@example.com");
|
||||
assert!(email_auth.completed_at.is_some());
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
|
||||
@@ -21,6 +21,8 @@ mod cookie;
|
||||
pub(crate) mod password;
|
||||
pub(crate) mod steps;
|
||||
|
||||
pub use self::cookie::UserRegistrationSessions as UserRegistrationSessionsCookie;
|
||||
|
||||
#[tracing::instrument(name = "handlers.views.register.get", skip_all)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
|
||||
@@ -154,56 +154,90 @@ pub(crate) async fn get(
|
||||
// If there is an email authentication, we need to check that the email
|
||||
// address was verified. If there is no email authentication attached, we
|
||||
// need to make sure the server doesn't require it
|
||||
let email_authentication = if let Some(email_authentication_id) =
|
||||
registration.email_authentication_id
|
||||
let email_authentication =
|
||||
if let Some(email_authentication_id) = registration.email_authentication_id {
|
||||
let email_authentication = repo
|
||||
.user_email()
|
||||
.lookup_authentication(email_authentication_id)
|
||||
.await?
|
||||
.context("Could not load the email authentication")
|
||||
.map_err(InternalError::from_anyhow)?;
|
||||
|
||||
// Check that the email authentication has been completed
|
||||
if email_authentication.completed_at.is_none() {
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
url_builder.redirect(&mas_router::RegisterVerifyEmail::new(id)),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
|
||||
// Check that the email address isn't already used
|
||||
// It is important to do that here, as we we're not checking during the
|
||||
// registration, because we don't want to disclose whether an email is
|
||||
// already being used or not before we verified it
|
||||
if repo
|
||||
.user_email()
|
||||
.count(UserEmailFilter::new().for_email(&email_authentication.email))
|
||||
.await?
|
||||
> 0
|
||||
{
|
||||
let action = registration
|
||||
.post_auth_action
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?;
|
||||
|
||||
let ctx = RegisterStepsEmailInUseContext::new(email_authentication.email, action)
|
||||
.with_language(lang);
|
||||
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
Html(templates.render_register_steps_email_in_use(&ctx)?),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
|
||||
Some(email_authentication)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// If this registration was created from an upstream OAuth session, check
|
||||
// it is still valid and wasn't linked to a user in the meantime
|
||||
let upstream_oauth = if let Some(upstream_oauth_authorization_session_id) =
|
||||
registration.upstream_oauth_authorization_session_id
|
||||
{
|
||||
let email_authentication = repo
|
||||
.user_email()
|
||||
.lookup_authentication(email_authentication_id)
|
||||
let upstream_oauth_authorization_session = repo
|
||||
.upstream_oauth_session()
|
||||
.lookup(upstream_oauth_authorization_session_id)
|
||||
.await?
|
||||
.context("Could not load the email authentication")
|
||||
.context("Could not load the upstream OAuth authorization session")
|
||||
.map_err(InternalError::from_anyhow)?;
|
||||
|
||||
// Check that the email authentication has been completed
|
||||
if email_authentication.completed_at.is_none() {
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
url_builder.redirect(&mas_router::RegisterVerifyEmail::new(id)),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
let link_id = upstream_oauth_authorization_session
|
||||
.link_id()
|
||||
// This should not happen, the session is associated with the user
|
||||
// registration once the link was already created
|
||||
.context("Authorization session has no upstream link associated with it")
|
||||
.map_err(InternalError::from_anyhow)?;
|
||||
|
||||
// Check that the email address isn't already used
|
||||
// It is important to do that here, as we we're not checking during the
|
||||
// registration, because we don't want to disclose whether an email is
|
||||
// already being used or not before we verified it
|
||||
if repo
|
||||
.user_email()
|
||||
.count(UserEmailFilter::new().for_email(&email_authentication.email))
|
||||
let upstream_oauth_link = repo
|
||||
.upstream_oauth_link()
|
||||
.lookup(link_id)
|
||||
.await?
|
||||
> 0
|
||||
{
|
||||
let action = registration
|
||||
.post_auth_action
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?;
|
||||
.context("Could not load the upstream OAuth link")
|
||||
.map_err(InternalError::from_anyhow)?;
|
||||
|
||||
let ctx = RegisterStepsEmailInUseContext::new(email_authentication.email, action)
|
||||
.with_language(lang);
|
||||
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
Html(templates.render_register_steps_email_in_use(&ctx)?),
|
||||
)
|
||||
.into_response());
|
||||
if upstream_oauth_link.user_id.is_some() {
|
||||
// This means the link was already associated to a user. This could
|
||||
// in theory happen if the same user registers concurrently, but
|
||||
// this is not going to happen often enough to have a dedicated page
|
||||
return Err(InternalError::from_anyhow(anyhow::anyhow!(
|
||||
"The upstream identity was already linked to a user. Try logging in again"
|
||||
)));
|
||||
}
|
||||
|
||||
Some(email_authentication)
|
||||
} else if site_config.password_registration_email_required {
|
||||
// This could only happen in theory during a configuration change
|
||||
return Err(InternalError::from_anyhow(anyhow::anyhow!(
|
||||
"Server requires an email address to complete the registration, but no email authentication was attached to the user registration"
|
||||
)));
|
||||
Some((upstream_oauth_authorization_session, upstream_oauth_link))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -272,6 +306,16 @@ pub(crate) async fn get(
|
||||
PASSWORD_REGISTER_COUNTER.add(1, &[]);
|
||||
}
|
||||
|
||||
if let Some((upstream_session, upstream_link)) = upstream_oauth {
|
||||
repo.upstream_oauth_link()
|
||||
.associate_to_user(&upstream_link, &user)
|
||||
.await?;
|
||||
|
||||
repo.browser_session()
|
||||
.authenticate_with_upstream(&mut rng, &clock, &user_session, &upstream_session)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Some(terms_url) = registration.terms_url {
|
||||
repo.user_terms()
|
||||
.accept_terms(&mut rng, &clock, &user, terms_url)
|
||||
|
||||
@@ -200,7 +200,7 @@ pub(crate) async fn post(
|
||||
};
|
||||
|
||||
repo.user_email()
|
||||
.complete_authentication(&clock, email_authentication, &code)
|
||||
.complete_authentication_with_code(&clock, email_authentication, &code)
|
||||
.await?;
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
15
crates/storage-pg/.sqlx/query-4c37988dacca5a83c8b64209042d5f1a8ec44ec8ccccad2d7fce9ac855209883.json
generated
Normal file
15
crates/storage-pg/.sqlx/query-4c37988dacca5a83c8b64209042d5f1a8ec44ec8ccccad2d7fce9ac855209883.json
generated
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE user_registrations\n SET upstream_oauth_authorization_session_id = $2\n WHERE user_registration_id = $1 AND completed_at IS NULL\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "4c37988dacca5a83c8b64209042d5f1a8ec44ec8ccccad2d7fce9ac855209883"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT user_registration_id\n , ip_address as \"ip_address: IpAddr\"\n , user_agent\n , post_auth_action\n , username\n , display_name\n , terms_url\n , email_authentication_id\n , user_registration_token_id\n , hashed_password\n , hashed_password_version\n , created_at\n , completed_at\n FROM user_registrations\n WHERE user_registration_id = $1\n ",
|
||||
"query": "\n SELECT user_registration_id\n , ip_address as \"ip_address: IpAddr\"\n , user_agent\n , post_auth_action\n , username\n , display_name\n , terms_url\n , email_authentication_id\n , user_registration_token_id\n , hashed_password\n , hashed_password_version\n , upstream_oauth_authorization_session_id\n , created_at\n , completed_at\n FROM user_registrations\n WHERE user_registration_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -60,11 +60,16 @@
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"name": "upstream_oauth_authorization_session_id",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 12,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 12,
|
||||
"ordinal": 13,
|
||||
"name": "completed_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
@@ -86,9 +91,10 @@
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "5bb3ad7486365e0798e103b072514e66b5b69a347dce91135e158a5eba1d1426"
|
||||
"hash": "b91cc2458e1a530e7cadbd1ca3e2eaf93e1c44108b6770a24c9a24ac29db37d3"
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
-- Copyright 2025 Element Creations Ltd.
|
||||
--
|
||||
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
-- Please see LICENSE in the repository root for full details.
|
||||
|
||||
-- Track what upstream OAuth session to associate during user registration
|
||||
ALTER TABLE user_registrations
|
||||
ADD COLUMN upstream_oauth_authorization_session_id UUID
|
||||
REFERENCES upstream_oauth_authorization_sessions (upstream_oauth_authorization_session_id)
|
||||
ON DELETE SET NULL;
|
||||
@@ -0,0 +1,9 @@
|
||||
-- no-transaction
|
||||
-- Copyright 2025 New Vector Ltd.
|
||||
--
|
||||
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
-- Please see LICENSE in the repository root for full details.
|
||||
|
||||
-- Index on the new foreign key added by the previous migration
|
||||
CREATE INDEX CONCURRENTLY user_registrations_upstream_oauth_session_id_idx
|
||||
ON user_registrations (upstream_oauth_authorization_session_id);
|
||||
@@ -7,8 +7,8 @@
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
BrowserSession, Clock, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode,
|
||||
UserRegistration,
|
||||
BrowserSession, Clock, UpstreamOAuthAuthorizationSession, User, UserEmail,
|
||||
UserEmailAuthentication, UserEmailAuthenticationCode, UserRegistration,
|
||||
};
|
||||
use mas_storage::{
|
||||
Page, Pagination,
|
||||
@@ -668,7 +668,7 @@ impl UserEmailRepository for PgUserEmailRepository<'_> {
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.complete_email_authentication",
|
||||
name = "db.user_email.complete_email_authentication_with_code",
|
||||
skip_all,
|
||||
fields(
|
||||
db.query.text,
|
||||
@@ -679,7 +679,7 @@ impl UserEmailRepository for PgUserEmailRepository<'_> {
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn complete_authentication(
|
||||
async fn complete_authentication_with_code(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
mut user_email_authentication: UserEmailAuthentication,
|
||||
@@ -712,4 +712,49 @@ impl UserEmailRepository for PgUserEmailRepository<'_> {
|
||||
user_email_authentication.completed_at = Some(completed_at);
|
||||
Ok(user_email_authentication)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.complete_email_authentication_with_upstream",
|
||||
skip_all,
|
||||
fields(
|
||||
db.query.text,
|
||||
%user_email_authentication.id,
|
||||
%user_email_authentication.email,
|
||||
%upstream_oauth_authorization_session.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn complete_authentication_with_upstream(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
mut user_email_authentication: UserEmailAuthentication,
|
||||
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
|
||||
) -> Result<UserEmailAuthentication, Self::Error> {
|
||||
// We technically don't use the upstream_oauth_authorization_session here (other
|
||||
// than recording it in the span), but this is to make sure the caller
|
||||
// has fetched one before calling this
|
||||
let completed_at = clock.now();
|
||||
|
||||
// We'll assume the caller has checked that completed_at is None, so in case
|
||||
// they haven't, the update will not affect any rows, which will raise
|
||||
// an error
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE user_email_authentications
|
||||
SET completed_at = $2
|
||||
WHERE user_email_authentication_id = $1
|
||||
AND completed_at IS NULL
|
||||
"#,
|
||||
Uuid::from(user_email_authentication.id),
|
||||
completed_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
user_email_authentication.completed_at = Some(completed_at);
|
||||
Ok(user_email_authentication)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ use std::net::IpAddr;
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
Clock, UserEmailAuthentication, UserRegistration, UserRegistrationPassword,
|
||||
UserRegistrationToken,
|
||||
Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration,
|
||||
UserRegistrationPassword, UserRegistrationToken,
|
||||
};
|
||||
use mas_storage::user::UserRegistrationRepository;
|
||||
use rand::RngCore;
|
||||
@@ -46,6 +46,7 @@ struct UserRegistrationLookup {
|
||||
user_registration_token_id: Option<Uuid>,
|
||||
hashed_password: Option<String>,
|
||||
hashed_password_version: Option<i32>,
|
||||
upstream_oauth_authorization_session_id: Option<Uuid>,
|
||||
created_at: DateTime<Utc>,
|
||||
completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
@@ -100,6 +101,9 @@ impl TryFrom<UserRegistrationLookup> for UserRegistration {
|
||||
email_authentication_id: value.email_authentication_id.map(Ulid::from),
|
||||
user_registration_token_id: value.user_registration_token_id.map(Ulid::from),
|
||||
password,
|
||||
upstream_oauth_authorization_session_id: value
|
||||
.upstream_oauth_authorization_session_id
|
||||
.map(Ulid::from),
|
||||
created_at: value.created_at,
|
||||
completed_at: value.completed_at,
|
||||
})
|
||||
@@ -134,6 +138,7 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
|
||||
, user_registration_token_id
|
||||
, hashed_password
|
||||
, hashed_password_version
|
||||
, upstream_oauth_authorization_session_id
|
||||
, created_at
|
||||
, completed_at
|
||||
FROM user_registrations
|
||||
@@ -208,6 +213,7 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
|
||||
email_authentication_id: None,
|
||||
user_registration_token_id: None,
|
||||
password: None,
|
||||
upstream_oauth_authorization_session_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -393,6 +399,42 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
|
||||
Ok(user_registration)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_registration.set_upstream_oauth_authorization_session",
|
||||
skip_all,
|
||||
fields(
|
||||
db.query.text,
|
||||
%user_registration.id,
|
||||
%upstream_oauth_authorization_session.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn set_upstream_oauth_authorization_session(
|
||||
&mut self,
|
||||
mut user_registration: UserRegistration,
|
||||
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
|
||||
) -> Result<UserRegistration, Self::Error> {
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE user_registrations
|
||||
SET upstream_oauth_authorization_session_id = $2
|
||||
WHERE user_registration_id = $1 AND completed_at IS NULL
|
||||
"#,
|
||||
Uuid::from(user_registration.id),
|
||||
Uuid::from(upstream_oauth_authorization_session.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
user_registration.upstream_oauth_authorization_session_id =
|
||||
Some(upstream_oauth_authorization_session.id);
|
||||
|
||||
Ok(user_registration)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_registration.complete",
|
||||
skip_all,
|
||||
@@ -433,7 +475,14 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
|
||||
mod tests {
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
|
||||
use mas_data_model::{Clock, UserRegistrationPassword, clock::MockClock};
|
||||
use mas_data_model::{
|
||||
Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
|
||||
UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode,
|
||||
UpstreamOAuthProviderTokenAuthMethod, UserRegistrationPassword, clock::MockClock,
|
||||
};
|
||||
use mas_iana::jose::JsonWebSignatureAlg;
|
||||
use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams;
|
||||
use oauth2_types::scope::Scope;
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaChaRng;
|
||||
use sqlx::PgPool;
|
||||
@@ -851,4 +900,120 @@ mod tests {
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_set_upstream_oauth_session(pool: PgPool) {
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = MockClock::default();
|
||||
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
|
||||
|
||||
let registration = repo
|
||||
.user_registration()
|
||||
.add(&mut rng, &clock, "alice".to_owned(), None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(registration.upstream_oauth_authorization_session_id, None);
|
||||
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
UpstreamOAuthProviderParams {
|
||||
issuer: Some("https://example.com/".to_owned()),
|
||||
human_name: Some("Example Ltd.".to_owned()),
|
||||
brand_name: None,
|
||||
scope: Scope::from_iter([oauth2_types::scope::OPENID]),
|
||||
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
|
||||
token_endpoint_signing_alg: None,
|
||||
id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
|
||||
client_id: "client".to_owned(),
|
||||
encrypted_client_secret: None,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
||||
authorization_endpoint_override: None,
|
||||
token_endpoint_override: None,
|
||||
userinfo_endpoint_override: None,
|
||||
fetch_userinfo: false,
|
||||
userinfo_signed_response_alg: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
|
||||
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
|
||||
response_mode: None,
|
||||
additional_authorization_parameters: Vec::new(),
|
||||
forward_login_hint: false,
|
||||
ui_order: 0,
|
||||
on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.add(&mut rng, &clock, &provider, "state".to_owned(), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let registration = repo
|
||||
.user_registration()
|
||||
.set_upstream_oauth_authorization_session(registration, &session)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
registration.upstream_oauth_authorization_session_id,
|
||||
Some(session.id)
|
||||
);
|
||||
|
||||
let lookup = repo
|
||||
.user_registration()
|
||||
.lookup(registration.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
lookup.upstream_oauth_authorization_session_id,
|
||||
registration.upstream_oauth_authorization_session_id
|
||||
);
|
||||
|
||||
// Setting it again should work
|
||||
let registration = repo
|
||||
.user_registration()
|
||||
.set_upstream_oauth_authorization_session(registration, &session)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
registration.upstream_oauth_authorization_session_id,
|
||||
Some(session.id)
|
||||
);
|
||||
|
||||
let lookup = repo
|
||||
.user_registration()
|
||||
.lookup(registration.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
lookup.upstream_oauth_authorization_session_id,
|
||||
registration.upstream_oauth_authorization_session_id
|
||||
);
|
||||
|
||||
// Can't set it once completed
|
||||
let registration = repo
|
||||
.user_registration()
|
||||
.complete(&clock, registration)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let res = repo
|
||||
.user_registration()
|
||||
.set_upstream_oauth_authorization_session(registration, &session)
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -488,7 +488,7 @@ async fn test_user_email_repo_authentications(pool: PgPool) {
|
||||
// Complete the authentication
|
||||
let authentication = repo
|
||||
.user_email()
|
||||
.complete_authentication(&clock, authentication, &code)
|
||||
.complete_authentication_with_code(&clock, authentication, &code)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -514,7 +514,7 @@ async fn test_user_email_repo_authentications(pool: PgPool) {
|
||||
// Completing a second time should fail
|
||||
let res = repo
|
||||
.user_email()
|
||||
.complete_authentication(&clock, authentication, &code)
|
||||
.complete_authentication_with_code(&clock, authentication, &code)
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mas_data_model::{
|
||||
BrowserSession, Clock, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode,
|
||||
UserRegistration,
|
||||
BrowserSession, Clock, UpstreamOAuthAuthorizationSession, User, UserEmail,
|
||||
UserEmailAuthentication, UserEmailAuthenticationCode, UserRegistration,
|
||||
};
|
||||
use rand_core::RngCore;
|
||||
use ulid::Ulid;
|
||||
@@ -306,12 +306,34 @@ pub trait UserEmailRepository: Send + Sync {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the underlying repository fails
|
||||
async fn complete_authentication(
|
||||
async fn complete_authentication_with_code(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
authentication: UserEmailAuthentication,
|
||||
code: &UserEmailAuthenticationCode,
|
||||
) -> Result<UserEmailAuthentication, Self::Error>;
|
||||
|
||||
/// Complete a [`UserEmailAuthentication`] by using the given upstream oauth
|
||||
/// authorization session
|
||||
///
|
||||
/// Returns the completed [`UserEmailAuthentication`]
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `clock`: The clock to use to generate timestamps
|
||||
/// * `authentication`: The [`UserEmailAuthentication`] to complete
|
||||
/// * `upstream_oauth_authorization_session`: The
|
||||
/// [`UpstreamOAuthAuthorizationSession`] to use
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the underlying repository fails
|
||||
async fn complete_authentication_with_upstream(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
authentication: UserEmailAuthentication,
|
||||
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
|
||||
) -> Result<UserEmailAuthentication, Self::Error>;
|
||||
}
|
||||
|
||||
repository_impl!(UserEmailRepository:
|
||||
@@ -374,10 +396,17 @@ repository_impl!(UserEmailRepository:
|
||||
code: &str,
|
||||
) -> Result<Option<UserEmailAuthenticationCode>, Self::Error>;
|
||||
|
||||
async fn complete_authentication(
|
||||
async fn complete_authentication_with_code(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
authentication: UserEmailAuthentication,
|
||||
code: &UserEmailAuthenticationCode,
|
||||
) -> Result<UserEmailAuthentication, Self::Error>;
|
||||
|
||||
async fn complete_authentication_with_upstream(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
authentication: UserEmailAuthentication,
|
||||
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
|
||||
) -> Result<UserEmailAuthentication, Self::Error>;
|
||||
);
|
||||
|
||||
@@ -6,7 +6,10 @@
|
||||
use std::net::IpAddr;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mas_data_model::{Clock, UserEmailAuthentication, UserRegistration, UserRegistrationToken};
|
||||
use mas_data_model::{
|
||||
Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration,
|
||||
UserRegistrationToken,
|
||||
};
|
||||
use rand_core::RngCore;
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
@@ -157,6 +160,27 @@ pub trait UserRegistrationRepository: Send + Sync {
|
||||
user_registration_token: &UserRegistrationToken,
|
||||
) -> Result<UserRegistration, Self::Error>;
|
||||
|
||||
/// Set an [`UpstreamOAuthAuthorizationSession`] to associate with a
|
||||
/// [`UserRegistration`]
|
||||
///
|
||||
/// Returns the updated [`UserRegistration`]
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `user_registration`: The [`UserRegistration`] to update
|
||||
/// * `upstream_oauth_authorization_session`: The
|
||||
/// [`UpstreamOAuthAuthorizationSession`] to set
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails or if the
|
||||
/// registration is already completed
|
||||
async fn set_upstream_oauth_authorization_session(
|
||||
&mut self,
|
||||
user_registration: UserRegistration,
|
||||
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
|
||||
) -> Result<UserRegistration, Self::Error>;
|
||||
|
||||
/// Complete a [`UserRegistration`]
|
||||
///
|
||||
/// Returns the updated [`UserRegistration`]
|
||||
@@ -214,6 +238,11 @@ repository_impl!(UserRegistrationRepository:
|
||||
user_registration: UserRegistration,
|
||||
user_registration_token: &UserRegistrationToken,
|
||||
) -> Result<UserRegistration, Self::Error>;
|
||||
async fn set_upstream_oauth_authorization_session(
|
||||
&mut self,
|
||||
user_registration: UserRegistration,
|
||||
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
|
||||
) -> Result<UserRegistration, Self::Error>;
|
||||
async fn complete(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
|
||||
Reference in New Issue
Block a user