From 7f3aa061539095e82b925a3d31fba54054ffd9f0 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 28 Jul 2023 18:25:54 +0200 Subject: [PATCH 1/8] Add a way to lock users --- crates/cli/src/commands/manage.rs | 52 +++++++++++++ crates/data-model/src/users.rs | 14 +++- crates/handlers/src/compat/login.rs | 2 + crates/handlers/src/oauth2/introspection.rs | 4 + crates/handlers/src/upstream_oauth2/link.rs | 6 +- crates/handlers/src/views/login.rs | 1 + ...e561e6521c45ce07d3a42411984c9a6b75fdc.json | 14 ++++ ...13b91fbccfe5fbdbead8c4868d52a61a0f9d.json} | 16 +++- ...4599f6374c96bb4a6827d400acb22fb0fd39.json} | 12 ++- ...a75d18e914f823902587b63c9f295407144b1.json | 15 ++++ ...e1a6ac868c95bfaee3a6960df1cf484d53da.json} | 12 ++- .../migrations/20230728154304_user_lock.sql | 19 +++++ crates/storage-pg/src/iden.rs | 2 + crates/storage-pg/src/user/mod.rs | 77 ++++++++++++++++++- crates/storage-pg/src/user/session.rs | 14 ++++ crates/storage/src/user/mod.rs | 29 +++++++ 16 files changed, 277 insertions(+), 12 deletions(-) create mode 100644 crates/storage-pg/.sqlx/query-22896e8f2a002f307089c3e0f9ee561e6521c45ce07d3a42411984c9a6b75fdc.json rename crates/storage-pg/.sqlx/{query-25d61a373560556deafe056c8cd2982ac472f5ec2fab08b0b5275c4b78c11a7e.json => query-73fe61f03a41778c6273b1c2dbdb13b91fbccfe5fbdbead8c4868d52a61a0f9d.json} (66%) rename crates/storage-pg/.sqlx/{query-836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c.json => query-bfa5eaeaa5b4574bb083c86711eb4599f6374c96bb4a6827d400acb22fb0fd39.json} (67%) create mode 100644 crates/storage-pg/.sqlx/query-c29fa41743811a6ac3a9b952b6ea75d18e914f823902587b63c9f295407144b1.json rename crates/storage-pg/.sqlx/{query-08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35.json => query-e0ea7d93ab3f565828b2faab4cc5e1a6ac868c95bfaee3a6960df1cf484d53da.json} (67%) create mode 100644 crates/storage-pg/migrations/20230728154304_user_lock.sql diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index c2c41e287..192b46c4e 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -69,6 +69,18 @@ enum Subcommand { #[arg(long)] dry_run: bool, }, + + /// Lock a user + LockUser { + /// User to lock + username: String, + }, + + /// Unlock a user + UnlockUser { + /// User to unlock + username: String, + }, } impl Options { @@ -330,6 +342,46 @@ impl Options { Ok(()) } + + SC::LockUser { username } => { + let _span = info_span!("cli.manage.lock_user", user.username = username).entered(); + let config: DatabaseConfig = root.load_config()?; + let pool = database_from_config(&config).await?; + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + + let user = repo + .user() + .find_by_username(&username) + .await? + .context("User not found")?; + + info!(%user.id, "Locking user"); + + repo.user().lock(&clock, user).await?; + repo.save().await?; + + Ok(()) + } + + SC::UnlockUser { username } => { + let _span = info_span!("cli.manage.lock_user", user.username = username).entered(); + let config: DatabaseConfig = root.load_config()?; + let pool = database_from_config(&config).await?; + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + + let user = repo + .user() + .find_by_username(&username) + .await? + .context("User not found")?; + + info!(%user.id, "Unlocking user"); + + repo.user().unlock(user).await?; + repo.save().await?; + + Ok(()) + } } } } diff --git a/crates/data-model/src/users.rs b/crates/data-model/src/users.rs index d53a688ba..9858356fe 100644 --- a/crates/data-model/src/users.rs +++ b/crates/data-model/src/users.rs @@ -25,6 +25,16 @@ pub struct User { pub username: String, pub sub: String, pub primary_user_email_id: Option, + pub created_at: DateTime, + pub locked_at: Option>, +} + +impl User { + /// Returns `true` unless the user is locked. + #[must_use] + pub fn is_valid(&self) -> bool { + self.locked_at.is_none() + } } impl User { @@ -35,6 +45,8 @@ impl User { username: "john".to_owned(), sub: "123-456".to_owned(), primary_user_email_id: None, + created_at: now, + locked_at: None, }] } } @@ -65,7 +77,7 @@ pub struct BrowserSession { impl BrowserSession { #[must_use] pub fn active(&self) -> bool { - self.finished_at.is_none() + self.finished_at.is_none() && self.user.is_valid() } } diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 2541ad6af..cbf365ab2 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -335,6 +335,7 @@ async fn token_login( .user() .lookup(session.user_id) .await? + .filter(mas_data_model::User::is_valid) .ok_or(RouteError::UserNotFound)?; repo.compat_sso_login().exchange(clock, login).await?; @@ -355,6 +356,7 @@ async fn user_password_login( .user() .find_by_username(&username) .await? + .filter(mas_data_model::User::is_valid) .ok_or(RouteError::UserNotFound)?; // Lookup its password diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 730498514..6dc58233e 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -188,6 +188,7 @@ pub(crate) async fn post( .browser_session() .lookup(session.user_session_id) .await? + .filter(|b| b.user.is_valid()) // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; @@ -227,6 +228,7 @@ pub(crate) async fn post( .browser_session() .lookup(session.user_session_id) .await? + .filter(|b| b.user.is_valid()) // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; @@ -265,6 +267,7 @@ pub(crate) async fn post( .user() .lookup(session.user_id) .await? + .filter(mas_data_model::User::is_valid) // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; @@ -311,6 +314,7 @@ pub(crate) async fn post( .user() .lookup(session.user_id) .await? + .filter(mas_data_model::User::is_valid) // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index b5d97578d..6320adf83 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -23,7 +23,7 @@ use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, SessionInfoExt, }; -use mas_data_model::UpstreamOAuthProviderImportPreference; +use mas_data_model::{UpstreamOAuthProviderImportPreference, User}; use mas_jose::jwt::Jwt; use mas_keystore::Encrypter; use mas_storage::{ @@ -239,6 +239,8 @@ pub(crate) async fn get( .user() .lookup(user_id) .await? + // XXX: is that right? + .filter(User::is_valid) .ok_or(RouteError::UserNotFound)?; let ctx = UpstreamExistingLinkContext::new(user) @@ -263,6 +265,7 @@ pub(crate) async fn get( .user() .lookup(user_id) .await? + .filter(mas_data_model::User::is_valid) .ok_or(RouteError::UserNotFound)?; let ctx = UpstreamExistingLinkContext::new(user).with_csrf(csrf_token.form_value()); @@ -390,6 +393,7 @@ pub(crate) async fn post( .user() .lookup(user_id) .await? + .filter(mas_data_model::User::is_valid) .ok_or(RouteError::UserNotFound)?; repo.browser_session().add(&mut rng, &clock, &user).await? diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index fd4a07e29..f9168003d 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -202,6 +202,7 @@ async fn login( .find_by_username(username) .await .map_err(|_e| FormError::Internal)? + .filter(mas_data_model::User::is_valid) .ok_or(FormError::InvalidCredentials)?; // And its password diff --git a/crates/storage-pg/.sqlx/query-22896e8f2a002f307089c3e0f9ee561e6521c45ce07d3a42411984c9a6b75fdc.json b/crates/storage-pg/.sqlx/query-22896e8f2a002f307089c3e0f9ee561e6521c45ce07d3a42411984c9a6b75fdc.json new file mode 100644 index 000000000..8d1577563 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-22896e8f2a002f307089c3e0f9ee561e6521c45ce07d3a42411984c9a6b75fdc.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE users\n SET locked_at = NULL\n WHERE user_id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "22896e8f2a002f307089c3e0f9ee561e6521c45ce07d3a42411984c9a6b75fdc" +} diff --git a/crates/storage-pg/.sqlx/query-25d61a373560556deafe056c8cd2982ac472f5ec2fab08b0b5275c4b78c11a7e.json b/crates/storage-pg/.sqlx/query-73fe61f03a41778c6273b1c2dbdb13b91fbccfe5fbdbead8c4868d52a61a0f9d.json similarity index 66% rename from crates/storage-pg/.sqlx/query-25d61a373560556deafe056c8cd2982ac472f5ec2fab08b0b5275c4b78c11a7e.json rename to crates/storage-pg/.sqlx/query-73fe61f03a41778c6273b1c2dbdb13b91fbccfe5fbdbead8c4868d52a61a0f9d.json index 6ea5b01e3..146624950 100644 --- a/crates/storage-pg/.sqlx/query-25d61a373560556deafe056c8cd2982ac472f5ec2fab08b0b5275c4b78c11a7e.json +++ b/crates/storage-pg/.sqlx/query-73fe61f03a41778c6273b1c2dbdb13b91fbccfe5fbdbead8c4868d52a61a0f9d.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n WHERE s.user_session_id = $1\n ", + "query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , u.created_at AS \"user_created_at\"\n , u.locked_at AS \"user_locked_at\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n WHERE s.user_session_id = $1\n ", "describe": { "columns": [ { @@ -32,6 +32,16 @@ "ordinal": 5, "name": "user_primary_user_email_id", "type_info": "Uuid" + }, + { + "ordinal": 6, + "name": "user_created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 7, + "name": "user_locked_at", + "type_info": "Timestamptz" } ], "parameters": { @@ -45,8 +55,10 @@ true, false, false, + true, + false, true ] }, - "hash": "25d61a373560556deafe056c8cd2982ac472f5ec2fab08b0b5275c4b78c11a7e" + "hash": "73fe61f03a41778c6273b1c2dbdb13b91fbccfe5fbdbead8c4868d52a61a0f9d" } diff --git a/crates/storage-pg/.sqlx/query-836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c.json b/crates/storage-pg/.sqlx/query-bfa5eaeaa5b4574bb083c86711eb4599f6374c96bb4a6827d400acb22fb0fd39.json similarity index 67% rename from crates/storage-pg/.sqlx/query-836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c.json rename to crates/storage-pg/.sqlx/query-bfa5eaeaa5b4574bb083c86711eb4599f6374c96bb4a6827d400acb22fb0fd39.json index aa1a90fe1..f70906f26 100644 --- a/crates/storage-pg/.sqlx/query-836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c.json +++ b/crates/storage-pg/.sqlx/query-bfa5eaeaa5b4574bb083c86711eb4599f6374c96bb4a6827d400acb22fb0fd39.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE username = $1\n ", + "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n , locked_at\n FROM users\n WHERE username = $1\n ", "describe": { "columns": [ { @@ -22,6 +22,11 @@ "ordinal": 3, "name": "created_at", "type_info": "Timestamptz" + }, + { + "ordinal": 4, + "name": "locked_at", + "type_info": "Timestamptz" } ], "parameters": { @@ -33,8 +38,9 @@ false, false, true, - false + false, + true ] }, - "hash": "836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c" + "hash": "bfa5eaeaa5b4574bb083c86711eb4599f6374c96bb4a6827d400acb22fb0fd39" } diff --git a/crates/storage-pg/.sqlx/query-c29fa41743811a6ac3a9b952b6ea75d18e914f823902587b63c9f295407144b1.json b/crates/storage-pg/.sqlx/query-c29fa41743811a6ac3a9b952b6ea75d18e914f823902587b63c9f295407144b1.json new file mode 100644 index 000000000..3ba8612b4 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-c29fa41743811a6ac3a9b952b6ea75d18e914f823902587b63c9f295407144b1.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE users\n SET locked_at = $1\n WHERE user_id = $2\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "c29fa41743811a6ac3a9b952b6ea75d18e914f823902587b63c9f295407144b1" +} diff --git a/crates/storage-pg/.sqlx/query-08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35.json b/crates/storage-pg/.sqlx/query-e0ea7d93ab3f565828b2faab4cc5e1a6ac868c95bfaee3a6960df1cf484d53da.json similarity index 67% rename from crates/storage-pg/.sqlx/query-08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35.json rename to crates/storage-pg/.sqlx/query-e0ea7d93ab3f565828b2faab4cc5e1a6ac868c95bfaee3a6960df1cf484d53da.json index 9a30425ae..c38eb7dbf 100644 --- a/crates/storage-pg/.sqlx/query-08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35.json +++ b/crates/storage-pg/.sqlx/query-e0ea7d93ab3f565828b2faab4cc5e1a6ac868c95bfaee3a6960df1cf484d53da.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE user_id = $1\n ", + "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n , locked_at\n FROM users\n WHERE user_id = $1\n ", "describe": { "columns": [ { @@ -22,6 +22,11 @@ "ordinal": 3, "name": "created_at", "type_info": "Timestamptz" + }, + { + "ordinal": 4, + "name": "locked_at", + "type_info": "Timestamptz" } ], "parameters": { @@ -33,8 +38,9 @@ false, false, true, - false + false, + true ] }, - "hash": "08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35" + "hash": "e0ea7d93ab3f565828b2faab4cc5e1a6ac868c95bfaee3a6960df1cf484d53da" } diff --git a/crates/storage-pg/migrations/20230728154304_user_lock.sql b/crates/storage-pg/migrations/20230728154304_user_lock.sql new file mode 100644 index 000000000..a015c7522 --- /dev/null +++ b/crates/storage-pg/migrations/20230728154304_user_lock.sql @@ -0,0 +1,19 @@ +-- Copyright 2023 The Matrix.org Foundation C.I.C. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- Add a new column in on the `users` to record when an account gets locked +ALTER TABLE "users" + ADD COLUMN "locked_at" + TIMESTAMP WITH TIME ZONE + DEFAULT NULL; \ No newline at end of file diff --git a/crates/storage-pg/src/iden.rs b/crates/storage-pg/src/iden.rs index 29978ab8a..c5175f9f8 100644 --- a/crates/storage-pg/src/iden.rs +++ b/crates/storage-pg/src/iden.rs @@ -29,6 +29,8 @@ pub enum Users { UserId, Username, PrimaryUserEmailId, + CreatedAt, + LockedAt, } #[derive(sea_query::Iden)] diff --git a/crates/storage-pg/src/user/mod.rs b/crates/storage-pg/src/user/mod.rs index 468736e67..f5cf37d8c 100644 --- a/crates/storage-pg/src/user/mod.rs +++ b/crates/storage-pg/src/user/mod.rs @@ -55,9 +55,8 @@ struct UserLookup { user_id: Uuid, username: String, primary_user_email_id: Option, - - #[allow(dead_code)] created_at: DateTime, + locked_at: Option>, } impl From for User { @@ -68,6 +67,8 @@ impl From for User { username: value.username, sub: id.to_string(), primary_user_email_id: value.primary_user_email_id.map(Into::into), + created_at: value.created_at, + locked_at: value.locked_at, } } } @@ -93,6 +94,7 @@ impl<'c> UserRepository for PgUserRepository<'c> { , username , primary_user_email_id , created_at + , locked_at FROM users WHERE user_id = $1 "#, @@ -124,6 +126,7 @@ impl<'c> UserRepository for PgUserRepository<'c> { , username , primary_user_email_id , created_at + , locked_at FROM users WHERE username = $1 "#, @@ -176,6 +179,8 @@ impl<'c> UserRepository for PgUserRepository<'c> { username, sub: id.to_string(), primary_user_email_id: None, + created_at, + locked_at: None, }) } @@ -203,4 +208,72 @@ impl<'c> UserRepository for PgUserRepository<'c> { Ok(exists) } + + #[tracing::instrument( + name = "db.user.lock", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result { + if user.locked_at.is_some() { + return Ok(user); + } + + let locked_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE users + SET locked_at = $1 + WHERE user_id = $2 + "#, + locked_at, + Uuid::from(user.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + user.locked_at = Some(locked_at); + + Ok(user) + } + + #[tracing::instrument( + name = "db.user.unlock", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn unlock(&mut self, mut user: User) -> Result { + if user.locked_at.is_none() { + return Ok(user); + } + + let res = sqlx::query!( + r#" + UPDATE users + SET locked_at = NULL + WHERE user_id = $1 + "#, + Uuid::from(user.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + user.locked_at = None; + + Ok(user) + } } diff --git a/crates/storage-pg/src/user/session.rs b/crates/storage-pg/src/user/session.rs index c0afdc637..3b094c483 100644 --- a/crates/storage-pg/src/user/session.rs +++ b/crates/storage-pg/src/user/session.rs @@ -53,6 +53,8 @@ struct SessionLookup { user_id: Uuid, user_username: String, user_primary_user_email_id: Option, + user_created_at: DateTime, + user_locked_at: Option>, } impl TryFrom for BrowserSession { @@ -65,6 +67,8 @@ impl TryFrom for BrowserSession { username: value.user_username, sub: id.to_string(), primary_user_email_id: value.user_primary_user_email_id.map(Into::into), + created_at: value.user_created_at, + locked_at: value.user_locked_at, }; Ok(BrowserSession { @@ -99,6 +103,8 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { , u.user_id , u.username AS "user_username" , u.primary_user_email_id AS "user_primary_user_email_id" + , u.created_at AS "user_created_at" + , u.locked_at AS "user_locked_at" FROM user_sessions s INNER JOIN users u USING (user_id) @@ -232,6 +238,14 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { Expr::col((Users::Table, Users::PrimaryUserEmailId)), SessionLookupIden::UserPrimaryUserEmailId, ) + .expr_as( + Expr::col((Users::Table, Users::CreatedAt)), + SessionLookupIden::UserCreatedAt, + ) + .expr_as( + Expr::col((Users::Table, Users::LockedAt)), + SessionLookupIden::UserLockedAt, + ) .from(UserSessions::Table) .inner_join( Users::Table, diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 65709ecf3..82ad656ef 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -96,6 +96,33 @@ pub trait UserRepository: Send + Sync { /// /// Returns [`Self::Error`] if the underlying repository fails async fn exists(&mut self, username: &str) -> Result; + + /// Lock a [`User`] + /// + /// Returns the locked [`User`] + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `user`: The [`User`] to lock + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lock(&mut self, clock: &dyn Clock, user: User) -> Result; + + /// Unlock a [`User`] + /// + /// Returns the unlocked [`User`] + /// + /// # Parameters + /// + /// * `user`: The [`User`] to unlock + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn unlock(&mut self, user: User) -> Result; } repository_impl!(UserRepository: @@ -108,4 +135,6 @@ repository_impl!(UserRepository: username: String, ) -> Result; async fn exists(&mut self, username: &str) -> Result; + async fn lock(&mut self, clock: &dyn Clock, user: User) -> Result; + async fn unlock(&mut self, user: User) -> Result; ); From e4c3b9fd9fe4235cb045370fd08ddc4cc2d9e574 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 31 Jul 2023 17:42:12 +0200 Subject: [PATCH 2/8] storage-pg: add tests for user locking --- ...e92abf78dfbdf1a25e58a2bc9c14be8035f0.json} | 4 +-- crates/storage-pg/src/user/mod.rs | 7 ++++- crates/storage-pg/src/user/tests.rs | 26 +++++++++++++++++++ 3 files changed, 34 insertions(+), 3 deletions(-) rename crates/storage-pg/.sqlx/{query-b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537.json => query-7f4c4634ada4dc2745530dcca8eee92abf78dfbdf1a25e58a2bc9c14be8035f0.json} (70%) diff --git a/crates/storage-pg/.sqlx/query-b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537.json b/crates/storage-pg/.sqlx/query-7f4c4634ada4dc2745530dcca8eee92abf78dfbdf1a25e58a2bc9c14be8035f0.json similarity index 70% rename from crates/storage-pg/.sqlx/query-b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537.json rename to crates/storage-pg/.sqlx/query-7f4c4634ada4dc2745530dcca8eee92abf78dfbdf1a25e58a2bc9c14be8035f0.json index accc70f5e..e9af498f4 100644 --- a/crates/storage-pg/.sqlx/query-b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537.json +++ b/crates/storage-pg/.sqlx/query-7f4c4634ada4dc2745530dcca8eee92abf78dfbdf1a25e58a2bc9c14be8035f0.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n ", + "query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n ON CONFLICT (username) DO NOTHING\n ", "describe": { "columns": [], "parameters": { @@ -12,5 +12,5 @@ }, "nullable": [] }, - "hash": "b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537" + "hash": "7f4c4634ada4dc2745530dcca8eee92abf78dfbdf1a25e58a2bc9c14be8035f0" } diff --git a/crates/storage-pg/src/user/mod.rs b/crates/storage-pg/src/user/mod.rs index f5cf37d8c..7dbf666d2 100644 --- a/crates/storage-pg/src/user/mod.rs +++ b/crates/storage-pg/src/user/mod.rs @@ -161,10 +161,11 @@ impl<'c> UserRepository for PgUserRepository<'c> { let id = Ulid::from_datetime_with_source(created_at.into(), rng); tracing::Span::current().record("user.id", tracing::field::display(id)); - sqlx::query!( + let res = sqlx::query!( r#" INSERT INTO users (user_id, username, created_at) VALUES ($1, $2, $3) + ON CONFLICT (username) DO NOTHING "#, Uuid::from(id), username, @@ -174,6 +175,10 @@ impl<'c> UserRepository for PgUserRepository<'c> { .execute(&mut *self.conn) .await?; + // If the user already exists, want to return an error but not poison the + // transaction + DatabaseError::ensure_affected_rows(&res, 1)?; + Ok(User { id, username, diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index 7036ecbdb..cf46bfa62 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -63,12 +63,38 @@ async fn test_user_repo(pool: PgPool) { assert!(repo.user().lookup(user.id).await.unwrap().is_some()); // Adding a second time should give a conflict + // It should not poison the transaction though assert!(repo .user() .add(&mut rng, &clock, USERNAME.to_owned()) .await .is_err()); + // Try locking a user + assert!(user.is_valid()); + let user = repo.user().lock(&clock, user).await.unwrap(); + assert!(!user.is_valid()); + + // Check that the property is retrieved on lookup + let user = repo.user().lookup(user.id).await.unwrap().unwrap(); + assert!(!user.is_valid()); + + // Locking a second time should not fail + let user = repo.user().lock(&clock, user).await.unwrap(); + assert!(!user.is_valid()); + + // Try unlocking a user + let user = repo.user().unlock(user).await.unwrap(); + assert!(user.is_valid()); + + // Check that the property is retrieved on lookup + let user = repo.user().lookup(user.id).await.unwrap().unwrap(); + assert!(user.is_valid()); + + // Unlocking a second time should not fail + let user = repo.user().unlock(user).await.unwrap(); + assert!(user.is_valid()); + repo.save().await.unwrap(); } From ef29dd15f547c775202db839ee23030f358cc95a Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 1 Aug 2023 17:47:04 +0200 Subject: [PATCH 3/8] Implement a mocked HomeserverConnection which keeps state around --- Cargo.lock | 2 + crates/handlers/src/test_utils.rs | 40 +------- crates/matrix/Cargo.toml | 2 + crates/matrix/src/lib.rs | 10 +- crates/matrix/src/mock.rs | 155 ++++++++++++++++++++++++++++++ 5 files changed, 168 insertions(+), 41 deletions(-) create mode 100644 crates/matrix/src/mock.rs diff --git a/Cargo.lock b/Cargo.lock index 579b9c8b5..fd42f90ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3362,9 +3362,11 @@ dependencies = [ name = "mas-matrix" version = "0.1.0" dependencies = [ + "anyhow", "async-trait", "http", "serde", + "tokio", "url", ] diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index edcdaebb2..720d2deee 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -23,7 +23,7 @@ use headers::{Authorization, ContentType, HeaderMapExt, HeaderName}; use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode}; use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; -use mas_matrix::{HomeserverConnection, MatrixUser, ProvisionRequest}; +use mas_matrix::{HomeserverConnection, MatrixUser, MockHomeserverConnection, ProvisionRequest}; use mas_policy::PolicyFactory; use mas_router::{SimpleRoute, UrlBuilder}; use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; @@ -69,40 +69,6 @@ pub(crate) struct TestState { pub rng: Arc>, } -/// A Mock implementation of a [`HomeserverConnection`], which never fails and -/// doesn't do anything. -struct MockHomeserverConnection { - homeserver: String, -} - -#[async_trait] -impl HomeserverConnection for MockHomeserverConnection { - type Error = anyhow::Error; - - fn homeserver(&self) -> &str { - &self.homeserver - } - - async fn query_user(&self, _mxid: &str) -> Result { - Ok(MatrixUser { - displayname: None, - avatar_url: None, - }) - } - - async fn provision_user(&self, _request: &ProvisionRequest) -> Result { - Ok(false) - } - - async fn create_device(&self, _mxid: &str, _device_id: &str) -> Result<(), Self::Error> { - Ok(()) - } - - async fn delete_device(&self, _mxid: &str, _device_id: &str) -> Result<(), Self::Error> { - Ok(()) - } -} - impl TestState { /// Create a new test state from the given database pool pub async fn from_pool(pool: PgPool) -> Result { @@ -145,9 +111,7 @@ impl TestState { ) .await?; - let homeserver_connection = MockHomeserverConnection { - homeserver: "example.com".to_owned(), - }; + let homeserver_connection = MockHomeserverConnection::new("example.com"); let policy_factory = Arc::new(policy_factory); diff --git a/crates/matrix/Cargo.toml b/crates/matrix/Cargo.toml index ef9777b69..bb162caa0 100644 --- a/crates/matrix/Cargo.toml +++ b/crates/matrix/Cargo.toml @@ -6,7 +6,9 @@ edition = "2021" license = "Apache-2.0" [dependencies] +anyhow = "1.0.71" serde = { version = "1.0.177", features = ["derive"] } async-trait = "0.1.72" http = "0.2.9" +tokio = { version = "1.28.2", features = ["sync", "macros", "rt"] } url = "2.4.0" diff --git a/crates/matrix/src/lib.rs b/crates/matrix/src/lib.rs index df814452d..b1e3e4e58 100644 --- a/crates/matrix/src/lib.rs +++ b/crates/matrix/src/lib.rs @@ -16,6 +16,10 @@ #![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)] #![warn(clippy::pedantic)] +mod mock; + +pub use self::mock::MockHomeserverConnection; + #[derive(Debug)] pub struct MatrixUser { pub displayname: Option, @@ -40,10 +44,10 @@ pub struct ProvisionRequest { impl ProvisionRequest { #[must_use] - pub fn new(mxid: String, sub: String) -> Self { + pub fn new(mxid: impl Into, sub: impl Into) -> Self { Self { - mxid, - sub, + mxid: mxid.into(), + sub: sub.into(), displayname: FieldAction::DoNothing, avatar_url: FieldAction::DoNothing, emails: FieldAction::DoNothing, diff --git a/crates/matrix/src/mock.rs b/crates/matrix/src/mock.rs new file mode 100644 index 000000000..b13aa0bec --- /dev/null +++ b/crates/matrix/src/mock.rs @@ -0,0 +1,155 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{HashMap, HashSet}; + +use anyhow::Context; +use async_trait::async_trait; +use tokio::sync::RwLock; + +use crate::{HomeserverConnection, MatrixUser, ProvisionRequest}; + +struct MockUser { + sub: String, + avatar_url: Option, + displayname: Option, + devices: HashSet, + emails: Option>, +} + +/// A Mock implementation of a [`HomeserverConnection`], which never fails and +/// doesn't do anything. +pub struct MockHomeserverConnection { + homeserver: String, + users: RwLock>, +} + +impl MockHomeserverConnection { + /// Create a new [`MockHomeserverConnection`]. + pub fn new(homeserver: H) -> Self + where + H: Into, + { + Self { + homeserver: homeserver.into(), + users: RwLock::new(HashMap::new()), + } + } +} + +#[async_trait] +impl HomeserverConnection for MockHomeserverConnection { + type Error = anyhow::Error; + + fn homeserver(&self) -> &str { + &self.homeserver + } + + async fn query_user(&self, mxid: &str) -> Result { + let users = self.users.read().await; + let user = users.get(mxid).context("User not found")?; + Ok(MatrixUser { + displayname: user.displayname.clone(), + avatar_url: user.avatar_url.clone(), + }) + } + + async fn provision_user(&self, request: &ProvisionRequest) -> Result { + let mut users = self.users.write().await; + let inserted = !users.contains_key(request.mxid()); + let user = users.entry(request.mxid().to_owned()).or_insert(MockUser { + sub: request.sub().to_owned(), + avatar_url: None, + displayname: None, + devices: HashSet::new(), + emails: None, + }); + + anyhow::ensure!( + user.sub == request.sub(), + "User already provisioned with different sub" + ); + + request.on_emails(|emails| { + user.emails = emails.map(ToOwned::to_owned); + }); + + request.on_displayname(|displayname| { + user.displayname = displayname.map(ToOwned::to_owned); + }); + + request.on_avatar_url(|avatar_url| { + user.avatar_url = avatar_url.map(ToOwned::to_owned); + }); + + Ok(inserted) + } + + async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + let mut users = self.users.write().await; + let user = users.get_mut(mxid).context("User not found")?; + user.devices.insert(device_id.to_owned()); + Ok(()) + } + + async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + let mut users = self.users.write().await; + let user = users.get_mut(mxid).context("User not found")?; + user.devices.remove(device_id); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_mock_connection() { + let conn = MockHomeserverConnection::new("example.org"); + + let mxid = "@test:example.org"; + let device = "test"; + assert_eq!(conn.homeserver(), "example.org"); + assert_eq!(conn.mxid("test"), mxid); + + assert!(conn.query_user(mxid).await.is_err()); + assert!(conn.create_device(mxid, device).await.is_err()); + assert!(conn.delete_device(mxid, device).await.is_err()); + + let request = ProvisionRequest::new("@test:example.org", "test") + .set_displayname("Test User".into()) + .set_avatar_url("mxc://example.org/1234567890".into()) + .set_emails(vec!["test@example.org".to_owned()]); + + let inserted = conn.provision_user(&request).await.unwrap(); + assert!(inserted); + + let user = conn.query_user("@test:example.org").await.unwrap(); + assert_eq!(user.displayname, Some("Test User".into())); + assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into())); + + // Deleting a non-existent device should not fail + assert!(conn.delete_device(mxid, device).await.is_ok()); + + // Create the device + assert!(conn.create_device(mxid, device).await.is_ok()); + // Create the same device again + assert!(conn.create_device(mxid, device).await.is_ok()); + + // XXX: there is no API to query devices yet in the trait + // Delete the device + assert!(conn.delete_device(mxid, device).await.is_ok()); + } +} From 30cd9f611397ddca25854eb60d0f8d4034bd8d11 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 3 Aug 2023 12:14:44 +0200 Subject: [PATCH 4/8] Show and log the policy violations better --- .../src/oauth2/authorization/complete.rs | 72 ++++++++++--------- .../handlers/src/oauth2/authorization/mod.rs | 46 ++++++------ crates/storage-pg/src/repository.rs | 15 +++- 3 files changed, 79 insertions(+), 54 deletions(-) diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index b8a03159e..6a67cd141 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -16,23 +16,24 @@ use std::sync::Arc; use axum::{ extract::{Path, State}, - response::{IntoResponse, Response}, + response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; use hyper::StatusCode; -use mas_axum_utils::SessionInfoExt; +use mas_axum_utils::{csrf::CsrfExt, SessionInfoExt}; use mas_data_model::{AuthorizationGrant, BrowserSession, Client, Device}; use mas_keystore::{Encrypter, Keystore}; -use mas_policy::PolicyFactory; +use mas_policy::{EvaluationResult, PolicyFactory}; use mas_router::{PostAuthAction, Route, UrlBuilder}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, user::BrowserSessionRepository, - BoxClock, BoxRepository, BoxRng, RepositoryAccess, + BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, }; -use mas_templates::Templates; +use mas_templates::{PolicyViolationContext, TemplateContext, Templates}; use oauth2_types::requests::AuthorizationResponse; use thiserror::Error; +use tracing::warn; use ulid::Ulid; use super::callback::CallbackDestination; @@ -74,6 +75,7 @@ impl IntoResponse for RouteError { } impl_from_error_for_route!(mas_storage::RepositoryError); +impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -87,7 +89,7 @@ impl_from_error_for_route!(super::callback::CallbackDestinationError); err, )] pub(crate) async fn get( - rng: BoxRng, + mut rng: BoxRng, clock: BoxClock, State(policy_factory): State>, State(templates): State, @@ -123,15 +125,15 @@ pub(crate) async fn get( .ok_or(RouteError::NoSuchClient)?; match complete( - rng, - clock, + &mut rng, + &clock, repo, key_store, &policy_factory, url_builder, grant, - client, - session, + &client, + &session, ) .await { @@ -144,10 +146,22 @@ pub(crate) async fn get( mas_router::Reauth::and_then(continue_grant).go(), ) .into_response()), - Err(GrantCompletionError::RequiresConsent | GrantCompletionError::PolicyViolation) => { + Err(GrantCompletionError::RequiresConsent) => { let next = mas_router::Consent(grant_id); Ok((cookie_jar, next.go()).into_response()) } + Err(GrantCompletionError::PolicyViolation(grant, res)) => { + warn!(violation = ?res, "Authorization grant for client {} denied by policy", client.id); + + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + let ctx = PolicyViolationContext::new(grant, client) + .with_session(session) + .with_csrf(csrf_token.form_value()); + + let content = templates.render_policy_violation(&ctx).await?; + + Ok((cookie_jar, Html(content)).into_response()) + } Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending), Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)), } @@ -168,7 +182,7 @@ pub enum GrantCompletionError { RequiresConsent, #[error("denied by the policy")] - PolicyViolation, + PolicyViolation(AuthorizationGrant, EvaluationResult), } impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError); @@ -179,15 +193,15 @@ impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError); impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError); pub(crate) async fn complete( - mut rng: BoxRng, - clock: BoxClock, + rng: &mut (impl rand::RngCore + rand::CryptoRng + Send), + clock: &impl Clock, mut repo: BoxRepository, key_store: Keystore, policy_factory: &PolicyFactory, url_builder: UrlBuilder, grant: AuthorizationGrant, - client: Client, - browser_session: BrowserSession, + client: &Client, + browser_session: &BrowserSession, ) -> Result { // Verify that the grant is in a pending stage if !grant.stage.is_pending() { @@ -197,7 +211,7 @@ pub(crate) async fn complete( // Check if the authentication is fresh enough let authentication = repo .browser_session() - .get_last_authentication(&browser_session) + .get_last_authentication(browser_session) .await?; let authentication = authentication.filter(|auth| auth.created_at > grant.max_auth_time()); @@ -209,16 +223,16 @@ pub(crate) async fn complete( // Run through the policy let mut policy = policy_factory.instantiate().await?; let res = policy - .evaluate_authorization_grant(&grant, &client, &browser_session.user) + .evaluate_authorization_grant(&grant, client, &browser_session.user) .await?; if !res.valid() { - return Err(GrantCompletionError::PolicyViolation); + return Err(GrantCompletionError::PolicyViolation(grant, res)); } let current_consent = repo .oauth2_client() - .get_consent_for_user(&client, &browser_session.user) + .get_consent_for_user(client, &browser_session.user) .await?; let lacks_consent = grant @@ -236,18 +250,12 @@ pub(crate) async fn complete( // All good, let's start the session let session = repo .oauth2_session() - .add( - &mut rng, - &clock, - &client, - &browser_session, - grant.scope.clone(), - ) + .add(rng, clock, client, browser_session, grant.scope.clone()) .await?; let grant = repo .oauth2_authorization_grant() - .fulfill(&clock, &session, grant) + .fulfill(clock, &session, grant) .await?; // Yep! Let's complete the auth now @@ -256,13 +264,13 @@ pub(crate) async fn complete( // Did they request an ID token? if grant.response_type_id_token { params.id_token = Some(generate_id_token( - &mut rng, - &clock, + rng, + clock, &url_builder, &key_store, - &client, + client, &grant, - &browser_session, + browser_session, None, Some(&valid_authentication), )?); diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 1230362dd..1fd4b52cc 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -16,11 +16,11 @@ use std::sync::Arc; use axum::{ extract::{Form, State}, - response::{IntoResponse, Response}, + response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; use hyper::StatusCode; -use mas_axum_utils::SessionInfoExt; +use mas_axum_utils::{csrf::CsrfExt, SessionInfoExt}; use mas_data_model::{AuthorizationCode, Pkce}; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; @@ -29,7 +29,7 @@ use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, BoxClock, BoxRepository, BoxRng, }; -use mas_templates::Templates; +use mas_templates::{PolicyViolationContext, TemplateContext, Templates}; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, pkce, @@ -39,6 +39,7 @@ use oauth2_types::{ use rand::{distributions::Alphanumeric, Rng}; use serde::Deserialize; use thiserror::Error; +use tracing::warn; use self::{callback::CallbackDestination, complete::GrantCompletionError}; use crate::impl_from_error_for_route; @@ -91,6 +92,7 @@ impl IntoResponse for RouteError { } impl_from_error_for_route!(mas_storage::RepositoryError); +impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); @@ -170,6 +172,7 @@ pub(crate) async fn get( // Get the session info from the cookie let (session_info, cookie_jar) = cookie_jar.session_info(); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); // One day, we will have try blocks let res: Result = ({ @@ -340,15 +343,15 @@ pub(crate) async fn get( Some(user_session) if prompt.contains(&Prompt::None) => { // With prompt=none, we should get back to the client immediately match self::complete::complete( - rng, - clock, + &mut rng, + &clock, repo, key_store, &policy_factory, url_builder, grant, - client, - user_session, + &client, + &user_session, ) .await { @@ -369,7 +372,7 @@ pub(crate) async fn get( ) .await? } - Err(GrantCompletionError::PolicyViolation) => { + Err(GrantCompletionError::PolicyViolation(_grant, _res)) => { callback_destination .go(&templates, ClientError::from(ClientErrorCode::AccessDenied)) .await? @@ -387,29 +390,32 @@ pub(crate) async fn get( let grant_id = grant.id; // Else, we show the relevant reauth/consent page if necessary match self::complete::complete( - rng, - clock, + &mut rng, + &clock, repo, key_store, &policy_factory, url_builder, grant, - client, - user_session, + &client, + &user_session, ) .await { Ok(params) => callback_destination.go(&templates, params).await?, - Err( - GrantCompletionError::RequiresConsent - | GrantCompletionError::PolicyViolation, - ) => { - // We're redirecting to the consent URI in both 'consent required' and - // 'policy violation' cases, because we reevaluate the policy on this - // page, and show the error accordingly - // XXX: is this the right approach? + Err(GrantCompletionError::RequiresConsent) => { mas_router::Consent(grant_id).go().into_response() } + Err(GrantCompletionError::PolicyViolation(grant, res)) => { + warn!(violation = ?res, "Authorization grant for client {} denied by policy", client.id); + + let ctx = PolicyViolationContext::new(grant, client) + .with_session(user_session) + .with_csrf(csrf_token.form_value()); + + let content = templates.render_policy_violation(&ctx).await?; + Html(content).into_response() + } Err(GrantCompletionError::RequiresReauth) => { mas_router::Reauth::and_then(continue_grant) .go() diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index d5a42792e..3e749d97a 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -31,6 +31,7 @@ use mas_storage::{ Repository, RepositoryAccess, RepositoryTransaction, }; use sqlx::{PgPool, Postgres, Transaction}; +use tracing::Instrument; use crate::{ compat::{ @@ -78,11 +79,21 @@ impl RepositoryTransaction for PgRepository { type Error = DatabaseError; fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { - self.txn.commit().map_err(DatabaseError::from).boxed() + let span = tracing::info_span!("db.save"); + self.txn + .commit() + .map_err(DatabaseError::from) + .instrument(span) + .boxed() } fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { - self.txn.rollback().map_err(DatabaseError::from).boxed() + let span = tracing::info_span!("db.cancel"); + self.txn + .rollback() + .map_err(DatabaseError::from) + .instrument(span) + .boxed() } } From c821e3de54149b21135c2aab907aa477dcd4dc45 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 3 Aug 2023 14:03:07 +0200 Subject: [PATCH 5/8] Properly trace the `cleanup-expired-tokens` job --- crates/tasks/src/database.rs | 8 +++++++- crates/tasks/src/email.rs | 1 + crates/tasks/src/utils.rs | 33 ++++++++++++++++++++++++--------- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index 6c639eafe..d8f8eccc1 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -29,7 +29,10 @@ use chrono::{DateTime, Utc}; use mas_storage::{oauth2::OAuth2AccessTokenRepository, RepositoryAccess}; use tracing::{debug, info}; -use crate::{utils::metrics_layer, JobContextExt, State}; +use crate::{ + utils::{metrics_layer, trace_layer, TracedJob}, + JobContextExt, State, +}; #[derive(Default, Clone)] pub struct CleanupExpiredTokensJob { @@ -46,6 +49,8 @@ impl Job for CleanupExpiredTokensJob { const NAME: &'static str = "cleanup-expired-tokens"; } +impl TracedJob for CleanupExpiredTokensJob {} + pub async fn cleanup_expired_tokens( job: CleanupExpiredTokensJob, ctx: JobContext, @@ -79,6 +84,7 @@ pub(crate) fn register( .stream(CronStream::new(schedule).timer(TokioTimer).to_stream()) .layer(state.inject()) .layer(metrics_layer()) + .layer(trace_layer()) .build_fn(cleanup_expired_tokens); monitor.register(worker) diff --git a/crates/tasks/src/email.rs b/crates/tasks/src/email.rs index e1caf7128..8646ff0fa 100644 --- a/crates/tasks/src/email.rs +++ b/crates/tasks/src/email.rs @@ -109,5 +109,6 @@ pub(crate) fn register( c.fetch_interval(std::time::Duration::from_secs(1)) }) .build_fn(verify_email); + monitor.register(worker) } diff --git a/crates/tasks/src/utils.rs b/crates/tasks/src/utils.rs index 804658b74..a607c6af1 100644 --- a/crates/tasks/src/utils.rs +++ b/crates/tasks/src/utils.rs @@ -18,14 +18,32 @@ use mas_tower::{ make_span_fn, DurationRecorderLayer, FnWrapper, IdentityLayer, InFlightCounterLayer, TraceLayer, KV, }; -use opentelemetry::{Key, KeyValue}; +use opentelemetry::{trace::SpanContext, Key, KeyValue}; use tracing::info_span; use tracing_opentelemetry::OpenTelemetrySpanExt; const JOB_NAME: Key = Key::from_static_str("job.name"); const JOB_STATUS: Key = Key::from_static_str("job.status"); -fn make_span_for_job_request(req: &JobRequest>) -> tracing::Span +/// Represents a job that can may have a span context attached to it. +pub trait TracedJob: Job { + /// Returns the span context for this job, if any. + /// + /// The default implementation returns `None`. + fn span_context(&self) -> Option { + None + } +} + +/// Implements [`TracedJob`] for any job with the [`JobWithSpanContext`] +/// wrapper. +impl TracedJob for JobWithSpanContext { + fn span_context(&self) -> Option { + JobWithSpanContext::span_context(self) + } +} + +fn make_span_for_job_request(req: &JobRequest) -> tracing::Span where J: Job, { @@ -45,18 +63,15 @@ where span } -type TraceLayerForJob = TraceLayer< - FnWrapper>) -> tracing::Span>, - KV<&'static str>, - KV<&'static str>, ->; +type TraceLayerForJob = + TraceLayer) -> tracing::Span>, KV<&'static str>, KV<&'static str>>; pub(crate) fn trace_layer() -> TraceLayerForJob where - J: Job, + J: TracedJob, { TraceLayer::new(make_span_fn( - make_span_for_job_request:: as fn(&JobRequest>) -> tracing::Span, + make_span_for_job_request:: as fn(&JobRequest) -> tracing::Span, )) .on_response(KV("otel.status_code", "OK")) .on_error(KV("otel.status_code", "ERROR")) From bea8e4eff43e091d39615cd6a991955c17567df0 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 3 Aug 2023 14:05:10 +0200 Subject: [PATCH 6/8] Call the homeserver for user deactivation --- Cargo.lock | 1 + crates/cli/src/commands/manage.rs | 25 +++++-- crates/matrix-synapse/Cargo.toml | 3 +- crates/matrix-synapse/src/lib.rs | 75 ++++++++++++++++++++ crates/matrix/src/lib.rs | 112 ++++++++++++++++++++++++++++++ crates/matrix/src/mock.rs | 13 ++++ crates/storage/src/job.rs | 52 +++++++++++++- crates/tasks/src/lib.rs | 2 + crates/tasks/src/user.rs | 96 +++++++++++++++++++++++++ 9 files changed, 373 insertions(+), 6 deletions(-) create mode 100644 crates/tasks/src/user.rs diff --git a/Cargo.lock b/Cargo.lock index fd42f90ac..a943b0152 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3382,6 +3382,7 @@ dependencies = [ "mas-matrix", "serde", "tower", + "tracing", "url", ] diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 192b46c4e..4d41ee829 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -18,14 +18,14 @@ use mas_config::{DatabaseConfig, PasswordsConfig}; use mas_data_model::{Device, TokenType}; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - job::{DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob}, + job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob}, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, Repository, RepositoryAccess, SystemClock, }; use mas_storage_pg::PgRepository; use rand::SeedableRng; use sqlx::types::Uuid; -use tracing::{info, info_span}; +use tracing::{info, info_span, warn}; use crate::util::{database_from_config, password_manager_from_config}; @@ -74,6 +74,10 @@ enum Subcommand { LockUser { /// User to lock username: String, + + /// Whether to deactivate the user + #[arg(long)] + deactivate: bool, }, /// Unlock a user @@ -343,7 +347,10 @@ impl Options { Ok(()) } - SC::LockUser { username } => { + SC::LockUser { + username, + deactivate, + } => { let _span = info_span!("cli.manage.lock_user", user.username = username).entered(); let config: DatabaseConfig = root.load_config()?; let pool = database_from_config(&config).await?; @@ -357,7 +364,17 @@ impl Options { info!(%user.id, "Locking user"); - repo.user().lock(&clock, user).await?; + // Even though the deactivation job will lock the user, we lock it here in case + // the worker is not running, as we don't have a good way to run a job + // synchronously yet. + let user = repo.user().lock(&clock, user).await?; + + if deactivate { + warn!(%user.id, "Scheduling user deactivation"); + repo.job() + .schedule_job(DeactivateUserJob::new(&user, false)) + .await?; + } repo.save().await?; Ok(()) diff --git a/crates/matrix-synapse/Cargo.toml b/crates/matrix-synapse/Cargo.toml index 64bb19f7c..f0a47b976 100644 --- a/crates/matrix-synapse/Cargo.toml +++ b/crates/matrix-synapse/Cargo.toml @@ -9,9 +9,10 @@ license = "Apache-2.0" anyhow = "1.0.72" async-trait = "0.1.72" http = "0.2.9" -url = "2.4.0" serde = { version = "1.0.177", features = ["derive"] } tower = { version = "0.4.13", features = ["util"] } +tracing = "0.1.37" +url = "2.4.0" mas-axum-utils = { path = "../axum-utils" } mas-http = { path = "../http" } diff --git a/crates/matrix-synapse/src/lib.rs b/crates/matrix-synapse/src/lib.rs index 1620faaca..3b888e616 100644 --- a/crates/matrix-synapse/src/lib.rs +++ b/crates/matrix-synapse/src/lib.rs @@ -124,6 +124,11 @@ struct SynapseDevice { device_id: String, } +#[derive(Serialize)] +struct SynapseDeactivateUserRequest { + erase: bool, +} + #[async_trait::async_trait] impl HomeserverConnection for SynapseConnection { type Error = anyhow::Error; @@ -132,6 +137,15 @@ impl HomeserverConnection for SynapseConnection { &self.homeserver } + #[tracing::instrument( + name = "homeserver.query_user", + skip_all, + fields( + matrix.homeserver = self.homeserver, + matrix.mxid = mxid, + ), + err(Display), + )] async fn query_user(&self, mxid: &str) -> Result { let mut client = self .http_client_factory @@ -158,6 +172,16 @@ impl HomeserverConnection for SynapseConnection { }) } + #[tracing::instrument( + name = "homeserver.provision_user", + skip_all, + fields( + matrix.homeserver = self.homeserver, + matrix.mxid = request.mxid(), + user.id = request.sub(), + ), + err(Display), + )] async fn provision_user(&self, request: &ProvisionRequest) -> Result { let mut body = SynapseUser { external_ids: Some(vec![ExternalID { @@ -213,6 +237,16 @@ impl HomeserverConnection for SynapseConnection { } } + #[tracing::instrument( + name = "homeserver.create_device", + skip_all, + fields( + matrix.homeserver = self.homeserver, + matrix.mxid = mxid, + matrix.device_id = device_id, + ), + err(Display), + )] async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { let mut client = self .http_client_factory @@ -236,6 +270,16 @@ impl HomeserverConnection for SynapseConnection { Ok(()) } + #[tracing::instrument( + name = "homeserver.delete_device", + skip_all, + fields( + matrix.homeserver = self.homeserver, + matrix.mxid = mxid, + matrix.device_id = device_id, + ), + err(Display), + )] async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { let mut client = self.http_client_factory.client().await?; @@ -253,4 +297,35 @@ impl HomeserverConnection for SynapseConnection { Ok(()) } + + #[tracing::instrument( + name = "homeserver.delete_user", + skip_all, + fields( + matrix.homeserver = self.homeserver, + matrix.mxid = mxid, + erase = erase, + ), + err(Display), + )] + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { + let mut client = self + .http_client_factory + .client() + .await? + .request_bytes_to_body() + .json_request(); + + let request = self + .post(&format!("_synapse/admin/v1/deactivate/{mxid}")) + .body(SynapseDeactivateUserRequest { erase })?; + + let response = client.ready().await?.call(request).await?; + + if response.status() != StatusCode::OK { + return Err(anyhow::anyhow!("Failed to delete user in Synapse")); + } + + Ok(()) + } } diff --git a/crates/matrix/src/lib.rs b/crates/matrix/src/lib.rs index b1e3e4e58..b5f799d1c 100644 --- a/crates/matrix/src/lib.rs +++ b/crates/matrix/src/lib.rs @@ -43,6 +43,12 @@ pub struct ProvisionRequest { } impl ProvisionRequest { + /// Create a new [`ProvisionRequest`]. + /// + /// # Parameters + /// + /// * `mxid` - The Matrix ID to provision. + /// * `sub` - The `sub` of the user, aka the internal ID. #[must_use] pub fn new(mxid: impl Into, sub: impl Into) -> Self { Self { @@ -54,28 +60,41 @@ impl ProvisionRequest { } } + /// Get the `sub` of the user to provision, aka the internal ID. #[must_use] pub fn sub(&self) -> &str { &self.sub } + /// Get the Matrix ID to provision. #[must_use] pub fn mxid(&self) -> &str { &self.mxid } + /// Ask to set the displayname of the user. + /// + /// # Parameters + /// + /// * `displayname` - The displayname to set. #[must_use] pub fn set_displayname(mut self, displayname: String) -> Self { self.displayname = FieldAction::Set(displayname); self } + /// Ask to unset the displayname of the user. #[must_use] pub fn unset_displayname(mut self) -> Self { self.displayname = FieldAction::Unset; self } + /// Call the given callback if the displayname should be set or unset. + /// + /// # Parameters + /// + /// * `callback` - The callback to call. pub fn on_displayname(&self, callback: impl FnOnce(Option<&str>)) -> &Self { match &self.displayname { FieldAction::Unset => callback(None), @@ -86,18 +105,29 @@ impl ProvisionRequest { self } + /// Ask to set the avatar URL of the user. + /// + /// # Parameters + /// + /// * `avatar_url` - The avatar URL to set. #[must_use] pub fn set_avatar_url(mut self, avatar_url: String) -> Self { self.avatar_url = FieldAction::Set(avatar_url); self } + /// Ask to unset the avatar URL of the user. #[must_use] pub fn unset_avatar_url(mut self) -> Self { self.avatar_url = FieldAction::Unset; self } + /// Call the given callback if the avatar URL should be set or unset. + /// + /// # Parameters + /// + /// * `callback` - The callback to call. pub fn on_avatar_url(&self, callback: impl FnOnce(Option<&str>)) -> &Self { match &self.avatar_url { FieldAction::Unset => callback(None), @@ -108,18 +138,29 @@ impl ProvisionRequest { self } + /// Ask to set the emails of the user. + /// + /// # Parameters + /// + /// * `emails` - The list of emails to set. #[must_use] pub fn set_emails(mut self, emails: Vec) -> Self { self.emails = FieldAction::Set(emails); self } + /// Ask to unset the emails of the user. #[must_use] pub fn unset_emails(mut self) -> Self { self.emails = FieldAction::Unset; self } + /// Call the given callback if the emails should be set or unset. + /// + /// # Parameters + /// + /// * `callback` - The callback to call. pub fn on_emails(&self, callback: impl FnOnce(Option<&[String]>)) -> &Self { match &self.emails { FieldAction::Unset => callback(None), @@ -133,17 +174,84 @@ impl ProvisionRequest { #[async_trait::async_trait] pub trait HomeserverConnection: Send + Sync { + /// The error type returned by all methods. type Error; + /// Get the homeserver URL. fn homeserver(&self) -> &str; + + /// Get the Matrix ID of the user with the given localpart. + /// + /// # Parameters + /// + /// * `localpart` - The localpart of the user. fn mxid(&self, localpart: &str) -> String { format!("@{}:{}", localpart, self.homeserver()) } + /// Query the state of a user on the homeserver. + /// + /// # Parameters + /// + /// * `mxid` - The Matrix ID of the user to query. + /// + /// # Errors + /// + /// Returns an error if the homeserver is unreachable or the user does not + /// exist. async fn query_user(&self, mxid: &str) -> Result; + + /// Provision a user on the homeserver. + /// + /// # Parameters + /// + /// * `request` - a [`ProvisionRequest`] containing the details of the user + /// to provision. + /// + /// # Errors + /// + /// Returns an error if the homeserver is unreachable or the user could not + /// be provisioned. async fn provision_user(&self, request: &ProvisionRequest) -> Result; + + /// Create a device for a user on the homeserver. + /// + /// # Parameters + /// + /// * `mxid` - The Matrix ID of the user to create a device for. + /// * `device_id` - The device ID to create. + /// + /// # Errors + /// + /// Returns an error if the homeserver is unreachable or the device could + /// not be created. async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error>; + + /// Delete a device for a user on the homeserver. + /// + /// # Parameters + /// + /// * `mxid` - The Matrix ID of the user to delete a device for. + /// * `device_id` - The device ID to delete. + /// + /// # Errors + /// + /// Returns an error if the homeserver is unreachable or the device could + /// not be deleted. async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error>; + + /// Delete a user on the homeserver. + /// + /// # Parameters + /// + /// * `mxid` - The Matrix ID of the user to delete. + /// * `erase` - Whether to ask the homeserver to erase the user's data. + /// + /// # Errors + /// + /// Returns an error if the homeserver is unreachable or the user could not + /// be deleted. + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error>; } #[async_trait::async_trait] @@ -169,4 +277,8 @@ impl HomeserverConnection for &T async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { (**self).delete_device(mxid, device_id).await } + + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { + (**self).delete_user(mxid, erase).await + } } diff --git a/crates/matrix/src/mock.rs b/crates/matrix/src/mock.rs index b13aa0bec..33d1e0d4b 100644 --- a/crates/matrix/src/mock.rs +++ b/crates/matrix/src/mock.rs @@ -109,6 +109,19 @@ impl HomeserverConnection for MockHomeserverConnection { user.devices.remove(device_id); Ok(()) } + + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { + let mut users = self.users.write().await; + let user = users.get_mut(mxid).context("User not found")?; + user.devices.clear(); + user.emails = None; + if erase { + user.avatar_url = None; + user.displayname = None; + } + + Ok(()) + } } #[cfg(test)] diff --git a/crates/storage/src/job.rs b/crates/storage/src/job.rs index d42b21380..2d2f85fc9 100644 --- a/crates/storage/src/job.rs +++ b/crates/storage/src/job.rs @@ -73,6 +73,15 @@ pub struct JobWithSpanContext { payload: T, } +impl From for JobWithSpanContext { + fn from(payload: J) -> Self { + Self { + span_context: None, + payload, + } + } +} + impl Job for JobWithSpanContext { const NAME: &'static str = J::NAME; } @@ -369,6 +378,47 @@ mod jobs { impl Job for DeleteDeviceJob { const NAME: &'static str = "delete-device"; } + + /// A job to deactivate and lock a user + #[derive(Serialize, Deserialize, Debug, Clone)] + pub struct DeactivateUserJob { + user_id: Ulid, + hs_erase: bool, + } + + impl DeactivateUserJob { + /// Create a new job to deactivate and lock a user + /// + /// # Parameters + /// + /// * `user` - The user to deactivate + /// * `hs_erase` - Whether to erase the user from the homeserver + #[must_use] + pub fn new(user: &User, hs_erase: bool) -> Self { + Self { + user_id: user.id, + hs_erase, + } + } + + /// The ID of the user to deactivate + #[must_use] + pub fn user_id(&self) -> Ulid { + self.user_id + } + + /// Whether to erase the user from the homeserver + #[must_use] + pub fn hs_erase(&self) -> bool { + self.hs_erase + } + } + + impl Job for DeactivateUserJob { + const NAME: &'static str = "deactivate-user"; + } } -pub use self::jobs::{DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, VerifyEmailJob}; +pub use self::jobs::{ + DeactivateUserJob, DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, VerifyEmailJob, +}; diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index da5996874..1019bc4fd 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -33,6 +33,7 @@ mod database; mod email; mod matrix; mod storage; +mod user; mod utils; #[derive(Clone)] @@ -128,6 +129,7 @@ pub async fn init( let monitor = self::database::register(name, monitor, &state); let monitor = self::email::register(name, monitor, &state, &factory); let monitor = self::matrix::register(name, monitor, &state, &factory); + let monitor = self::user::register(name, monitor, &state, &factory); // TODO: we might want to grab the join handle here factory.listen().await?; debug!(?monitor, "workers registered"); diff --git a/crates/tasks/src/user.rs b/crates/tasks/src/user.rs new file mode 100644 index 000000000..398f8d542 --- /dev/null +++ b/crates/tasks/src/user.rs @@ -0,0 +1,96 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::time::Duration; + +use anyhow::Context; +use apalis_core::{ + builder::{WorkerBuilder, WorkerFactoryFn}, + context::JobContext, + executor::TokioExecutor, + job::Job, + monitor::Monitor, + storage::builder::WithStorage, +}; +use mas_storage::{ + job::{DeactivateUserJob, DeleteDeviceJob, JobWithSpanContext}, + user::UserRepository, + RepositoryAccess, +}; +use tracing::info; + +use crate::{ + storage::PostgresStorageFactory, + utils::{metrics_layer, trace_layer}, + JobContextExt, State, +}; + +/// Job to deactivate a user, both locally and on the Matrix homeserver. +#[tracing::instrument( + name = "job.deactivate_user" + fields(user.id = %job.user_id(), erase = %job.hs_erase()), + skip_all, + err(Debug), +)] +async fn deactivate_user( + job: JobWithSpanContext, + ctx: JobContext, +) -> Result<(), anyhow::Error> { + let state = ctx.state(); + let clock = state.clock(); + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(job.user_id()) + .await? + .context("User not found")?; + + // Let's first lock the user + let user = repo + .user() + .lock(&clock, user) + .await + .context("Failed to lock user")?; + + // TODO: delete the sessions & access tokens + + // Before calling back to the homeserver, commit the changes to the database + repo.save().await?; + + let mxid = matrix.mxid(&user.username); + info!("Deactivating user {} on homeserver", mxid); + matrix.delete_user(&mxid, job.hs_erase()).await?; + + Ok(()) +} + +pub(crate) fn register( + suffix: &str, + monitor: Monitor, + state: &State, + storage_factory: &PostgresStorageFactory, +) -> Monitor { + let storage = storage_factory.build(); + let worker_name = format!("{job}-{suffix}", job = DeleteDeviceJob::NAME); + let deactivate_user_worker = WorkerBuilder::new(worker_name) + .layer(state.inject()) + .layer(trace_layer()) + .layer(metrics_layer()) + .with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1))) + .build_fn(deactivate_user); + + monitor.register(deactivate_user_worker) +} From 649d86c1cf642c1570ef46296891ae90a3a93caa Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 3 Aug 2023 15:02:40 +0200 Subject: [PATCH 7/8] mas-tasks: refactor worker building behind a macro --- crates/tasks/src/email.rs | 29 ++++------------------ crates/tasks/src/lib.rs | 26 ++++++++++++++++++++ crates/tasks/src/matrix.rs | 49 +++++++------------------------------- crates/tasks/src/user.rs | 29 ++++------------------ 4 files changed, 44 insertions(+), 89 deletions(-) diff --git a/crates/tasks/src/email.rs b/crates/tasks/src/email.rs index 8646ff0fa..32a9f5eb9 100644 --- a/crates/tasks/src/email.rs +++ b/crates/tasks/src/email.rs @@ -13,25 +13,14 @@ // limitations under the License. use anyhow::Context; -use apalis_core::{ - builder::{WorkerBuilder, WorkerFactoryFn}, - context::JobContext, - executor::TokioExecutor, - job::Job, - monitor::Monitor, - storage::builder::WithStorage, -}; +use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; use chrono::Duration; use mas_email::{Address, EmailVerificationContext, Mailbox}; use mas_storage::job::{JobWithSpanContext, VerifyEmailJob}; use rand::{distributions::Uniform, Rng}; use tracing::info; -use crate::{ - storage::PostgresStorageFactory, - utils::{metrics_layer, trace_layer}, - JobContextExt, State, -}; +use crate::{storage::PostgresStorageFactory, JobContextExt, State}; #[tracing::instrument( name = "job.verify_email", @@ -99,16 +88,8 @@ pub(crate) fn register( state: &State, storage_factory: &PostgresStorageFactory, ) -> Monitor { - let storage = storage_factory.build(); - let worker_name = format!("{job}-{suffix}", job = VerifyEmailJob::NAME); - let worker = WorkerBuilder::new(worker_name) - .layer(state.inject()) - .layer(trace_layer()) - .layer(metrics_layer()) - .with_storage_config(storage, |c| { - c.fetch_interval(std::time::Duration::from_secs(1)) - }) - .build_fn(verify_email); + let verify_email_worker = + crate::build!(VerifyEmailJob => verify_email, suffix, state, storage_factory); - monitor.register(worker) + monitor.register(verify_email_worker) } diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index 1019bc4fd..ebf0d783a 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -107,6 +107,32 @@ impl JobContextExt for apalis_core::context::JobContext { } } +/// Helper macro to build a storage-backed worker. +macro_rules! build { + ($job:ty => $fn:ident, $suffix:expr, $state:expr, $factory:expr) => {{ + let storage = $factory.build(); + let worker_name = format!( + "{job}-{suffix}", + job = <$job as ::apalis_core::job::Job>::NAME, + suffix = $suffix + ); + + let builder = ::apalis_core::builder::WorkerBuilder::new(worker_name) + .layer($state.inject()) + .layer(crate::utils::trace_layer()) + .layer(crate::utils::metrics_layer()); + + let builder = ::apalis_core::storage::builder::WithStorage::with_storage_config( + builder, + storage, + |c| c.fetch_interval(std::time::Duration::from_secs(1)), + ); + ::apalis_core::builder::WorkerFactory::build(builder, ::apalis_core::job_fn::job_fn($fn)) + }}; +} + +pub(crate) use build; + /// Initialise the workers. /// /// # Errors diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index 395dccbb2..a3c993094 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -12,17 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::time::Duration; - use anyhow::Context; -use apalis_core::{ - builder::{WorkerBuilder, WorkerFactoryFn}, - context::JobContext, - executor::TokioExecutor, - job::Job, - monitor::Monitor, - storage::builder::WithStorage, -}; +use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; use mas_matrix::ProvisionRequest; use mas_storage::{ job::{DeleteDeviceJob, JobWithSpanContext, ProvisionDeviceJob, ProvisionUserJob}, @@ -31,11 +22,7 @@ use mas_storage::{ }; use tracing::info; -use crate::{ - storage::PostgresStorageFactory, - utils::{metrics_layer, trace_layer}, - JobContextExt, State, -}; +use crate::{storage::PostgresStorageFactory, JobContextExt, State}; /// Job to provision a user on the Matrix homeserver. /// This works by doing a PUT request to the /_synapse/admin/v2/users/{user_id} @@ -163,32 +150,12 @@ pub(crate) fn register( state: &State, storage_factory: &PostgresStorageFactory, ) -> Monitor { - let storage = storage_factory.build(); - let worker_name = format!("{job}-{suffix}", job = ProvisionUserJob::NAME); - let provision_user_worker = WorkerBuilder::new(worker_name) - .layer(state.inject()) - .layer(trace_layer()) - .layer(metrics_layer()) - .with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1))) - .build_fn(provision_user); - - let storage = storage_factory.build(); - let worker_name = format!("{job}-{suffix}", job = ProvisionDeviceJob::NAME); - let provision_device_worker = WorkerBuilder::new(worker_name) - .layer(state.inject()) - .layer(trace_layer()) - .layer(metrics_layer()) - .with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1))) - .build_fn(provision_device); - - let storage = storage_factory.build(); - let worker_name = format!("{job}-{suffix}", job = DeleteDeviceJob::NAME); - let delete_device_worker = WorkerBuilder::new(worker_name) - .layer(state.inject()) - .layer(trace_layer()) - .layer(metrics_layer()) - .with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1))) - .build_fn(delete_device); + let provision_user_worker = + crate::build!(ProvisionUserJob => provision_user, suffix, state, storage_factory); + let provision_device_worker = + crate::build!(ProvisionDeviceJob => provision_device, suffix, state, storage_factory); + let delete_device_worker = + crate::build!(DeleteDeviceJob => delete_device, suffix, state, storage_factory); monitor .register(provision_user_worker) diff --git a/crates/tasks/src/user.rs b/crates/tasks/src/user.rs index 398f8d542..2eb2d3a23 100644 --- a/crates/tasks/src/user.rs +++ b/crates/tasks/src/user.rs @@ -12,29 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::time::Duration; - use anyhow::Context; -use apalis_core::{ - builder::{WorkerBuilder, WorkerFactoryFn}, - context::JobContext, - executor::TokioExecutor, - job::Job, - monitor::Monitor, - storage::builder::WithStorage, -}; +use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; use mas_storage::{ - job::{DeactivateUserJob, DeleteDeviceJob, JobWithSpanContext}, + job::{DeactivateUserJob, JobWithSpanContext}, user::UserRepository, RepositoryAccess, }; use tracing::info; -use crate::{ - storage::PostgresStorageFactory, - utils::{metrics_layer, trace_layer}, - JobContextExt, State, -}; +use crate::{storage::PostgresStorageFactory, JobContextExt, State}; /// Job to deactivate a user, both locally and on the Matrix homeserver. #[tracing::instrument( @@ -83,14 +70,8 @@ pub(crate) fn register( state: &State, storage_factory: &PostgresStorageFactory, ) -> Monitor { - let storage = storage_factory.build(); - let worker_name = format!("{job}-{suffix}", job = DeleteDeviceJob::NAME); - let deactivate_user_worker = WorkerBuilder::new(worker_name) - .layer(state.inject()) - .layer(trace_layer()) - .layer(metrics_layer()) - .with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1))) - .build_fn(deactivate_user); + let deactivate_user_worker = + crate::build!(DeactivateUserJob => deactivate_user, suffix, state, storage_factory); monitor.register(deactivate_user_worker) } From 07ec4b5a340e5452da9c5cfa8f58350adf3c0a35 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 3 Aug 2023 15:06:45 +0200 Subject: [PATCH 8/8] mas-matrix: fix clippy warnings --- crates/handlers/src/test_utils.rs | 2 +- crates/matrix/src/lib.rs | 2 +- crates/matrix/src/mock.rs | 15 ++++++++------- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 720d2deee..8b2a4c518 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -23,7 +23,7 @@ use headers::{Authorization, ContentType, HeaderMapExt, HeaderName}; use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode}; use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; -use mas_matrix::{HomeserverConnection, MatrixUser, MockHomeserverConnection, ProvisionRequest}; +use mas_matrix::MockHomeserverConnection; use mas_policy::PolicyFactory; use mas_router::{SimpleRoute, UrlBuilder}; use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; diff --git a/crates/matrix/src/lib.rs b/crates/matrix/src/lib.rs index b5f799d1c..144801846 100644 --- a/crates/matrix/src/lib.rs +++ b/crates/matrix/src/lib.rs @@ -18,7 +18,7 @@ mod mock; -pub use self::mock::MockHomeserverConnection; +pub use self::mock::HomeserverConnection as MockHomeserverConnection; #[derive(Debug)] pub struct MatrixUser { diff --git a/crates/matrix/src/mock.rs b/crates/matrix/src/mock.rs index 33d1e0d4b..53bfbb0a0 100644 --- a/crates/matrix/src/mock.rs +++ b/crates/matrix/src/mock.rs @@ -18,7 +18,7 @@ use anyhow::Context; use async_trait::async_trait; use tokio::sync::RwLock; -use crate::{HomeserverConnection, MatrixUser, ProvisionRequest}; +use crate::{MatrixUser, ProvisionRequest}; struct MockUser { sub: String, @@ -28,15 +28,15 @@ struct MockUser { emails: Option>, } -/// A Mock implementation of a [`HomeserverConnection`], which never fails and +/// A mock implementation of a [`HomeserverConnection`], which never fails and /// doesn't do anything. -pub struct MockHomeserverConnection { +pub struct HomeserverConnection { homeserver: String, users: RwLock>, } -impl MockHomeserverConnection { - /// Create a new [`MockHomeserverConnection`]. +impl HomeserverConnection { + /// Create a new mock connection. pub fn new(homeserver: H) -> Self where H: Into, @@ -49,7 +49,7 @@ impl MockHomeserverConnection { } #[async_trait] -impl HomeserverConnection for MockHomeserverConnection { +impl crate::HomeserverConnection for HomeserverConnection { type Error = anyhow::Error; fn homeserver(&self) -> &str { @@ -127,10 +127,11 @@ impl HomeserverConnection for MockHomeserverConnection { #[cfg(test)] mod tests { use super::*; + use crate::HomeserverConnection as _; #[tokio::test] async fn test_mock_connection() { - let conn = MockHomeserverConnection::new("example.org"); + let conn = HomeserverConnection::new("example.org"); let mxid = "@test:example.org"; let device = "test";