diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 6aa0fb53d..5e57ce5c0 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -16,7 +16,7 @@ use mas_data_model::{ User, }; use mas_matrix::HomeserverConnection; -use mas_policy::{Policy, Requester, ViolationCode, model::CompatLogin}; +use mas_policy::{Policy, Requester, ViolationVariant, model::CompatLogin}; use mas_storage::{ BoxRepository, BoxRepositoryFactory, RepositoryAccess, compat::{ @@ -605,7 +605,7 @@ async fn token_login( // that removing a session wouldn't actually unblock the login. if res.violations.len() == 1 { let violation = &res.violations[0]; - if violation.code == Some(ViolationCode::TooManySessions) { + if violation.variant == Some(ViolationVariant::TooManySessions) { // The only violation is having reached the session limit. return Err(RouteError::PolicyHardSessionLimitReached); } @@ -738,7 +738,7 @@ async fn user_password_login( // that removing a session wouldn't actually unblock the login. if res.violations.len() == 1 { let violation = &res.violations[0]; - if violation.code == Some(ViolationCode::TooManySessions) { + if violation.variant == Some(ViolationVariant::TooManySessions) { // The only violation is having reached the session limit. return Err(RouteError::PolicyHardSessionLimitReached); } diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index e2c3e9fb9..e081b0f79 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -1108,13 +1108,13 @@ pub(crate) async fn post( form_state.add_error_on_field( mas_templates::UpstreamRegisterFormField::Username, FieldError::Policy { - code: violation.code.map(|c| c.as_str()), + code: violation.variant.map(|c| c.as_str()), message: violation.msg, }, ); } _ => form_state.add_error_on_form(FormError::Policy { - code: violation.code.map(|c| c.as_str()), + code: violation.variant.map(|c| c.as_str()), message: violation.msg, }), } diff --git a/crates/handlers/src/views/register/password.rs b/crates/handlers/src/views/register/password.rs index 65ba5fe0d..39643e972 100644 --- a/crates/handlers/src/views/register/password.rs +++ b/crates/handlers/src/views/register/password.rs @@ -263,7 +263,7 @@ pub(crate) async fn post( Some("email") => state.add_error_on_field( RegisterFormField::Email, FieldError::Policy { - code: violation.code.map(|c| c.as_str()), + code: violation.variant.map(|c| c.as_str()), message: violation.msg, }, ), @@ -274,7 +274,7 @@ pub(crate) async fn post( state.add_error_on_field( RegisterFormField::Username, FieldError::Policy { - code: violation.code.map(|c| c.as_str()), + code: violation.variant.map(|c| c.as_str()), message: violation.msg, }, ); @@ -282,12 +282,12 @@ pub(crate) async fn post( Some("password") => state.add_error_on_field( RegisterFormField::Password, FieldError::Policy { - code: violation.code.map(|c| c.as_str()), + code: violation.variant.map(|c| c.as_str()), message: violation.msg, }, ), _ => state.add_error_on_form(FormError::Policy { - code: violation.code.map(|c| c.as_str()), + code: violation.variant.map(|c| c.as_str()), message: violation.msg, }), } diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index dcb68dd36..a5d4805ad 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -19,9 +19,9 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; pub use self::model::{ - AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput, - EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, - Violation, + AuthorizationGrantInput, ClientRegistrationInput, CompatLoginInput, EmailInput, + EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation, + ViolationVariant, }; #[derive(Debug, Error)] diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index a9f5fb502..a3bf24b5f 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -16,10 +16,11 @@ use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -/// A well-known policy code. +/// Violation variants identified by a well-known policy code (under the `code` +/// key). #[derive(Serialize, Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)] -#[serde(rename_all = "kebab-case")] -pub enum Code { +#[serde(tag = "code", rename_all = "kebab-case")] +pub enum ViolationVariant { /// The username is too short. UsernameTooShort, @@ -54,7 +55,7 @@ pub enum Code { TooManySessions, } -impl Code { +impl ViolationVariant { /// Returns the code as a string #[must_use] pub fn as_str(&self) -> &'static str { @@ -80,7 +81,13 @@ pub struct Violation { pub msg: String, pub redirect_uri: Option, pub field: Option, - pub code: Option, + + // We flatten as policies expect `code` as another top-level field. + // + // This also means all of the extra fields from the variant will be splatted at this + // level which is fine (arbitrary). + #[serde(flatten)] + pub variant: Option, } /// The result of a policy evaluation. diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index d43556fae..25123970b 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -29,7 +29,7 @@ use mas_data_model::{ }; use mas_i18n::DataLocale; use mas_iana::jose::JsonWebSignatureAlg; -use mas_policy::{Violation, ViolationCode}; +use mas_policy::{Violation, ViolationVariant}; use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder}; use oauth2_types::scope::{OPENID, Scope}; use rand::{ @@ -890,7 +890,7 @@ impl TemplateContext for CompatLoginPolicyViolationContext { msg: "user has too many active sessions".to_owned(), redirect_uri: None, field: None, - code: Some(ViolationCode::TooManySessions), + variant: Some(ViolationVariant::TooManySessions), }], }, ])