Enforce policy on compat login
This commit is contained in:
@@ -16,6 +16,7 @@ use mas_data_model::{
|
||||
User,
|
||||
};
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_policy::{Policy, Requester, ViolationCode, model::CompatLoginType};
|
||||
use mas_storage::{
|
||||
BoxRepository, BoxRepositoryFactory, RepositoryAccess,
|
||||
compat::{
|
||||
@@ -37,6 +38,7 @@ use crate::{
|
||||
BoundActivityTracker, Limiter, METER, RequesterFingerprint, impl_from_error_for_route,
|
||||
passwords::{PasswordManager, PasswordVerificationResult},
|
||||
rate_limit::PasswordCheckLimitedError,
|
||||
session::count_user_sessions_for_limiting,
|
||||
};
|
||||
|
||||
static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
|
||||
@@ -213,9 +215,16 @@ pub enum RouteError {
|
||||
|
||||
#[error("failed to provision device")]
|
||||
ProvisionDeviceFailed(#[source] anyhow::Error),
|
||||
|
||||
#[error("login rejected by policy")]
|
||||
PolicyRejected,
|
||||
|
||||
#[error("login rejected by policy (hard session limit reached)")]
|
||||
PolicyHardSessionLimitReached,
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
|
||||
impl From<anyhow::Error> for RouteError {
|
||||
fn from(err: anyhow::Error) -> Self {
|
||||
@@ -274,6 +283,16 @@ impl IntoResponse for RouteError {
|
||||
error: "User account has been locked",
|
||||
status: StatusCode::UNAUTHORIZED,
|
||||
},
|
||||
Self::PolicyRejected => MatrixError {
|
||||
errcode: "M_FORBIDDEN",
|
||||
error: "Login denied by the policy enforced by this service",
|
||||
status: StatusCode::FORBIDDEN,
|
||||
},
|
||||
Self::PolicyHardSessionLimitReached => MatrixError {
|
||||
errcode: "M_FORBIDDEN",
|
||||
error: "You have reached your hard device limit. Please visit your account page to sign some out.",
|
||||
status: StatusCode::FORBIDDEN,
|
||||
},
|
||||
};
|
||||
|
||||
(sentry_event_id, response).into_response()
|
||||
@@ -290,6 +309,7 @@ pub(crate) async fn post(
|
||||
State(homeserver): State<Arc<dyn HomeserverConnection>>,
|
||||
State(site_config): State<SiteConfig>,
|
||||
State(limiter): State<Limiter>,
|
||||
mut policy: Policy,
|
||||
requester: RequesterFingerprint,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
MatrixJsonBody(input): MatrixJsonBody<RequestBody>,
|
||||
@@ -329,6 +349,11 @@ pub(crate) async fn post(
|
||||
&limiter,
|
||||
requester,
|
||||
&mut repo,
|
||||
&mut policy,
|
||||
Requester {
|
||||
ip_address: activity_tracker.ip(),
|
||||
user_agent: user_agent.clone(),
|
||||
},
|
||||
username,
|
||||
password,
|
||||
input.device_id, // TODO check for validity
|
||||
@@ -342,6 +367,11 @@ pub(crate) async fn post(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&mut repo,
|
||||
&mut policy,
|
||||
Requester {
|
||||
ip_address: activity_tracker.ip(),
|
||||
user_agent: user_agent.clone(),
|
||||
},
|
||||
&token,
|
||||
input.device_id,
|
||||
input.initial_device_display_name,
|
||||
@@ -459,6 +489,8 @@ async fn token_login(
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
repo: &mut BoxRepository,
|
||||
policy: &mut Policy,
|
||||
requester: Requester,
|
||||
token: &str,
|
||||
requested_device_id: Option<String>,
|
||||
initial_device_display_name: Option<String>,
|
||||
@@ -548,6 +580,27 @@ async fn token_login(
|
||||
.finish_sessions_to_replace_device(clock, &browser_session.user, &device)
|
||||
.await?;
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(repo, &browser_session.user).await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_compat_login(mas_policy::CompatLoginInput {
|
||||
user: &browser_session.user,
|
||||
login_type: CompatLoginType::WebSso,
|
||||
session_counts,
|
||||
requester,
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
if res.violations.len() == 1 {
|
||||
let violation = &res.violations[0];
|
||||
if violation.code == Some(ViolationCode::TooManySessions) {
|
||||
// The only violation is having reached the session limit.
|
||||
return Err(RouteError::PolicyHardSessionLimitReached);
|
||||
}
|
||||
}
|
||||
return Err(RouteError::PolicyRejected);
|
||||
}
|
||||
|
||||
// We first create the session in the database, commit the transaction, then
|
||||
// create it on the homeserver, scheduling a device sync job afterwards to
|
||||
// make sure we don't end up in an inconsistent state.
|
||||
@@ -578,6 +631,8 @@ async fn user_password_login(
|
||||
limiter: &Limiter,
|
||||
requester: RequesterFingerprint,
|
||||
repo: &mut BoxRepository,
|
||||
policy: &mut Policy,
|
||||
policy_requester: Requester,
|
||||
username: &str,
|
||||
password: String,
|
||||
requested_device_id: Option<String>,
|
||||
@@ -651,6 +706,27 @@ async fn user_password_login(
|
||||
.finish_sessions_to_replace_device(clock, &user, &device)
|
||||
.await?;
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(repo, &user).await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_compat_login(mas_policy::CompatLoginInput {
|
||||
user: &user,
|
||||
login_type: CompatLoginType::Password,
|
||||
session_counts,
|
||||
requester: policy_requester,
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
if res.violations.len() == 1 {
|
||||
let violation = &res.violations[0];
|
||||
if violation.code == Some(ViolationCode::TooManySessions) {
|
||||
// The only violation is having reached the session limit.
|
||||
return Err(RouteError::PolicyHardSessionLimitReached);
|
||||
}
|
||||
}
|
||||
return Err(RouteError::PolicyRejected);
|
||||
}
|
||||
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.add(
|
||||
|
||||
@@ -11,23 +11,28 @@ use axum::{
|
||||
extract::{Form, Path, State},
|
||||
response::{Html, IntoResponse, Redirect, Response},
|
||||
};
|
||||
use axum_extra::extract::Query;
|
||||
use axum_extra::{TypedHeader, extract::Query};
|
||||
use chrono::Duration;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::{
|
||||
InternalError,
|
||||
cookies::CookieJar,
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
};
|
||||
use mas_data_model::{BoxClock, BoxRng, Clock};
|
||||
use mas_policy::{Policy, ViolationCode, model::CompatLoginType};
|
||||
use mas_router::{CompatLoginSsoAction, UrlBuilder};
|
||||
use mas_storage::{BoxRepository, RepositoryAccess, compat::CompatSsoLoginRepository};
|
||||
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
|
||||
use mas_templates::{
|
||||
CompatLoginPolicyViolationContext, CompatSsoContext, EmptyContext, ErrorContext,
|
||||
PolicyViolationContext, TemplateContext, Templates,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
PreferredLanguage,
|
||||
session::{SessionOrFallback, load_session_or_fallback},
|
||||
BoundActivityTracker, PreferredLanguage,
|
||||
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
|
||||
};
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -56,10 +61,15 @@ pub async fn get(
|
||||
mut repo: BoxRepository,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
mut policy: Policy,
|
||||
activity_tracker: BoundActivityTracker,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
cookie_jar: CookieJar,
|
||||
Path(id): Path<Ulid>,
|
||||
Query(params): Query<Params>,
|
||||
) -> Result<Response, InternalError> {
|
||||
let user_agent = user_agent.map(|ua| ua.to_string());
|
||||
|
||||
let (cookie_jar, maybe_session) = match load_session_or_fallback(
|
||||
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
|
||||
)
|
||||
@@ -107,6 +117,35 @@ pub async fn get(
|
||||
return Ok((cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_compat_login(mas_policy::CompatLoginInput {
|
||||
user: &session.user,
|
||||
login_type: CompatLoginType::WebSso,
|
||||
session_counts,
|
||||
requester: mas_policy::Requester {
|
||||
ip_address: activity_tracker.ip(),
|
||||
user_agent,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
let ctx = CompatLoginPolicyViolationContext::for_violations(
|
||||
res.violations
|
||||
.into_iter()
|
||||
.filter_map(|v| Some(v.code?.as_str()))
|
||||
.collect(),
|
||||
)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
|
||||
let content = templates.render_compat_login_policy_violation(&ctx)?;
|
||||
|
||||
return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
let ctx = CompatSsoContext::new(login)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
@@ -129,11 +168,16 @@ pub async fn post(
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
mut policy: Policy,
|
||||
activity_tracker: BoundActivityTracker,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
cookie_jar: CookieJar,
|
||||
Path(id): Path<Ulid>,
|
||||
Query(params): Query<Params>,
|
||||
Form(form): Form<ProtectedForm<()>>,
|
||||
) -> Result<Response, InternalError> {
|
||||
let user_agent = user_agent.map(|ua| ua.to_string());
|
||||
|
||||
let (cookie_jar, maybe_session) = match load_session_or_fallback(
|
||||
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
|
||||
)
|
||||
@@ -200,6 +244,37 @@ pub async fn post(
|
||||
redirect_uri
|
||||
};
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_compat_login(mas_policy::CompatLoginInput {
|
||||
user: &session.user,
|
||||
login_type: CompatLoginType::WebSso,
|
||||
session_counts,
|
||||
requester: mas_policy::Requester {
|
||||
ip_address: activity_tracker.ip(),
|
||||
user_agent,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
|
||||
if !res.valid() {
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let ctx = CompatLoginPolicyViolationContext::for_violations(
|
||||
res.violations
|
||||
.into_iter()
|
||||
.filter_map(|v| Some(v.code?.as_str()))
|
||||
.collect(),
|
||||
)
|
||||
.with_session(session)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
|
||||
let content = templates.render_compat_login_policy_violation(&ctx)?;
|
||||
|
||||
return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
// Note that if the login is not Pending,
|
||||
// this fails and aborts the transaction.
|
||||
repo.compat_sso_login()
|
||||
|
||||
@@ -272,6 +272,7 @@ where
|
||||
BoxRepository: FromRequestParts<S>,
|
||||
BoxClock: FromRequestParts<S>,
|
||||
BoxRng: FromRequestParts<S>,
|
||||
Policy: FromRequestParts<S>,
|
||||
{
|
||||
// A sub-router for human-facing routes with error handling
|
||||
let human_router = Router::new()
|
||||
|
||||
@@ -17,7 +17,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A well-known policy code.
|
||||
#[derive(Deserialize, Debug, Clone, Copy, JsonSchema)]
|
||||
#[derive(Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum Code {
|
||||
/// The username is too short.
|
||||
|
||||
Reference in New Issue
Block a user