Convert ViolationCode into ViolationVariant to allow adding fields on each variant
This commit is contained in:
@@ -16,7 +16,7 @@ use mas_data_model::{
|
||||
User,
|
||||
};
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_policy::{Policy, Requester, ViolationCode, model::CompatLogin};
|
||||
use mas_policy::{Policy, Requester, ViolationVariant, model::CompatLogin};
|
||||
use mas_storage::{
|
||||
BoxRepository, BoxRepositoryFactory, RepositoryAccess,
|
||||
compat::{
|
||||
@@ -605,7 +605,7 @@ async fn token_login(
|
||||
// that removing a session wouldn't actually unblock the login.
|
||||
if res.violations.len() == 1 {
|
||||
let violation = &res.violations[0];
|
||||
if violation.code == Some(ViolationCode::TooManySessions) {
|
||||
if violation.variant == Some(ViolationVariant::TooManySessions) {
|
||||
// The only violation is having reached the session limit.
|
||||
return Err(RouteError::PolicyHardSessionLimitReached);
|
||||
}
|
||||
@@ -738,7 +738,7 @@ async fn user_password_login(
|
||||
// that removing a session wouldn't actually unblock the login.
|
||||
if res.violations.len() == 1 {
|
||||
let violation = &res.violations[0];
|
||||
if violation.code == Some(ViolationCode::TooManySessions) {
|
||||
if violation.variant == Some(ViolationVariant::TooManySessions) {
|
||||
// The only violation is having reached the session limit.
|
||||
return Err(RouteError::PolicyHardSessionLimitReached);
|
||||
}
|
||||
|
||||
@@ -1108,13 +1108,13 @@ pub(crate) async fn post(
|
||||
form_state.add_error_on_field(
|
||||
mas_templates::UpstreamRegisterFormField::Username,
|
||||
FieldError::Policy {
|
||||
code: violation.code.map(|c| c.as_str()),
|
||||
code: violation.variant.map(|c| c.as_str()),
|
||||
message: violation.msg,
|
||||
},
|
||||
);
|
||||
}
|
||||
_ => form_state.add_error_on_form(FormError::Policy {
|
||||
code: violation.code.map(|c| c.as_str()),
|
||||
code: violation.variant.map(|c| c.as_str()),
|
||||
message: violation.msg,
|
||||
}),
|
||||
}
|
||||
|
||||
@@ -263,7 +263,7 @@ pub(crate) async fn post(
|
||||
Some("email") => state.add_error_on_field(
|
||||
RegisterFormField::Email,
|
||||
FieldError::Policy {
|
||||
code: violation.code.map(|c| c.as_str()),
|
||||
code: violation.variant.map(|c| c.as_str()),
|
||||
message: violation.msg,
|
||||
},
|
||||
),
|
||||
@@ -274,7 +274,7 @@ pub(crate) async fn post(
|
||||
state.add_error_on_field(
|
||||
RegisterFormField::Username,
|
||||
FieldError::Policy {
|
||||
code: violation.code.map(|c| c.as_str()),
|
||||
code: violation.variant.map(|c| c.as_str()),
|
||||
message: violation.msg,
|
||||
},
|
||||
);
|
||||
@@ -282,12 +282,12 @@ pub(crate) async fn post(
|
||||
Some("password") => state.add_error_on_field(
|
||||
RegisterFormField::Password,
|
||||
FieldError::Policy {
|
||||
code: violation.code.map(|c| c.as_str()),
|
||||
code: violation.variant.map(|c| c.as_str()),
|
||||
message: violation.msg,
|
||||
},
|
||||
),
|
||||
_ => state.add_error_on_form(FormError::Policy {
|
||||
code: violation.code.map(|c| c.as_str()),
|
||||
code: violation.variant.map(|c| c.as_str()),
|
||||
message: violation.msg,
|
||||
}),
|
||||
}
|
||||
|
||||
@@ -19,9 +19,9 @@ use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
pub use self::model::{
|
||||
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput,
|
||||
EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester,
|
||||
Violation,
|
||||
AuthorizationGrantInput, ClientRegistrationInput, CompatLoginInput, EmailInput,
|
||||
EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
|
||||
ViolationVariant,
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
||||
@@ -16,10 +16,11 @@ use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A well-known policy code.
|
||||
/// Violation variants identified by a well-known policy code (under the `code`
|
||||
/// key).
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum Code {
|
||||
#[serde(tag = "code", rename_all = "kebab-case")]
|
||||
pub enum ViolationVariant {
|
||||
/// The username is too short.
|
||||
UsernameTooShort,
|
||||
|
||||
@@ -54,7 +55,7 @@ pub enum Code {
|
||||
TooManySessions,
|
||||
}
|
||||
|
||||
impl Code {
|
||||
impl ViolationVariant {
|
||||
/// Returns the code as a string
|
||||
#[must_use]
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
@@ -80,7 +81,9 @@ pub struct Violation {
|
||||
pub msg: String,
|
||||
pub redirect_uri: Option<String>,
|
||||
pub field: Option<String>,
|
||||
pub code: Option<Code>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub variant: Option<ViolationVariant>,
|
||||
}
|
||||
|
||||
/// The result of a policy evaluation.
|
||||
|
||||
@@ -29,7 +29,7 @@ use mas_data_model::{
|
||||
};
|
||||
use mas_i18n::DataLocale;
|
||||
use mas_iana::jose::JsonWebSignatureAlg;
|
||||
use mas_policy::{Violation, ViolationCode};
|
||||
use mas_policy::{Violation, ViolationVariant};
|
||||
use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder};
|
||||
use oauth2_types::scope::{OPENID, Scope};
|
||||
use rand::{
|
||||
@@ -890,7 +890,7 @@ impl TemplateContext for CompatLoginPolicyViolationContext {
|
||||
msg: "user has too many active sessions".to_owned(),
|
||||
redirect_uri: None,
|
||||
field: None,
|
||||
code: Some(ViolationCode::TooManySessions),
|
||||
variant: Some(ViolationVariant::TooManySessions),
|
||||
}],
|
||||
},
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user