Unify registrations for local passwords and upstream OAuth registrations (#5281)

This commit is contained in:
Quentin Gliech
2025-11-27 16:13:03 +01:00
committed by GitHub
15 changed files with 568 additions and 145 deletions

View File

@@ -272,6 +272,7 @@ pub struct UserRegistration {
pub email_authentication_id: Option<Ulid>,
pub user_registration_token_id: Option<Ulid>,
pub password: Option<UserRegistrationPassword>,
pub upstream_oauth_authorization_session_id: Option<Ulid>,
pub post_auth_action: Option<serde_json::Value>,
pub ip_address: Option<IpAddr>,
pub user_agent: Option<String>,

View File

@@ -817,7 +817,7 @@ impl UserEmailMutations {
let authentication = repo
.user_email()
.complete_authentication(&clock, authentication, &code)
.complete_authentication_with_code(&clock, authentication, &code)
.await?;
// Check the email is not already in use by anyone, including the current user

View File

@@ -26,7 +26,6 @@ use mas_policy::Policy;
use mas_router::UrlBuilder;
use mas_storage::{
BoxRepository, RepositoryAccess,
queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
};
@@ -46,7 +45,7 @@ use super::{
};
use crate::{
BoundActivityTracker, METER, PreferredLanguage, SiteConfig, impl_from_error_for_route,
views::shared::OptionalPostAuthAction,
views::{register::UserRegistrationSessionsCookie, shared::OptionalPostAuthAction},
};
static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
@@ -610,10 +609,6 @@ pub(crate) async fn post(
.lookup_link(link_id)
.map_err(|_| RouteError::MissingCookie)?;
let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};
let link = repo
.upstream_oauth_link()
.lookup(link_id)
@@ -641,7 +636,7 @@ pub(crate) async fn post(
let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
let form_state = form.to_form_state();
let session = match (maybe_user_session, link.user_id, form) {
match (maybe_user_session, link.user_id, form) {
(Some(session), None, FormData::Link) => {
// The user is already logged in, the link is not linked to any user, and the
// user asked to link their account.
@@ -649,7 +644,27 @@ pub(crate) async fn post(
.associate_to_user(&link, &session.user)
.await?;
session
let upstream_session = repo
.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;
repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?;
let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};
let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);
let cookie_jar = cookie_jar.set_session(&session);
repo.save().await?;
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
}
(None, None, FormData::Link) => {
@@ -714,14 +729,38 @@ pub(crate) async fn post(
return Err(RouteError::InvalidFormAction);
}
UpstreamOAuthProviderOnConflict::Add => {
//add link to the user
// Add link to the user
repo.upstream_oauth_link()
.associate_to_user(&link, &user)
.await?;
repo.browser_session()
// And sign in the user
let session = repo
.browser_session()
.add(&mut rng, &clock, &user, user_agent)
.await?
.await?;
let upstream_session = repo
.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;
repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?;
let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};
let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);
let cookie_jar = cookie_jar.set_session(&session);
repo.save().await?;
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
}
}
}
@@ -950,61 +989,84 @@ pub(crate) async fn post(
REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
// Now we can create the user
let user = repo.user().add(&mut rng, &clock, username).await?;
if let Some(terms_url) = &site_config.tos_uri {
repo.user_terms()
.accept_terms(&mut rng, &clock, &user, terms_url.clone())
.await?;
}
// And schedule the job to provision it
let mut job = ProvisionUserJob::new(&user);
// If we have a display name, set it during provisioning
if let Some(name) = display_name {
job = job.set_display_name(name);
}
repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
// If we have an email, add it to the user
if let Some(email) = email {
repo.user_email()
.add(&mut rng, &clock, &user, email)
.await?;
}
repo.upstream_oauth_link()
.associate_to_user(&link, &user)
let mut registration = repo
.user_registration()
.add(
&mut rng,
&clock,
username,
activity_tracker.ip(),
user_agent,
post_auth_action.map(|action| serde_json::json!(action)),
)
.await?;
repo.browser_session()
.add(&mut rng, &clock, &user, user_agent)
.await?
if let Some(terms_url) = &site_config.tos_uri {
registration = repo
.user_registration()
.set_terms_url(registration, terms_url.clone())
.await?;
}
// If we have an email, add an email authentication and complete it
if let Some(email) = email {
let authentication = repo
.user_email()
.add_authentication_for_registration(&mut rng, &clock, email, &registration)
.await?;
let authentication = repo
.user_email()
.complete_authentication_with_upstream(
&clock,
authentication,
&upstream_session,
)
.await?;
registration = repo
.user_registration()
.set_email_authentication(registration, &authentication)
.await?;
}
// If we have a display name, add it to the registration
if let Some(name) = display_name {
registration = repo
.user_registration()
.set_display_name(registration, name)
.await?;
}
let registration = repo
.user_registration()
.set_upstream_oauth_authorization_session(registration, &upstream_session)
.await?;
repo.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;
let registrations = UserRegistrationSessionsCookie::load(&cookie_jar);
let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);
let cookie_jar = registrations.add(&registration).save(cookie_jar, &clock);
repo.save().await?;
// Redirect to the user registration flow, in case we have any other step to
// finish
Ok((
cookie_jar,
url_builder.redirect(&mas_router::RegisterFinish::new(registration.id)),
)
.into_response())
}
_ => return Err(RouteError::InvalidFormAction),
};
let upstream_session = repo
.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;
repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?;
let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);
let cookie_jar = cookie_jar.set_session(&session);
repo.save().await?;
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
_ => Err(RouteError::InvalidFormAction),
}
}
#[cfg(test)]
@@ -1013,20 +1075,18 @@ mod tests {
use mas_data_model::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProviderClaimsImports,
UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderLocalpartPreference,
UpstreamOAuthProviderTokenAuthMethod,
UpstreamOAuthProviderTokenAuthMethod, UserEmailAuthentication, UserRegistration,
};
use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
use mas_keystore::Keystore;
use mas_router::Route;
use mas_storage::{
Pagination, Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams,
user::UserEmailFilter,
};
use mas_storage::{Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams};
use oauth2_types::scope::{OPENID, Scope};
use rand_chacha::ChaChaRng;
use serde_json::Value;
use sqlx::PgPool;
use ulid::Ulid;
use super::UpstreamSessionsCookie;
use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
@@ -1188,33 +1248,41 @@ mod tests {
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::SEE_OTHER);
let location = response.headers().get(hyper::header::LOCATION).unwrap();
// Grab the registration ID from the redirected URL:
// /register/steps/{id}/finish
let registration_id: Ulid = str::from_utf8(location.as_bytes())
.unwrap()
.rsplit('/')
.nth(1)
.expect("Location to have two slashes")
.parse()
.expect("last segment of location to be a ULID");
// Check that we have a registered user, with the email imported
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.find_by_username("john")
let registration: UserRegistration = repo
.user_registration()
.lookup(registration_id)
.await
.unwrap()
.expect("user exists");
.expect("user registration exists");
let link = repo
.upstream_oauth_link()
.find_by_subject(&provider, "subject")
.await
.unwrap()
.expect("link exists");
assert_eq!(registration.password, None);
assert_eq!(registration.completed_at, None);
assert_eq!(registration.username, "john");
assert_eq!(link.user_id, Some(user.id));
let page = repo
let email_auth_id = registration
.email_authentication_id
.expect("registration should have an email authentication");
let email_auth: UserEmailAuthentication = repo
.user_email()
.list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
.lookup_authentication(email_auth_id)
.await
.unwrap();
let edge = page.edges.first().expect("email exists");
assert_eq!(edge.node.email, "john@example.com");
.unwrap()
.expect("email authentication should exist");
assert_eq!(email_auth.email, "john@example.com");
assert!(email_auth.completed_at.is_some());
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]

View File

@@ -21,6 +21,8 @@ mod cookie;
pub(crate) mod password;
pub(crate) mod steps;
pub use self::cookie::UserRegistrationSessions as UserRegistrationSessionsCookie;
#[tracing::instrument(name = "handlers.views.register.get", skip_all)]
pub(crate) async fn get(
mut rng: BoxRng,

View File

@@ -154,56 +154,90 @@ pub(crate) async fn get(
// If there is an email authentication, we need to check that the email
// address was verified. If there is no email authentication attached, we
// need to make sure the server doesn't require it
let email_authentication = if let Some(email_authentication_id) =
registration.email_authentication_id
let email_authentication =
if let Some(email_authentication_id) = registration.email_authentication_id {
let email_authentication = repo
.user_email()
.lookup_authentication(email_authentication_id)
.await?
.context("Could not load the email authentication")
.map_err(InternalError::from_anyhow)?;
// Check that the email authentication has been completed
if email_authentication.completed_at.is_none() {
return Ok((
cookie_jar,
url_builder.redirect(&mas_router::RegisterVerifyEmail::new(id)),
)
.into_response());
}
// Check that the email address isn't already used
// It is important to do that here, as we we're not checking during the
// registration, because we don't want to disclose whether an email is
// already being used or not before we verified it
if repo
.user_email()
.count(UserEmailFilter::new().for_email(&email_authentication.email))
.await?
> 0
{
let action = registration
.post_auth_action
.map(serde_json::from_value)
.transpose()?;
let ctx = RegisterStepsEmailInUseContext::new(email_authentication.email, action)
.with_language(lang);
return Ok((
cookie_jar,
Html(templates.render_register_steps_email_in_use(&ctx)?),
)
.into_response());
}
Some(email_authentication)
} else {
None
};
// If this registration was created from an upstream OAuth session, check
// it is still valid and wasn't linked to a user in the meantime
let upstream_oauth = if let Some(upstream_oauth_authorization_session_id) =
registration.upstream_oauth_authorization_session_id
{
let email_authentication = repo
.user_email()
.lookup_authentication(email_authentication_id)
let upstream_oauth_authorization_session = repo
.upstream_oauth_session()
.lookup(upstream_oauth_authorization_session_id)
.await?
.context("Could not load the email authentication")
.context("Could not load the upstream OAuth authorization session")
.map_err(InternalError::from_anyhow)?;
// Check that the email authentication has been completed
if email_authentication.completed_at.is_none() {
return Ok((
cookie_jar,
url_builder.redirect(&mas_router::RegisterVerifyEmail::new(id)),
)
.into_response());
}
let link_id = upstream_oauth_authorization_session
.link_id()
// This should not happen, the session is associated with the user
// registration once the link was already created
.context("Authorization session has no upstream link associated with it")
.map_err(InternalError::from_anyhow)?;
// Check that the email address isn't already used
// It is important to do that here, as we we're not checking during the
// registration, because we don't want to disclose whether an email is
// already being used or not before we verified it
if repo
.user_email()
.count(UserEmailFilter::new().for_email(&email_authentication.email))
let upstream_oauth_link = repo
.upstream_oauth_link()
.lookup(link_id)
.await?
> 0
{
let action = registration
.post_auth_action
.map(serde_json::from_value)
.transpose()?;
.context("Could not load the upstream OAuth link")
.map_err(InternalError::from_anyhow)?;
let ctx = RegisterStepsEmailInUseContext::new(email_authentication.email, action)
.with_language(lang);
return Ok((
cookie_jar,
Html(templates.render_register_steps_email_in_use(&ctx)?),
)
.into_response());
if upstream_oauth_link.user_id.is_some() {
// This means the link was already associated to a user. This could
// in theory happen if the same user registers concurrently, but
// this is not going to happen often enough to have a dedicated page
return Err(InternalError::from_anyhow(anyhow::anyhow!(
"The upstream identity was already linked to a user. Try logging in again"
)));
}
Some(email_authentication)
} else if site_config.password_registration_email_required {
// This could only happen in theory during a configuration change
return Err(InternalError::from_anyhow(anyhow::anyhow!(
"Server requires an email address to complete the registration, but no email authentication was attached to the user registration"
)));
Some((upstream_oauth_authorization_session, upstream_oauth_link))
} else {
None
};
@@ -272,6 +306,16 @@ pub(crate) async fn get(
PASSWORD_REGISTER_COUNTER.add(1, &[]);
}
if let Some((upstream_session, upstream_link)) = upstream_oauth {
repo.upstream_oauth_link()
.associate_to_user(&upstream_link, &user)
.await?;
repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &user_session, &upstream_session)
.await?;
}
if let Some(terms_url) = registration.terms_url {
repo.user_terms()
.accept_terms(&mut rng, &clock, &user, terms_url)

View File

@@ -200,7 +200,7 @@ pub(crate) async fn post(
};
repo.user_email()
.complete_authentication(&clock, email_authentication, &code)
.complete_authentication_with_code(&clock, email_authentication, &code)
.await?;
repo.save().await?;

View File

@@ -0,0 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE user_registrations\n SET upstream_oauth_authorization_session_id = $2\n WHERE user_registration_id = $1 AND completed_at IS NULL\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Uuid"
]
},
"nullable": []
},
"hash": "4c37988dacca5a83c8b64209042d5f1a8ec44ec8ccccad2d7fce9ac855209883"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT user_registration_id\n , ip_address as \"ip_address: IpAddr\"\n , user_agent\n , post_auth_action\n , username\n , display_name\n , terms_url\n , email_authentication_id\n , user_registration_token_id\n , hashed_password\n , hashed_password_version\n , created_at\n , completed_at\n FROM user_registrations\n WHERE user_registration_id = $1\n ",
"query": "\n SELECT user_registration_id\n , ip_address as \"ip_address: IpAddr\"\n , user_agent\n , post_auth_action\n , username\n , display_name\n , terms_url\n , email_authentication_id\n , user_registration_token_id\n , hashed_password\n , hashed_password_version\n , upstream_oauth_authorization_session_id\n , created_at\n , completed_at\n FROM user_registrations\n WHERE user_registration_id = $1\n ",
"describe": {
"columns": [
{
@@ -60,11 +60,16 @@
},
{
"ordinal": 11,
"name": "upstream_oauth_authorization_session_id",
"type_info": "Uuid"
},
{
"ordinal": 12,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 12,
"ordinal": 13,
"name": "completed_at",
"type_info": "Timestamptz"
}
@@ -86,9 +91,10 @@
true,
true,
true,
true,
false,
true
]
},
"hash": "5bb3ad7486365e0798e103b072514e66b5b69a347dce91135e158a5eba1d1426"
"hash": "b91cc2458e1a530e7cadbd1ca3e2eaf93e1c44108b6770a24c9a24ac29db37d3"
}

View File

@@ -0,0 +1,10 @@
-- Copyright 2025 Element Creations Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
-- Please see LICENSE in the repository root for full details.
-- Track what upstream OAuth session to associate during user registration
ALTER TABLE user_registrations
ADD COLUMN upstream_oauth_authorization_session_id UUID
REFERENCES upstream_oauth_authorization_sessions (upstream_oauth_authorization_session_id)
ON DELETE SET NULL;

View File

@@ -0,0 +1,9 @@
-- no-transaction
-- Copyright 2025 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
-- Please see LICENSE in the repository root for full details.
-- Index on the new foreign key added by the previous migration
CREATE INDEX CONCURRENTLY user_registrations_upstream_oauth_session_id_idx
ON user_registrations (upstream_oauth_authorization_session_id);

View File

@@ -7,8 +7,8 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
BrowserSession, Clock, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode,
UserRegistration,
BrowserSession, Clock, UpstreamOAuthAuthorizationSession, User, UserEmail,
UserEmailAuthentication, UserEmailAuthenticationCode, UserRegistration,
};
use mas_storage::{
Page, Pagination,
@@ -668,7 +668,7 @@ impl UserEmailRepository for PgUserEmailRepository<'_> {
}
#[tracing::instrument(
name = "db.user_email.complete_email_authentication",
name = "db.user_email.complete_email_authentication_with_code",
skip_all,
fields(
db.query.text,
@@ -679,7 +679,7 @@ impl UserEmailRepository for PgUserEmailRepository<'_> {
),
err,
)]
async fn complete_authentication(
async fn complete_authentication_with_code(
&mut self,
clock: &dyn Clock,
mut user_email_authentication: UserEmailAuthentication,
@@ -712,4 +712,49 @@ impl UserEmailRepository for PgUserEmailRepository<'_> {
user_email_authentication.completed_at = Some(completed_at);
Ok(user_email_authentication)
}
#[tracing::instrument(
name = "db.user_email.complete_email_authentication_with_upstream",
skip_all,
fields(
db.query.text,
%user_email_authentication.id,
%user_email_authentication.email,
%upstream_oauth_authorization_session.id,
),
err,
)]
async fn complete_authentication_with_upstream(
&mut self,
clock: &dyn Clock,
mut user_email_authentication: UserEmailAuthentication,
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
) -> Result<UserEmailAuthentication, Self::Error> {
// We technically don't use the upstream_oauth_authorization_session here (other
// than recording it in the span), but this is to make sure the caller
// has fetched one before calling this
let completed_at = clock.now();
// We'll assume the caller has checked that completed_at is None, so in case
// they haven't, the update will not affect any rows, which will raise
// an error
let res = sqlx::query!(
r#"
UPDATE user_email_authentications
SET completed_at = $2
WHERE user_email_authentication_id = $1
AND completed_at IS NULL
"#,
Uuid::from(user_email_authentication.id),
completed_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
user_email_authentication.completed_at = Some(completed_at);
Ok(user_email_authentication)
}
}

View File

@@ -8,8 +8,8 @@ use std::net::IpAddr;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
Clock, UserEmailAuthentication, UserRegistration, UserRegistrationPassword,
UserRegistrationToken,
Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration,
UserRegistrationPassword, UserRegistrationToken,
};
use mas_storage::user::UserRegistrationRepository;
use rand::RngCore;
@@ -46,6 +46,7 @@ struct UserRegistrationLookup {
user_registration_token_id: Option<Uuid>,
hashed_password: Option<String>,
hashed_password_version: Option<i32>,
upstream_oauth_authorization_session_id: Option<Uuid>,
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
}
@@ -100,6 +101,9 @@ impl TryFrom<UserRegistrationLookup> for UserRegistration {
email_authentication_id: value.email_authentication_id.map(Ulid::from),
user_registration_token_id: value.user_registration_token_id.map(Ulid::from),
password,
upstream_oauth_authorization_session_id: value
.upstream_oauth_authorization_session_id
.map(Ulid::from),
created_at: value.created_at,
completed_at: value.completed_at,
})
@@ -134,6 +138,7 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
, user_registration_token_id
, hashed_password
, hashed_password_version
, upstream_oauth_authorization_session_id
, created_at
, completed_at
FROM user_registrations
@@ -208,6 +213,7 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
email_authentication_id: None,
user_registration_token_id: None,
password: None,
upstream_oauth_authorization_session_id: None,
})
}
@@ -393,6 +399,42 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
Ok(user_registration)
}
#[tracing::instrument(
name = "db.user_registration.set_upstream_oauth_authorization_session",
skip_all,
fields(
db.query.text,
%user_registration.id,
%upstream_oauth_authorization_session.id,
),
err,
)]
async fn set_upstream_oauth_authorization_session(
&mut self,
mut user_registration: UserRegistration,
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
) -> Result<UserRegistration, Self::Error> {
let res = sqlx::query!(
r#"
UPDATE user_registrations
SET upstream_oauth_authorization_session_id = $2
WHERE user_registration_id = $1 AND completed_at IS NULL
"#,
Uuid::from(user_registration.id),
Uuid::from(upstream_oauth_authorization_session.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
user_registration.upstream_oauth_authorization_session_id =
Some(upstream_oauth_authorization_session.id);
Ok(user_registration)
}
#[tracing::instrument(
name = "db.user_registration.complete",
skip_all,
@@ -433,7 +475,14 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
mod tests {
use std::net::{IpAddr, Ipv4Addr};
use mas_data_model::{Clock, UserRegistrationPassword, clock::MockClock};
use mas_data_model::{
Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode,
UpstreamOAuthProviderTokenAuthMethod, UserRegistrationPassword, clock::MockClock,
};
use mas_iana::jose::JsonWebSignatureAlg;
use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams;
use oauth2_types::scope::Scope;
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
@@ -851,4 +900,120 @@ mod tests {
.await;
assert!(res.is_err());
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_set_upstream_oauth_session(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let registration = repo
.user_registration()
.add(&mut rng, &clock, "alice".to_owned(), None, None, None)
.await
.unwrap();
assert_eq!(registration.upstream_oauth_authorization_session_id, None);
let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
&clock,
UpstreamOAuthProviderParams {
issuer: Some("https://example.com/".to_owned()),
human_name: Some("Example Ltd.".to_owned()),
brand_name: None,
scope: Scope::from_iter([oauth2_types::scope::OPENID]),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
token_endpoint_signing_alg: None,
id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
fetch_userinfo: false,
userinfo_signed_response_alg: None,
jwks_uri_override: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
response_mode: None,
additional_authorization_parameters: Vec::new(),
forward_login_hint: false,
ui_order: 0,
on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
},
)
.await
.unwrap();
let session = repo
.upstream_oauth_session()
.add(&mut rng, &clock, &provider, "state".to_owned(), None, None)
.await
.unwrap();
let registration = repo
.user_registration()
.set_upstream_oauth_authorization_session(registration, &session)
.await
.unwrap();
assert_eq!(
registration.upstream_oauth_authorization_session_id,
Some(session.id)
);
let lookup = repo
.user_registration()
.lookup(registration.id)
.await
.unwrap()
.unwrap();
assert_eq!(
lookup.upstream_oauth_authorization_session_id,
registration.upstream_oauth_authorization_session_id
);
// Setting it again should work
let registration = repo
.user_registration()
.set_upstream_oauth_authorization_session(registration, &session)
.await
.unwrap();
assert_eq!(
registration.upstream_oauth_authorization_session_id,
Some(session.id)
);
let lookup = repo
.user_registration()
.lookup(registration.id)
.await
.unwrap()
.unwrap();
assert_eq!(
lookup.upstream_oauth_authorization_session_id,
registration.upstream_oauth_authorization_session_id
);
// Can't set it once completed
let registration = repo
.user_registration()
.complete(&clock, registration)
.await
.unwrap();
let res = repo
.user_registration()
.set_upstream_oauth_authorization_session(registration, &session)
.await;
assert!(res.is_err());
}
}

View File

@@ -488,7 +488,7 @@ async fn test_user_email_repo_authentications(pool: PgPool) {
// Complete the authentication
let authentication = repo
.user_email()
.complete_authentication(&clock, authentication, &code)
.complete_authentication_with_code(&clock, authentication, &code)
.await
.unwrap();
@@ -514,7 +514,7 @@ async fn test_user_email_repo_authentications(pool: PgPool) {
// Completing a second time should fail
let res = repo
.user_email()
.complete_authentication(&clock, authentication, &code)
.complete_authentication_with_code(&clock, authentication, &code)
.await;
assert!(res.is_err());
}

View File

@@ -6,8 +6,8 @@
use async_trait::async_trait;
use mas_data_model::{
BrowserSession, Clock, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode,
UserRegistration,
BrowserSession, Clock, UpstreamOAuthAuthorizationSession, User, UserEmail,
UserEmailAuthentication, UserEmailAuthenticationCode, UserRegistration,
};
use rand_core::RngCore;
use ulid::Ulid;
@@ -306,12 +306,34 @@ pub trait UserEmailRepository: Send + Sync {
/// # Errors
///
/// Returns an error if the underlying repository fails
async fn complete_authentication(
async fn complete_authentication_with_code(
&mut self,
clock: &dyn Clock,
authentication: UserEmailAuthentication,
code: &UserEmailAuthenticationCode,
) -> Result<UserEmailAuthentication, Self::Error>;
/// Complete a [`UserEmailAuthentication`] by using the given upstream oauth
/// authorization session
///
/// Returns the completed [`UserEmailAuthentication`]
///
/// # Parameters
///
/// * `clock`: The clock to use to generate timestamps
/// * `authentication`: The [`UserEmailAuthentication`] to complete
/// * `upstream_oauth_authorization_session`: The
/// [`UpstreamOAuthAuthorizationSession`] to use
///
/// # Errors
///
/// Returns an error if the underlying repository fails
async fn complete_authentication_with_upstream(
&mut self,
clock: &dyn Clock,
authentication: UserEmailAuthentication,
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
) -> Result<UserEmailAuthentication, Self::Error>;
}
repository_impl!(UserEmailRepository:
@@ -374,10 +396,17 @@ repository_impl!(UserEmailRepository:
code: &str,
) -> Result<Option<UserEmailAuthenticationCode>, Self::Error>;
async fn complete_authentication(
async fn complete_authentication_with_code(
&mut self,
clock: &dyn Clock,
authentication: UserEmailAuthentication,
code: &UserEmailAuthenticationCode,
) -> Result<UserEmailAuthentication, Self::Error>;
async fn complete_authentication_with_upstream(
&mut self,
clock: &dyn Clock,
authentication: UserEmailAuthentication,
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
) -> Result<UserEmailAuthentication, Self::Error>;
);

View File

@@ -6,7 +6,10 @@
use std::net::IpAddr;
use async_trait::async_trait;
use mas_data_model::{Clock, UserEmailAuthentication, UserRegistration, UserRegistrationToken};
use mas_data_model::{
Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration,
UserRegistrationToken,
};
use rand_core::RngCore;
use ulid::Ulid;
use url::Url;
@@ -157,6 +160,27 @@ pub trait UserRegistrationRepository: Send + Sync {
user_registration_token: &UserRegistrationToken,
) -> Result<UserRegistration, Self::Error>;
/// Set an [`UpstreamOAuthAuthorizationSession`] to associate with a
/// [`UserRegistration`]
///
/// Returns the updated [`UserRegistration`]
///
/// # Parameters
///
/// * `user_registration`: The [`UserRegistration`] to update
/// * `upstream_oauth_authorization_session`: The
/// [`UpstreamOAuthAuthorizationSession`] to set
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails or if the
/// registration is already completed
async fn set_upstream_oauth_authorization_session(
&mut self,
user_registration: UserRegistration,
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
) -> Result<UserRegistration, Self::Error>;
/// Complete a [`UserRegistration`]
///
/// Returns the updated [`UserRegistration`]
@@ -214,6 +238,11 @@ repository_impl!(UserRegistrationRepository:
user_registration: UserRegistration,
user_registration_token: &UserRegistrationToken,
) -> Result<UserRegistration, Self::Error>;
async fn set_upstream_oauth_authorization_session(
&mut self,
user_registration: UserRegistration,
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
) -> Result<UserRegistration, Self::Error>;
async fn complete(
&mut self,
clock: &dyn Clock,