Convert ViolationCode into ViolationVariant to allow adding fields on each variant (#5553)
Concretely, for the [device](https://github.com/element-hq/matrix-authentication-service/issues/4339) [limiting](https://github.com/element-hq/backend-internal/issues/199) stuff, we'll need the policy to be able to tell MAS how many sessions should be removed before the user can log in. This way the UI will be able to guide the user through it. Intended to look something like: ```rust ViolationVariant::TooManySessions { /// How many devices to remove need_to_remove: u32 } ```
This commit is contained in:
@@ -16,7 +16,7 @@ use mas_data_model::{
|
||||
User,
|
||||
};
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_policy::{Policy, Requester, ViolationCode, model::CompatLogin};
|
||||
use mas_policy::{Policy, Requester, ViolationVariant, model::CompatLogin};
|
||||
use mas_storage::{
|
||||
BoxRepository, BoxRepositoryFactory, RepositoryAccess,
|
||||
compat::{
|
||||
@@ -605,7 +605,7 @@ async fn token_login(
|
||||
// that removing a session wouldn't actually unblock the login.
|
||||
if res.violations.len() == 1 {
|
||||
let violation = &res.violations[0];
|
||||
if violation.code == Some(ViolationCode::TooManySessions) {
|
||||
if violation.variant == Some(ViolationVariant::TooManySessions) {
|
||||
// The only violation is having reached the session limit.
|
||||
return Err(RouteError::PolicyHardSessionLimitReached);
|
||||
}
|
||||
@@ -738,7 +738,7 @@ async fn user_password_login(
|
||||
// that removing a session wouldn't actually unblock the login.
|
||||
if res.violations.len() == 1 {
|
||||
let violation = &res.violations[0];
|
||||
if violation.code == Some(ViolationCode::TooManySessions) {
|
||||
if violation.variant == Some(ViolationVariant::TooManySessions) {
|
||||
// The only violation is having reached the session limit.
|
||||
return Err(RouteError::PolicyHardSessionLimitReached);
|
||||
}
|
||||
|
||||
@@ -1108,13 +1108,13 @@ pub(crate) async fn post(
|
||||
form_state.add_error_on_field(
|
||||
mas_templates::UpstreamRegisterFormField::Username,
|
||||
FieldError::Policy {
|
||||
code: violation.code.map(|c| c.as_str()),
|
||||
code: violation.variant.map(|c| c.as_str()),
|
||||
message: violation.msg,
|
||||
},
|
||||
);
|
||||
}
|
||||
_ => form_state.add_error_on_form(FormError::Policy {
|
||||
code: violation.code.map(|c| c.as_str()),
|
||||
code: violation.variant.map(|c| c.as_str()),
|
||||
message: violation.msg,
|
||||
}),
|
||||
}
|
||||
|
||||
@@ -263,7 +263,7 @@ pub(crate) async fn post(
|
||||
Some("email") => state.add_error_on_field(
|
||||
RegisterFormField::Email,
|
||||
FieldError::Policy {
|
||||
code: violation.code.map(|c| c.as_str()),
|
||||
code: violation.variant.map(|c| c.as_str()),
|
||||
message: violation.msg,
|
||||
},
|
||||
),
|
||||
@@ -274,7 +274,7 @@ pub(crate) async fn post(
|
||||
state.add_error_on_field(
|
||||
RegisterFormField::Username,
|
||||
FieldError::Policy {
|
||||
code: violation.code.map(|c| c.as_str()),
|
||||
code: violation.variant.map(|c| c.as_str()),
|
||||
message: violation.msg,
|
||||
},
|
||||
);
|
||||
@@ -282,12 +282,12 @@ pub(crate) async fn post(
|
||||
Some("password") => state.add_error_on_field(
|
||||
RegisterFormField::Password,
|
||||
FieldError::Policy {
|
||||
code: violation.code.map(|c| c.as_str()),
|
||||
code: violation.variant.map(|c| c.as_str()),
|
||||
message: violation.msg,
|
||||
},
|
||||
),
|
||||
_ => state.add_error_on_form(FormError::Policy {
|
||||
code: violation.code.map(|c| c.as_str()),
|
||||
code: violation.variant.map(|c| c.as_str()),
|
||||
message: violation.msg,
|
||||
}),
|
||||
}
|
||||
|
||||
@@ -19,9 +19,9 @@ use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
pub use self::model::{
|
||||
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput,
|
||||
EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester,
|
||||
Violation,
|
||||
AuthorizationGrantInput, ClientRegistrationInput, CompatLoginInput, EmailInput,
|
||||
EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
|
||||
ViolationVariant,
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
||||
@@ -16,10 +16,11 @@ use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A well-known policy code.
|
||||
/// Violation variants identified by a well-known policy code (under the `code`
|
||||
/// key).
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum Code {
|
||||
#[serde(tag = "code", rename_all = "kebab-case")]
|
||||
pub enum ViolationVariant {
|
||||
/// The username is too short.
|
||||
UsernameTooShort,
|
||||
|
||||
@@ -54,7 +55,7 @@ pub enum Code {
|
||||
TooManySessions,
|
||||
}
|
||||
|
||||
impl Code {
|
||||
impl ViolationVariant {
|
||||
/// Returns the code as a string
|
||||
#[must_use]
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
@@ -80,7 +81,13 @@ pub struct Violation {
|
||||
pub msg: String,
|
||||
pub redirect_uri: Option<String>,
|
||||
pub field: Option<String>,
|
||||
pub code: Option<Code>,
|
||||
|
||||
// We flatten as policies expect `code` as another top-level field.
|
||||
//
|
||||
// This also means all of the extra fields from the variant will be splatted at this
|
||||
// level which is fine (arbitrary).
|
||||
#[serde(flatten)]
|
||||
pub variant: Option<ViolationVariant>,
|
||||
}
|
||||
|
||||
/// The result of a policy evaluation.
|
||||
|
||||
@@ -29,7 +29,7 @@ use mas_data_model::{
|
||||
};
|
||||
use mas_i18n::DataLocale;
|
||||
use mas_iana::jose::JsonWebSignatureAlg;
|
||||
use mas_policy::{Violation, ViolationCode};
|
||||
use mas_policy::{Violation, ViolationVariant};
|
||||
use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder};
|
||||
use oauth2_types::scope::{OPENID, Scope};
|
||||
use rand::{
|
||||
@@ -890,7 +890,7 @@ impl TemplateContext for CompatLoginPolicyViolationContext {
|
||||
msg: "user has too many active sessions".to_owned(),
|
||||
redirect_uri: None,
|
||||
field: None,
|
||||
code: Some(ViolationCode::TooManySessions),
|
||||
variant: Some(ViolationVariant::TooManySessions),
|
||||
}],
|
||||
},
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user