Pass an input object to the policy evaluation instead of multiple arguments

This commit is contained in:
Quentin Gliech
2025-02-14 17:15:26 +01:00
parent b8fb25faed
commit 72384b8e03
10 changed files with 112 additions and 151 deletions

View File

@@ -424,7 +424,11 @@ impl UserEmailMutations {
if !skip_policy_check {
let mut policy = state.policy().await?;
let res = policy.evaluate_email(&input.email).await?;
let res = policy
.evaluate_email(mas_policy::EmailInput {
email: &input.email,
})
.await?;
if !res.valid() {
return Ok(AddEmailPayload::Denied {
violations: res.violations,
@@ -610,7 +614,11 @@ impl UserEmailMutations {
// Check if the email address is allowed by the policy
let mut policy = state.policy().await?;
let res = policy.evaluate_email(&input.email).await?;
let res = policy
.evaluate_email(mas_policy::EmailInput {
email: &input.email,
})
.await?;
if !res.valid() {
return Ok(StartEmailAuthenticationPayload::Denied {
violations: res.violations,

View File

@@ -226,7 +226,12 @@ pub(crate) async fn complete(
// Run through the policy
let res = policy
.evaluate_authorization_grant(&grant, client, &browser_session.user)
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&browser_session.user),
client,
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
})
.await?;
if !res.valid() {

View File

@@ -111,7 +111,12 @@ pub(crate) async fn get(
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let res = policy
.evaluate_authorization_grant(&grant, &client, &session.user)
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&session.user),
client: &client,
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
})
.await?;
if res.valid() {
@@ -185,7 +190,12 @@ pub(crate) async fn post(
.ok_or(RouteError::NoSuchClient)?;
let res = policy
.evaluate_authorization_grant(&grant, &client, &session.user)
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&session.user),
client: &client,
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
})
.await?;
if !res.valid() {

View File

@@ -82,7 +82,12 @@ pub(crate) async fn get(
// Evaluate the policy
let res = policy
.evaluate_device_code_grant(&grant, &client, &session.user)
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
grant_type: mas_policy::GrantType::DeviceCode,
client: &client,
scope: &grant.scope,
user: Some(&session.user),
})
.await?;
if !res.valid() {
warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id);
@@ -157,7 +162,12 @@ pub(crate) async fn post(
// Evaluate the policy
let res = policy
.evaluate_device_code_grant(&grant, &client, &session.user)
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
grant_type: mas_policy::GrantType::DeviceCode,
client: &client,
scope: &grant.scope,
user: Some(&session.user),
})
.await?;
if !res.valid() {
warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id);

View File

@@ -244,7 +244,11 @@ pub(crate) async fn post(
}
}
let res = policy.evaluate_client_registration(&metadata).await?;
let res = policy
.evaluate_client_registration(mas_policy::ClientRegistrationInput {
client_metadata: &metadata,
})
.await?;
if !res.valid() {
return Err(RouteError::PolicyDenied(res.violations));
}

View File

@@ -676,7 +676,12 @@ async fn client_credentials_grant(
// Make the request go through the policy engine
let res = policy
.evaluate_client_credentials_grant(&scope, client)
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: None,
client,
scope: &scope,
grant_type: mas_policy::GrantType::ClientCredentials,
})
.await?;
if !res.valid() {
return Err(RouteError::DeniedByPolicy(res.violations));

View File

@@ -441,7 +441,11 @@ pub(crate) async fn get(
}
let res = policy
.evaluate_upstream_oauth_register(&localpart, None)
.evaluate_register(mas_policy::RegisterInput {
registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
username: &localpart,
email: None,
})
.await?;
if res.valid() {
@@ -752,8 +756,13 @@ pub(crate) async fn post(
// Policy check
let res = policy
.evaluate_upstream_oauth_register(&username, email.as_deref())
.evaluate_register(mas_policy::RegisterInput {
registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
username: &username,
email: email.as_deref(),
})
.await?;
if !res.valid() {
let form_state =
res.violations

View File

@@ -233,7 +233,11 @@ pub(crate) async fn post(
}
let res = policy
.evaluate_register(&form.username, &form.email)
.evaluate_register(mas_policy::RegisterInput {
registration_method: mas_policy::RegistrationMethod::Password,
username: &form.username,
email: Some(&form.email),
})
.await?;
for violation in res.violations {

View File

@@ -6,8 +6,6 @@
pub mod model;
use mas_data_model::{AuthorizationGrant, Client, DeviceCodeGrant, User};
use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope};
use opa_wasm::{
wasmtime::{Config, Engine, Module, OptLevel, Store},
Runtime,
@@ -16,9 +14,10 @@ use serde::Serialize;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
use self::model::{AuthorizationGrantInput, ClientRegistrationInput, EmailInput, RegisterInput};
pub use self::model::{Code as ViolationCode, EvaluationResult, Violation};
use crate::model::GrantType;
pub use self::model::{
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput,
EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
};
#[derive(Debug, Error)]
pub enum LoadError {
@@ -190,16 +189,14 @@ impl Policy {
name = "policy.evaluate_email",
skip_all,
fields(
input.email = email,
%input.email,
),
err,
)]
pub async fn evaluate_email(
&mut self,
email: &str,
input: EmailInput<'_>,
) -> Result<EvaluationResult, EvaluationError> {
let input = EmailInput { email };
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.entrypoints.email, &input)
@@ -212,44 +209,16 @@ impl Policy {
name = "policy.evaluate.register",
skip_all,
fields(
input.registration_method = "password",
input.user.username = username,
input.user.email = email,
?input.registration_method,
input.username = input.username,
input.email = input.email,
),
err,
)]
pub async fn evaluate_register(
&mut self,
username: &str,
email: &str,
input: RegisterInput<'_>,
) -> Result<EvaluationResult, EvaluationError> {
let input = RegisterInput::Password { username, email };
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.entrypoints.register, &input)
.await?;
Ok(res)
}
#[tracing::instrument(
name = "policy.evaluate.upstream_oauth_register",
skip_all,
fields(
input.registration_method = "password",
input.user.username = username,
input.user.email = email,
),
err,
)]
pub async fn evaluate_upstream_oauth_register(
&mut self,
username: &str,
email: Option<&str>,
) -> Result<EvaluationResult, EvaluationError> {
let input = RegisterInput::UpstreamOAuth2 { username, email };
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.entrypoints.register, &input)
@@ -261,10 +230,8 @@ impl Policy {
#[tracing::instrument(skip(self))]
pub async fn evaluate_client_registration(
&mut self,
client_metadata: &VerifiedClientMetadata,
input: ClientRegistrationInput<'_>,
) -> Result<EvaluationResult, EvaluationError> {
let input = ClientRegistrationInput { client_metadata };
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(
@@ -281,95 +248,15 @@ impl Policy {
name = "policy.evaluate.authorization_grant",
skip_all,
fields(
input.authorization_grant.id = %authorization_grant.id,
input.scope = %authorization_grant.scope,
input.client.id = %client.id,
input.user.id = %user.id,
%input.scope,
%input.client.id,
),
err,
)]
pub async fn evaluate_authorization_grant(
&mut self,
authorization_grant: &AuthorizationGrant,
client: &Client,
user: &User,
input: AuthorizationGrantInput<'_>,
) -> Result<EvaluationResult, EvaluationError> {
let input = AuthorizationGrantInput {
user: Some(user),
client,
scope: &authorization_grant.scope,
grant_type: GrantType::AuthorizationCode,
};
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(
&mut self.store,
&self.entrypoints.authorization_grant,
&input,
)
.await?;
Ok(res)
}
#[tracing::instrument(
name = "policy.evaluate.client_credentials_grant",
skip_all,
fields(
input.scope = %scope,
input.client.id = %client.id,
),
err,
)]
pub async fn evaluate_client_credentials_grant(
&mut self,
scope: &Scope,
client: &Client,
) -> Result<EvaluationResult, EvaluationError> {
let input = AuthorizationGrantInput {
user: None,
client,
scope,
grant_type: GrantType::ClientCredentials,
};
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(
&mut self.store,
&self.entrypoints.authorization_grant,
&input,
)
.await?;
Ok(res)
}
#[tracing::instrument(
name = "policy.evaluate.device_code_grant",
skip_all,
fields(
input.device_code_grant.id = %device_code_grant.id,
input.scope = %device_code_grant.scope,
input.client.id = %client.id,
input.user.id = %user.id,
),
err,
)]
pub async fn evaluate_device_code_grant(
&mut self,
device_code_grant: &DeviceCodeGrant,
client: &Client,
user: &User,
) -> Result<EvaluationResult, EvaluationError> {
let input = AuthorizationGrantInput {
user: Some(user),
client,
scope: &device_code_grant.scope,
grant_type: GrantType::DeviceCode,
};
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(
@@ -385,6 +272,7 @@ impl Policy {
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
@@ -415,19 +303,31 @@ mod tests {
let mut policy = factory.instantiate().await.unwrap();
let res = policy
.evaluate_register("hello", "hello@example.com")
.evaluate_register(RegisterInput {
registration_method: RegistrationMethod::Password,
username: "hello",
email: Some("hello@example.com"),
})
.await
.unwrap();
assert!(!res.valid());
let res = policy
.evaluate_register("hello", "hello@foo.element.io")
.evaluate_register(RegisterInput {
registration_method: RegistrationMethod::Password,
username: "hello",
email: Some("hello@foo.element.io"),
})
.await
.unwrap();
assert!(res.valid());
let res = policy
.evaluate_register("hello", "hello@staging.element.io")
.evaluate_register(RegisterInput {
registration_method: RegistrationMethod::Password,
username: "hello",
email: Some("hello@staging.element.io"),
})
.await
.unwrap();
assert!(!res.valid());

View File

@@ -92,21 +92,27 @@ impl EvaluationResult {
}
}
#[derive(Serialize, Debug)]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
pub enum RegistrationMethod {
#[serde(rename = "password")]
Password,
#[serde(rename = "upstream-oauth2")]
UpstreamOAuth2,
}
/// Input for the user registration policy.
#[derive(Serialize, Debug)]
#[serde(tag = "registration_method")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
pub enum RegisterInput<'a> {
#[serde(rename = "password")]
Password { username: &'a str, email: &'a str },
pub struct RegisterInput<'a> {
pub registration_method: RegistrationMethod,
#[serde(rename = "upstream-oauth2")]
UpstreamOAuth2 {
username: &'a str,
pub username: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
email: Option<&'a str>,
},
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<&'a str>,
}
/// Input for the client registration policy.