From e8627166a93d969d1832312acb4248c400d8632e Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 4 Jul 2025 12:49:07 +0200 Subject: [PATCH] Log out oauth & compat sessions when receiving a backchannel logout request --- crates/cli/src/sync.rs | 3 + crates/config/src/sections/upstream_oauth2.rs | 4 ++ .../src/upstream_oauth2/provider.rs | 3 + .../src/upstream_oauth2/backchannel_logout.rs | 68 ++++++++++++++++++- crates/storage-pg/src/compat/session.rs | 14 +++- crates/storage-pg/src/oauth2/session.rs | 14 +++- crates/storage/src/compat/session.rs | 19 +++++- crates/storage/src/oauth2/session.rs | 21 +++++- docs/config.schema.json | 7 ++ 9 files changed, 146 insertions(+), 7 deletions(-) diff --git a/crates/cli/src/sync.rs b/crates/cli/src/sync.rs index d8433c291..363c2a0f8 100644 --- a/crates/cli/src/sync.rs +++ b/crates/cli/src/sync.rs @@ -283,6 +283,9 @@ pub async fn config_sync( mas_config::UpstreamOAuth2OnBackchannelLogout::LogoutBrowserOnly => { mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::LogoutBrowserOnly } + mas_config::UpstreamOAuth2OnBackchannelLogout::LogoutAll => { + mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::LogoutAll + } }; repo.upstream_oauth_provider() diff --git a/crates/config/src/sections/upstream_oauth2.rs b/crates/config/src/sections/upstream_oauth2.rs index 2cf43b530..2162c9fe4 100644 --- a/crates/config/src/sections/upstream_oauth2.rs +++ b/crates/config/src/sections/upstream_oauth2.rs @@ -418,6 +418,10 @@ pub enum OnBackchannelLogout { /// Only log out the MAS 'browser session' started by this OIDC session LogoutBrowserOnly, + + /// Log out all sessions started by this OIDC session, including MAS + /// 'browser sessions' and client sessions + LogoutAll, } impl OnBackchannelLogout { diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs index c384366df..3a71c03c3 100644 --- a/crates/data-model/src/upstream_oauth2/provider.rs +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -221,6 +221,7 @@ pub struct InvalidUpstreamOAuth2TokenAuthMethod(String); pub enum OnBackchannelLogout { DoNothing, LogoutBrowserOnly, + LogoutAll, } impl OnBackchannelLogout { @@ -229,6 +230,7 @@ impl OnBackchannelLogout { match self { Self::DoNothing => "do_nothing", Self::LogoutBrowserOnly => "logout_browser_only", + Self::LogoutAll => "logout_all", } } } @@ -246,6 +248,7 @@ impl std::str::FromStr for OnBackchannelLogout { match s { "do_nothing" => Ok(Self::DoNothing), "logout_browser_only" => Ok(Self::LogoutBrowserOnly), + "logout_all" => Ok(Self::LogoutAll), s => Err(InvalidUpstreamOAuth2OnBackchannelLogout(s.to_owned())), } } diff --git a/crates/handlers/src/upstream_oauth2/backchannel_logout.rs b/crates/handlers/src/upstream_oauth2/backchannel_logout.rs index 71d8b674f..9e2a034b9 100644 --- a/crates/handlers/src/upstream_oauth2/backchannel_logout.rs +++ b/crates/handlers/src/upstream_oauth2/backchannel_logout.rs @@ -3,7 +3,7 @@ // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // Please see LICENSE files in the repository root for full details. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use axum::{ Form, Json, @@ -22,7 +22,11 @@ use mas_oidc_client::{ requests::jose::{JwtVerificationData, verify_signed_jwt}, }; use mas_storage::{ - BoxClock, BoxRepository, upstream_oauth2::UpstreamOAuthSessionFilter, + BoxClock, BoxRepository, BoxRng, Pagination, + compat::CompatSessionFilter, + oauth2::OAuth2SessionFilter, + queue::{QueueJobRepositoryExt as _, SyncDevicesJob}, + upstream_oauth2::UpstreamOAuthSessionFilter, user::BrowserSessionFilter, }; use oauth2_types::errors::{ClientError, ClientErrorCode}; @@ -131,6 +135,7 @@ const EVENTS: Claim = Claim::new("events"); )] pub(crate) async fn post( clock: BoxClock, + mut rng: BoxRng, mut repo: BoxRepository, State(metadata_cache): State, State(client): State, @@ -242,10 +247,67 @@ pub(crate) async fn post( } UpstreamOAuthProviderOnBackchannelLogout::LogoutBrowserOnly => { let filter = BrowserSessionFilter::new() - .authenticated_by_upstream_sessions_only(auth_session_filter); + .authenticated_by_upstream_sessions_only(auth_session_filter) + .active_only(); let affected = repo.browser_session().finish_bulk(&clock, filter).await?; tracing::info!("Finished {affected} browser sessions"); } + UpstreamOAuthProviderOnBackchannelLogout::LogoutAll => { + let browser_session_filter = BrowserSessionFilter::new() + .authenticated_by_upstream_sessions_only(auth_session_filter); + + // We need to loop through all the browser sessions to find all the + // users affected so that we can trigger a device sync job for them + let mut cursor = Pagination::first(1000); + let mut user_ids = HashSet::new(); + loop { + let browser_sessions = repo + .browser_session() + .list(browser_session_filter, cursor) + .await?; + for browser_session in browser_sessions.edges { + user_ids.insert(browser_session.user.id); + cursor = cursor.after(browser_session.id); + } + + if !browser_sessions.has_next_page { + break; + } + } + + let browser_sessions_affected = repo + .browser_session() + .finish_bulk(&clock, browser_session_filter.active_only()) + .await?; + + let oauth2_session_filter = OAuth2SessionFilter::new() + .active_only() + .for_browser_sessions(browser_session_filter); + + let oauth2_sessions_affected = repo + .oauth2_session() + .finish_bulk(&clock, oauth2_session_filter) + .await?; + + let compat_session_filter = CompatSessionFilter::new() + .active_only() + .for_browser_sessions(browser_session_filter); + + let compat_sessions_affected = repo + .compat_session() + .finish_bulk(&clock, compat_session_filter) + .await?; + + tracing::info!( + "Finished {browser_sessions_affected} browser sessions, {oauth2_sessions_affected} OAuth 2.0 sessions and {compat_sessions_affected} compatibility sessions" + ); + + for user_id in user_ids { + tracing::info!(user.id = %user_id, "Queueing a device sync job for user"); + let job = SyncDevicesJob::new_for_id(user_id); + repo.queue_job().schedule_job(&mut rng, &clock, job).await?; + } + } } repo.save().await?; diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs index 19e6366d6..d5d41fb7b 100644 --- a/crates/storage-pg/src/compat/session.rs +++ b/crates/storage-pg/src/compat/session.rs @@ -27,7 +27,7 @@ use uuid::Uuid; use crate::{ DatabaseError, DatabaseInconsistencyError, filter::{Filter, StatementExt, StatementWithJoinsExt}, - iden::{CompatSessions, CompatSsoLogins}, + iden::{CompatSessions, CompatSsoLogins, UserSessions}, pagination::QueryBuilderExt, tracing::ExecuteExt, }; @@ -190,6 +190,18 @@ impl Filter for CompatSessionFilter<'_> { Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)) .eq(Uuid::from(browser_session.id)) })) + .add_option(self.browser_session_filter().map(|browser_session_filter| { + Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)).in_subquery( + Query::select() + .expr(Expr::col(( + UserSessions::Table, + UserSessions::UserSessionId, + ))) + .apply_filter(browser_session_filter) + .from(UserSessions::Table) + .take(), + ) + })) .add_option(self.state().map(|state| { if state.is_active() { Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null() diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index 3aa3877b1..00fc501a0 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -24,7 +24,7 @@ use uuid::Uuid; use crate::{ DatabaseError, DatabaseInconsistencyError, filter::{Filter, StatementExt}, - iden::{OAuth2Clients, OAuth2Sessions}, + iden::{OAuth2Clients, OAuth2Sessions, UserSessions}, pagination::QueryBuilderExt, tracing::ExecuteExt, }; @@ -141,6 +141,18 @@ impl Filter for OAuth2SessionFilter<'_> { Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)) .eq(Uuid::from(browser_session.id)) })) + .add_option(self.browser_session_filter().map(|browser_session_filter| { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)).in_subquery( + Query::select() + .expr(Expr::col(( + UserSessions::Table, + UserSessions::UserSessionId, + ))) + .apply_filter(browser_session_filter) + .from(UserSessions::Table) + .take(), + ) + })) .add_option(self.state().map(|state| { if state.is_active() { Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null() diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index 2b964ba22..5287b4cee 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -12,7 +12,7 @@ use mas_data_model::{BrowserSession, CompatSession, CompatSsoLogin, Device, User use rand_core::RngCore; use ulid::Ulid; -use crate::{Clock, Page, Pagination, repository_impl}; +use crate::{Clock, Page, Pagination, repository_impl, user::BrowserSessionFilter}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum CompatSessionState { @@ -59,6 +59,7 @@ impl CompatSessionType { pub struct CompatSessionFilter<'a> { user: Option<&'a User>, browser_session: Option<&'a BrowserSession>, + browser_session_filter: Option>, state: Option, auth_type: Option, device: Option<&'a Device>, @@ -106,12 +107,28 @@ impl<'a> CompatSessionFilter<'a> { self } + /// Set the browser sessions filter + #[must_use] + pub fn for_browser_sessions( + mut self, + browser_session_filter: BrowserSessionFilter<'a>, + ) -> Self { + self.browser_session_filter = Some(browser_session_filter); + self + } + /// Get the browser session filter #[must_use] pub fn browser_session(&self) -> Option<&'a BrowserSession> { self.browser_session } + /// Get the browser sessions filter + #[must_use] + pub fn browser_session_filter(&self) -> Option> { + self.browser_session_filter + } + /// Only return sessions with a last active time before the given time #[must_use] pub fn with_last_active_before(mut self, last_active_before: DateTime) -> Self { diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index faf933a7f..5d217c1e2 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -13,7 +13,7 @@ use oauth2_types::scope::Scope; use rand_core::RngCore; use ulid::Ulid; -use crate::{Clock, Pagination, pagination::Page, repository_impl}; +use crate::{Clock, Pagination, pagination::Page, repository_impl, user::BrowserSessionFilter}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum OAuth2SessionState { @@ -49,6 +49,7 @@ pub struct OAuth2SessionFilter<'a> { user: Option<&'a User>, any_user: Option, browser_session: Option<&'a BrowserSession>, + browser_session_filter: Option>, device: Option<&'a Device>, client: Option<&'a Client>, client_kind: Option, @@ -109,6 +110,16 @@ impl<'a> OAuth2SessionFilter<'a> { self } + /// List sessions started by a set of browser sessions + #[must_use] + pub fn for_browser_sessions( + mut self, + browser_session_filter: BrowserSessionFilter<'a>, + ) -> Self { + self.browser_session_filter = Some(browser_session_filter); + self + } + /// Get the browser session filter /// /// Returns [`None`] if no browser session filter was set @@ -117,6 +128,14 @@ impl<'a> OAuth2SessionFilter<'a> { self.browser_session } + /// Get the browser sessions filter + /// + /// Returns [`None`] if no browser session filter was set + #[must_use] + pub fn browser_session_filter(&self) -> Option> { + self.browser_session_filter + } + /// List sessions for a specific client #[must_use] pub fn for_client(mut self, client: &'a Client) -> Self { diff --git a/docs/config.schema.json b/docs/config.schema.json index cf2793c25..abb811fca 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -2459,6 +2459,13 @@ "enum": [ "logout_browser_only" ] + }, + { + "description": "Log out all sessions started by this OIDC session, including MAS 'browser sessions' and client sessions", + "type": "string", + "enum": [ + "logout_all" + ] } ] },