Add session counts to policy input
This commit is contained in:
@@ -32,7 +32,7 @@ use super::callback::CallbackDestination;
|
||||
use crate::{
|
||||
BoundActivityTracker, PreferredLanguage, impl_from_error_for_route,
|
||||
oauth2::generate_id_token,
|
||||
session::{SessionOrFallback, load_session_or_fallback},
|
||||
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -136,10 +136,15 @@ pub(crate) async fn get(
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user)
|
||||
.await
|
||||
.map_err(|e| RouteError::Internal(e.into()))?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
user: Some(&session.user),
|
||||
client: &client,
|
||||
session_counts: Some(session_counts),
|
||||
scope: &grant.scope,
|
||||
grant_type: mas_policy::GrantType::AuthorizationCode,
|
||||
requester: mas_policy::Requester {
|
||||
@@ -235,10 +240,15 @@ pub(crate) async fn post(
|
||||
return Err(RouteError::GrantNotPending(grant.id));
|
||||
}
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(&mut repo, &browser_session.user)
|
||||
.await
|
||||
.map_err(|e| RouteError::Internal(e.into()))?;
|
||||
|
||||
let res = policy
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
user: Some(&browser_session.user),
|
||||
client: &client,
|
||||
session_counts: Some(session_counts),
|
||||
scope: &grant.scope,
|
||||
grant_type: mas_policy::GrantType::AuthorizationCode,
|
||||
requester: mas_policy::Requester {
|
||||
|
||||
@@ -27,7 +27,7 @@ use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
BoundActivityTracker, PreferredLanguage,
|
||||
session::{SessionOrFallback, load_session_or_fallback},
|
||||
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
|
||||
};
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
@@ -103,11 +103,16 @@ pub(crate) async fn get(
|
||||
.context("Client not found")
|
||||
.map_err(InternalError::from_anyhow)?;
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user)
|
||||
.await
|
||||
.map_err(InternalError::from_anyhow)?;
|
||||
|
||||
// Evaluate the policy
|
||||
let res = policy
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
grant_type: mas_policy::GrantType::DeviceCode,
|
||||
client: &client,
|
||||
session_counts: Some(session_counts),
|
||||
scope: &grant.scope,
|
||||
user: Some(&session.user),
|
||||
requester: mas_policy::Requester {
|
||||
@@ -205,11 +210,16 @@ pub(crate) async fn post(
|
||||
.context("Client not found")
|
||||
.map_err(InternalError::from_anyhow)?;
|
||||
|
||||
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user)
|
||||
.await
|
||||
.map_err(InternalError::from_anyhow)?;
|
||||
|
||||
// Evaluate the policy
|
||||
let res = policy
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
grant_type: mas_policy::GrantType::DeviceCode,
|
||||
client: &client,
|
||||
session_counts: Some(session_counts),
|
||||
scope: &grant.scope,
|
||||
user: Some(&session.user),
|
||||
requester: mas_policy::Requester {
|
||||
|
||||
@@ -781,6 +781,7 @@ async fn client_credentials_grant(
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
user: None,
|
||||
client,
|
||||
session_counts: None,
|
||||
scope: &scope,
|
||||
grant_type: mas_policy::GrantType::ClientCredentials,
|
||||
requester: mas_policy::Requester {
|
||||
|
||||
@@ -168,6 +168,10 @@ pub struct AuthorizationGrantInput<'a> {
|
||||
#[schemars(with = "Option<std::collections::HashMap<String, serde_json::Value>>")]
|
||||
pub user: Option<&'a User>,
|
||||
|
||||
/// How many sessions the user has.
|
||||
/// Not populated if it's not a user logging in.
|
||||
pub session_counts: Option<SessionCounts>,
|
||||
|
||||
#[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
|
||||
pub client: &'a Client,
|
||||
|
||||
|
||||
@@ -14,6 +14,14 @@
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
},
|
||||
"session_counts": {
|
||||
"description": "How many sessions the user has. Not populated if it's not a user logging in.",
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/SessionCounts"
|
||||
}
|
||||
]
|
||||
},
|
||||
"client": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
@@ -29,6 +37,38 @@
|
||||
}
|
||||
},
|
||||
"definitions": {
|
||||
"SessionCounts": {
|
||||
"description": "Information about how many sessions the user has",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"compat",
|
||||
"oauth2",
|
||||
"personal",
|
||||
"total"
|
||||
],
|
||||
"properties": {
|
||||
"total": {
|
||||
"type": "integer",
|
||||
"format": "uint64",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"oauth2": {
|
||||
"type": "integer",
|
||||
"format": "uint64",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"compat": {
|
||||
"type": "integer",
|
||||
"format": "uint64",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"personal": {
|
||||
"type": "integer",
|
||||
"format": "uint64",
|
||||
"minimum": 0.0
|
||||
}
|
||||
}
|
||||
},
|
||||
"GrantType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
||||
Reference in New Issue
Block a user