Add support for locking to the mock homeserver and use in tests

This commit is contained in:
Olivier 'reivilibre
2026-03-17 11:19:34 +00:00
parent c33880d54f
commit fe5284a3ee
3 changed files with 83 additions and 9 deletions

View File

@@ -102,7 +102,11 @@ mod tests {
use chrono::Duration; use chrono::Duration;
use hyper::{Request, StatusCode}; use hyper::{Request, StatusCode};
use mas_data_model::Clock; use mas_data_model::Clock;
use mas_storage::{RepositoryAccess, user::UserRepository}; use mas_storage::{
RepositoryAccess,
queue::{ProvisionUserJob, QueueJobRepositoryExt},
user::UserRepository,
};
use sqlx::PgPool; use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup}; use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
@@ -119,8 +123,25 @@ mod tests {
.add(&mut state.rng(), &state.clock, "alice".to_owned()) .add(&mut state.rng(), &state.clock, "alice".to_owned())
.await .await
.unwrap(); .unwrap();
repo.queue_job()
.schedule_job(&mut state.rng(), &state.clock, ProvisionUserJob::new(&user))
.await
.unwrap();
repo.save().await.unwrap(); repo.save().await.unwrap();
state.run_jobs_in_queue().await;
assert!(
!state
.homeserver_connection
.query_user_raw("alice")
.await
.unwrap()
.locked,
"User should not be locked at start of test"
);
let request = Request::post(format!("/api/admin/v1/users/{}/lock", user.id)) let request = Request::post(format!("/api/admin/v1/users/{}/lock", user.id))
.bearer(&token) .bearer(&token)
.empty(); .empty();
@@ -133,6 +154,17 @@ mod tests {
body["data"]["attributes"]["locked_at"], body["data"]["attributes"]["locked_at"],
serde_json::json!(state.clock.now()) serde_json::json!(state.clock.now())
); );
state.run_jobs_in_queue().await;
assert!(
state
.homeserver_connection
.query_user_raw("alice")
.await
.unwrap()
.locked,
"User should be locked"
);
} }
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]

View File

@@ -102,7 +102,11 @@ mod tests {
use hyper::{Request, StatusCode}; use hyper::{Request, StatusCode};
use mas_data_model::Clock; use mas_data_model::Clock;
use mas_matrix::{HomeserverConnection, ProvisionRequest}; use mas_matrix::{HomeserverConnection, ProvisionRequest};
use mas_storage::{RepositoryAccess, user::UserRepository}; use mas_storage::{
RepositoryAccess,
queue::{ProvisionUserJob, QueueJobRepositoryExt},
user::UserRepository,
};
use sqlx::PgPool; use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup}; use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
@@ -120,16 +124,27 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let user = repo.user().lock(&state.clock, user).await.unwrap(); let user = repo.user().lock(&state.clock, user).await.unwrap();
repo.save().await.unwrap();
// Also provision the user on the homeserver, because this endpoint will try to // Also provision the user on the homeserver, because this endpoint will try to
// reactivate it // reactivate it
state repo.queue_job()
.homeserver_connection .schedule_job(&mut state.rng(), &state.clock, ProvisionUserJob::new(&user))
.provision_user(&ProvisionRequest::new(&user.username, &user.sub, false))
.await .await
.unwrap(); .unwrap();
repo.save().await.unwrap();
state.run_jobs_in_queue().await;
assert!(
state
.homeserver_connection
.query_user_raw("alice")
.await
.unwrap()
.locked,
"User should be locked at start of test"
);
let request = Request::post(format!("/api/admin/v1/users/{}/unlock", user.id)) let request = Request::post(format!("/api/admin/v1/users/{}/unlock", user.id))
.bearer(&token) .bearer(&token)
.empty(); .empty();
@@ -141,6 +156,17 @@ mod tests {
body["data"]["attributes"]["locked_at"], body["data"]["attributes"]["locked_at"],
serde_json::Value::Null serde_json::Value::Null
); );
state.run_jobs_in_queue().await;
assert!(
!state
.homeserver_connection
.query_user_raw("alice")
.await
.unwrap()
.locked,
"User should not be locked"
);
} }
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]

View File

@@ -10,9 +10,10 @@ use anyhow::Context;
use async_trait::async_trait; use async_trait::async_trait;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::{MatrixUser, ProvisionRequest}; use crate::{HomeserverConnection as _, MatrixUser, ProvisionRequest};
struct MockUser { #[derive(Clone)]
pub struct MockUser {
sub: String, sub: String,
avatar_url: Option<String>, avatar_url: Option<String>,
displayname: Option<String>, displayname: Option<String>,
@@ -20,6 +21,7 @@ struct MockUser {
emails: Option<Vec<String>>, emails: Option<Vec<String>>,
cross_signing_reset_allowed: bool, cross_signing_reset_allowed: bool,
deactivated: bool, deactivated: bool,
pub locked: bool,
} }
/// A mock implementation of a [`HomeserverConnection`], which never fails and /// A mock implementation of a [`HomeserverConnection`], which never fails and
@@ -50,6 +52,18 @@ impl HomeserverConnection {
pub async fn reserve_localpart(&self, localpart: &'static str) { pub async fn reserve_localpart(&self, localpart: &'static str) {
self.reserved_localparts.write().await.insert(localpart); self.reserved_localparts.write().await.insert(localpart);
} }
/// Like `query_user` but get the raw test state of the user.
///
/// # Errors
///
/// Will fail if the user doesn't exist.
pub async fn query_user_raw(&self, localpart: &str) -> Result<MockUser, anyhow::Error> {
let mxid = self.mxid(localpart);
let users = self.users.read().await;
let user = users.get(&mxid).context("User not found")?;
Ok(user.clone())
}
} }
#[async_trait] #[async_trait]
@@ -85,6 +99,7 @@ impl crate::HomeserverConnection for HomeserverConnection {
emails: None, emails: None,
cross_signing_reset_allowed: false, cross_signing_reset_allowed: false,
deactivated: false, deactivated: false,
locked: false,
}); });
anyhow::ensure!( anyhow::ensure!(
@@ -104,6 +119,8 @@ impl crate::HomeserverConnection for HomeserverConnection {
user.avatar_url = avatar_url.map(ToOwned::to_owned); user.avatar_url = avatar_url.map(ToOwned::to_owned);
}); });
user.locked = request.locked();
Ok(inserted) Ok(inserted)
} }
@@ -219,7 +236,6 @@ impl crate::HomeserverConnection for HomeserverConnection {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::HomeserverConnection as _;
#[tokio::test] #[tokio::test]
async fn test_mock_connection() { async fn test_mock_connection() {