Enforce policy on compat login

This commit is contained in:
Olivier 'reivilibre
2025-11-25 18:20:14 +00:00
parent 31c3fe2b39
commit 985ea0b30a
4 changed files with 157 additions and 5 deletions

View File

@@ -16,6 +16,7 @@ use mas_data_model::{
User,
};
use mas_matrix::HomeserverConnection;
use mas_policy::{Policy, Requester, ViolationCode, model::CompatLoginType};
use mas_storage::{
BoxRepository, BoxRepositoryFactory, RepositoryAccess,
compat::{
@@ -37,6 +38,7 @@ use crate::{
BoundActivityTracker, Limiter, METER, RequesterFingerprint, impl_from_error_for_route,
passwords::{PasswordManager, PasswordVerificationResult},
rate_limit::PasswordCheckLimitedError,
session::count_user_sessions_for_limiting,
};
static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
@@ -213,9 +215,16 @@ pub enum RouteError {
#[error("failed to provision device")]
ProvisionDeviceFailed(#[source] anyhow::Error),
#[error("login rejected by policy")]
PolicyRejected,
#[error("login rejected by policy (hard session limit reached)")]
PolicyHardSessionLimitReached,
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl From<anyhow::Error> for RouteError {
fn from(err: anyhow::Error) -> Self {
@@ -274,6 +283,16 @@ impl IntoResponse for RouteError {
error: "User account has been locked",
status: StatusCode::UNAUTHORIZED,
},
Self::PolicyRejected => MatrixError {
errcode: "M_FORBIDDEN",
error: "Login denied by the policy enforced by this service",
status: StatusCode::FORBIDDEN,
},
Self::PolicyHardSessionLimitReached => MatrixError {
errcode: "M_FORBIDDEN",
error: "You have reached your hard device limit. Please visit your account page to sign some out.",
status: StatusCode::FORBIDDEN,
},
};
(sentry_event_id, response).into_response()
@@ -290,6 +309,7 @@ pub(crate) async fn post(
State(homeserver): State<Arc<dyn HomeserverConnection>>,
State(site_config): State<SiteConfig>,
State(limiter): State<Limiter>,
mut policy: Policy,
requester: RequesterFingerprint,
user_agent: Option<TypedHeader<headers::UserAgent>>,
MatrixJsonBody(input): MatrixJsonBody<RequestBody>,
@@ -329,6 +349,11 @@ pub(crate) async fn post(
&limiter,
requester,
&mut repo,
&mut policy,
Requester {
ip_address: activity_tracker.ip(),
user_agent: user_agent.clone(),
},
username,
password,
input.device_id, // TODO check for validity
@@ -342,6 +367,11 @@ pub(crate) async fn post(
&mut rng,
&clock,
&mut repo,
&mut policy,
Requester {
ip_address: activity_tracker.ip(),
user_agent: user_agent.clone(),
},
&token,
input.device_id,
input.initial_device_display_name,
@@ -459,6 +489,8 @@ async fn token_login(
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
repo: &mut BoxRepository,
policy: &mut Policy,
requester: Requester,
token: &str,
requested_device_id: Option<String>,
initial_device_display_name: Option<String>,
@@ -548,6 +580,27 @@ async fn token_login(
.finish_sessions_to_replace_device(clock, &browser_session.user, &device)
.await?;
let session_counts = count_user_sessions_for_limiting(repo, &browser_session.user).await?;
let res = policy
.evaluate_compat_login(mas_policy::CompatLoginInput {
user: &browser_session.user,
login_type: CompatLoginType::WebSso,
session_counts,
requester,
})
.await?;
if !res.valid() {
if res.violations.len() == 1 {
let violation = &res.violations[0];
if violation.code == Some(ViolationCode::TooManySessions) {
// The only violation is having reached the session limit.
return Err(RouteError::PolicyHardSessionLimitReached);
}
}
return Err(RouteError::PolicyRejected);
}
// We first create the session in the database, commit the transaction, then
// create it on the homeserver, scheduling a device sync job afterwards to
// make sure we don't end up in an inconsistent state.
@@ -578,6 +631,8 @@ async fn user_password_login(
limiter: &Limiter,
requester: RequesterFingerprint,
repo: &mut BoxRepository,
policy: &mut Policy,
policy_requester: Requester,
username: &str,
password: String,
requested_device_id: Option<String>,
@@ -651,6 +706,27 @@ async fn user_password_login(
.finish_sessions_to_replace_device(clock, &user, &device)
.await?;
let session_counts = count_user_sessions_for_limiting(repo, &user).await?;
let res = policy
.evaluate_compat_login(mas_policy::CompatLoginInput {
user: &user,
login_type: CompatLoginType::Password,
session_counts,
requester: policy_requester,
})
.await?;
if !res.valid() {
if res.violations.len() == 1 {
let violation = &res.violations[0];
if violation.code == Some(ViolationCode::TooManySessions) {
// The only violation is having reached the session limit.
return Err(RouteError::PolicyHardSessionLimitReached);
}
}
return Err(RouteError::PolicyRejected);
}
let session = repo
.compat_session()
.add(

View File

@@ -11,23 +11,28 @@ use axum::{
extract::{Form, Path, State},
response::{Html, IntoResponse, Redirect, Response},
};
use axum_extra::extract::Query;
use axum_extra::{TypedHeader, extract::Query};
use chrono::Duration;
use hyper::StatusCode;
use mas_axum_utils::{
InternalError,
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
};
use mas_data_model::{BoxClock, BoxRng, Clock};
use mas_policy::{Policy, ViolationCode, model::CompatLoginType};
use mas_router::{CompatLoginSsoAction, UrlBuilder};
use mas_storage::{BoxRepository, RepositoryAccess, compat::CompatSsoLoginRepository};
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use mas_templates::{
CompatLoginPolicyViolationContext, CompatSsoContext, EmptyContext, ErrorContext,
PolicyViolationContext, TemplateContext, Templates,
};
use serde::{Deserialize, Serialize};
use ulid::Ulid;
use crate::{
PreferredLanguage,
session::{SessionOrFallback, load_session_or_fallback},
BoundActivityTracker, PreferredLanguage,
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
};
#[derive(Serialize)]
@@ -56,10 +61,15 @@ pub async fn get(
mut repo: BoxRepository,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
mut policy: Policy,
activity_tracker: BoundActivityTracker,
user_agent: Option<TypedHeader<headers::UserAgent>>,
cookie_jar: CookieJar,
Path(id): Path<Ulid>,
Query(params): Query<Params>,
) -> Result<Response, InternalError> {
let user_agent = user_agent.map(|ua| ua.to_string());
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
@@ -107,6 +117,35 @@ pub async fn get(
return Ok((cookie_jar, Html(content)).into_response());
}
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
let res = policy
.evaluate_compat_login(mas_policy::CompatLoginInput {
user: &session.user,
login_type: CompatLoginType::WebSso,
session_counts,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
user_agent,
},
})
.await?;
if !res.valid() {
let ctx = CompatLoginPolicyViolationContext::for_violations(
res.violations
.into_iter()
.filter_map(|v| Some(v.code?.as_str()))
.collect(),
)
.with_session(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let content = templates.render_compat_login_policy_violation(&ctx)?;
return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response());
}
let ctx = CompatSsoContext::new(login)
.with_session(session)
.with_csrf(csrf_token.form_value())
@@ -129,11 +168,16 @@ pub async fn post(
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
mut policy: Policy,
activity_tracker: BoundActivityTracker,
user_agent: Option<TypedHeader<headers::UserAgent>>,
cookie_jar: CookieJar,
Path(id): Path<Ulid>,
Query(params): Query<Params>,
Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, InternalError> {
let user_agent = user_agent.map(|ua| ua.to_string());
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
@@ -200,6 +244,37 @@ pub async fn post(
redirect_uri
};
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
let res = policy
.evaluate_compat_login(mas_policy::CompatLoginInput {
user: &session.user,
login_type: CompatLoginType::WebSso,
session_counts,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
user_agent,
},
})
.await?;
if !res.valid() {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let ctx = CompatLoginPolicyViolationContext::for_violations(
res.violations
.into_iter()
.filter_map(|v| Some(v.code?.as_str()))
.collect(),
)
.with_session(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let content = templates.render_compat_login_policy_violation(&ctx)?;
return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response());
}
// Note that if the login is not Pending,
// this fails and aborts the transaction.
repo.compat_sso_login()

View File

@@ -272,6 +272,7 @@ where
BoxRepository: FromRequestParts<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
Policy: FromRequestParts<S>,
{
// A sub-router for human-facing routes with error handling
let human_router = Router::new()

View File

@@ -17,7 +17,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
/// A well-known policy code.
#[derive(Deserialize, Debug, Clone, Copy, JsonSchema)]
#[derive(Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum Code {
/// The username is too short.