Record the next refresh token ID when refreshing

This will help us determine whether we had a double-refresh happening
This commit is contained in:
Quentin Gliech
2024-12-10 14:06:06 +01:00
parent fb85651793
commit b3756e4ae4
10 changed files with 125 additions and 40 deletions

View File

@@ -109,6 +109,7 @@ pub enum RefreshTokenState {
Valid,
Consumed {
consumed_at: DateTime<Utc>,
next_refresh_token_id: Option<Ulid>,
},
}
@@ -118,9 +119,16 @@ impl RefreshTokenState {
/// # Errors
///
/// Returns an error if the refresh token is already consumed.
fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
fn consume(
self,
consumed_at: DateTime<Utc>,
replaced_by: &RefreshToken,
) -> Result<Self, InvalidTransitionError> {
match self {
Self::Valid => Ok(Self::Consumed { consumed_at }),
Self::Valid => Ok(Self::Consumed {
consumed_at,
next_refresh_token_id: Some(replaced_by.id),
}),
Self::Consumed { .. } => Err(InvalidTransitionError),
}
}
@@ -140,6 +148,18 @@ impl RefreshTokenState {
pub fn is_consumed(&self) -> bool {
matches!(self, Self::Consumed { .. })
}
/// Returns the next refresh token ID, if any.
#[must_use]
pub fn next_refresh_token_id(&self) -> Option<Ulid> {
match self {
Self::Valid => None,
Self::Consumed {
next_refresh_token_id,
..
} => *next_refresh_token_id,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -171,8 +191,12 @@ impl RefreshToken {
/// # Errors
///
/// Returns an error if the refresh token is already consumed.
pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.consume(consumed_at)?;
pub fn consume(
mut self,
consumed_at: DateTime<Utc>,
replaced_by: &Self,
) -> Result<Self, InvalidTransitionError> {
self.state = self.state.consume(consumed_at, replaced_by)?;
Ok(self)
}
}

View File

@@ -544,7 +544,7 @@ async fn refresh_token_grant(
let refresh_token = repo
.oauth2_refresh_token()
.consume(clock, refresh_token)
.consume(clock, refresh_token, &new_refresh_token)
.await?;
if let Some(access_token_id) = refresh_token.access_token_id {

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n ",
"query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n , next_oauth2_refresh_token_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n ",
"describe": {
"columns": [
{
@@ -32,6 +32,11 @@
"ordinal": 5,
"name": "oauth2_session_id",
"type_info": "Uuid"
},
{
"ordinal": 6,
"name": "next_oauth2_refresh_token_id",
"type_info": "Uuid"
}
],
"parameters": {
@@ -45,8 +50,9 @@
false,
true,
true,
false
false,
true
]
},
"hash": "e709869c062ac50248b1f9f8f808cc2f5e7bef58a6c2e42a7bb0c1cb8f508671"
"hash": "265a981142194216593cd09cd8f6af36c103e030358262e46a4bd5e4006dc630"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE oauth2_refresh_token_id = $1\n ",
"query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n , next_oauth2_refresh_token_id\n FROM oauth2_refresh_tokens\n\n WHERE oauth2_refresh_token_id = $1\n ",
"describe": {
"columns": [
{
@@ -32,6 +32,11 @@
"ordinal": 5,
"name": "oauth2_session_id",
"type_info": "Uuid"
},
{
"ordinal": 6,
"name": "next_oauth2_refresh_token_id",
"type_info": "Uuid"
}
],
"parameters": {
@@ -45,8 +50,9 @@
false,
true,
true,
false
false,
true
]
},
"hash": "a6fa7811d0a7c62c7cccff96dc82db5b25462fa7669fde1941ccab4712585b20"
"hash": "5b94c22d44692a16fa5a6edd5dac019c36bf5983182b0871de6e85036d8df466"
}

View File

@@ -1,15 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE oauth2_refresh_tokens\n SET consumed_at = $2\n WHERE oauth2_refresh_token_id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz"
]
},
"nullable": []
},
"hash": "b6a6f5386dc89e4bc2ce56d578a29341848fce336d339b6bbf425956f5ed5032"
}

View File

@@ -0,0 +1,16 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE oauth2_refresh_tokens\n SET consumed_at = $2,\n next_oauth2_refresh_token_id = $3\n WHERE oauth2_refresh_token_id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz",
"Uuid"
]
},
"nullable": []
},
"hash": "ffbfef8b7e72ec4bae02b6bbe862980b5fe575ae8432a000e9c4e4307caa2d9b"
}

View File

@@ -0,0 +1,9 @@
-- Copyright 2024 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only
-- Please see LICENSE in the repository root for full details.
-- Add a reference to the 'next' refresh token when it was consumed and replaced
ALTER TABLE oauth2_refresh_tokens
ADD COLUMN "next_oauth2_refresh_token_id" UUID
REFERENCES oauth2_refresh_tokens (oauth2_refresh_token_id);

View File

@@ -341,6 +341,19 @@ mod tests {
clock.advance(Duration::try_minutes(-6).unwrap()); // Go back in time
assert!(access_token.is_valid(clock.now()));
// Create a new refresh token to be able to consume the old one
let new_refresh_token = repo
.oauth2_refresh_token()
.add(
&mut rng,
&clock,
&session,
&access_token,
"ddeeff".to_owned(),
)
.await
.unwrap();
// Mark the access token as revoked
let access_token = repo
.oauth2_access_token()
@@ -353,7 +366,7 @@ mod tests {
assert!(refresh_token.is_valid());
let refresh_token = repo
.oauth2_refresh_token()
.consume(&clock, refresh_token)
.consume(&clock, refresh_token, &new_refresh_token)
.await
.unwrap();
assert!(!refresh_token.is_valid());

View File

@@ -13,7 +13,7 @@ use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError};
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError};
/// An implementation of [`OAuth2RefreshTokenRepository`] for a PostgreSQL
/// connection
@@ -36,23 +36,39 @@ struct OAuth2RefreshTokenLookup {
consumed_at: Option<DateTime<Utc>>,
oauth2_access_token_id: Option<Uuid>,
oauth2_session_id: Uuid,
next_oauth2_refresh_token_id: Option<Uuid>,
}
impl From<OAuth2RefreshTokenLookup> for RefreshToken {
fn from(value: OAuth2RefreshTokenLookup) -> Self {
let state = match value.consumed_at {
None => RefreshTokenState::Valid,
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
impl TryFrom<OAuth2RefreshTokenLookup> for RefreshToken {
type Error = DatabaseInconsistencyError;
fn try_from(value: OAuth2RefreshTokenLookup) -> Result<Self, Self::Error> {
let id = value.oauth2_refresh_token_id.into();
let state = match (value.consumed_at, value.next_oauth2_refresh_token_id) {
(None, None) => RefreshTokenState::Valid,
(Some(consumed_at), None) => RefreshTokenState::Consumed {
consumed_at,
next_refresh_token_id: None,
},
(Some(consumed_at), Some(id)) => RefreshTokenState::Consumed {
consumed_at,
next_refresh_token_id: Some(Ulid::from(id)),
},
(None, Some(_)) => {
return Err(DatabaseInconsistencyError::on("oauth2_refresh_tokens")
.column("next_oauth2_refresh_token_id")
.row(id))
}
};
RefreshToken {
id: value.oauth2_refresh_token_id.into(),
Ok(RefreshToken {
id,
state,
session_id: value.oauth2_session_id.into(),
refresh_token: value.refresh_token,
created_at: value.created_at,
access_token_id: value.oauth2_access_token_id.map(Ulid::from),
}
})
}
}
@@ -79,18 +95,20 @@ impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
, next_oauth2_refresh_token_id
FROM oauth2_refresh_tokens
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_optional(&mut *self.conn)
.await?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
@@ -114,6 +132,7 @@ impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
, next_oauth2_refresh_token_id
FROM oauth2_refresh_tokens
WHERE refresh_token = $1
@@ -126,7 +145,7 @@ impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
@@ -194,24 +213,28 @@ impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
&mut self,
clock: &dyn Clock,
refresh_token: RefreshToken,
replaced_by: &RefreshToken,
) -> Result<RefreshToken, Self::Error> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens
SET consumed_at = $2
SET consumed_at = $2,
next_oauth2_refresh_token_id = $3
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(refresh_token.id),
consumed_at,
Uuid::from(replaced_by.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
refresh_token
.consume(consumed_at)
.consume(consumed_at, replaced_by)
.map_err(DatabaseError::to_invalid_operation)
}
}

View File

@@ -80,6 +80,7 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync {
///
/// * `clock`: The clock used to generate timestamps
/// * `refresh_token`: The [`RefreshToken`] to consume
/// * `replaced_by`: The [`RefreshToken`] which replaced this one
///
/// # Errors
///
@@ -89,6 +90,7 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync {
&mut self,
clock: &dyn Clock,
refresh_token: RefreshToken,
replaced_by: &RefreshToken,
) -> Result<RefreshToken, Self::Error>;
}
@@ -113,5 +115,6 @@ repository_impl!(OAuth2RefreshTokenRepository:
&mut self,
clock: &dyn Clock,
refresh_token: RefreshToken,
replaced_by: &RefreshToken,
) -> Result<RefreshToken, Self::Error>;
);