diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 8477222c5..af4d0be37 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -50,6 +50,6 @@ pub use self::{ users::{ Authentication, AuthenticationMethod, BrowserSession, Password, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode, UserRecoverySession, - UserRecoveryTicket, UserRegistration, UserRegistrationPassword, + UserRecoveryTicket, UserRegistration, UserRegistrationPassword, UserRegistrationToken, }, }; diff --git a/crates/data-model/src/users.rs b/crates/data-model/src/users.rs index 7e40f4df2..fc6ed2695 100644 --- a/crates/data-model/src/users.rs +++ b/crates/data-model/src/users.rs @@ -201,6 +201,52 @@ pub struct UserRegistrationPassword { pub version: u16, } +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UserRegistrationToken { + pub id: Ulid, + pub token: String, + pub usage_limit: Option, + pub times_used: u32, + pub created_at: DateTime, + pub last_used_at: Option>, + pub expires_at: Option>, + pub revoked_at: Option>, +} + +impl UserRegistrationToken { + /// Returns `true` if the token is still valid and can be used + #[must_use] + pub fn is_valid(&self, now: DateTime) -> bool { + // Check if revoked + if self.revoked_at.is_some() { + return false; + } + + // Check if expired + if let Some(expires_at) = self.expires_at { + if now >= expires_at { + return false; + } + } + + // Check if usage limit exceeded + if let Some(usage_limit) = self.usage_limit { + if self.times_used >= usage_limit { + return false; + } + } + + true + } + + /// Returns `true` if the token can still be used (not expired and under + /// usage limit) + #[must_use] + pub fn can_be_used(&self, now: DateTime) -> bool { + self.is_valid(now) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct UserRegistration { pub id: Ulid, @@ -208,6 +254,7 @@ pub struct UserRegistration { pub display_name: Option, pub terms_url: Option, pub email_authentication_id: Option, + pub user_registration_token_id: Option, pub password: Option, pub post_auth_action: Option, pub ip_address: Option, diff --git a/crates/storage-pg/.sqlx/query-5133f9c5ba06201433be4ec784034d222975d084d0a9ebe7f1b6b865ab2e09ef.json b/crates/storage-pg/.sqlx/query-5133f9c5ba06201433be4ec784034d222975d084d0a9ebe7f1b6b865ab2e09ef.json new file mode 100644 index 000000000..227398475 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-5133f9c5ba06201433be4ec784034d222975d084d0a9ebe7f1b6b865ab2e09ef.json @@ -0,0 +1,23 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE user_registration_tokens\n SET times_used = times_used + 1,\n last_used_at = $2\n WHERE user_registration_token_id = $1 AND revoked_at IS NULL\n RETURNING times_used\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "times_used", + "type_info": "Int4" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + }, + "nullable": [ + false + ] + }, + "hash": "5133f9c5ba06201433be4ec784034d222975d084d0a9ebe7f1b6b865ab2e09ef" +} diff --git a/crates/storage-pg/.sqlx/query-6772b17585f26365e70ec3e342100c6890d2d63f54f1306e1bb95ca6ca123777.json b/crates/storage-pg/.sqlx/query-5bb3ad7486365e0798e103b072514e66b5b69a347dce91135e158a5eba1d1426.json similarity index 77% rename from crates/storage-pg/.sqlx/query-6772b17585f26365e70ec3e342100c6890d2d63f54f1306e1bb95ca6ca123777.json rename to crates/storage-pg/.sqlx/query-5bb3ad7486365e0798e103b072514e66b5b69a347dce91135e158a5eba1d1426.json index 6ee03e2d7..bad355b81 100644 --- a/crates/storage-pg/.sqlx/query-6772b17585f26365e70ec3e342100c6890d2d63f54f1306e1bb95ca6ca123777.json +++ b/crates/storage-pg/.sqlx/query-5bb3ad7486365e0798e103b072514e66b5b69a347dce91135e158a5eba1d1426.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT user_registration_id\n , ip_address as \"ip_address: IpAddr\"\n , user_agent\n , post_auth_action\n , username\n , display_name\n , terms_url\n , email_authentication_id\n , hashed_password\n , hashed_password_version\n , created_at\n , completed_at\n FROM user_registrations\n WHERE user_registration_id = $1\n ", + "query": "\n SELECT user_registration_id\n , ip_address as \"ip_address: IpAddr\"\n , user_agent\n , post_auth_action\n , username\n , display_name\n , terms_url\n , email_authentication_id\n , user_registration_token_id\n , hashed_password\n , hashed_password_version\n , created_at\n , completed_at\n FROM user_registrations\n WHERE user_registration_id = $1\n ", "describe": { "columns": [ { @@ -45,21 +45,26 @@ }, { "ordinal": 8, + "name": "user_registration_token_id", + "type_info": "Uuid" + }, + { + "ordinal": 9, "name": "hashed_password", "type_info": "Text" }, { - "ordinal": 9, + "ordinal": 10, "name": "hashed_password_version", "type_info": "Int4" }, { - "ordinal": 10, + "ordinal": 11, "name": "created_at", "type_info": "Timestamptz" }, { - "ordinal": 11, + "ordinal": 12, "name": "completed_at", "type_info": "Timestamptz" } @@ -80,9 +85,10 @@ true, true, true, + true, false, true ] }, - "hash": "6772b17585f26365e70ec3e342100c6890d2d63f54f1306e1bb95ca6ca123777" + "hash": "5bb3ad7486365e0798e103b072514e66b5b69a347dce91135e158a5eba1d1426" } diff --git a/crates/storage-pg/.sqlx/query-860e01cd660b450439d63c5ee31ade59f478b0b096b4bc90c89fb9c26b467dd2.json b/crates/storage-pg/.sqlx/query-860e01cd660b450439d63c5ee31ade59f478b0b096b4bc90c89fb9c26b467dd2.json new file mode 100644 index 000000000..a8909316d --- /dev/null +++ b/crates/storage-pg/.sqlx/query-860e01cd660b450439d63c5ee31ade59f478b0b096b4bc90c89fb9c26b467dd2.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE user_registrations\n SET user_registration_token_id = $2\n WHERE user_registration_id = $1 AND completed_at IS NULL\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "860e01cd660b450439d63c5ee31ade59f478b0b096b4bc90c89fb9c26b467dd2" +} diff --git a/crates/storage-pg/.sqlx/query-89edaec8661e435c3b71bb9b995cd711eb78a4d39608e897432d6124cd135938.json b/crates/storage-pg/.sqlx/query-89edaec8661e435c3b71bb9b995cd711eb78a4d39608e897432d6124cd135938.json new file mode 100644 index 000000000..f04a39a7f --- /dev/null +++ b/crates/storage-pg/.sqlx/query-89edaec8661e435c3b71bb9b995cd711eb78a4d39608e897432d6124cd135938.json @@ -0,0 +1,18 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO user_registration_tokens\n (user_registration_token_id, token, usage_limit, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Int4", + "Timestamptz", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "89edaec8661e435c3b71bb9b995cd711eb78a4d39608e897432d6124cd135938" +} diff --git a/crates/storage-pg/.sqlx/query-b3568613352efae1125a88565d886157d96866f7ef9b09b03a45ba4322664bd0.json b/crates/storage-pg/.sqlx/query-b3568613352efae1125a88565d886157d96866f7ef9b09b03a45ba4322664bd0.json new file mode 100644 index 000000000..9acc3f81a --- /dev/null +++ b/crates/storage-pg/.sqlx/query-b3568613352efae1125a88565d886157d96866f7ef9b09b03a45ba4322664bd0.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE user_registration_tokens\n SET revoked_at = $2\n WHERE user_registration_token_id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "b3568613352efae1125a88565d886157d96866f7ef9b09b03a45ba4322664bd0" +} diff --git a/crates/storage-pg/.sqlx/query-d0355d4e98bec6120f17d8cf81ac8c30ed19e9cebd0c8e7c7918b1c3ca0e3cba.json b/crates/storage-pg/.sqlx/query-d0355d4e98bec6120f17d8cf81ac8c30ed19e9cebd0c8e7c7918b1c3ca0e3cba.json new file mode 100644 index 000000000..2e20ac5ca --- /dev/null +++ b/crates/storage-pg/.sqlx/query-d0355d4e98bec6120f17d8cf81ac8c30ed19e9cebd0c8e7c7918b1c3ca0e3cba.json @@ -0,0 +1,64 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT user_registration_token_id,\n token,\n usage_limit,\n times_used,\n created_at,\n last_used_at,\n expires_at,\n revoked_at\n FROM user_registration_tokens\n WHERE user_registration_token_id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "user_registration_token_id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "token", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "usage_limit", + "type_info": "Int4" + }, + { + "ordinal": 3, + "name": "times_used", + "type_info": "Int4" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "last_used_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "expires_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 7, + "name": "revoked_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + true, + false, + false, + true, + true, + true + ] + }, + "hash": "d0355d4e98bec6120f17d8cf81ac8c30ed19e9cebd0c8e7c7918b1c3ca0e3cba" +} diff --git a/crates/storage-pg/.sqlx/query-fca331753aeccddbad96d06fc9d066dcefebe978a7af477bb6b55faa1d31e9b1.json b/crates/storage-pg/.sqlx/query-fca331753aeccddbad96d06fc9d066dcefebe978a7af477bb6b55faa1d31e9b1.json new file mode 100644 index 000000000..c5e2c6953 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-fca331753aeccddbad96d06fc9d066dcefebe978a7af477bb6b55faa1d31e9b1.json @@ -0,0 +1,64 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT user_registration_token_id,\n token,\n usage_limit,\n times_used,\n created_at,\n last_used_at,\n expires_at,\n revoked_at\n FROM user_registration_tokens\n WHERE token = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "user_registration_token_id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "token", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "usage_limit", + "type_info": "Int4" + }, + { + "ordinal": 3, + "name": "times_used", + "type_info": "Int4" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "last_used_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "expires_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 7, + "name": "revoked_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + true, + false, + false, + true, + true, + true + ] + }, + "hash": "fca331753aeccddbad96d06fc9d066dcefebe978a7af477bb6b55faa1d31e9b1" +} diff --git a/crates/storage-pg/migrations/20250602212100_user_registration_tokens.sql b/crates/storage-pg/migrations/20250602212100_user_registration_tokens.sql new file mode 100644 index 000000000..2f9ec3cd2 --- /dev/null +++ b/crates/storage-pg/migrations/20250602212100_user_registration_tokens.sql @@ -0,0 +1,57 @@ +-- Copyright 2025 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- Add a table for storing user registration tokens +CREATE TABLE "user_registration_tokens" ( + "user_registration_token_id" UUID PRIMARY KEY, + + -- The token string that users need to provide during registration + "token" TEXT NOT NULL UNIQUE, + + -- Optional limit on how many times this token can be used + "usage_limit" INTEGER, + + -- How many times this token has been used + "times_used" INTEGER NOT NULL DEFAULT 0, + + -- When the token was created + "created_at" TIMESTAMP WITH TIME ZONE NOT NULL, + + -- When the token was last used + "last_used_at" TIMESTAMP WITH TIME ZONE, + + -- Optional expiration time for the token + "expires_at" TIMESTAMP WITH TIME ZONE, + + -- When the token was revoked + "revoked_at" TIMESTAMP WITH TIME ZONE +); + +-- Create a few indices on the table, as we use those for filtering +-- They are safe to create non-concurrently, as the table is empty at this point +CREATE INDEX "user_registration_tokens_usage_limit_idx" + ON "user_registration_tokens" ("usage_limit"); + +CREATE INDEX "user_registration_tokens_times_used_idx" + ON "user_registration_tokens" ("times_used"); + +CREATE INDEX "user_registration_tokens_created_at_idx" + ON "user_registration_tokens" ("created_at"); + +CREATE INDEX "user_registration_tokens_last_used_at_idx" + ON "user_registration_tokens" ("last_used_at"); + +CREATE INDEX "user_registration_tokens_expires_at_idx" + ON "user_registration_tokens" ("expires_at"); + +CREATE INDEX "user_registration_tokens_revoked_at_idx" + ON "user_registration_tokens" ("revoked_at"); + +-- Add foreign key reference to registration tokens in user registrations +-- A second migration will add the index for this foreign key +ALTER TABLE "user_registrations" + ADD COLUMN "user_registration_token_id" UUID + REFERENCES "user_registration_tokens" ("user_registration_token_id") + ON DELETE SET NULL; \ No newline at end of file diff --git a/crates/storage-pg/migrations/20250602212101_idx_user_registration_token.sql b/crates/storage-pg/migrations/20250602212101_idx_user_registration_token.sql new file mode 100644 index 000000000..a25d6358a --- /dev/null +++ b/crates/storage-pg/migrations/20250602212101_idx_user_registration_token.sql @@ -0,0 +1,9 @@ +-- no-transaction +-- Copyright 2025 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +CREATE INDEX CONCURRENTLY + user_registrations_user_registration_token_id_fk + ON user_registrations (user_registration_token_id); \ No newline at end of file diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index c6668c2e4..8dc02b9bb 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -26,7 +26,11 @@ use mas_storage::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, - user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + user::{ + BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, + UserRecoveryRepository, UserRegistrationRepository, UserRegistrationTokenRepository, + UserRepository, UserTermsRepository, + }, }; use sqlx::{PgConnection, PgPool, Postgres, Transaction}; use tracing::Instrument; @@ -55,8 +59,8 @@ use crate::{ }, user::{ PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository, - PgUserRecoveryRepository, PgUserRegistrationRepository, PgUserRepository, - PgUserTermsRepository, + PgUserRecoveryRepository, PgUserRegistrationRepository, PgUserRegistrationTokenRepository, + PgUserRepository, PgUserTermsRepository, }, }; @@ -232,22 +236,26 @@ where fn user_recovery<'c>( &'c mut self, - ) -> Box + 'c> { + ) -> Box + 'c> { Box::new(PgUserRecoveryRepository::new(self.conn.as_mut())) } - fn user_terms<'c>( - &'c mut self, - ) -> Box + 'c> { + fn user_terms<'c>(&'c mut self) -> Box + 'c> { Box::new(PgUserTermsRepository::new(self.conn.as_mut())) } fn user_registration<'c>( &'c mut self, - ) -> Box + 'c> { + ) -> Box + 'c> { Box::new(PgUserRegistrationRepository::new(self.conn.as_mut())) } + fn user_registration_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUserRegistrationTokenRepository::new(self.conn.as_mut())) + } + fn browser_session<'c>( &'c mut self, ) -> Box + 'c> { diff --git a/crates/storage-pg/src/user/mod.rs b/crates/storage-pg/src/user/mod.rs index 659a2172a..8e755188d 100644 --- a/crates/storage-pg/src/user/mod.rs +++ b/crates/storage-pg/src/user/mod.rs @@ -32,6 +32,7 @@ mod email; mod password; mod recovery; mod registration; +mod registration_token; mod session; mod terms; @@ -41,7 +42,8 @@ mod tests; pub use self::{ email::PgUserEmailRepository, password::PgUserPasswordRepository, recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository, - session::PgBrowserSessionRepository, terms::PgUserTermsRepository, + registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository, + terms::PgUserTermsRepository, }; /// An implementation of [`UserRepository`] for a PostgreSQL connection diff --git a/crates/storage-pg/src/user/registration.rs b/crates/storage-pg/src/user/registration.rs index 5d578ab79..7f123b361 100644 --- a/crates/storage-pg/src/user/registration.rs +++ b/crates/storage-pg/src/user/registration.rs @@ -7,7 +7,9 @@ use std::net::IpAddr; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_data_model::{UserEmailAuthentication, UserRegistration, UserRegistrationPassword}; +use mas_data_model::{ + UserEmailAuthentication, UserRegistration, UserRegistrationPassword, UserRegistrationToken, +}; use mas_storage::{Clock, user::UserRegistrationRepository}; use rand::RngCore; use sqlx::PgConnection; @@ -40,6 +42,7 @@ struct UserRegistrationLookup { display_name: Option, terms_url: Option, email_authentication_id: Option, + user_registration_token_id: Option, hashed_password: Option, hashed_password_version: Option, created_at: DateTime, @@ -94,6 +97,7 @@ impl TryFrom for UserRegistration { display_name: value.display_name, terms_url, email_authentication_id: value.email_authentication_id.map(Ulid::from), + user_registration_token_id: value.user_registration_token_id.map(Ulid::from), password, created_at: value.created_at, completed_at: value.completed_at, @@ -126,6 +130,7 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> { , display_name , terms_url , email_authentication_id + , user_registration_token_id , hashed_password , hashed_password_version , created_at @@ -200,6 +205,7 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> { display_name: None, terms_url: None, email_authentication_id: None, + user_registration_token_id: None, password: None, }) } @@ -351,6 +357,41 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> { Ok(user_registration) } + #[tracing::instrument( + name = "db.user_registration.set_registration_token", + skip_all, + fields( + db.query.text, + %user_registration.id, + %user_registration_token.id, + ), + err, + )] + async fn set_registration_token( + &mut self, + mut user_registration: UserRegistration, + user_registration_token: &UserRegistrationToken, + ) -> Result { + let res = sqlx::query!( + r#" + UPDATE user_registrations + SET user_registration_token_id = $2 + WHERE user_registration_id = $1 AND completed_at IS NULL + "#, + Uuid::from(user_registration.id), + Uuid::from(user_registration_token.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + user_registration.user_registration_token_id = Some(user_registration_token.id); + + Ok(user_registration) + } + #[tracing::instrument( name = "db.user_registration.complete", skip_all, diff --git a/crates/storage-pg/src/user/registration_token.rs b/crates/storage-pg/src/user/registration_token.rs new file mode 100644 index 000000000..2f71fc83c --- /dev/null +++ b/crates/storage-pg/src/user/registration_token.rs @@ -0,0 +1,287 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::UserRegistrationToken; +use mas_storage::{Clock, user::UserRegistrationTokenRepository}; +use rand::RngCore; +use sqlx::{PgConnection, types::Uuid}; +use ulid::Ulid; + +use crate::{DatabaseInconsistencyError, errors::DatabaseError, tracing::ExecuteExt}; + +/// An implementation of [`mas_storage::user::UserRegistrationTokenRepository`] +/// for a PostgreSQL connection +pub struct PgUserRegistrationTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserRegistrationTokenRepository<'c> { + /// Create a new [`PgUserRegistrationTokenRepository`] from an active + /// PostgreSQL connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct UserRegistrationTokenLookup { + user_registration_token_id: Uuid, + token: String, + usage_limit: Option, + times_used: i32, + created_at: DateTime, + last_used_at: Option>, + expires_at: Option>, + revoked_at: Option>, +} + +impl TryFrom for UserRegistrationToken { + type Error = DatabaseInconsistencyError; + + fn try_from(res: UserRegistrationTokenLookup) -> Result { + let id = Ulid::from(res.user_registration_token_id); + + let usage_limit = res + .usage_limit + .map(u32::try_from) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("user_registration_tokens") + .column("usage_limit") + .row(id) + .source(e) + })?; + + let times_used = res.times_used.try_into().map_err(|e| { + DatabaseInconsistencyError::on("user_registration_tokens") + .column("times_used") + .row(id) + .source(e) + })?; + + Ok(UserRegistrationToken { + id, + token: res.token, + usage_limit, + times_used, + created_at: res.created_at, + last_used_at: res.last_used_at, + expires_at: res.expires_at, + revoked_at: res.revoked_at, + }) + } +} + +#[async_trait] +impl UserRegistrationTokenRepository for PgUserRegistrationTokenRepository<'_> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.user_registration_token.lookup", + skip_all, + fields( + db.query.text, + user_registration_token.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserRegistrationTokenLookup, + r#" + SELECT user_registration_token_id, + token, + usage_limit, + times_used, + created_at, + last_used_at, + expires_at, + revoked_at + FROM user_registration_tokens + WHERE user_registration_token_id = $1 + "#, + Uuid::from(id) + ) + .traced() + .fetch_optional(&mut *self.conn) + .await?; + + let Some(res) = res else { + return Ok(None); + }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.user_registration_token.find_by_token", + skip_all, + fields( + db.query.text, + token = %token, + ), + err, + )] + async fn find_by_token( + &mut self, + token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserRegistrationTokenLookup, + r#" + SELECT user_registration_token_id, + token, + usage_limit, + times_used, + created_at, + last_used_at, + expires_at, + revoked_at + FROM user_registration_tokens + WHERE token = $1 + "#, + token + ) + .traced() + .fetch_optional(&mut *self.conn) + .await?; + + let Some(res) = res else { + return Ok(None); + }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.user_registration_token.add", + skip_all, + fields( + db.query.text, + user_registration_token.token = %token, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn mas_storage::Clock, + token: String, + usage_limit: Option, + expires_at: Option>, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + + let usage_limit_i32 = usage_limit + .map(i32::try_from) + .transpose() + .map_err(DatabaseError::to_invalid_operation)?; + + sqlx::query!( + r#" + INSERT INTO user_registration_tokens + (user_registration_token_id, token, usage_limit, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + &token, + usage_limit_i32, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UserRegistrationToken { + id, + token, + usage_limit, + times_used: 0, + created_at, + last_used_at: None, + expires_at, + revoked_at: None, + }) + } + + #[tracing::instrument( + name = "db.user_registration_token.use_token", + skip_all, + fields( + db.query.text, + user_registration_token.id = %token.id, + ), + err, + )] + async fn use_token( + &mut self, + clock: &dyn Clock, + token: UserRegistrationToken, + ) -> Result { + let now = clock.now(); + let new_times_used = sqlx::query_scalar!( + r#" + UPDATE user_registration_tokens + SET times_used = times_used + 1, + last_used_at = $2 + WHERE user_registration_token_id = $1 AND revoked_at IS NULL + RETURNING times_used + "#, + Uuid::from(token.id), + now, + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + let new_times_used = new_times_used + .try_into() + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(UserRegistrationToken { + times_used: new_times_used, + last_used_at: Some(now), + ..token + }) + } + + #[tracing::instrument( + name = "db.user_registration_token.revoke", + skip_all, + fields( + db.query.text, + user_registration_token.id = %token.id, + ), + err, + )] + async fn revoke( + &mut self, + clock: &dyn Clock, + mut token: UserRegistrationToken, + ) -> Result { + let revoked_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE user_registration_tokens + SET revoked_at = $2 + WHERE user_registration_token_id = $1 + "#, + Uuid::from(token.id), + revoked_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + token.revoked_at = Some(revoked_at); + + Ok(token) + } +} diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 93c43d469..a02edb4ad 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -26,7 +26,8 @@ use crate::{ }, user::{ BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, - UserRecoveryRepository, UserRegistrationRepository, UserRepository, UserTermsRepository, + UserRecoveryRepository, UserRegistrationRepository, UserRegistrationTokenRepository, + UserRepository, UserTermsRepository, }, }; @@ -148,6 +149,11 @@ pub trait RepositoryAccess: Send { &'c mut self, ) -> Box + 'c>; + /// Get an [`UserRegistrationTokenRepository`] + fn user_registration_token<'c>( + &'c mut self, + ) -> Box + 'c>; + /// Get an [`UserTermsRepository`] fn user_terms<'c>(&'c mut self) -> Box + 'c>; @@ -249,7 +255,8 @@ mod impls { }, user::{ BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, - UserRegistrationRepository, UserRepository, UserTermsRepository, + UserRegistrationRepository, UserRegistrationTokenRepository, UserRepository, + UserTermsRepository, }, }; @@ -348,6 +355,15 @@ mod impls { )) } + fn user_registration_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.user_registration_token(), + &mut self.mapper, + )) + } + fn user_terms<'c>(&'c mut self) -> Box + 'c> { Box::new(MapErr::new(self.inner.user_terms(), &mut self.mapper)) } @@ -512,6 +528,12 @@ mod impls { (**self).user_registration() } + fn user_registration_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).user_registration_token() + } + fn user_terms<'c>(&'c mut self) -> Box + 'c> { (**self).user_terms() } diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 395c6e615..6a9bdc4c5 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -17,6 +17,7 @@ mod email; mod password; mod recovery; mod registration; +mod registration_token; mod session; mod terms; @@ -25,6 +26,7 @@ pub use self::{ password::UserPasswordRepository, recovery::UserRecoveryRepository, registration::UserRegistrationRepository, + registration_token::UserRegistrationTokenRepository, session::{BrowserSessionFilter, BrowserSessionRepository}, terms::UserTermsRepository, }; diff --git a/crates/storage/src/user/registration.rs b/crates/storage/src/user/registration.rs index 3932db622..49fd01fdc 100644 --- a/crates/storage/src/user/registration.rs +++ b/crates/storage/src/user/registration.rs @@ -6,7 +6,7 @@ use std::net::IpAddr; use async_trait::async_trait; -use mas_data_model::{UserEmailAuthentication, UserRegistration}; +use mas_data_model::{UserEmailAuthentication, UserRegistration, UserRegistrationToken}; use rand_core::RngCore; use ulid::Ulid; use url::Url; @@ -138,6 +138,25 @@ pub trait UserRegistrationRepository: Send + Sync { version: u16, ) -> Result; + /// Set the registration token of a [`UserRegistration`] + /// + /// Returns the updated [`UserRegistration`] + /// + /// # Parameters + /// + /// * `user_registration`: The [`UserRegistration`] to update + /// * `user_registration_token`: The [`UserRegistrationToken`] to set + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails or if the + /// registration is already completed + async fn set_registration_token( + &mut self, + user_registration: UserRegistration, + user_registration_token: &UserRegistrationToken, + ) -> Result; + /// Complete a [`UserRegistration`] /// /// Returns the updated [`UserRegistration`] @@ -190,6 +209,11 @@ repository_impl!(UserRegistrationRepository: hashed_password: String, version: u16, ) -> Result; + async fn set_registration_token( + &mut self, + user_registration: UserRegistration, + user_registration_token: &UserRegistrationToken, + ) -> Result; async fn complete( &mut self, clock: &dyn Clock, diff --git a/crates/storage/src/user/registration_token.rs b/crates/storage/src/user/registration_token.rs new file mode 100644 index 000000000..91b0584b4 --- /dev/null +++ b/crates/storage/src/user/registration_token.rs @@ -0,0 +1,130 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::UserRegistrationToken; +use rand_core::RngCore; +use ulid::Ulid; + +use crate::{Clock, repository_impl}; + +/// A [`UserRegistrationTokenRepository`] helps interacting with +/// [`UserRegistrationToken`] saved in the storage backend +#[async_trait] +pub trait UserRegistrationTokenRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup a [`UserRegistrationToken`] by its ID + /// + /// Returns `None` if no [`UserRegistrationToken`] was found + /// + /// # Parameters + /// + /// * `id`: The ID of the [`UserRegistrationToken`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Lookup a [`UserRegistrationToken`] by its token string + /// + /// Returns `None` if no [`UserRegistrationToken`] was found + /// + /// # Parameters + /// + /// * `token`: The token string to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_by_token( + &mut self, + token: &str, + ) -> Result, Self::Error>; + + /// Create a new [`UserRegistrationToken`] + /// + /// Returns the newly created [`UserRegistrationToken`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `token`: The token string + /// * `usage_limit`: Optional limit on how many times the token can be used + /// * `expires_at`: Optional expiration time for the token + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + token: String, + usage_limit: Option, + expires_at: Option>, + ) -> Result; + + /// Increment the usage count of a [`UserRegistrationToken`] + /// + /// Returns the updated [`UserRegistrationToken`] + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `token`: The [`UserRegistrationToken`] to update + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn use_token( + &mut self, + clock: &dyn Clock, + token: UserRegistrationToken, + ) -> Result; + + /// Revoke a [`UserRegistrationToken`] + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `token`: The [`UserRegistrationToken`] to delete + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn revoke( + &mut self, + clock: &dyn Clock, + token: UserRegistrationToken, + ) -> Result; +} + +repository_impl!(UserRegistrationTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn find_by_token(&mut self, token: &str) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + token: String, + usage_limit: Option, + expires_at: Option>, + ) -> Result; + async fn use_token( + &mut self, + clock: &dyn Clock, + token: UserRegistrationToken, + ) -> Result; + async fn revoke( + &mut self, + clock: &dyn Clock, + token: UserRegistrationToken, + ) -> Result; +);