Pass an input object to the policy evaluation instead of multiple arguments
This commit is contained in:
@@ -424,7 +424,11 @@ impl UserEmailMutations {
|
||||
|
||||
if !skip_policy_check {
|
||||
let mut policy = state.policy().await?;
|
||||
let res = policy.evaluate_email(&input.email).await?;
|
||||
let res = policy
|
||||
.evaluate_email(mas_policy::EmailInput {
|
||||
email: &input.email,
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
return Ok(AddEmailPayload::Denied {
|
||||
violations: res.violations,
|
||||
@@ -610,7 +614,11 @@ impl UserEmailMutations {
|
||||
|
||||
// Check if the email address is allowed by the policy
|
||||
let mut policy = state.policy().await?;
|
||||
let res = policy.evaluate_email(&input.email).await?;
|
||||
let res = policy
|
||||
.evaluate_email(mas_policy::EmailInput {
|
||||
email: &input.email,
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
return Ok(StartEmailAuthenticationPayload::Denied {
|
||||
violations: res.violations,
|
||||
|
||||
@@ -226,7 +226,12 @@ pub(crate) async fn complete(
|
||||
|
||||
// Run through the policy
|
||||
let res = policy
|
||||
.evaluate_authorization_grant(&grant, client, &browser_session.user)
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
user: Some(&browser_session.user),
|
||||
client,
|
||||
scope: &grant.scope,
|
||||
grant_type: mas_policy::GrantType::AuthorizationCode,
|
||||
})
|
||||
.await?;
|
||||
|
||||
if !res.valid() {
|
||||
|
||||
@@ -111,7 +111,12 @@ pub(crate) async fn get(
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
|
||||
let res = policy
|
||||
.evaluate_authorization_grant(&grant, &client, &session.user)
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
user: Some(&session.user),
|
||||
client: &client,
|
||||
scope: &grant.scope,
|
||||
grant_type: mas_policy::GrantType::AuthorizationCode,
|
||||
})
|
||||
.await?;
|
||||
|
||||
if res.valid() {
|
||||
@@ -185,7 +190,12 @@ pub(crate) async fn post(
|
||||
.ok_or(RouteError::NoSuchClient)?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_authorization_grant(&grant, &client, &session.user)
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
user: Some(&session.user),
|
||||
client: &client,
|
||||
scope: &grant.scope,
|
||||
grant_type: mas_policy::GrantType::AuthorizationCode,
|
||||
})
|
||||
.await?;
|
||||
|
||||
if !res.valid() {
|
||||
|
||||
@@ -82,7 +82,12 @@ pub(crate) async fn get(
|
||||
|
||||
// Evaluate the policy
|
||||
let res = policy
|
||||
.evaluate_device_code_grant(&grant, &client, &session.user)
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
grant_type: mas_policy::GrantType::DeviceCode,
|
||||
client: &client,
|
||||
scope: &grant.scope,
|
||||
user: Some(&session.user),
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id);
|
||||
@@ -157,7 +162,12 @@ pub(crate) async fn post(
|
||||
|
||||
// Evaluate the policy
|
||||
let res = policy
|
||||
.evaluate_device_code_grant(&grant, &client, &session.user)
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
grant_type: mas_policy::GrantType::DeviceCode,
|
||||
client: &client,
|
||||
scope: &grant.scope,
|
||||
user: Some(&session.user),
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id);
|
||||
|
||||
@@ -244,7 +244,11 @@ pub(crate) async fn post(
|
||||
}
|
||||
}
|
||||
|
||||
let res = policy.evaluate_client_registration(&metadata).await?;
|
||||
let res = policy
|
||||
.evaluate_client_registration(mas_policy::ClientRegistrationInput {
|
||||
client_metadata: &metadata,
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
return Err(RouteError::PolicyDenied(res.violations));
|
||||
}
|
||||
|
||||
@@ -676,7 +676,12 @@ async fn client_credentials_grant(
|
||||
|
||||
// Make the request go through the policy engine
|
||||
let res = policy
|
||||
.evaluate_client_credentials_grant(&scope, client)
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
user: None,
|
||||
client,
|
||||
scope: &scope,
|
||||
grant_type: mas_policy::GrantType::ClientCredentials,
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
return Err(RouteError::DeniedByPolicy(res.violations));
|
||||
|
||||
@@ -441,7 +441,11 @@ pub(crate) async fn get(
|
||||
}
|
||||
|
||||
let res = policy
|
||||
.evaluate_upstream_oauth_register(&localpart, None)
|
||||
.evaluate_register(mas_policy::RegisterInput {
|
||||
registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
|
||||
username: &localpart,
|
||||
email: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
if res.valid() {
|
||||
@@ -752,8 +756,13 @@ pub(crate) async fn post(
|
||||
|
||||
// Policy check
|
||||
let res = policy
|
||||
.evaluate_upstream_oauth_register(&username, email.as_deref())
|
||||
.evaluate_register(mas_policy::RegisterInput {
|
||||
registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
|
||||
username: &username,
|
||||
email: email.as_deref(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
if !res.valid() {
|
||||
let form_state =
|
||||
res.violations
|
||||
|
||||
@@ -233,7 +233,11 @@ pub(crate) async fn post(
|
||||
}
|
||||
|
||||
let res = policy
|
||||
.evaluate_register(&form.username, &form.email)
|
||||
.evaluate_register(mas_policy::RegisterInput {
|
||||
registration_method: mas_policy::RegistrationMethod::Password,
|
||||
username: &form.username,
|
||||
email: Some(&form.email),
|
||||
})
|
||||
.await?;
|
||||
|
||||
for violation in res.violations {
|
||||
|
||||
@@ -6,8 +6,6 @@
|
||||
|
||||
pub mod model;
|
||||
|
||||
use mas_data_model::{AuthorizationGrant, Client, DeviceCodeGrant, User};
|
||||
use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope};
|
||||
use opa_wasm::{
|
||||
wasmtime::{Config, Engine, Module, OptLevel, Store},
|
||||
Runtime,
|
||||
@@ -16,9 +14,10 @@ use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
use self::model::{AuthorizationGrantInput, ClientRegistrationInput, EmailInput, RegisterInput};
|
||||
pub use self::model::{Code as ViolationCode, EvaluationResult, Violation};
|
||||
use crate::model::GrantType;
|
||||
pub use self::model::{
|
||||
AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput,
|
||||
EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LoadError {
|
||||
@@ -190,16 +189,14 @@ impl Policy {
|
||||
name = "policy.evaluate_email",
|
||||
skip_all,
|
||||
fields(
|
||||
input.email = email,
|
||||
%input.email,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_email(
|
||||
&mut self,
|
||||
email: &str,
|
||||
input: EmailInput<'_>,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = EmailInput { email };
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.entrypoints.email, &input)
|
||||
@@ -212,44 +209,16 @@ impl Policy {
|
||||
name = "policy.evaluate.register",
|
||||
skip_all,
|
||||
fields(
|
||||
input.registration_method = "password",
|
||||
input.user.username = username,
|
||||
input.user.email = email,
|
||||
?input.registration_method,
|
||||
input.username = input.username,
|
||||
input.email = input.email,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_register(
|
||||
&mut self,
|
||||
username: &str,
|
||||
email: &str,
|
||||
input: RegisterInput<'_>,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = RegisterInput::Password { username, email };
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.entrypoints.register, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "policy.evaluate.upstream_oauth_register",
|
||||
skip_all,
|
||||
fields(
|
||||
input.registration_method = "password",
|
||||
input.user.username = username,
|
||||
input.user.email = email,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_upstream_oauth_register(
|
||||
&mut self,
|
||||
username: &str,
|
||||
email: Option<&str>,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = RegisterInput::UpstreamOAuth2 { username, email };
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.entrypoints.register, &input)
|
||||
@@ -261,10 +230,8 @@ impl Policy {
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub async fn evaluate_client_registration(
|
||||
&mut self,
|
||||
client_metadata: &VerifiedClientMetadata,
|
||||
input: ClientRegistrationInput<'_>,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = ClientRegistrationInput { client_metadata };
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(
|
||||
@@ -281,95 +248,15 @@ impl Policy {
|
||||
name = "policy.evaluate.authorization_grant",
|
||||
skip_all,
|
||||
fields(
|
||||
input.authorization_grant.id = %authorization_grant.id,
|
||||
input.scope = %authorization_grant.scope,
|
||||
input.client.id = %client.id,
|
||||
input.user.id = %user.id,
|
||||
%input.scope,
|
||||
%input.client.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_authorization_grant(
|
||||
&mut self,
|
||||
authorization_grant: &AuthorizationGrant,
|
||||
client: &Client,
|
||||
user: &User,
|
||||
input: AuthorizationGrantInput<'_>,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = AuthorizationGrantInput {
|
||||
user: Some(user),
|
||||
client,
|
||||
scope: &authorization_grant.scope,
|
||||
grant_type: GrantType::AuthorizationCode,
|
||||
};
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(
|
||||
&mut self.store,
|
||||
&self.entrypoints.authorization_grant,
|
||||
&input,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "policy.evaluate.client_credentials_grant",
|
||||
skip_all,
|
||||
fields(
|
||||
input.scope = %scope,
|
||||
input.client.id = %client.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_client_credentials_grant(
|
||||
&mut self,
|
||||
scope: &Scope,
|
||||
client: &Client,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = AuthorizationGrantInput {
|
||||
user: None,
|
||||
client,
|
||||
scope,
|
||||
grant_type: GrantType::ClientCredentials,
|
||||
};
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(
|
||||
&mut self.store,
|
||||
&self.entrypoints.authorization_grant,
|
||||
&input,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "policy.evaluate.device_code_grant",
|
||||
skip_all,
|
||||
fields(
|
||||
input.device_code_grant.id = %device_code_grant.id,
|
||||
input.scope = %device_code_grant.scope,
|
||||
input.client.id = %client.id,
|
||||
input.user.id = %user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_device_code_grant(
|
||||
&mut self,
|
||||
device_code_grant: &DeviceCodeGrant,
|
||||
client: &Client,
|
||||
user: &User,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = AuthorizationGrantInput {
|
||||
user: Some(user),
|
||||
client,
|
||||
scope: &device_code_grant.scope,
|
||||
grant_type: GrantType::DeviceCode,
|
||||
};
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(
|
||||
@@ -385,6 +272,7 @@ impl Policy {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -415,19 +303,31 @@ mod tests {
|
||||
let mut policy = factory.instantiate().await.unwrap();
|
||||
|
||||
let res = policy
|
||||
.evaluate_register("hello", "hello@example.com")
|
||||
.evaluate_register(RegisterInput {
|
||||
registration_method: RegistrationMethod::Password,
|
||||
username: "hello",
|
||||
email: Some("hello@example.com"),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!res.valid());
|
||||
|
||||
let res = policy
|
||||
.evaluate_register("hello", "hello@foo.element.io")
|
||||
.evaluate_register(RegisterInput {
|
||||
registration_method: RegistrationMethod::Password,
|
||||
username: "hello",
|
||||
email: Some("hello@foo.element.io"),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(res.valid());
|
||||
|
||||
let res = policy
|
||||
.evaluate_register("hello", "hello@staging.element.io")
|
||||
.evaluate_register(RegisterInput {
|
||||
registration_method: RegistrationMethod::Password,
|
||||
username: "hello",
|
||||
email: Some("hello@staging.element.io"),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!res.valid());
|
||||
|
||||
@@ -92,21 +92,27 @@ impl EvaluationResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
|
||||
pub enum RegistrationMethod {
|
||||
#[serde(rename = "password")]
|
||||
Password,
|
||||
|
||||
#[serde(rename = "upstream-oauth2")]
|
||||
UpstreamOAuth2,
|
||||
}
|
||||
|
||||
/// Input for the user registration policy.
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(tag = "registration_method")]
|
||||
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
|
||||
pub enum RegisterInput<'a> {
|
||||
#[serde(rename = "password")]
|
||||
Password { username: &'a str, email: &'a str },
|
||||
pub struct RegisterInput<'a> {
|
||||
pub registration_method: RegistrationMethod,
|
||||
|
||||
#[serde(rename = "upstream-oauth2")]
|
||||
UpstreamOAuth2 {
|
||||
username: &'a str,
|
||||
pub username: &'a str,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
email: Option<&'a str>,
|
||||
},
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub email: Option<&'a str>,
|
||||
}
|
||||
|
||||
/// Input for the client registration policy.
|
||||
|
||||
Reference in New Issue
Block a user