diff --git a/crates/handlers/src/graphql/mutations/user_email.rs b/crates/handlers/src/graphql/mutations/user_email.rs index 63b825566..34fb54050 100644 --- a/crates/handlers/src/graphql/mutations/user_email.rs +++ b/crates/handlers/src/graphql/mutations/user_email.rs @@ -817,7 +817,7 @@ impl UserEmailMutations { let authentication = repo .user_email() - .complete_authentication(&clock, authentication, &code) + .complete_authentication_with_code(&clock, authentication, &code) .await?; // Check the email is not already in use by anyone, including the current user diff --git a/crates/handlers/src/views/register/steps/verify_email.rs b/crates/handlers/src/views/register/steps/verify_email.rs index 9b85626e1..d1312c951 100644 --- a/crates/handlers/src/views/register/steps/verify_email.rs +++ b/crates/handlers/src/views/register/steps/verify_email.rs @@ -200,7 +200,7 @@ pub(crate) async fn post( }; repo.user_email() - .complete_authentication(&clock, email_authentication, &code) + .complete_authentication_with_code(&clock, email_authentication, &code) .await?; repo.save().await?; diff --git a/crates/storage-pg/src/user/email.rs b/crates/storage-pg/src/user/email.rs index 0f998e55f..05122ac7a 100644 --- a/crates/storage-pg/src/user/email.rs +++ b/crates/storage-pg/src/user/email.rs @@ -7,8 +7,8 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{ - BrowserSession, Clock, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode, - UserRegistration, + BrowserSession, Clock, UpstreamOAuthAuthorizationSession, User, UserEmail, + UserEmailAuthentication, UserEmailAuthenticationCode, UserRegistration, }; use mas_storage::{ Page, Pagination, @@ -668,7 +668,7 @@ impl UserEmailRepository for PgUserEmailRepository<'_> { } #[tracing::instrument( - name = "db.user_email.complete_email_authentication", + name = "db.user_email.complete_email_authentication_with_code", skip_all, fields( db.query.text, @@ -679,7 +679,7 @@ impl UserEmailRepository for PgUserEmailRepository<'_> { ), err, )] - async fn complete_authentication( + async fn complete_authentication_with_code( &mut self, clock: &dyn Clock, mut user_email_authentication: UserEmailAuthentication, @@ -712,4 +712,49 @@ impl UserEmailRepository for PgUserEmailRepository<'_> { user_email_authentication.completed_at = Some(completed_at); Ok(user_email_authentication) } + + #[tracing::instrument( + name = "db.user_email.complete_email_authentication_with_upstream", + skip_all, + fields( + db.query.text, + %user_email_authentication.id, + %user_email_authentication.email, + %upstream_oauth_authorization_session.id, + ), + err, + )] + async fn complete_authentication_with_upstream( + &mut self, + clock: &dyn Clock, + mut user_email_authentication: UserEmailAuthentication, + upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession, + ) -> Result { + // We technically don't use the upstream_oauth_authorization_session here (other + // than recording it in the span), but this is to make sure the caller + // has fetched one before calling this + let completed_at = clock.now(); + + // We'll assume the caller has checked that completed_at is None, so in case + // they haven't, the update will not affect any rows, which will raise + // an error + let res = sqlx::query!( + r#" + UPDATE user_email_authentications + SET completed_at = $2 + WHERE user_email_authentication_id = $1 + AND completed_at IS NULL + "#, + Uuid::from(user_email_authentication.id), + completed_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + user_email_authentication.completed_at = Some(completed_at); + Ok(user_email_authentication) + } } diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index 98489d68d..aa8c9dd07 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -488,7 +488,7 @@ async fn test_user_email_repo_authentications(pool: PgPool) { // Complete the authentication let authentication = repo .user_email() - .complete_authentication(&clock, authentication, &code) + .complete_authentication_with_code(&clock, authentication, &code) .await .unwrap(); @@ -514,7 +514,7 @@ async fn test_user_email_repo_authentications(pool: PgPool) { // Completing a second time should fail let res = repo .user_email() - .complete_authentication(&clock, authentication, &code) + .complete_authentication_with_code(&clock, authentication, &code) .await; assert!(res.is_err()); } diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 7e973510a..f73414130 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -6,8 +6,8 @@ use async_trait::async_trait; use mas_data_model::{ - BrowserSession, Clock, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode, - UserRegistration, + BrowserSession, Clock, UpstreamOAuthAuthorizationSession, User, UserEmail, + UserEmailAuthentication, UserEmailAuthenticationCode, UserRegistration, }; use rand_core::RngCore; use ulid::Ulid; @@ -306,12 +306,34 @@ pub trait UserEmailRepository: Send + Sync { /// # Errors /// /// Returns an error if the underlying repository fails - async fn complete_authentication( + async fn complete_authentication_with_code( &mut self, clock: &dyn Clock, authentication: UserEmailAuthentication, code: &UserEmailAuthenticationCode, ) -> Result; + + /// Complete a [`UserEmailAuthentication`] by using the given upstream oauth + /// authorization session + /// + /// Returns the completed [`UserEmailAuthentication`] + /// + /// # Parameters + /// + /// * `clock`: The clock to use to generate timestamps + /// * `authentication`: The [`UserEmailAuthentication`] to complete + /// * `upstream_oauth_authorization_session`: The + /// [`UpstreamOAuthAuthorizationSession`] to use + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails + async fn complete_authentication_with_upstream( + &mut self, + clock: &dyn Clock, + authentication: UserEmailAuthentication, + upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession, + ) -> Result; } repository_impl!(UserEmailRepository: @@ -374,10 +396,17 @@ repository_impl!(UserEmailRepository: code: &str, ) -> Result, Self::Error>; - async fn complete_authentication( + async fn complete_authentication_with_code( &mut self, clock: &dyn Clock, authentication: UserEmailAuthentication, code: &UserEmailAuthenticationCode, ) -> Result; + + async fn complete_authentication_with_upstream( + &mut self, + clock: &dyn Clock, + authentication: UserEmailAuthentication, + upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession, + ) -> Result; ); diff --git a/crates/storage/src/user/registration.rs b/crates/storage/src/user/registration.rs index 0d32684d4..77c85b932 100644 --- a/crates/storage/src/user/registration.rs +++ b/crates/storage/src/user/registration.rs @@ -6,7 +6,10 @@ use std::net::IpAddr; use async_trait::async_trait; -use mas_data_model::{Clock, UserEmailAuthentication, UserRegistration, UserRegistrationToken}; +use mas_data_model::{ + Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration, + UserRegistrationToken, +}; use rand_core::RngCore; use ulid::Ulid; use url::Url; @@ -157,6 +160,27 @@ pub trait UserRegistrationRepository: Send + Sync { user_registration_token: &UserRegistrationToken, ) -> Result; + /// Set an [`UpstreamOAuthAuthorizationSession`] to associate with a + /// [`UserRegistration`] + /// + /// Returns the updated [`UserRegistration`] + /// + /// # Parameters + /// + /// * `user_registration`: The [`UserRegistration`] to update + /// * `upstream_oauth_authorization_session`: The + /// [`UpstreamOAuthAuthorizationSession`] to set + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails or if the + /// registration is already completed + async fn set_upstream_oauth_authorization_session( + &mut self, + user_registration: UserRegistration, + upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession, + ) -> Result; + /// Complete a [`UserRegistration`] /// /// Returns the updated [`UserRegistration`] @@ -214,6 +238,11 @@ repository_impl!(UserRegistrationRepository: user_registration: UserRegistration, user_registration_token: &UserRegistrationToken, ) -> Result; + async fn set_upstream_oauth_authorization_session( + &mut self, + user_registration: UserRegistration, + upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession, + ) -> Result; async fn complete( &mut self, clock: &dyn Clock,