From aaf4bf588ffa74c3a104c6365677b74d50bad225 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 27 Jun 2025 15:52:35 +0200 Subject: [PATCH] Allow filtering upstream sessions by sub and sid claims --- ...m_oauth2_id_token_claims_sub_sid_index.sql | 15 ++++ ...m_oauth2_id_token_claims_sid_sub_index.sql | 15 ++++ crates/storage-pg/src/upstream_oauth2/mod.rs | 68 ++++++++++++++++++- .../storage-pg/src/upstream_oauth2/session.rs | 33 ++++++--- crates/storage/src/upstream_oauth2/session.rs | 32 +++++++++ 5 files changed, 154 insertions(+), 9 deletions(-) create mode 100644 crates/storage-pg/migrations/20250602212103_upstream_oauth2_id_token_claims_sub_sid_index.sql create mode 100644 crates/storage-pg/migrations/20250602212104_upstream_oauth2_id_token_claims_sid_sub_index.sql diff --git a/crates/storage-pg/migrations/20250602212103_upstream_oauth2_id_token_claims_sub_sid_index.sql b/crates/storage-pg/migrations/20250602212103_upstream_oauth2_id_token_claims_sub_sid_index.sql new file mode 100644 index 000000000..327022168 --- /dev/null +++ b/crates/storage-pg/migrations/20250602212103_upstream_oauth2_id_token_claims_sub_sid_index.sql @@ -0,0 +1,15 @@ +-- no-transaction +-- Copyright 2025 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +-- Please see LICENSE in the repository root for full details. + +-- We'll be requesting authorization sessions by provider, sub and sid, so we'll +-- need to index those columns +CREATE INDEX CONCURRENTLY IF NOT EXISTS + upstream_oauth_authorization_sessions_sub_sid_idx + ON upstream_oauth_authorization_sessions ( + upstream_oauth_provider_id, + (id_token_claims->>'sub'), + (id_token_claims->>'sid') + ); diff --git a/crates/storage-pg/migrations/20250602212104_upstream_oauth2_id_token_claims_sid_sub_index.sql b/crates/storage-pg/migrations/20250602212104_upstream_oauth2_id_token_claims_sid_sub_index.sql new file mode 100644 index 000000000..097c3da32 --- /dev/null +++ b/crates/storage-pg/migrations/20250602212104_upstream_oauth2_id_token_claims_sid_sub_index.sql @@ -0,0 +1,15 @@ +-- no-transaction +-- Copyright 2025 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +-- Please see LICENSE in the repository root for full details. + +-- We'll be requesting authorization sessions by provider, sub and sid, so we'll +-- need to index those columns +CREATE INDEX CONCURRENTLY IF NOT EXISTS + upstream_oauth_authorization_sessions_sid_sub_idx + ON upstream_oauth_authorization_sessions ( + upstream_oauth_provider_id, + (id_token_claims->>'sid'), + (id_token_claims->>'sub') + ); diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index 2158d872c..84a52defd 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -499,15 +499,45 @@ mod tests { 0 ); + let mut links = Vec::with_capacity(3); + for subject in ["alice", "bob", "charlie"] { + let link = repo + .upstream_oauth_link() + .add(&mut rng, &clock, &provider, subject.to_owned(), None) + .await + .unwrap(); + links.push(link); + } + let mut ids = Vec::with_capacity(20); + let sids = ["one", "two"].into_iter().cycle(); // Create 20 sessions - for idx in 0..20 { + for (idx, (link, sid)) in links.iter().cycle().zip(sids).enumerate().take(20) { let state = format!("state-{idx}"); let session = repo .upstream_oauth_session() .add(&mut rng, &clock, &provider, state, None, None) .await .unwrap(); + let id_token_claims = serde_json::json!({ + "sub": link.subject, + "sid": sid, + "aud": provider.client_id, + "iss": "https://example.com/", + }); + let session = repo + .upstream_oauth_session() + .complete_with_link( + &clock, + session, + link, + None, + Some(id_token_claims), + None, + None, + ) + .await + .unwrap(); ids.push(session.id); clock.advance(Duration::microseconds(10 * 1000 * 1000)); } @@ -577,5 +607,41 @@ mod tests { assert!(!page.has_next_page); let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect(); assert_eq!(&edge_ids, &ids[6..11]); + + // Check the sub/sid filters + assert_eq!( + repo.upstream_oauth_session() + .count(filter.with_sub_claim("alice").with_sid_claim("one")) + .await + .unwrap(), + 4 + ); + assert_eq!( + repo.upstream_oauth_session() + .count(filter.with_sub_claim("bob").with_sid_claim("two")) + .await + .unwrap(), + 4 + ); + + let page = repo + .upstream_oauth_session() + .list( + filter.with_sub_claim("alice").with_sid_claim("one"), + Pagination::first(10), + ) + .await + .unwrap(); + assert_eq!(page.edges.len(), 4); + for edge in page.edges { + assert_eq!( + edge.id_token_claims().unwrap().get("sub").unwrap().as_str(), + Some("alice") + ); + assert_eq!( + edge.id_token_claims().unwrap().get("sid").unwrap().as_str(), + Some("one") + ); + } } } diff --git a/crates/storage-pg/src/upstream_oauth2/session.rs b/crates/storage-pg/src/upstream_oauth2/session.rs index a595f600d..8cc04eeb6 100644 --- a/crates/storage-pg/src/upstream_oauth2/session.rs +++ b/crates/storage-pg/src/upstream_oauth2/session.rs @@ -15,7 +15,7 @@ use mas_storage::{ upstream_oauth2::{UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository}, }; use rand::RngCore; -use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def}; +use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr}; use sea_query_binder::SqlxBinder; use sqlx::PgConnection; use ulid::Ulid; @@ -31,13 +31,30 @@ use crate::{ impl Filter for UpstreamOAuthSessionFilter<'_> { fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition { - sea_query::Condition::all().add_option(self.provider().map(|provider| { - Expr::col(( - UpstreamOAuthAuthorizationSessions::Table, - UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId, - )) - .eq(Uuid::from(provider.id)) - })) + sea_query::Condition::all() + .add_option(self.provider().map(|provider| { + Expr::col(( + UpstreamOAuthAuthorizationSessions::Table, + UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId, + )) + .eq(Uuid::from(provider.id)) + })) + .add_option(self.sub_claim().map(|sub| { + Expr::col(( + UpstreamOAuthAuthorizationSessions::Table, + UpstreamOAuthAuthorizationSessions::IdTokenClaims, + )) + .cast_json_field("sub") + .eq(sub) + })) + .add_option(self.sid_claim().map(|sid| { + Expr::col(( + UpstreamOAuthAuthorizationSessions::Table, + UpstreamOAuthAuthorizationSessions::IdTokenClaims, + )) + .cast_json_field("sid") + .eq(sid) + })) } } diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index 66fcb1ba8..d6505285b 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -15,6 +15,8 @@ use crate::{Clock, Pagination, pagination::Page, repository_impl}; #[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] pub struct UpstreamOAuthSessionFilter<'a> { provider: Option<&'a UpstreamOAuthProvider>, + sub_claim: Option<&'a str>, + sid_claim: Option<&'a str>, } impl<'a> UpstreamOAuthSessionFilter<'a> { @@ -38,6 +40,36 @@ impl<'a> UpstreamOAuthSessionFilter<'a> { pub fn provider(&self) -> Option<&UpstreamOAuthProvider> { self.provider } + + /// Set the `sub` claim to filter by + #[must_use] + pub fn with_sub_claim(mut self, sub_claim: &'a str) -> Self { + self.sub_claim = Some(sub_claim); + self + } + + /// Get the `sub` claim filter + /// + /// Returns [`None`] if no filter was set + #[must_use] + pub fn sub_claim(&self) -> Option<&str> { + self.sub_claim + } + + /// Set the `sid` claim to filter by + #[must_use] + pub fn with_sid_claim(mut self, sid_claim: &'a str) -> Self { + self.sid_claim = Some(sid_claim); + self + } + + /// Get the `sid` claim filter + /// + /// Returns [`None`] if no filter was set + #[must_use] + pub fn sid_claim(&self) -> Option<&str> { + self.sid_claim + } } /// An [`UpstreamOAuthSessionRepository`] helps interacting with