Cleanup consumed refresh tokens

This commit is contained in:
Quentin Gliech
2026-01-12 11:03:31 +01:00
parent ab25c23829
commit f98957617e
8 changed files with 209 additions and 2 deletions

View File

@@ -0,0 +1,30 @@
{
"db_name": "PostgreSQL",
"query": "\n WITH\n to_delete AS (\n SELECT rts_to_del.oauth2_refresh_token_id\n FROM oauth2_refresh_tokens rts_to_del\n LEFT JOIN oauth2_refresh_tokens next_rts\n ON rts_to_del.next_oauth2_refresh_token_id = next_rts.oauth2_refresh_token_id\n WHERE rts_to_del.consumed_at IS NOT NULL\n AND (rts_to_del.next_oauth2_refresh_token_id IS NULL OR next_rts.consumed_at IS NOT NULL)\n AND ($1::timestamptz IS NULL OR rts_to_del.consumed_at >= $1::timestamptz)\n AND rts_to_del.consumed_at < $2::timestamptz\n ORDER BY rts_to_del.consumed_at ASC\n LIMIT $3\n ),\n\n deleted AS (\n DELETE FROM oauth2_refresh_tokens\n USING to_delete\n WHERE oauth2_refresh_tokens.oauth2_refresh_token_id = to_delete.oauth2_refresh_token_id\n RETURNING oauth2_refresh_tokens.consumed_at\n )\n\n SELECT\n COUNT(*) as \"count!\",\n MAX(consumed_at) as last_consumed_at\n FROM deleted\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "count!",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "last_consumed_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Timestamptz",
"Timestamptz",
"Int8"
]
},
"nullable": [
null,
null
]
},
"hash": "093d42238578771b4183b48c1680ba438b6b18306dfe1454fa4124c0207b3deb"
}

View File

@@ -0,0 +1,11 @@
-- no-transaction
-- Copyright 2026 Element Creations Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
-- Please see LICENSE in the repository root for full details.
-- Adds a partial index on oauth2_refresh_tokens that are consumed
-- to speed up cleaning up of consumed tokens
CREATE INDEX CONCURRENTLY IF NOT EXISTS oauth_refresh_token_not_consumed_idx
ON oauth2_refresh_tokens (oauth2_refresh_token_id)
WHERE consumed_at IS NOT NULL;

View File

@@ -0,0 +1,11 @@
-- no-transaction
-- Copyright 2026 Element Creations Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
-- Please see LICENSE in the repository root for full details.
-- Adds a partial index on oauth2_refresh_tokens on the consumed_at field,
-- including other interesting fields to speed up cleaning up of consumed tokens
CREATE INDEX CONCURRENTLY IF NOT EXISTS oauth_refresh_token_consumed_at_idx
ON oauth2_refresh_tokens (consumed_at, next_oauth2_refresh_token_id, oauth2_refresh_token_id)
WHERE consumed_at IS NOT NULL;

View File

@@ -336,4 +336,65 @@ impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
res.last_revoked_at,
))
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.cleanup_consumed",
skip_all,
fields(
db.query.text,
),
err,
)]
async fn cleanup_consumed(
&mut self,
since: Option<DateTime<Utc>>,
until: DateTime<Utc>,
limit: usize,
) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
// We only consider a token as consumed if also the next token has its
// `consumed_at` set. This makes the query a bit expensive to compute,
// but is optimised to two index scans and a nested join using the
// `oauth2_refresh_token_not_consumed_idx` and
// `oauth2_refresh_token_consumed_at_idx` indexes.
let res = sqlx::query!(
r#"
WITH
to_delete AS (
SELECT rts_to_del.oauth2_refresh_token_id
FROM oauth2_refresh_tokens rts_to_del
LEFT JOIN oauth2_refresh_tokens next_rts
ON rts_to_del.next_oauth2_refresh_token_id = next_rts.oauth2_refresh_token_id
WHERE rts_to_del.consumed_at IS NOT NULL
AND (rts_to_del.next_oauth2_refresh_token_id IS NULL OR next_rts.consumed_at IS NOT NULL)
AND ($1::timestamptz IS NULL OR rts_to_del.consumed_at >= $1::timestamptz)
AND rts_to_del.consumed_at < $2::timestamptz
ORDER BY rts_to_del.consumed_at ASC
LIMIT $3
),
deleted AS (
DELETE FROM oauth2_refresh_tokens
USING to_delete
WHERE oauth2_refresh_tokens.oauth2_refresh_token_id = to_delete.oauth2_refresh_token_id
RETURNING oauth2_refresh_tokens.consumed_at
)
SELECT
COUNT(*) as "count!",
MAX(consumed_at) as last_consumed_at
FROM deleted
"#,
since,
until,
i64::try_from(limit).unwrap_or(i64::MAX),
)
.traced()
.fetch_one(&mut *self.conn)
.await?;
Ok((
res.count.try_into().unwrap_or(usize::MAX),
res.last_consumed_at,
))
}
}

View File

@@ -133,6 +133,31 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync {
until: chrono::DateTime<chrono::Utc>,
limit: usize,
) -> Result<(usize, Option<chrono::DateTime<chrono::Utc>>), Self::Error>;
/// Cleanup consumed refresh tokens that were consumed before a certain time
///
/// A token is considered as fully consumed only if both the `consumed_at`
/// column is set and the next refresh token in the chain also has its
/// `consumed_at` set.
///
/// Returns the number of deleted tokens and the last `consumed_at`
/// timestamp processed
///
/// # Parameters
///
/// * `since`: An optional timestamp to start from
/// * `until`: The timestamp before which to revoke tokens
/// * `limit`: The maximum number of tokens to revoke
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn cleanup_consumed(
&mut self,
since: Option<chrono::DateTime<chrono::Utc>>,
until: chrono::DateTime<chrono::Utc>,
limit: usize,
) -> Result<(usize, Option<chrono::DateTime<chrono::Utc>>), Self::Error>;
}
repository_impl!(OAuth2RefreshTokenRepository:
@@ -171,4 +196,11 @@ repository_impl!(OAuth2RefreshTokenRepository:
until: chrono::DateTime<chrono::Utc>,
limit: usize,
) -> Result<(usize, Option<chrono::DateTime<chrono::Utc>>), Self::Error>;
async fn cleanup_consumed(
&mut self,
since: Option<chrono::DateTime<chrono::Utc>>,
until: chrono::DateTime<chrono::Utc>,
limit: usize,
) -> Result<(usize, Option<chrono::DateTime<chrono::Utc>>), Self::Error>;
);

View File

@@ -342,6 +342,14 @@ impl InsertableJob for CleanupRevokedOAuthRefreshTokensJob {
const QUEUE_NAME: &'static str = "cleanup-revoked-oauth-refresh-tokens";
}
/// Cleanup consumed OAuth 2.0 refresh tokens
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct CleanupConsumedOAuthRefreshTokensJob;
impl InsertableJob for CleanupConsumedOAuthRefreshTokensJob {
const QUEUE_NAME: &'static str = "cleanup-consumed-oauth-refresh-tokens";
}
/// Scheduled job to expire inactive sessions
///
/// This job will trigger jobs to expire inactive compat, oauth and user

View File

@@ -11,8 +11,9 @@ use std::time::Duration;
use async_trait::async_trait;
use mas_storage::queue::{
CleanupExpiredOAuthAccessTokensJob, CleanupRevokedOAuthAccessTokensJob,
CleanupRevokedOAuthRefreshTokensJob, PruneStalePolicyDataJob,
CleanupConsumedOAuthRefreshTokensJob, CleanupExpiredOAuthAccessTokensJob,
CleanupRevokedOAuthAccessTokensJob, CleanupRevokedOAuthRefreshTokensJob,
PruneStalePolicyDataJob,
};
use tracing::{debug, info};
@@ -167,6 +168,52 @@ impl RunnableJob for CleanupRevokedOAuthRefreshTokensJob {
}
}
#[async_trait]
impl RunnableJob for CleanupConsumedOAuthRefreshTokensJob {
#[tracing::instrument(name = "job.cleanup_consumed_oauth_refresh_tokens", skip_all)]
async fn run(&self, state: &State, context: JobContext) -> Result<(), JobError> {
// Cleanup tokens that were consumed more than an hour ago
let until = state.clock.now() - chrono::Duration::hours(1);
let mut total = 0;
// Run until we get cancelled. We don't schedule a retry if we get cancelled, as
// this is a scheduled job and it will end up being rescheduled later anyway.
let mut since = None;
while !context.cancellation_token.is_cancelled() {
let mut repo = state.repository().await.map_err(JobError::retry)?;
// This returns the number of deleted tokens, and the last consumed_at timestamp
let (count, last_consumed_at) = repo
.oauth2_refresh_token()
.cleanup_consumed(since, until, BATCH_SIZE)
.await
.map_err(JobError::retry)?;
repo.save().await.map_err(JobError::retry)?;
since = last_consumed_at;
total += count;
// Check how many we deleted. If we deleted exactly BATCH_SIZE,
// there might be more to delete
if count != BATCH_SIZE {
break;
}
}
if total == 0 {
debug!("no token to clean up");
} else {
info!(count = total, "cleaned up consumed tokens");
}
Ok(())
}
fn timeout(&self) -> Option<Duration> {
Some(Duration::from_secs(60))
}
}
#[async_trait]
impl RunnableJob for PruneStalePolicyDataJob {
#[tracing::instrument(name = "job.prune_stale_policy_data", skip_all)]

View File

@@ -132,6 +132,7 @@ pub async fn init(
.register_handler::<mas_storage::queue::CleanupRevokedOAuthAccessTokensJob>()
.register_handler::<mas_storage::queue::CleanupExpiredOAuthAccessTokensJob>()
.register_handler::<mas_storage::queue::CleanupRevokedOAuthRefreshTokensJob>()
.register_handler::<mas_storage::queue::CleanupConsumedOAuthRefreshTokensJob>()
.register_handler::<mas_storage::queue::DeactivateUserJob>()
.register_handler::<mas_storage::queue::DeleteDeviceJob>()
.register_handler::<mas_storage::queue::ProvisionDeviceJob>()
@@ -159,6 +160,12 @@ pub async fn init(
"0 10 * * * *".parse()?,
mas_storage::queue::CleanupRevokedOAuthRefreshTokensJob,
)
.add_schedule(
"cleanup-consumed-oauth-refresh-tokens",
// Run this job every hour
"0 20 * * * *".parse()?,
mas_storage::queue::CleanupConsumedOAuthRefreshTokensJob,
)
.add_schedule(
"cleanup-expired-oauth-access-tokens",
// Run this job every 4 hours