Add experimental and preliminary policy-driven session limiting when logging in compatibility sessions. (#5287)
This commit is contained in:
@@ -145,6 +145,7 @@ pub async fn policy_factory_from_config(
|
||||
register: config.register_entrypoint.clone(),
|
||||
client_registration: config.client_registration_entrypoint.clone(),
|
||||
authorization_grant: config.authorization_grant_entrypoint.clone(),
|
||||
compat_login: config.compat_login_entrypoint.clone(),
|
||||
email: config.email_entrypoint.clone(),
|
||||
};
|
||||
|
||||
|
||||
@@ -62,6 +62,14 @@ fn is_default_password_entrypoint(value: &String) -> bool {
|
||||
*value == default_password_entrypoint()
|
||||
}
|
||||
|
||||
fn default_compat_login_entrypoint() -> String {
|
||||
"compat_login/violation".to_owned()
|
||||
}
|
||||
|
||||
fn is_default_compat_login_entrypoint(value: &String) -> bool {
|
||||
*value == default_compat_login_entrypoint()
|
||||
}
|
||||
|
||||
fn default_email_entrypoint() -> String {
|
||||
"email/violation".to_owned()
|
||||
}
|
||||
@@ -111,6 +119,13 @@ pub struct PolicyConfig {
|
||||
)]
|
||||
pub authorization_grant_entrypoint: String,
|
||||
|
||||
/// Entrypoint to use when evaluating compatibility logins
|
||||
#[serde(
|
||||
default = "default_compat_login_entrypoint",
|
||||
skip_serializing_if = "is_default_compat_login_entrypoint"
|
||||
)]
|
||||
pub compat_login_entrypoint: String,
|
||||
|
||||
/// Entrypoint to use when changing password
|
||||
#[serde(
|
||||
default = "default_password_entrypoint",
|
||||
@@ -137,6 +152,7 @@ impl Default for PolicyConfig {
|
||||
client_registration_entrypoint: default_client_registration_entrypoint(),
|
||||
register_entrypoint: default_register_entrypoint(),
|
||||
authorization_grant_entrypoint: default_authorization_grant_entrypoint(),
|
||||
compat_login_entrypoint: default_compat_login_entrypoint(),
|
||||
password_entrypoint: default_password_entrypoint(),
|
||||
email_entrypoint: default_email_entrypoint(),
|
||||
data: default_data(),
|
||||
|
||||
@@ -16,6 +16,7 @@ use mas_data_model::{
|
||||
User,
|
||||
};
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_policy::{Policy, Requester, ViolationCode, model::CompatLogin};
|
||||
use mas_storage::{
|
||||
BoxRepository, BoxRepositoryFactory, RepositoryAccess,
|
||||
compat::{
|
||||
@@ -37,6 +38,7 @@ use crate::{
|
||||
BoundActivityTracker, Limiter, METER, RequesterFingerprint, impl_from_error_for_route,
|
||||
passwords::{PasswordManager, PasswordVerificationResult},
|
||||
rate_limit::PasswordCheckLimitedError,
|
||||
session::count_user_sessions_for_limiting,
|
||||
};
|
||||
|
||||
static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
|
||||
@@ -213,9 +215,16 @@ pub enum RouteError {
|
||||
|
||||
#[error("failed to provision device")]
|
||||
ProvisionDeviceFailed(#[source] anyhow::Error),
|
||||
|
||||
#[error("login rejected by policy")]
|
||||
PolicyRejected,
|
||||
|
||||
#[error("login rejected by policy (hard session limit reached)")]
|
||||
PolicyHardSessionLimitReached,
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
|
||||
impl From<anyhow::Error> for RouteError {
|
||||
fn from(err: anyhow::Error) -> Self {
|
||||
@@ -274,6 +283,16 @@ impl IntoResponse for RouteError {
|
||||
error: "User account has been locked",
|
||||
status: StatusCode::UNAUTHORIZED,
|
||||
},
|
||||
Self::PolicyRejected => MatrixError {
|
||||
errcode: "M_FORBIDDEN",
|
||||
error: "Login denied by the policy enforced by this service",
|
||||
status: StatusCode::FORBIDDEN,
|
||||
},
|
||||
Self::PolicyHardSessionLimitReached => MatrixError {
|
||||
errcode: "M_FORBIDDEN",
|
||||
error: "You have reached your hard device limit. Please visit your account page to sign some out.",
|
||||
status: StatusCode::FORBIDDEN,
|
||||
},
|
||||
};
|
||||
|
||||
(sentry_event_id, response).into_response()
|
||||
@@ -290,6 +309,7 @@ pub(crate) async fn post(
|
||||
State(homeserver): State<Arc<dyn HomeserverConnection>>,
|
||||
State(site_config): State<SiteConfig>,
|
||||
State(limiter): State<Limiter>,
|
||||
mut policy: Policy,
|
||||
requester: RequesterFingerprint,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
MatrixJsonBody(input): MatrixJsonBody<RequestBody>,
|
||||
@@ -329,6 +349,11 @@ pub(crate) async fn post(
|
||||
&limiter,
|
||||
requester,
|
||||
&mut repo,
|
||||
&mut policy,
|
||||
Requester {
|
||||
ip_address: activity_tracker.ip(),
|
||||
user_agent: user_agent.clone(),
|
||||
},
|
||||
username,
|
||||
password,
|
||||
input.device_id, // TODO check for validity
|
||||
@@ -342,6 +367,11 @@ pub(crate) async fn post(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&mut repo,
|
||||
&mut policy,
|
||||
Requester {
|
||||
ip_address: activity_tracker.ip(),
|
||||
user_agent: user_agent.clone(),
|
||||
},
|
||||
&token,
|
||||
input.device_id,
|
||||
input.initial_device_display_name,
|
||||
@@ -459,6 +489,8 @@ async fn token_login(
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
repo: &mut BoxRepository,
|
||||
policy: &mut Policy,
|
||||
requester: Requester,
|
||||
token: &str,
|
||||
requested_device_id: Option<String>,
|
||||
initial_device_display_name: Option<String>,
|
||||
@@ -544,10 +576,38 @@ async fn token_login(
|
||||
Device::generate(rng)
|
||||
};
|
||||
|
||||
repo.app_session()
|
||||
let session_replaced = repo
|
||||
.app_session()
|
||||
.finish_sessions_to_replace_device(clock, &browser_session.user, &device)
|
||||
.await?;
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(repo, &browser_session.user).await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_compat_login(mas_policy::CompatLoginInput {
|
||||
user: &browser_session.user,
|
||||
login: CompatLogin::Token,
|
||||
session_replaced,
|
||||
session_counts,
|
||||
requester,
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
// If the only violation is that we have too many sessions, then handle that
|
||||
// separately.
|
||||
// In the future, we intend to evict some sessions automatically instead. We
|
||||
// don't trigger this if there was some other violation anyway, since that means
|
||||
// that removing a session wouldn't actually unblock the login.
|
||||
if res.violations.len() == 1 {
|
||||
let violation = &res.violations[0];
|
||||
if violation.code == Some(ViolationCode::TooManySessions) {
|
||||
// The only violation is having reached the session limit.
|
||||
return Err(RouteError::PolicyHardSessionLimitReached);
|
||||
}
|
||||
}
|
||||
return Err(RouteError::PolicyRejected);
|
||||
}
|
||||
|
||||
// We first create the session in the database, commit the transaction, then
|
||||
// create it on the homeserver, scheduling a device sync job afterwards to
|
||||
// make sure we don't end up in an inconsistent state.
|
||||
@@ -578,6 +638,8 @@ async fn user_password_login(
|
||||
limiter: &Limiter,
|
||||
requester: RequesterFingerprint,
|
||||
repo: &mut BoxRepository,
|
||||
policy: &mut Policy,
|
||||
policy_requester: Requester,
|
||||
username: &str,
|
||||
password: String,
|
||||
requested_device_id: Option<String>,
|
||||
@@ -647,10 +709,38 @@ async fn user_password_login(
|
||||
Device::generate(&mut rng)
|
||||
};
|
||||
|
||||
repo.app_session()
|
||||
let session_replaced = repo
|
||||
.app_session()
|
||||
.finish_sessions_to_replace_device(clock, &user, &device)
|
||||
.await?;
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(repo, &user).await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_compat_login(mas_policy::CompatLoginInput {
|
||||
user: &user,
|
||||
login: CompatLogin::Password,
|
||||
session_replaced,
|
||||
session_counts,
|
||||
requester: policy_requester,
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
// If the only violation is that we have too many sessions, then handle that
|
||||
// separately.
|
||||
// In the future, we intend to evict some sessions automatically instead. We
|
||||
// don't trigger this if there was some other violation anyway, since that means
|
||||
// that removing a session wouldn't actually unblock the login.
|
||||
if res.violations.len() == 1 {
|
||||
let violation = &res.violations[0];
|
||||
if violation.code == Some(ViolationCode::TooManySessions) {
|
||||
// The only violation is having reached the session limit.
|
||||
return Err(RouteError::PolicyHardSessionLimitReached);
|
||||
}
|
||||
}
|
||||
return Err(RouteError::PolicyRejected);
|
||||
}
|
||||
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.add(
|
||||
|
||||
@@ -11,23 +11,27 @@ use axum::{
|
||||
extract::{Form, Path, State},
|
||||
response::{Html, IntoResponse, Redirect, Response},
|
||||
};
|
||||
use axum_extra::extract::Query;
|
||||
use axum_extra::{TypedHeader, extract::Query};
|
||||
use chrono::Duration;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::{
|
||||
InternalError,
|
||||
cookies::CookieJar,
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
};
|
||||
use mas_data_model::{BoxClock, BoxRng, Clock};
|
||||
use mas_policy::{Policy, model::CompatLogin};
|
||||
use mas_router::{CompatLoginSsoAction, UrlBuilder};
|
||||
use mas_storage::{BoxRepository, RepositoryAccess, compat::CompatSsoLoginRepository};
|
||||
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
|
||||
use mas_templates::{
|
||||
CompatLoginPolicyViolationContext, CompatSsoContext, ErrorContext, TemplateContext, Templates,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
PreferredLanguage,
|
||||
session::{SessionOrFallback, load_session_or_fallback},
|
||||
BoundActivityTracker, PreferredLanguage,
|
||||
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
|
||||
};
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -56,10 +60,15 @@ pub async fn get(
|
||||
mut repo: BoxRepository,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
mut policy: Policy,
|
||||
activity_tracker: BoundActivityTracker,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
cookie_jar: CookieJar,
|
||||
Path(id): Path<Ulid>,
|
||||
Query(params): Query<Params>,
|
||||
) -> Result<Response, InternalError> {
|
||||
let user_agent = user_agent.map(|ua| ua.to_string());
|
||||
|
||||
let (cookie_jar, maybe_session) = match load_session_or_fallback(
|
||||
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
|
||||
)
|
||||
@@ -107,6 +116,35 @@ pub async fn get(
|
||||
return Ok((cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_compat_login(mas_policy::CompatLoginInput {
|
||||
user: &session.user,
|
||||
login: CompatLogin::Sso {
|
||||
redirect_uri: login.redirect_uri.to_string(),
|
||||
},
|
||||
// We don't know if there's going to be a replacement until we received the device ID,
|
||||
// which happens too late.
|
||||
session_replaced: false,
|
||||
session_counts,
|
||||
requester: mas_policy::Requester {
|
||||
ip_address: activity_tracker.ip(),
|
||||
user_agent,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
let ctx = CompatLoginPolicyViolationContext::for_violations(res.violations)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
|
||||
let content = templates.render_compat_login_policy_violation(&ctx)?;
|
||||
|
||||
return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
let ctx = CompatSsoContext::new(login)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
@@ -129,11 +167,16 @@ pub async fn post(
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
mut policy: Policy,
|
||||
activity_tracker: BoundActivityTracker,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
cookie_jar: CookieJar,
|
||||
Path(id): Path<Ulid>,
|
||||
Query(params): Query<Params>,
|
||||
Form(form): Form<ProtectedForm<()>>,
|
||||
) -> Result<Response, InternalError> {
|
||||
let user_agent = user_agent.map(|ua| ua.to_string());
|
||||
|
||||
let (cookie_jar, maybe_session) = match load_session_or_fallback(
|
||||
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
|
||||
)
|
||||
@@ -200,6 +243,37 @@ pub async fn post(
|
||||
redirect_uri
|
||||
};
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_compat_login(mas_policy::CompatLoginInput {
|
||||
user: &session.user,
|
||||
login: CompatLogin::Sso {
|
||||
redirect_uri: login.redirect_uri.to_string(),
|
||||
},
|
||||
session_counts,
|
||||
// We don't know if there's going to be a replacement until we received the device ID,
|
||||
// which happens too late.
|
||||
session_replaced: false,
|
||||
requester: mas_policy::Requester {
|
||||
ip_address: activity_tracker.ip(),
|
||||
user_agent,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
|
||||
if !res.valid() {
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let ctx = CompatLoginPolicyViolationContext::for_violations(res.violations)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
|
||||
let content = templates.render_compat_login_policy_violation(&ctx)?;
|
||||
|
||||
return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
// Note that if the login is not Pending,
|
||||
// this fails and aborts the transaction.
|
||||
repo.compat_sso_login()
|
||||
|
||||
@@ -272,6 +272,7 @@ where
|
||||
BoxRepository: FromRequestParts<S>,
|
||||
BoxClock: FromRequestParts<S>,
|
||||
BoxRng: FromRequestParts<S>,
|
||||
Policy: FromRequestParts<S>,
|
||||
{
|
||||
// A sub-router for human-facing routes with error handling
|
||||
let human_router = Router::new()
|
||||
|
||||
@@ -82,6 +82,7 @@ pub(crate) async fn policy_factory(
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
compat_login: "compat_login/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use mas_policy::model::{
|
||||
AuthorizationGrantInput, ClientRegistrationInput, EmailInput, RegisterInput,
|
||||
AuthorizationGrantInput, ClientRegistrationInput, CompatLoginInput, EmailInput, RegisterInput,
|
||||
};
|
||||
use schemars::{JsonSchema, generate::SchemaSettings};
|
||||
|
||||
@@ -42,5 +42,6 @@ fn main() {
|
||||
write_schema::<RegisterInput>(output_root, "register_input.json");
|
||||
write_schema::<ClientRegistrationInput>(output_root, "client_registration_input.json");
|
||||
write_schema::<AuthorizationGrantInput>(output_root, "authorization_grant_input.json");
|
||||
write_schema::<CompatLoginInput>(output_root, "compat_login_input.json");
|
||||
write_schema::<EmailInput>(output_root, "email_input.json");
|
||||
}
|
||||
|
||||
@@ -19,8 +19,9 @@ use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
pub use self::model::{
|
||||
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput,
|
||||
EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
|
||||
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput,
|
||||
EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester,
|
||||
Violation,
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -72,15 +73,17 @@ pub struct Entrypoints {
|
||||
pub register: String,
|
||||
pub client_registration: String,
|
||||
pub authorization_grant: String,
|
||||
pub compat_login: String,
|
||||
pub email: String,
|
||||
}
|
||||
|
||||
impl Entrypoints {
|
||||
fn all(&self) -> [&str; 4] {
|
||||
fn all(&self) -> [&str; 5] {
|
||||
[
|
||||
self.register.as_str(),
|
||||
self.client_registration.as_str(),
|
||||
self.authorization_grant.as_str(),
|
||||
self.compat_login.as_str(),
|
||||
self.email.as_str(),
|
||||
]
|
||||
}
|
||||
@@ -459,6 +462,30 @@ impl Policy {
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Evaluate the `compat_login` entrypoint.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the policy engine fails to evaluate the entrypoint.
|
||||
#[tracing::instrument(
|
||||
name = "policy.evaluate.compat_login",
|
||||
skip_all,
|
||||
fields(
|
||||
%input.user.id,
|
||||
),
|
||||
)]
|
||||
pub async fn evaluate_compat_login(
|
||||
&mut self,
|
||||
input: CompatLoginInput<'_>,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.entrypoints.compat_login, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -468,6 +495,16 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
fn make_entrypoints() -> Entrypoints {
|
||||
Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
compat_login: "compat_login/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register() {
|
||||
let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({
|
||||
@@ -484,14 +521,9 @@ mod tests {
|
||||
|
||||
let file = tokio::fs::File::open(path).await.unwrap();
|
||||
|
||||
let entrypoints = Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
||||
let factory = PolicyFactory::load(file, data, make_entrypoints())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut policy = factory.instantiate().await.unwrap();
|
||||
|
||||
@@ -551,14 +583,9 @@ mod tests {
|
||||
|
||||
let file = tokio::fs::File::open(path).await.unwrap();
|
||||
|
||||
let entrypoints = Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
||||
let factory = PolicyFactory::load(file, data, make_entrypoints())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut policy = factory.instantiate().await.unwrap();
|
||||
|
||||
@@ -620,14 +647,9 @@ mod tests {
|
||||
|
||||
let file = tokio::fs::File::open(path).await.unwrap();
|
||||
|
||||
let entrypoints = Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
||||
let factory = PolicyFactory::load(file, data, make_entrypoints())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// That is around 1 MB of JSON data. Each element is a 5-digit string, so 8
|
||||
// characters including the quotes and a comma.
|
||||
|
||||
@@ -17,7 +17,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A well-known policy code.
|
||||
#[derive(Deserialize, Debug, Clone, Copy, JsonSchema)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum Code {
|
||||
/// The username is too short.
|
||||
@@ -75,7 +75,7 @@ impl Code {
|
||||
}
|
||||
|
||||
/// A single violation of a policy.
|
||||
#[derive(Deserialize, Debug, JsonSchema)]
|
||||
#[derive(Serialize, Deserialize, Debug, JsonSchema)]
|
||||
pub struct Violation {
|
||||
pub msg: String,
|
||||
pub redirect_uri: Option<String>,
|
||||
@@ -187,6 +187,42 @@ pub struct AuthorizationGrantInput<'a> {
|
||||
pub requester: Requester,
|
||||
}
|
||||
|
||||
/// Input for the compatibility login policy.
|
||||
#[derive(Serialize, Debug, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct CompatLoginInput<'a> {
|
||||
#[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
|
||||
pub user: &'a User,
|
||||
|
||||
/// How many sessions the user has.
|
||||
pub session_counts: SessionCounts,
|
||||
|
||||
/// Whether a session will be replaced by this login
|
||||
pub session_replaced: bool,
|
||||
|
||||
/// What type of login is being performed.
|
||||
/// This also determines whether the login is interactive.
|
||||
pub login: CompatLogin,
|
||||
|
||||
pub requester: Requester,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, JsonSchema)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum CompatLogin {
|
||||
/// Used as the interactive part of SSO login.
|
||||
#[serde(rename = "m.login.sso")]
|
||||
Sso { redirect_uri: String },
|
||||
|
||||
/// Used as the final (non-interactive) stage of SSO login.
|
||||
#[serde(rename = "m.login.token")]
|
||||
Token,
|
||||
|
||||
/// Non-interactive password-over-the-API login.
|
||||
#[serde(rename = "m.login.password")]
|
||||
Password,
|
||||
}
|
||||
|
||||
/// Information about how many sessions the user has
|
||||
#[derive(Serialize, Debug, JsonSchema)]
|
||||
pub struct SessionCounts {
|
||||
|
||||
@@ -487,14 +487,15 @@ impl AppSessionRepository for PgAppSessionRepository<'_> {
|
||||
clock: &dyn Clock,
|
||||
user: &User,
|
||||
device: &Device,
|
||||
) -> Result<(), Self::Error> {
|
||||
) -> Result<bool, Self::Error> {
|
||||
let mut affected = false;
|
||||
// TODO need to invoke this from all the oauth2 login sites
|
||||
let span = tracing::info_span!(
|
||||
"db.app_session.finish_sessions_to_replace_device.compat_sessions",
|
||||
{ DB_QUERY_TEXT } = tracing::field::Empty,
|
||||
);
|
||||
let finished_at = clock.now();
|
||||
sqlx::query!(
|
||||
let compat_affected = sqlx::query!(
|
||||
"
|
||||
UPDATE compat_sessions SET finished_at = $3 WHERE user_id = $1 AND device_id = $2 AND finished_at IS NULL
|
||||
",
|
||||
@@ -505,7 +506,9 @@ impl AppSessionRepository for PgAppSessionRepository<'_> {
|
||||
.record(&span)
|
||||
.execute(&mut *self.conn)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
.await?
|
||||
.rows_affected();
|
||||
affected |= compat_affected > 0;
|
||||
|
||||
if let Ok([stable_device_as_scope_token, unstable_device_as_scope_token]) =
|
||||
device.to_scope_token()
|
||||
@@ -514,7 +517,7 @@ impl AppSessionRepository for PgAppSessionRepository<'_> {
|
||||
"db.app_session.finish_sessions_to_replace_device.oauth2_sessions",
|
||||
{ DB_QUERY_TEXT } = tracing::field::Empty,
|
||||
);
|
||||
sqlx::query!(
|
||||
let oauth2_affected = sqlx::query!(
|
||||
"
|
||||
UPDATE oauth2_sessions
|
||||
SET finished_at = $4
|
||||
@@ -530,10 +533,12 @@ impl AppSessionRepository for PgAppSessionRepository<'_> {
|
||||
.record(&span)
|
||||
.execute(&mut *self.conn)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
.await?
|
||||
.rows_affected();
|
||||
affected |= oauth2_affected > 0;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(affected)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -196,12 +196,14 @@ pub trait AppSessionRepository: Send + Sync {
|
||||
/// replacing a device).
|
||||
///
|
||||
/// Should be called *before* creating a new session for the device.
|
||||
///
|
||||
/// Returns true if a session was finished.
|
||||
async fn finish_sessions_to_replace_device(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
user: &User,
|
||||
device: &Device,
|
||||
) -> Result<(), Self::Error>;
|
||||
) -> Result<bool, Self::Error>;
|
||||
}
|
||||
|
||||
repository_impl!(AppSessionRepository:
|
||||
@@ -218,5 +220,5 @@ repository_impl!(AppSessionRepository:
|
||||
clock: &dyn Clock,
|
||||
user: &User,
|
||||
device: &Device,
|
||||
) -> Result<(), Self::Error>;
|
||||
) -> Result<bool, Self::Error>;
|
||||
);
|
||||
|
||||
@@ -41,6 +41,7 @@ oauth2-types.workspace = true
|
||||
mas-data-model.workspace = true
|
||||
mas-i18n.workspace = true
|
||||
mas-iana.workspace = true
|
||||
mas-policy.workspace = true
|
||||
mas-router.workspace = true
|
||||
mas-spa.workspace = true
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ use mas_data_model::{
|
||||
};
|
||||
use mas_i18n::DataLocale;
|
||||
use mas_iana::jose::JsonWebSignatureAlg;
|
||||
use mas_policy::{Violation, ViolationCode};
|
||||
use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder};
|
||||
use oauth2_types::scope::{OPENID, Scope};
|
||||
use rand::{
|
||||
@@ -860,6 +861,44 @@ impl PolicyViolationContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// Context used by the `compat_login_policy_violation.html` template
|
||||
#[derive(Serialize)]
|
||||
pub struct CompatLoginPolicyViolationContext {
|
||||
violations: Vec<Violation>,
|
||||
}
|
||||
|
||||
impl TemplateContext for CompatLoginPolicyViolationContext {
|
||||
fn sample<R: Rng>(
|
||||
_now: chrono::DateTime<Utc>,
|
||||
_rng: &mut R,
|
||||
_locales: &[DataLocale],
|
||||
) -> BTreeMap<SampleIdentifier, Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
sample_list(vec![
|
||||
CompatLoginPolicyViolationContext { violations: vec![] },
|
||||
CompatLoginPolicyViolationContext {
|
||||
violations: vec![Violation {
|
||||
msg: "user has too many active sessions".to_owned(),
|
||||
redirect_uri: None,
|
||||
field: None,
|
||||
code: Some(ViolationCode::TooManySessions),
|
||||
}],
|
||||
},
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl CompatLoginPolicyViolationContext {
|
||||
/// Constructs a context for the compatibility login policy violation page
|
||||
/// given the list of violations
|
||||
#[must_use]
|
||||
pub const fn for_violations(violations: Vec<Violation>) -> Self {
|
||||
Self { violations }
|
||||
}
|
||||
}
|
||||
|
||||
/// Context used by the `sso.html` template
|
||||
#[derive(Serialize)]
|
||||
pub struct CompatSsoContext {
|
||||
|
||||
@@ -37,14 +37,15 @@ mod macros;
|
||||
|
||||
pub use self::{
|
||||
context::{
|
||||
AccountInactiveContext, ApiDocContext, AppContext, CompatSsoContext, ConsentContext,
|
||||
DeviceConsentContext, DeviceLinkContext, DeviceLinkFormField, DeviceNameContext,
|
||||
EmailRecoveryContext, EmailVerificationContext, EmptyContext, ErrorContext,
|
||||
FormPostContext, IndexContext, LoginContext, LoginFormField, NotFoundContext,
|
||||
PasswordRegisterContext, PolicyViolationContext, PostAuthContext, PostAuthContextInner,
|
||||
RecoveryExpiredContext, RecoveryFinishContext, RecoveryFinishFormField,
|
||||
RecoveryProgressContext, RecoveryStartContext, RecoveryStartFormField, RegisterContext,
|
||||
RegisterFormField, RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField,
|
||||
AccountInactiveContext, ApiDocContext, AppContext, CompatLoginPolicyViolationContext,
|
||||
CompatSsoContext, ConsentContext, DeviceConsentContext, DeviceLinkContext,
|
||||
DeviceLinkFormField, DeviceNameContext, EmailRecoveryContext, EmailVerificationContext,
|
||||
EmptyContext, ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField,
|
||||
NotFoundContext, PasswordRegisterContext, PolicyViolationContext, PostAuthContext,
|
||||
PostAuthContextInner, RecoveryExpiredContext, RecoveryFinishContext,
|
||||
RecoveryFinishFormField, RecoveryProgressContext, RecoveryStartContext,
|
||||
RecoveryStartFormField, RegisterContext, RegisterFormField,
|
||||
RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField,
|
||||
RegisterStepsEmailInUseContext, RegisterStepsRegistrationTokenContext,
|
||||
RegisterStepsRegistrationTokenFormField, RegisterStepsVerifyEmailContext,
|
||||
RegisterStepsVerifyEmailFormField, SiteBranding, SiteConfigExt, SiteFeatures,
|
||||
@@ -391,6 +392,9 @@ register_templates! {
|
||||
/// Render the policy violation page
|
||||
pub fn render_policy_violation(WithLanguage<WithCsrf<WithSession<PolicyViolationContext>>>) { "pages/policy_violation.html" }
|
||||
|
||||
/// Render the compatibility login policy violation page
|
||||
pub fn render_compat_login_policy_violation(WithLanguage<WithCsrf<WithSession<CompatLoginPolicyViolationContext>>>) { "pages/compat_login_policy_violation.html" }
|
||||
|
||||
/// Render the legacy SSO login consent page
|
||||
pub fn render_sso_login(WithLanguage<WithCsrf<WithSession<CompatSsoContext>>>) { "pages/sso.html" }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user