From f98957617e0356a0d9158107076109728fffa110 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 12 Jan 2026 11:03:31 +0100 Subject: [PATCH] Cleanup consumed refresh tokens --- ...0ba438b6b18306dfe1454fa4124c0207b3deb.json | 30 +++++++++ ...0_oauth_refresh_token_not_consumed_idx.sql | 11 ++++ ...37_oauth_refresh_token_consumed_at_idx.sql | 11 ++++ crates/storage-pg/src/oauth2/refresh_token.rs | 61 +++++++++++++++++++ crates/storage/src/oauth2/refresh_token.rs | 32 ++++++++++ crates/storage/src/queue/tasks.rs | 8 +++ crates/tasks/src/database.rs | 51 +++++++++++++++- crates/tasks/src/lib.rs | 7 +++ 8 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 crates/storage-pg/.sqlx/query-093d42238578771b4183b48c1680ba438b6b18306dfe1454fa4124c0207b3deb.json create mode 100644 crates/storage-pg/migrations/20260112094550_oauth_refresh_token_not_consumed_idx.sql create mode 100644 crates/storage-pg/migrations/20260112094837_oauth_refresh_token_consumed_at_idx.sql diff --git a/crates/storage-pg/.sqlx/query-093d42238578771b4183b48c1680ba438b6b18306dfe1454fa4124c0207b3deb.json b/crates/storage-pg/.sqlx/query-093d42238578771b4183b48c1680ba438b6b18306dfe1454fa4124c0207b3deb.json new file mode 100644 index 000000000..3af6ac9c9 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-093d42238578771b4183b48c1680ba438b6b18306dfe1454fa4124c0207b3deb.json @@ -0,0 +1,30 @@ +{ + "db_name": "PostgreSQL", + "query": "\n WITH\n to_delete AS (\n SELECT rts_to_del.oauth2_refresh_token_id\n FROM oauth2_refresh_tokens rts_to_del\n LEFT JOIN oauth2_refresh_tokens next_rts\n ON rts_to_del.next_oauth2_refresh_token_id = next_rts.oauth2_refresh_token_id\n WHERE rts_to_del.consumed_at IS NOT NULL\n AND (rts_to_del.next_oauth2_refresh_token_id IS NULL OR next_rts.consumed_at IS NOT NULL)\n AND ($1::timestamptz IS NULL OR rts_to_del.consumed_at >= $1::timestamptz)\n AND rts_to_del.consumed_at < $2::timestamptz\n ORDER BY rts_to_del.consumed_at ASC\n LIMIT $3\n ),\n\n deleted AS (\n DELETE FROM oauth2_refresh_tokens\n USING to_delete\n WHERE oauth2_refresh_tokens.oauth2_refresh_token_id = to_delete.oauth2_refresh_token_id\n RETURNING oauth2_refresh_tokens.consumed_at\n )\n\n SELECT\n COUNT(*) as \"count!\",\n MAX(consumed_at) as last_consumed_at\n FROM deleted\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "count!", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "last_consumed_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Timestamptz", + "Timestamptz", + "Int8" + ] + }, + "nullable": [ + null, + null + ] + }, + "hash": "093d42238578771b4183b48c1680ba438b6b18306dfe1454fa4124c0207b3deb" +} diff --git a/crates/storage-pg/migrations/20260112094550_oauth_refresh_token_not_consumed_idx.sql b/crates/storage-pg/migrations/20260112094550_oauth_refresh_token_not_consumed_idx.sql new file mode 100644 index 000000000..6d44b4652 --- /dev/null +++ b/crates/storage-pg/migrations/20260112094550_oauth_refresh_token_not_consumed_idx.sql @@ -0,0 +1,11 @@ +-- no-transaction +-- Copyright 2026 Element Creations Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +-- Please see LICENSE in the repository root for full details. + +-- Adds a partial index on oauth2_refresh_tokens that are consumed +-- to speed up cleaning up of consumed tokens +CREATE INDEX CONCURRENTLY IF NOT EXISTS oauth_refresh_token_not_consumed_idx + ON oauth2_refresh_tokens (oauth2_refresh_token_id) + WHERE consumed_at IS NOT NULL; diff --git a/crates/storage-pg/migrations/20260112094837_oauth_refresh_token_consumed_at_idx.sql b/crates/storage-pg/migrations/20260112094837_oauth_refresh_token_consumed_at_idx.sql new file mode 100644 index 000000000..ff99e7a60 --- /dev/null +++ b/crates/storage-pg/migrations/20260112094837_oauth_refresh_token_consumed_at_idx.sql @@ -0,0 +1,11 @@ +-- no-transaction +-- Copyright 2026 Element Creations Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +-- Please see LICENSE in the repository root for full details. + +-- Adds a partial index on oauth2_refresh_tokens on the consumed_at field, +-- including other interesting fields to speed up cleaning up of consumed tokens +CREATE INDEX CONCURRENTLY IF NOT EXISTS oauth_refresh_token_consumed_at_idx + ON oauth2_refresh_tokens (consumed_at, next_oauth2_refresh_token_id, oauth2_refresh_token_id) + WHERE consumed_at IS NOT NULL; diff --git a/crates/storage-pg/src/oauth2/refresh_token.rs b/crates/storage-pg/src/oauth2/refresh_token.rs index d98e986c8..cdf328bbd 100644 --- a/crates/storage-pg/src/oauth2/refresh_token.rs +++ b/crates/storage-pg/src/oauth2/refresh_token.rs @@ -336,4 +336,65 @@ impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> { res.last_revoked_at, )) } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.cleanup_consumed", + skip_all, + fields( + db.query.text, + ), + err, + )] + async fn cleanup_consumed( + &mut self, + since: Option>, + until: DateTime, + limit: usize, + ) -> Result<(usize, Option>), Self::Error> { + // We only consider a token as consumed if also the next token has its + // `consumed_at` set. This makes the query a bit expensive to compute, + // but is optimised to two index scans and a nested join using the + // `oauth2_refresh_token_not_consumed_idx` and + // `oauth2_refresh_token_consumed_at_idx` indexes. + let res = sqlx::query!( + r#" + WITH + to_delete AS ( + SELECT rts_to_del.oauth2_refresh_token_id + FROM oauth2_refresh_tokens rts_to_del + LEFT JOIN oauth2_refresh_tokens next_rts + ON rts_to_del.next_oauth2_refresh_token_id = next_rts.oauth2_refresh_token_id + WHERE rts_to_del.consumed_at IS NOT NULL + AND (rts_to_del.next_oauth2_refresh_token_id IS NULL OR next_rts.consumed_at IS NOT NULL) + AND ($1::timestamptz IS NULL OR rts_to_del.consumed_at >= $1::timestamptz) + AND rts_to_del.consumed_at < $2::timestamptz + ORDER BY rts_to_del.consumed_at ASC + LIMIT $3 + ), + + deleted AS ( + DELETE FROM oauth2_refresh_tokens + USING to_delete + WHERE oauth2_refresh_tokens.oauth2_refresh_token_id = to_delete.oauth2_refresh_token_id + RETURNING oauth2_refresh_tokens.consumed_at + ) + + SELECT + COUNT(*) as "count!", + MAX(consumed_at) as last_consumed_at + FROM deleted + "#, + since, + until, + i64::try_from(limit).unwrap_or(i64::MAX), + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + Ok(( + res.count.try_into().unwrap_or(usize::MAX), + res.last_consumed_at, + )) + } } diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 1dd55db16..8a352d836 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -133,6 +133,31 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync { until: chrono::DateTime, limit: usize, ) -> Result<(usize, Option>), Self::Error>; + + /// Cleanup consumed refresh tokens that were consumed before a certain time + /// + /// A token is considered as fully consumed only if both the `consumed_at` + /// column is set and the next refresh token in the chain also has its + /// `consumed_at` set. + /// + /// Returns the number of deleted tokens and the last `consumed_at` + /// timestamp processed + /// + /// # Parameters + /// + /// * `since`: An optional timestamp to start from + /// * `until`: The timestamp before which to revoke tokens + /// * `limit`: The maximum number of tokens to revoke + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn cleanup_consumed( + &mut self, + since: Option>, + until: chrono::DateTime, + limit: usize, + ) -> Result<(usize, Option>), Self::Error>; } repository_impl!(OAuth2RefreshTokenRepository: @@ -171,4 +196,11 @@ repository_impl!(OAuth2RefreshTokenRepository: until: chrono::DateTime, limit: usize, ) -> Result<(usize, Option>), Self::Error>; + + async fn cleanup_consumed( + &mut self, + since: Option>, + until: chrono::DateTime, + limit: usize, + ) -> Result<(usize, Option>), Self::Error>; ); diff --git a/crates/storage/src/queue/tasks.rs b/crates/storage/src/queue/tasks.rs index b033f68f7..7abfd13c8 100644 --- a/crates/storage/src/queue/tasks.rs +++ b/crates/storage/src/queue/tasks.rs @@ -342,6 +342,14 @@ impl InsertableJob for CleanupRevokedOAuthRefreshTokensJob { const QUEUE_NAME: &'static str = "cleanup-revoked-oauth-refresh-tokens"; } +/// Cleanup consumed OAuth 2.0 refresh tokens +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +pub struct CleanupConsumedOAuthRefreshTokensJob; + +impl InsertableJob for CleanupConsumedOAuthRefreshTokensJob { + const QUEUE_NAME: &'static str = "cleanup-consumed-oauth-refresh-tokens"; +} + /// Scheduled job to expire inactive sessions /// /// This job will trigger jobs to expire inactive compat, oauth and user diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index f1cffea3f..9e6e67658 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -11,8 +11,9 @@ use std::time::Duration; use async_trait::async_trait; use mas_storage::queue::{ - CleanupExpiredOAuthAccessTokensJob, CleanupRevokedOAuthAccessTokensJob, - CleanupRevokedOAuthRefreshTokensJob, PruneStalePolicyDataJob, + CleanupConsumedOAuthRefreshTokensJob, CleanupExpiredOAuthAccessTokensJob, + CleanupRevokedOAuthAccessTokensJob, CleanupRevokedOAuthRefreshTokensJob, + PruneStalePolicyDataJob, }; use tracing::{debug, info}; @@ -167,6 +168,52 @@ impl RunnableJob for CleanupRevokedOAuthRefreshTokensJob { } } +#[async_trait] +impl RunnableJob for CleanupConsumedOAuthRefreshTokensJob { + #[tracing::instrument(name = "job.cleanup_consumed_oauth_refresh_tokens", skip_all)] + async fn run(&self, state: &State, context: JobContext) -> Result<(), JobError> { + // Cleanup tokens that were consumed more than an hour ago + let until = state.clock.now() - chrono::Duration::hours(1); + let mut total = 0; + + // Run until we get cancelled. We don't schedule a retry if we get cancelled, as + // this is a scheduled job and it will end up being rescheduled later anyway. + let mut since = None; + while !context.cancellation_token.is_cancelled() { + let mut repo = state.repository().await.map_err(JobError::retry)?; + + // This returns the number of deleted tokens, and the last consumed_at timestamp + let (count, last_consumed_at) = repo + .oauth2_refresh_token() + .cleanup_consumed(since, until, BATCH_SIZE) + .await + .map_err(JobError::retry)?; + repo.save().await.map_err(JobError::retry)?; + + since = last_consumed_at; + total += count; + + // Check how many we deleted. If we deleted exactly BATCH_SIZE, + // there might be more to delete + if count != BATCH_SIZE { + break; + } + } + + if total == 0 { + debug!("no token to clean up"); + } else { + info!(count = total, "cleaned up consumed tokens"); + } + + Ok(()) + } + + fn timeout(&self) -> Option { + Some(Duration::from_secs(60)) + } +} + #[async_trait] impl RunnableJob for PruneStalePolicyDataJob { #[tracing::instrument(name = "job.prune_stale_policy_data", skip_all)] diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index eb1ee3bc9..f2570049b 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -132,6 +132,7 @@ pub async fn init( .register_handler::() .register_handler::() .register_handler::() + .register_handler::() .register_handler::() .register_handler::() .register_handler::() @@ -159,6 +160,12 @@ pub async fn init( "0 10 * * * *".parse()?, mas_storage::queue::CleanupRevokedOAuthRefreshTokensJob, ) + .add_schedule( + "cleanup-consumed-oauth-refresh-tokens", + // Run this job every hour + "0 20 * * * *".parse()?, + mas_storage::queue::CleanupConsumedOAuthRefreshTokensJob, + ) .add_schedule( "cleanup-expired-oauth-access-tokens", // Run this job every 4 hours