Expose the compat login policy from the policy engine

This commit is contained in:
Olivier 'reivilibre
2025-11-25 18:20:14 +00:00
parent 069b57758b
commit 2c95c0a9a0
4 changed files with 70 additions and 27 deletions

View File

@@ -145,6 +145,7 @@ pub async fn policy_factory_from_config(
register: config.register_entrypoint.clone(),
client_registration: config.client_registration_entrypoint.clone(),
authorization_grant: config.authorization_grant_entrypoint.clone(),
compat_login: config.compat_login_entrypoint.clone(),
email: config.email_entrypoint.clone(),
};

View File

@@ -62,6 +62,14 @@ fn is_default_password_entrypoint(value: &String) -> bool {
*value == default_password_entrypoint()
}
fn default_compat_login_entrypoint() -> String {
"compat_login/violation".to_owned()
}
fn is_default_compat_login_entrypoint(value: &String) -> bool {
*value == default_compat_login_entrypoint()
}
fn default_email_entrypoint() -> String {
"email/violation".to_owned()
}
@@ -111,6 +119,13 @@ pub struct PolicyConfig {
)]
pub authorization_grant_entrypoint: String,
/// Entrypoint to use when evaluating compatibility logins
#[serde(
default = "default_compat_login_entrypoint",
skip_serializing_if = "is_default_compat_login_entrypoint"
)]
pub compat_login_entrypoint: String,
/// Entrypoint to use when changing password
#[serde(
default = "default_password_entrypoint",
@@ -137,6 +152,7 @@ impl Default for PolicyConfig {
client_registration_entrypoint: default_client_registration_entrypoint(),
register_entrypoint: default_register_entrypoint(),
authorization_grant_entrypoint: default_authorization_grant_entrypoint(),
compat_login_entrypoint: default_compat_login_entrypoint(),
password_entrypoint: default_password_entrypoint(),
email_entrypoint: default_email_entrypoint(),
data: default_data(),

View File

@@ -19,8 +19,9 @@ use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
pub use self::model::{
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput,
EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput,
EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester,
Violation,
};
#[derive(Debug, Error)]
@@ -72,15 +73,17 @@ pub struct Entrypoints {
pub register: String,
pub client_registration: String,
pub authorization_grant: String,
pub compat_login: String,
pub email: String,
}
impl Entrypoints {
fn all(&self) -> [&str; 4] {
fn all(&self) -> [&str; 5] {
[
self.register.as_str(),
self.client_registration.as_str(),
self.authorization_grant.as_str(),
self.compat_login.as_str(),
self.email.as_str(),
]
}
@@ -459,6 +462,30 @@ impl Policy {
Ok(res)
}
/// Evaluate the `compat_login` entrypoint.
///
/// # Errors
///
/// Returns an error if the policy engine fails to evaluate the entrypoint.
#[tracing::instrument(
name = "policy.evaluate.compat_login",
skip_all,
fields(
%input.user.id,
),
)]
pub async fn evaluate_compat_login(
&mut self,
input: CompatLoginInput<'_>,
) -> Result<EvaluationResult, EvaluationError> {
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.entrypoints.compat_login, &input)
.await?;
Ok(res)
}
}
#[cfg(test)]
@@ -468,6 +495,16 @@ mod tests {
use super::*;
fn make_entrypoints() -> Entrypoints {
Entrypoints {
register: "register/violation".to_owned(),
client_registration: "client_registration/violation".to_owned(),
authorization_grant: "authorization_grant/violation".to_owned(),
compat_login: "compat_login/violation".to_owned(),
email: "email/violation".to_owned(),
}
}
#[tokio::test]
async fn test_register() {
let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({
@@ -484,14 +521,9 @@ mod tests {
let file = tokio::fs::File::open(path).await.unwrap();
let entrypoints = Entrypoints {
register: "register/violation".to_owned(),
client_registration: "client_registration/violation".to_owned(),
authorization_grant: "authorization_grant/violation".to_owned(),
email: "email/violation".to_owned(),
};
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
let factory = PolicyFactory::load(file, data, make_entrypoints())
.await
.unwrap();
let mut policy = factory.instantiate().await.unwrap();
@@ -551,14 +583,9 @@ mod tests {
let file = tokio::fs::File::open(path).await.unwrap();
let entrypoints = Entrypoints {
register: "register/violation".to_owned(),
client_registration: "client_registration/violation".to_owned(),
authorization_grant: "authorization_grant/violation".to_owned(),
email: "email/violation".to_owned(),
};
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
let factory = PolicyFactory::load(file, data, make_entrypoints())
.await
.unwrap();
let mut policy = factory.instantiate().await.unwrap();
@@ -620,14 +647,9 @@ mod tests {
let file = tokio::fs::File::open(path).await.unwrap();
let entrypoints = Entrypoints {
register: "register/violation".to_owned(),
client_registration: "client_registration/violation".to_owned(),
authorization_grant: "authorization_grant/violation".to_owned(),
email: "email/violation".to_owned(),
};
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
let factory = PolicyFactory::load(file, data, make_entrypoints())
.await
.unwrap();
// That is around 1 MB of JSON data. Each element is a 5-digit string, so 8
// characters including the quotes and a comma.

View File

@@ -1883,6 +1883,10 @@
"description": "Entrypoint to use when evaluating authorization grants",
"type": "string"
},
"compat_login_entrypoint": {
"description": "Entrypoint to use when evaluating compatibility logins",
"type": "string"
},
"password_entrypoint": {
"description": "Entrypoint to use when changing password",
"type": "string"