Record the next refresh token ID when refreshing
This will help us determine whether we had a double-refresh happening
This commit is contained in:
@@ -109,6 +109,7 @@ pub enum RefreshTokenState {
|
||||
Valid,
|
||||
Consumed {
|
||||
consumed_at: DateTime<Utc>,
|
||||
next_refresh_token_id: Option<Ulid>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -118,9 +119,16 @@ impl RefreshTokenState {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the refresh token is already consumed.
|
||||
fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
fn consume(
|
||||
self,
|
||||
consumed_at: DateTime<Utc>,
|
||||
replaced_by: &RefreshToken,
|
||||
) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Valid => Ok(Self::Consumed { consumed_at }),
|
||||
Self::Valid => Ok(Self::Consumed {
|
||||
consumed_at,
|
||||
next_refresh_token_id: Some(replaced_by.id),
|
||||
}),
|
||||
Self::Consumed { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
@@ -140,6 +148,18 @@ impl RefreshTokenState {
|
||||
pub fn is_consumed(&self) -> bool {
|
||||
matches!(self, Self::Consumed { .. })
|
||||
}
|
||||
|
||||
/// Returns the next refresh token ID, if any.
|
||||
#[must_use]
|
||||
pub fn next_refresh_token_id(&self) -> Option<Ulid> {
|
||||
match self {
|
||||
Self::Valid => None,
|
||||
Self::Consumed {
|
||||
next_refresh_token_id,
|
||||
..
|
||||
} => *next_refresh_token_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -171,8 +191,12 @@ impl RefreshToken {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the refresh token is already consumed.
|
||||
pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.consume(consumed_at)?;
|
||||
pub fn consume(
|
||||
mut self,
|
||||
consumed_at: DateTime<Utc>,
|
||||
replaced_by: &Self,
|
||||
) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.consume(consumed_at, replaced_by)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -544,7 +544,7 @@ async fn refresh_token_grant(
|
||||
|
||||
let refresh_token = repo
|
||||
.oauth2_refresh_token()
|
||||
.consume(clock, refresh_token)
|
||||
.consume(clock, refresh_token, &new_refresh_token)
|
||||
.await?;
|
||||
|
||||
if let Some(access_token_id) = refresh_token.access_token_id {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n ",
|
||||
"query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n , next_oauth2_refresh_token_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -32,6 +32,11 @@
|
||||
"ordinal": 5,
|
||||
"name": "oauth2_session_id",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "next_oauth2_refresh_token_id",
|
||||
"type_info": "Uuid"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
@@ -45,8 +50,9 @@
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "e709869c062ac50248b1f9f8f808cc2f5e7bef58a6c2e42a7bb0c1cb8f508671"
|
||||
"hash": "265a981142194216593cd09cd8f6af36c103e030358262e46a4bd5e4006dc630"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE oauth2_refresh_token_id = $1\n ",
|
||||
"query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n , next_oauth2_refresh_token_id\n FROM oauth2_refresh_tokens\n\n WHERE oauth2_refresh_token_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -32,6 +32,11 @@
|
||||
"ordinal": 5,
|
||||
"name": "oauth2_session_id",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "next_oauth2_refresh_token_id",
|
||||
"type_info": "Uuid"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
@@ -45,8 +50,9 @@
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "a6fa7811d0a7c62c7cccff96dc82db5b25462fa7669fde1941ccab4712585b20"
|
||||
"hash": "5b94c22d44692a16fa5a6edd5dac019c36bf5983182b0871de6e85036d8df466"
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE oauth2_refresh_tokens\n SET consumed_at = $2\n WHERE oauth2_refresh_token_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "b6a6f5386dc89e4bc2ce56d578a29341848fce336d339b6bbf425956f5ed5032"
|
||||
}
|
||||
16
crates/storage-pg/.sqlx/query-ffbfef8b7e72ec4bae02b6bbe862980b5fe575ae8432a000e9c4e4307caa2d9b.json
generated
Normal file
16
crates/storage-pg/.sqlx/query-ffbfef8b7e72ec4bae02b6bbe862980b5fe575ae8432a000e9c4e4307caa2d9b.json
generated
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE oauth2_refresh_tokens\n SET consumed_at = $2,\n next_oauth2_refresh_token_id = $3\n WHERE oauth2_refresh_token_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Timestamptz",
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "ffbfef8b7e72ec4bae02b6bbe862980b5fe575ae8432a000e9c4e4307caa2d9b"
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
-- Copyright 2024 New Vector Ltd.
|
||||
--
|
||||
-- SPDX-License-Identifier: AGPL-3.0-only
|
||||
-- Please see LICENSE in the repository root for full details.
|
||||
|
||||
-- Add a reference to the 'next' refresh token when it was consumed and replaced
|
||||
ALTER TABLE oauth2_refresh_tokens
|
||||
ADD COLUMN "next_oauth2_refresh_token_id" UUID
|
||||
REFERENCES oauth2_refresh_tokens (oauth2_refresh_token_id);
|
||||
@@ -341,6 +341,19 @@ mod tests {
|
||||
clock.advance(Duration::try_minutes(-6).unwrap()); // Go back in time
|
||||
assert!(access_token.is_valid(clock.now()));
|
||||
|
||||
// Create a new refresh token to be able to consume the old one
|
||||
let new_refresh_token = repo
|
||||
.oauth2_refresh_token()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
&access_token,
|
||||
"ddeeff".to_owned(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Mark the access token as revoked
|
||||
let access_token = repo
|
||||
.oauth2_access_token()
|
||||
@@ -353,7 +366,7 @@ mod tests {
|
||||
assert!(refresh_token.is_valid());
|
||||
let refresh_token = repo
|
||||
.oauth2_refresh_token()
|
||||
.consume(&clock, refresh_token)
|
||||
.consume(&clock, refresh_token, &new_refresh_token)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!refresh_token.is_valid());
|
||||
|
||||
@@ -13,7 +13,7 @@ use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, DatabaseError};
|
||||
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError};
|
||||
|
||||
/// An implementation of [`OAuth2RefreshTokenRepository`] for a PostgreSQL
|
||||
/// connection
|
||||
@@ -36,23 +36,39 @@ struct OAuth2RefreshTokenLookup {
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
oauth2_access_token_id: Option<Uuid>,
|
||||
oauth2_session_id: Uuid,
|
||||
next_oauth2_refresh_token_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
impl From<OAuth2RefreshTokenLookup> for RefreshToken {
|
||||
fn from(value: OAuth2RefreshTokenLookup) -> Self {
|
||||
let state = match value.consumed_at {
|
||||
None => RefreshTokenState::Valid,
|
||||
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
|
||||
impl TryFrom<OAuth2RefreshTokenLookup> for RefreshToken {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: OAuth2RefreshTokenLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.oauth2_refresh_token_id.into();
|
||||
let state = match (value.consumed_at, value.next_oauth2_refresh_token_id) {
|
||||
(None, None) => RefreshTokenState::Valid,
|
||||
(Some(consumed_at), None) => RefreshTokenState::Consumed {
|
||||
consumed_at,
|
||||
next_refresh_token_id: None,
|
||||
},
|
||||
(Some(consumed_at), Some(id)) => RefreshTokenState::Consumed {
|
||||
consumed_at,
|
||||
next_refresh_token_id: Some(Ulid::from(id)),
|
||||
},
|
||||
(None, Some(_)) => {
|
||||
return Err(DatabaseInconsistencyError::on("oauth2_refresh_tokens")
|
||||
.column("next_oauth2_refresh_token_id")
|
||||
.row(id))
|
||||
}
|
||||
};
|
||||
|
||||
RefreshToken {
|
||||
id: value.oauth2_refresh_token_id.into(),
|
||||
Ok(RefreshToken {
|
||||
id,
|
||||
state,
|
||||
session_id: value.oauth2_session_id.into(),
|
||||
refresh_token: value.refresh_token,
|
||||
created_at: value.created_at,
|
||||
access_token_id: value.oauth2_access_token_id.map(Ulid::from),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,18 +95,20 @@ impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
|
||||
, consumed_at
|
||||
, oauth2_access_token_id
|
||||
, oauth2_session_id
|
||||
, next_oauth2_refresh_token_id
|
||||
FROM oauth2_refresh_tokens
|
||||
|
||||
WHERE oauth2_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_optional(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@@ -114,6 +132,7 @@ impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
|
||||
, consumed_at
|
||||
, oauth2_access_token_id
|
||||
, oauth2_session_id
|
||||
, next_oauth2_refresh_token_id
|
||||
FROM oauth2_refresh_tokens
|
||||
|
||||
WHERE refresh_token = $1
|
||||
@@ -126,7 +145,7 @@ impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@@ -194,24 +213,28 @@ impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
refresh_token: RefreshToken,
|
||||
replaced_by: &RefreshToken,
|
||||
) -> Result<RefreshToken, Self::Error> {
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_refresh_tokens
|
||||
SET consumed_at = $2
|
||||
SET consumed_at = $2,
|
||||
next_oauth2_refresh_token_id = $3
|
||||
WHERE oauth2_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(refresh_token.id),
|
||||
consumed_at,
|
||||
Uuid::from(replaced_by.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
refresh_token
|
||||
.consume(consumed_at)
|
||||
.consume(consumed_at, replaced_by)
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,6 +80,7 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync {
|
||||
///
|
||||
/// * `clock`: The clock used to generate timestamps
|
||||
/// * `refresh_token`: The [`RefreshToken`] to consume
|
||||
/// * `replaced_by`: The [`RefreshToken`] which replaced this one
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
@@ -89,6 +90,7 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync {
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
refresh_token: RefreshToken,
|
||||
replaced_by: &RefreshToken,
|
||||
) -> Result<RefreshToken, Self::Error>;
|
||||
}
|
||||
|
||||
@@ -113,5 +115,6 @@ repository_impl!(OAuth2RefreshTokenRepository:
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
refresh_token: RefreshToken,
|
||||
replaced_by: &RefreshToken,
|
||||
) -> Result<RefreshToken, Self::Error>;
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user