Log out oauth & compat sessions when receiving a backchannel logout request
This commit is contained in:
@@ -283,6 +283,9 @@ pub async fn config_sync(
|
||||
mas_config::UpstreamOAuth2OnBackchannelLogout::LogoutBrowserOnly => {
|
||||
mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::LogoutBrowserOnly
|
||||
}
|
||||
mas_config::UpstreamOAuth2OnBackchannelLogout::LogoutAll => {
|
||||
mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::LogoutAll
|
||||
}
|
||||
};
|
||||
|
||||
repo.upstream_oauth_provider()
|
||||
|
||||
@@ -418,6 +418,10 @@ pub enum OnBackchannelLogout {
|
||||
|
||||
/// Only log out the MAS 'browser session' started by this OIDC session
|
||||
LogoutBrowserOnly,
|
||||
|
||||
/// Log out all sessions started by this OIDC session, including MAS
|
||||
/// 'browser sessions' and client sessions
|
||||
LogoutAll,
|
||||
}
|
||||
|
||||
impl OnBackchannelLogout {
|
||||
|
||||
@@ -221,6 +221,7 @@ pub struct InvalidUpstreamOAuth2TokenAuthMethod(String);
|
||||
pub enum OnBackchannelLogout {
|
||||
DoNothing,
|
||||
LogoutBrowserOnly,
|
||||
LogoutAll,
|
||||
}
|
||||
|
||||
impl OnBackchannelLogout {
|
||||
@@ -229,6 +230,7 @@ impl OnBackchannelLogout {
|
||||
match self {
|
||||
Self::DoNothing => "do_nothing",
|
||||
Self::LogoutBrowserOnly => "logout_browser_only",
|
||||
Self::LogoutAll => "logout_all",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -246,6 +248,7 @@ impl std::str::FromStr for OnBackchannelLogout {
|
||||
match s {
|
||||
"do_nothing" => Ok(Self::DoNothing),
|
||||
"logout_browser_only" => Ok(Self::LogoutBrowserOnly),
|
||||
"logout_all" => Ok(Self::LogoutAll),
|
||||
s => Err(InvalidUpstreamOAuth2OnBackchannelLogout(s.to_owned())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use axum::{
|
||||
Form, Json,
|
||||
@@ -22,7 +22,11 @@ use mas_oidc_client::{
|
||||
requests::jose::{JwtVerificationData, verify_signed_jwt},
|
||||
};
|
||||
use mas_storage::{
|
||||
BoxClock, BoxRepository, upstream_oauth2::UpstreamOAuthSessionFilter,
|
||||
BoxClock, BoxRepository, BoxRng, Pagination,
|
||||
compat::CompatSessionFilter,
|
||||
oauth2::OAuth2SessionFilter,
|
||||
queue::{QueueJobRepositoryExt as _, SyncDevicesJob},
|
||||
upstream_oauth2::UpstreamOAuthSessionFilter,
|
||||
user::BrowserSessionFilter,
|
||||
};
|
||||
use oauth2_types::errors::{ClientError, ClientErrorCode};
|
||||
@@ -131,6 +135,7 @@ const EVENTS: Claim<LogoutTokenEvents> = Claim::new("events");
|
||||
)]
|
||||
pub(crate) async fn post(
|
||||
clock: BoxClock,
|
||||
mut rng: BoxRng,
|
||||
mut repo: BoxRepository,
|
||||
State(metadata_cache): State<MetadataCache>,
|
||||
State(client): State<reqwest::Client>,
|
||||
@@ -242,10 +247,67 @@ pub(crate) async fn post(
|
||||
}
|
||||
UpstreamOAuthProviderOnBackchannelLogout::LogoutBrowserOnly => {
|
||||
let filter = BrowserSessionFilter::new()
|
||||
.authenticated_by_upstream_sessions_only(auth_session_filter);
|
||||
.authenticated_by_upstream_sessions_only(auth_session_filter)
|
||||
.active_only();
|
||||
let affected = repo.browser_session().finish_bulk(&clock, filter).await?;
|
||||
tracing::info!("Finished {affected} browser sessions");
|
||||
}
|
||||
UpstreamOAuthProviderOnBackchannelLogout::LogoutAll => {
|
||||
let browser_session_filter = BrowserSessionFilter::new()
|
||||
.authenticated_by_upstream_sessions_only(auth_session_filter);
|
||||
|
||||
// We need to loop through all the browser sessions to find all the
|
||||
// users affected so that we can trigger a device sync job for them
|
||||
let mut cursor = Pagination::first(1000);
|
||||
let mut user_ids = HashSet::new();
|
||||
loop {
|
||||
let browser_sessions = repo
|
||||
.browser_session()
|
||||
.list(browser_session_filter, cursor)
|
||||
.await?;
|
||||
for browser_session in browser_sessions.edges {
|
||||
user_ids.insert(browser_session.user.id);
|
||||
cursor = cursor.after(browser_session.id);
|
||||
}
|
||||
|
||||
if !browser_sessions.has_next_page {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let browser_sessions_affected = repo
|
||||
.browser_session()
|
||||
.finish_bulk(&clock, browser_session_filter.active_only())
|
||||
.await?;
|
||||
|
||||
let oauth2_session_filter = OAuth2SessionFilter::new()
|
||||
.active_only()
|
||||
.for_browser_sessions(browser_session_filter);
|
||||
|
||||
let oauth2_sessions_affected = repo
|
||||
.oauth2_session()
|
||||
.finish_bulk(&clock, oauth2_session_filter)
|
||||
.await?;
|
||||
|
||||
let compat_session_filter = CompatSessionFilter::new()
|
||||
.active_only()
|
||||
.for_browser_sessions(browser_session_filter);
|
||||
|
||||
let compat_sessions_affected = repo
|
||||
.compat_session()
|
||||
.finish_bulk(&clock, compat_session_filter)
|
||||
.await?;
|
||||
|
||||
tracing::info!(
|
||||
"Finished {browser_sessions_affected} browser sessions, {oauth2_sessions_affected} OAuth 2.0 sessions and {compat_sessions_affected} compatibility sessions"
|
||||
);
|
||||
|
||||
for user_id in user_ids {
|
||||
tracing::info!(user.id = %user_id, "Queueing a device sync job for user");
|
||||
let job = SyncDevicesJob::new_for_id(user_id);
|
||||
repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
@@ -27,7 +27,7 @@ use uuid::Uuid;
|
||||
use crate::{
|
||||
DatabaseError, DatabaseInconsistencyError,
|
||||
filter::{Filter, StatementExt, StatementWithJoinsExt},
|
||||
iden::{CompatSessions, CompatSsoLogins},
|
||||
iden::{CompatSessions, CompatSsoLogins, UserSessions},
|
||||
pagination::QueryBuilderExt,
|
||||
tracing::ExecuteExt,
|
||||
};
|
||||
@@ -190,6 +190,18 @@ impl Filter for CompatSessionFilter<'_> {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
|
||||
.eq(Uuid::from(browser_session.id))
|
||||
}))
|
||||
.add_option(self.browser_session_filter().map(|browser_session_filter| {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)).in_subquery(
|
||||
Query::select()
|
||||
.expr(Expr::col((
|
||||
UserSessions::Table,
|
||||
UserSessions::UserSessionId,
|
||||
)))
|
||||
.apply_filter(browser_session_filter)
|
||||
.from(UserSessions::Table)
|
||||
.take(),
|
||||
)
|
||||
}))
|
||||
.add_option(self.state().map(|state| {
|
||||
if state.is_active() {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
|
||||
|
||||
@@ -24,7 +24,7 @@ use uuid::Uuid;
|
||||
use crate::{
|
||||
DatabaseError, DatabaseInconsistencyError,
|
||||
filter::{Filter, StatementExt},
|
||||
iden::{OAuth2Clients, OAuth2Sessions},
|
||||
iden::{OAuth2Clients, OAuth2Sessions, UserSessions},
|
||||
pagination::QueryBuilderExt,
|
||||
tracing::ExecuteExt,
|
||||
};
|
||||
@@ -141,6 +141,18 @@ impl Filter for OAuth2SessionFilter<'_> {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
|
||||
.eq(Uuid::from(browser_session.id))
|
||||
}))
|
||||
.add_option(self.browser_session_filter().map(|browser_session_filter| {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)).in_subquery(
|
||||
Query::select()
|
||||
.expr(Expr::col((
|
||||
UserSessions::Table,
|
||||
UserSessions::UserSessionId,
|
||||
)))
|
||||
.apply_filter(browser_session_filter)
|
||||
.from(UserSessions::Table)
|
||||
.take(),
|
||||
)
|
||||
}))
|
||||
.add_option(self.state().map(|state| {
|
||||
if state.is_active() {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
||||
|
||||
@@ -12,7 +12,7 @@ use mas_data_model::{BrowserSession, CompatSession, CompatSsoLogin, Device, User
|
||||
use rand_core::RngCore;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{Clock, Page, Pagination, repository_impl};
|
||||
use crate::{Clock, Page, Pagination, repository_impl, user::BrowserSessionFilter};
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum CompatSessionState {
|
||||
@@ -59,6 +59,7 @@ impl CompatSessionType {
|
||||
pub struct CompatSessionFilter<'a> {
|
||||
user: Option<&'a User>,
|
||||
browser_session: Option<&'a BrowserSession>,
|
||||
browser_session_filter: Option<BrowserSessionFilter<'a>>,
|
||||
state: Option<CompatSessionState>,
|
||||
auth_type: Option<CompatSessionType>,
|
||||
device: Option<&'a Device>,
|
||||
@@ -106,12 +107,28 @@ impl<'a> CompatSessionFilter<'a> {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the browser sessions filter
|
||||
#[must_use]
|
||||
pub fn for_browser_sessions(
|
||||
mut self,
|
||||
browser_session_filter: BrowserSessionFilter<'a>,
|
||||
) -> Self {
|
||||
self.browser_session_filter = Some(browser_session_filter);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the browser session filter
|
||||
#[must_use]
|
||||
pub fn browser_session(&self) -> Option<&'a BrowserSession> {
|
||||
self.browser_session
|
||||
}
|
||||
|
||||
/// Get the browser sessions filter
|
||||
#[must_use]
|
||||
pub fn browser_session_filter(&self) -> Option<BrowserSessionFilter<'a>> {
|
||||
self.browser_session_filter
|
||||
}
|
||||
|
||||
/// Only return sessions with a last active time before the given time
|
||||
#[must_use]
|
||||
pub fn with_last_active_before(mut self, last_active_before: DateTime<Utc>) -> Self {
|
||||
|
||||
@@ -13,7 +13,7 @@ use oauth2_types::scope::Scope;
|
||||
use rand_core::RngCore;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{Clock, Pagination, pagination::Page, repository_impl};
|
||||
use crate::{Clock, Pagination, pagination::Page, repository_impl, user::BrowserSessionFilter};
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum OAuth2SessionState {
|
||||
@@ -49,6 +49,7 @@ pub struct OAuth2SessionFilter<'a> {
|
||||
user: Option<&'a User>,
|
||||
any_user: Option<bool>,
|
||||
browser_session: Option<&'a BrowserSession>,
|
||||
browser_session_filter: Option<BrowserSessionFilter<'a>>,
|
||||
device: Option<&'a Device>,
|
||||
client: Option<&'a Client>,
|
||||
client_kind: Option<ClientKind>,
|
||||
@@ -109,6 +110,16 @@ impl<'a> OAuth2SessionFilter<'a> {
|
||||
self
|
||||
}
|
||||
|
||||
/// List sessions started by a set of browser sessions
|
||||
#[must_use]
|
||||
pub fn for_browser_sessions(
|
||||
mut self,
|
||||
browser_session_filter: BrowserSessionFilter<'a>,
|
||||
) -> Self {
|
||||
self.browser_session_filter = Some(browser_session_filter);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the browser session filter
|
||||
///
|
||||
/// Returns [`None`] if no browser session filter was set
|
||||
@@ -117,6 +128,14 @@ impl<'a> OAuth2SessionFilter<'a> {
|
||||
self.browser_session
|
||||
}
|
||||
|
||||
/// Get the browser sessions filter
|
||||
///
|
||||
/// Returns [`None`] if no browser session filter was set
|
||||
#[must_use]
|
||||
pub fn browser_session_filter(&self) -> Option<BrowserSessionFilter<'a>> {
|
||||
self.browser_session_filter
|
||||
}
|
||||
|
||||
/// List sessions for a specific client
|
||||
#[must_use]
|
||||
pub fn for_client(mut self, client: &'a Client) -> Self {
|
||||
|
||||
@@ -2459,6 +2459,13 @@
|
||||
"enum": [
|
||||
"logout_browser_only"
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "Log out all sessions started by this OIDC session, including MAS 'browser sessions' and client sessions",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"logout_all"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user