diff --git a/crates/handlers/src/graphql/mutations/user_email.rs b/crates/handlers/src/graphql/mutations/user_email.rs index 38048d39a..4a0c51853 100644 --- a/crates/handlers/src/graphql/mutations/user_email.rs +++ b/crates/handlers/src/graphql/mutations/user_email.rs @@ -424,7 +424,11 @@ impl UserEmailMutations { if !skip_policy_check { let mut policy = state.policy().await?; - let res = policy.evaluate_email(&input.email).await?; + let res = policy + .evaluate_email(mas_policy::EmailInput { + email: &input.email, + }) + .await?; if !res.valid() { return Ok(AddEmailPayload::Denied { violations: res.violations, @@ -610,7 +614,11 @@ impl UserEmailMutations { // Check if the email address is allowed by the policy let mut policy = state.policy().await?; - let res = policy.evaluate_email(&input.email).await?; + let res = policy + .evaluate_email(mas_policy::EmailInput { + email: &input.email, + }) + .await?; if !res.valid() { return Ok(StartEmailAuthenticationPayload::Denied { violations: res.violations, diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index a9efb2ae2..ba2786240 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -226,7 +226,12 @@ pub(crate) async fn complete( // Run through the policy let res = policy - .evaluate_authorization_grant(&grant, client, &browser_session.user) + .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { + user: Some(&browser_session.user), + client, + scope: &grant.scope, + grant_type: mas_policy::GrantType::AuthorizationCode, + }) .await?; if !res.valid() { diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index c5a479e41..b2b99b372 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -111,7 +111,12 @@ pub(crate) async fn get( let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let res = policy - .evaluate_authorization_grant(&grant, &client, &session.user) + .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { + user: Some(&session.user), + client: &client, + scope: &grant.scope, + grant_type: mas_policy::GrantType::AuthorizationCode, + }) .await?; if res.valid() { @@ -185,7 +190,12 @@ pub(crate) async fn post( .ok_or(RouteError::NoSuchClient)?; let res = policy - .evaluate_authorization_grant(&grant, &client, &session.user) + .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { + user: Some(&session.user), + client: &client, + scope: &grant.scope, + grant_type: mas_policy::GrantType::AuthorizationCode, + }) .await?; if !res.valid() { diff --git a/crates/handlers/src/oauth2/device/consent.rs b/crates/handlers/src/oauth2/device/consent.rs index 6025f3d1e..9f0045ad9 100644 --- a/crates/handlers/src/oauth2/device/consent.rs +++ b/crates/handlers/src/oauth2/device/consent.rs @@ -82,7 +82,12 @@ pub(crate) async fn get( // Evaluate the policy let res = policy - .evaluate_device_code_grant(&grant, &client, &session.user) + .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { + grant_type: mas_policy::GrantType::DeviceCode, + client: &client, + scope: &grant.scope, + user: Some(&session.user), + }) .await?; if !res.valid() { warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id); @@ -157,7 +162,12 @@ pub(crate) async fn post( // Evaluate the policy let res = policy - .evaluate_device_code_grant(&grant, &client, &session.user) + .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { + grant_type: mas_policy::GrantType::DeviceCode, + client: &client, + scope: &grant.scope, + user: Some(&session.user), + }) .await?; if !res.valid() { warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id); diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index bd608dd69..1dcba9f65 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -244,7 +244,11 @@ pub(crate) async fn post( } } - let res = policy.evaluate_client_registration(&metadata).await?; + let res = policy + .evaluate_client_registration(mas_policy::ClientRegistrationInput { + client_metadata: &metadata, + }) + .await?; if !res.valid() { return Err(RouteError::PolicyDenied(res.violations)); } diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index ad6a15618..70705293e 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -676,7 +676,12 @@ async fn client_credentials_grant( // Make the request go through the policy engine let res = policy - .evaluate_client_credentials_grant(&scope, client) + .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { + user: None, + client, + scope: &scope, + grant_type: mas_policy::GrantType::ClientCredentials, + }) .await?; if !res.valid() { return Err(RouteError::DeniedByPolicy(res.violations)); diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 6b2922cce..ec22c2c52 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -441,7 +441,11 @@ pub(crate) async fn get( } let res = policy - .evaluate_upstream_oauth_register(&localpart, None) + .evaluate_register(mas_policy::RegisterInput { + registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2, + username: &localpart, + email: None, + }) .await?; if res.valid() { @@ -752,8 +756,13 @@ pub(crate) async fn post( // Policy check let res = policy - .evaluate_upstream_oauth_register(&username, email.as_deref()) + .evaluate_register(mas_policy::RegisterInput { + registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2, + username: &username, + email: email.as_deref(), + }) .await?; + if !res.valid() { let form_state = res.violations diff --git a/crates/handlers/src/views/register/password.rs b/crates/handlers/src/views/register/password.rs index c2177c484..0cc59cbe3 100644 --- a/crates/handlers/src/views/register/password.rs +++ b/crates/handlers/src/views/register/password.rs @@ -233,7 +233,11 @@ pub(crate) async fn post( } let res = policy - .evaluate_register(&form.username, &form.email) + .evaluate_register(mas_policy::RegisterInput { + registration_method: mas_policy::RegistrationMethod::Password, + username: &form.username, + email: Some(&form.email), + }) .await?; for violation in res.violations { diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 9ffe2f511..bd338714f 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -6,8 +6,6 @@ pub mod model; -use mas_data_model::{AuthorizationGrant, Client, DeviceCodeGrant, User}; -use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope}; use opa_wasm::{ wasmtime::{Config, Engine, Module, OptLevel, Store}, Runtime, @@ -16,9 +14,10 @@ use serde::Serialize; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; -use self::model::{AuthorizationGrantInput, ClientRegistrationInput, EmailInput, RegisterInput}; -pub use self::model::{Code as ViolationCode, EvaluationResult, Violation}; -use crate::model::GrantType; +pub use self::model::{ + AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput, + EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation, +}; #[derive(Debug, Error)] pub enum LoadError { @@ -190,16 +189,14 @@ impl Policy { name = "policy.evaluate_email", skip_all, fields( - input.email = email, + %input.email, ), err, )] pub async fn evaluate_email( &mut self, - email: &str, + input: EmailInput<'_>, ) -> Result { - let input = EmailInput { email }; - let [res]: [EvaluationResult; 1] = self .instance .evaluate(&mut self.store, &self.entrypoints.email, &input) @@ -212,44 +209,16 @@ impl Policy { name = "policy.evaluate.register", skip_all, fields( - input.registration_method = "password", - input.user.username = username, - input.user.email = email, + ?input.registration_method, + input.username = input.username, + input.email = input.email, ), err, )] pub async fn evaluate_register( &mut self, - username: &str, - email: &str, + input: RegisterInput<'_>, ) -> Result { - let input = RegisterInput::Password { username, email }; - - let [res]: [EvaluationResult; 1] = self - .instance - .evaluate(&mut self.store, &self.entrypoints.register, &input) - .await?; - - Ok(res) - } - - #[tracing::instrument( - name = "policy.evaluate.upstream_oauth_register", - skip_all, - fields( - input.registration_method = "password", - input.user.username = username, - input.user.email = email, - ), - err, - )] - pub async fn evaluate_upstream_oauth_register( - &mut self, - username: &str, - email: Option<&str>, - ) -> Result { - let input = RegisterInput::UpstreamOAuth2 { username, email }; - let [res]: [EvaluationResult; 1] = self .instance .evaluate(&mut self.store, &self.entrypoints.register, &input) @@ -261,10 +230,8 @@ impl Policy { #[tracing::instrument(skip(self))] pub async fn evaluate_client_registration( &mut self, - client_metadata: &VerifiedClientMetadata, + input: ClientRegistrationInput<'_>, ) -> Result { - let input = ClientRegistrationInput { client_metadata }; - let [res]: [EvaluationResult; 1] = self .instance .evaluate( @@ -281,95 +248,15 @@ impl Policy { name = "policy.evaluate.authorization_grant", skip_all, fields( - input.authorization_grant.id = %authorization_grant.id, - input.scope = %authorization_grant.scope, - input.client.id = %client.id, - input.user.id = %user.id, + %input.scope, + %input.client.id, ), err, )] pub async fn evaluate_authorization_grant( &mut self, - authorization_grant: &AuthorizationGrant, - client: &Client, - user: &User, + input: AuthorizationGrantInput<'_>, ) -> Result { - let input = AuthorizationGrantInput { - user: Some(user), - client, - scope: &authorization_grant.scope, - grant_type: GrantType::AuthorizationCode, - }; - - let [res]: [EvaluationResult; 1] = self - .instance - .evaluate( - &mut self.store, - &self.entrypoints.authorization_grant, - &input, - ) - .await?; - - Ok(res) - } - - #[tracing::instrument( - name = "policy.evaluate.client_credentials_grant", - skip_all, - fields( - input.scope = %scope, - input.client.id = %client.id, - ), - err, - )] - pub async fn evaluate_client_credentials_grant( - &mut self, - scope: &Scope, - client: &Client, - ) -> Result { - let input = AuthorizationGrantInput { - user: None, - client, - scope, - grant_type: GrantType::ClientCredentials, - }; - - let [res]: [EvaluationResult; 1] = self - .instance - .evaluate( - &mut self.store, - &self.entrypoints.authorization_grant, - &input, - ) - .await?; - - Ok(res) - } - - #[tracing::instrument( - name = "policy.evaluate.device_code_grant", - skip_all, - fields( - input.device_code_grant.id = %device_code_grant.id, - input.scope = %device_code_grant.scope, - input.client.id = %client.id, - input.user.id = %user.id, - ), - err, - )] - pub async fn evaluate_device_code_grant( - &mut self, - device_code_grant: &DeviceCodeGrant, - client: &Client, - user: &User, - ) -> Result { - let input = AuthorizationGrantInput { - user: Some(user), - client, - scope: &device_code_grant.scope, - grant_type: GrantType::DeviceCode, - }; - let [res]: [EvaluationResult; 1] = self .instance .evaluate( @@ -385,6 +272,7 @@ impl Policy { #[cfg(test)] mod tests { + use super::*; #[tokio::test] @@ -415,19 +303,31 @@ mod tests { let mut policy = factory.instantiate().await.unwrap(); let res = policy - .evaluate_register("hello", "hello@example.com") + .evaluate_register(RegisterInput { + registration_method: RegistrationMethod::Password, + username: "hello", + email: Some("hello@example.com"), + }) .await .unwrap(); assert!(!res.valid()); let res = policy - .evaluate_register("hello", "hello@foo.element.io") + .evaluate_register(RegisterInput { + registration_method: RegistrationMethod::Password, + username: "hello", + email: Some("hello@foo.element.io"), + }) .await .unwrap(); assert!(res.valid()); let res = policy - .evaluate_register("hello", "hello@staging.element.io") + .evaluate_register(RegisterInput { + registration_method: RegistrationMethod::Password, + username: "hello", + email: Some("hello@staging.element.io"), + }) .await .unwrap(); assert!(!res.valid()); diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index aebca8928..1868b699e 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -92,21 +92,27 @@ impl EvaluationResult { } } +#[derive(Serialize, Debug)] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub enum RegistrationMethod { + #[serde(rename = "password")] + Password, + + #[serde(rename = "upstream-oauth2")] + UpstreamOAuth2, +} + /// Input for the user registration policy. #[derive(Serialize, Debug)] #[serde(tag = "registration_method")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] -pub enum RegisterInput<'a> { - #[serde(rename = "password")] - Password { username: &'a str, email: &'a str }, +pub struct RegisterInput<'a> { + pub registration_method: RegistrationMethod, - #[serde(rename = "upstream-oauth2")] - UpstreamOAuth2 { - username: &'a str, + pub username: &'a str, - #[serde(skip_serializing_if = "Option::is_none")] - email: Option<&'a str>, - }, + #[serde(skip_serializing_if = "Option::is_none")] + pub email: Option<&'a str>, } /// Input for the client registration policy.