diff --git a/crates/axum-utils/src/cookies.rs b/crates/axum-utils/src/cookies.rs index 9371d5422..1c9e0eb09 100644 --- a/crates/axum-utils/src/cookies.rs +++ b/crates/axum-utils/src/cookies.rs @@ -138,6 +138,13 @@ impl CookieJar { self } + /// Remove a cookie from the jar + #[must_use] + pub fn remove(mut self, key: &str) -> Self { + self.inner = self.inner.remove(key.to_owned()); + self + } + /// Load and deserialize a cookie from the jar /// /// Returns `None` if the cookie is not present diff --git a/crates/handlers/src/views/register/cookie.rs b/crates/handlers/src/views/register/cookie.rs new file mode 100644 index 000000000..7e3eb8173 --- /dev/null +++ b/crates/handlers/src/views/register/cookie.rs @@ -0,0 +1,103 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +// TODO: move that to a standalone cookie manager + +use std::collections::BTreeSet; + +use chrono::{DateTime, Duration, Utc}; +use mas_axum_utils::cookies::CookieJar; +use mas_data_model::UserRegistration; +use mas_storage::Clock; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use ulid::Ulid; + +/// Name of the cookie +static COOKIE_NAME: &str = "user-registration-sessions"; + +/// Sessions expire after an hour +static SESSION_MAX_TIME: Duration = Duration::hours(1); + +/// The content of the cookie, which stores a list of user registration IDs +#[derive(Serialize, Deserialize, Default, Debug)] +pub struct UserRegistrationSessions(BTreeSet); + +#[derive(Debug, Error, PartialEq, Eq)] +#[error("user registration session not found")] +pub struct UserRegistrationSessionNotFound; + +impl UserRegistrationSessions { + /// Load the user registration sessions cookie + pub fn load(cookie_jar: &CookieJar) -> Self { + match cookie_jar.load(COOKIE_NAME) { + Ok(Some(sessions)) => sessions, + Ok(None) => Self::default(), + Err(e) => { + tracing::warn!( + error = &e as &dyn std::error::Error, + "Invalid upstream sessions cookie" + ); + Self::default() + } + } + } + + /// Returns true if the cookie is empty + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Save the user registration sessions to the cookie jar + pub fn save(self, cookie_jar: CookieJar, clock: &C) -> CookieJar + where + C: Clock, + { + let this = self.expire(clock.now()); + + if this.is_empty() { + cookie_jar.remove(COOKIE_NAME) + } else { + cookie_jar.save(COOKIE_NAME, &this, false) + } + } + + fn expire(mut self, now: DateTime) -> Self { + self.0.retain(|id| { + let Ok(ts) = id.timestamp_ms().try_into() else { + return false; + }; + let Some(when) = DateTime::from_timestamp_millis(ts) else { + return false; + }; + now - when < SESSION_MAX_TIME + }); + + self + } + + /// Add a new session, for a provider and a random state + pub fn add(mut self, user_registration: &UserRegistration) -> Self { + self.0.insert(user_registration.id); + self + } + + /// Check if the session is in the list + pub fn contains(&self, user_registration: &UserRegistration) -> bool { + self.0.contains(&user_registration.id) + } + + /// Mark a link as consumed to avoid replay + pub fn consume_session( + mut self, + user_registration: &UserRegistration, + ) -> Result { + if !self.0.remove(&user_registration.id) { + return Err(UserRegistrationSessionNotFound); + } + + Ok(self) + } +} diff --git a/crates/handlers/src/views/register/mod.rs b/crates/handlers/src/views/register/mod.rs index dbd8a25ed..ea8cb40ce 100644 --- a/crates/handlers/src/views/register/mod.rs +++ b/crates/handlers/src/views/register/mod.rs @@ -17,6 +17,7 @@ use mas_templates::{RegisterContext, TemplateContext, Templates}; use super::shared::OptionalPostAuthAction; use crate::{BoundActivityTracker, PreferredLanguage}; +mod cookie; pub(crate) mod password; pub(crate) mod steps; diff --git a/crates/handlers/src/views/register/password.rs b/crates/handlers/src/views/register/password.rs index 37cfd861b..aebcca7c5 100644 --- a/crates/handlers/src/views/register/password.rs +++ b/crates/handlers/src/views/register/password.rs @@ -35,6 +35,7 @@ use mas_templates::{ use serde::{Deserialize, Serialize}; use zeroize::Zeroizing; +use super::cookie::UserRegistrationSessions; use crate::{ captcha::Form as CaptchaForm, passwords::PasswordManager, views::shared::OptionalPostAuthAction, BoundActivityTracker, Limiter, PreferredLanguage, @@ -361,8 +362,14 @@ pub(crate) async fn post( repo.save().await?; - Ok(url_builder - .redirect(&mas_router::RegisterFinish::new(registration.id)) + let cookie_jar = UserRegistrationSessions::load(&cookie_jar) + .add(®istration) + .save(cookie_jar, &clock); + + Ok(( + cookie_jar, + url_builder.redirect(&mas_router::RegisterFinish::new(registration.id)), + ) .into_response()) } diff --git a/crates/handlers/src/views/register/steps/finish.rs b/crates/handlers/src/views/register/steps/finish.rs index 2c4679a28..eaae8b3f1 100644 --- a/crates/handlers/src/views/register/steps/finish.rs +++ b/crates/handlers/src/views/register/steps/finish.rs @@ -19,6 +19,7 @@ use mas_storage::{ }; use ulid::Ulid; +use super::super::cookie::UserRegistrationSessions; use crate::{views::shared::OptionalPostAuthAction, BoundActivityTracker}; #[tracing::instrument( @@ -59,6 +60,14 @@ pub(crate) async fn get( )); } + // Check that this registration belongs to this browser + let registrations = UserRegistrationSessions::load(&cookie_jar); + if !registrations.contains(®istration) { + return Err(FancyError::from(anyhow::anyhow!( + "Could not find the registration in the browser cookies" + ))); + } + // Let's perform last minute checks on the registration, especially to avoid // race conditions where multiple users register with the same username or email // address @@ -116,6 +125,11 @@ pub(crate) async fn get( .complete(&clock, registration) .await?; + // Consume the registration session + let cookie_jar = registrations + .consume_session(®istration)? + .save(cookie_jar, &clock); + // Now we can start the user creation let user = repo .user()