Fix compat token refresh giving back a consumed token

This commit is contained in:
Olivier 'reivilibre
2026-02-13 14:57:27 +00:00
parent 1341400325
commit 9018f52d61
6 changed files with 53 additions and 21 deletions

View File

@@ -156,7 +156,7 @@ pub(crate) async fn post(
.await?; .await?;
repo.compat_refresh_token() repo.compat_refresh_token()
.consume(&clock, refresh_token) .consume_and_replace(&clock, refresh_token, &new_refresh_token)
.await?; .await?;
if let Some(access_token) = access_token { if let Some(access_token) = access_token {

View File

@@ -123,8 +123,8 @@ async fn test_compat_refresh(pool: sqlx::PgPool) {
assert_eq!( assert_eq!(
second_refresh_response, second_refresh_response,
RefreshResponse { RefreshResponse {
access_token: "???".to_owned(), access_token: "mct_Wc6Hx4l9DGzqGtgLoYqtrtBUBcWlE4_ZFyTp2".to_owned(),
refresh_token: "???".to_owned(), refresh_token: "mcr_Yp7FM44zJN5qePGMLvvMXC4Ds1A3lC_0YcYCM".to_owned(),
expires_in_ms: 300_000 expires_in_ms: 300_000
} }
); );
@@ -185,8 +185,7 @@ async fn test_refresh_with_consumed_token(pool: sqlx::PgPool) {
let _first_refresh_response: RefreshResponse = first_refresh_response.json(); let _first_refresh_response: RefreshResponse = first_refresh_response.json();
// Try to use the same refresh token again - should fail because it's consumed // Try to use the same refresh token again - should fail because it's consumed
let second_refresh_request = let second_refresh_request = Request::post("/_matrix/client/v3/refresh").json(&refresh_request);
Request::post("/_matrix/client/v3/refresh").json(&refresh_request);
let second_refresh_response = state.request(second_refresh_request).await; let second_refresh_response = state.request(second_refresh_request).await;
second_refresh_response.assert_status(StatusCode::UNAUTHORIZED); second_refresh_response.assert_status(StatusCode::UNAUTHORIZED);

View File

@@ -1,15 +1,16 @@
{ {
"db_name": "PostgreSQL", "db_name": "PostgreSQL",
"query": "\n UPDATE compat_refresh_tokens\n SET consumed_at = $2\n WHERE compat_session_id = $1\n AND consumed_at IS NULL\n ", "query": "\n UPDATE compat_refresh_tokens\n SET consumed_at = $2\n WHERE compat_session_id = $1\n AND consumed_at IS NULL\n AND compat_refresh_token_id <> $3\n ",
"describe": { "describe": {
"columns": [], "columns": [],
"parameters": { "parameters": {
"Left": [ "Left": [
"Uuid", "Uuid",
"Timestamptz" "Timestamptz",
"Uuid"
] ]
}, },
"nullable": [] "nullable": []
}, },
"hash": "f75e44b528234dac708640ad9a111f3f6b468a91bf0d5b574795bf8c80605f19" "hash": "4e64540bbffe5f4b9c4a6589012cf69eb67adaa4d40fc1910dfcd2640e32ab37"
} }

View File

@@ -437,6 +437,7 @@ mod tests {
async fn test_refresh_token_repository(pool: PgPool) { async fn test_refresh_token_repository(pool: PgPool) {
const ACCESS_TOKEN: &str = "access_token"; const ACCESS_TOKEN: &str = "access_token";
const REFRESH_TOKEN: &str = "refresh_token"; const REFRESH_TOKEN: &str = "refresh_token";
const REFRESH_TOKEN2: &str = "refresh_token2";
let mut rng = ChaChaRng::seed_from_u64(42); let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default(); let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
@@ -508,16 +509,28 @@ mod tests {
assert!(refresh_token_lookup.is_valid()); assert!(refresh_token_lookup.is_valid());
assert!(!refresh_token_lookup.is_consumed()); assert!(!refresh_token_lookup.is_consumed());
// Consume it // Consume the first token, but to do so we need a 2nd to replace it with
let refresh_token2 = repo
.compat_refresh_token()
.add(
&mut rng,
&clock,
&session,
&access_token,
REFRESH_TOKEN2.to_owned(),
)
.await
.unwrap();
let refresh_token = repo let refresh_token = repo
.compat_refresh_token() .compat_refresh_token()
.consume(&clock, refresh_token) .consume_and_replace(&clock, refresh_token, &refresh_token2)
.await .await
.unwrap(); .unwrap();
assert!(!refresh_token.is_valid()); assert!(!refresh_token.is_valid());
assert!(refresh_token.is_consumed()); assert!(refresh_token.is_consumed());
// Reload it and check again // Reload the first token and check again
let refresh_token_lookup = repo let refresh_token_lookup = repo
.compat_refresh_token() .compat_refresh_token()
.find_by_token(REFRESH_TOKEN) .find_by_token(REFRESH_TOKEN)
@@ -530,7 +543,7 @@ mod tests {
// Consuming it again should not work // Consuming it again should not work
assert!( assert!(
repo.compat_refresh_token() repo.compat_refresh_token()
.consume(&clock, refresh_token) .consume_and_replace(&clock, refresh_token, &refresh_token2)
.await .await
.is_err() .is_err()
); );

View File

@@ -185,20 +185,26 @@ impl CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'_> {
} }
#[tracing::instrument( #[tracing::instrument(
name = "db.compat_refresh_token.consume", name = "db.compat_refresh_token.consume_and_replace",
skip_all, skip_all,
fields( fields(
db.query.text, db.query.text,
%compat_refresh_token.id, %compat_refresh_token.id,
%successor_compat_refresh_token.id,
compat_session.id = %compat_refresh_token.session_id, compat_session.id = %compat_refresh_token.session_id,
), ),
err, err,
)] )]
async fn consume( async fn consume_and_replace(
&mut self, &mut self,
clock: &dyn Clock, clock: &dyn Clock,
compat_refresh_token: CompatRefreshToken, compat_refresh_token: CompatRefreshToken,
successor_compat_refresh_token: &CompatRefreshToken,
) -> Result<CompatRefreshToken, Self::Error> { ) -> Result<CompatRefreshToken, Self::Error> {
if compat_refresh_token.session_id != successor_compat_refresh_token.session_id {
return Err(DatabaseError::invalid_operation());
}
let consumed_at = clock.now(); let consumed_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
@@ -206,9 +212,11 @@ impl CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'_> {
SET consumed_at = $2 SET consumed_at = $2
WHERE compat_session_id = $1 WHERE compat_session_id = $1
AND consumed_at IS NULL AND consumed_at IS NULL
AND compat_refresh_token_id <> $3
"#, "#,
Uuid::from(compat_refresh_token.session_id), Uuid::from(compat_refresh_token.session_id),
consumed_at, consumed_at,
Uuid::from(successor_compat_refresh_token.id),
) )
.traced() .traced()
.execute(&mut *self.conn) .execute(&mut *self.conn)

View File

@@ -69,16 +69,22 @@ pub trait CompatRefreshTokenRepository: Send + Sync {
token: String, token: String,
) -> Result<CompatRefreshToken, Self::Error>; ) -> Result<CompatRefreshToken, Self::Error>;
/// Consume a compat refresh token. /// Consume the given compat refresh token, as well as all other refresh
/// tokens from the same session, except for the given successor compat
/// refresh token.
/// ///
/// This also marks other refresh tokens in the same session as consumed. /// The given successor refresh token will thereafter be the only valid
/// refresh token for the session.
///
/// # Historical context
///
/// When using a refresh token, we must be able to mark multiple other
/// refresh tokens in the same session as consumed.
/// This is desirable because the syn2mas migration process can import /// This is desirable because the syn2mas migration process can import
/// multiple refresh tokens for one device (compat session). /// multiple refresh tokens for one device (compat session).
/// But once the user uses one of those, the others should no longer /// But once the user uses one of those, the others should no longer
/// be valid. /// be valid.
/// ///
/// Returns the consumed compat refresh token
///
/// # Parameters /// # Parameters
/// ///
/// * `clock`: The clock used to generate timestamps /// * `clock`: The clock used to generate timestamps
@@ -86,11 +92,15 @@ pub trait CompatRefreshTokenRepository: Send + Sync {
/// ///
/// # Errors /// # Errors
/// ///
/// Returns [`Self::Error`] if the underlying repository fails /// - Returns [`Self::Error`] if the underlying repository fails
async fn consume( /// - Returns an error if `compat_refresh_token` is not valid to be
/// consumed.
/// - Returns an error if no refresh tokens would be consumed.
async fn consume_and_replace(
&mut self, &mut self,
clock: &dyn Clock, clock: &dyn Clock,
compat_refresh_token: CompatRefreshToken, compat_refresh_token: CompatRefreshToken,
successor_compat_refresh_token: &CompatRefreshToken,
) -> Result<CompatRefreshToken, Self::Error>; ) -> Result<CompatRefreshToken, Self::Error>;
} }
@@ -111,9 +121,10 @@ repository_impl!(CompatRefreshTokenRepository:
token: String, token: String,
) -> Result<CompatRefreshToken, Self::Error>; ) -> Result<CompatRefreshToken, Self::Error>;
async fn consume( async fn consume_and_replace(
&mut self, &mut self,
clock: &dyn Clock, clock: &dyn Clock,
compat_refresh_token: CompatRefreshToken, compat_refresh_token: CompatRefreshToken,
successor_compat_refresh_token: &CompatRefreshToken,
) -> Result<CompatRefreshToken, Self::Error>; ) -> Result<CompatRefreshToken, Self::Error>;
); );