Merge pull request #1413 from matrix-org/quenting/user-lock
Add a way to lock and deprovision users
This commit is contained in:
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -3371,9 +3371,11 @@ dependencies = [
|
||||
name = "mas-matrix"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"http",
|
||||
"serde",
|
||||
"tokio",
|
||||
"url",
|
||||
]
|
||||
|
||||
@@ -3389,6 +3391,7 @@ dependencies = [
|
||||
"mas-matrix",
|
||||
"serde",
|
||||
"tower",
|
||||
"tracing",
|
||||
"url",
|
||||
]
|
||||
|
||||
|
||||
@@ -18,14 +18,14 @@ use mas_config::{DatabaseConfig, PasswordsConfig};
|
||||
use mas_data_model::{Device, TokenType};
|
||||
use mas_storage::{
|
||||
compat::{CompatAccessTokenRepository, CompatSessionRepository},
|
||||
job::{DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
|
||||
job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
|
||||
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
|
||||
Repository, RepositoryAccess, SystemClock,
|
||||
};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use rand::SeedableRng;
|
||||
use sqlx::types::Uuid;
|
||||
use tracing::{info, info_span};
|
||||
use tracing::{info, info_span, warn};
|
||||
|
||||
use crate::util::{database_from_config, password_manager_from_config};
|
||||
|
||||
@@ -69,6 +69,22 @@ enum Subcommand {
|
||||
#[arg(long)]
|
||||
dry_run: bool,
|
||||
},
|
||||
|
||||
/// Lock a user
|
||||
LockUser {
|
||||
/// User to lock
|
||||
username: String,
|
||||
|
||||
/// Whether to deactivate the user
|
||||
#[arg(long)]
|
||||
deactivate: bool,
|
||||
},
|
||||
|
||||
/// Unlock a user
|
||||
UnlockUser {
|
||||
/// User to unlock
|
||||
username: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Options {
|
||||
@@ -330,6 +346,59 @@ impl Options {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SC::LockUser {
|
||||
username,
|
||||
deactivate,
|
||||
} => {
|
||||
let _span = info_span!("cli.manage.lock_user", user.username = username).entered();
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config).await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.find_by_username(&username)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
info!(%user.id, "Locking user");
|
||||
|
||||
// Even though the deactivation job will lock the user, we lock it here in case
|
||||
// the worker is not running, as we don't have a good way to run a job
|
||||
// synchronously yet.
|
||||
let user = repo.user().lock(&clock, user).await?;
|
||||
|
||||
if deactivate {
|
||||
warn!(%user.id, "Scheduling user deactivation");
|
||||
repo.job()
|
||||
.schedule_job(DeactivateUserJob::new(&user, false))
|
||||
.await?;
|
||||
}
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SC::UnlockUser { username } => {
|
||||
let _span = info_span!("cli.manage.lock_user", user.username = username).entered();
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config).await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.find_by_username(&username)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
info!(%user.id, "Unlocking user");
|
||||
|
||||
repo.user().unlock(user).await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,16 @@ pub struct User {
|
||||
pub username: String,
|
||||
pub sub: String,
|
||||
pub primary_user_email_id: Option<Ulid>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub locked_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl User {
|
||||
/// Returns `true` unless the user is locked.
|
||||
#[must_use]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
self.locked_at.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl User {
|
||||
@@ -35,6 +45,8 @@ impl User {
|
||||
username: "john".to_owned(),
|
||||
sub: "123-456".to_owned(),
|
||||
primary_user_email_id: None,
|
||||
created_at: now,
|
||||
locked_at: None,
|
||||
}]
|
||||
}
|
||||
}
|
||||
@@ -65,7 +77,7 @@ pub struct BrowserSession {
|
||||
impl BrowserSession {
|
||||
#[must_use]
|
||||
pub fn active(&self) -> bool {
|
||||
self.finished_at.is_none()
|
||||
self.finished_at.is_none() && self.user.is_valid()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -335,6 +335,7 @@ async fn token_login(
|
||||
.user()
|
||||
.lookup(session.user_id)
|
||||
.await?
|
||||
.filter(mas_data_model::User::is_valid)
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
|
||||
repo.compat_sso_login().exchange(clock, login).await?;
|
||||
@@ -355,6 +356,7 @@ async fn user_password_login(
|
||||
.user()
|
||||
.find_by_username(&username)
|
||||
.await?
|
||||
.filter(mas_data_model::User::is_valid)
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
|
||||
// Lookup its password
|
||||
|
||||
@@ -16,23 +16,24 @@ use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::{IntoResponse, Response},
|
||||
response::{Html, IntoResponse, Response},
|
||||
};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::SessionInfoExt;
|
||||
use mas_axum_utils::{csrf::CsrfExt, SessionInfoExt};
|
||||
use mas_data_model::{AuthorizationGrant, BrowserSession, Client, Device};
|
||||
use mas_keystore::{Encrypter, Keystore};
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_policy::{EvaluationResult, PolicyFactory};
|
||||
use mas_router::{PostAuthAction, Route, UrlBuilder};
|
||||
use mas_storage::{
|
||||
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
|
||||
user::BrowserSessionRepository,
|
||||
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
|
||||
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
|
||||
};
|
||||
use mas_templates::Templates;
|
||||
use mas_templates::{PolicyViolationContext, TemplateContext, Templates};
|
||||
use oauth2_types::requests::AuthorizationResponse;
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::callback::CallbackDestination;
|
||||
@@ -74,6 +75,7 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
@@ -87,7 +89,7 @@ impl_from_error_for_route!(super::callback::CallbackDestinationError);
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn get(
|
||||
rng: BoxRng,
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||
State(templates): State<Templates>,
|
||||
@@ -123,15 +125,15 @@ pub(crate) async fn get(
|
||||
.ok_or(RouteError::NoSuchClient)?;
|
||||
|
||||
match complete(
|
||||
rng,
|
||||
clock,
|
||||
&mut rng,
|
||||
&clock,
|
||||
repo,
|
||||
key_store,
|
||||
&policy_factory,
|
||||
url_builder,
|
||||
grant,
|
||||
client,
|
||||
session,
|
||||
&client,
|
||||
&session,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -144,10 +146,22 @@ pub(crate) async fn get(
|
||||
mas_router::Reauth::and_then(continue_grant).go(),
|
||||
)
|
||||
.into_response()),
|
||||
Err(GrantCompletionError::RequiresConsent | GrantCompletionError::PolicyViolation) => {
|
||||
Err(GrantCompletionError::RequiresConsent) => {
|
||||
let next = mas_router::Consent(grant_id);
|
||||
Ok((cookie_jar, next.go()).into_response())
|
||||
}
|
||||
Err(GrantCompletionError::PolicyViolation(grant, res)) => {
|
||||
warn!(violation = ?res, "Authorization grant for client {} denied by policy", client.id);
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let ctx = PolicyViolationContext::new(grant, client)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value());
|
||||
|
||||
let content = templates.render_policy_violation(&ctx).await?;
|
||||
|
||||
Ok((cookie_jar, Html(content)).into_response())
|
||||
}
|
||||
Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending),
|
||||
Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)),
|
||||
}
|
||||
@@ -168,7 +182,7 @@ pub enum GrantCompletionError {
|
||||
RequiresConsent,
|
||||
|
||||
#[error("denied by the policy")]
|
||||
PolicyViolation,
|
||||
PolicyViolation(AuthorizationGrant, EvaluationResult),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError);
|
||||
@@ -179,15 +193,15 @@ impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError);
|
||||
impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError);
|
||||
|
||||
pub(crate) async fn complete(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
rng: &mut (impl rand::RngCore + rand::CryptoRng + Send),
|
||||
clock: &impl Clock,
|
||||
mut repo: BoxRepository,
|
||||
key_store: Keystore,
|
||||
policy_factory: &PolicyFactory,
|
||||
url_builder: UrlBuilder,
|
||||
grant: AuthorizationGrant,
|
||||
client: Client,
|
||||
browser_session: BrowserSession,
|
||||
client: &Client,
|
||||
browser_session: &BrowserSession,
|
||||
) -> Result<AuthorizationResponse, GrantCompletionError> {
|
||||
// Verify that the grant is in a pending stage
|
||||
if !grant.stage.is_pending() {
|
||||
@@ -197,7 +211,7 @@ pub(crate) async fn complete(
|
||||
// Check if the authentication is fresh enough
|
||||
let authentication = repo
|
||||
.browser_session()
|
||||
.get_last_authentication(&browser_session)
|
||||
.get_last_authentication(browser_session)
|
||||
.await?;
|
||||
let authentication = authentication.filter(|auth| auth.created_at > grant.max_auth_time());
|
||||
|
||||
@@ -209,16 +223,16 @@ pub(crate) async fn complete(
|
||||
// Run through the policy
|
||||
let mut policy = policy_factory.instantiate().await?;
|
||||
let res = policy
|
||||
.evaluate_authorization_grant(&grant, &client, &browser_session.user)
|
||||
.evaluate_authorization_grant(&grant, client, &browser_session.user)
|
||||
.await?;
|
||||
|
||||
if !res.valid() {
|
||||
return Err(GrantCompletionError::PolicyViolation);
|
||||
return Err(GrantCompletionError::PolicyViolation(grant, res));
|
||||
}
|
||||
|
||||
let current_consent = repo
|
||||
.oauth2_client()
|
||||
.get_consent_for_user(&client, &browser_session.user)
|
||||
.get_consent_for_user(client, &browser_session.user)
|
||||
.await?;
|
||||
|
||||
let lacks_consent = grant
|
||||
@@ -236,18 +250,12 @@ pub(crate) async fn complete(
|
||||
// All good, let's start the session
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&client,
|
||||
&browser_session,
|
||||
grant.scope.clone(),
|
||||
)
|
||||
.add(rng, clock, client, browser_session, grant.scope.clone())
|
||||
.await?;
|
||||
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.fulfill(&clock, &session, grant)
|
||||
.fulfill(clock, &session, grant)
|
||||
.await?;
|
||||
|
||||
// Yep! Let's complete the auth now
|
||||
@@ -256,13 +264,13 @@ pub(crate) async fn complete(
|
||||
// Did they request an ID token?
|
||||
if grant.response_type_id_token {
|
||||
params.id_token = Some(generate_id_token(
|
||||
&mut rng,
|
||||
&clock,
|
||||
rng,
|
||||
clock,
|
||||
&url_builder,
|
||||
&key_store,
|
||||
&client,
|
||||
client,
|
||||
&grant,
|
||||
&browser_session,
|
||||
browser_session,
|
||||
None,
|
||||
Some(&valid_authentication),
|
||||
)?);
|
||||
|
||||
@@ -16,11 +16,11 @@ use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::{Form, State},
|
||||
response::{IntoResponse, Response},
|
||||
response::{Html, IntoResponse, Response},
|
||||
};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::SessionInfoExt;
|
||||
use mas_axum_utils::{csrf::CsrfExt, SessionInfoExt};
|
||||
use mas_data_model::{AuthorizationCode, Pkce};
|
||||
use mas_keystore::{Encrypter, Keystore};
|
||||
use mas_policy::PolicyFactory;
|
||||
@@ -29,7 +29,7 @@ use mas_storage::{
|
||||
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
|
||||
BoxClock, BoxRepository, BoxRng,
|
||||
};
|
||||
use mas_templates::Templates;
|
||||
use mas_templates::{PolicyViolationContext, TemplateContext, Templates};
|
||||
use oauth2_types::{
|
||||
errors::{ClientError, ClientErrorCode},
|
||||
pkce,
|
||||
@@ -39,6 +39,7 @@ use oauth2_types::{
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde::Deserialize;
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
|
||||
use self::{callback::CallbackDestination, complete::GrantCompletionError};
|
||||
use crate::impl_from_error_for_route;
|
||||
@@ -91,6 +92,7 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
impl_from_error_for_route!(self::callback::CallbackDestinationError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
@@ -170,6 +172,7 @@ pub(crate) async fn get(
|
||||
|
||||
// Get the session info from the cookie
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
|
||||
// One day, we will have try blocks
|
||||
let res: Result<Response, RouteError> = ({
|
||||
@@ -340,15 +343,15 @@ pub(crate) async fn get(
|
||||
Some(user_session) if prompt.contains(&Prompt::None) => {
|
||||
// With prompt=none, we should get back to the client immediately
|
||||
match self::complete::complete(
|
||||
rng,
|
||||
clock,
|
||||
&mut rng,
|
||||
&clock,
|
||||
repo,
|
||||
key_store,
|
||||
&policy_factory,
|
||||
url_builder,
|
||||
grant,
|
||||
client,
|
||||
user_session,
|
||||
&client,
|
||||
&user_session,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -369,7 +372,7 @@ pub(crate) async fn get(
|
||||
)
|
||||
.await?
|
||||
}
|
||||
Err(GrantCompletionError::PolicyViolation) => {
|
||||
Err(GrantCompletionError::PolicyViolation(_grant, _res)) => {
|
||||
callback_destination
|
||||
.go(&templates, ClientError::from(ClientErrorCode::AccessDenied))
|
||||
.await?
|
||||
@@ -387,29 +390,32 @@ pub(crate) async fn get(
|
||||
let grant_id = grant.id;
|
||||
// Else, we show the relevant reauth/consent page if necessary
|
||||
match self::complete::complete(
|
||||
rng,
|
||||
clock,
|
||||
&mut rng,
|
||||
&clock,
|
||||
repo,
|
||||
key_store,
|
||||
&policy_factory,
|
||||
url_builder,
|
||||
grant,
|
||||
client,
|
||||
user_session,
|
||||
&client,
|
||||
&user_session,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(params) => callback_destination.go(&templates, params).await?,
|
||||
Err(
|
||||
GrantCompletionError::RequiresConsent
|
||||
| GrantCompletionError::PolicyViolation,
|
||||
) => {
|
||||
// We're redirecting to the consent URI in both 'consent required' and
|
||||
// 'policy violation' cases, because we reevaluate the policy on this
|
||||
// page, and show the error accordingly
|
||||
// XXX: is this the right approach?
|
||||
Err(GrantCompletionError::RequiresConsent) => {
|
||||
mas_router::Consent(grant_id).go().into_response()
|
||||
}
|
||||
Err(GrantCompletionError::PolicyViolation(grant, res)) => {
|
||||
warn!(violation = ?res, "Authorization grant for client {} denied by policy", client.id);
|
||||
|
||||
let ctx = PolicyViolationContext::new(grant, client)
|
||||
.with_session(user_session)
|
||||
.with_csrf(csrf_token.form_value());
|
||||
|
||||
let content = templates.render_policy_violation(&ctx).await?;
|
||||
Html(content).into_response()
|
||||
}
|
||||
Err(GrantCompletionError::RequiresReauth) => {
|
||||
mas_router::Reauth::and_then(continue_grant)
|
||||
.go()
|
||||
|
||||
@@ -188,6 +188,7 @@ pub(crate) async fn post(
|
||||
.browser_session()
|
||||
.lookup(session.user_session_id)
|
||||
.await?
|
||||
.filter(|b| b.user.is_valid())
|
||||
// XXX: is that the right error to bubble up?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
@@ -227,6 +228,7 @@ pub(crate) async fn post(
|
||||
.browser_session()
|
||||
.lookup(session.user_session_id)
|
||||
.await?
|
||||
.filter(|b| b.user.is_valid())
|
||||
// XXX: is that the right error to bubble up?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
@@ -265,6 +267,7 @@ pub(crate) async fn post(
|
||||
.user()
|
||||
.lookup(session.user_id)
|
||||
.await?
|
||||
.filter(mas_data_model::User::is_valid)
|
||||
// XXX: is that the right error to bubble up?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
@@ -311,6 +314,7 @@ pub(crate) async fn post(
|
||||
.user()
|
||||
.lookup(session.user_id)
|
||||
.await?
|
||||
.filter(mas_data_model::User::is_valid)
|
||||
// XXX: is that the right error to bubble up?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ use headers::{Authorization, ContentType, HeaderMapExt, HeaderName};
|
||||
use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode};
|
||||
use mas_axum_utils::http_client_factory::HttpClientFactory;
|
||||
use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
|
||||
use mas_matrix::{HomeserverConnection, MatrixUser, ProvisionRequest};
|
||||
use mas_matrix::MockHomeserverConnection;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::{SimpleRoute, UrlBuilder};
|
||||
use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository};
|
||||
@@ -69,40 +69,6 @@ pub(crate) struct TestState {
|
||||
pub rng: Arc<Mutex<ChaChaRng>>,
|
||||
}
|
||||
|
||||
/// A Mock implementation of a [`HomeserverConnection`], which never fails and
|
||||
/// doesn't do anything.
|
||||
struct MockHomeserverConnection {
|
||||
homeserver: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HomeserverConnection for MockHomeserverConnection {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn homeserver(&self) -> &str {
|
||||
&self.homeserver
|
||||
}
|
||||
|
||||
async fn query_user(&self, _mxid: &str) -> Result<MatrixUser, Self::Error> {
|
||||
Ok(MatrixUser {
|
||||
displayname: None,
|
||||
avatar_url: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn provision_user(&self, _request: &ProvisionRequest) -> Result<bool, Self::Error> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn create_device(&self, _mxid: &str, _device_id: &str) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_device(&self, _mxid: &str, _device_id: &str) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl TestState {
|
||||
/// Create a new test state from the given database pool
|
||||
pub async fn from_pool(pool: PgPool) -> Result<Self, anyhow::Error> {
|
||||
@@ -145,9 +111,7 @@ impl TestState {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let homeserver_connection = MockHomeserverConnection {
|
||||
homeserver: "example.com".to_owned(),
|
||||
};
|
||||
let homeserver_connection = MockHomeserverConnection::new("example.com");
|
||||
|
||||
let policy_factory = Arc::new(policy_factory);
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ use mas_axum_utils::{
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
SessionInfoExt,
|
||||
};
|
||||
use mas_data_model::UpstreamOAuthProviderImportPreference;
|
||||
use mas_data_model::{UpstreamOAuthProviderImportPreference, User};
|
||||
use mas_jose::jwt::Jwt;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_storage::{
|
||||
@@ -239,6 +239,8 @@ pub(crate) async fn get(
|
||||
.user()
|
||||
.lookup(user_id)
|
||||
.await?
|
||||
// XXX: is that right?
|
||||
.filter(User::is_valid)
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
|
||||
let ctx = UpstreamExistingLinkContext::new(user)
|
||||
@@ -263,6 +265,7 @@ pub(crate) async fn get(
|
||||
.user()
|
||||
.lookup(user_id)
|
||||
.await?
|
||||
.filter(mas_data_model::User::is_valid)
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
|
||||
let ctx = UpstreamExistingLinkContext::new(user).with_csrf(csrf_token.form_value());
|
||||
@@ -390,6 +393,7 @@ pub(crate) async fn post(
|
||||
.user()
|
||||
.lookup(user_id)
|
||||
.await?
|
||||
.filter(mas_data_model::User::is_valid)
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
|
||||
repo.browser_session().add(&mut rng, &clock, &user).await?
|
||||
|
||||
@@ -202,6 +202,7 @@ async fn login(
|
||||
.find_by_username(username)
|
||||
.await
|
||||
.map_err(|_e| FormError::Internal)?
|
||||
.filter(mas_data_model::User::is_valid)
|
||||
.ok_or(FormError::InvalidCredentials)?;
|
||||
|
||||
// And its password
|
||||
|
||||
@@ -9,9 +9,10 @@ license = "Apache-2.0"
|
||||
anyhow = "1.0.72"
|
||||
async-trait = "0.1.72"
|
||||
http = "0.2.9"
|
||||
url = "2.4.0"
|
||||
serde = { version = "1.0.180", features = ["derive"] }
|
||||
tower = { version = "0.4.13", features = ["util"] }
|
||||
tracing = "0.1.37"
|
||||
url = "2.4.0"
|
||||
|
||||
mas-axum-utils = { path = "../axum-utils" }
|
||||
mas-http = { path = "../http" }
|
||||
|
||||
@@ -124,6 +124,11 @@ struct SynapseDevice {
|
||||
device_id: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SynapseDeactivateUserRequest {
|
||||
erase: bool,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HomeserverConnection for SynapseConnection {
|
||||
type Error = anyhow::Error;
|
||||
@@ -132,6 +137,15 @@ impl HomeserverConnection for SynapseConnection {
|
||||
&self.homeserver
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "homeserver.query_user",
|
||||
skip_all,
|
||||
fields(
|
||||
matrix.homeserver = self.homeserver,
|
||||
matrix.mxid = mxid,
|
||||
),
|
||||
err(Display),
|
||||
)]
|
||||
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error> {
|
||||
let mut client = self
|
||||
.http_client_factory
|
||||
@@ -158,6 +172,16 @@ impl HomeserverConnection for SynapseConnection {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "homeserver.provision_user",
|
||||
skip_all,
|
||||
fields(
|
||||
matrix.homeserver = self.homeserver,
|
||||
matrix.mxid = request.mxid(),
|
||||
user.id = request.sub(),
|
||||
),
|
||||
err(Display),
|
||||
)]
|
||||
async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, Self::Error> {
|
||||
let mut body = SynapseUser {
|
||||
external_ids: Some(vec![ExternalID {
|
||||
@@ -213,6 +237,16 @@ impl HomeserverConnection for SynapseConnection {
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "homeserver.create_device",
|
||||
skip_all,
|
||||
fields(
|
||||
matrix.homeserver = self.homeserver,
|
||||
matrix.mxid = mxid,
|
||||
matrix.device_id = device_id,
|
||||
),
|
||||
err(Display),
|
||||
)]
|
||||
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
|
||||
let mut client = self
|
||||
.http_client_factory
|
||||
@@ -236,6 +270,16 @@ impl HomeserverConnection for SynapseConnection {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "homeserver.delete_device",
|
||||
skip_all,
|
||||
fields(
|
||||
matrix.homeserver = self.homeserver,
|
||||
matrix.mxid = mxid,
|
||||
matrix.device_id = device_id,
|
||||
),
|
||||
err(Display),
|
||||
)]
|
||||
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
|
||||
let mut client = self.http_client_factory.client().await?;
|
||||
|
||||
@@ -253,4 +297,35 @@ impl HomeserverConnection for SynapseConnection {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "homeserver.delete_user",
|
||||
skip_all,
|
||||
fields(
|
||||
matrix.homeserver = self.homeserver,
|
||||
matrix.mxid = mxid,
|
||||
erase = erase,
|
||||
),
|
||||
err(Display),
|
||||
)]
|
||||
async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> {
|
||||
let mut client = self
|
||||
.http_client_factory
|
||||
.client()
|
||||
.await?
|
||||
.request_bytes_to_body()
|
||||
.json_request();
|
||||
|
||||
let request = self
|
||||
.post(&format!("_synapse/admin/v1/deactivate/{mxid}"))
|
||||
.body(SynapseDeactivateUserRequest { erase })?;
|
||||
|
||||
let response = client.ready().await?.call(request).await?;
|
||||
|
||||
if response.status() != StatusCode::OK {
|
||||
return Err(anyhow::anyhow!("Failed to delete user in Synapse"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@ edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.71"
|
||||
serde = { version = "1.0.180", features = ["derive"] }
|
||||
async-trait = "0.1.72"
|
||||
http = "0.2.9"
|
||||
tokio = { version = "1.28.2", features = ["sync", "macros", "rt"] }
|
||||
url = "2.4.0"
|
||||
|
||||
@@ -16,6 +16,10 @@
|
||||
#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)]
|
||||
#![warn(clippy::pedantic)]
|
||||
|
||||
mod mock;
|
||||
|
||||
pub use self::mock::HomeserverConnection as MockHomeserverConnection;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MatrixUser {
|
||||
pub displayname: Option<String>,
|
||||
@@ -39,39 +43,58 @@ pub struct ProvisionRequest {
|
||||
}
|
||||
|
||||
impl ProvisionRequest {
|
||||
/// Create a new [`ProvisionRequest`].
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `mxid` - The Matrix ID to provision.
|
||||
/// * `sub` - The `sub` of the user, aka the internal ID.
|
||||
#[must_use]
|
||||
pub fn new(mxid: String, sub: String) -> Self {
|
||||
pub fn new(mxid: impl Into<String>, sub: impl Into<String>) -> Self {
|
||||
Self {
|
||||
mxid,
|
||||
sub,
|
||||
mxid: mxid.into(),
|
||||
sub: sub.into(),
|
||||
displayname: FieldAction::DoNothing,
|
||||
avatar_url: FieldAction::DoNothing,
|
||||
emails: FieldAction::DoNothing,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the `sub` of the user to provision, aka the internal ID.
|
||||
#[must_use]
|
||||
pub fn sub(&self) -> &str {
|
||||
&self.sub
|
||||
}
|
||||
|
||||
/// Get the Matrix ID to provision.
|
||||
#[must_use]
|
||||
pub fn mxid(&self) -> &str {
|
||||
&self.mxid
|
||||
}
|
||||
|
||||
/// Ask to set the displayname of the user.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `displayname` - The displayname to set.
|
||||
#[must_use]
|
||||
pub fn set_displayname(mut self, displayname: String) -> Self {
|
||||
self.displayname = FieldAction::Set(displayname);
|
||||
self
|
||||
}
|
||||
|
||||
/// Ask to unset the displayname of the user.
|
||||
#[must_use]
|
||||
pub fn unset_displayname(mut self) -> Self {
|
||||
self.displayname = FieldAction::Unset;
|
||||
self
|
||||
}
|
||||
|
||||
/// Call the given callback if the displayname should be set or unset.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `callback` - The callback to call.
|
||||
pub fn on_displayname(&self, callback: impl FnOnce(Option<&str>)) -> &Self {
|
||||
match &self.displayname {
|
||||
FieldAction::Unset => callback(None),
|
||||
@@ -82,18 +105,29 @@ impl ProvisionRequest {
|
||||
self
|
||||
}
|
||||
|
||||
/// Ask to set the avatar URL of the user.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `avatar_url` - The avatar URL to set.
|
||||
#[must_use]
|
||||
pub fn set_avatar_url(mut self, avatar_url: String) -> Self {
|
||||
self.avatar_url = FieldAction::Set(avatar_url);
|
||||
self
|
||||
}
|
||||
|
||||
/// Ask to unset the avatar URL of the user.
|
||||
#[must_use]
|
||||
pub fn unset_avatar_url(mut self) -> Self {
|
||||
self.avatar_url = FieldAction::Unset;
|
||||
self
|
||||
}
|
||||
|
||||
/// Call the given callback if the avatar URL should be set or unset.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `callback` - The callback to call.
|
||||
pub fn on_avatar_url(&self, callback: impl FnOnce(Option<&str>)) -> &Self {
|
||||
match &self.avatar_url {
|
||||
FieldAction::Unset => callback(None),
|
||||
@@ -104,18 +138,29 @@ impl ProvisionRequest {
|
||||
self
|
||||
}
|
||||
|
||||
/// Ask to set the emails of the user.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `emails` - The list of emails to set.
|
||||
#[must_use]
|
||||
pub fn set_emails(mut self, emails: Vec<String>) -> Self {
|
||||
self.emails = FieldAction::Set(emails);
|
||||
self
|
||||
}
|
||||
|
||||
/// Ask to unset the emails of the user.
|
||||
#[must_use]
|
||||
pub fn unset_emails(mut self) -> Self {
|
||||
self.emails = FieldAction::Unset;
|
||||
self
|
||||
}
|
||||
|
||||
/// Call the given callback if the emails should be set or unset.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `callback` - The callback to call.
|
||||
pub fn on_emails(&self, callback: impl FnOnce(Option<&[String]>)) -> &Self {
|
||||
match &self.emails {
|
||||
FieldAction::Unset => callback(None),
|
||||
@@ -129,17 +174,84 @@ impl ProvisionRequest {
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait HomeserverConnection: Send + Sync {
|
||||
/// The error type returned by all methods.
|
||||
type Error;
|
||||
|
||||
/// Get the homeserver URL.
|
||||
fn homeserver(&self) -> &str;
|
||||
|
||||
/// Get the Matrix ID of the user with the given localpart.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `localpart` - The localpart of the user.
|
||||
fn mxid(&self, localpart: &str) -> String {
|
||||
format!("@{}:{}", localpart, self.homeserver())
|
||||
}
|
||||
|
||||
/// Query the state of a user on the homeserver.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `mxid` - The Matrix ID of the user to query.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the homeserver is unreachable or the user does not
|
||||
/// exist.
|
||||
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error>;
|
||||
|
||||
/// Provision a user on the homeserver.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `request` - a [`ProvisionRequest`] containing the details of the user
|
||||
/// to provision.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the homeserver is unreachable or the user could not
|
||||
/// be provisioned.
|
||||
async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, Self::Error>;
|
||||
|
||||
/// Create a device for a user on the homeserver.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `mxid` - The Matrix ID of the user to create a device for.
|
||||
/// * `device_id` - The device ID to create.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the homeserver is unreachable or the device could
|
||||
/// not be created.
|
||||
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error>;
|
||||
|
||||
/// Delete a device for a user on the homeserver.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `mxid` - The Matrix ID of the user to delete a device for.
|
||||
/// * `device_id` - The device ID to delete.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the homeserver is unreachable or the device could
|
||||
/// not be deleted.
|
||||
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error>;
|
||||
|
||||
/// Delete a user on the homeserver.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `mxid` - The Matrix ID of the user to delete.
|
||||
/// * `erase` - Whether to ask the homeserver to erase the user's data.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the homeserver is unreachable or the user could not
|
||||
/// be deleted.
|
||||
async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -165,4 +277,8 @@ impl<T: HomeserverConnection + Send + Sync + ?Sized> HomeserverConnection for &T
|
||||
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
|
||||
(**self).delete_device(mxid, device_id).await
|
||||
}
|
||||
|
||||
async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> {
|
||||
(**self).delete_user(mxid, erase).await
|
||||
}
|
||||
}
|
||||
|
||||
169
crates/matrix/src/mock.rs
Normal file
169
crates/matrix/src/mock.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::{MatrixUser, ProvisionRequest};
|
||||
|
||||
struct MockUser {
|
||||
sub: String,
|
||||
avatar_url: Option<String>,
|
||||
displayname: Option<String>,
|
||||
devices: HashSet<String>,
|
||||
emails: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// A mock implementation of a [`HomeserverConnection`], which never fails and
|
||||
/// doesn't do anything.
|
||||
pub struct HomeserverConnection {
|
||||
homeserver: String,
|
||||
users: RwLock<HashMap<String, MockUser>>,
|
||||
}
|
||||
|
||||
impl HomeserverConnection {
|
||||
/// Create a new mock connection.
|
||||
pub fn new<H>(homeserver: H) -> Self
|
||||
where
|
||||
H: Into<String>,
|
||||
{
|
||||
Self {
|
||||
homeserver: homeserver.into(),
|
||||
users: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl crate::HomeserverConnection for HomeserverConnection {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn homeserver(&self) -> &str {
|
||||
&self.homeserver
|
||||
}
|
||||
|
||||
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error> {
|
||||
let users = self.users.read().await;
|
||||
let user = users.get(mxid).context("User not found")?;
|
||||
Ok(MatrixUser {
|
||||
displayname: user.displayname.clone(),
|
||||
avatar_url: user.avatar_url.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, Self::Error> {
|
||||
let mut users = self.users.write().await;
|
||||
let inserted = !users.contains_key(request.mxid());
|
||||
let user = users.entry(request.mxid().to_owned()).or_insert(MockUser {
|
||||
sub: request.sub().to_owned(),
|
||||
avatar_url: None,
|
||||
displayname: None,
|
||||
devices: HashSet::new(),
|
||||
emails: None,
|
||||
});
|
||||
|
||||
anyhow::ensure!(
|
||||
user.sub == request.sub(),
|
||||
"User already provisioned with different sub"
|
||||
);
|
||||
|
||||
request.on_emails(|emails| {
|
||||
user.emails = emails.map(ToOwned::to_owned);
|
||||
});
|
||||
|
||||
request.on_displayname(|displayname| {
|
||||
user.displayname = displayname.map(ToOwned::to_owned);
|
||||
});
|
||||
|
||||
request.on_avatar_url(|avatar_url| {
|
||||
user.avatar_url = avatar_url.map(ToOwned::to_owned);
|
||||
});
|
||||
|
||||
Ok(inserted)
|
||||
}
|
||||
|
||||
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
|
||||
let mut users = self.users.write().await;
|
||||
let user = users.get_mut(mxid).context("User not found")?;
|
||||
user.devices.insert(device_id.to_owned());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
|
||||
let mut users = self.users.write().await;
|
||||
let user = users.get_mut(mxid).context("User not found")?;
|
||||
user.devices.remove(device_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> {
|
||||
let mut users = self.users.write().await;
|
||||
let user = users.get_mut(mxid).context("User not found")?;
|
||||
user.devices.clear();
|
||||
user.emails = None;
|
||||
if erase {
|
||||
user.avatar_url = None;
|
||||
user.displayname = None;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::HomeserverConnection as _;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_connection() {
|
||||
let conn = HomeserverConnection::new("example.org");
|
||||
|
||||
let mxid = "@test:example.org";
|
||||
let device = "test";
|
||||
assert_eq!(conn.homeserver(), "example.org");
|
||||
assert_eq!(conn.mxid("test"), mxid);
|
||||
|
||||
assert!(conn.query_user(mxid).await.is_err());
|
||||
assert!(conn.create_device(mxid, device).await.is_err());
|
||||
assert!(conn.delete_device(mxid, device).await.is_err());
|
||||
|
||||
let request = ProvisionRequest::new("@test:example.org", "test")
|
||||
.set_displayname("Test User".into())
|
||||
.set_avatar_url("mxc://example.org/1234567890".into())
|
||||
.set_emails(vec!["test@example.org".to_owned()]);
|
||||
|
||||
let inserted = conn.provision_user(&request).await.unwrap();
|
||||
assert!(inserted);
|
||||
|
||||
let user = conn.query_user("@test:example.org").await.unwrap();
|
||||
assert_eq!(user.displayname, Some("Test User".into()));
|
||||
assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into()));
|
||||
|
||||
// Deleting a non-existent device should not fail
|
||||
assert!(conn.delete_device(mxid, device).await.is_ok());
|
||||
|
||||
// Create the device
|
||||
assert!(conn.create_device(mxid, device).await.is_ok());
|
||||
// Create the same device again
|
||||
assert!(conn.create_device(mxid, device).await.is_ok());
|
||||
|
||||
// XXX: there is no API to query devices yet in the trait
|
||||
// Delete the device
|
||||
assert!(conn.delete_device(mxid, device).await.is_ok());
|
||||
}
|
||||
}
|
||||
14
crates/storage-pg/.sqlx/query-22896e8f2a002f307089c3e0f9ee561e6521c45ce07d3a42411984c9a6b75fdc.json
generated
Normal file
14
crates/storage-pg/.sqlx/query-22896e8f2a002f307089c3e0f9ee561e6521c45ce07d3a42411984c9a6b75fdc.json
generated
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE users\n SET locked_at = NULL\n WHERE user_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "22896e8f2a002f307089c3e0f9ee561e6521c45ce07d3a42411984c9a6b75fdc"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n WHERE s.user_session_id = $1\n ",
|
||||
"query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , u.created_at AS \"user_created_at\"\n , u.locked_at AS \"user_locked_at\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n WHERE s.user_session_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -32,6 +32,16 @@
|
||||
"ordinal": 5,
|
||||
"name": "user_primary_user_email_id",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "user_created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "user_locked_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
@@ -45,8 +55,10 @@
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "25d61a373560556deafe056c8cd2982ac472f5ec2fab08b0b5275c4b78c11a7e"
|
||||
"hash": "73fe61f03a41778c6273b1c2dbdb13b91fbccfe5fbdbead8c4868d52a61a0f9d"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n ",
|
||||
"query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n ON CONFLICT (username) DO NOTHING\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
@@ -12,5 +12,5 @@
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537"
|
||||
"hash": "7f4c4634ada4dc2745530dcca8eee92abf78dfbdf1a25e58a2bc9c14be8035f0"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE username = $1\n ",
|
||||
"query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n , locked_at\n FROM users\n WHERE username = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -22,6 +22,11 @@
|
||||
"ordinal": 3,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "locked_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
@@ -33,8 +38,9 @@
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c"
|
||||
"hash": "bfa5eaeaa5b4574bb083c86711eb4599f6374c96bb4a6827d400acb22fb0fd39"
|
||||
}
|
||||
15
crates/storage-pg/.sqlx/query-c29fa41743811a6ac3a9b952b6ea75d18e914f823902587b63c9f295407144b1.json
generated
Normal file
15
crates/storage-pg/.sqlx/query-c29fa41743811a6ac3a9b952b6ea75d18e914f823902587b63c9f295407144b1.json
generated
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE users\n SET locked_at = $1\n WHERE user_id = $2\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Timestamptz",
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "c29fa41743811a6ac3a9b952b6ea75d18e914f823902587b63c9f295407144b1"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE user_id = $1\n ",
|
||||
"query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n , locked_at\n FROM users\n WHERE user_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -22,6 +22,11 @@
|
||||
"ordinal": 3,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "locked_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
@@ -33,8 +38,9 @@
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35"
|
||||
"hash": "e0ea7d93ab3f565828b2faab4cc5e1a6ac868c95bfaee3a6960df1cf484d53da"
|
||||
}
|
||||
19
crates/storage-pg/migrations/20230728154304_user_lock.sql
Normal file
19
crates/storage-pg/migrations/20230728154304_user_lock.sql
Normal file
@@ -0,0 +1,19 @@
|
||||
-- Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
-- Add a new column in on the `users` to record when an account gets locked
|
||||
ALTER TABLE "users"
|
||||
ADD COLUMN "locked_at"
|
||||
TIMESTAMP WITH TIME ZONE
|
||||
DEFAULT NULL;
|
||||
@@ -29,6 +29,8 @@ pub enum Users {
|
||||
UserId,
|
||||
Username,
|
||||
PrimaryUserEmailId,
|
||||
CreatedAt,
|
||||
LockedAt,
|
||||
}
|
||||
|
||||
#[derive(sea_query::Iden)]
|
||||
|
||||
@@ -31,6 +31,7 @@ use mas_storage::{
|
||||
Repository, RepositoryAccess, RepositoryTransaction,
|
||||
};
|
||||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
compat::{
|
||||
@@ -78,11 +79,21 @@ impl RepositoryTransaction for PgRepository {
|
||||
type Error = DatabaseError;
|
||||
|
||||
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
|
||||
self.txn.commit().map_err(DatabaseError::from).boxed()
|
||||
let span = tracing::info_span!("db.save");
|
||||
self.txn
|
||||
.commit()
|
||||
.map_err(DatabaseError::from)
|
||||
.instrument(span)
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
|
||||
self.txn.rollback().map_err(DatabaseError::from).boxed()
|
||||
let span = tracing::info_span!("db.cancel");
|
||||
self.txn
|
||||
.rollback()
|
||||
.map_err(DatabaseError::from)
|
||||
.instrument(span)
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -55,9 +55,8 @@ struct UserLookup {
|
||||
user_id: Uuid,
|
||||
username: String,
|
||||
primary_user_email_id: Option<Uuid>,
|
||||
|
||||
#[allow(dead_code)]
|
||||
created_at: DateTime<Utc>,
|
||||
locked_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl From<UserLookup> for User {
|
||||
@@ -68,6 +67,8 @@ impl From<UserLookup> for User {
|
||||
username: value.username,
|
||||
sub: id.to_string(),
|
||||
primary_user_email_id: value.primary_user_email_id.map(Into::into),
|
||||
created_at: value.created_at,
|
||||
locked_at: value.locked_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,6 +94,7 @@ impl<'c> UserRepository for PgUserRepository<'c> {
|
||||
, username
|
||||
, primary_user_email_id
|
||||
, created_at
|
||||
, locked_at
|
||||
FROM users
|
||||
WHERE user_id = $1
|
||||
"#,
|
||||
@@ -124,6 +126,7 @@ impl<'c> UserRepository for PgUserRepository<'c> {
|
||||
, username
|
||||
, primary_user_email_id
|
||||
, created_at
|
||||
, locked_at
|
||||
FROM users
|
||||
WHERE username = $1
|
||||
"#,
|
||||
@@ -158,10 +161,11 @@ impl<'c> UserRepository for PgUserRepository<'c> {
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("user.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO users (user_id, username, created_at)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (username) DO NOTHING
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
username,
|
||||
@@ -171,11 +175,17 @@ impl<'c> UserRepository for PgUserRepository<'c> {
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
// If the user already exists, want to return an error but not poison the
|
||||
// transaction
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
Ok(User {
|
||||
id,
|
||||
username,
|
||||
sub: id.to_string(),
|
||||
primary_user_email_id: None,
|
||||
created_at,
|
||||
locked_at: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -203,4 +213,72 @@ impl<'c> UserRepository for PgUserRepository<'c> {
|
||||
|
||||
Ok(exists)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user.lock",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
|
||||
if user.locked_at.is_some() {
|
||||
return Ok(user);
|
||||
}
|
||||
|
||||
let locked_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE users
|
||||
SET locked_at = $1
|
||||
WHERE user_id = $2
|
||||
"#,
|
||||
locked_at,
|
||||
Uuid::from(user.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
user.locked_at = Some(locked_at);
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user.unlock",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
|
||||
if user.locked_at.is_none() {
|
||||
return Ok(user);
|
||||
}
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE users
|
||||
SET locked_at = NULL
|
||||
WHERE user_id = $1
|
||||
"#,
|
||||
Uuid::from(user.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
user.locked_at = None;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +53,8 @@ struct SessionLookup {
|
||||
user_id: Uuid,
|
||||
user_username: String,
|
||||
user_primary_user_email_id: Option<Uuid>,
|
||||
user_created_at: DateTime<Utc>,
|
||||
user_locked_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TryFrom<SessionLookup> for BrowserSession {
|
||||
@@ -65,6 +67,8 @@ impl TryFrom<SessionLookup> for BrowserSession {
|
||||
username: value.user_username,
|
||||
sub: id.to_string(),
|
||||
primary_user_email_id: value.user_primary_user_email_id.map(Into::into),
|
||||
created_at: value.user_created_at,
|
||||
locked_at: value.user_locked_at,
|
||||
};
|
||||
|
||||
Ok(BrowserSession {
|
||||
@@ -99,6 +103,8 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
, u.user_id
|
||||
, u.username AS "user_username"
|
||||
, u.primary_user_email_id AS "user_primary_user_email_id"
|
||||
, u.created_at AS "user_created_at"
|
||||
, u.locked_at AS "user_locked_at"
|
||||
FROM user_sessions s
|
||||
INNER JOIN users u
|
||||
USING (user_id)
|
||||
@@ -232,6 +238,14 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
Expr::col((Users::Table, Users::PrimaryUserEmailId)),
|
||||
SessionLookupIden::UserPrimaryUserEmailId,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((Users::Table, Users::CreatedAt)),
|
||||
SessionLookupIden::UserCreatedAt,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((Users::Table, Users::LockedAt)),
|
||||
SessionLookupIden::UserLockedAt,
|
||||
)
|
||||
.from(UserSessions::Table)
|
||||
.inner_join(
|
||||
Users::Table,
|
||||
|
||||
@@ -63,12 +63,38 @@ async fn test_user_repo(pool: PgPool) {
|
||||
assert!(repo.user().lookup(user.id).await.unwrap().is_some());
|
||||
|
||||
// Adding a second time should give a conflict
|
||||
// It should not poison the transaction though
|
||||
assert!(repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, USERNAME.to_owned())
|
||||
.await
|
||||
.is_err());
|
||||
|
||||
// Try locking a user
|
||||
assert!(user.is_valid());
|
||||
let user = repo.user().lock(&clock, user).await.unwrap();
|
||||
assert!(!user.is_valid());
|
||||
|
||||
// Check that the property is retrieved on lookup
|
||||
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
|
||||
assert!(!user.is_valid());
|
||||
|
||||
// Locking a second time should not fail
|
||||
let user = repo.user().lock(&clock, user).await.unwrap();
|
||||
assert!(!user.is_valid());
|
||||
|
||||
// Try unlocking a user
|
||||
let user = repo.user().unlock(user).await.unwrap();
|
||||
assert!(user.is_valid());
|
||||
|
||||
// Check that the property is retrieved on lookup
|
||||
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
|
||||
assert!(user.is_valid());
|
||||
|
||||
// Unlocking a second time should not fail
|
||||
let user = repo.user().unlock(user).await.unwrap();
|
||||
assert!(user.is_valid());
|
||||
|
||||
repo.save().await.unwrap();
|
||||
}
|
||||
|
||||
|
||||
@@ -73,6 +73,15 @@ pub struct JobWithSpanContext<T> {
|
||||
payload: T,
|
||||
}
|
||||
|
||||
impl<J> From<J> for JobWithSpanContext<J> {
|
||||
fn from(payload: J) -> Self {
|
||||
Self {
|
||||
span_context: None,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<J: Job> Job for JobWithSpanContext<J> {
|
||||
const NAME: &'static str = J::NAME;
|
||||
}
|
||||
@@ -369,6 +378,47 @@ mod jobs {
|
||||
impl Job for DeleteDeviceJob {
|
||||
const NAME: &'static str = "delete-device";
|
||||
}
|
||||
|
||||
/// A job to deactivate and lock a user
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct DeactivateUserJob {
|
||||
user_id: Ulid,
|
||||
hs_erase: bool,
|
||||
}
|
||||
|
||||
impl DeactivateUserJob {
|
||||
/// Create a new job to deactivate and lock a user
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `user` - The user to deactivate
|
||||
/// * `hs_erase` - Whether to erase the user from the homeserver
|
||||
#[must_use]
|
||||
pub fn new(user: &User, hs_erase: bool) -> Self {
|
||||
Self {
|
||||
user_id: user.id,
|
||||
hs_erase,
|
||||
}
|
||||
}
|
||||
|
||||
/// The ID of the user to deactivate
|
||||
#[must_use]
|
||||
pub fn user_id(&self) -> Ulid {
|
||||
self.user_id
|
||||
}
|
||||
|
||||
/// Whether to erase the user from the homeserver
|
||||
#[must_use]
|
||||
pub fn hs_erase(&self) -> bool {
|
||||
self.hs_erase
|
||||
}
|
||||
}
|
||||
|
||||
impl Job for DeactivateUserJob {
|
||||
const NAME: &'static str = "deactivate-user";
|
||||
}
|
||||
}
|
||||
|
||||
pub use self::jobs::{DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, VerifyEmailJob};
|
||||
pub use self::jobs::{
|
||||
DeactivateUserJob, DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, VerifyEmailJob,
|
||||
};
|
||||
|
||||
@@ -96,6 +96,33 @@ pub trait UserRepository: Send + Sync {
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn exists(&mut self, username: &str) -> Result<bool, Self::Error>;
|
||||
|
||||
/// Lock a [`User`]
|
||||
///
|
||||
/// Returns the locked [`User`]
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `clock`: The clock used to generate timestamps
|
||||
/// * `user`: The [`User`] to lock
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn lock(&mut self, clock: &dyn Clock, user: User) -> Result<User, Self::Error>;
|
||||
|
||||
/// Unlock a [`User`]
|
||||
///
|
||||
/// Returns the unlocked [`User`]
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `user`: The [`User`] to unlock
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn unlock(&mut self, user: User) -> Result<User, Self::Error>;
|
||||
}
|
||||
|
||||
repository_impl!(UserRepository:
|
||||
@@ -108,4 +135,6 @@ repository_impl!(UserRepository:
|
||||
username: String,
|
||||
) -> Result<User, Self::Error>;
|
||||
async fn exists(&mut self, username: &str) -> Result<bool, Self::Error>;
|
||||
async fn lock(&mut self, clock: &dyn Clock, user: User) -> Result<User, Self::Error>;
|
||||
async fn unlock(&mut self, user: User) -> Result<User, Self::Error>;
|
||||
);
|
||||
|
||||
@@ -29,7 +29,10 @@ use chrono::{DateTime, Utc};
|
||||
use mas_storage::{oauth2::OAuth2AccessTokenRepository, RepositoryAccess};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{utils::metrics_layer, JobContextExt, State};
|
||||
use crate::{
|
||||
utils::{metrics_layer, trace_layer, TracedJob},
|
||||
JobContextExt, State,
|
||||
};
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct CleanupExpiredTokensJob {
|
||||
@@ -46,6 +49,8 @@ impl Job for CleanupExpiredTokensJob {
|
||||
const NAME: &'static str = "cleanup-expired-tokens";
|
||||
}
|
||||
|
||||
impl TracedJob for CleanupExpiredTokensJob {}
|
||||
|
||||
pub async fn cleanup_expired_tokens(
|
||||
job: CleanupExpiredTokensJob,
|
||||
ctx: JobContext,
|
||||
@@ -79,6 +84,7 @@ pub(crate) fn register(
|
||||
.stream(CronStream::new(schedule).timer(TokioTimer).to_stream())
|
||||
.layer(state.inject())
|
||||
.layer(metrics_layer())
|
||||
.layer(trace_layer())
|
||||
.build_fn(cleanup_expired_tokens);
|
||||
|
||||
monitor.register(worker)
|
||||
|
||||
@@ -13,25 +13,14 @@
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use apalis_core::{
|
||||
builder::{WorkerBuilder, WorkerFactoryFn},
|
||||
context::JobContext,
|
||||
executor::TokioExecutor,
|
||||
job::Job,
|
||||
monitor::Monitor,
|
||||
storage::builder::WithStorage,
|
||||
};
|
||||
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
|
||||
use chrono::Duration;
|
||||
use mas_email::{Address, EmailVerificationContext, Mailbox};
|
||||
use mas_storage::job::{JobWithSpanContext, VerifyEmailJob};
|
||||
use rand::{distributions::Uniform, Rng};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
storage::PostgresStorageFactory,
|
||||
utils::{metrics_layer, trace_layer},
|
||||
JobContextExt, State,
|
||||
};
|
||||
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "job.verify_email",
|
||||
@@ -99,15 +88,8 @@ pub(crate) fn register(
|
||||
state: &State,
|
||||
storage_factory: &PostgresStorageFactory,
|
||||
) -> Monitor<TokioExecutor> {
|
||||
let storage = storage_factory.build();
|
||||
let worker_name = format!("{job}-{suffix}", job = VerifyEmailJob::NAME);
|
||||
let worker = WorkerBuilder::new(worker_name)
|
||||
.layer(state.inject())
|
||||
.layer(trace_layer())
|
||||
.layer(metrics_layer())
|
||||
.with_storage_config(storage, |c| {
|
||||
c.fetch_interval(std::time::Duration::from_secs(1))
|
||||
})
|
||||
.build_fn(verify_email);
|
||||
monitor.register(worker)
|
||||
let verify_email_worker =
|
||||
crate::build!(VerifyEmailJob => verify_email, suffix, state, storage_factory);
|
||||
|
||||
monitor.register(verify_email_worker)
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ mod database;
|
||||
mod email;
|
||||
mod matrix;
|
||||
mod storage;
|
||||
mod user;
|
||||
mod utils;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -106,6 +107,32 @@ impl JobContextExt for apalis_core::context::JobContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper macro to build a storage-backed worker.
|
||||
macro_rules! build {
|
||||
($job:ty => $fn:ident, $suffix:expr, $state:expr, $factory:expr) => {{
|
||||
let storage = $factory.build();
|
||||
let worker_name = format!(
|
||||
"{job}-{suffix}",
|
||||
job = <$job as ::apalis_core::job::Job>::NAME,
|
||||
suffix = $suffix
|
||||
);
|
||||
|
||||
let builder = ::apalis_core::builder::WorkerBuilder::new(worker_name)
|
||||
.layer($state.inject())
|
||||
.layer(crate::utils::trace_layer())
|
||||
.layer(crate::utils::metrics_layer());
|
||||
|
||||
let builder = ::apalis_core::storage::builder::WithStorage::with_storage_config(
|
||||
builder,
|
||||
storage,
|
||||
|c| c.fetch_interval(std::time::Duration::from_secs(1)),
|
||||
);
|
||||
::apalis_core::builder::WorkerFactory::build(builder, ::apalis_core::job_fn::job_fn($fn))
|
||||
}};
|
||||
}
|
||||
|
||||
pub(crate) use build;
|
||||
|
||||
/// Initialise the workers.
|
||||
///
|
||||
/// # Errors
|
||||
@@ -128,6 +155,7 @@ pub async fn init(
|
||||
let monitor = self::database::register(name, monitor, &state);
|
||||
let monitor = self::email::register(name, monitor, &state, &factory);
|
||||
let monitor = self::matrix::register(name, monitor, &state, &factory);
|
||||
let monitor = self::user::register(name, monitor, &state, &factory);
|
||||
// TODO: we might want to grab the join handle here
|
||||
factory.listen().await?;
|
||||
debug!(?monitor, "workers registered");
|
||||
|
||||
@@ -12,17 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use apalis_core::{
|
||||
builder::{WorkerBuilder, WorkerFactoryFn},
|
||||
context::JobContext,
|
||||
executor::TokioExecutor,
|
||||
job::Job,
|
||||
monitor::Monitor,
|
||||
storage::builder::WithStorage,
|
||||
};
|
||||
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
|
||||
use mas_matrix::ProvisionRequest;
|
||||
use mas_storage::{
|
||||
job::{DeleteDeviceJob, JobWithSpanContext, ProvisionDeviceJob, ProvisionUserJob},
|
||||
@@ -31,11 +22,7 @@ use mas_storage::{
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
storage::PostgresStorageFactory,
|
||||
utils::{metrics_layer, trace_layer},
|
||||
JobContextExt, State,
|
||||
};
|
||||
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
|
||||
|
||||
/// Job to provision a user on the Matrix homeserver.
|
||||
/// This works by doing a PUT request to the /_synapse/admin/v2/users/{user_id}
|
||||
@@ -163,32 +150,12 @@ pub(crate) fn register(
|
||||
state: &State,
|
||||
storage_factory: &PostgresStorageFactory,
|
||||
) -> Monitor<TokioExecutor> {
|
||||
let storage = storage_factory.build();
|
||||
let worker_name = format!("{job}-{suffix}", job = ProvisionUserJob::NAME);
|
||||
let provision_user_worker = WorkerBuilder::new(worker_name)
|
||||
.layer(state.inject())
|
||||
.layer(trace_layer())
|
||||
.layer(metrics_layer())
|
||||
.with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1)))
|
||||
.build_fn(provision_user);
|
||||
|
||||
let storage = storage_factory.build();
|
||||
let worker_name = format!("{job}-{suffix}", job = ProvisionDeviceJob::NAME);
|
||||
let provision_device_worker = WorkerBuilder::new(worker_name)
|
||||
.layer(state.inject())
|
||||
.layer(trace_layer())
|
||||
.layer(metrics_layer())
|
||||
.with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1)))
|
||||
.build_fn(provision_device);
|
||||
|
||||
let storage = storage_factory.build();
|
||||
let worker_name = format!("{job}-{suffix}", job = DeleteDeviceJob::NAME);
|
||||
let delete_device_worker = WorkerBuilder::new(worker_name)
|
||||
.layer(state.inject())
|
||||
.layer(trace_layer())
|
||||
.layer(metrics_layer())
|
||||
.with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1)))
|
||||
.build_fn(delete_device);
|
||||
let provision_user_worker =
|
||||
crate::build!(ProvisionUserJob => provision_user, suffix, state, storage_factory);
|
||||
let provision_device_worker =
|
||||
crate::build!(ProvisionDeviceJob => provision_device, suffix, state, storage_factory);
|
||||
let delete_device_worker =
|
||||
crate::build!(DeleteDeviceJob => delete_device, suffix, state, storage_factory);
|
||||
|
||||
monitor
|
||||
.register(provision_user_worker)
|
||||
|
||||
77
crates/tasks/src/user.rs
Normal file
77
crates/tasks/src/user.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
|
||||
use mas_storage::{
|
||||
job::{DeactivateUserJob, JobWithSpanContext},
|
||||
user::UserRepository,
|
||||
RepositoryAccess,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
|
||||
|
||||
/// Job to deactivate a user, both locally and on the Matrix homeserver.
|
||||
#[tracing::instrument(
|
||||
name = "job.deactivate_user"
|
||||
fields(user.id = %job.user_id(), erase = %job.hs_erase()),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn deactivate_user(
|
||||
job: JobWithSpanContext<DeactivateUserJob>,
|
||||
ctx: JobContext,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let state = ctx.state();
|
||||
let clock = state.clock();
|
||||
let matrix = state.matrix_connection();
|
||||
let mut repo = state.repository().await?;
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(job.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
// Let's first lock the user
|
||||
let user = repo
|
||||
.user()
|
||||
.lock(&clock, user)
|
||||
.await
|
||||
.context("Failed to lock user")?;
|
||||
|
||||
// TODO: delete the sessions & access tokens
|
||||
|
||||
// Before calling back to the homeserver, commit the changes to the database
|
||||
repo.save().await?;
|
||||
|
||||
let mxid = matrix.mxid(&user.username);
|
||||
info!("Deactivating user {} on homeserver", mxid);
|
||||
matrix.delete_user(&mxid, job.hs_erase()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn register(
|
||||
suffix: &str,
|
||||
monitor: Monitor<TokioExecutor>,
|
||||
state: &State,
|
||||
storage_factory: &PostgresStorageFactory,
|
||||
) -> Monitor<TokioExecutor> {
|
||||
let deactivate_user_worker =
|
||||
crate::build!(DeactivateUserJob => deactivate_user, suffix, state, storage_factory);
|
||||
|
||||
monitor.register(deactivate_user_worker)
|
||||
}
|
||||
@@ -18,14 +18,32 @@ use mas_tower::{
|
||||
make_span_fn, DurationRecorderLayer, FnWrapper, IdentityLayer, InFlightCounterLayer,
|
||||
TraceLayer, KV,
|
||||
};
|
||||
use opentelemetry::{Key, KeyValue};
|
||||
use opentelemetry::{trace::SpanContext, Key, KeyValue};
|
||||
use tracing::info_span;
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
const JOB_NAME: Key = Key::from_static_str("job.name");
|
||||
const JOB_STATUS: Key = Key::from_static_str("job.status");
|
||||
|
||||
fn make_span_for_job_request<J>(req: &JobRequest<JobWithSpanContext<J>>) -> tracing::Span
|
||||
/// Represents a job that can may have a span context attached to it.
|
||||
pub trait TracedJob: Job {
|
||||
/// Returns the span context for this job, if any.
|
||||
///
|
||||
/// The default implementation returns `None`.
|
||||
fn span_context(&self) -> Option<SpanContext> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements [`TracedJob`] for any job with the [`JobWithSpanContext`]
|
||||
/// wrapper.
|
||||
impl<J: Job> TracedJob for JobWithSpanContext<J> {
|
||||
fn span_context(&self) -> Option<SpanContext> {
|
||||
JobWithSpanContext::span_context(self)
|
||||
}
|
||||
}
|
||||
|
||||
fn make_span_for_job_request<J: TracedJob>(req: &JobRequest<J>) -> tracing::Span
|
||||
where
|
||||
J: Job,
|
||||
{
|
||||
@@ -45,18 +63,15 @@ where
|
||||
span
|
||||
}
|
||||
|
||||
type TraceLayerForJob<J> = TraceLayer<
|
||||
FnWrapper<fn(&JobRequest<JobWithSpanContext<J>>) -> tracing::Span>,
|
||||
KV<&'static str>,
|
||||
KV<&'static str>,
|
||||
>;
|
||||
type TraceLayerForJob<J> =
|
||||
TraceLayer<FnWrapper<fn(&JobRequest<J>) -> tracing::Span>, KV<&'static str>, KV<&'static str>>;
|
||||
|
||||
pub(crate) fn trace_layer<J>() -> TraceLayerForJob<J>
|
||||
where
|
||||
J: Job,
|
||||
J: TracedJob,
|
||||
{
|
||||
TraceLayer::new(make_span_fn(
|
||||
make_span_for_job_request::<J> as fn(&JobRequest<JobWithSpanContext<J>>) -> tracing::Span,
|
||||
make_span_for_job_request::<J> as fn(&JobRequest<J>) -> tracing::Span,
|
||||
))
|
||||
.on_response(KV("otel.status_code", "OK"))
|
||||
.on_error(KV("otel.status_code", "ERROR"))
|
||||
|
||||
Reference in New Issue
Block a user