Log out oauth & compat sessions when receiving a backchannel logout request

This commit is contained in:
Quentin Gliech
2025-07-04 12:49:07 +02:00
parent 84d9e47e23
commit e8627166a9
9 changed files with 146 additions and 7 deletions

View File

@@ -283,6 +283,9 @@ pub async fn config_sync(
mas_config::UpstreamOAuth2OnBackchannelLogout::LogoutBrowserOnly => {
mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::LogoutBrowserOnly
}
mas_config::UpstreamOAuth2OnBackchannelLogout::LogoutAll => {
mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::LogoutAll
}
};
repo.upstream_oauth_provider()

View File

@@ -418,6 +418,10 @@ pub enum OnBackchannelLogout {
/// Only log out the MAS 'browser session' started by this OIDC session
LogoutBrowserOnly,
/// Log out all sessions started by this OIDC session, including MAS
/// 'browser sessions' and client sessions
LogoutAll,
}
impl OnBackchannelLogout {

View File

@@ -221,6 +221,7 @@ pub struct InvalidUpstreamOAuth2TokenAuthMethod(String);
pub enum OnBackchannelLogout {
DoNothing,
LogoutBrowserOnly,
LogoutAll,
}
impl OnBackchannelLogout {
@@ -229,6 +230,7 @@ impl OnBackchannelLogout {
match self {
Self::DoNothing => "do_nothing",
Self::LogoutBrowserOnly => "logout_browser_only",
Self::LogoutAll => "logout_all",
}
}
}
@@ -246,6 +248,7 @@ impl std::str::FromStr for OnBackchannelLogout {
match s {
"do_nothing" => Ok(Self::DoNothing),
"logout_browser_only" => Ok(Self::LogoutBrowserOnly),
"logout_all" => Ok(Self::LogoutAll),
s => Err(InvalidUpstreamOAuth2OnBackchannelLogout(s.to_owned())),
}
}

View File

@@ -3,7 +3,7 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use axum::{
Form, Json,
@@ -22,7 +22,11 @@ use mas_oidc_client::{
requests::jose::{JwtVerificationData, verify_signed_jwt},
};
use mas_storage::{
BoxClock, BoxRepository, upstream_oauth2::UpstreamOAuthSessionFilter,
BoxClock, BoxRepository, BoxRng, Pagination,
compat::CompatSessionFilter,
oauth2::OAuth2SessionFilter,
queue::{QueueJobRepositoryExt as _, SyncDevicesJob},
upstream_oauth2::UpstreamOAuthSessionFilter,
user::BrowserSessionFilter,
};
use oauth2_types::errors::{ClientError, ClientErrorCode};
@@ -131,6 +135,7 @@ const EVENTS: Claim<LogoutTokenEvents> = Claim::new("events");
)]
pub(crate) async fn post(
clock: BoxClock,
mut rng: BoxRng,
mut repo: BoxRepository,
State(metadata_cache): State<MetadataCache>,
State(client): State<reqwest::Client>,
@@ -242,10 +247,67 @@ pub(crate) async fn post(
}
UpstreamOAuthProviderOnBackchannelLogout::LogoutBrowserOnly => {
let filter = BrowserSessionFilter::new()
.authenticated_by_upstream_sessions_only(auth_session_filter);
.authenticated_by_upstream_sessions_only(auth_session_filter)
.active_only();
let affected = repo.browser_session().finish_bulk(&clock, filter).await?;
tracing::info!("Finished {affected} browser sessions");
}
UpstreamOAuthProviderOnBackchannelLogout::LogoutAll => {
let browser_session_filter = BrowserSessionFilter::new()
.authenticated_by_upstream_sessions_only(auth_session_filter);
// We need to loop through all the browser sessions to find all the
// users affected so that we can trigger a device sync job for them
let mut cursor = Pagination::first(1000);
let mut user_ids = HashSet::new();
loop {
let browser_sessions = repo
.browser_session()
.list(browser_session_filter, cursor)
.await?;
for browser_session in browser_sessions.edges {
user_ids.insert(browser_session.user.id);
cursor = cursor.after(browser_session.id);
}
if !browser_sessions.has_next_page {
break;
}
}
let browser_sessions_affected = repo
.browser_session()
.finish_bulk(&clock, browser_session_filter.active_only())
.await?;
let oauth2_session_filter = OAuth2SessionFilter::new()
.active_only()
.for_browser_sessions(browser_session_filter);
let oauth2_sessions_affected = repo
.oauth2_session()
.finish_bulk(&clock, oauth2_session_filter)
.await?;
let compat_session_filter = CompatSessionFilter::new()
.active_only()
.for_browser_sessions(browser_session_filter);
let compat_sessions_affected = repo
.compat_session()
.finish_bulk(&clock, compat_session_filter)
.await?;
tracing::info!(
"Finished {browser_sessions_affected} browser sessions, {oauth2_sessions_affected} OAuth 2.0 sessions and {compat_sessions_affected} compatibility sessions"
);
for user_id in user_ids {
tracing::info!(user.id = %user_id, "Queueing a device sync job for user");
let job = SyncDevicesJob::new_for_id(user_id);
repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
}
}
}
repo.save().await?;

View File

@@ -27,7 +27,7 @@ use uuid::Uuid;
use crate::{
DatabaseError, DatabaseInconsistencyError,
filter::{Filter, StatementExt, StatementWithJoinsExt},
iden::{CompatSessions, CompatSsoLogins},
iden::{CompatSessions, CompatSsoLogins, UserSessions},
pagination::QueryBuilderExt,
tracing::ExecuteExt,
};
@@ -190,6 +190,18 @@ impl Filter for CompatSessionFilter<'_> {
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
.eq(Uuid::from(browser_session.id))
}))
.add_option(self.browser_session_filter().map(|browser_session_filter| {
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)).in_subquery(
Query::select()
.expr(Expr::col((
UserSessions::Table,
UserSessions::UserSessionId,
)))
.apply_filter(browser_session_filter)
.from(UserSessions::Table)
.take(),
)
}))
.add_option(self.state().map(|state| {
if state.is_active() {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()

View File

@@ -24,7 +24,7 @@ use uuid::Uuid;
use crate::{
DatabaseError, DatabaseInconsistencyError,
filter::{Filter, StatementExt},
iden::{OAuth2Clients, OAuth2Sessions},
iden::{OAuth2Clients, OAuth2Sessions, UserSessions},
pagination::QueryBuilderExt,
tracing::ExecuteExt,
};
@@ -141,6 +141,18 @@ impl Filter for OAuth2SessionFilter<'_> {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
.eq(Uuid::from(browser_session.id))
}))
.add_option(self.browser_session_filter().map(|browser_session_filter| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)).in_subquery(
Query::select()
.expr(Expr::col((
UserSessions::Table,
UserSessions::UserSessionId,
)))
.apply_filter(browser_session_filter)
.from(UserSessions::Table)
.take(),
)
}))
.add_option(self.state().map(|state| {
if state.is_active() {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()

View File

@@ -12,7 +12,7 @@ use mas_data_model::{BrowserSession, CompatSession, CompatSsoLogin, Device, User
use rand_core::RngCore;
use ulid::Ulid;
use crate::{Clock, Page, Pagination, repository_impl};
use crate::{Clock, Page, Pagination, repository_impl, user::BrowserSessionFilter};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CompatSessionState {
@@ -59,6 +59,7 @@ impl CompatSessionType {
pub struct CompatSessionFilter<'a> {
user: Option<&'a User>,
browser_session: Option<&'a BrowserSession>,
browser_session_filter: Option<BrowserSessionFilter<'a>>,
state: Option<CompatSessionState>,
auth_type: Option<CompatSessionType>,
device: Option<&'a Device>,
@@ -106,12 +107,28 @@ impl<'a> CompatSessionFilter<'a> {
self
}
/// Set the browser sessions filter
#[must_use]
pub fn for_browser_sessions(
mut self,
browser_session_filter: BrowserSessionFilter<'a>,
) -> Self {
self.browser_session_filter = Some(browser_session_filter);
self
}
/// Get the browser session filter
#[must_use]
pub fn browser_session(&self) -> Option<&'a BrowserSession> {
self.browser_session
}
/// Get the browser sessions filter
#[must_use]
pub fn browser_session_filter(&self) -> Option<BrowserSessionFilter<'a>> {
self.browser_session_filter
}
/// Only return sessions with a last active time before the given time
#[must_use]
pub fn with_last_active_before(mut self, last_active_before: DateTime<Utc>) -> Self {

View File

@@ -13,7 +13,7 @@ use oauth2_types::scope::Scope;
use rand_core::RngCore;
use ulid::Ulid;
use crate::{Clock, Pagination, pagination::Page, repository_impl};
use crate::{Clock, Pagination, pagination::Page, repository_impl, user::BrowserSessionFilter};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OAuth2SessionState {
@@ -49,6 +49,7 @@ pub struct OAuth2SessionFilter<'a> {
user: Option<&'a User>,
any_user: Option<bool>,
browser_session: Option<&'a BrowserSession>,
browser_session_filter: Option<BrowserSessionFilter<'a>>,
device: Option<&'a Device>,
client: Option<&'a Client>,
client_kind: Option<ClientKind>,
@@ -109,6 +110,16 @@ impl<'a> OAuth2SessionFilter<'a> {
self
}
/// List sessions started by a set of browser sessions
#[must_use]
pub fn for_browser_sessions(
mut self,
browser_session_filter: BrowserSessionFilter<'a>,
) -> Self {
self.browser_session_filter = Some(browser_session_filter);
self
}
/// Get the browser session filter
///
/// Returns [`None`] if no browser session filter was set
@@ -117,6 +128,14 @@ impl<'a> OAuth2SessionFilter<'a> {
self.browser_session
}
/// Get the browser sessions filter
///
/// Returns [`None`] if no browser session filter was set
#[must_use]
pub fn browser_session_filter(&self) -> Option<BrowserSessionFilter<'a>> {
self.browser_session_filter
}
/// List sessions for a specific client
#[must_use]
pub fn for_client(mut self, client: &'a Client) -> Self {

View File

@@ -2459,6 +2459,13 @@
"enum": [
"logout_browser_only"
]
},
{
"description": "Log out all sessions started by this OIDC session, including MAS 'browser sessions' and client sessions",
"type": "string",
"enum": [
"logout_all"
]
}
]
},