From 2c95c0a9a032764f249edf2da66fe347842b5d0d Mon Sep 17 00:00:00 2001 From: Olivier 'reivilibre Date: Tue, 25 Nov 2025 18:20:14 +0000 Subject: [PATCH] Expose the compat login policy from the policy engine --- crates/cli/src/util.rs | 1 + crates/config/src/sections/policy.rs | 16 ++++++ crates/policy/src/lib.rs | 76 ++++++++++++++++++---------- docs/config.schema.json | 4 ++ 4 files changed, 70 insertions(+), 27 deletions(-) diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index c0f31557b..454276150 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -145,6 +145,7 @@ pub async fn policy_factory_from_config( register: config.register_entrypoint.clone(), client_registration: config.client_registration_entrypoint.clone(), authorization_grant: config.authorization_grant_entrypoint.clone(), + compat_login: config.compat_login_entrypoint.clone(), email: config.email_entrypoint.clone(), }; diff --git a/crates/config/src/sections/policy.rs b/crates/config/src/sections/policy.rs index 37d052ade..3b816b713 100644 --- a/crates/config/src/sections/policy.rs +++ b/crates/config/src/sections/policy.rs @@ -62,6 +62,14 @@ fn is_default_password_entrypoint(value: &String) -> bool { *value == default_password_entrypoint() } +fn default_compat_login_entrypoint() -> String { + "compat_login/violation".to_owned() +} + +fn is_default_compat_login_entrypoint(value: &String) -> bool { + *value == default_compat_login_entrypoint() +} + fn default_email_entrypoint() -> String { "email/violation".to_owned() } @@ -111,6 +119,13 @@ pub struct PolicyConfig { )] pub authorization_grant_entrypoint: String, + /// Entrypoint to use when evaluating compatibility logins + #[serde( + default = "default_compat_login_entrypoint", + skip_serializing_if = "is_default_compat_login_entrypoint" + )] + pub compat_login_entrypoint: String, + /// Entrypoint to use when changing password #[serde( default = "default_password_entrypoint", @@ -137,6 +152,7 @@ impl Default for PolicyConfig { client_registration_entrypoint: default_client_registration_entrypoint(), register_entrypoint: default_register_entrypoint(), authorization_grant_entrypoint: default_authorization_grant_entrypoint(), + compat_login_entrypoint: default_compat_login_entrypoint(), password_entrypoint: default_password_entrypoint(), email_entrypoint: default_email_entrypoint(), data: default_data(), diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 8a038aea8..dcb68dd36 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -19,8 +19,9 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; pub use self::model::{ - AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput, - EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation, + AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput, + EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, + Violation, }; #[derive(Debug, Error)] @@ -72,15 +73,17 @@ pub struct Entrypoints { pub register: String, pub client_registration: String, pub authorization_grant: String, + pub compat_login: String, pub email: String, } impl Entrypoints { - fn all(&self) -> [&str; 4] { + fn all(&self) -> [&str; 5] { [ self.register.as_str(), self.client_registration.as_str(), self.authorization_grant.as_str(), + self.compat_login.as_str(), self.email.as_str(), ] } @@ -459,6 +462,30 @@ impl Policy { Ok(res) } + + /// Evaluate the `compat_login` entrypoint. + /// + /// # Errors + /// + /// Returns an error if the policy engine fails to evaluate the entrypoint. + #[tracing::instrument( + name = "policy.evaluate.compat_login", + skip_all, + fields( + %input.user.id, + ), + )] + pub async fn evaluate_compat_login( + &mut self, + input: CompatLoginInput<'_>, + ) -> Result { + let [res]: [EvaluationResult; 1] = self + .instance + .evaluate(&mut self.store, &self.entrypoints.compat_login, &input) + .await?; + + Ok(res) + } } #[cfg(test)] @@ -468,6 +495,16 @@ mod tests { use super::*; + fn make_entrypoints() -> Entrypoints { + Entrypoints { + register: "register/violation".to_owned(), + client_registration: "client_registration/violation".to_owned(), + authorization_grant: "authorization_grant/violation".to_owned(), + compat_login: "compat_login/violation".to_owned(), + email: "email/violation".to_owned(), + } + } + #[tokio::test] async fn test_register() { let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({ @@ -484,14 +521,9 @@ mod tests { let file = tokio::fs::File::open(path).await.unwrap(); - let entrypoints = Entrypoints { - register: "register/violation".to_owned(), - client_registration: "client_registration/violation".to_owned(), - authorization_grant: "authorization_grant/violation".to_owned(), - email: "email/violation".to_owned(), - }; - - let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap(); + let factory = PolicyFactory::load(file, data, make_entrypoints()) + .await + .unwrap(); let mut policy = factory.instantiate().await.unwrap(); @@ -551,14 +583,9 @@ mod tests { let file = tokio::fs::File::open(path).await.unwrap(); - let entrypoints = Entrypoints { - register: "register/violation".to_owned(), - client_registration: "client_registration/violation".to_owned(), - authorization_grant: "authorization_grant/violation".to_owned(), - email: "email/violation".to_owned(), - }; - - let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap(); + let factory = PolicyFactory::load(file, data, make_entrypoints()) + .await + .unwrap(); let mut policy = factory.instantiate().await.unwrap(); @@ -620,14 +647,9 @@ mod tests { let file = tokio::fs::File::open(path).await.unwrap(); - let entrypoints = Entrypoints { - register: "register/violation".to_owned(), - client_registration: "client_registration/violation".to_owned(), - authorization_grant: "authorization_grant/violation".to_owned(), - email: "email/violation".to_owned(), - }; - - let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap(); + let factory = PolicyFactory::load(file, data, make_entrypoints()) + .await + .unwrap(); // That is around 1 MB of JSON data. Each element is a 5-digit string, so 8 // characters including the quotes and a comma. diff --git a/docs/config.schema.json b/docs/config.schema.json index cda68f145..496cd2c5b 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -1883,6 +1883,10 @@ "description": "Entrypoint to use when evaluating authorization grants", "type": "string" }, + "compat_login_entrypoint": { + "description": "Entrypoint to use when evaluating compatibility logins", + "type": "string" + }, "password_entrypoint": { "description": "Entrypoint to use when changing password", "type": "string"