diff --git a/crates/handlers/src/oauth2/authorization/consent.rs b/crates/handlers/src/oauth2/authorization/consent.rs index 968aec08a..9e8491141 100644 --- a/crates/handlers/src/oauth2/authorization/consent.rs +++ b/crates/handlers/src/oauth2/authorization/consent.rs @@ -32,7 +32,7 @@ use super::callback::CallbackDestination; use crate::{ BoundActivityTracker, PreferredLanguage, impl_from_error_for_route, oauth2::generate_id_token, - session::{SessionOrFallback, load_session_or_fallback}, + session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback}, }; #[derive(Debug, Error)] @@ -136,10 +136,15 @@ pub(crate) async fn get( let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user) + .await + .map_err(|e| RouteError::Internal(e.into()))?; + let res = policy .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { user: Some(&session.user), client: &client, + session_counts: Some(session_counts), scope: &grant.scope, grant_type: mas_policy::GrantType::AuthorizationCode, requester: mas_policy::Requester { @@ -235,10 +240,15 @@ pub(crate) async fn post( return Err(RouteError::GrantNotPending(grant.id)); } + let session_counts = count_user_sessions_for_limiting(&mut repo, &browser_session.user) + .await + .map_err(|e| RouteError::Internal(e.into()))?; + let res = policy .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { user: Some(&browser_session.user), client: &client, + session_counts: Some(session_counts), scope: &grant.scope, grant_type: mas_policy::GrantType::AuthorizationCode, requester: mas_policy::Requester { diff --git a/crates/handlers/src/oauth2/device/consent.rs b/crates/handlers/src/oauth2/device/consent.rs index 30a35aa17..22cf1fca0 100644 --- a/crates/handlers/src/oauth2/device/consent.rs +++ b/crates/handlers/src/oauth2/device/consent.rs @@ -27,7 +27,7 @@ use ulid::Ulid; use crate::{ BoundActivityTracker, PreferredLanguage, - session::{SessionOrFallback, load_session_or_fallback}, + session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback}, }; #[derive(Deserialize, Debug)] @@ -103,11 +103,16 @@ pub(crate) async fn get( .context("Client not found") .map_err(InternalError::from_anyhow)?; + let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user) + .await + .map_err(InternalError::from_anyhow)?; + // Evaluate the policy let res = policy .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { grant_type: mas_policy::GrantType::DeviceCode, client: &client, + session_counts: Some(session_counts), scope: &grant.scope, user: Some(&session.user), requester: mas_policy::Requester { @@ -205,11 +210,16 @@ pub(crate) async fn post( .context("Client not found") .map_err(InternalError::from_anyhow)?; + let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user) + .await + .map_err(InternalError::from_anyhow)?; + // Evaluate the policy let res = policy .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { grant_type: mas_policy::GrantType::DeviceCode, client: &client, + session_counts: Some(session_counts), scope: &grant.scope, user: Some(&session.user), requester: mas_policy::Requester { diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 4a63d8290..99506ac29 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -781,6 +781,7 @@ async fn client_credentials_grant( .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { user: None, client, + session_counts: None, scope: &scope, grant_type: mas_policy::GrantType::ClientCredentials, requester: mas_policy::Requester { diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index 8f778f3d1..9977d6653 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -168,6 +168,10 @@ pub struct AuthorizationGrantInput<'a> { #[schemars(with = "Option>")] pub user: Option<&'a User>, + /// How many sessions the user has. + /// Not populated if it's not a user logging in. + pub session_counts: Option, + #[schemars(with = "std::collections::HashMap")] pub client: &'a Client, diff --git a/policies/schema/authorization_grant_input.json b/policies/schema/authorization_grant_input.json index f23bf7a73..a5d49e304 100644 --- a/policies/schema/authorization_grant_input.json +++ b/policies/schema/authorization_grant_input.json @@ -14,6 +14,14 @@ "type": "object", "additionalProperties": true }, + "session_counts": { + "description": "How many sessions the user has. Not populated if it's not a user logging in.", + "allOf": [ + { + "$ref": "#/definitions/SessionCounts" + } + ] + }, "client": { "type": "object", "additionalProperties": true @@ -29,6 +37,38 @@ } }, "definitions": { + "SessionCounts": { + "description": "Information about how many sessions the user has", + "type": "object", + "required": [ + "compat", + "oauth2", + "personal", + "total" + ], + "properties": { + "total": { + "type": "integer", + "format": "uint64", + "minimum": 0.0 + }, + "oauth2": { + "type": "integer", + "format": "uint64", + "minimum": 0.0 + }, + "compat": { + "type": "integer", + "format": "uint64", + "minimum": 0.0 + }, + "personal": { + "type": "integer", + "format": "uint64", + "minimum": 0.0 + } + } + }, "GrantType": { "type": "string", "enum": [