Allow filtering upstream sessions by sub and sid claims
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
-- no-transaction
|
||||
-- Copyright 2025 New Vector Ltd.
|
||||
--
|
||||
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
-- Please see LICENSE in the repository root for full details.
|
||||
|
||||
-- We'll be requesting authorization sessions by provider, sub and sid, so we'll
|
||||
-- need to index those columns
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS
|
||||
upstream_oauth_authorization_sessions_sub_sid_idx
|
||||
ON upstream_oauth_authorization_sessions (
|
||||
upstream_oauth_provider_id,
|
||||
(id_token_claims->>'sub'),
|
||||
(id_token_claims->>'sid')
|
||||
);
|
||||
@@ -0,0 +1,15 @@
|
||||
-- no-transaction
|
||||
-- Copyright 2025 New Vector Ltd.
|
||||
--
|
||||
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
-- Please see LICENSE in the repository root for full details.
|
||||
|
||||
-- We'll be requesting authorization sessions by provider, sub and sid, so we'll
|
||||
-- need to index those columns
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS
|
||||
upstream_oauth_authorization_sessions_sid_sub_idx
|
||||
ON upstream_oauth_authorization_sessions (
|
||||
upstream_oauth_provider_id,
|
||||
(id_token_claims->>'sid'),
|
||||
(id_token_claims->>'sub')
|
||||
);
|
||||
@@ -499,15 +499,45 @@ mod tests {
|
||||
0
|
||||
);
|
||||
|
||||
let mut links = Vec::with_capacity(3);
|
||||
for subject in ["alice", "bob", "charlie"] {
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.add(&mut rng, &clock, &provider, subject.to_owned(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
links.push(link);
|
||||
}
|
||||
|
||||
let mut ids = Vec::with_capacity(20);
|
||||
let sids = ["one", "two"].into_iter().cycle();
|
||||
// Create 20 sessions
|
||||
for idx in 0..20 {
|
||||
for (idx, (link, sid)) in links.iter().cycle().zip(sids).enumerate().take(20) {
|
||||
let state = format!("state-{idx}");
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.add(&mut rng, &clock, &provider, state, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let id_token_claims = serde_json::json!({
|
||||
"sub": link.subject,
|
||||
"sid": sid,
|
||||
"aud": provider.client_id,
|
||||
"iss": "https://example.com/",
|
||||
});
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.complete_with_link(
|
||||
&clock,
|
||||
session,
|
||||
link,
|
||||
None,
|
||||
Some(id_token_claims),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
ids.push(session.id);
|
||||
clock.advance(Duration::microseconds(10 * 1000 * 1000));
|
||||
}
|
||||
@@ -577,5 +607,41 @@ mod tests {
|
||||
assert!(!page.has_next_page);
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[6..11]);
|
||||
|
||||
// Check the sub/sid filters
|
||||
assert_eq!(
|
||||
repo.upstream_oauth_session()
|
||||
.count(filter.with_sub_claim("alice").with_sid_claim("one"))
|
||||
.await
|
||||
.unwrap(),
|
||||
4
|
||||
);
|
||||
assert_eq!(
|
||||
repo.upstream_oauth_session()
|
||||
.count(filter.with_sub_claim("bob").with_sid_claim("two"))
|
||||
.await
|
||||
.unwrap(),
|
||||
4
|
||||
);
|
||||
|
||||
let page = repo
|
||||
.upstream_oauth_session()
|
||||
.list(
|
||||
filter.with_sub_claim("alice").with_sid_claim("one"),
|
||||
Pagination::first(10),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(page.edges.len(), 4);
|
||||
for edge in page.edges {
|
||||
assert_eq!(
|
||||
edge.id_token_claims().unwrap().get("sub").unwrap().as_str(),
|
||||
Some("alice")
|
||||
);
|
||||
assert_eq!(
|
||||
edge.id_token_claims().unwrap().get("sid").unwrap().as_str(),
|
||||
Some("one")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ use mas_storage::{
|
||||
upstream_oauth2::{UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository},
|
||||
};
|
||||
use rand::RngCore;
|
||||
use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
|
||||
use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
|
||||
use sea_query_binder::SqlxBinder;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
@@ -31,13 +31,30 @@ use crate::{
|
||||
|
||||
impl Filter for UpstreamOAuthSessionFilter<'_> {
|
||||
fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
|
||||
sea_query::Condition::all().add_option(self.provider().map(|provider| {
|
||||
Expr::col((
|
||||
UpstreamOAuthAuthorizationSessions::Table,
|
||||
UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
|
||||
))
|
||||
.eq(Uuid::from(provider.id))
|
||||
}))
|
||||
sea_query::Condition::all()
|
||||
.add_option(self.provider().map(|provider| {
|
||||
Expr::col((
|
||||
UpstreamOAuthAuthorizationSessions::Table,
|
||||
UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
|
||||
))
|
||||
.eq(Uuid::from(provider.id))
|
||||
}))
|
||||
.add_option(self.sub_claim().map(|sub| {
|
||||
Expr::col((
|
||||
UpstreamOAuthAuthorizationSessions::Table,
|
||||
UpstreamOAuthAuthorizationSessions::IdTokenClaims,
|
||||
))
|
||||
.cast_json_field("sub")
|
||||
.eq(sub)
|
||||
}))
|
||||
.add_option(self.sid_claim().map(|sid| {
|
||||
Expr::col((
|
||||
UpstreamOAuthAuthorizationSessions::Table,
|
||||
UpstreamOAuthAuthorizationSessions::IdTokenClaims,
|
||||
))
|
||||
.cast_json_field("sid")
|
||||
.eq(sid)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ use crate::{Clock, Pagination, pagination::Page, repository_impl};
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||
pub struct UpstreamOAuthSessionFilter<'a> {
|
||||
provider: Option<&'a UpstreamOAuthProvider>,
|
||||
sub_claim: Option<&'a str>,
|
||||
sid_claim: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a> UpstreamOAuthSessionFilter<'a> {
|
||||
@@ -38,6 +40,36 @@ impl<'a> UpstreamOAuthSessionFilter<'a> {
|
||||
pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
|
||||
self.provider
|
||||
}
|
||||
|
||||
/// Set the `sub` claim to filter by
|
||||
#[must_use]
|
||||
pub fn with_sub_claim(mut self, sub_claim: &'a str) -> Self {
|
||||
self.sub_claim = Some(sub_claim);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the `sub` claim filter
|
||||
///
|
||||
/// Returns [`None`] if no filter was set
|
||||
#[must_use]
|
||||
pub fn sub_claim(&self) -> Option<&str> {
|
||||
self.sub_claim
|
||||
}
|
||||
|
||||
/// Set the `sid` claim to filter by
|
||||
#[must_use]
|
||||
pub fn with_sid_claim(mut self, sid_claim: &'a str) -> Self {
|
||||
self.sid_claim = Some(sid_claim);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the `sid` claim filter
|
||||
///
|
||||
/// Returns [`None`] if no filter was set
|
||||
#[must_use]
|
||||
pub fn sid_claim(&self) -> Option<&str> {
|
||||
self.sid_claim
|
||||
}
|
||||
}
|
||||
|
||||
/// An [`UpstreamOAuthSessionRepository`] helps interacting with
|
||||
|
||||
Reference in New Issue
Block a user