List and count methods on the UserRegistrationTokenRepository

This commit is contained in:
Quentin Gliech
2025-06-03 09:47:27 +02:00
parent 0760b4e9bc
commit 8a6fd1d6b2
4 changed files with 578 additions and 5 deletions

View File

@@ -1,4 +1,4 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
@@ -139,3 +139,16 @@ pub enum UpstreamOAuthLinks {
HumanAccountName,
CreatedAt,
}
#[derive(sea_query::Iden)]
pub enum UserRegistrationTokens {
Table,
UserRegistrationTokenId,
Token,
UsageLimit,
TimesUsed,
CreatedAt,
LastUsedAt,
ExpiresAt,
RevokedAt,
}

View File

@@ -6,12 +6,25 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::UserRegistrationToken;
use mas_storage::{Clock, user::UserRegistrationTokenRepository};
use mas_storage::{
Clock, Page, Pagination,
user::{UserRegistrationTokenFilter, UserRegistrationTokenRepository},
};
use rand::RngCore;
use sqlx::{PgConnection, types::Uuid};
use sea_query::{Condition, Expr, PostgresQueryBuilder, Query, enum_def};
use sea_query_binder::SqlxBinder;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{DatabaseInconsistencyError, errors::DatabaseError, tracing::ExecuteExt};
use crate::{
DatabaseInconsistencyError,
errors::DatabaseError,
filter::{Filter, StatementExt},
iden::UserRegistrationTokens,
pagination::QueryBuilderExt,
tracing::ExecuteExt,
};
/// An implementation of [`mas_storage::user::UserRegistrationTokenRepository`]
/// for a PostgreSQL connection
@@ -27,6 +40,8 @@ impl<'c> PgUserRegistrationTokenRepository<'c> {
}
}
#[derive(Debug, Clone, sqlx::FromRow)]
#[enum_def]
struct UserRegistrationTokenLookup {
user_registration_token_id: Uuid,
token: String,
@@ -38,6 +53,130 @@ struct UserRegistrationTokenLookup {
revoked_at: Option<DateTime<Utc>>,
}
impl Filter for UserRegistrationTokenFilter {
#[expect(clippy::too_many_lines)]
fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
sea_query::Condition::all()
.add_option(self.has_been_used().map(|has_been_used| {
if has_been_used {
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::TimesUsed,
))
.gt(0)
} else {
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::TimesUsed,
))
.eq(0)
}
}))
.add_option(self.is_revoked().map(|is_revoked| {
if is_revoked {
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::RevokedAt,
))
.is_not_null()
} else {
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::RevokedAt,
))
.is_null()
}
}))
.add_option(self.is_expired().map(|is_expired| {
if is_expired {
Condition::all()
.add(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::ExpiresAt,
))
.is_not_null(),
)
.add(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::ExpiresAt,
))
.lt(Expr::val(self.now())),
)
} else {
Condition::any()
.add(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::ExpiresAt,
))
.is_null(),
)
.add(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::ExpiresAt,
))
.gte(Expr::val(self.now())),
)
}
}))
.add_option(self.is_valid().map(|is_valid| {
let valid = Condition::all()
// Has not reached its usage limit
.add(
Condition::any()
.add(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::UsageLimit,
))
.is_null(),
)
.add(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::TimesUsed,
))
.lt(Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::UsageLimit,
))),
),
)
// Has not been revoked
.add(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::RevokedAt,
))
.is_null(),
)
// Has not expired
.add(
Condition::any()
.add(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::ExpiresAt,
))
.is_null(),
)
.add(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::ExpiresAt,
))
.gte(Expr::val(self.now())),
),
);
if is_valid { valid } else { valid.not() }
}))
}
}
impl TryFrom<UserRegistrationTokenLookup> for UserRegistrationToken {
type Error = DatabaseInconsistencyError;
@@ -79,6 +218,129 @@ impl TryFrom<UserRegistrationTokenLookup> for UserRegistrationToken {
impl UserRegistrationTokenRepository for PgUserRegistrationTokenRepository<'_> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.user_registration_token.list",
skip_all,
fields(
db.query.text,
),
err,
)]
async fn list(
&mut self,
filter: UserRegistrationTokenFilter,
pagination: Pagination,
) -> Result<Page<UserRegistrationToken>, Self::Error> {
let (sql, values) = Query::select()
.expr_as(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::UserRegistrationTokenId,
)),
UserRegistrationTokenLookupIden::UserRegistrationTokenId,
)
.expr_as(
Expr::col((UserRegistrationTokens::Table, UserRegistrationTokens::Token)),
UserRegistrationTokenLookupIden::Token,
)
.expr_as(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::UsageLimit,
)),
UserRegistrationTokenLookupIden::UsageLimit,
)
.expr_as(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::TimesUsed,
)),
UserRegistrationTokenLookupIden::TimesUsed,
)
.expr_as(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::CreatedAt,
)),
UserRegistrationTokenLookupIden::CreatedAt,
)
.expr_as(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::LastUsedAt,
)),
UserRegistrationTokenLookupIden::LastUsedAt,
)
.expr_as(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::ExpiresAt,
)),
UserRegistrationTokenLookupIden::ExpiresAt,
)
.expr_as(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::RevokedAt,
)),
UserRegistrationTokenLookupIden::RevokedAt,
)
.from(UserRegistrationTokens::Table)
.apply_filter(filter)
.generate_pagination(
(
UserRegistrationTokens::Table,
UserRegistrationTokens::UserRegistrationTokenId,
),
pagination,
)
.build_sqlx(PostgresQueryBuilder);
let tokens = sqlx::query_as_with::<_, UserRegistrationTokenLookup, _>(&sql, values)
.traced()
.fetch_all(&mut *self.conn)
.await?
.into_iter()
.map(TryInto::try_into)
.collect::<Result<Vec<_>, _>>()?;
let page = pagination.process(tokens);
Ok(page)
}
#[tracing::instrument(
name = "db.user_registration_token.count",
skip_all,
fields(
db.query.text,
user_registration_token.filter = ?filter,
),
err,
)]
async fn count(&mut self, filter: UserRegistrationTokenFilter) -> Result<usize, Self::Error> {
let (sql, values) = Query::select()
.expr(
Expr::col((
UserRegistrationTokens::Table,
UserRegistrationTokens::UserRegistrationTokenId,
))
.count(),
)
.from(UserRegistrationTokens::Table)
.apply_filter(filter)
.build_sqlx(PostgresQueryBuilder);
let count: i64 = sqlx::query_scalar_with(&sql, values)
.traced()
.fetch_one(&mut *self.conn)
.await?;
count
.try_into()
.map_err(DatabaseError::to_invalid_operation)
}
#[tracing::instrument(
name = "db.user_registration_token.lookup",
skip_all,
@@ -285,3 +547,173 @@ impl UserRegistrationTokenRepository for PgUserRegistrationTokenRepository<'_> {
Ok(token)
}
}
#[cfg(test)]
mod tests {
use chrono::Duration;
use mas_storage::{
Clock as _, Pagination, clock::MockClock, user::UserRegistrationTokenFilter,
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
use crate::PgRepository;
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_list_and_count(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Create different types of tokens
// 1. A regular token
let _token1 = repo
.user_registration_token()
.add(&mut rng, &clock, "token1".to_owned(), None, None)
.await
.unwrap();
// 2. A token that has been used
let token2 = repo
.user_registration_token()
.add(&mut rng, &clock, "token2".to_owned(), None, None)
.await
.unwrap();
let token2 = repo
.user_registration_token()
.use_token(&clock, token2)
.await
.unwrap();
// 3. A token that is expired
let past_time = clock.now() - Duration::days(1);
let token3 = repo
.user_registration_token()
.add(&mut rng, &clock, "token3".to_owned(), None, Some(past_time))
.await
.unwrap();
// 4. A token that is revoked
let token4 = repo
.user_registration_token()
.add(&mut rng, &clock, "token4".to_owned(), None, None)
.await
.unwrap();
let token4 = repo
.user_registration_token()
.revoke(&clock, token4)
.await
.unwrap();
// Test list with empty filter
let empty_filter = UserRegistrationTokenFilter::new(clock.now());
let page = repo
.user_registration_token()
.list(empty_filter, Pagination::first(10))
.await
.unwrap();
assert_eq!(page.edges.len(), 4);
// Test count with empty filter
let count = repo
.user_registration_token()
.count(empty_filter)
.await
.unwrap();
assert_eq!(count, 4);
// Test has_been_used filter
let used_filter = UserRegistrationTokenFilter::new(clock.now()).with_been_used(true);
let page = repo
.user_registration_token()
.list(used_filter, Pagination::first(10))
.await
.unwrap();
assert_eq!(page.edges.len(), 1);
assert_eq!(page.edges[0].id, token2.id);
// Test unused filter
let unused_filter = UserRegistrationTokenFilter::new(clock.now()).with_been_used(false);
let page = repo
.user_registration_token()
.list(unused_filter, Pagination::first(10))
.await
.unwrap();
assert_eq!(page.edges.len(), 3);
// Test is_expired filter
let expired_filter = UserRegistrationTokenFilter::new(clock.now()).with_expired(true);
let page = repo
.user_registration_token()
.list(expired_filter, Pagination::first(10))
.await
.unwrap();
assert_eq!(page.edges.len(), 1);
assert_eq!(page.edges[0].id, token3.id);
let not_expired_filter = UserRegistrationTokenFilter::new(clock.now()).with_expired(false);
let page = repo
.user_registration_token()
.list(not_expired_filter, Pagination::first(10))
.await
.unwrap();
assert_eq!(page.edges.len(), 3);
// Test is_revoked filter
let revoked_filter = UserRegistrationTokenFilter::new(clock.now()).with_revoked(true);
let page = repo
.user_registration_token()
.list(revoked_filter, Pagination::first(10))
.await
.unwrap();
assert_eq!(page.edges.len(), 1);
assert_eq!(page.edges[0].id, token4.id);
let not_revoked_filter = UserRegistrationTokenFilter::new(clock.now()).with_revoked(false);
let page = repo
.user_registration_token()
.list(not_revoked_filter, Pagination::first(10))
.await
.unwrap();
assert_eq!(page.edges.len(), 3);
// Test is_valid filter
let valid_filter = UserRegistrationTokenFilter::new(clock.now()).with_valid(true);
let page = repo
.user_registration_token()
.list(valid_filter, Pagination::first(10))
.await
.unwrap();
assert_eq!(page.edges.len(), 2);
let invalid_filter = UserRegistrationTokenFilter::new(clock.now()).with_valid(false);
let page = repo
.user_registration_token()
.list(invalid_filter, Pagination::first(10))
.await
.unwrap();
assert_eq!(page.edges.len(), 2);
// Test combined filters
let combined_filter = UserRegistrationTokenFilter::new(clock.now())
.with_been_used(false)
.with_revoked(true);
let page = repo
.user_registration_token()
.list(combined_filter, Pagination::first(10))
.await
.unwrap();
assert_eq!(page.edges.len(), 1);
assert_eq!(page.edges[0].id, token4.id);
// Test pagination
let page = repo
.user_registration_token()
.list(empty_filter, Pagination::first(2))
.await
.unwrap();
assert_eq!(page.edges.len(), 2);
}
}

View File

@@ -26,7 +26,7 @@ pub use self::{
password::UserPasswordRepository,
recovery::UserRecoveryRepository,
registration::UserRegistrationRepository,
registration_token::UserRegistrationTokenRepository,
registration_token::{UserRegistrationTokenFilter, UserRegistrationTokenRepository},
session::{BrowserSessionFilter, BrowserSessionRepository},
terms::UserTermsRepository,
};

View File

@@ -11,6 +11,97 @@ use ulid::Ulid;
use crate::{Clock, repository_impl};
/// A filter to apply when listing [`UserRegistrationToken`]s
#[derive(Debug, Clone, Copy)]
pub struct UserRegistrationTokenFilter {
now: DateTime<Utc>,
has_been_used: Option<bool>,
is_revoked: Option<bool>,
is_expired: Option<bool>,
is_valid: Option<bool>,
}
impl UserRegistrationTokenFilter {
/// Create a new empty filter
#[must_use]
pub fn new(now: DateTime<Utc>) -> Self {
Self {
now,
has_been_used: None,
is_revoked: None,
is_expired: None,
is_valid: None,
}
}
/// Filter by whether the token has been used at least once
#[must_use]
pub fn with_been_used(mut self, has_been_used: bool) -> Self {
self.has_been_used = Some(has_been_used);
self
}
/// Filter by revoked status
#[must_use]
pub fn with_revoked(mut self, is_revoked: bool) -> Self {
self.is_revoked = Some(is_revoked);
self
}
/// Filter by expired status
#[must_use]
pub fn with_expired(mut self, is_expired: bool) -> Self {
self.is_expired = Some(is_expired);
self
}
/// Filter by valid status (meaning: not expired, not revoked, and still
/// with uses left)
#[must_use]
pub fn with_valid(mut self, is_valid: bool) -> Self {
self.is_valid = Some(is_valid);
self
}
/// Get the used status filter
///
/// Returns [`None`] if no used status filter was set
#[must_use]
pub fn has_been_used(&self) -> Option<bool> {
self.has_been_used
}
/// Get the revoked status filter
///
/// Returns [`None`] if no revoked status filter was set
#[must_use]
pub fn is_revoked(&self) -> Option<bool> {
self.is_revoked
}
/// Get the expired status filter
///
/// Returns [`None`] if no expired status filter was set
#[must_use]
pub fn is_expired(&self) -> Option<bool> {
self.is_expired
}
/// Get the valid status filter
///
/// Returns [`None`] if no valid status filter was set
#[must_use]
pub fn is_valid(&self) -> Option<bool> {
self.is_valid
}
/// Get the current time for this filter evaluation
#[must_use]
pub fn now(&self) -> DateTime<Utc> {
self.now
}
}
/// A [`UserRegistrationTokenRepository`] helps interacting with
/// [`UserRegistrationToken`] saved in the storage backend
#[async_trait]
@@ -104,6 +195,37 @@ pub trait UserRegistrationTokenRepository: Send + Sync {
clock: &dyn Clock,
token: UserRegistrationToken,
) -> Result<UserRegistrationToken, Self::Error>;
/// List [`UserRegistrationToken`]s based on the provided filter
///
/// Returns a list of matching [`UserRegistrationToken`]s
///
/// # Parameters
///
/// * `filter`: The filter to apply
/// * `pagination`: The pagination parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn list(
&mut self,
filter: UserRegistrationTokenFilter,
pagination: crate::Pagination,
) -> Result<crate::Page<UserRegistrationToken>, Self::Error>;
/// Count [`UserRegistrationToken`]s based on the provided filter
///
/// Returns the number of matching [`UserRegistrationToken`]s
///
/// # Parameters
///
/// * `filter`: The filter to apply
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn count(&mut self, filter: UserRegistrationTokenFilter) -> Result<usize, Self::Error>;
}
repository_impl!(UserRegistrationTokenRepository:
@@ -127,4 +249,10 @@ repository_impl!(UserRegistrationTokenRepository:
clock: &dyn Clock,
token: UserRegistrationToken,
) -> Result<UserRegistrationToken, Self::Error>;
async fn list(
&mut self,
filter: UserRegistrationTokenFilter,
pagination: crate::Pagination,
) -> Result<crate::Page<UserRegistrationToken>, Self::Error>;
async fn count(&mut self, filter: UserRegistrationTokenFilter) -> Result<usize, Self::Error>;
);