Merge branch 'main' into rei/reapply_5297
This commit is contained in:
@@ -12,10 +12,9 @@ use clap::Parser;
|
||||
use figment::Figment;
|
||||
use mas_config::{ConfigurationSection, RootConfig, SyncConfig};
|
||||
use mas_data_model::{Clock as _, SystemClock};
|
||||
use mas_storage_pg::MIGRATOR;
|
||||
use rand::SeedableRng;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing::{Instrument, info, info_span};
|
||||
use tracing::{info, info_span};
|
||||
|
||||
use crate::util::database_connection_from_config;
|
||||
|
||||
@@ -129,9 +128,7 @@ impl Options {
|
||||
// Grab a connection to the database
|
||||
let mut conn = database_connection_from_config(&config.database).await?;
|
||||
|
||||
MIGRATOR
|
||||
.run(&mut conn)
|
||||
.instrument(info_span!("db.migrate"))
|
||||
mas_storage_pg::migrate(&mut conn)
|
||||
.await
|
||||
.context("could not run migrations")?;
|
||||
|
||||
|
||||
@@ -10,8 +10,7 @@ use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use figment::Figment;
|
||||
use mas_config::{ConfigurationSectionExt, DatabaseConfig};
|
||||
use mas_storage_pg::MIGRATOR;
|
||||
use tracing::{Instrument, info_span};
|
||||
use tracing::info_span;
|
||||
|
||||
use crate::util::database_connection_from_config;
|
||||
|
||||
@@ -35,9 +34,7 @@ impl Options {
|
||||
let mut conn = database_connection_from_config(&config).await?;
|
||||
|
||||
// Run pending migrations
|
||||
MIGRATOR
|
||||
.run(&mut conn)
|
||||
.instrument(info_span!("db.migrate"))
|
||||
mas_storage_pg::migrate(&mut conn)
|
||||
.await
|
||||
.context("could not run migrations")?;
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
use std::{collections::BTreeSet, process::ExitCode, sync::Arc, time::Duration};
|
||||
use std::{process::ExitCode, sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
@@ -18,9 +18,8 @@ use mas_data_model::SystemClock;
|
||||
use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache};
|
||||
use mas_listener::server::Server;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage_pg::{MIGRATOR, PgRepositoryFactory};
|
||||
use sqlx::migrate::Migrate;
|
||||
use tracing::{Instrument, info, info_span, warn};
|
||||
use mas_storage_pg::PgRepositoryFactory;
|
||||
use tracing::{info, info_span, warn};
|
||||
|
||||
use crate::{
|
||||
app_state::AppState,
|
||||
@@ -73,24 +72,20 @@ impl Options {
|
||||
let pool = database_pool_from_config(&config.database).await?;
|
||||
|
||||
if self.no_migrate {
|
||||
// Check that we applied all the migrations
|
||||
let mut conn = pool.acquire().await?;
|
||||
let applied = conn.list_applied_migrations().await?;
|
||||
let applied: BTreeSet<_> = applied.into_iter().map(|m| m.version).collect();
|
||||
let has_missing_migrations = MIGRATOR.iter().any(|m| !applied.contains(&m.version));
|
||||
if has_missing_migrations {
|
||||
let pending_migrations = mas_storage_pg::pending_migrations(&mut conn).await?;
|
||||
if !pending_migrations.is_empty() {
|
||||
// Refuse to start if there are pending migrations
|
||||
return Err(anyhow::anyhow!(
|
||||
"The server is running with `--no-migrate` but there are pending. Please run them first with `mas-cli database migrate`, or omit the `--no-migrate` flag to apply them automatically on startup."
|
||||
"The server is running with `--no-migrate` but there are pending migrations. Please run them first with `mas-cli database migrate`, or omit the `--no-migrate` flag to apply them automatically on startup."
|
||||
));
|
||||
}
|
||||
} else {
|
||||
info!("Running pending database migrations");
|
||||
MIGRATOR
|
||||
.run(&pool)
|
||||
.instrument(info_span!("db.migrate"))
|
||||
let mut conn = pool.acquire().await?;
|
||||
mas_storage_pg::migrate(&mut conn)
|
||||
.await
|
||||
.context("could not run database migrations")?;
|
||||
.context("could not run migrations")?;
|
||||
}
|
||||
|
||||
let encrypter = config.secrets.encrypter().await?;
|
||||
|
||||
@@ -14,13 +14,12 @@ use mas_config::{
|
||||
UpstreamOAuth2Config,
|
||||
};
|
||||
use mas_data_model::SystemClock;
|
||||
use mas_storage_pg::MIGRATOR;
|
||||
use rand::thread_rng;
|
||||
use sqlx::{Connection, Either, PgConnection, postgres::PgConnectOptions, types::Uuid};
|
||||
use syn2mas::{
|
||||
LockedMasDatabase, MasWriter, Progress, ProgressStage, SynapseReader, synapse_config,
|
||||
};
|
||||
use tracing::{Instrument, error, info, info_span};
|
||||
use tracing::{Instrument, error, info};
|
||||
|
||||
use crate::util::{DatabaseConnectOptions, database_connection_from_config_with_options};
|
||||
|
||||
@@ -122,9 +121,7 @@ impl Options {
|
||||
)
|
||||
.await?;
|
||||
|
||||
MIGRATOR
|
||||
.run(&mut mas_connection)
|
||||
.instrument(info_span!("db.migrate"))
|
||||
mas_storage_pg::migrate(&mut mas_connection)
|
||||
.await
|
||||
.context("could not run migrations")?;
|
||||
|
||||
|
||||
@@ -45,6 +45,12 @@ fn map_import_on_conflict(
|
||||
mas_config::UpstreamOAuth2OnConflict::Add => {
|
||||
mas_data_model::UpstreamOAuthProviderOnConflict::Add
|
||||
}
|
||||
mas_config::UpstreamOAuth2OnConflict::Replace => {
|
||||
mas_data_model::UpstreamOAuthProviderOnConflict::Replace
|
||||
}
|
||||
mas_config::UpstreamOAuth2OnConflict::Set => {
|
||||
mas_data_model::UpstreamOAuthProviderOnConflict::Set
|
||||
}
|
||||
mas_config::UpstreamOAuth2OnConflict::Fail => {
|
||||
mas_data_model::UpstreamOAuthProviderOnConflict::Fail
|
||||
}
|
||||
@@ -58,6 +64,7 @@ fn map_claims_imports(
|
||||
subject: mas_data_model::UpstreamOAuthProviderSubjectPreference {
|
||||
template: config.subject.template.clone(),
|
||||
},
|
||||
skip_confirmation: config.skip_confirmation,
|
||||
localpart: mas_data_model::UpstreamOAuthProviderLocalpartPreference {
|
||||
action: map_import_action(config.localpart.action),
|
||||
template: config.localpart.template.clone(),
|
||||
|
||||
@@ -145,6 +145,7 @@ pub async fn policy_factory_from_config(
|
||||
register: config.register_entrypoint.clone(),
|
||||
client_registration: config.client_registration_entrypoint.clone(),
|
||||
authorization_grant: config.authorization_grant_entrypoint.clone(),
|
||||
compat_login: config.compat_login_entrypoint.clone(),
|
||||
email: config.email_entrypoint.clone(),
|
||||
};
|
||||
|
||||
|
||||
@@ -62,6 +62,14 @@ fn is_default_password_entrypoint(value: &String) -> bool {
|
||||
*value == default_password_entrypoint()
|
||||
}
|
||||
|
||||
fn default_compat_login_entrypoint() -> String {
|
||||
"compat_login/violation".to_owned()
|
||||
}
|
||||
|
||||
fn is_default_compat_login_entrypoint(value: &String) -> bool {
|
||||
*value == default_compat_login_entrypoint()
|
||||
}
|
||||
|
||||
fn default_email_entrypoint() -> String {
|
||||
"email/violation".to_owned()
|
||||
}
|
||||
@@ -111,6 +119,13 @@ pub struct PolicyConfig {
|
||||
)]
|
||||
pub authorization_grant_entrypoint: String,
|
||||
|
||||
/// Entrypoint to use when evaluating compatibility logins
|
||||
#[serde(
|
||||
default = "default_compat_login_entrypoint",
|
||||
skip_serializing_if = "is_default_compat_login_entrypoint"
|
||||
)]
|
||||
pub compat_login_entrypoint: String,
|
||||
|
||||
/// Entrypoint to use when changing password
|
||||
#[serde(
|
||||
default = "default_password_entrypoint",
|
||||
@@ -137,6 +152,7 @@ impl Default for PolicyConfig {
|
||||
client_registration_entrypoint: default_client_registration_entrypoint(),
|
||||
register_entrypoint: default_register_entrypoint(),
|
||||
authorization_grant_entrypoint: default_authorization_grant_entrypoint(),
|
||||
compat_login_entrypoint: default_compat_login_entrypoint(),
|
||||
password_entrypoint: default_password_entrypoint(),
|
||||
email_entrypoint: default_email_entrypoint(),
|
||||
data: default_data(),
|
||||
|
||||
@@ -118,16 +118,36 @@ impl ConfigurationSection for UpstreamOAuth2Config {
|
||||
}
|
||||
}
|
||||
|
||||
if provider.claims_imports.skip_confirmation {
|
||||
if provider.claims_imports.localpart.action != ImportAction::Require {
|
||||
return Err(annotate(figment::Error::custom(
|
||||
"The field `action` must be `require` when `skip_confirmation` is set to `true`",
|
||||
)).with_path("claims_imports.localpart").into());
|
||||
}
|
||||
|
||||
if provider.claims_imports.email.action == ImportAction::Suggest {
|
||||
return Err(annotate(figment::Error::custom(
|
||||
"The field `action` must not be `suggest` when `skip_confirmation` is set to `true`",
|
||||
)).with_path("claims_imports.email").into());
|
||||
}
|
||||
|
||||
if provider.claims_imports.displayname.action == ImportAction::Suggest {
|
||||
return Err(annotate(figment::Error::custom(
|
||||
"The field `action` must not be `suggest` when `skip_confirmation` is set to `true`",
|
||||
)).with_path("claims_imports.displayname").into());
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(
|
||||
provider.claims_imports.localpart.on_conflict,
|
||||
OnConflict::Add
|
||||
OnConflict::Add | OnConflict::Replace | OnConflict::Set
|
||||
) && !matches!(
|
||||
provider.claims_imports.localpart.action,
|
||||
ImportAction::Force | ImportAction::Require
|
||||
) {
|
||||
return Err(annotate(figment::Error::custom(
|
||||
"The field `action` must be either `force` or `require` when `on_conflict` is set to `add`",
|
||||
)).into());
|
||||
"The field `action` must be either `force` or `require` when `on_conflict` is set to `add`, `replace` or `set`",
|
||||
)).with_path("claims_imports.localpart").into());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,13 +226,20 @@ impl ImportAction {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OnConflict {
|
||||
/// Fails the sso login on conflict
|
||||
/// Fails the upstream OAuth 2.0 login on conflict
|
||||
#[default]
|
||||
Fail,
|
||||
|
||||
/// Adds the oauth identity link, regardless of whether there is an existing
|
||||
/// link or not
|
||||
/// Adds the upstream OAuth 2.0 identity link, regardless of whether there
|
||||
/// is an existing link or not
|
||||
Add,
|
||||
|
||||
/// Replace any existing upstream OAuth 2.0 identity link
|
||||
Replace,
|
||||
|
||||
/// Adds the upstream OAuth 2.0 identity link *only* if there is no existing
|
||||
/// link for this provider on the matching user
|
||||
Set,
|
||||
}
|
||||
|
||||
impl OnConflict {
|
||||
@@ -326,6 +353,13 @@ pub struct ClaimsImports {
|
||||
#[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")]
|
||||
pub subject: SubjectImportPreference,
|
||||
|
||||
/// Whether to skip the interactive screen prompting the user to confirm the
|
||||
/// attributes that are being imported. This requires `localpart.action` to
|
||||
/// be `require` and other attribute actions to be either `ignore`, `force`
|
||||
/// or `require`
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub skip_confirmation: bool,
|
||||
|
||||
/// Import the localpart of the MXID
|
||||
#[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")]
|
||||
pub localpart: LocalpartImportPreference,
|
||||
@@ -337,8 +371,7 @@ pub struct ClaimsImports {
|
||||
)]
|
||||
pub displayname: DisplaynameImportPreference,
|
||||
|
||||
/// Import the email address of the user based on the `email` and
|
||||
/// `email_verified` claims
|
||||
/// Import the email address of the user
|
||||
#[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
|
||||
pub email: EmailImportPreference,
|
||||
|
||||
@@ -354,8 +387,10 @@ impl ClaimsImports {
|
||||
const fn is_default(&self) -> bool {
|
||||
self.subject.is_default()
|
||||
&& self.localpart.is_default()
|
||||
&& !self.skip_confirmation
|
||||
&& self.displayname.is_default()
|
||||
&& self.email.is_default()
|
||||
&& self.account_name.is_default()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -56,8 +56,8 @@ pub use self::{
|
||||
},
|
||||
user_agent::{DeviceType, UserAgent},
|
||||
users::{
|
||||
Authentication, AuthenticationMethod, BrowserSession, Password, User, UserEmail,
|
||||
UserEmailAuthentication, UserEmailAuthenticationCode, UserRecoverySession,
|
||||
Authentication, AuthenticationMethod, BrowserSession, MatrixUser, Password, User,
|
||||
UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode, UserRecoverySession,
|
||||
UserRecoveryTicket, UserRegistration, UserRegistrationPassword, UserRegistrationToken,
|
||||
},
|
||||
utils::{BoxClock, BoxRng},
|
||||
|
||||
@@ -312,6 +312,9 @@ pub struct ClaimsImports {
|
||||
#[serde(default)]
|
||||
pub subject: SubjectPreference,
|
||||
|
||||
#[serde(default)]
|
||||
pub skip_confirmation: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub localpart: LocalpartPreference,
|
||||
|
||||
@@ -415,11 +418,18 @@ impl ImportAction {
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OnConflict {
|
||||
/// Fails the upstream OAuth 2.0 login
|
||||
/// Fails the upstream OAuth 2.0 login on conflict
|
||||
#[default]
|
||||
Fail,
|
||||
|
||||
/// Adds the upstream account link, regardless of whether there is an
|
||||
/// existing link or not
|
||||
/// Adds the upstream OAuth 2.0 identity link, regardless of whether there
|
||||
/// is an existing link or not
|
||||
Add,
|
||||
|
||||
/// Replace any existing upstream OAuth 2.0 identity link
|
||||
Replace,
|
||||
|
||||
/// Adds the upstream OAuth 2.0 identity link *only* if there is no existing
|
||||
/// link for this provider on the matching user
|
||||
Set,
|
||||
}
|
||||
|
||||
@@ -12,6 +12,12 @@ use serde::Serialize;
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct MatrixUser {
|
||||
pub mxid: String,
|
||||
pub display_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct User {
|
||||
pub id: Ulid,
|
||||
|
||||
@@ -16,6 +16,7 @@ use mas_data_model::{
|
||||
User,
|
||||
};
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_policy::{Policy, Requester, ViolationCode, model::CompatLogin};
|
||||
use mas_storage::{
|
||||
BoxRepository, BoxRepositoryFactory, RepositoryAccess,
|
||||
compat::{
|
||||
@@ -37,6 +38,7 @@ use crate::{
|
||||
BoundActivityTracker, Limiter, METER, RequesterFingerprint, impl_from_error_for_route,
|
||||
passwords::{PasswordManager, PasswordVerificationResult},
|
||||
rate_limit::PasswordCheckLimitedError,
|
||||
session::count_user_sessions_for_limiting,
|
||||
};
|
||||
|
||||
static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
|
||||
@@ -213,9 +215,16 @@ pub enum RouteError {
|
||||
|
||||
#[error("failed to provision device")]
|
||||
ProvisionDeviceFailed(#[source] anyhow::Error),
|
||||
|
||||
#[error("login rejected by policy")]
|
||||
PolicyRejected,
|
||||
|
||||
#[error("login rejected by policy (hard session limit reached)")]
|
||||
PolicyHardSessionLimitReached,
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
|
||||
impl From<anyhow::Error> for RouteError {
|
||||
fn from(err: anyhow::Error) -> Self {
|
||||
@@ -274,6 +283,16 @@ impl IntoResponse for RouteError {
|
||||
error: "User account has been locked",
|
||||
status: StatusCode::UNAUTHORIZED,
|
||||
},
|
||||
Self::PolicyRejected => MatrixError {
|
||||
errcode: "M_FORBIDDEN",
|
||||
error: "Login denied by the policy enforced by this service",
|
||||
status: StatusCode::FORBIDDEN,
|
||||
},
|
||||
Self::PolicyHardSessionLimitReached => MatrixError {
|
||||
errcode: "M_FORBIDDEN",
|
||||
error: "You have reached your hard device limit. Please visit your account page to sign some out.",
|
||||
status: StatusCode::FORBIDDEN,
|
||||
},
|
||||
};
|
||||
|
||||
(sentry_event_id, response).into_response()
|
||||
@@ -290,6 +309,7 @@ pub(crate) async fn post(
|
||||
State(homeserver): State<Arc<dyn HomeserverConnection>>,
|
||||
State(site_config): State<SiteConfig>,
|
||||
State(limiter): State<Limiter>,
|
||||
mut policy: Policy,
|
||||
requester: RequesterFingerprint,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
MatrixJsonBody(input): MatrixJsonBody<RequestBody>,
|
||||
@@ -329,6 +349,11 @@ pub(crate) async fn post(
|
||||
&limiter,
|
||||
requester,
|
||||
&mut repo,
|
||||
&mut policy,
|
||||
Requester {
|
||||
ip_address: activity_tracker.ip(),
|
||||
user_agent: user_agent.clone(),
|
||||
},
|
||||
username,
|
||||
password,
|
||||
input.device_id, // TODO check for validity
|
||||
@@ -342,6 +367,11 @@ pub(crate) async fn post(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&mut repo,
|
||||
&mut policy,
|
||||
Requester {
|
||||
ip_address: activity_tracker.ip(),
|
||||
user_agent: user_agent.clone(),
|
||||
},
|
||||
&token,
|
||||
input.device_id,
|
||||
input.initial_device_display_name,
|
||||
@@ -459,6 +489,8 @@ async fn token_login(
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
repo: &mut BoxRepository,
|
||||
policy: &mut Policy,
|
||||
requester: Requester,
|
||||
token: &str,
|
||||
requested_device_id: Option<String>,
|
||||
initial_device_display_name: Option<String>,
|
||||
@@ -544,10 +576,38 @@ async fn token_login(
|
||||
Device::generate(rng)
|
||||
};
|
||||
|
||||
repo.app_session()
|
||||
let session_replaced = repo
|
||||
.app_session()
|
||||
.finish_sessions_to_replace_device(clock, &browser_session.user, &device)
|
||||
.await?;
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(repo, &browser_session.user).await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_compat_login(mas_policy::CompatLoginInput {
|
||||
user: &browser_session.user,
|
||||
login: CompatLogin::Token,
|
||||
session_replaced,
|
||||
session_counts,
|
||||
requester,
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
// If the only violation is that we have too many sessions, then handle that
|
||||
// separately.
|
||||
// In the future, we intend to evict some sessions automatically instead. We
|
||||
// don't trigger this if there was some other violation anyway, since that means
|
||||
// that removing a session wouldn't actually unblock the login.
|
||||
if res.violations.len() == 1 {
|
||||
let violation = &res.violations[0];
|
||||
if violation.code == Some(ViolationCode::TooManySessions) {
|
||||
// The only violation is having reached the session limit.
|
||||
return Err(RouteError::PolicyHardSessionLimitReached);
|
||||
}
|
||||
}
|
||||
return Err(RouteError::PolicyRejected);
|
||||
}
|
||||
|
||||
// We first create the session in the database, commit the transaction, then
|
||||
// create it on the homeserver, scheduling a device sync job afterwards to
|
||||
// make sure we don't end up in an inconsistent state.
|
||||
@@ -578,6 +638,8 @@ async fn user_password_login(
|
||||
limiter: &Limiter,
|
||||
requester: RequesterFingerprint,
|
||||
repo: &mut BoxRepository,
|
||||
policy: &mut Policy,
|
||||
policy_requester: Requester,
|
||||
username: &str,
|
||||
password: String,
|
||||
requested_device_id: Option<String>,
|
||||
@@ -647,10 +709,38 @@ async fn user_password_login(
|
||||
Device::generate(&mut rng)
|
||||
};
|
||||
|
||||
repo.app_session()
|
||||
let session_replaced = repo
|
||||
.app_session()
|
||||
.finish_sessions_to_replace_device(clock, &user, &device)
|
||||
.await?;
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(repo, &user).await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_compat_login(mas_policy::CompatLoginInput {
|
||||
user: &user,
|
||||
login: CompatLogin::Password,
|
||||
session_replaced,
|
||||
session_counts,
|
||||
requester: policy_requester,
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
// If the only violation is that we have too many sessions, then handle that
|
||||
// separately.
|
||||
// In the future, we intend to evict some sessions automatically instead. We
|
||||
// don't trigger this if there was some other violation anyway, since that means
|
||||
// that removing a session wouldn't actually unblock the login.
|
||||
if res.violations.len() == 1 {
|
||||
let violation = &res.violations[0];
|
||||
if violation.code == Some(ViolationCode::TooManySessions) {
|
||||
// The only violation is having reached the session limit.
|
||||
return Err(RouteError::PolicyHardSessionLimitReached);
|
||||
}
|
||||
}
|
||||
return Err(RouteError::PolicyRejected);
|
||||
}
|
||||
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.add(
|
||||
|
||||
@@ -4,30 +4,35 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use axum::{
|
||||
extract::{Form, Path, State},
|
||||
response::{Html, IntoResponse, Redirect, Response},
|
||||
};
|
||||
use axum_extra::extract::Query;
|
||||
use axum_extra::{TypedHeader, extract::Query};
|
||||
use chrono::Duration;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::{
|
||||
InternalError,
|
||||
cookies::CookieJar,
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
};
|
||||
use mas_data_model::{BoxClock, BoxRng, Clock};
|
||||
use mas_data_model::{BoxClock, BoxRng, Clock, MatrixUser};
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_policy::{Policy, model::CompatLogin};
|
||||
use mas_router::{CompatLoginSsoAction, UrlBuilder};
|
||||
use mas_storage::{BoxRepository, RepositoryAccess, compat::CompatSsoLoginRepository};
|
||||
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
|
||||
use mas_templates::{
|
||||
CompatLoginPolicyViolationContext, CompatSsoContext, ErrorContext, TemplateContext, Templates,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
PreferredLanguage,
|
||||
session::{SessionOrFallback, load_session_or_fallback},
|
||||
BoundActivityTracker, PreferredLanguage,
|
||||
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
|
||||
};
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -56,10 +61,16 @@ pub async fn get(
|
||||
mut repo: BoxRepository,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
State(homeserver): State<Arc<dyn HomeserverConnection>>,
|
||||
mut policy: Policy,
|
||||
activity_tracker: BoundActivityTracker,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
cookie_jar: CookieJar,
|
||||
Path(id): Path<Ulid>,
|
||||
Query(params): Query<Params>,
|
||||
) -> Result<Response, InternalError> {
|
||||
let user_agent = user_agent.map(|ua| ua.to_string());
|
||||
|
||||
let (cookie_jar, maybe_session) = match load_session_or_fallback(
|
||||
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
|
||||
)
|
||||
@@ -107,7 +118,69 @@ pub async fn get(
|
||||
return Ok((cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
let ctx = CompatSsoContext::new(login)
|
||||
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
|
||||
|
||||
// We can close the repository early, we don't need it at this point
|
||||
repo.save().await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_compat_login(mas_policy::CompatLoginInput {
|
||||
user: &session.user,
|
||||
login: CompatLogin::Sso {
|
||||
redirect_uri: login.redirect_uri.to_string(),
|
||||
},
|
||||
// We don't know if there's going to be a replacement until we received the device ID,
|
||||
// which happens too late.
|
||||
session_replaced: false,
|
||||
session_counts,
|
||||
requester: mas_policy::Requester {
|
||||
ip_address: activity_tracker.ip(),
|
||||
user_agent,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
let ctx = CompatLoginPolicyViolationContext::for_violations(res.violations)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
|
||||
let content = templates.render_compat_login_policy_violation(&ctx)?;
|
||||
|
||||
return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
// Fetch informations about the user. This is purely cosmetic, so we let it
|
||||
// fail and put a 1s timeout to it in case we fail to query it
|
||||
// XXX: we're likely to need this in other places
|
||||
let localpart = &session.user.username;
|
||||
let display_name = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(1),
|
||||
homeserver.query_user(localpart),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(user)) => user.displayname,
|
||||
Ok(Err(err)) => {
|
||||
tracing::warn!(
|
||||
error = &*err as &dyn std::error::Error,
|
||||
localpart,
|
||||
"Failed to query user"
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!(localpart, "Timed out while querying user");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let matrix_user = MatrixUser {
|
||||
mxid: homeserver.mxid(localpart),
|
||||
display_name,
|
||||
};
|
||||
|
||||
let ctx = CompatSsoContext::new(login, matrix_user)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
@@ -129,11 +202,16 @@ pub async fn post(
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
mut policy: Policy,
|
||||
activity_tracker: BoundActivityTracker,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
cookie_jar: CookieJar,
|
||||
Path(id): Path<Ulid>,
|
||||
Query(params): Query<Params>,
|
||||
Form(form): Form<ProtectedForm<()>>,
|
||||
) -> Result<Response, InternalError> {
|
||||
let user_agent = user_agent.map(|ua| ua.to_string());
|
||||
|
||||
let (cookie_jar, maybe_session) = match load_session_or_fallback(
|
||||
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
|
||||
)
|
||||
@@ -200,6 +278,37 @@ pub async fn post(
|
||||
redirect_uri
|
||||
};
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_compat_login(mas_policy::CompatLoginInput {
|
||||
user: &session.user,
|
||||
login: CompatLogin::Sso {
|
||||
redirect_uri: login.redirect_uri.to_string(),
|
||||
},
|
||||
session_counts,
|
||||
// We don't know if there's going to be a replacement until we received the device ID,
|
||||
// which happens too late.
|
||||
session_replaced: false,
|
||||
requester: mas_policy::Requester {
|
||||
ip_address: activity_tracker.ip(),
|
||||
user_agent,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
|
||||
if !res.valid() {
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let ctx = CompatLoginPolicyViolationContext::for_violations(res.violations)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
|
||||
let content = templates.render_compat_login_policy_violation(&ctx)?;
|
||||
|
||||
return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
// Note that if the login is not Pending,
|
||||
// this fails and aborts the transaction.
|
||||
repo.compat_sso_login()
|
||||
|
||||
@@ -272,6 +272,7 @@ where
|
||||
BoxRepository: FromRequestParts<S>,
|
||||
BoxClock: FromRequestParts<S>,
|
||||
BoxRng: FromRequestParts<S>,
|
||||
Policy: FromRequestParts<S>,
|
||||
{
|
||||
// A sub-router for human-facing routes with error handling
|
||||
let human_router = Router::new()
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{Form, Path, State},
|
||||
response::{Html, IntoResponse, Response},
|
||||
@@ -15,8 +17,9 @@ use mas_axum_utils::{
|
||||
cookies::CookieJar,
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
};
|
||||
use mas_data_model::{AuthorizationGrantStage, BoxClock, BoxRng};
|
||||
use mas_data_model::{AuthorizationGrantStage, BoxClock, BoxRng, MatrixUser};
|
||||
use mas_keystore::Keystore;
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_policy::Policy;
|
||||
use mas_router::{PostAuthAction, UrlBuilder};
|
||||
use mas_storage::{
|
||||
@@ -87,6 +90,7 @@ pub(crate) async fn get(
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
State(homeserver): State<Arc<dyn HomeserverConnection>>,
|
||||
mut policy: Policy,
|
||||
mut repo: BoxRepository,
|
||||
activity_tracker: BoundActivityTracker,
|
||||
@@ -138,6 +142,9 @@ pub(crate) async fn get(
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
|
||||
|
||||
// We can close the repository early, we don't need it at this point
|
||||
repo.save().await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
user: Some(&session.user),
|
||||
@@ -162,7 +169,37 @@ pub(crate) async fn get(
|
||||
return Ok((cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
let ctx = ConsentContext::new(grant, client)
|
||||
// Fetch informations about the user. This is purely cosmetic, so we let it
|
||||
// fail and put a 1s timeout to it in case we fail to query it
|
||||
// XXX: we're likely to need this in other places
|
||||
let localpart = &session.user.username;
|
||||
let display_name = match tokio::time::timeout(
|
||||
Duration::from_secs(1),
|
||||
homeserver.query_user(localpart),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(user)) => user.displayname,
|
||||
Ok(Err(err)) => {
|
||||
tracing::warn!(
|
||||
error = &*err as &dyn std::error::Error,
|
||||
localpart,
|
||||
"Failed to query user"
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!(localpart, "Timed out while querying user");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let matrix_user = MatrixUser {
|
||||
mxid: homeserver.mxid(localpart),
|
||||
display_name,
|
||||
};
|
||||
|
||||
let ctx = ConsentContext::new(grant, client, matrix_user)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use axum::{
|
||||
Form,
|
||||
@@ -16,7 +18,8 @@ use mas_axum_utils::{
|
||||
cookies::CookieJar,
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
};
|
||||
use mas_data_model::{BoxClock, BoxRng};
|
||||
use mas_data_model::{BoxClock, BoxRng, MatrixUser};
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_policy::Policy;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::BoxRepository;
|
||||
@@ -49,6 +52,7 @@ pub(crate) async fn get(
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
State(homeserver): State<Arc<dyn HomeserverConnection>>,
|
||||
mut repo: BoxRepository,
|
||||
mut policy: Policy,
|
||||
activity_tracker: BoundActivityTracker,
|
||||
@@ -105,6 +109,9 @@ pub(crate) async fn get(
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
|
||||
|
||||
// We can close the repository early, we don't need it at this point
|
||||
repo.save().await?;
|
||||
|
||||
// Evaluate the policy
|
||||
let res = policy
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
@@ -133,7 +140,37 @@ pub(crate) async fn get(
|
||||
return Ok((cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
let ctx = DeviceConsentContext::new(grant, client)
|
||||
// Fetch informations about the user. This is purely cosmetic, so we let it
|
||||
// fail and put a 1s timeout to it in case we fail to query it
|
||||
// XXX: we're likely to need this in other places
|
||||
let localpart = &session.user.username;
|
||||
let display_name = match tokio::time::timeout(
|
||||
Duration::from_secs(1),
|
||||
homeserver.query_user(localpart),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(user)) => user.displayname,
|
||||
Ok(Err(err)) => {
|
||||
tracing::warn!(
|
||||
error = &*err as &dyn std::error::Error,
|
||||
localpart,
|
||||
"Failed to query user"
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!(localpart, "Timed out while querying user");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let matrix_user = MatrixUser {
|
||||
mxid: homeserver.mxid(localpart),
|
||||
display_name,
|
||||
};
|
||||
|
||||
let ctx = DeviceConsentContext::new(grant, client, matrix_user)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
@@ -153,6 +190,7 @@ pub(crate) async fn post(
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
State(homeserver): State<Arc<dyn HomeserverConnection>>,
|
||||
mut repo: BoxRepository,
|
||||
mut policy: Policy,
|
||||
activity_tracker: BoundActivityTracker,
|
||||
@@ -265,7 +303,37 @@ pub(crate) async fn post(
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
let ctx = DeviceConsentContext::new(grant, client)
|
||||
// Fetch informations about the user. This is purely cosmetic, so we let it
|
||||
// fail and put a 1s timeout to it in case we fail to query it
|
||||
// XXX: we're likely to need this in other places
|
||||
let localpart = &session.user.username;
|
||||
let display_name = match tokio::time::timeout(
|
||||
Duration::from_secs(1),
|
||||
homeserver.query_user(localpart),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(user)) => user.displayname,
|
||||
Ok(Err(err)) => {
|
||||
tracing::warn!(
|
||||
error = &*err as &dyn std::error::Error,
|
||||
localpart,
|
||||
"Failed to query user"
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!(localpart, "Timed out while querying user");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let matrix_user = MatrixUser {
|
||||
mxid: homeserver.mxid(localpart),
|
||||
display_name,
|
||||
};
|
||||
|
||||
let ctx = DeviceConsentContext::new(grant, client, matrix_user)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
|
||||
@@ -82,6 +82,7 @@ pub(crate) async fn policy_factory(
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
compat_login: "compat_login/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,7 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use mas_policy::model::{
|
||||
AuthorizationGrantInput, ClientRegistrationInput, EmailInput, RegisterInput,
|
||||
AuthorizationGrantInput, ClientRegistrationInput, CompatLoginInput, EmailInput, RegisterInput,
|
||||
};
|
||||
use schemars::{JsonSchema, generate::SchemaSettings};
|
||||
|
||||
@@ -42,5 +42,6 @@ fn main() {
|
||||
write_schema::<RegisterInput>(output_root, "register_input.json");
|
||||
write_schema::<ClientRegistrationInput>(output_root, "client_registration_input.json");
|
||||
write_schema::<AuthorizationGrantInput>(output_root, "authorization_grant_input.json");
|
||||
write_schema::<CompatLoginInput>(output_root, "compat_login_input.json");
|
||||
write_schema::<EmailInput>(output_root, "email_input.json");
|
||||
}
|
||||
|
||||
@@ -19,8 +19,9 @@ use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
pub use self::model::{
|
||||
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput,
|
||||
EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
|
||||
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput,
|
||||
EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester,
|
||||
Violation,
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -72,15 +73,17 @@ pub struct Entrypoints {
|
||||
pub register: String,
|
||||
pub client_registration: String,
|
||||
pub authorization_grant: String,
|
||||
pub compat_login: String,
|
||||
pub email: String,
|
||||
}
|
||||
|
||||
impl Entrypoints {
|
||||
fn all(&self) -> [&str; 4] {
|
||||
fn all(&self) -> [&str; 5] {
|
||||
[
|
||||
self.register.as_str(),
|
||||
self.client_registration.as_str(),
|
||||
self.authorization_grant.as_str(),
|
||||
self.compat_login.as_str(),
|
||||
self.email.as_str(),
|
||||
]
|
||||
}
|
||||
@@ -459,6 +462,30 @@ impl Policy {
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Evaluate the `compat_login` entrypoint.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the policy engine fails to evaluate the entrypoint.
|
||||
#[tracing::instrument(
|
||||
name = "policy.evaluate.compat_login",
|
||||
skip_all,
|
||||
fields(
|
||||
%input.user.id,
|
||||
),
|
||||
)]
|
||||
pub async fn evaluate_compat_login(
|
||||
&mut self,
|
||||
input: CompatLoginInput<'_>,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.entrypoints.compat_login, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -468,6 +495,16 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
fn make_entrypoints() -> Entrypoints {
|
||||
Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
compat_login: "compat_login/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register() {
|
||||
let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({
|
||||
@@ -484,14 +521,9 @@ mod tests {
|
||||
|
||||
let file = tokio::fs::File::open(path).await.unwrap();
|
||||
|
||||
let entrypoints = Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
||||
let factory = PolicyFactory::load(file, data, make_entrypoints())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut policy = factory.instantiate().await.unwrap();
|
||||
|
||||
@@ -551,14 +583,9 @@ mod tests {
|
||||
|
||||
let file = tokio::fs::File::open(path).await.unwrap();
|
||||
|
||||
let entrypoints = Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
||||
let factory = PolicyFactory::load(file, data, make_entrypoints())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut policy = factory.instantiate().await.unwrap();
|
||||
|
||||
@@ -620,14 +647,9 @@ mod tests {
|
||||
|
||||
let file = tokio::fs::File::open(path).await.unwrap();
|
||||
|
||||
let entrypoints = Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
||||
let factory = PolicyFactory::load(file, data, make_entrypoints())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// That is around 1 MB of JSON data. Each element is a 5-digit string, so 8
|
||||
// characters including the quotes and a comma.
|
||||
|
||||
@@ -17,7 +17,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A well-known policy code.
|
||||
#[derive(Deserialize, Debug, Clone, Copy, JsonSchema)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum Code {
|
||||
/// The username is too short.
|
||||
@@ -75,7 +75,7 @@ impl Code {
|
||||
}
|
||||
|
||||
/// A single violation of a policy.
|
||||
#[derive(Deserialize, Debug, JsonSchema)]
|
||||
#[derive(Serialize, Deserialize, Debug, JsonSchema)]
|
||||
pub struct Violation {
|
||||
pub msg: String,
|
||||
pub redirect_uri: Option<String>,
|
||||
@@ -187,6 +187,42 @@ pub struct AuthorizationGrantInput<'a> {
|
||||
pub requester: Requester,
|
||||
}
|
||||
|
||||
/// Input for the compatibility login policy.
|
||||
#[derive(Serialize, Debug, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct CompatLoginInput<'a> {
|
||||
#[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
|
||||
pub user: &'a User,
|
||||
|
||||
/// How many sessions the user has.
|
||||
pub session_counts: SessionCounts,
|
||||
|
||||
/// Whether a session will be replaced by this login
|
||||
pub session_replaced: bool,
|
||||
|
||||
/// What type of login is being performed.
|
||||
/// This also determines whether the login is interactive.
|
||||
pub login: CompatLogin,
|
||||
|
||||
pub requester: Requester,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, JsonSchema)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum CompatLogin {
|
||||
/// Used as the interactive part of SSO login.
|
||||
#[serde(rename = "m.login.sso")]
|
||||
Sso { redirect_uri: String },
|
||||
|
||||
/// Used as the final (non-interactive) stage of SSO login.
|
||||
#[serde(rename = "m.login.token")]
|
||||
Token,
|
||||
|
||||
/// Non-interactive password-over-the-API login.
|
||||
#[serde(rename = "m.login.password")]
|
||||
Password,
|
||||
}
|
||||
|
||||
/// Information about how many sessions the user has
|
||||
#[derive(Serialize, Debug, JsonSchema)]
|
||||
pub struct SessionCounts {
|
||||
|
||||
20
crates/storage-pg/.sqlx/query-2f66991d7b9ba58f011d9aef0eb6a38f3b244c2f46444c0ab345de7feff54aba.json
generated
Normal file
20
crates/storage-pg/.sqlx/query-2f66991d7b9ba58f011d9aef0eb6a38f3b244c2f46444c0ab345de7feff54aba.json
generated
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT current_database() as \"current_database!\"",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "current_database!",
|
||||
"type_info": "Name"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": []
|
||||
},
|
||||
"nullable": [
|
||||
null
|
||||
]
|
||||
},
|
||||
"hash": "2f66991d7b9ba58f011d9aef0eb6a38f3b244c2f46444c0ab345de7feff54aba"
|
||||
}
|
||||
20
crates/storage-pg/.sqlx/query-fbf926f630df5d588df4f1c9c0dc0f594332be5829d5d7c6b66183ac25b3d166.json
generated
Normal file
20
crates/storage-pg/.sqlx/query-fbf926f630df5d588df4f1c9c0dc0f594332be5829d5d7c6b66183ac25b3d166.json
generated
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT EXISTS (\n SELECT 1\n FROM information_schema.tables\n WHERE table_name = '_sqlx_migrations'\n ) AS \"exists!\"\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "exists!",
|
||||
"type_info": "Bool"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": []
|
||||
},
|
||||
"nullable": [
|
||||
null
|
||||
]
|
||||
},
|
||||
"hash": "fbf926f630df5d588df4f1c9c0dc0f594332be5829d5d7c6b66183ac25b3d166"
|
||||
}
|
||||
@@ -19,6 +19,7 @@ workspace = true
|
||||
[dependencies]
|
||||
async-trait.workspace = true
|
||||
chrono.workspace = true
|
||||
crc.workspace = true
|
||||
futures-util.workspace = true
|
||||
opentelemetry-semantic-conventions.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
@@ -31,6 +32,7 @@ sha2.workspace = true
|
||||
sqlx.workspace = true
|
||||
thiserror.workspace = true
|
||||
tracing.workspace = true
|
||||
tokio.workspace = true
|
||||
ulid.workspace = true
|
||||
url.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
@@ -487,14 +487,15 @@ impl AppSessionRepository for PgAppSessionRepository<'_> {
|
||||
clock: &dyn Clock,
|
||||
user: &User,
|
||||
device: &Device,
|
||||
) -> Result<(), Self::Error> {
|
||||
) -> Result<bool, Self::Error> {
|
||||
let mut affected = false;
|
||||
// TODO need to invoke this from all the oauth2 login sites
|
||||
let span = tracing::info_span!(
|
||||
"db.app_session.finish_sessions_to_replace_device.compat_sessions",
|
||||
{ DB_QUERY_TEXT } = tracing::field::Empty,
|
||||
);
|
||||
let finished_at = clock.now();
|
||||
sqlx::query!(
|
||||
let compat_affected = sqlx::query!(
|
||||
"
|
||||
UPDATE compat_sessions SET finished_at = $3 WHERE user_id = $1 AND device_id = $2 AND finished_at IS NULL
|
||||
",
|
||||
@@ -505,7 +506,9 @@ impl AppSessionRepository for PgAppSessionRepository<'_> {
|
||||
.record(&span)
|
||||
.execute(&mut *self.conn)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
.await?
|
||||
.rows_affected();
|
||||
affected |= compat_affected > 0;
|
||||
|
||||
if let Ok([stable_device_as_scope_token, unstable_device_as_scope_token]) =
|
||||
device.to_scope_token()
|
||||
@@ -514,7 +517,7 @@ impl AppSessionRepository for PgAppSessionRepository<'_> {
|
||||
"db.app_session.finish_sessions_to_replace_device.oauth2_sessions",
|
||||
{ DB_QUERY_TEXT } = tracing::field::Empty,
|
||||
);
|
||||
sqlx::query!(
|
||||
let oauth2_affected = sqlx::query!(
|
||||
"
|
||||
UPDATE oauth2_sessions
|
||||
SET finished_at = $4
|
||||
@@ -530,10 +533,12 @@ impl AppSessionRepository for PgAppSessionRepository<'_> {
|
||||
.record(&span)
|
||||
.execute(&mut *self.conn)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
.await?
|
||||
.rows_affected();
|
||||
affected |= oauth2_affected > 0;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(affected)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -160,7 +160,15 @@
|
||||
#![deny(clippy::future_not_send, missing_docs)]
|
||||
#![allow(clippy::module_name_repetitions, clippy::blocks_in_conditions)]
|
||||
|
||||
use sqlx::migrate::Migrator;
|
||||
use std::collections::{BTreeMap, BTreeSet, HashSet};
|
||||
|
||||
use ::tracing::{Instrument, debug, info, info_span, warn};
|
||||
use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
|
||||
use sqlx::{
|
||||
Either, PgConnection,
|
||||
migrate::{AppliedMigration, Migrate, MigrateError, Migration, Migrator},
|
||||
postgres::{PgAdvisoryLock, PgAdvisoryLockKey},
|
||||
};
|
||||
|
||||
pub mod app_session;
|
||||
pub mod compat;
|
||||
@@ -186,14 +194,290 @@ pub use self::{
|
||||
tracing::ExecuteExt,
|
||||
};
|
||||
|
||||
/// Embedded migrations, allowing them to run on startup
|
||||
pub static MIGRATOR: Migrator = {
|
||||
// XXX: The macro does not let us ignore missing migrations, so we have to do it
|
||||
// like this. See https://github.com/launchbadge/sqlx/issues/1788
|
||||
let mut m = sqlx::migrate!();
|
||||
/// Embedded migrations in the binary
|
||||
pub static MIGRATOR: Migrator = sqlx::migrate!();
|
||||
|
||||
// We manually removed some migrations because they made us depend on the
|
||||
// `pgcrypto` extension. See: https://github.com/matrix-org/matrix-authentication-service/issues/1557
|
||||
m.ignore_missing = true;
|
||||
m
|
||||
};
|
||||
fn available_migrations() -> BTreeMap<i64, &'static Migration> {
|
||||
MIGRATOR.iter().map(|m| (m.version, m)).collect()
|
||||
}
|
||||
|
||||
/// This is the list of migrations we've removed from the migration history but
|
||||
/// might have been applied in the past
|
||||
#[allow(clippy::inconsistent_digit_grouping)]
|
||||
const ALLOWED_MISSING_MIGRATIONS: &[i64] = &[
|
||||
// https://github.com/matrix-org/matrix-authentication-service/pull/1585
|
||||
20220709_210445,
|
||||
20230330_210841,
|
||||
20230408_110421,
|
||||
];
|
||||
|
||||
fn allowed_missing_migrations() -> BTreeSet<i64> {
|
||||
ALLOWED_MISSING_MIGRATIONS.iter().copied().collect()
|
||||
}
|
||||
|
||||
/// This is a list of possible additional checksums from previous versions of
|
||||
/// migrations. The checksum we store in the database is 48 bytes long. We're
|
||||
/// not really concerned with partial hash collisions, and to avoid this file to
|
||||
/// be completely unreadable, we only store the upper 16 bytes of that hash.
|
||||
#[allow(clippy::inconsistent_digit_grouping)]
|
||||
const ALLOWED_ALTERNATE_CHECKSUMS: &[(i64, u128)] = &[
|
||||
// https://github.com/element-hq/matrix-authentication-service/pull/5300
|
||||
(20250410_000000, 0x8811_c3ef_dbee_8c00_5b49_25da_5d55_9c3f),
|
||||
(20250410_000001, 0x7990_37b3_2193_8a5d_c72f_bccd_95fd_82e5),
|
||||
(20250410_000002, 0xf2b8_f120_deae_27e7_60d0_79a3_0b77_eea3),
|
||||
(20250410_000003, 0x06be_fc2b_cedc_acf4_b981_02c7_b40c_c469),
|
||||
(20250410_000004, 0x0a90_9c6a_dba7_545c_10d9_60eb_6d30_2f50),
|
||||
(20250410_000006, 0xcc7f_5152_6497_5729_d94b_be0d_9c95_8316),
|
||||
(20250410_000007, 0x12e7_cfab_a017_a5a5_4f2c_18fa_541c_ce62),
|
||||
(20250410_000008, 0x171d_62e5_ee1a_f0d9_3639_6c5a_277c_54cd),
|
||||
(20250410_000009, 0xb1a0_93c7_6645_92ad_df45_b395_57bb_a281),
|
||||
(20250410_000010, 0x8089_86ac_7cff_8d86_2850_d287_cdb1_2b57),
|
||||
(20250410_000011, 0x8d9d_3fae_02c9_3d3f_81e4_6242_2b39_b5b8),
|
||||
(20250410_000012, 0x9805_1372_41aa_d5b0_ebe1_ba9d_28c7_faf6),
|
||||
(20250410_000013, 0x7291_9a97_e4d1_0d45_1791_6e8c_3f2d_e34d),
|
||||
(20250410_000014, 0x811d_f965_8127_e168_4aa2_f177_a4e6_f077),
|
||||
(20250410_000015, 0xa639_0780_aab7_d60d_5fcb_771d_13ed_73ee),
|
||||
(20250410_000016, 0x22b6_e909_6de4_39e3_b2b9_c684_7417_fe07),
|
||||
(20250410_000017, 0x9dfe_b6d3_89e4_e509_651b_2793_8d8d_cd32),
|
||||
(20250410_000018, 0x638f_bdbc_2276_5094_020b_cec1_ab95_c07f),
|
||||
(20250410_000019, 0xa283_84bc_5fd5_7cbd_b5fb_b5fe_0255_6845),
|
||||
(20250410_000020, 0x17d1_54b1_7c6e_fc48_61dd_da3d_f8a5_9546),
|
||||
(20250410_000022, 0xbc36_af82_994a_6f93_8aca_a46b_fc3c_ffde),
|
||||
(20250410_000023, 0x54ec_3b07_ac79_443b_9e18_a2b3_2d17_5ab9),
|
||||
(20250410_000024, 0x8ab4_4f80_00b6_58b2_d757_c40f_bc72_3d87),
|
||||
(20250410_000025, 0x5dc4_2ff3_3042_2f45_046d_10af_ab3a_b583),
|
||||
(20250410_000026, 0x5263_c547_0b64_6425_5729_48b2_ce84_7cad),
|
||||
(20250410_000027, 0x0aad_cb50_1d6a_7794_9017_d24d_55e7_1b9d),
|
||||
(20250410_000028, 0x8fc1_92f8_68df_ca4e_3e2b_cddf_bc12_cffe),
|
||||
(20250410_000029, 0x416c_9446_b6a3_1b49_2940_a8ac_c1c2_665a),
|
||||
(20250410_000030, 0x83a5_e51e_25a6_77fb_2b79_6ea5_db1e_364f),
|
||||
(20250410_000031, 0xfa18_a707_9438_dbc7_2cde_b5f1_ee21_5c7e),
|
||||
(20250410_000032, 0xd669_662e_8930_838a_b142_c3fa_7b39_d2a0),
|
||||
(20250410_000033, 0x4019_1053_cabc_191c_c02e_9aa9_407c_0de5),
|
||||
(20250410_000034, 0xdd59_e595_24e6_4dad_c5f7_fef2_90b8_df57),
|
||||
(20250410_000035, 0x09b4_ea53_2da4_9c39_eb10_db33_6a6d_608b),
|
||||
(20250410_000036, 0x3ca5_9c78_8480_e342_d729_907c_d293_2049),
|
||||
(20250410_000037, 0xc857_2a10_450b_0612_822c_2b86_535a_ea7d),
|
||||
(20250410_000038, 0x1642_39da_9c3b_d9fd_b1e1_72b1_db78_b978),
|
||||
(20250410_000039, 0xdd70_b211_6016_bb84_0d84_f04e_eb8a_59d9),
|
||||
(20250410_000040, 0xe435_ead6_c363_a0b6_e048_dd85_0ecb_9499),
|
||||
(20250410_000041, 0xe9f3_122f_70d4_9839_c818_4b18_0192_ae26),
|
||||
(20250410_000043, 0xec5e_1400_483d_c4bf_6014_aba4_ffc3_6236),
|
||||
(20250410_000044, 0x4750_5eba_4095_6664_78d0_27f9_64bf_64f4),
|
||||
(20250410_000045, 0x9a53_bd70_4cad_2bf1_61d4_f143_0c82_681d),
|
||||
(20250410_121612, 0x25f0_9d20_a897_df18_162d_1c47_b68e_81bd),
|
||||
(20250602_212101, 0xd1a8_782c_b3f0_5045_3f46_49a0_bab0_822b),
|
||||
(20250708_155857, 0xb78e_6957_a588_c16a_d292_a0c7_cae9_f290),
|
||||
(20250915_092635, 0x6854_d58b_99d7_3ac5_82f8_25e5_b1c3_cc0b),
|
||||
(20251127_145951, 0x3bcd_d92e_8391_2a2c_8a18_1d76_354f_96c6),
|
||||
];
|
||||
|
||||
fn alternate_checksums_map() -> BTreeMap<i64, HashSet<u128>> {
|
||||
let mut map = BTreeMap::new();
|
||||
for (version, checksum) in ALLOWED_ALTERNATE_CHECKSUMS {
|
||||
map.entry(*version)
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(*checksum);
|
||||
}
|
||||
map
|
||||
}
|
||||
|
||||
/// Load the list of applied migrations into a map.
|
||||
///
|
||||
/// It's important to use a [`BTreeMap`] so that the migrations are naturally
|
||||
/// ordered by version.
|
||||
async fn applied_migrations_map(
|
||||
conn: &mut PgConnection,
|
||||
) -> Result<BTreeMap<i64, AppliedMigration>, MigrateError> {
|
||||
let applied_migrations = conn
|
||||
.list_applied_migrations()
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|m| (m.version, m))
|
||||
.collect();
|
||||
|
||||
Ok(applied_migrations)
|
||||
}
|
||||
|
||||
/// Checks if the migration table exists
|
||||
async fn migration_table_exists(conn: &mut PgConnection) -> Result<bool, sqlx::Error> {
|
||||
sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_name = '_sqlx_migrations'
|
||||
) AS "exists!"
|
||||
"#,
|
||||
)
|
||||
.fetch_one(conn)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Run the migrations on the given connection
|
||||
///
|
||||
/// This function acquires an advisory lock on the database to ensure that only
|
||||
/// one migrator is running at a time.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function returns an error if the migration fails.
|
||||
#[::tracing::instrument(name = "db.migrate", skip_all, err)]
|
||||
pub async fn migrate(conn: &mut PgConnection) -> Result<(), MigrateError> {
|
||||
// Get the database name and use it to derive an advisory lock key. This
|
||||
// is the same lock key used by SQLx default migrator, so that it works even
|
||||
// with older versions of MAS, and when running through `cargo sqlx migrate run`
|
||||
let database_name = sqlx::query_scalar!(r#"SELECT current_database() as "current_database!""#)
|
||||
.fetch_one(&mut *conn)
|
||||
.await
|
||||
.map_err(MigrateError::from)?;
|
||||
|
||||
let lock =
|
||||
PgAdvisoryLock::with_key(PgAdvisoryLockKey::BigInt(generate_lock_id(&database_name)));
|
||||
|
||||
// Try to acquire the migration lock in a loop.
|
||||
//
|
||||
// The reason we do that with a `try_acquire` is because in Postgres, `CREATE
|
||||
// INDEX CONCURRENTLY` will *not* complete whilst an advisory lock is being
|
||||
// acquired on another connection. This then means that if we run two
|
||||
// migration process at the same time, one of them will go through and block
|
||||
// on concurrent index creations, because the other will get stuck trying to
|
||||
// acquire this lock.
|
||||
//
|
||||
// To avoid this, we use `try_acquire`/`pg_advisory_lock_try` in a loop, which
|
||||
// will fail immediately if the lock is held by another connection, allowing
|
||||
// potential 'CREATE INDEX CONCURRENTLY' statements to complete.
|
||||
let mut backoff = std::time::Duration::from_millis(250);
|
||||
let mut conn = conn;
|
||||
let mut locked_connection = loop {
|
||||
match lock.try_acquire(conn).await? {
|
||||
Either::Left(guard) => break guard,
|
||||
Either::Right(conn_) => {
|
||||
warn!(
|
||||
"Another process is already running migrations on the database, waiting {duration}s and trying again…",
|
||||
duration = backoff.as_secs_f32()
|
||||
);
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = std::cmp::min(backoff * 2, std::time::Duration::from_secs(5));
|
||||
conn = conn_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Creates the migration table if missing
|
||||
// We check if the table exists before calling `ensure_migrations_table` to
|
||||
// avoid the pesky 'relation "_sqlx_migrations" already exists, skipping' notice
|
||||
if !migration_table_exists(locked_connection.as_mut()).await? {
|
||||
locked_connection.as_mut().ensure_migrations_table().await?;
|
||||
}
|
||||
|
||||
for migration in pending_migrations(locked_connection.as_mut()).await? {
|
||||
info!(
|
||||
"Applying migration {version}: {description}",
|
||||
version = migration.version,
|
||||
description = migration.description
|
||||
);
|
||||
locked_connection
|
||||
.as_mut()
|
||||
.apply(migration)
|
||||
.instrument(info_span!(
|
||||
"db.migrate.run_migration",
|
||||
db.migration.version = migration.version,
|
||||
db.migration.description = &*migration.description,
|
||||
{ DB_QUERY_TEXT } = &*migration.sql,
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
locked_connection.release_now().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the list of pending migrations
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function returns an error if there is a problem checking the applied
|
||||
/// migrations
|
||||
pub async fn pending_migrations(
|
||||
conn: &mut PgConnection,
|
||||
) -> Result<Vec<&'static Migration>, MigrateError> {
|
||||
// Load the maps of available migrations, applied migrations, migrations that
|
||||
// are allowed to be missing, alternate checksums for migrations that changed
|
||||
let available_migrations = available_migrations();
|
||||
let allowed_missing = allowed_missing_migrations();
|
||||
let alternate_checksums = alternate_checksums_map();
|
||||
let applied_migrations = if migration_table_exists(&mut *conn).await? {
|
||||
applied_migrations_map(&mut *conn).await?
|
||||
} else {
|
||||
BTreeMap::new()
|
||||
};
|
||||
|
||||
// Check that all applied migrations are still valid
|
||||
for applied_migration in applied_migrations.values() {
|
||||
// Check that we know about the applied migration
|
||||
if let Some(migration) = available_migrations.get(&applied_migration.version) {
|
||||
// Check the migration checksum
|
||||
if applied_migration.checksum != migration.checksum {
|
||||
// The checksum we have in the database doesn't match the one we
|
||||
// have embedded. This might be because a migration was
|
||||
// intentionally changed, so we check the alternate checksums
|
||||
if let Some(alternates) = alternate_checksums.get(&applied_migration.version) {
|
||||
// This converts the first 16 bytes of the checksum into a u128
|
||||
let Some(applied_checksum_prefix) = applied_migration
|
||||
.checksum
|
||||
.get(..16)
|
||||
.and_then(|bytes| bytes.try_into().ok())
|
||||
.map(u128::from_be_bytes)
|
||||
else {
|
||||
return Err(MigrateError::ExecuteMigration(
|
||||
sqlx::Error::InvalidArgument(
|
||||
"checksum stored in database is invalid".to_owned(),
|
||||
),
|
||||
applied_migration.version,
|
||||
));
|
||||
};
|
||||
|
||||
if !alternates.contains(&applied_checksum_prefix) {
|
||||
warn!(
|
||||
"The database has a migration applied ({version}) which has known alternative checksums {alternates:x?}, but none of them matched {applied_checksum_prefix:x}",
|
||||
version = applied_migration.version,
|
||||
);
|
||||
return Err(MigrateError::VersionMismatch(applied_migration.version));
|
||||
}
|
||||
} else {
|
||||
return Err(MigrateError::VersionMismatch(applied_migration.version));
|
||||
}
|
||||
}
|
||||
} else if allowed_missing.contains(&applied_migration.version) {
|
||||
// The migration is missing, but allowed to be missing
|
||||
debug!(
|
||||
"The database has a migration applied ({version}) that doesn't exist anymore, but it was intentionally removed",
|
||||
version = applied_migration.version
|
||||
);
|
||||
} else {
|
||||
// The migration is missing, warn about it
|
||||
warn!(
|
||||
"The database has a migration applied ({version}) that doesn't exist anymore! This should not happen, unless rolling back to an older version of MAS.",
|
||||
version = applied_migration.version
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(available_migrations
|
||||
.values()
|
||||
.copied()
|
||||
.filter(|migration| {
|
||||
!migration.migration_type.is_down_migration()
|
||||
&& !applied_migrations.contains_key(&migration.version)
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
// Copied from the sqlx source code, so that we generate the same lock ID
|
||||
fn generate_lock_id(database_name: &str) -> i64 {
|
||||
const CRC_IEEE: crc::Crc<u32> = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC);
|
||||
// 0x3d32ad9e chosen by fair dice roll
|
||||
0x3d32_ad9e * i64::from(CRC_IEEE.checksum(database_name.as_bytes()))
|
||||
}
|
||||
|
||||
@@ -196,12 +196,14 @@ pub trait AppSessionRepository: Send + Sync {
|
||||
/// replacing a device).
|
||||
///
|
||||
/// Should be called *before* creating a new session for the device.
|
||||
///
|
||||
/// Returns true if a session was finished.
|
||||
async fn finish_sessions_to_replace_device(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
user: &User,
|
||||
device: &Device,
|
||||
) -> Result<(), Self::Error>;
|
||||
) -> Result<bool, Self::Error>;
|
||||
}
|
||||
|
||||
repository_impl!(AppSessionRepository:
|
||||
@@ -218,5 +220,5 @@ repository_impl!(AppSessionRepository:
|
||||
clock: &dyn Clock,
|
||||
user: &User,
|
||||
device: &Device,
|
||||
) -> Result<(), Self::Error>;
|
||||
) -> Result<bool, Self::Error>;
|
||||
);
|
||||
|
||||
@@ -41,6 +41,7 @@ oauth2-types.workspace = true
|
||||
mas-data-model.workspace = true
|
||||
mas-i18n.workspace = true
|
||||
mas-iana.workspace = true
|
||||
mas-policy.workspace = true
|
||||
mas-router.workspace = true
|
||||
mas-spa.workspace = true
|
||||
|
||||
|
||||
@@ -21,13 +21,15 @@ use chrono::{DateTime, Duration, Utc};
|
||||
use http::{Method, Uri, Version};
|
||||
use mas_data_model::{
|
||||
AuthorizationGrant, BrowserSession, Client, CompatSsoLogin, CompatSsoLoginState,
|
||||
DeviceCodeGrant, UpstreamOAuthLink, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports,
|
||||
UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderOnBackchannelLogout,
|
||||
UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderTokenAuthMethod, User,
|
||||
UserEmailAuthentication, UserEmailAuthenticationCode, UserRecoverySession, UserRegistration,
|
||||
DeviceCodeGrant, MatrixUser, UpstreamOAuthLink, UpstreamOAuthProvider,
|
||||
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
|
||||
UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode,
|
||||
UpstreamOAuthProviderTokenAuthMethod, User, UserEmailAuthentication,
|
||||
UserEmailAuthenticationCode, UserRecoverySession, UserRegistration,
|
||||
};
|
||||
use mas_i18n::DataLocale;
|
||||
use mas_iana::jose::JsonWebSignatureAlg;
|
||||
use mas_policy::{Violation, ViolationCode};
|
||||
use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder};
|
||||
use oauth2_types::scope::{OPENID, Scope};
|
||||
use rand::{
|
||||
@@ -732,6 +734,7 @@ pub struct ConsentContext {
|
||||
grant: AuthorizationGrant,
|
||||
client: Client,
|
||||
action: PostAuthAction,
|
||||
matrix_user: MatrixUser,
|
||||
}
|
||||
|
||||
impl TemplateContext for ConsentContext {
|
||||
@@ -755,6 +758,10 @@ impl TemplateContext for ConsentContext {
|
||||
grant,
|
||||
client,
|
||||
action,
|
||||
matrix_user: MatrixUser {
|
||||
mxid: "@alice:example.com".to_owned(),
|
||||
display_name: Some("Alice".to_owned()),
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
@@ -765,12 +772,13 @@ impl TemplateContext for ConsentContext {
|
||||
impl ConsentContext {
|
||||
/// Constructs a context for the client consent page
|
||||
#[must_use]
|
||||
pub fn new(grant: AuthorizationGrant, client: Client) -> Self {
|
||||
pub fn new(grant: AuthorizationGrant, client: Client, matrix_user: MatrixUser) -> Self {
|
||||
let action = PostAuthAction::continue_grant(grant.id);
|
||||
Self {
|
||||
grant,
|
||||
client,
|
||||
action,
|
||||
matrix_user,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -860,11 +868,50 @@ impl PolicyViolationContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// Context used by the `compat_login_policy_violation.html` template
|
||||
#[derive(Serialize)]
|
||||
pub struct CompatLoginPolicyViolationContext {
|
||||
violations: Vec<Violation>,
|
||||
}
|
||||
|
||||
impl TemplateContext for CompatLoginPolicyViolationContext {
|
||||
fn sample<R: Rng>(
|
||||
_now: chrono::DateTime<Utc>,
|
||||
_rng: &mut R,
|
||||
_locales: &[DataLocale],
|
||||
) -> BTreeMap<SampleIdentifier, Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
sample_list(vec![
|
||||
CompatLoginPolicyViolationContext { violations: vec![] },
|
||||
CompatLoginPolicyViolationContext {
|
||||
violations: vec![Violation {
|
||||
msg: "user has too many active sessions".to_owned(),
|
||||
redirect_uri: None,
|
||||
field: None,
|
||||
code: Some(ViolationCode::TooManySessions),
|
||||
}],
|
||||
},
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl CompatLoginPolicyViolationContext {
|
||||
/// Constructs a context for the compatibility login policy violation page
|
||||
/// given the list of violations
|
||||
#[must_use]
|
||||
pub const fn for_violations(violations: Vec<Violation>) -> Self {
|
||||
Self { violations }
|
||||
}
|
||||
}
|
||||
|
||||
/// Context used by the `sso.html` template
|
||||
#[derive(Serialize)]
|
||||
pub struct CompatSsoContext {
|
||||
login: CompatSsoLogin,
|
||||
action: PostAuthAction,
|
||||
matrix_user: MatrixUser,
|
||||
}
|
||||
|
||||
impl TemplateContext for CompatSsoContext {
|
||||
@@ -877,23 +924,33 @@ impl TemplateContext for CompatSsoContext {
|
||||
Self: Sized,
|
||||
{
|
||||
let id = Ulid::from_datetime_with_source(now.into(), rng);
|
||||
sample_list(vec![CompatSsoContext::new(CompatSsoLogin {
|
||||
id,
|
||||
redirect_uri: Url::parse("https://app.element.io/").unwrap(),
|
||||
login_token: "abcdefghijklmnopqrstuvwxyz012345".into(),
|
||||
created_at: now,
|
||||
state: CompatSsoLoginState::Pending,
|
||||
})])
|
||||
sample_list(vec![CompatSsoContext::new(
|
||||
CompatSsoLogin {
|
||||
id,
|
||||
redirect_uri: Url::parse("https://app.element.io/").unwrap(),
|
||||
login_token: "abcdefghijklmnopqrstuvwxyz012345".into(),
|
||||
created_at: now,
|
||||
state: CompatSsoLoginState::Pending,
|
||||
},
|
||||
MatrixUser {
|
||||
mxid: "@alice:example.com".to_owned(),
|
||||
display_name: Some("Alice".to_owned()),
|
||||
},
|
||||
)])
|
||||
}
|
||||
}
|
||||
|
||||
impl CompatSsoContext {
|
||||
/// Constructs a context for the legacy SSO login page
|
||||
#[must_use]
|
||||
pub fn new(login: CompatSsoLogin) -> Self
|
||||
pub fn new(login: CompatSsoLogin, matrix_user: MatrixUser) -> Self
|
||||
where {
|
||||
let action = PostAuthAction::continue_compat_sso_login(login.id);
|
||||
Self { login, action }
|
||||
Self {
|
||||
login,
|
||||
action,
|
||||
matrix_user,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1748,13 +1805,18 @@ impl TemplateContext for DeviceLinkContext {
|
||||
pub struct DeviceConsentContext {
|
||||
grant: DeviceCodeGrant,
|
||||
client: Client,
|
||||
matrix_user: MatrixUser,
|
||||
}
|
||||
|
||||
impl DeviceConsentContext {
|
||||
/// Constructs a new context with an existing linked user
|
||||
#[must_use]
|
||||
pub fn new(grant: DeviceCodeGrant, client: Client) -> Self {
|
||||
Self { grant, client }
|
||||
pub fn new(grant: DeviceCodeGrant, client: Client, matrix_user: MatrixUser) -> Self {
|
||||
Self {
|
||||
grant,
|
||||
client,
|
||||
matrix_user,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1782,7 +1844,14 @@ impl TemplateContext for DeviceConsentContext {
|
||||
ip_address: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
|
||||
user_agent: Some("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.0.0 Safari/537.36".to_owned()),
|
||||
};
|
||||
Self { grant, client }
|
||||
Self {
|
||||
grant,
|
||||
client,
|
||||
matrix_user: MatrixUser {
|
||||
mxid: "@alice:example.com".to_owned(),
|
||||
display_name: Some("Alice".to_owned()),
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ pub fn register(
|
||||
env.add_filter("simplify_url", filter_simplify_url);
|
||||
env.add_filter("add_slashes", filter_add_slashes);
|
||||
env.add_filter("parse_user_agent", filter_parse_user_agent);
|
||||
env.add_filter("id_color_hash", filter_id_color_hash);
|
||||
env.add_function("add_params_to_url", function_add_params_to_url);
|
||||
env.add_function("counter", || Ok(Value::from_object(Counter::default())));
|
||||
if let Some(vite_manifest) = vite_manifest {
|
||||
@@ -138,6 +139,12 @@ fn filter_simplify_url(url: &str, kwargs: Kwargs) -> Result<String, minijinja::E
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter which computes a hash between 1 and 6 of an input string, identitical
|
||||
/// to compound-web's `useIdColorHash`
|
||||
fn filter_id_color_hash(input: &str) -> u32 {
|
||||
input.chars().fold(0, |hash, c| hash + c as u32) % 6 + 1
|
||||
}
|
||||
|
||||
/// Filter which parses a user-agent string
|
||||
fn filter_parse_user_agent(user_agent: String) -> Value {
|
||||
let user_agent = mas_data_model::UserAgent::parse(user_agent);
|
||||
|
||||
@@ -37,14 +37,15 @@ mod macros;
|
||||
|
||||
pub use self::{
|
||||
context::{
|
||||
AccountInactiveContext, ApiDocContext, AppContext, CompatSsoContext, ConsentContext,
|
||||
DeviceConsentContext, DeviceLinkContext, DeviceLinkFormField, DeviceNameContext,
|
||||
EmailRecoveryContext, EmailVerificationContext, EmptyContext, ErrorContext,
|
||||
FormPostContext, IndexContext, LoginContext, LoginFormField, NotFoundContext,
|
||||
PasswordRegisterContext, PolicyViolationContext, PostAuthContext, PostAuthContextInner,
|
||||
RecoveryExpiredContext, RecoveryFinishContext, RecoveryFinishFormField,
|
||||
RecoveryProgressContext, RecoveryStartContext, RecoveryStartFormField, RegisterContext,
|
||||
RegisterFormField, RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField,
|
||||
AccountInactiveContext, ApiDocContext, AppContext, CompatLoginPolicyViolationContext,
|
||||
CompatSsoContext, ConsentContext, DeviceConsentContext, DeviceLinkContext,
|
||||
DeviceLinkFormField, DeviceNameContext, EmailRecoveryContext, EmailVerificationContext,
|
||||
EmptyContext, ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField,
|
||||
NotFoundContext, PasswordRegisterContext, PolicyViolationContext, PostAuthContext,
|
||||
PostAuthContextInner, RecoveryExpiredContext, RecoveryFinishContext,
|
||||
RecoveryFinishFormField, RecoveryProgressContext, RecoveryStartContext,
|
||||
RecoveryStartFormField, RegisterContext, RegisterFormField,
|
||||
RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField,
|
||||
RegisterStepsEmailInUseContext, RegisterStepsRegistrationTokenContext,
|
||||
RegisterStepsRegistrationTokenFormField, RegisterStepsVerifyEmailContext,
|
||||
RegisterStepsVerifyEmailFormField, SiteBranding, SiteConfigExt, SiteFeatures,
|
||||
@@ -391,6 +392,9 @@ register_templates! {
|
||||
/// Render the policy violation page
|
||||
pub fn render_policy_violation(WithLanguage<WithCsrf<WithSession<PolicyViolationContext>>>) { "pages/policy_violation.html" }
|
||||
|
||||
/// Render the compatibility login policy violation page
|
||||
pub fn render_compat_login_policy_violation(WithLanguage<WithCsrf<WithSession<CompatLoginPolicyViolationContext>>>) { "pages/compat_login_policy_violation.html" }
|
||||
|
||||
/// Render the legacy SSO login consent page
|
||||
pub fn render_sso_login(WithLanguage<WithCsrf<WithSession<CompatSsoContext>>>) { "pages/sso.html" }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user