From add9650e10c79461484703a0d2d3961c521f2fd1 Mon Sep 17 00:00:00 2001 From: Olivier 'reivilibre Date: Fri, 13 Feb 2026 11:49:06 +0000 Subject: [PATCH 1/2] Convert ViolationCode into ViolationVariant to allow adding fields on each variant --- crates/handlers/src/compat/login.rs | 6 +++--- crates/handlers/src/upstream_oauth2/link.rs | 4 ++-- crates/handlers/src/views/register/password.rs | 8 ++++---- crates/policy/src/lib.rs | 6 +++--- crates/policy/src/model.rs | 13 ++++++++----- crates/templates/src/context.rs | 4 ++-- 6 files changed, 22 insertions(+), 19 deletions(-) diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index ebb5d32c1..1c3710644 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -16,7 +16,7 @@ use mas_data_model::{ User, }; use mas_matrix::HomeserverConnection; -use mas_policy::{Policy, Requester, ViolationCode, model::CompatLogin}; +use mas_policy::{Policy, Requester, ViolationVariant, model::CompatLogin}; use mas_storage::{ BoxRepository, BoxRepositoryFactory, RepositoryAccess, compat::{ @@ -605,7 +605,7 @@ async fn token_login( // that removing a session wouldn't actually unblock the login. if res.violations.len() == 1 { let violation = &res.violations[0]; - if violation.code == Some(ViolationCode::TooManySessions) { + if violation.variant == Some(ViolationVariant::TooManySessions) { // The only violation is having reached the session limit. return Err(RouteError::PolicyHardSessionLimitReached); } @@ -738,7 +738,7 @@ async fn user_password_login( // that removing a session wouldn't actually unblock the login. if res.violations.len() == 1 { let violation = &res.violations[0]; - if violation.code == Some(ViolationCode::TooManySessions) { + if violation.variant == Some(ViolationVariant::TooManySessions) { // The only violation is having reached the session limit. return Err(RouteError::PolicyHardSessionLimitReached); } diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 357a7a72f..4ed536ad0 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -1108,13 +1108,13 @@ pub(crate) async fn post( form_state.add_error_on_field( mas_templates::UpstreamRegisterFormField::Username, FieldError::Policy { - code: violation.code.map(|c| c.as_str()), + code: violation.variant.map(|c| c.as_str()), message: violation.msg, }, ); } _ => form_state.add_error_on_form(FormError::Policy { - code: violation.code.map(|c| c.as_str()), + code: violation.variant.map(|c| c.as_str()), message: violation.msg, }), } diff --git a/crates/handlers/src/views/register/password.rs b/crates/handlers/src/views/register/password.rs index 65ba5fe0d..39643e972 100644 --- a/crates/handlers/src/views/register/password.rs +++ b/crates/handlers/src/views/register/password.rs @@ -263,7 +263,7 @@ pub(crate) async fn post( Some("email") => state.add_error_on_field( RegisterFormField::Email, FieldError::Policy { - code: violation.code.map(|c| c.as_str()), + code: violation.variant.map(|c| c.as_str()), message: violation.msg, }, ), @@ -274,7 +274,7 @@ pub(crate) async fn post( state.add_error_on_field( RegisterFormField::Username, FieldError::Policy { - code: violation.code.map(|c| c.as_str()), + code: violation.variant.map(|c| c.as_str()), message: violation.msg, }, ); @@ -282,12 +282,12 @@ pub(crate) async fn post( Some("password") => state.add_error_on_field( RegisterFormField::Password, FieldError::Policy { - code: violation.code.map(|c| c.as_str()), + code: violation.variant.map(|c| c.as_str()), message: violation.msg, }, ), _ => state.add_error_on_form(FormError::Policy { - code: violation.code.map(|c| c.as_str()), + code: violation.variant.map(|c| c.as_str()), message: violation.msg, }), } diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index dcb68dd36..a5d4805ad 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -19,9 +19,9 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; pub use self::model::{ - AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput, - EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, - Violation, + AuthorizationGrantInput, ClientRegistrationInput, CompatLoginInput, EmailInput, + EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation, + ViolationVariant, }; #[derive(Debug, Error)] diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index a9f5fb502..8ba780297 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -16,10 +16,11 @@ use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -/// A well-known policy code. +/// Violation variants identified by a well-known policy code (under the `code` +/// key). #[derive(Serialize, Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)] -#[serde(rename_all = "kebab-case")] -pub enum Code { +#[serde(tag = "code", rename_all = "kebab-case")] +pub enum ViolationVariant { /// The username is too short. UsernameTooShort, @@ -54,7 +55,7 @@ pub enum Code { TooManySessions, } -impl Code { +impl ViolationVariant { /// Returns the code as a string #[must_use] pub fn as_str(&self) -> &'static str { @@ -80,7 +81,9 @@ pub struct Violation { pub msg: String, pub redirect_uri: Option, pub field: Option, - pub code: Option, + + #[serde(flatten)] + pub variant: Option, } /// The result of a policy evaluation. diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index d43556fae..25123970b 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -29,7 +29,7 @@ use mas_data_model::{ }; use mas_i18n::DataLocale; use mas_iana::jose::JsonWebSignatureAlg; -use mas_policy::{Violation, ViolationCode}; +use mas_policy::{Violation, ViolationVariant}; use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder}; use oauth2_types::scope::{OPENID, Scope}; use rand::{ @@ -890,7 +890,7 @@ impl TemplateContext for CompatLoginPolicyViolationContext { msg: "user has too many active sessions".to_owned(), redirect_uri: None, field: None, - code: Some(ViolationCode::TooManySessions), + variant: Some(ViolationVariant::TooManySessions), }], }, ]) From e4b08f90064103293a5d7172d4e5e13609fef253 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 25 Mar 2026 13:12:04 -0500 Subject: [PATCH 2/2] Explain `code` being splatted See https://github.com/element-hq/matrix-authentication-service/pull/5553#discussion_r2984636426 --- crates/policy/src/model.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index 8ba780297..a3bf24b5f 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -82,6 +82,10 @@ pub struct Violation { pub redirect_uri: Option, pub field: Option, + // We flatten as policies expect `code` as another top-level field. + // + // This also means all of the extra fields from the variant will be splatted at this + // level which is fine (arbitrary). #[serde(flatten)] pub variant: Option, }