Convert ViolationCode into ViolationVariant to allow adding fields on each variant (#5553)

Concretely, for the [device](https://github.com/element-hq/matrix-authentication-service/issues/4339) [limiting](https://github.com/element-hq/backend-internal/issues/199) stuff, we'll need the policy to be able to tell MAS how many sessions should be removed before the user can log in.
This way the UI will be able to guide the user through it.

Intended to look something like:
```rust
ViolationVariant::TooManySessions {
    /// How many devices to remove
    need_to_remove: u32
}
```
This commit is contained in:
Eric Eastwood
2026-03-25 17:23:29 -05:00
committed by GitHub
6 changed files with 26 additions and 19 deletions

View File

@@ -16,7 +16,7 @@ use mas_data_model::{
User,
};
use mas_matrix::HomeserverConnection;
use mas_policy::{Policy, Requester, ViolationCode, model::CompatLogin};
use mas_policy::{Policy, Requester, ViolationVariant, model::CompatLogin};
use mas_storage::{
BoxRepository, BoxRepositoryFactory, RepositoryAccess,
compat::{
@@ -605,7 +605,7 @@ async fn token_login(
// that removing a session wouldn't actually unblock the login.
if res.violations.len() == 1 {
let violation = &res.violations[0];
if violation.code == Some(ViolationCode::TooManySessions) {
if violation.variant == Some(ViolationVariant::TooManySessions) {
// The only violation is having reached the session limit.
return Err(RouteError::PolicyHardSessionLimitReached);
}
@@ -738,7 +738,7 @@ async fn user_password_login(
// that removing a session wouldn't actually unblock the login.
if res.violations.len() == 1 {
let violation = &res.violations[0];
if violation.code == Some(ViolationCode::TooManySessions) {
if violation.variant == Some(ViolationVariant::TooManySessions) {
// The only violation is having reached the session limit.
return Err(RouteError::PolicyHardSessionLimitReached);
}

View File

@@ -1108,13 +1108,13 @@ pub(crate) async fn post(
form_state.add_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Policy {
code: violation.code.map(|c| c.as_str()),
code: violation.variant.map(|c| c.as_str()),
message: violation.msg,
},
);
}
_ => form_state.add_error_on_form(FormError::Policy {
code: violation.code.map(|c| c.as_str()),
code: violation.variant.map(|c| c.as_str()),
message: violation.msg,
}),
}

View File

@@ -263,7 +263,7 @@ pub(crate) async fn post(
Some("email") => state.add_error_on_field(
RegisterFormField::Email,
FieldError::Policy {
code: violation.code.map(|c| c.as_str()),
code: violation.variant.map(|c| c.as_str()),
message: violation.msg,
},
),
@@ -274,7 +274,7 @@ pub(crate) async fn post(
state.add_error_on_field(
RegisterFormField::Username,
FieldError::Policy {
code: violation.code.map(|c| c.as_str()),
code: violation.variant.map(|c| c.as_str()),
message: violation.msg,
},
);
@@ -282,12 +282,12 @@ pub(crate) async fn post(
Some("password") => state.add_error_on_field(
RegisterFormField::Password,
FieldError::Policy {
code: violation.code.map(|c| c.as_str()),
code: violation.variant.map(|c| c.as_str()),
message: violation.msg,
},
),
_ => state.add_error_on_form(FormError::Policy {
code: violation.code.map(|c| c.as_str()),
code: violation.variant.map(|c| c.as_str()),
message: violation.msg,
}),
}

View File

@@ -19,9 +19,9 @@ use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
pub use self::model::{
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput,
EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester,
Violation,
AuthorizationGrantInput, ClientRegistrationInput, CompatLoginInput, EmailInput,
EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
ViolationVariant,
};
#[derive(Debug, Error)]

View File

@@ -16,10 +16,11 @@ use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
/// A well-known policy code.
/// Violation variants identified by a well-known policy code (under the `code`
/// key).
#[derive(Serialize, Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum Code {
#[serde(tag = "code", rename_all = "kebab-case")]
pub enum ViolationVariant {
/// The username is too short.
UsernameTooShort,
@@ -54,7 +55,7 @@ pub enum Code {
TooManySessions,
}
impl Code {
impl ViolationVariant {
/// Returns the code as a string
#[must_use]
pub fn as_str(&self) -> &'static str {
@@ -80,7 +81,13 @@ pub struct Violation {
pub msg: String,
pub redirect_uri: Option<String>,
pub field: Option<String>,
pub code: Option<Code>,
// We flatten as policies expect `code` as another top-level field.
//
// This also means all of the extra fields from the variant will be splatted at this
// level which is fine (arbitrary).
#[serde(flatten)]
pub variant: Option<ViolationVariant>,
}
/// The result of a policy evaluation.

View File

@@ -29,7 +29,7 @@ use mas_data_model::{
};
use mas_i18n::DataLocale;
use mas_iana::jose::JsonWebSignatureAlg;
use mas_policy::{Violation, ViolationCode};
use mas_policy::{Violation, ViolationVariant};
use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder};
use oauth2_types::scope::{OPENID, Scope};
use rand::{
@@ -890,7 +890,7 @@ impl TemplateContext for CompatLoginPolicyViolationContext {
msg: "user has too many active sessions".to_owned(),
redirect_uri: None,
field: None,
code: Some(ViolationCode::TooManySessions),
variant: Some(ViolationVariant::TooManySessions),
}],
},
])