Expose the compat login policy from the policy engine
This commit is contained in:
@@ -145,6 +145,7 @@ pub async fn policy_factory_from_config(
|
||||
register: config.register_entrypoint.clone(),
|
||||
client_registration: config.client_registration_entrypoint.clone(),
|
||||
authorization_grant: config.authorization_grant_entrypoint.clone(),
|
||||
compat_login: config.compat_login_entrypoint.clone(),
|
||||
email: config.email_entrypoint.clone(),
|
||||
};
|
||||
|
||||
|
||||
@@ -62,6 +62,14 @@ fn is_default_password_entrypoint(value: &String) -> bool {
|
||||
*value == default_password_entrypoint()
|
||||
}
|
||||
|
||||
fn default_compat_login_entrypoint() -> String {
|
||||
"compat_login/violation".to_owned()
|
||||
}
|
||||
|
||||
fn is_default_compat_login_entrypoint(value: &String) -> bool {
|
||||
*value == default_compat_login_entrypoint()
|
||||
}
|
||||
|
||||
fn default_email_entrypoint() -> String {
|
||||
"email/violation".to_owned()
|
||||
}
|
||||
@@ -111,6 +119,13 @@ pub struct PolicyConfig {
|
||||
)]
|
||||
pub authorization_grant_entrypoint: String,
|
||||
|
||||
/// Entrypoint to use when evaluating compatibility logins
|
||||
#[serde(
|
||||
default = "default_compat_login_entrypoint",
|
||||
skip_serializing_if = "is_default_compat_login_entrypoint"
|
||||
)]
|
||||
pub compat_login_entrypoint: String,
|
||||
|
||||
/// Entrypoint to use when changing password
|
||||
#[serde(
|
||||
default = "default_password_entrypoint",
|
||||
@@ -137,6 +152,7 @@ impl Default for PolicyConfig {
|
||||
client_registration_entrypoint: default_client_registration_entrypoint(),
|
||||
register_entrypoint: default_register_entrypoint(),
|
||||
authorization_grant_entrypoint: default_authorization_grant_entrypoint(),
|
||||
compat_login_entrypoint: default_compat_login_entrypoint(),
|
||||
password_entrypoint: default_password_entrypoint(),
|
||||
email_entrypoint: default_email_entrypoint(),
|
||||
data: default_data(),
|
||||
|
||||
@@ -19,8 +19,9 @@ use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
pub use self::model::{
|
||||
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput,
|
||||
EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
|
||||
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput,
|
||||
EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester,
|
||||
Violation,
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -72,15 +73,17 @@ pub struct Entrypoints {
|
||||
pub register: String,
|
||||
pub client_registration: String,
|
||||
pub authorization_grant: String,
|
||||
pub compat_login: String,
|
||||
pub email: String,
|
||||
}
|
||||
|
||||
impl Entrypoints {
|
||||
fn all(&self) -> [&str; 4] {
|
||||
fn all(&self) -> [&str; 5] {
|
||||
[
|
||||
self.register.as_str(),
|
||||
self.client_registration.as_str(),
|
||||
self.authorization_grant.as_str(),
|
||||
self.compat_login.as_str(),
|
||||
self.email.as_str(),
|
||||
]
|
||||
}
|
||||
@@ -459,6 +462,30 @@ impl Policy {
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Evaluate the `compat_login` entrypoint.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the policy engine fails to evaluate the entrypoint.
|
||||
#[tracing::instrument(
|
||||
name = "policy.evaluate.compat_login",
|
||||
skip_all,
|
||||
fields(
|
||||
%input.user.id,
|
||||
),
|
||||
)]
|
||||
pub async fn evaluate_compat_login(
|
||||
&mut self,
|
||||
input: CompatLoginInput<'_>,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.entrypoints.compat_login, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -468,6 +495,16 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
fn make_entrypoints() -> Entrypoints {
|
||||
Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
compat_login: "compat_login/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register() {
|
||||
let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({
|
||||
@@ -484,14 +521,9 @@ mod tests {
|
||||
|
||||
let file = tokio::fs::File::open(path).await.unwrap();
|
||||
|
||||
let entrypoints = Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
||||
let factory = PolicyFactory::load(file, data, make_entrypoints())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut policy = factory.instantiate().await.unwrap();
|
||||
|
||||
@@ -551,14 +583,9 @@ mod tests {
|
||||
|
||||
let file = tokio::fs::File::open(path).await.unwrap();
|
||||
|
||||
let entrypoints = Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
||||
let factory = PolicyFactory::load(file, data, make_entrypoints())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut policy = factory.instantiate().await.unwrap();
|
||||
|
||||
@@ -620,14 +647,9 @@ mod tests {
|
||||
|
||||
let file = tokio::fs::File::open(path).await.unwrap();
|
||||
|
||||
let entrypoints = Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
||||
let factory = PolicyFactory::load(file, data, make_entrypoints())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// That is around 1 MB of JSON data. Each element is a 5-digit string, so 8
|
||||
// characters including the quotes and a comma.
|
||||
|
||||
@@ -1883,6 +1883,10 @@
|
||||
"description": "Entrypoint to use when evaluating authorization grants",
|
||||
"type": "string"
|
||||
},
|
||||
"compat_login_entrypoint": {
|
||||
"description": "Entrypoint to use when evaluating compatibility logins",
|
||||
"type": "string"
|
||||
},
|
||||
"password_entrypoint": {
|
||||
"description": "Entrypoint to use when changing password",
|
||||
"type": "string"
|
||||
|
||||
Reference in New Issue
Block a user