diff --git a/crates/storage-pg/src/iden.rs b/crates/storage-pg/src/iden.rs index 76067b2fa..6692c7a75 100644 --- a/crates/storage-pg/src/iden.rs +++ b/crates/storage-pg/src/iden.rs @@ -1,4 +1,4 @@ -// Copyright 2024 New Vector Ltd. +// Copyright 2024, 2025 New Vector Ltd. // Copyright 2023, 2024 The Matrix.org Foundation C.I.C. // // SPDX-License-Identifier: AGPL-3.0-only @@ -139,3 +139,16 @@ pub enum UpstreamOAuthLinks { HumanAccountName, CreatedAt, } + +#[derive(sea_query::Iden)] +pub enum UserRegistrationTokens { + Table, + UserRegistrationTokenId, + Token, + UsageLimit, + TimesUsed, + CreatedAt, + LastUsedAt, + ExpiresAt, + RevokedAt, +} diff --git a/crates/storage-pg/src/user/registration_token.rs b/crates/storage-pg/src/user/registration_token.rs index 2f71fc83c..02b03038d 100644 --- a/crates/storage-pg/src/user/registration_token.rs +++ b/crates/storage-pg/src/user/registration_token.rs @@ -6,12 +6,25 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::UserRegistrationToken; -use mas_storage::{Clock, user::UserRegistrationTokenRepository}; +use mas_storage::{ + Clock, Page, Pagination, + user::{UserRegistrationTokenFilter, UserRegistrationTokenRepository}, +}; use rand::RngCore; -use sqlx::{PgConnection, types::Uuid}; +use sea_query::{Condition, Expr, PostgresQueryBuilder, Query, enum_def}; +use sea_query_binder::SqlxBinder; +use sqlx::PgConnection; use ulid::Ulid; +use uuid::Uuid; -use crate::{DatabaseInconsistencyError, errors::DatabaseError, tracing::ExecuteExt}; +use crate::{ + DatabaseInconsistencyError, + errors::DatabaseError, + filter::{Filter, StatementExt}, + iden::UserRegistrationTokens, + pagination::QueryBuilderExt, + tracing::ExecuteExt, +}; /// An implementation of [`mas_storage::user::UserRegistrationTokenRepository`] /// for a PostgreSQL connection @@ -27,6 +40,8 @@ impl<'c> PgUserRegistrationTokenRepository<'c> { } } +#[derive(Debug, Clone, sqlx::FromRow)] +#[enum_def] struct UserRegistrationTokenLookup { user_registration_token_id: Uuid, token: String, @@ -38,6 +53,130 @@ struct UserRegistrationTokenLookup { revoked_at: Option>, } +impl Filter for UserRegistrationTokenFilter { + #[expect(clippy::too_many_lines)] + fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition { + sea_query::Condition::all() + .add_option(self.has_been_used().map(|has_been_used| { + if has_been_used { + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::TimesUsed, + )) + .gt(0) + } else { + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::TimesUsed, + )) + .eq(0) + } + })) + .add_option(self.is_revoked().map(|is_revoked| { + if is_revoked { + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::RevokedAt, + )) + .is_not_null() + } else { + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::RevokedAt, + )) + .is_null() + } + })) + .add_option(self.is_expired().map(|is_expired| { + if is_expired { + Condition::all() + .add( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::ExpiresAt, + )) + .is_not_null(), + ) + .add( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::ExpiresAt, + )) + .lt(Expr::val(self.now())), + ) + } else { + Condition::any() + .add( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::ExpiresAt, + )) + .is_null(), + ) + .add( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::ExpiresAt, + )) + .gte(Expr::val(self.now())), + ) + } + })) + .add_option(self.is_valid().map(|is_valid| { + let valid = Condition::all() + // Has not reached its usage limit + .add( + Condition::any() + .add( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::UsageLimit, + )) + .is_null(), + ) + .add( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::TimesUsed, + )) + .lt(Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::UsageLimit, + ))), + ), + ) + // Has not been revoked + .add( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::RevokedAt, + )) + .is_null(), + ) + // Has not expired + .add( + Condition::any() + .add( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::ExpiresAt, + )) + .is_null(), + ) + .add( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::ExpiresAt, + )) + .gte(Expr::val(self.now())), + ), + ); + + if is_valid { valid } else { valid.not() } + })) + } +} + impl TryFrom for UserRegistrationToken { type Error = DatabaseInconsistencyError; @@ -79,6 +218,129 @@ impl TryFrom for UserRegistrationToken { impl UserRegistrationTokenRepository for PgUserRegistrationTokenRepository<'_> { type Error = DatabaseError; + #[tracing::instrument( + name = "db.user_registration_token.list", + skip_all, + fields( + db.query.text, + ), + err, + )] + async fn list( + &mut self, + filter: UserRegistrationTokenFilter, + pagination: Pagination, + ) -> Result, Self::Error> { + let (sql, values) = Query::select() + .expr_as( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::UserRegistrationTokenId, + )), + UserRegistrationTokenLookupIden::UserRegistrationTokenId, + ) + .expr_as( + Expr::col((UserRegistrationTokens::Table, UserRegistrationTokens::Token)), + UserRegistrationTokenLookupIden::Token, + ) + .expr_as( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::UsageLimit, + )), + UserRegistrationTokenLookupIden::UsageLimit, + ) + .expr_as( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::TimesUsed, + )), + UserRegistrationTokenLookupIden::TimesUsed, + ) + .expr_as( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::CreatedAt, + )), + UserRegistrationTokenLookupIden::CreatedAt, + ) + .expr_as( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::LastUsedAt, + )), + UserRegistrationTokenLookupIden::LastUsedAt, + ) + .expr_as( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::ExpiresAt, + )), + UserRegistrationTokenLookupIden::ExpiresAt, + ) + .expr_as( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::RevokedAt, + )), + UserRegistrationTokenLookupIden::RevokedAt, + ) + .from(UserRegistrationTokens::Table) + .apply_filter(filter) + .generate_pagination( + ( + UserRegistrationTokens::Table, + UserRegistrationTokens::UserRegistrationTokenId, + ), + pagination, + ) + .build_sqlx(PostgresQueryBuilder); + + let tokens = sqlx::query_as_with::<_, UserRegistrationTokenLookup, _>(&sql, values) + .traced() + .fetch_all(&mut *self.conn) + .await? + .into_iter() + .map(TryInto::try_into) + .collect::, _>>()?; + + let page = pagination.process(tokens); + + Ok(page) + } + + #[tracing::instrument( + name = "db.user_registration_token.count", + skip_all, + fields( + db.query.text, + user_registration_token.filter = ?filter, + ), + err, + )] + async fn count(&mut self, filter: UserRegistrationTokenFilter) -> Result { + let (sql, values) = Query::select() + .expr( + Expr::col(( + UserRegistrationTokens::Table, + UserRegistrationTokens::UserRegistrationTokenId, + )) + .count(), + ) + .from(UserRegistrationTokens::Table) + .apply_filter(filter) + .build_sqlx(PostgresQueryBuilder); + + let count: i64 = sqlx::query_scalar_with(&sql, values) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + count + .try_into() + .map_err(DatabaseError::to_invalid_operation) + } + #[tracing::instrument( name = "db.user_registration_token.lookup", skip_all, @@ -285,3 +547,173 @@ impl UserRegistrationTokenRepository for PgUserRegistrationTokenRepository<'_> { Ok(token) } } + +#[cfg(test)] +mod tests { + use chrono::Duration; + use mas_storage::{ + Clock as _, Pagination, clock::MockClock, user::UserRegistrationTokenFilter, + }; + use rand::SeedableRng; + use rand_chacha::ChaChaRng; + use sqlx::PgPool; + + use crate::PgRepository; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_list_and_count(pool: PgPool) { + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + + // Create different types of tokens + // 1. A regular token + let _token1 = repo + .user_registration_token() + .add(&mut rng, &clock, "token1".to_owned(), None, None) + .await + .unwrap(); + + // 2. A token that has been used + let token2 = repo + .user_registration_token() + .add(&mut rng, &clock, "token2".to_owned(), None, None) + .await + .unwrap(); + let token2 = repo + .user_registration_token() + .use_token(&clock, token2) + .await + .unwrap(); + + // 3. A token that is expired + let past_time = clock.now() - Duration::days(1); + let token3 = repo + .user_registration_token() + .add(&mut rng, &clock, "token3".to_owned(), None, Some(past_time)) + .await + .unwrap(); + + // 4. A token that is revoked + let token4 = repo + .user_registration_token() + .add(&mut rng, &clock, "token4".to_owned(), None, None) + .await + .unwrap(); + let token4 = repo + .user_registration_token() + .revoke(&clock, token4) + .await + .unwrap(); + + // Test list with empty filter + let empty_filter = UserRegistrationTokenFilter::new(clock.now()); + let page = repo + .user_registration_token() + .list(empty_filter, Pagination::first(10)) + .await + .unwrap(); + assert_eq!(page.edges.len(), 4); + + // Test count with empty filter + let count = repo + .user_registration_token() + .count(empty_filter) + .await + .unwrap(); + assert_eq!(count, 4); + + // Test has_been_used filter + let used_filter = UserRegistrationTokenFilter::new(clock.now()).with_been_used(true); + let page = repo + .user_registration_token() + .list(used_filter, Pagination::first(10)) + .await + .unwrap(); + assert_eq!(page.edges.len(), 1); + assert_eq!(page.edges[0].id, token2.id); + + // Test unused filter + let unused_filter = UserRegistrationTokenFilter::new(clock.now()).with_been_used(false); + let page = repo + .user_registration_token() + .list(unused_filter, Pagination::first(10)) + .await + .unwrap(); + assert_eq!(page.edges.len(), 3); + + // Test is_expired filter + let expired_filter = UserRegistrationTokenFilter::new(clock.now()).with_expired(true); + let page = repo + .user_registration_token() + .list(expired_filter, Pagination::first(10)) + .await + .unwrap(); + assert_eq!(page.edges.len(), 1); + assert_eq!(page.edges[0].id, token3.id); + + let not_expired_filter = UserRegistrationTokenFilter::new(clock.now()).with_expired(false); + let page = repo + .user_registration_token() + .list(not_expired_filter, Pagination::first(10)) + .await + .unwrap(); + assert_eq!(page.edges.len(), 3); + + // Test is_revoked filter + let revoked_filter = UserRegistrationTokenFilter::new(clock.now()).with_revoked(true); + let page = repo + .user_registration_token() + .list(revoked_filter, Pagination::first(10)) + .await + .unwrap(); + assert_eq!(page.edges.len(), 1); + assert_eq!(page.edges[0].id, token4.id); + + let not_revoked_filter = UserRegistrationTokenFilter::new(clock.now()).with_revoked(false); + let page = repo + .user_registration_token() + .list(not_revoked_filter, Pagination::first(10)) + .await + .unwrap(); + assert_eq!(page.edges.len(), 3); + + // Test is_valid filter + let valid_filter = UserRegistrationTokenFilter::new(clock.now()).with_valid(true); + let page = repo + .user_registration_token() + .list(valid_filter, Pagination::first(10)) + .await + .unwrap(); + assert_eq!(page.edges.len(), 2); + + let invalid_filter = UserRegistrationTokenFilter::new(clock.now()).with_valid(false); + let page = repo + .user_registration_token() + .list(invalid_filter, Pagination::first(10)) + .await + .unwrap(); + assert_eq!(page.edges.len(), 2); + + // Test combined filters + let combined_filter = UserRegistrationTokenFilter::new(clock.now()) + .with_been_used(false) + .with_revoked(true); + let page = repo + .user_registration_token() + .list(combined_filter, Pagination::first(10)) + .await + .unwrap(); + assert_eq!(page.edges.len(), 1); + assert_eq!(page.edges[0].id, token4.id); + + // Test pagination + let page = repo + .user_registration_token() + .list(empty_filter, Pagination::first(2)) + .await + .unwrap(); + assert_eq!(page.edges.len(), 2); + } +} diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 6a9bdc4c5..17852f0e9 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -26,7 +26,7 @@ pub use self::{ password::UserPasswordRepository, recovery::UserRecoveryRepository, registration::UserRegistrationRepository, - registration_token::UserRegistrationTokenRepository, + registration_token::{UserRegistrationTokenFilter, UserRegistrationTokenRepository}, session::{BrowserSessionFilter, BrowserSessionRepository}, terms::UserTermsRepository, }; diff --git a/crates/storage/src/user/registration_token.rs b/crates/storage/src/user/registration_token.rs index 91b0584b4..60f65a73f 100644 --- a/crates/storage/src/user/registration_token.rs +++ b/crates/storage/src/user/registration_token.rs @@ -11,6 +11,97 @@ use ulid::Ulid; use crate::{Clock, repository_impl}; +/// A filter to apply when listing [`UserRegistrationToken`]s +#[derive(Debug, Clone, Copy)] +pub struct UserRegistrationTokenFilter { + now: DateTime, + has_been_used: Option, + is_revoked: Option, + is_expired: Option, + is_valid: Option, +} + +impl UserRegistrationTokenFilter { + /// Create a new empty filter + #[must_use] + pub fn new(now: DateTime) -> Self { + Self { + now, + has_been_used: None, + is_revoked: None, + is_expired: None, + is_valid: None, + } + } + + /// Filter by whether the token has been used at least once + #[must_use] + pub fn with_been_used(mut self, has_been_used: bool) -> Self { + self.has_been_used = Some(has_been_used); + self + } + + /// Filter by revoked status + #[must_use] + pub fn with_revoked(mut self, is_revoked: bool) -> Self { + self.is_revoked = Some(is_revoked); + self + } + + /// Filter by expired status + #[must_use] + pub fn with_expired(mut self, is_expired: bool) -> Self { + self.is_expired = Some(is_expired); + self + } + + /// Filter by valid status (meaning: not expired, not revoked, and still + /// with uses left) + #[must_use] + pub fn with_valid(mut self, is_valid: bool) -> Self { + self.is_valid = Some(is_valid); + self + } + + /// Get the used status filter + /// + /// Returns [`None`] if no used status filter was set + #[must_use] + pub fn has_been_used(&self) -> Option { + self.has_been_used + } + + /// Get the revoked status filter + /// + /// Returns [`None`] if no revoked status filter was set + #[must_use] + pub fn is_revoked(&self) -> Option { + self.is_revoked + } + + /// Get the expired status filter + /// + /// Returns [`None`] if no expired status filter was set + #[must_use] + pub fn is_expired(&self) -> Option { + self.is_expired + } + + /// Get the valid status filter + /// + /// Returns [`None`] if no valid status filter was set + #[must_use] + pub fn is_valid(&self) -> Option { + self.is_valid + } + + /// Get the current time for this filter evaluation + #[must_use] + pub fn now(&self) -> DateTime { + self.now + } +} + /// A [`UserRegistrationTokenRepository`] helps interacting with /// [`UserRegistrationToken`] saved in the storage backend #[async_trait] @@ -104,6 +195,37 @@ pub trait UserRegistrationTokenRepository: Send + Sync { clock: &dyn Clock, token: UserRegistrationToken, ) -> Result; + + /// List [`UserRegistrationToken`]s based on the provided filter + /// + /// Returns a list of matching [`UserRegistrationToken`]s + /// + /// # Parameters + /// + /// * `filter`: The filter to apply + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn list( + &mut self, + filter: UserRegistrationTokenFilter, + pagination: crate::Pagination, + ) -> Result, Self::Error>; + + /// Count [`UserRegistrationToken`]s based on the provided filter + /// + /// Returns the number of matching [`UserRegistrationToken`]s + /// + /// # Parameters + /// + /// * `filter`: The filter to apply + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn count(&mut self, filter: UserRegistrationTokenFilter) -> Result; } repository_impl!(UserRegistrationTokenRepository: @@ -127,4 +249,10 @@ repository_impl!(UserRegistrationTokenRepository: clock: &dyn Clock, token: UserRegistrationToken, ) -> Result; + async fn list( + &mut self, + filter: UserRegistrationTokenFilter, + pagination: crate::Pagination, + ) -> Result, Self::Error>; + async fn count(&mut self, filter: UserRegistrationTokenFilter) -> Result; );