diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 41b9a11f7..97c019175 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -542,7 +542,7 @@ impl Options { warn!(%user.id, "User scheduling user reactivation"); repo.queue_job() - .schedule_job(&mut rng, &clock, ReactivateUserJob::new(&user, true)) + .schedule_job(&mut rng, &clock, ReactivateUserJob::new(&user)) .await?; repo.into_inner().commit().await?; diff --git a/crates/handlers/src/admin/v1/users/deactivate.rs b/crates/handlers/src/admin/v1/users/deactivate.rs index 7a6bd8e4e..316b882be 100644 --- a/crates/handlers/src/admin/v1/users/deactivate.rs +++ b/crates/handlers/src/admin/v1/users/deactivate.rs @@ -12,6 +12,8 @@ use mas_storage::{ BoxRng, queue::{DeactivateUserJob, QueueJobRepositoryExt as _}, }; +use schemars::JsonSchema; +use serde::Deserialize; use tracing::info; use ulid::Ulid; @@ -49,7 +51,25 @@ impl IntoResponse for RouteError { } } -pub fn doc(operation: TransformOperation) -> TransformOperation { +/// # JSON payload for the `POST /api/admin/v1/users/:id/deactivate` endpoint +#[derive(Default, Deserialize, JsonSchema)] +#[serde(rename = "DeactivateUserRequest")] +pub struct Request { + /// Whether to skip locking the user before deactivation. + #[serde(default)] + skip_lock: bool, +} + +pub fn doc(mut operation: TransformOperation) -> TransformOperation { + operation + .inner_mut() + .request_body + .as_mut() + .unwrap() + .as_item_mut() + .unwrap() + .required = false; + operation .id("deactivateUser") .summary("Deactivate a user") @@ -76,7 +96,9 @@ pub async fn handler( }: CallContext, NoApi(mut rng): NoApi, id: UlidPathParam, + body: Option>, ) -> Result>, RouteError> { + let Json(params) = body.unwrap_or_default(); let id = *id; let mut user = repo .user() @@ -84,7 +106,7 @@ pub async fn handler( .await? .ok_or(RouteError::NotFound(id))?; - if user.locked_at.is_none() { + if !params.skip_lock && user.locked_at.is_none() { user = repo.user().lock(&clock, user).await?; } @@ -105,14 +127,13 @@ pub async fn handler( mod tests { use chrono::Duration; use hyper::{Request, StatusCode}; - use insta::assert_json_snapshot; + use insta::{allow_duplicates, assert_json_snapshot}; use mas_storage::{Clock, RepositoryAccess, user::UserRepository}; use sqlx::PgPool; use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup}; - #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] - async fn test_deactivate_user(pool: PgPool) { + async fn test_deactivate_user_helper(pool: PgPool, skip_lock: Option) { setup(); let mut state = TestState::from_pool(pool.clone()).await.unwrap(); let token = state.token_with_scope("urn:mas:admin").await; @@ -125,19 +146,27 @@ mod tests { .unwrap(); repo.save().await.unwrap(); - let request = Request::post(format!("/api/admin/v1/users/{}/deactivate", user.id)) - .bearer(&token) - .empty(); + let request = + Request::post(format!("/api/admin/v1/users/{}/deactivate", user.id)).bearer(&token); + let request = match skip_lock { + None => request.empty(), + Some(skip_lock) => request.json(serde_json::json!({ + "skip_lock": skip_lock, + })), + }; let response = state.request(request).await; response.assert_status(StatusCode::OK); let body: serde_json::Value = response.json(); - // The locked_at timestamp should be the same as the current time + // The locked_at timestamp should be the same as the current time, or null if not locked assert_eq!( body["data"]["attributes"]["locked_at"], - serde_json::json!(state.clock.now()) + if !skip_lock.unwrap_or(false) { + serde_json::json!(state.clock.now()) + } else { + serde_json::Value::Null + } ); - // TODO: have test coverage on deactivated_at timestamp // Make sure to run the jobs in the queue state.run_jobs_in_queue().await; @@ -149,7 +178,7 @@ mod tests { response.assert_status(StatusCode::OK); let body: serde_json::Value = response.json(); - assert_json_snapshot!(body, @r#" + allow_duplicates!(assert_json_snapshot!(body, @r#" { "data": { "type": "user", @@ -169,7 +198,17 @@ mod tests { "self": "/api/admin/v1/users/01FSHN9AG0MZAA6S4AF7CTV32E" } } - "#); + "#)); + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_deactivate_user(pool: PgPool) { + test_deactivate_user_helper(pool, Option::None).await; + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_deactivate_user_skip_lock(pool: PgPool) { + test_deactivate_user_helper(pool, Option::Some(true)).await; } #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] @@ -206,7 +245,6 @@ mod tests { body["data"]["attributes"]["locked_at"], serde_json::Value::Null ); - // TODO: have test coverage on deactivated_at timestamp // Make sure to run the jobs in the queue state.run_jobs_in_queue().await; diff --git a/crates/handlers/src/admin/v1/users/reactivate.rs b/crates/handlers/src/admin/v1/users/reactivate.rs index 44c5ae88c..ad73c4dba 100644 --- a/crates/handlers/src/admin/v1/users/reactivate.rs +++ b/crates/handlers/src/admin/v1/users/reactivate.rs @@ -3,15 +3,13 @@ // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // Please see LICENSE files in the repository root for full details. -use aide::{NoApi, OperationIo, transform::TransformOperation}; -use axum::{Json, response::IntoResponse}; +use std::sync::Arc; + +use aide::{OperationIo, transform::TransformOperation}; +use axum::{Json, extract::State, response::IntoResponse}; use hyper::StatusCode; use mas_axum_utils::record_error; -use mas_storage::{ - BoxRng, - queue::{QueueJobRepositoryExt as _, ReactivateUserJob}, -}; -use tracing::info; +use mas_matrix::HomeserverConnection; use ulid::Ulid; use crate::{ @@ -30,6 +28,9 @@ pub enum RouteError { #[error(transparent)] Internal(Box), + #[error(transparent)] + Homeserver(anyhow::Error), + #[error("User ID {0} not found")] NotFound(Ulid), } @@ -39,9 +40,9 @@ impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { let error = ErrorResponse::from_error(&self); - let sentry_event_id = record_error!(self, Self::Internal(_)); + let sentry_event_id = record_error!(self, Self::Internal(_) | Self::Homeserver(_)); let status = match self { - Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::Internal(_) | Self::Homeserver(_) => StatusCode::INTERNAL_SERVER_ERROR, Self::NotFound(_) => StatusCode::NOT_FOUND, }; (status, sentry_event_id, Json(error)).into_response() @@ -69,10 +70,8 @@ pub fn doc(operation: TransformOperation) -> TransformOperation { #[tracing::instrument(name = "handler.admin.v1.users.reactivate", skip_all)] pub async fn handler( - CallContext { - mut repo, clock, .. - }: CallContext, - NoApi(mut rng): NoApi, + CallContext { mut repo, .. }: CallContext, + State(homeserver): State>, id: UlidPathParam, ) -> Result>, RouteError> { let id = *id; @@ -82,10 +81,15 @@ pub async fn handler( .await? .ok_or(RouteError::NotFound(id))?; - info!(%user.id, "Scheduling reactivation of user"); - repo.queue_job() - .schedule_job(&mut rng, &clock, ReactivateUserJob::new(&user, false)) - .await?; + // Call the homeserver synchronously to reactivate the user + let mxid = homeserver.mxid(&user.username); + homeserver + .reactivate_user(&mxid) + .await + .map_err(RouteError::Homeserver)?; + + // Now reactivate the user in our database + let user = repo.user().reactivate(user).await?; repo.save().await?; @@ -100,7 +104,7 @@ mod tests { use hyper::{Request, StatusCode}; use mas_matrix::{HomeserverConnection, ProvisionRequest}; use mas_storage::{Clock, RepositoryAccess, user::UserRepository}; - use sqlx::{PgPool, types::Json}; + use sqlx::PgPool; use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup}; @@ -150,18 +154,10 @@ mod tests { body["data"]["attributes"]["locked_at"], serde_json::json!(state.clock.now()) ); - // TODO: have test coverage on deactivated_at timestamp - - // It should have scheduled a reactivation job for the user - // XXX: we don't have a good way to look for the reactivation job - let job: Json = sqlx::query_scalar( - "SELECT payload FROM queue_jobs WHERE queue_name = 'reactivate-user'", - ) - .fetch_one(&pool) - .await - .expect("Reactivation job to be scheduled"); - assert_eq!(job["user_id"], serde_json::json!(user.id)); - assert_eq!(job["unlock"], serde_json::Value::Bool(false)); + assert_eq!( + body["data"]["attributes"]["deactivated_at"], + serde_json::Value::Null, + ); } #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] @@ -178,6 +174,14 @@ mod tests { .unwrap(); repo.save().await.unwrap(); + // Provision the user on the homeserver + let mxid = state.homeserver_connection.mxid(&user.username); + state + .homeserver_connection + .provision_user(&ProvisionRequest::new(&mxid, &user.sub)) + .await + .unwrap(); + let request = Request::post(format!("/api/admin/v1/users/{}/reactivate", user.id)) .bearer(&token) .empty(); @@ -189,18 +193,10 @@ mod tests { body["data"]["attributes"]["locked_at"], serde_json::Value::Null ); - // TODO: have test coverage on deactivated_at timestamp - - // It should have scheduled a reactivation job for the user - // XXX: we don't have a good way to look for the reactivation job - let job: Json = sqlx::query_scalar( - "SELECT payload FROM queue_jobs WHERE queue_name = 'reactivate-user'", - ) - .fetch_one(&pool) - .await - .expect("Reactivation job to be scheduled"); - assert_eq!(job["user_id"], serde_json::json!(user.id)); - assert_eq!(job["unlock"], serde_json::Value::Bool(false)); + assert_eq!( + body["data"]["attributes"]["deactivated_at"], + serde_json::Value::Null + ); } #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] diff --git a/crates/handlers/src/admin/v1/users/unlock.rs b/crates/handlers/src/admin/v1/users/unlock.rs index e74d80aea..224d6a81b 100644 --- a/crates/handlers/src/admin/v1/users/unlock.rs +++ b/crates/handlers/src/admin/v1/users/unlock.rs @@ -11,6 +11,8 @@ use axum::{Json, extract::State, response::IntoResponse}; use hyper::StatusCode; use mas_axum_utils::record_error; use mas_matrix::HomeserverConnection; +use schemars::JsonSchema; +use serde::Deserialize; use ulid::Ulid; use crate::{ @@ -50,7 +52,25 @@ impl IntoResponse for RouteError { } } -pub fn doc(operation: TransformOperation) -> TransformOperation { +/// # JSON payload for the `POST /api/admin/v1/users/:id/unlock` endpoint +#[derive(Default, Deserialize, JsonSchema)] +#[serde(rename = "UnlockUserRequest")] +pub struct Request { + /// Whether to skip ensuring the user is active upon being unlocked. + #[serde(default)] + skip_reactivate: bool, +} + +pub fn doc(mut operation: TransformOperation) -> TransformOperation { + operation + .inner_mut() + .request_body + .as_mut() + .unwrap() + .as_item_mut() + .unwrap() + .required = false; + operation .id("unlockUser") .summary("Unlock a user") @@ -73,7 +93,9 @@ pub async fn handler( CallContext { mut repo, .. }: CallContext, State(homeserver): State>, id: UlidPathParam, + body: Option>, ) -> Result>, RouteError> { + let Json(params) = body.unwrap_or_default(); let id = *id; let user = repo .user() @@ -81,15 +103,17 @@ pub async fn handler( .await? .ok_or(RouteError::NotFound(id))?; - // Call the homeserver synchronously to unlock the user - let mxid = homeserver.mxid(&user.username); - homeserver - .reactivate_user(&mxid) - .await - .map_err(RouteError::Homeserver)?; - - // Now unlock the user in our database - let user = repo.user().unlock(user).await?; + let user = if !params.skip_reactivate { + // Call the homeserver synchronously to reactivate the user + let mxid = homeserver.mxid(&user.username); + homeserver + .reactivate_user(&mxid) + .await + .map_err(RouteError::Homeserver)?; + repo.user().reactivate_and_unlock(user).await? + } else { + repo.user().unlock(user).await? + }; repo.save().await?; @@ -103,7 +127,7 @@ pub async fn handler( mod tests { use hyper::{Request, StatusCode}; use mas_matrix::{HomeserverConnection, ProvisionRequest}; - use mas_storage::{RepositoryAccess, user::UserRepository}; + use mas_storage::{user::UserRepository, Clock, RepositoryAccess}; use sqlx::PgPool; use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup}; @@ -145,8 +169,7 @@ mod tests { ); } - #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] - async fn test_unlock_deactivated_user(pool: PgPool) { + async fn test_unlock_deactivated_user_helper(pool: PgPool, skip_reactivate: Option) { setup(); let mut state = TestState::from_pool(pool).await.unwrap(); let token = state.token_with_scope("urn:mas:admin").await; @@ -179,9 +202,13 @@ mod tests { let mx_user = state.homeserver_connection.query_user(&mxid).await.unwrap(); assert!(mx_user.deactivated); - let request = Request::post(format!("/api/admin/v1/users/{}/unlock", user.id)) - .bearer(&token) - .empty(); + let request = Request::post(format!("/api/admin/v1/users/{}/unlock", user.id)).bearer(&token); + let request = match skip_reactivate { + None => request.empty(), + Some(skip_reactivate) => request.json(serde_json::json!({ + "skip_reactivate": skip_reactivate, + })), + }; let response = state.request(request).await; response.assert_status(StatusCode::OK); let body: serde_json::Value = response.json(); @@ -190,11 +217,30 @@ mod tests { body["data"]["attributes"]["locked_at"], serde_json::Value::Null ); - // TODO: have test coverage on deactivated_at timestamp - // The user should be reactivated on the homeserver + let skip_reactivate = skip_reactivate.unwrap_or(false); + assert_eq!( + body["data"]["attributes"]["deactivated_at"], + if !skip_reactivate { + serde_json::Value::Null + } else { + serde_json::json!(state.clock.now()) + } + ); + + // Check whether the user should be reactivated on the homeserver let mx_user = state.homeserver_connection.query_user(&mxid).await.unwrap(); - assert!(!mx_user.deactivated); + assert_eq!(mx_user.deactivated, skip_reactivate); + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_unlock_deactivated_user(pool: PgPool) { + test_unlock_deactivated_user_helper(pool, Option::None).await; + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_unlock_deactivated_user_skip_reactivate(pool: PgPool) { + test_unlock_deactivated_user_helper(pool, Option::Some(true)).await; } #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] diff --git a/crates/handlers/src/graphql/mutations/user.rs b/crates/handlers/src/graphql/mutations/user.rs index a403d95ce..a5d7e0fc2 100644 --- a/crates/handlers/src/graphql/mutations/user.rs +++ b/crates/handlers/src/graphql/mutations/user.rs @@ -590,7 +590,7 @@ impl UserMutations { matrix.reactivate_user(&mxid).await?; // Now unlock the user in our database - let user = repo.user().unlock(user).await?; + let user = repo.user().reactivate_and_unlock(user).await?; repo.save().await?; diff --git a/crates/storage-pg/.sqlx/query-3e2d1ce1c7aba2952ed9c659972a18ded5613186104695524e85df9b6641ea4e.json b/crates/storage-pg/.sqlx/query-3e2d1ce1c7aba2952ed9c659972a18ded5613186104695524e85df9b6641ea4e.json new file mode 100644 index 000000000..738adae1c --- /dev/null +++ b/crates/storage-pg/.sqlx/query-3e2d1ce1c7aba2952ed9c659972a18ded5613186104695524e85df9b6641ea4e.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE users\n SET deactivated_at = NULL, locked_at = NULL\n WHERE user_id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "3e2d1ce1c7aba2952ed9c659972a18ded5613186104695524e85df9b6641ea4e" +} diff --git a/crates/storage-pg/src/user/mod.rs b/crates/storage-pg/src/user/mod.rs index 6abc29d9a..a1f321e7c 100644 --- a/crates/storage-pg/src/user/mod.rs +++ b/crates/storage-pg/src/user/mod.rs @@ -417,6 +417,40 @@ impl UserRepository for PgUserRepository<'_> { Ok(user) } + #[tracing::instrument( + name = "db.user.reactivate_and_unlock", + skip_all, + fields( + db.query.text, + %user.id, + ), + err, + )] + async fn reactivate_and_unlock(&mut self, mut user: User) -> Result { + if user.deactivated_at.is_none() && user.locked_at.is_none() { + return Ok(user); + } + + let res = sqlx::query!( + r#" + UPDATE users + SET deactivated_at = NULL, locked_at = NULL + WHERE user_id = $1 + "#, + Uuid::from(user.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + user.deactivated_at = None; + user.locked_at = None; + + Ok(user) + } + #[tracing::instrument( name = "db.user.set_can_request_admin", skip_all, diff --git a/crates/storage/src/queue/tasks.rs b/crates/storage/src/queue/tasks.rs index 87fb41486..f59971ba4 100644 --- a/crates/storage/src/queue/tasks.rs +++ b/crates/storage/src/queue/tasks.rs @@ -257,26 +257,21 @@ impl InsertableJob for DeactivateUserJob { const QUEUE_NAME: &'static str = "deactivate-user"; } -/// A job to reactivate and optionally unlock a user +/// A job to reactivate and unlock a user #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ReactivateUserJob { user_id: Ulid, - unlock: bool, } impl ReactivateUserJob { - /// Create a new job to reactivate a user + /// Create a new job to reactivate and unlock a user /// /// # Parameters /// /// * `user` - The user to reactivate - /// * `unlock` - Whether the user should be unlocked on reactivation #[must_use] - pub fn new(user: &User, unlock: bool) -> Self { - Self { - user_id: user.id, - unlock, - } + pub fn new(user: &User) -> Self { + Self { user_id: user.id } } /// The ID of the user to reactivate @@ -284,12 +279,6 @@ impl ReactivateUserJob { pub fn user_id(&self) -> Ulid { self.user_id } - - /// Whether the user should be unlocked on reactivation - #[must_use] - pub fn unlock(&self) -> bool { - self.unlock - } } impl InsertableJob for ReactivateUserJob { diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index f864157b1..f990af3e2 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -257,6 +257,19 @@ pub trait UserRepository: Send + Sync { /// Returns [`Self::Error`] if the underlying repository fails async fn reactivate(&mut self, user: User) -> Result; + /// Reactivate and unlock a [`User`] + /// + /// Returns the reactivated and unlocked [`User`] + /// + /// # Parameters + /// + /// * `user`: The [`User`] to reactivate and unlock + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn reactivate_and_unlock(&mut self, user: User) -> Result; + /// Set whether a [`User`] can request admin /// /// Returns the [`User`] with the new `can_request_admin` value @@ -329,6 +342,7 @@ repository_impl!(UserRepository: async fn unlock(&mut self, user: User) -> Result; async fn deactivate(&mut self, clock: &dyn Clock, user: User) -> Result; async fn reactivate(&mut self, user: User) -> Result; + async fn reactivate_and_unlock(&mut self, user: User) -> Result; async fn set_can_request_admin( &mut self, user: User, diff --git a/crates/tasks/src/user.rs b/crates/tasks/src/user.rs index 290dab28f..01864764a 100644 --- a/crates/tasks/src/user.rs +++ b/crates/tasks/src/user.rs @@ -137,25 +137,9 @@ impl RunnableJob for ReactivateUserJob { .await .map_err(JobError::retry)?; - // Now reactivate the user in our database - let user = repo - .user() - .reactivate(user) - .await - .context("Failed to reactivate user") - .map_err(JobError::retry)?; - - if self.unlock() { - // We want to unlock the user from our side only once it has been reactivated on - // the homeserver - let _user = repo - .user() - .unlock(user) - .await - .context("Failed to unlock user") - .map_err(JobError::retry)?; - } - + // We want to unlock the user from our side only once it has been reactivated on + // the homeserver + let _user = repo.user().reactivate_and_unlock(user).await.map_err(JobError::retry)?; repo.save().await.map_err(JobError::retry)?; Ok(()) diff --git a/docs/api/spec.json b/docs/api/spec.json index b4d07e84d..28d394254 100644 --- a/docs/api/spec.json +++ b/docs/api/spec.json @@ -1359,6 +1359,15 @@ "style": "simple" } ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DeactivateUserRequest" + } + } + } + }, "responses": { "200": { "description": "User was deactivated", @@ -1568,6 +1577,15 @@ "style": "simple" } ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnlockUserRequest" + } + } + } + }, "responses": { "200": { "description": "User was unlocked", @@ -3942,6 +3960,28 @@ } } }, + "DeactivateUserRequest": { + "title": "JSON payload for the `POST /api/admin/v1/users/:id/deactivate` endpoint", + "type": "object", + "properties": { + "skip_lock": { + "description": "Whether to skip locking the user before deactivation.", + "default": false, + "type": "boolean" + } + } + }, + "UnlockUserRequest": { + "title": "JSON payload for the `POST /api/admin/v1/users/:id/unlock` endpoint", + "type": "object", + "properties": { + "skip_reactivate": { + "description": "Whether to skip ensuring the user is active upon being unlocked.", + "default": false, + "type": "boolean" + } + } + }, "UserEmailFilter": { "type": "object", "properties": {