diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index d3c7c979f..f0be20ca9 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -16,6 +16,7 @@ use mas_data_model::{ User, }; use mas_matrix::HomeserverConnection; +use mas_policy::{Policy, Requester, ViolationCode, model::CompatLoginType}; use mas_storage::{ BoxRepository, BoxRepositoryFactory, RepositoryAccess, compat::{ @@ -37,6 +38,7 @@ use crate::{ BoundActivityTracker, Limiter, METER, RequesterFingerprint, impl_from_error_for_route, passwords::{PasswordManager, PasswordVerificationResult}, rate_limit::PasswordCheckLimitedError, + session::count_user_sessions_for_limiting, }; static LOGIN_COUNTER: LazyLock> = LazyLock::new(|| { @@ -213,9 +215,16 @@ pub enum RouteError { #[error("failed to provision device")] ProvisionDeviceFailed(#[source] anyhow::Error), + + #[error("login rejected by policy")] + PolicyRejected, + + #[error("login rejected by policy (hard session limit reached)")] + PolicyHardSessionLimitReached, } impl_from_error_for_route!(mas_storage::RepositoryError); +impl_from_error_for_route!(mas_policy::EvaluationError); impl From for RouteError { fn from(err: anyhow::Error) -> Self { @@ -274,6 +283,16 @@ impl IntoResponse for RouteError { error: "User account has been locked", status: StatusCode::UNAUTHORIZED, }, + Self::PolicyRejected => MatrixError { + errcode: "M_FORBIDDEN", + error: "Login denied by the policy enforced by this service", + status: StatusCode::FORBIDDEN, + }, + Self::PolicyHardSessionLimitReached => MatrixError { + errcode: "M_FORBIDDEN", + error: "You have reached your hard device limit. Please visit your account page to sign some out.", + status: StatusCode::FORBIDDEN, + }, }; (sentry_event_id, response).into_response() @@ -290,6 +309,7 @@ pub(crate) async fn post( State(homeserver): State>, State(site_config): State, State(limiter): State, + mut policy: Policy, requester: RequesterFingerprint, user_agent: Option>, MatrixJsonBody(input): MatrixJsonBody, @@ -329,6 +349,11 @@ pub(crate) async fn post( &limiter, requester, &mut repo, + &mut policy, + Requester { + ip_address: activity_tracker.ip(), + user_agent: user_agent.clone(), + }, username, password, input.device_id, // TODO check for validity @@ -342,6 +367,11 @@ pub(crate) async fn post( &mut rng, &clock, &mut repo, + &mut policy, + Requester { + ip_address: activity_tracker.ip(), + user_agent: user_agent.clone(), + }, &token, input.device_id, input.initial_device_display_name, @@ -459,6 +489,8 @@ async fn token_login( rng: &mut (dyn RngCore + Send), clock: &dyn Clock, repo: &mut BoxRepository, + policy: &mut Policy, + requester: Requester, token: &str, requested_device_id: Option, initial_device_display_name: Option, @@ -548,6 +580,27 @@ async fn token_login( .finish_sessions_to_replace_device(clock, &browser_session.user, &device) .await?; + let session_counts = count_user_sessions_for_limiting(repo, &browser_session.user).await?; + + let res = policy + .evaluate_compat_login(mas_policy::CompatLoginInput { + user: &browser_session.user, + login_type: CompatLoginType::WebSso, + session_counts, + requester, + }) + .await?; + if !res.valid() { + if res.violations.len() == 1 { + let violation = &res.violations[0]; + if violation.code == Some(ViolationCode::TooManySessions) { + // The only violation is having reached the session limit. + return Err(RouteError::PolicyHardSessionLimitReached); + } + } + return Err(RouteError::PolicyRejected); + } + // We first create the session in the database, commit the transaction, then // create it on the homeserver, scheduling a device sync job afterwards to // make sure we don't end up in an inconsistent state. @@ -578,6 +631,8 @@ async fn user_password_login( limiter: &Limiter, requester: RequesterFingerprint, repo: &mut BoxRepository, + policy: &mut Policy, + policy_requester: Requester, username: &str, password: String, requested_device_id: Option, @@ -651,6 +706,27 @@ async fn user_password_login( .finish_sessions_to_replace_device(clock, &user, &device) .await?; + let session_counts = count_user_sessions_for_limiting(repo, &user).await?; + + let res = policy + .evaluate_compat_login(mas_policy::CompatLoginInput { + user: &user, + login_type: CompatLoginType::Password, + session_counts, + requester: policy_requester, + }) + .await?; + if !res.valid() { + if res.violations.len() == 1 { + let violation = &res.violations[0]; + if violation.code == Some(ViolationCode::TooManySessions) { + // The only violation is having reached the session limit. + return Err(RouteError::PolicyHardSessionLimitReached); + } + } + return Err(RouteError::PolicyRejected); + } + let session = repo .compat_session() .add( diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index a4fbb24fb..3c33b5d46 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -11,23 +11,28 @@ use axum::{ extract::{Form, Path, State}, response::{Html, IntoResponse, Redirect, Response}, }; -use axum_extra::extract::Query; +use axum_extra::{TypedHeader, extract::Query}; use chrono::Duration; +use hyper::StatusCode; use mas_axum_utils::{ InternalError, cookies::CookieJar, csrf::{CsrfExt, ProtectedForm}, }; use mas_data_model::{BoxClock, BoxRng, Clock}; +use mas_policy::{Policy, ViolationCode, model::CompatLoginType}; use mas_router::{CompatLoginSsoAction, UrlBuilder}; use mas_storage::{BoxRepository, RepositoryAccess, compat::CompatSsoLoginRepository}; -use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; +use mas_templates::{ + CompatLoginPolicyViolationContext, CompatSsoContext, EmptyContext, ErrorContext, + PolicyViolationContext, TemplateContext, Templates, +}; use serde::{Deserialize, Serialize}; use ulid::Ulid; use crate::{ - PreferredLanguage, - session::{SessionOrFallback, load_session_or_fallback}, + BoundActivityTracker, PreferredLanguage, + session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback}, }; #[derive(Serialize)] @@ -56,10 +61,15 @@ pub async fn get( mut repo: BoxRepository, State(templates): State, State(url_builder): State, + mut policy: Policy, + activity_tracker: BoundActivityTracker, + user_agent: Option>, cookie_jar: CookieJar, Path(id): Path, Query(params): Query, ) -> Result { + let user_agent = user_agent.map(|ua| ua.to_string()); + let (cookie_jar, maybe_session) = match load_session_or_fallback( cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo, ) @@ -107,6 +117,35 @@ pub async fn get( return Ok((cookie_jar, Html(content)).into_response()); } + let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?; + + let res = policy + .evaluate_compat_login(mas_policy::CompatLoginInput { + user: &session.user, + login_type: CompatLoginType::WebSso, + session_counts, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent, + }, + }) + .await?; + if !res.valid() { + let ctx = CompatLoginPolicyViolationContext::for_violations( + res.violations + .into_iter() + .filter_map(|v| Some(v.code?.as_str())) + .collect(), + ) + .with_session(session) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + + let content = templates.render_compat_login_policy_violation(&ctx)?; + + return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response()); + } + let ctx = CompatSsoContext::new(login) .with_session(session) .with_csrf(csrf_token.form_value()) @@ -129,11 +168,16 @@ pub async fn post( PreferredLanguage(locale): PreferredLanguage, State(templates): State, State(url_builder): State, + mut policy: Policy, + activity_tracker: BoundActivityTracker, + user_agent: Option>, cookie_jar: CookieJar, Path(id): Path, Query(params): Query, Form(form): Form>, ) -> Result { + let user_agent = user_agent.map(|ua| ua.to_string()); + let (cookie_jar, maybe_session) = match load_session_or_fallback( cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo, ) @@ -200,6 +244,37 @@ pub async fn post( redirect_uri }; + let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?; + + let res = policy + .evaluate_compat_login(mas_policy::CompatLoginInput { + user: &session.user, + login_type: CompatLoginType::WebSso, + session_counts, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent, + }, + }) + .await?; + + if !res.valid() { + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + let ctx = CompatLoginPolicyViolationContext::for_violations( + res.violations + .into_iter() + .filter_map(|v| Some(v.code?.as_str())) + .collect(), + ) + .with_session(session) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + + let content = templates.render_compat_login_policy_violation(&ctx)?; + + return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response()); + } + // Note that if the login is not Pending, // this fails and aborts the transaction. repo.compat_sso_login() diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 65a75f550..ebd223e4a 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -272,6 +272,7 @@ where BoxRepository: FromRequestParts, BoxClock: FromRequestParts, BoxRng: FromRequestParts, + Policy: FromRequestParts, { // A sub-router for human-facing routes with error handling let human_router = Router::new() diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index 85b05d317..81c37a6e5 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -17,7 +17,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; /// A well-known policy code. -#[derive(Deserialize, Debug, Clone, Copy, JsonSchema)] +#[derive(Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)] #[serde(rename_all = "kebab-case")] pub enum Code { /// The username is too short.