Allow filtering upstream sessions by sub and sid claims

This commit is contained in:
Quentin Gliech
2025-06-27 15:52:35 +02:00
parent a3acec4973
commit aaf4bf588f
5 changed files with 154 additions and 9 deletions

View File

@@ -0,0 +1,15 @@
-- no-transaction
-- Copyright 2025 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
-- Please see LICENSE in the repository root for full details.
-- We'll be requesting authorization sessions by provider, sub and sid, so we'll
-- need to index those columns
CREATE INDEX CONCURRENTLY IF NOT EXISTS
upstream_oauth_authorization_sessions_sub_sid_idx
ON upstream_oauth_authorization_sessions (
upstream_oauth_provider_id,
(id_token_claims->>'sub'),
(id_token_claims->>'sid')
);

View File

@@ -0,0 +1,15 @@
-- no-transaction
-- Copyright 2025 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
-- Please see LICENSE in the repository root for full details.
-- We'll be requesting authorization sessions by provider, sub and sid, so we'll
-- need to index those columns
CREATE INDEX CONCURRENTLY IF NOT EXISTS
upstream_oauth_authorization_sessions_sid_sub_idx
ON upstream_oauth_authorization_sessions (
upstream_oauth_provider_id,
(id_token_claims->>'sid'),
(id_token_claims->>'sub')
);

View File

@@ -499,15 +499,45 @@ mod tests {
0
);
let mut links = Vec::with_capacity(3);
for subject in ["alice", "bob", "charlie"] {
let link = repo
.upstream_oauth_link()
.add(&mut rng, &clock, &provider, subject.to_owned(), None)
.await
.unwrap();
links.push(link);
}
let mut ids = Vec::with_capacity(20);
let sids = ["one", "two"].into_iter().cycle();
// Create 20 sessions
for idx in 0..20 {
for (idx, (link, sid)) in links.iter().cycle().zip(sids).enumerate().take(20) {
let state = format!("state-{idx}");
let session = repo
.upstream_oauth_session()
.add(&mut rng, &clock, &provider, state, None, None)
.await
.unwrap();
let id_token_claims = serde_json::json!({
"sub": link.subject,
"sid": sid,
"aud": provider.client_id,
"iss": "https://example.com/",
});
let session = repo
.upstream_oauth_session()
.complete_with_link(
&clock,
session,
link,
None,
Some(id_token_claims),
None,
None,
)
.await
.unwrap();
ids.push(session.id);
clock.advance(Duration::microseconds(10 * 1000 * 1000));
}
@@ -577,5 +607,41 @@ mod tests {
assert!(!page.has_next_page);
let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
assert_eq!(&edge_ids, &ids[6..11]);
// Check the sub/sid filters
assert_eq!(
repo.upstream_oauth_session()
.count(filter.with_sub_claim("alice").with_sid_claim("one"))
.await
.unwrap(),
4
);
assert_eq!(
repo.upstream_oauth_session()
.count(filter.with_sub_claim("bob").with_sid_claim("two"))
.await
.unwrap(),
4
);
let page = repo
.upstream_oauth_session()
.list(
filter.with_sub_claim("alice").with_sid_claim("one"),
Pagination::first(10),
)
.await
.unwrap();
assert_eq!(page.edges.len(), 4);
for edge in page.edges {
assert_eq!(
edge.id_token_claims().unwrap().get("sub").unwrap().as_str(),
Some("alice")
);
assert_eq!(
edge.id_token_claims().unwrap().get("sid").unwrap().as_str(),
Some("one")
);
}
}
}

View File

@@ -15,7 +15,7 @@ use mas_storage::{
upstream_oauth2::{UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository},
};
use rand::RngCore;
use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
use sea_query_binder::SqlxBinder;
use sqlx::PgConnection;
use ulid::Ulid;
@@ -31,13 +31,30 @@ use crate::{
impl Filter for UpstreamOAuthSessionFilter<'_> {
fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
sea_query::Condition::all().add_option(self.provider().map(|provider| {
Expr::col((
UpstreamOAuthAuthorizationSessions::Table,
UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
))
.eq(Uuid::from(provider.id))
}))
sea_query::Condition::all()
.add_option(self.provider().map(|provider| {
Expr::col((
UpstreamOAuthAuthorizationSessions::Table,
UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
))
.eq(Uuid::from(provider.id))
}))
.add_option(self.sub_claim().map(|sub| {
Expr::col((
UpstreamOAuthAuthorizationSessions::Table,
UpstreamOAuthAuthorizationSessions::IdTokenClaims,
))
.cast_json_field("sub")
.eq(sub)
}))
.add_option(self.sid_claim().map(|sid| {
Expr::col((
UpstreamOAuthAuthorizationSessions::Table,
UpstreamOAuthAuthorizationSessions::IdTokenClaims,
))
.cast_json_field("sid")
.eq(sid)
}))
}
}

View File

@@ -15,6 +15,8 @@ use crate::{Clock, Pagination, pagination::Page, repository_impl};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct UpstreamOAuthSessionFilter<'a> {
provider: Option<&'a UpstreamOAuthProvider>,
sub_claim: Option<&'a str>,
sid_claim: Option<&'a str>,
}
impl<'a> UpstreamOAuthSessionFilter<'a> {
@@ -38,6 +40,36 @@ impl<'a> UpstreamOAuthSessionFilter<'a> {
pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
self.provider
}
/// Set the `sub` claim to filter by
#[must_use]
pub fn with_sub_claim(mut self, sub_claim: &'a str) -> Self {
self.sub_claim = Some(sub_claim);
self
}
/// Get the `sub` claim filter
///
/// Returns [`None`] if no filter was set
#[must_use]
pub fn sub_claim(&self) -> Option<&str> {
self.sub_claim
}
/// Set the `sid` claim to filter by
#[must_use]
pub fn with_sid_claim(mut self, sid_claim: &'a str) -> Self {
self.sid_claim = Some(sid_claim);
self
}
/// Get the `sid` claim filter
///
/// Returns [`None`] if no filter was set
#[must_use]
pub fn sid_claim(&self) -> Option<&str> {
self.sid_claim
}
}
/// An [`UpstreamOAuthSessionRepository`] helps interacting with