Merge pull request #1413 from matrix-org/quenting/user-lock

Add a way to lock and deprovision users
This commit is contained in:
Quentin Gliech
2023-08-03 15:52:53 +02:00
committed by GitHub
35 changed files with 982 additions and 189 deletions

3
Cargo.lock generated
View File

@@ -3371,9 +3371,11 @@ dependencies = [
name = "mas-matrix"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"http",
"serde",
"tokio",
"url",
]
@@ -3389,6 +3391,7 @@ dependencies = [
"mas-matrix",
"serde",
"tower",
"tracing",
"url",
]

View File

@@ -18,14 +18,14 @@ use mas_config::{DatabaseConfig, PasswordsConfig};
use mas_data_model::{Device, TokenType};
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatSessionRepository},
job::{DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
Repository, RepositoryAccess, SystemClock,
};
use mas_storage_pg::PgRepository;
use rand::SeedableRng;
use sqlx::types::Uuid;
use tracing::{info, info_span};
use tracing::{info, info_span, warn};
use crate::util::{database_from_config, password_manager_from_config};
@@ -69,6 +69,22 @@ enum Subcommand {
#[arg(long)]
dry_run: bool,
},
/// Lock a user
LockUser {
/// User to lock
username: String,
/// Whether to deactivate the user
#[arg(long)]
deactivate: bool,
},
/// Unlock a user
UnlockUser {
/// User to unlock
username: String,
},
}
impl Options {
@@ -330,6 +346,59 @@ impl Options {
Ok(())
}
SC::LockUser {
username,
deactivate,
} => {
let _span = info_span!("cli.manage.lock_user", user.username = username).entered();
let config: DatabaseConfig = root.load_config()?;
let pool = database_from_config(&config).await?;
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
let user = repo
.user()
.find_by_username(&username)
.await?
.context("User not found")?;
info!(%user.id, "Locking user");
// Even though the deactivation job will lock the user, we lock it here in case
// the worker is not running, as we don't have a good way to run a job
// synchronously yet.
let user = repo.user().lock(&clock, user).await?;
if deactivate {
warn!(%user.id, "Scheduling user deactivation");
repo.job()
.schedule_job(DeactivateUserJob::new(&user, false))
.await?;
}
repo.save().await?;
Ok(())
}
SC::UnlockUser { username } => {
let _span = info_span!("cli.manage.lock_user", user.username = username).entered();
let config: DatabaseConfig = root.load_config()?;
let pool = database_from_config(&config).await?;
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
let user = repo
.user()
.find_by_username(&username)
.await?
.context("User not found")?;
info!(%user.id, "Unlocking user");
repo.user().unlock(user).await?;
repo.save().await?;
Ok(())
}
}
}
}

View File

@@ -25,6 +25,16 @@ pub struct User {
pub username: String,
pub sub: String,
pub primary_user_email_id: Option<Ulid>,
pub created_at: DateTime<Utc>,
pub locked_at: Option<DateTime<Utc>>,
}
impl User {
/// Returns `true` unless the user is locked.
#[must_use]
pub fn is_valid(&self) -> bool {
self.locked_at.is_none()
}
}
impl User {
@@ -35,6 +45,8 @@ impl User {
username: "john".to_owned(),
sub: "123-456".to_owned(),
primary_user_email_id: None,
created_at: now,
locked_at: None,
}]
}
}
@@ -65,7 +77,7 @@ pub struct BrowserSession {
impl BrowserSession {
#[must_use]
pub fn active(&self) -> bool {
self.finished_at.is_none()
self.finished_at.is_none() && self.user.is_valid()
}
}

View File

@@ -335,6 +335,7 @@ async fn token_login(
.user()
.lookup(session.user_id)
.await?
.filter(mas_data_model::User::is_valid)
.ok_or(RouteError::UserNotFound)?;
repo.compat_sso_login().exchange(clock, login).await?;
@@ -355,6 +356,7 @@ async fn user_password_login(
.user()
.find_by_username(&username)
.await?
.filter(mas_data_model::User::is_valid)
.ok_or(RouteError::UserNotFound)?;
// Lookup its password

View File

@@ -16,23 +16,24 @@ use std::sync::Arc;
use axum::{
extract::{Path, State},
response::{IntoResponse, Response},
response::{Html, IntoResponse, Response},
};
use axum_extra::extract::PrivateCookieJar;
use hyper::StatusCode;
use mas_axum_utils::SessionInfoExt;
use mas_axum_utils::{csrf::CsrfExt, SessionInfoExt};
use mas_data_model::{AuthorizationGrant, BrowserSession, Client, Device};
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_policy::{EvaluationResult, PolicyFactory};
use mas_router::{PostAuthAction, Route, UrlBuilder};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
user::BrowserSessionRepository,
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
};
use mas_templates::Templates;
use mas_templates::{PolicyViolationContext, TemplateContext, Templates};
use oauth2_types::requests::AuthorizationResponse;
use thiserror::Error;
use tracing::warn;
use ulid::Ulid;
use super::callback::CallbackDestination;
@@ -74,6 +75,7 @@ impl IntoResponse for RouteError {
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
@@ -87,7 +89,7 @@ impl_from_error_for_route!(super::callback::CallbackDestinationError);
err,
)]
pub(crate) async fn get(
rng: BoxRng,
mut rng: BoxRng,
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
@@ -123,15 +125,15 @@ pub(crate) async fn get(
.ok_or(RouteError::NoSuchClient)?;
match complete(
rng,
clock,
&mut rng,
&clock,
repo,
key_store,
&policy_factory,
url_builder,
grant,
client,
session,
&client,
&session,
)
.await
{
@@ -144,10 +146,22 @@ pub(crate) async fn get(
mas_router::Reauth::and_then(continue_grant).go(),
)
.into_response()),
Err(GrantCompletionError::RequiresConsent | GrantCompletionError::PolicyViolation) => {
Err(GrantCompletionError::RequiresConsent) => {
let next = mas_router::Consent(grant_id);
Ok((cookie_jar, next.go()).into_response())
}
Err(GrantCompletionError::PolicyViolation(grant, res)) => {
warn!(violation = ?res, "Authorization grant for client {} denied by policy", client.id);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let ctx = PolicyViolationContext::new(grant, client)
.with_session(session)
.with_csrf(csrf_token.form_value());
let content = templates.render_policy_violation(&ctx).await?;
Ok((cookie_jar, Html(content)).into_response())
}
Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending),
Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)),
}
@@ -168,7 +182,7 @@ pub enum GrantCompletionError {
RequiresConsent,
#[error("denied by the policy")]
PolicyViolation,
PolicyViolation(AuthorizationGrant, EvaluationResult),
}
impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError);
@@ -179,15 +193,15 @@ impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError);
impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError);
pub(crate) async fn complete(
mut rng: BoxRng,
clock: BoxClock,
rng: &mut (impl rand::RngCore + rand::CryptoRng + Send),
clock: &impl Clock,
mut repo: BoxRepository,
key_store: Keystore,
policy_factory: &PolicyFactory,
url_builder: UrlBuilder,
grant: AuthorizationGrant,
client: Client,
browser_session: BrowserSession,
client: &Client,
browser_session: &BrowserSession,
) -> Result<AuthorizationResponse, GrantCompletionError> {
// Verify that the grant is in a pending stage
if !grant.stage.is_pending() {
@@ -197,7 +211,7 @@ pub(crate) async fn complete(
// Check if the authentication is fresh enough
let authentication = repo
.browser_session()
.get_last_authentication(&browser_session)
.get_last_authentication(browser_session)
.await?;
let authentication = authentication.filter(|auth| auth.created_at > grant.max_auth_time());
@@ -209,16 +223,16 @@ pub(crate) async fn complete(
// Run through the policy
let mut policy = policy_factory.instantiate().await?;
let res = policy
.evaluate_authorization_grant(&grant, &client, &browser_session.user)
.evaluate_authorization_grant(&grant, client, &browser_session.user)
.await?;
if !res.valid() {
return Err(GrantCompletionError::PolicyViolation);
return Err(GrantCompletionError::PolicyViolation(grant, res));
}
let current_consent = repo
.oauth2_client()
.get_consent_for_user(&client, &browser_session.user)
.get_consent_for_user(client, &browser_session.user)
.await?;
let lacks_consent = grant
@@ -236,18 +250,12 @@ pub(crate) async fn complete(
// All good, let's start the session
let session = repo
.oauth2_session()
.add(
&mut rng,
&clock,
&client,
&browser_session,
grant.scope.clone(),
)
.add(rng, clock, client, browser_session, grant.scope.clone())
.await?;
let grant = repo
.oauth2_authorization_grant()
.fulfill(&clock, &session, grant)
.fulfill(clock, &session, grant)
.await?;
// Yep! Let's complete the auth now
@@ -256,13 +264,13 @@ pub(crate) async fn complete(
// Did they request an ID token?
if grant.response_type_id_token {
params.id_token = Some(generate_id_token(
&mut rng,
&clock,
rng,
clock,
&url_builder,
&key_store,
&client,
client,
&grant,
&browser_session,
browser_session,
None,
Some(&valid_authentication),
)?);

View File

@@ -16,11 +16,11 @@ use std::sync::Arc;
use axum::{
extract::{Form, State},
response::{IntoResponse, Response},
response::{Html, IntoResponse, Response},
};
use axum_extra::extract::PrivateCookieJar;
use hyper::StatusCode;
use mas_axum_utils::SessionInfoExt;
use mas_axum_utils::{csrf::CsrfExt, SessionInfoExt};
use mas_data_model::{AuthorizationCode, Pkce};
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
@@ -29,7 +29,7 @@ use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::Templates;
use mas_templates::{PolicyViolationContext, TemplateContext, Templates};
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
pkce,
@@ -39,6 +39,7 @@ use oauth2_types::{
use rand::{distributions::Alphanumeric, Rng};
use serde::Deserialize;
use thiserror::Error;
use tracing::warn;
use self::{callback::CallbackDestination, complete::GrantCompletionError};
use crate::impl_from_error_for_route;
@@ -91,6 +92,7 @@ impl IntoResponse for RouteError {
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(self::callback::CallbackDestinationError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
@@ -170,6 +172,7 @@ pub(crate) async fn get(
// Get the session info from the cookie
let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
// One day, we will have try blocks
let res: Result<Response, RouteError> = ({
@@ -340,15 +343,15 @@ pub(crate) async fn get(
Some(user_session) if prompt.contains(&Prompt::None) => {
// With prompt=none, we should get back to the client immediately
match self::complete::complete(
rng,
clock,
&mut rng,
&clock,
repo,
key_store,
&policy_factory,
url_builder,
grant,
client,
user_session,
&client,
&user_session,
)
.await
{
@@ -369,7 +372,7 @@ pub(crate) async fn get(
)
.await?
}
Err(GrantCompletionError::PolicyViolation) => {
Err(GrantCompletionError::PolicyViolation(_grant, _res)) => {
callback_destination
.go(&templates, ClientError::from(ClientErrorCode::AccessDenied))
.await?
@@ -387,29 +390,32 @@ pub(crate) async fn get(
let grant_id = grant.id;
// Else, we show the relevant reauth/consent page if necessary
match self::complete::complete(
rng,
clock,
&mut rng,
&clock,
repo,
key_store,
&policy_factory,
url_builder,
grant,
client,
user_session,
&client,
&user_session,
)
.await
{
Ok(params) => callback_destination.go(&templates, params).await?,
Err(
GrantCompletionError::RequiresConsent
| GrantCompletionError::PolicyViolation,
) => {
// We're redirecting to the consent URI in both 'consent required' and
// 'policy violation' cases, because we reevaluate the policy on this
// page, and show the error accordingly
// XXX: is this the right approach?
Err(GrantCompletionError::RequiresConsent) => {
mas_router::Consent(grant_id).go().into_response()
}
Err(GrantCompletionError::PolicyViolation(grant, res)) => {
warn!(violation = ?res, "Authorization grant for client {} denied by policy", client.id);
let ctx = PolicyViolationContext::new(grant, client)
.with_session(user_session)
.with_csrf(csrf_token.form_value());
let content = templates.render_policy_violation(&ctx).await?;
Html(content).into_response()
}
Err(GrantCompletionError::RequiresReauth) => {
mas_router::Reauth::and_then(continue_grant)
.go()

View File

@@ -188,6 +188,7 @@ pub(crate) async fn post(
.browser_session()
.lookup(session.user_session_id)
.await?
.filter(|b| b.user.is_valid())
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
@@ -227,6 +228,7 @@ pub(crate) async fn post(
.browser_session()
.lookup(session.user_session_id)
.await?
.filter(|b| b.user.is_valid())
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
@@ -265,6 +267,7 @@ pub(crate) async fn post(
.user()
.lookup(session.user_id)
.await?
.filter(mas_data_model::User::is_valid)
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
@@ -311,6 +314,7 @@ pub(crate) async fn post(
.user()
.lookup(session.user_id)
.await?
.filter(mas_data_model::User::is_valid)
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;

View File

@@ -23,7 +23,7 @@ use headers::{Authorization, ContentType, HeaderMapExt, HeaderName};
use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode};
use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
use mas_matrix::{HomeserverConnection, MatrixUser, ProvisionRequest};
use mas_matrix::MockHomeserverConnection;
use mas_policy::PolicyFactory;
use mas_router::{SimpleRoute, UrlBuilder};
use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository};
@@ -69,40 +69,6 @@ pub(crate) struct TestState {
pub rng: Arc<Mutex<ChaChaRng>>,
}
/// A Mock implementation of a [`HomeserverConnection`], which never fails and
/// doesn't do anything.
struct MockHomeserverConnection {
homeserver: String,
}
#[async_trait]
impl HomeserverConnection for MockHomeserverConnection {
type Error = anyhow::Error;
fn homeserver(&self) -> &str {
&self.homeserver
}
async fn query_user(&self, _mxid: &str) -> Result<MatrixUser, Self::Error> {
Ok(MatrixUser {
displayname: None,
avatar_url: None,
})
}
async fn provision_user(&self, _request: &ProvisionRequest) -> Result<bool, Self::Error> {
Ok(false)
}
async fn create_device(&self, _mxid: &str, _device_id: &str) -> Result<(), Self::Error> {
Ok(())
}
async fn delete_device(&self, _mxid: &str, _device_id: &str) -> Result<(), Self::Error> {
Ok(())
}
}
impl TestState {
/// Create a new test state from the given database pool
pub async fn from_pool(pool: PgPool) -> Result<Self, anyhow::Error> {
@@ -145,9 +111,7 @@ impl TestState {
)
.await?;
let homeserver_connection = MockHomeserverConnection {
homeserver: "example.com".to_owned(),
};
let homeserver_connection = MockHomeserverConnection::new("example.com");
let policy_factory = Arc::new(policy_factory);

View File

@@ -23,7 +23,7 @@ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm},
SessionInfoExt,
};
use mas_data_model::UpstreamOAuthProviderImportPreference;
use mas_data_model::{UpstreamOAuthProviderImportPreference, User};
use mas_jose::jwt::Jwt;
use mas_keystore::Encrypter;
use mas_storage::{
@@ -239,6 +239,8 @@ pub(crate) async fn get(
.user()
.lookup(user_id)
.await?
// XXX: is that right?
.filter(User::is_valid)
.ok_or(RouteError::UserNotFound)?;
let ctx = UpstreamExistingLinkContext::new(user)
@@ -263,6 +265,7 @@ pub(crate) async fn get(
.user()
.lookup(user_id)
.await?
.filter(mas_data_model::User::is_valid)
.ok_or(RouteError::UserNotFound)?;
let ctx = UpstreamExistingLinkContext::new(user).with_csrf(csrf_token.form_value());
@@ -390,6 +393,7 @@ pub(crate) async fn post(
.user()
.lookup(user_id)
.await?
.filter(mas_data_model::User::is_valid)
.ok_or(RouteError::UserNotFound)?;
repo.browser_session().add(&mut rng, &clock, &user).await?

View File

@@ -202,6 +202,7 @@ async fn login(
.find_by_username(username)
.await
.map_err(|_e| FormError::Internal)?
.filter(mas_data_model::User::is_valid)
.ok_or(FormError::InvalidCredentials)?;
// And its password

View File

@@ -9,9 +9,10 @@ license = "Apache-2.0"
anyhow = "1.0.72"
async-trait = "0.1.72"
http = "0.2.9"
url = "2.4.0"
serde = { version = "1.0.180", features = ["derive"] }
tower = { version = "0.4.13", features = ["util"] }
tracing = "0.1.37"
url = "2.4.0"
mas-axum-utils = { path = "../axum-utils" }
mas-http = { path = "../http" }

View File

@@ -124,6 +124,11 @@ struct SynapseDevice {
device_id: String,
}
#[derive(Serialize)]
struct SynapseDeactivateUserRequest {
erase: bool,
}
#[async_trait::async_trait]
impl HomeserverConnection for SynapseConnection {
type Error = anyhow::Error;
@@ -132,6 +137,15 @@ impl HomeserverConnection for SynapseConnection {
&self.homeserver
}
#[tracing::instrument(
name = "homeserver.query_user",
skip_all,
fields(
matrix.homeserver = self.homeserver,
matrix.mxid = mxid,
),
err(Display),
)]
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error> {
let mut client = self
.http_client_factory
@@ -158,6 +172,16 @@ impl HomeserverConnection for SynapseConnection {
})
}
#[tracing::instrument(
name = "homeserver.provision_user",
skip_all,
fields(
matrix.homeserver = self.homeserver,
matrix.mxid = request.mxid(),
user.id = request.sub(),
),
err(Display),
)]
async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, Self::Error> {
let mut body = SynapseUser {
external_ids: Some(vec![ExternalID {
@@ -213,6 +237,16 @@ impl HomeserverConnection for SynapseConnection {
}
}
#[tracing::instrument(
name = "homeserver.create_device",
skip_all,
fields(
matrix.homeserver = self.homeserver,
matrix.mxid = mxid,
matrix.device_id = device_id,
),
err(Display),
)]
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
let mut client = self
.http_client_factory
@@ -236,6 +270,16 @@ impl HomeserverConnection for SynapseConnection {
Ok(())
}
#[tracing::instrument(
name = "homeserver.delete_device",
skip_all,
fields(
matrix.homeserver = self.homeserver,
matrix.mxid = mxid,
matrix.device_id = device_id,
),
err(Display),
)]
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
let mut client = self.http_client_factory.client().await?;
@@ -253,4 +297,35 @@ impl HomeserverConnection for SynapseConnection {
Ok(())
}
#[tracing::instrument(
name = "homeserver.delete_user",
skip_all,
fields(
matrix.homeserver = self.homeserver,
matrix.mxid = mxid,
erase = erase,
),
err(Display),
)]
async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> {
let mut client = self
.http_client_factory
.client()
.await?
.request_bytes_to_body()
.json_request();
let request = self
.post(&format!("_synapse/admin/v1/deactivate/{mxid}"))
.body(SynapseDeactivateUserRequest { erase })?;
let response = client.ready().await?.call(request).await?;
if response.status() != StatusCode::OK {
return Err(anyhow::anyhow!("Failed to delete user in Synapse"));
}
Ok(())
}
}

View File

@@ -6,7 +6,9 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
anyhow = "1.0.71"
serde = { version = "1.0.180", features = ["derive"] }
async-trait = "0.1.72"
http = "0.2.9"
tokio = { version = "1.28.2", features = ["sync", "macros", "rt"] }
url = "2.4.0"

View File

@@ -16,6 +16,10 @@
#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)]
#![warn(clippy::pedantic)]
mod mock;
pub use self::mock::HomeserverConnection as MockHomeserverConnection;
#[derive(Debug)]
pub struct MatrixUser {
pub displayname: Option<String>,
@@ -39,39 +43,58 @@ pub struct ProvisionRequest {
}
impl ProvisionRequest {
/// Create a new [`ProvisionRequest`].
///
/// # Parameters
///
/// * `mxid` - The Matrix ID to provision.
/// * `sub` - The `sub` of the user, aka the internal ID.
#[must_use]
pub fn new(mxid: String, sub: String) -> Self {
pub fn new(mxid: impl Into<String>, sub: impl Into<String>) -> Self {
Self {
mxid,
sub,
mxid: mxid.into(),
sub: sub.into(),
displayname: FieldAction::DoNothing,
avatar_url: FieldAction::DoNothing,
emails: FieldAction::DoNothing,
}
}
/// Get the `sub` of the user to provision, aka the internal ID.
#[must_use]
pub fn sub(&self) -> &str {
&self.sub
}
/// Get the Matrix ID to provision.
#[must_use]
pub fn mxid(&self) -> &str {
&self.mxid
}
/// Ask to set the displayname of the user.
///
/// # Parameters
///
/// * `displayname` - The displayname to set.
#[must_use]
pub fn set_displayname(mut self, displayname: String) -> Self {
self.displayname = FieldAction::Set(displayname);
self
}
/// Ask to unset the displayname of the user.
#[must_use]
pub fn unset_displayname(mut self) -> Self {
self.displayname = FieldAction::Unset;
self
}
/// Call the given callback if the displayname should be set or unset.
///
/// # Parameters
///
/// * `callback` - The callback to call.
pub fn on_displayname(&self, callback: impl FnOnce(Option<&str>)) -> &Self {
match &self.displayname {
FieldAction::Unset => callback(None),
@@ -82,18 +105,29 @@ impl ProvisionRequest {
self
}
/// Ask to set the avatar URL of the user.
///
/// # Parameters
///
/// * `avatar_url` - The avatar URL to set.
#[must_use]
pub fn set_avatar_url(mut self, avatar_url: String) -> Self {
self.avatar_url = FieldAction::Set(avatar_url);
self
}
/// Ask to unset the avatar URL of the user.
#[must_use]
pub fn unset_avatar_url(mut self) -> Self {
self.avatar_url = FieldAction::Unset;
self
}
/// Call the given callback if the avatar URL should be set or unset.
///
/// # Parameters
///
/// * `callback` - The callback to call.
pub fn on_avatar_url(&self, callback: impl FnOnce(Option<&str>)) -> &Self {
match &self.avatar_url {
FieldAction::Unset => callback(None),
@@ -104,18 +138,29 @@ impl ProvisionRequest {
self
}
/// Ask to set the emails of the user.
///
/// # Parameters
///
/// * `emails` - The list of emails to set.
#[must_use]
pub fn set_emails(mut self, emails: Vec<String>) -> Self {
self.emails = FieldAction::Set(emails);
self
}
/// Ask to unset the emails of the user.
#[must_use]
pub fn unset_emails(mut self) -> Self {
self.emails = FieldAction::Unset;
self
}
/// Call the given callback if the emails should be set or unset.
///
/// # Parameters
///
/// * `callback` - The callback to call.
pub fn on_emails(&self, callback: impl FnOnce(Option<&[String]>)) -> &Self {
match &self.emails {
FieldAction::Unset => callback(None),
@@ -129,17 +174,84 @@ impl ProvisionRequest {
#[async_trait::async_trait]
pub trait HomeserverConnection: Send + Sync {
/// The error type returned by all methods.
type Error;
/// Get the homeserver URL.
fn homeserver(&self) -> &str;
/// Get the Matrix ID of the user with the given localpart.
///
/// # Parameters
///
/// * `localpart` - The localpart of the user.
fn mxid(&self, localpart: &str) -> String {
format!("@{}:{}", localpart, self.homeserver())
}
/// Query the state of a user on the homeserver.
///
/// # Parameters
///
/// * `mxid` - The Matrix ID of the user to query.
///
/// # Errors
///
/// Returns an error if the homeserver is unreachable or the user does not
/// exist.
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error>;
/// Provision a user on the homeserver.
///
/// # Parameters
///
/// * `request` - a [`ProvisionRequest`] containing the details of the user
/// to provision.
///
/// # Errors
///
/// Returns an error if the homeserver is unreachable or the user could not
/// be provisioned.
async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, Self::Error>;
/// Create a device for a user on the homeserver.
///
/// # Parameters
///
/// * `mxid` - The Matrix ID of the user to create a device for.
/// * `device_id` - The device ID to create.
///
/// # Errors
///
/// Returns an error if the homeserver is unreachable or the device could
/// not be created.
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error>;
/// Delete a device for a user on the homeserver.
///
/// # Parameters
///
/// * `mxid` - The Matrix ID of the user to delete a device for.
/// * `device_id` - The device ID to delete.
///
/// # Errors
///
/// Returns an error if the homeserver is unreachable or the device could
/// not be deleted.
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error>;
/// Delete a user on the homeserver.
///
/// # Parameters
///
/// * `mxid` - The Matrix ID of the user to delete.
/// * `erase` - Whether to ask the homeserver to erase the user's data.
///
/// # Errors
///
/// Returns an error if the homeserver is unreachable or the user could not
/// be deleted.
async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error>;
}
#[async_trait::async_trait]
@@ -165,4 +277,8 @@ impl<T: HomeserverConnection + Send + Sync + ?Sized> HomeserverConnection for &T
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
(**self).delete_device(mxid, device_id).await
}
async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> {
(**self).delete_user(mxid, erase).await
}
}

169
crates/matrix/src/mock.rs Normal file
View File

@@ -0,0 +1,169 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::{HashMap, HashSet};
use anyhow::Context;
use async_trait::async_trait;
use tokio::sync::RwLock;
use crate::{MatrixUser, ProvisionRequest};
struct MockUser {
sub: String,
avatar_url: Option<String>,
displayname: Option<String>,
devices: HashSet<String>,
emails: Option<Vec<String>>,
}
/// A mock implementation of a [`HomeserverConnection`], which never fails and
/// doesn't do anything.
pub struct HomeserverConnection {
homeserver: String,
users: RwLock<HashMap<String, MockUser>>,
}
impl HomeserverConnection {
/// Create a new mock connection.
pub fn new<H>(homeserver: H) -> Self
where
H: Into<String>,
{
Self {
homeserver: homeserver.into(),
users: RwLock::new(HashMap::new()),
}
}
}
#[async_trait]
impl crate::HomeserverConnection for HomeserverConnection {
type Error = anyhow::Error;
fn homeserver(&self) -> &str {
&self.homeserver
}
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error> {
let users = self.users.read().await;
let user = users.get(mxid).context("User not found")?;
Ok(MatrixUser {
displayname: user.displayname.clone(),
avatar_url: user.avatar_url.clone(),
})
}
async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, Self::Error> {
let mut users = self.users.write().await;
let inserted = !users.contains_key(request.mxid());
let user = users.entry(request.mxid().to_owned()).or_insert(MockUser {
sub: request.sub().to_owned(),
avatar_url: None,
displayname: None,
devices: HashSet::new(),
emails: None,
});
anyhow::ensure!(
user.sub == request.sub(),
"User already provisioned with different sub"
);
request.on_emails(|emails| {
user.emails = emails.map(ToOwned::to_owned);
});
request.on_displayname(|displayname| {
user.displayname = displayname.map(ToOwned::to_owned);
});
request.on_avatar_url(|avatar_url| {
user.avatar_url = avatar_url.map(ToOwned::to_owned);
});
Ok(inserted)
}
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
let mut users = self.users.write().await;
let user = users.get_mut(mxid).context("User not found")?;
user.devices.insert(device_id.to_owned());
Ok(())
}
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
let mut users = self.users.write().await;
let user = users.get_mut(mxid).context("User not found")?;
user.devices.remove(device_id);
Ok(())
}
async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> {
let mut users = self.users.write().await;
let user = users.get_mut(mxid).context("User not found")?;
user.devices.clear();
user.emails = None;
if erase {
user.avatar_url = None;
user.displayname = None;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::HomeserverConnection as _;
#[tokio::test]
async fn test_mock_connection() {
let conn = HomeserverConnection::new("example.org");
let mxid = "@test:example.org";
let device = "test";
assert_eq!(conn.homeserver(), "example.org");
assert_eq!(conn.mxid("test"), mxid);
assert!(conn.query_user(mxid).await.is_err());
assert!(conn.create_device(mxid, device).await.is_err());
assert!(conn.delete_device(mxid, device).await.is_err());
let request = ProvisionRequest::new("@test:example.org", "test")
.set_displayname("Test User".into())
.set_avatar_url("mxc://example.org/1234567890".into())
.set_emails(vec!["test@example.org".to_owned()]);
let inserted = conn.provision_user(&request).await.unwrap();
assert!(inserted);
let user = conn.query_user("@test:example.org").await.unwrap();
assert_eq!(user.displayname, Some("Test User".into()));
assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into()));
// Deleting a non-existent device should not fail
assert!(conn.delete_device(mxid, device).await.is_ok());
// Create the device
assert!(conn.create_device(mxid, device).await.is_ok());
// Create the same device again
assert!(conn.create_device(mxid, device).await.is_ok());
// XXX: there is no API to query devices yet in the trait
// Delete the device
assert!(conn.delete_device(mxid, device).await.is_ok());
}
}

View File

@@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE users\n SET locked_at = NULL\n WHERE user_id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": []
},
"hash": "22896e8f2a002f307089c3e0f9ee561e6521c45ce07d3a42411984c9a6b75fdc"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n WHERE s.user_session_id = $1\n ",
"query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , u.created_at AS \"user_created_at\"\n , u.locked_at AS \"user_locked_at\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n WHERE s.user_session_id = $1\n ",
"describe": {
"columns": [
{
@@ -32,6 +32,16 @@
"ordinal": 5,
"name": "user_primary_user_email_id",
"type_info": "Uuid"
},
{
"ordinal": 6,
"name": "user_created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 7,
"name": "user_locked_at",
"type_info": "Timestamptz"
}
],
"parameters": {
@@ -45,8 +55,10 @@
true,
false,
false,
true,
false,
true
]
},
"hash": "25d61a373560556deafe056c8cd2982ac472f5ec2fab08b0b5275c4b78c11a7e"
"hash": "73fe61f03a41778c6273b1c2dbdb13b91fbccfe5fbdbead8c4868d52a61a0f9d"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n ",
"query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n ON CONFLICT (username) DO NOTHING\n ",
"describe": {
"columns": [],
"parameters": {
@@ -12,5 +12,5 @@
},
"nullable": []
},
"hash": "b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537"
"hash": "7f4c4634ada4dc2745530dcca8eee92abf78dfbdf1a25e58a2bc9c14be8035f0"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE username = $1\n ",
"query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n , locked_at\n FROM users\n WHERE username = $1\n ",
"describe": {
"columns": [
{
@@ -22,6 +22,11 @@
"ordinal": 3,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 4,
"name": "locked_at",
"type_info": "Timestamptz"
}
],
"parameters": {
@@ -33,8 +38,9 @@
false,
false,
true,
false
false,
true
]
},
"hash": "836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c"
"hash": "bfa5eaeaa5b4574bb083c86711eb4599f6374c96bb4a6827d400acb22fb0fd39"
}

View File

@@ -0,0 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE users\n SET locked_at = $1\n WHERE user_id = $2\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Timestamptz",
"Uuid"
]
},
"nullable": []
},
"hash": "c29fa41743811a6ac3a9b952b6ea75d18e914f823902587b63c9f295407144b1"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE user_id = $1\n ",
"query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n , locked_at\n FROM users\n WHERE user_id = $1\n ",
"describe": {
"columns": [
{
@@ -22,6 +22,11 @@
"ordinal": 3,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 4,
"name": "locked_at",
"type_info": "Timestamptz"
}
],
"parameters": {
@@ -33,8 +38,9 @@
false,
false,
true,
false
false,
true
]
},
"hash": "08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35"
"hash": "e0ea7d93ab3f565828b2faab4cc5e1a6ac868c95bfaee3a6960df1cf484d53da"
}

View File

@@ -0,0 +1,19 @@
-- Copyright 2023 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- Add a new column in on the `users` to record when an account gets locked
ALTER TABLE "users"
ADD COLUMN "locked_at"
TIMESTAMP WITH TIME ZONE
DEFAULT NULL;

View File

@@ -29,6 +29,8 @@ pub enum Users {
UserId,
Username,
PrimaryUserEmailId,
CreatedAt,
LockedAt,
}
#[derive(sea_query::Iden)]

View File

@@ -31,6 +31,7 @@ use mas_storage::{
Repository, RepositoryAccess, RepositoryTransaction,
};
use sqlx::{PgPool, Postgres, Transaction};
use tracing::Instrument;
use crate::{
compat::{
@@ -78,11 +79,21 @@ impl RepositoryTransaction for PgRepository {
type Error = DatabaseError;
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
self.txn.commit().map_err(DatabaseError::from).boxed()
let span = tracing::info_span!("db.save");
self.txn
.commit()
.map_err(DatabaseError::from)
.instrument(span)
.boxed()
}
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
self.txn.rollback().map_err(DatabaseError::from).boxed()
let span = tracing::info_span!("db.cancel");
self.txn
.rollback()
.map_err(DatabaseError::from)
.instrument(span)
.boxed()
}
}

View File

@@ -55,9 +55,8 @@ struct UserLookup {
user_id: Uuid,
username: String,
primary_user_email_id: Option<Uuid>,
#[allow(dead_code)]
created_at: DateTime<Utc>,
locked_at: Option<DateTime<Utc>>,
}
impl From<UserLookup> for User {
@@ -68,6 +67,8 @@ impl From<UserLookup> for User {
username: value.username,
sub: id.to_string(),
primary_user_email_id: value.primary_user_email_id.map(Into::into),
created_at: value.created_at,
locked_at: value.locked_at,
}
}
}
@@ -93,6 +94,7 @@ impl<'c> UserRepository for PgUserRepository<'c> {
, username
, primary_user_email_id
, created_at
, locked_at
FROM users
WHERE user_id = $1
"#,
@@ -124,6 +126,7 @@ impl<'c> UserRepository for PgUserRepository<'c> {
, username
, primary_user_email_id
, created_at
, locked_at
FROM users
WHERE username = $1
"#,
@@ -158,10 +161,11 @@ impl<'c> UserRepository for PgUserRepository<'c> {
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("user.id", tracing::field::display(id));
sqlx::query!(
let res = sqlx::query!(
r#"
INSERT INTO users (user_id, username, created_at)
VALUES ($1, $2, $3)
ON CONFLICT (username) DO NOTHING
"#,
Uuid::from(id),
username,
@@ -171,11 +175,17 @@ impl<'c> UserRepository for PgUserRepository<'c> {
.execute(&mut *self.conn)
.await?;
// If the user already exists, want to return an error but not poison the
// transaction
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(User {
id,
username,
sub: id.to_string(),
primary_user_email_id: None,
created_at,
locked_at: None,
})
}
@@ -203,4 +213,72 @@ impl<'c> UserRepository for PgUserRepository<'c> {
Ok(exists)
}
#[tracing::instrument(
name = "db.user.lock",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
if user.locked_at.is_some() {
return Ok(user);
}
let locked_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE users
SET locked_at = $1
WHERE user_id = $2
"#,
locked_at,
Uuid::from(user.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
user.locked_at = Some(locked_at);
Ok(user)
}
#[tracing::instrument(
name = "db.user.unlock",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
if user.locked_at.is_none() {
return Ok(user);
}
let res = sqlx::query!(
r#"
UPDATE users
SET locked_at = NULL
WHERE user_id = $1
"#,
Uuid::from(user.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
user.locked_at = None;
Ok(user)
}
}

View File

@@ -53,6 +53,8 @@ struct SessionLookup {
user_id: Uuid,
user_username: String,
user_primary_user_email_id: Option<Uuid>,
user_created_at: DateTime<Utc>,
user_locked_at: Option<DateTime<Utc>>,
}
impl TryFrom<SessionLookup> for BrowserSession {
@@ -65,6 +67,8 @@ impl TryFrom<SessionLookup> for BrowserSession {
username: value.user_username,
sub: id.to_string(),
primary_user_email_id: value.user_primary_user_email_id.map(Into::into),
created_at: value.user_created_at,
locked_at: value.user_locked_at,
};
Ok(BrowserSession {
@@ -99,6 +103,8 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
, u.user_id
, u.username AS "user_username"
, u.primary_user_email_id AS "user_primary_user_email_id"
, u.created_at AS "user_created_at"
, u.locked_at AS "user_locked_at"
FROM user_sessions s
INNER JOIN users u
USING (user_id)
@@ -232,6 +238,14 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
Expr::col((Users::Table, Users::PrimaryUserEmailId)),
SessionLookupIden::UserPrimaryUserEmailId,
)
.expr_as(
Expr::col((Users::Table, Users::CreatedAt)),
SessionLookupIden::UserCreatedAt,
)
.expr_as(
Expr::col((Users::Table, Users::LockedAt)),
SessionLookupIden::UserLockedAt,
)
.from(UserSessions::Table)
.inner_join(
Users::Table,

View File

@@ -63,12 +63,38 @@ async fn test_user_repo(pool: PgPool) {
assert!(repo.user().lookup(user.id).await.unwrap().is_some());
// Adding a second time should give a conflict
// It should not poison the transaction though
assert!(repo
.user()
.add(&mut rng, &clock, USERNAME.to_owned())
.await
.is_err());
// Try locking a user
assert!(user.is_valid());
let user = repo.user().lock(&clock, user).await.unwrap();
assert!(!user.is_valid());
// Check that the property is retrieved on lookup
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
assert!(!user.is_valid());
// Locking a second time should not fail
let user = repo.user().lock(&clock, user).await.unwrap();
assert!(!user.is_valid());
// Try unlocking a user
let user = repo.user().unlock(user).await.unwrap();
assert!(user.is_valid());
// Check that the property is retrieved on lookup
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
assert!(user.is_valid());
// Unlocking a second time should not fail
let user = repo.user().unlock(user).await.unwrap();
assert!(user.is_valid());
repo.save().await.unwrap();
}

View File

@@ -73,6 +73,15 @@ pub struct JobWithSpanContext<T> {
payload: T,
}
impl<J> From<J> for JobWithSpanContext<J> {
fn from(payload: J) -> Self {
Self {
span_context: None,
payload,
}
}
}
impl<J: Job> Job for JobWithSpanContext<J> {
const NAME: &'static str = J::NAME;
}
@@ -369,6 +378,47 @@ mod jobs {
impl Job for DeleteDeviceJob {
const NAME: &'static str = "delete-device";
}
/// A job to deactivate and lock a user
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct DeactivateUserJob {
user_id: Ulid,
hs_erase: bool,
}
impl DeactivateUserJob {
/// Create a new job to deactivate and lock a user
///
/// # Parameters
///
/// * `user` - The user to deactivate
/// * `hs_erase` - Whether to erase the user from the homeserver
#[must_use]
pub fn new(user: &User, hs_erase: bool) -> Self {
Self {
user_id: user.id,
hs_erase,
}
}
/// The ID of the user to deactivate
#[must_use]
pub fn user_id(&self) -> Ulid {
self.user_id
}
/// Whether to erase the user from the homeserver
#[must_use]
pub fn hs_erase(&self) -> bool {
self.hs_erase
}
}
impl Job for DeactivateUserJob {
const NAME: &'static str = "deactivate-user";
}
}
pub use self::jobs::{DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, VerifyEmailJob};
pub use self::jobs::{
DeactivateUserJob, DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, VerifyEmailJob,
};

View File

@@ -96,6 +96,33 @@ pub trait UserRepository: Send + Sync {
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn exists(&mut self, username: &str) -> Result<bool, Self::Error>;
/// Lock a [`User`]
///
/// Returns the locked [`User`]
///
/// # Parameters
///
/// * `clock`: The clock used to generate timestamps
/// * `user`: The [`User`] to lock
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn lock(&mut self, clock: &dyn Clock, user: User) -> Result<User, Self::Error>;
/// Unlock a [`User`]
///
/// Returns the unlocked [`User`]
///
/// # Parameters
///
/// * `user`: The [`User`] to unlock
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn unlock(&mut self, user: User) -> Result<User, Self::Error>;
}
repository_impl!(UserRepository:
@@ -108,4 +135,6 @@ repository_impl!(UserRepository:
username: String,
) -> Result<User, Self::Error>;
async fn exists(&mut self, username: &str) -> Result<bool, Self::Error>;
async fn lock(&mut self, clock: &dyn Clock, user: User) -> Result<User, Self::Error>;
async fn unlock(&mut self, user: User) -> Result<User, Self::Error>;
);

View File

@@ -29,7 +29,10 @@ use chrono::{DateTime, Utc};
use mas_storage::{oauth2::OAuth2AccessTokenRepository, RepositoryAccess};
use tracing::{debug, info};
use crate::{utils::metrics_layer, JobContextExt, State};
use crate::{
utils::{metrics_layer, trace_layer, TracedJob},
JobContextExt, State,
};
#[derive(Default, Clone)]
pub struct CleanupExpiredTokensJob {
@@ -46,6 +49,8 @@ impl Job for CleanupExpiredTokensJob {
const NAME: &'static str = "cleanup-expired-tokens";
}
impl TracedJob for CleanupExpiredTokensJob {}
pub async fn cleanup_expired_tokens(
job: CleanupExpiredTokensJob,
ctx: JobContext,
@@ -79,6 +84,7 @@ pub(crate) fn register(
.stream(CronStream::new(schedule).timer(TokioTimer).to_stream())
.layer(state.inject())
.layer(metrics_layer())
.layer(trace_layer())
.build_fn(cleanup_expired_tokens);
monitor.register(worker)

View File

@@ -13,25 +13,14 @@
// limitations under the License.
use anyhow::Context;
use apalis_core::{
builder::{WorkerBuilder, WorkerFactoryFn},
context::JobContext,
executor::TokioExecutor,
job::Job,
monitor::Monitor,
storage::builder::WithStorage,
};
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
use chrono::Duration;
use mas_email::{Address, EmailVerificationContext, Mailbox};
use mas_storage::job::{JobWithSpanContext, VerifyEmailJob};
use rand::{distributions::Uniform, Rng};
use tracing::info;
use crate::{
storage::PostgresStorageFactory,
utils::{metrics_layer, trace_layer},
JobContextExt, State,
};
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
#[tracing::instrument(
name = "job.verify_email",
@@ -99,15 +88,8 @@ pub(crate) fn register(
state: &State,
storage_factory: &PostgresStorageFactory,
) -> Monitor<TokioExecutor> {
let storage = storage_factory.build();
let worker_name = format!("{job}-{suffix}", job = VerifyEmailJob::NAME);
let worker = WorkerBuilder::new(worker_name)
.layer(state.inject())
.layer(trace_layer())
.layer(metrics_layer())
.with_storage_config(storage, |c| {
c.fetch_interval(std::time::Duration::from_secs(1))
})
.build_fn(verify_email);
monitor.register(worker)
let verify_email_worker =
crate::build!(VerifyEmailJob => verify_email, suffix, state, storage_factory);
monitor.register(verify_email_worker)
}

View File

@@ -33,6 +33,7 @@ mod database;
mod email;
mod matrix;
mod storage;
mod user;
mod utils;
#[derive(Clone)]
@@ -106,6 +107,32 @@ impl JobContextExt for apalis_core::context::JobContext {
}
}
/// Helper macro to build a storage-backed worker.
macro_rules! build {
($job:ty => $fn:ident, $suffix:expr, $state:expr, $factory:expr) => {{
let storage = $factory.build();
let worker_name = format!(
"{job}-{suffix}",
job = <$job as ::apalis_core::job::Job>::NAME,
suffix = $suffix
);
let builder = ::apalis_core::builder::WorkerBuilder::new(worker_name)
.layer($state.inject())
.layer(crate::utils::trace_layer())
.layer(crate::utils::metrics_layer());
let builder = ::apalis_core::storage::builder::WithStorage::with_storage_config(
builder,
storage,
|c| c.fetch_interval(std::time::Duration::from_secs(1)),
);
::apalis_core::builder::WorkerFactory::build(builder, ::apalis_core::job_fn::job_fn($fn))
}};
}
pub(crate) use build;
/// Initialise the workers.
///
/// # Errors
@@ -128,6 +155,7 @@ pub async fn init(
let monitor = self::database::register(name, monitor, &state);
let monitor = self::email::register(name, monitor, &state, &factory);
let monitor = self::matrix::register(name, monitor, &state, &factory);
let monitor = self::user::register(name, monitor, &state, &factory);
// TODO: we might want to grab the join handle here
factory.listen().await?;
debug!(?monitor, "workers registered");

View File

@@ -12,17 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
use anyhow::Context;
use apalis_core::{
builder::{WorkerBuilder, WorkerFactoryFn},
context::JobContext,
executor::TokioExecutor,
job::Job,
monitor::Monitor,
storage::builder::WithStorage,
};
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
use mas_matrix::ProvisionRequest;
use mas_storage::{
job::{DeleteDeviceJob, JobWithSpanContext, ProvisionDeviceJob, ProvisionUserJob},
@@ -31,11 +22,7 @@ use mas_storage::{
};
use tracing::info;
use crate::{
storage::PostgresStorageFactory,
utils::{metrics_layer, trace_layer},
JobContextExt, State,
};
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
/// Job to provision a user on the Matrix homeserver.
/// This works by doing a PUT request to the /_synapse/admin/v2/users/{user_id}
@@ -163,32 +150,12 @@ pub(crate) fn register(
state: &State,
storage_factory: &PostgresStorageFactory,
) -> Monitor<TokioExecutor> {
let storage = storage_factory.build();
let worker_name = format!("{job}-{suffix}", job = ProvisionUserJob::NAME);
let provision_user_worker = WorkerBuilder::new(worker_name)
.layer(state.inject())
.layer(trace_layer())
.layer(metrics_layer())
.with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1)))
.build_fn(provision_user);
let storage = storage_factory.build();
let worker_name = format!("{job}-{suffix}", job = ProvisionDeviceJob::NAME);
let provision_device_worker = WorkerBuilder::new(worker_name)
.layer(state.inject())
.layer(trace_layer())
.layer(metrics_layer())
.with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1)))
.build_fn(provision_device);
let storage = storage_factory.build();
let worker_name = format!("{job}-{suffix}", job = DeleteDeviceJob::NAME);
let delete_device_worker = WorkerBuilder::new(worker_name)
.layer(state.inject())
.layer(trace_layer())
.layer(metrics_layer())
.with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1)))
.build_fn(delete_device);
let provision_user_worker =
crate::build!(ProvisionUserJob => provision_user, suffix, state, storage_factory);
let provision_device_worker =
crate::build!(ProvisionDeviceJob => provision_device, suffix, state, storage_factory);
let delete_device_worker =
crate::build!(DeleteDeviceJob => delete_device, suffix, state, storage_factory);
monitor
.register(provision_user_worker)

77
crates/tasks/src/user.rs Normal file
View File

@@ -0,0 +1,77 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
use mas_storage::{
job::{DeactivateUserJob, JobWithSpanContext},
user::UserRepository,
RepositoryAccess,
};
use tracing::info;
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
/// Job to deactivate a user, both locally and on the Matrix homeserver.
#[tracing::instrument(
name = "job.deactivate_user"
fields(user.id = %job.user_id(), erase = %job.hs_erase()),
skip_all,
err(Debug),
)]
async fn deactivate_user(
job: JobWithSpanContext<DeactivateUserJob>,
ctx: JobContext,
) -> Result<(), anyhow::Error> {
let state = ctx.state();
let clock = state.clock();
let matrix = state.matrix_connection();
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(job.user_id())
.await?
.context("User not found")?;
// Let's first lock the user
let user = repo
.user()
.lock(&clock, user)
.await
.context("Failed to lock user")?;
// TODO: delete the sessions & access tokens
// Before calling back to the homeserver, commit the changes to the database
repo.save().await?;
let mxid = matrix.mxid(&user.username);
info!("Deactivating user {} on homeserver", mxid);
matrix.delete_user(&mxid, job.hs_erase()).await?;
Ok(())
}
pub(crate) fn register(
suffix: &str,
monitor: Monitor<TokioExecutor>,
state: &State,
storage_factory: &PostgresStorageFactory,
) -> Monitor<TokioExecutor> {
let deactivate_user_worker =
crate::build!(DeactivateUserJob => deactivate_user, suffix, state, storage_factory);
monitor.register(deactivate_user_worker)
}

View File

@@ -18,14 +18,32 @@ use mas_tower::{
make_span_fn, DurationRecorderLayer, FnWrapper, IdentityLayer, InFlightCounterLayer,
TraceLayer, KV,
};
use opentelemetry::{Key, KeyValue};
use opentelemetry::{trace::SpanContext, Key, KeyValue};
use tracing::info_span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
const JOB_NAME: Key = Key::from_static_str("job.name");
const JOB_STATUS: Key = Key::from_static_str("job.status");
fn make_span_for_job_request<J>(req: &JobRequest<JobWithSpanContext<J>>) -> tracing::Span
/// Represents a job that can may have a span context attached to it.
pub trait TracedJob: Job {
/// Returns the span context for this job, if any.
///
/// The default implementation returns `None`.
fn span_context(&self) -> Option<SpanContext> {
None
}
}
/// Implements [`TracedJob`] for any job with the [`JobWithSpanContext`]
/// wrapper.
impl<J: Job> TracedJob for JobWithSpanContext<J> {
fn span_context(&self) -> Option<SpanContext> {
JobWithSpanContext::span_context(self)
}
}
fn make_span_for_job_request<J: TracedJob>(req: &JobRequest<J>) -> tracing::Span
where
J: Job,
{
@@ -45,18 +63,15 @@ where
span
}
type TraceLayerForJob<J> = TraceLayer<
FnWrapper<fn(&JobRequest<JobWithSpanContext<J>>) -> tracing::Span>,
KV<&'static str>,
KV<&'static str>,
>;
type TraceLayerForJob<J> =
TraceLayer<FnWrapper<fn(&JobRequest<J>) -> tracing::Span>, KV<&'static str>, KV<&'static str>>;
pub(crate) fn trace_layer<J>() -> TraceLayerForJob<J>
where
J: Job,
J: TracedJob,
{
TraceLayer::new(make_span_fn(
make_span_for_job_request::<J> as fn(&JobRequest<JobWithSpanContext<J>>) -> tracing::Span,
make_span_for_job_request::<J> as fn(&JobRequest<J>) -> tracing::Span,
))
.on_response(KV("otel.status_code", "OK"))
.on_error(KV("otel.status_code", "ERROR"))