Add a repository for device code grants

This commit is contained in:
Quentin Gliech
2023-12-07 16:29:01 +01:00
parent d7b2414792
commit 45b7a6a931
9 changed files with 1036 additions and 10 deletions

View File

@@ -17,7 +17,7 @@ use oauth2_types::scope::Scope;
use serde::Serialize;
use ulid::Ulid;
use crate::{BrowserSession, InvalidTransitionError};
use crate::{BrowserSession, InvalidTransitionError, Session};
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case", tag = "state")]
@@ -117,7 +117,7 @@ impl DeviceCodeGrantState {
/// [`Fulfilled`]: DeviceCodeGrantState::Fulfilled
pub fn exchange(
self,
session_id: Ulid,
session: &Session,
exchanged_at: DateTime<Utc>,
) -> Result<Self, InvalidTransitionError> {
match self {
@@ -129,7 +129,7 @@ impl DeviceCodeGrantState {
browser_session_id,
fulfilled_at,
exchanged_at,
session_id,
session_id: session.id,
}),
_ => Err(InvalidTransitionError),
}
@@ -251,11 +251,11 @@ impl DeviceCodeGrant {
/// [`Fulfilled`]: DeviceCodeGrantState::Fulfilled
pub fn exchange(
self,
session_id: Ulid,
session: &Session,
exchanged_at: DateTime<Utc>,
) -> Result<Self, InvalidTransitionError> {
Ok(Self {
state: self.state.exchange(session_id, exchanged_at)?,
state: self.state.exchange(session, exchanged_at)?,
..self
})
}

View File

@@ -0,0 +1,76 @@
-- Copyright 2023 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--- Adds a table to store device codes for OAuth 2.0 device code flows
--
--
-- This has 4 possible states, only going in one direction:
--
-- [[ Pending ]]
-- | |
-- | [ Rejected ] -- The `rejected_at` and `user_session_id` fields are set
-- |
-- [ Fulfilled ] -- The `fulfilled_at` and `user_session_id` fields are set
-- |
-- [ Exchanged ] -- The `exchanged_at` and `oauth2_session_id` fields are also set
--
CREATE TABLE "oauth2_device_code_grant" (
"oauth2_device_code_grant_id" UUID NOT NULL
PRIMARY KEY,
-- The client who initiated the device code grant
"oauth2_client_id" UUID NOT NULL
REFERENCES "oauth2_clients" ("oauth2_client_id")
ON DELETE CASCADE,
-- The scope requested
"scope" TEXT NOT NULL,
-- The random code that is displayed to the user
"user_code" TEXT NOT NULL
UNIQUE,
-- The random code that the client uses to poll for the access token
"device_code" TEXT NOT NULL
UNIQUE,
-- Timestamp when the device code was created
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL,
-- Timestamp when the device code expires
"expires_at" TIMESTAMP WITH TIME ZONE NOT NULL,
-- When the device code was fulfilled, i.e. the user has granted access
-- This is mutually exclusive with rejected_at
"fulfilled_at" TIMESTAMP WITH TIME ZONE,
-- When the device code was rejected, i.e. the user has denied access
-- This is mutually exclusive with fulfilled_at
"rejected_at" TIMESTAMP WITH TIME ZONE,
-- When the device code was exchanged
-- This means "fulfilled_at" has also been set
"exchanged_at" TIMESTAMP WITH TIME ZONE,
-- The OAuth 2.0 session generated for this device code
-- This means "exchanged_at" has also been set
"oauth2_session_id" UUID
REFERENCES "oauth2_sessions" ("oauth2_session_id")
ON DELETE CASCADE,
-- The browser session ID that the user used to authenticate
-- This means "fulfilled_at" or "rejected_at" has also been set
"user_session_id" UUID
REFERENCES "user_sessions" ("user_session_id")
);

View File

@@ -281,6 +281,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi
requires_consent,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
@@ -340,6 +341,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi
"#,
Uuid::from(id),
)
.traced()
.fetch_optional(&mut *self.conn)
.await?;
@@ -427,6 +429,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi
fulfilled_at,
Uuid::from(session.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
@@ -465,6 +468,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi
Uuid::from(grant.id),
exchanged_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
@@ -501,6 +505,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi
"#,
Uuid::from(grant.id),
)
.traced()
.execute(&mut *self.conn)
.await?;

View File

@@ -0,0 +1,463 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{BrowserSession, DeviceCodeGrant, DeviceCodeGrantState, Session};
use mas_storage::{
oauth2::{OAuth2DeviceCodeGrantParams, OAuth2DeviceCodeGrantRepository},
Clock,
};
use oauth2_types::scope::Scope;
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{errors::DatabaseInconsistencyError, DatabaseError, ExecuteExt};
/// An implementation of [`OAuth2DeviceCodeGrantRepository`] for a PostgreSQL
/// connection
pub struct PgOAuth2DeviceCodeGrantRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2DeviceCodeGrantRepository<'c> {
/// Create a new [`PgOAuth2DeviceCodeGrantRepository`] from an active
/// PostgreSQL connection
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct OAuth2DeviceGrantLookup {
oauth2_device_code_grant_id: Uuid,
oauth2_client_id: Uuid,
scope: String,
device_code: String,
user_code: String,
created_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
fulfilled_at: Option<DateTime<Utc>>,
rejected_at: Option<DateTime<Utc>>,
exchanged_at: Option<DateTime<Utc>>,
user_session_id: Option<Uuid>,
oauth2_session_id: Option<Uuid>,
}
impl TryFrom<OAuth2DeviceGrantLookup> for DeviceCodeGrant {
type Error = DatabaseInconsistencyError;
fn try_from(
OAuth2DeviceGrantLookup {
oauth2_device_code_grant_id,
oauth2_client_id,
scope,
device_code,
user_code,
created_at,
expires_at,
fulfilled_at,
rejected_at,
exchanged_at,
user_session_id,
oauth2_session_id,
}: OAuth2DeviceGrantLookup,
) -> Result<Self, Self::Error> {
let id = Ulid::from(oauth2_device_code_grant_id);
let client_id = Ulid::from(oauth2_client_id);
let scope: Scope = scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("scope")
.row(id)
.source(e)
})?;
let state = match (
fulfilled_at,
rejected_at,
exchanged_at,
user_session_id,
oauth2_session_id,
) {
(None, None, None, None, None) => DeviceCodeGrantState::Pending,
(Some(fulfilled_at), None, None, Some(user_session_id), None) => {
DeviceCodeGrantState::Fulfilled {
browser_session_id: Ulid::from(user_session_id),
fulfilled_at,
}
}
(None, Some(rejected_at), None, Some(user_session_id), None) => {
DeviceCodeGrantState::Rejected {
browser_session_id: Ulid::from(user_session_id),
rejected_at,
}
}
(
Some(fulfilled_at),
None,
Some(exchanged_at),
Some(user_session_id),
Some(oauth2_session_id),
) => DeviceCodeGrantState::Exchanged {
browser_session_id: Ulid::from(user_session_id),
session_id: Ulid::from(oauth2_session_id),
fulfilled_at,
exchanged_at,
},
_ => return Err(DatabaseInconsistencyError::on("oauth2_device_code_grant").row(id)),
};
Ok(DeviceCodeGrant {
id,
state,
client_id,
scope,
user_code,
device_code,
created_at,
expires_at,
})
}
}
#[async_trait]
impl<'c> OAuth2DeviceCodeGrantRepository for PgOAuth2DeviceCodeGrantRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_device_code_grant.add",
skip_all,
fields(
db.statement,
oauth2_device_code.id,
oauth2_device_code.scope = %params.scope,
oauth2_client.id = %params.client.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
params: OAuth2DeviceCodeGrantParams<'_>,
) -> Result<DeviceCodeGrant, Self::Error> {
let now = clock.now();
let id = Ulid::from_datetime_with_source(now.into(), rng);
tracing::Span::current().record("oauth2_device_code.id", tracing::field::display(id));
let created_at = now;
let expires_at = now + params.expires_in;
let client_id = params.client.id;
sqlx::query!(
r#"
INSERT INTO "oauth2_device_code_grant"
( oauth2_device_code_grant_id
, oauth2_client_id
, scope
, device_code
, user_code
, created_at
, expires_at
)
VALUES
($1, $2, $3, $4, $5, $6, $7)
"#,
Uuid::from(id),
Uuid::from(client_id),
params.scope.to_string(),
&params.device_code,
&params.user_code,
created_at,
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(DeviceCodeGrant {
id,
state: DeviceCodeGrantState::Pending,
client_id,
scope: params.scope,
user_code: params.user_code,
device_code: params.device_code,
created_at,
expires_at,
})
}
#[tracing::instrument(
name = "db.oauth2_device_code_grant.lookup",
skip_all,
fields(
db.statement,
oauth2_device_code.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<DeviceCodeGrant>, Self::Error> {
let res = sqlx::query_as!(
OAuth2DeviceGrantLookup,
r#"
SELECT oauth2_device_code_grant_id
, oauth2_client_id
, scope
, device_code
, user_code
, created_at
, expires_at
, fulfilled_at
, rejected_at
, exchanged_at
, user_session_id
, oauth2_session_id
FROM
oauth2_device_code_grant
WHERE oauth2_device_code_grant_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_optional(&mut *self.conn)
.await?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_device_code_grant.find_by_user_code",
skip_all,
fields(
db.statement,
oauth2_device_code.user_code = %user_code,
),
err,
)]
async fn find_by_user_code(
&mut self,
user_code: &str,
) -> Result<Option<DeviceCodeGrant>, Self::Error> {
let res = sqlx::query_as!(
OAuth2DeviceGrantLookup,
r#"
SELECT oauth2_device_code_grant_id
, oauth2_client_id
, scope
, device_code
, user_code
, created_at
, expires_at
, fulfilled_at
, rejected_at
, exchanged_at
, user_session_id
, oauth2_session_id
FROM
oauth2_device_code_grant
WHERE user_code = $1
"#,
user_code,
)
.traced()
.fetch_optional(&mut *self.conn)
.await?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_device_code_grant.find_by_device_code",
skip_all,
fields(
db.statement,
oauth2_device_code.device_code = %device_code,
),
err,
)]
async fn find_by_device_code(
&mut self,
device_code: &str,
) -> Result<Option<DeviceCodeGrant>, Self::Error> {
let res = sqlx::query_as!(
OAuth2DeviceGrantLookup,
r#"
SELECT oauth2_device_code_grant_id
, oauth2_client_id
, scope
, device_code
, user_code
, created_at
, expires_at
, fulfilled_at
, rejected_at
, exchanged_at
, user_session_id
, oauth2_session_id
FROM
oauth2_device_code_grant
WHERE device_code = $1
"#,
device_code,
)
.traced()
.fetch_optional(&mut *self.conn)
.await?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_device_code_grant.fulfill",
skip_all,
fields(
db.statement,
oauth2_device_code.id = %device_code_grant.id,
oauth2_client.id = %device_code_grant.client_id,
browser_session.id = %browser_session.id,
user.id = %browser_session.user.id,
),
err,
)]
async fn fulfill(
&mut self,
clock: &dyn Clock,
device_code_grant: DeviceCodeGrant,
browser_session: &BrowserSession,
) -> Result<DeviceCodeGrant, Self::Error> {
let fulfilled_at = clock.now();
let device_code_grant = device_code_grant
.fulfill(&browser_session, fulfilled_at)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE oauth2_device_code_grant
SET fulfilled_at = $1
, user_session_id = $2
WHERE oauth2_device_code_grant_id = $3
"#,
fulfilled_at,
Uuid::from(browser_session.id),
Uuid::from(device_code_grant.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(device_code_grant)
}
#[tracing::instrument(
name = "db.oauth2_device_code_grant.reject",
skip_all,
fields(
db.statement,
oauth2_device_code.id = %device_code_grant.id,
oauth2_client.id = %device_code_grant.client_id,
browser_session.id = %browser_session.id,
user.id = %browser_session.user.id,
),
err,
)]
async fn reject(
&mut self,
clock: &dyn Clock,
device_code_grant: DeviceCodeGrant,
browser_session: &BrowserSession,
) -> Result<DeviceCodeGrant, Self::Error> {
let fulfilled_at = clock.now();
let device_code_grant = device_code_grant
.reject(&browser_session, fulfilled_at)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE oauth2_device_code_grant
SET rejected_at = $1
, user_session_id = $2
WHERE oauth2_device_code_grant_id = $3
"#,
fulfilled_at,
Uuid::from(browser_session.id),
Uuid::from(device_code_grant.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(device_code_grant)
}
#[tracing::instrument(
name = "db.oauth2_device_code_grant.exchange",
skip_all,
fields(
db.statement,
oauth2_device_code.id = %device_code_grant.id,
oauth2_client.id = %device_code_grant.client_id,
oauth2_session.id = %session.id,
),
err,
)]
async fn exchange(
&mut self,
clock: &dyn Clock,
device_code_grant: DeviceCodeGrant,
session: &Session,
) -> Result<DeviceCodeGrant, Self::Error> {
let exchanged_at = clock.now();
let device_code_grant = device_code_grant
.exchange(session, exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE oauth2_device_code_grant
SET exchanged_at = $1
, oauth2_session_id = $2
WHERE oauth2_device_code_grant_id = $3
"#,
exchanged_at,
Uuid::from(session.id),
Uuid::from(device_code_grant.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(device_code_grant)
}
}

View File

@@ -18,12 +18,14 @@
mod access_token;
mod authorization_grant;
mod client;
mod device_code_grant;
mod refresh_token;
mod session;
pub use self::{
access_token::PgOAuth2AccessTokenRepository,
authorization_grant::PgOAuth2AuthorizationGrantRepository, client::PgOAuth2ClientRepository,
device_code_grant::PgOAuth2DeviceCodeGrantRepository,
refresh_token::PgOAuth2RefreshTokenRepository, session::PgOAuth2SessionRepository,
};
@@ -33,7 +35,7 @@ mod tests {
use mas_data_model::AuthorizationCode;
use mas_storage::{
clock::MockClock,
oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
oauth2::{OAuth2DeviceCodeGrantParams, OAuth2SessionFilter, OAuth2SessionRepository},
Clock, Pagination, Repository,
};
use oauth2_types::{
@@ -690,4 +692,226 @@ mod tests {
assert_eq!(list.edges[0], session11);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
}
/// Test the [`OAuth2DeviceCodeGrantRepository`] implementation
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_device_code_grant_repository(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Provision a client
let client = repo
.oauth2_client()
.add(
&mut rng,
&clock,
vec!["https://example.com/redirect".parse().unwrap()],
None,
None,
vec![GrantType::AuthorizationCode],
Vec::new(), // TODO: contacts are not yet saved
// vec!["contact@example.com".to_owned()],
Some("Example".to_owned()),
Some("https://example.com/logo.png".parse().unwrap()),
Some("https://example.com/".parse().unwrap()),
Some("https://example.com/policy".parse().unwrap()),
Some("https://example.com/tos".parse().unwrap()),
Some("https://example.com/jwks.json".parse().unwrap()),
None,
None,
None,
None,
None,
Some("https://example.com/login".parse().unwrap()),
)
.await
.unwrap();
// Provision a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Provision a browser session
let browser_session = repo
.browser_session()
.add(&mut rng, &clock, &user, None)
.await
.unwrap();
let user_code = "usercode";
let device_code = "devicecode";
let scope = Scope::from_iter([OPENID, EMAIL]);
// Create a device code grant
let grant = repo
.oauth2_device_code_grant()
.add(
&mut rng,
&clock,
OAuth2DeviceCodeGrantParams {
client: &client,
scope: scope.clone(),
device_code: device_code.to_owned(),
user_code: user_code.to_owned(),
expires_in: Duration::minutes(5),
},
)
.await
.unwrap();
assert!(grant.is_pending());
// Check that we can find the grant by ID
let id = grant.id;
let lookup = repo.oauth2_device_code_grant().lookup(id).await.unwrap();
assert_eq!(lookup.as_ref(), Some(&grant));
// Check that we can find the grant by device code
let lookup = repo
.oauth2_device_code_grant()
.find_by_device_code(device_code)
.await
.unwrap();
assert_eq!(lookup.as_ref(), Some(&grant));
// Check that we can find the grant by user code
let lookup = repo
.oauth2_device_code_grant()
.find_by_user_code(user_code)
.await
.unwrap();
assert_eq!(lookup.as_ref(), Some(&grant));
// Let's mark it as fulfilled
let grant = repo
.oauth2_device_code_grant()
.fulfill(&clock, grant, &browser_session)
.await
.unwrap();
assert!(!grant.is_pending());
assert!(grant.is_fulfilled());
// Check that we can't mark it as rejected now
let res = repo
.oauth2_device_code_grant()
.reject(&clock, grant, &browser_session)
.await;
assert!(res.is_err());
// Look it up again
let grant = repo
.oauth2_device_code_grant()
.lookup(id)
.await
.unwrap()
.unwrap();
// We can't mark it as fulfilled again
let res = repo
.oauth2_device_code_grant()
.fulfill(&clock, grant, &browser_session)
.await;
assert!(res.is_err());
// Look it up again
let grant = repo
.oauth2_device_code_grant()
.lookup(id)
.await
.unwrap()
.unwrap();
// Create an OAuth 2.0 session
let session = repo
.oauth2_session()
.add_from_browser_session(&mut rng, &clock, &client, &browser_session, scope.clone())
.await
.unwrap();
// We can mark it as exchanged
let grant = repo
.oauth2_device_code_grant()
.exchange(&clock, grant, &session)
.await
.unwrap();
assert!(!grant.is_pending());
assert!(!grant.is_fulfilled());
assert!(grant.is_exchanged());
// We can't mark it as exchanged again
let res = repo
.oauth2_device_code_grant()
.exchange(&clock, grant, &session)
.await;
assert!(res.is_err());
// Do a new grant to reject it
let grant = repo
.oauth2_device_code_grant()
.add(
&mut rng,
&clock,
OAuth2DeviceCodeGrantParams {
client: &client,
scope: scope.clone(),
device_code: "second_devicecode".to_owned(),
user_code: "second_usercode".to_owned(),
expires_in: Duration::minutes(5),
},
)
.await
.unwrap();
let id = grant.id;
// We can mark it as rejected
let grant = repo
.oauth2_device_code_grant()
.reject(&clock, grant, &browser_session)
.await
.unwrap();
assert!(!grant.is_pending());
assert!(grant.is_rejected());
// We can't mark it as rejected again
let res = repo
.oauth2_device_code_grant()
.reject(&clock, grant, &browser_session)
.await;
assert!(res.is_err());
// Look it up again
let grant = repo
.oauth2_device_code_grant()
.lookup(id)
.await
.unwrap()
.unwrap();
// We can't mark it as fulfilled
let res = repo
.oauth2_device_code_grant()
.fulfill(&clock, grant, &browser_session)
.await;
assert!(res.is_err());
// Look it up again
let grant = repo
.oauth2_device_code_grant()
.lookup(id)
.await
.unwrap()
.unwrap();
// We can't mark it as exchanged
let res = repo
.oauth2_device_code_grant()
.exchange(&clock, grant, &session)
.await;
assert!(res.is_err());
}
}

View File

@@ -24,7 +24,7 @@ use mas_storage::{
job::JobRepository,
oauth2::{
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
OAuth2RefreshTokenRepository, OAuth2SessionRepository,
OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
},
upstream_oauth2::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
@@ -45,7 +45,8 @@ use crate::{
job::PgJobRepository,
oauth2::{
PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository,
PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
},
upstream_oauth2::{
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
@@ -220,6 +221,12 @@ where
Box::new(PgOAuth2RefreshTokenRepository::new(self.conn.as_mut()))
}
fn oauth2_device_code_grant<'c>(
&'c mut self,
) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
Box::new(PgOAuth2DeviceCodeGrantRepository::new(self.conn.as_mut()))
}
fn compat_session<'c>(
&'c mut self,
) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {

View File

@@ -0,0 +1,228 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::Duration;
use mas_data_model::{BrowserSession, Client, DeviceCodeGrant, Session};
use oauth2_types::scope::Scope;
use rand_core::RngCore;
use ulid::Ulid;
use crate::{repository_impl, Clock};
/// Parameters used to create a new [`DeviceCodeGrant`]
pub struct OAuth2DeviceCodeGrantParams<'a> {
/// The client which requested the device code grant
pub client: &'a Client,
/// The scope requested by the client
pub scope: Scope,
/// The device code which the client uses to poll for authorisation
pub device_code: String,
/// The user code which the client uses to display to the user
pub user_code: String,
/// After how long the device code expires
pub expires_in: Duration,
}
/// An [`OAuth2DeviceCodeGrantRepository`] helps interacting with
/// [`DeviceCodeGrant`] saved in the storage backend.
#[async_trait]
pub trait OAuth2DeviceCodeGrantRepository: Send + Sync {
/// The error type returned by the repository
type Error;
/// Create a new device code grant
///
/// Returns the newly created device code grant
///
/// # Parameters
///
/// * `rng`: A random number generator
/// * `clock`: The clock used to generate timestamps
/// * `params`: The parameters used to create the device code grant. See the
/// fields of [`DeviceCodeGrantParams`]
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
params: OAuth2DeviceCodeGrantParams<'_>,
) -> Result<DeviceCodeGrant, Self::Error>;
/// Lookup a device code grant by its ID
///
/// Returns the device code grant if found, [`None`] otherwise
///
/// # Parameters
///
/// * `id`: The ID of the device code grant
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn lookup(&mut self, id: Ulid) -> Result<Option<DeviceCodeGrant>, Self::Error>;
/// Lookup a device code grant by its device code
///
/// Returns the device code grant if found, [`None`] otherwise
///
/// # Parameters
///
/// * `device_code`: The device code of the device code grant
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn find_by_device_code(
&mut self,
device_code: &str,
) -> Result<Option<DeviceCodeGrant>, Self::Error>;
/// Lookup a device code grant by its user code
///
/// Returns the device code grant if found, [`None`] otherwise
///
/// # Parameters
///
/// * `user_code`: The user code of the device code grant
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn find_by_user_code(
&mut self,
user_code: &str,
) -> Result<Option<DeviceCodeGrant>, Self::Error>;
/// Mark the device code grant as fulfilled with the given browser session
///
/// Returns the updated device code grant
///
/// # Parameters
///
/// * `clock`: The clock used to generate timestamps
/// * `device_code_grant`: The device code grant to fulfill
/// * `browser_session`: The browser session which was used to fulfill the
/// device code grant
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails or if the
/// device code grant is not in the [`Pending`] state
///
/// [`Pending`]: DeviceCodeGrantState::Pending
async fn fulfill(
&mut self,
clock: &dyn Clock,
device_code_grant: DeviceCodeGrant,
browser_session: &BrowserSession,
) -> Result<DeviceCodeGrant, Self::Error>;
/// Mark the device code grant as rejected with the given browser session
///
/// Returns the updated device code grant
///
/// # Parameters
///
/// * `clock`: The clock used to generate timestamps
/// * `device_code_grant`: The device code grant to reject
/// * `browser_session`: The browser session which was used to reject the
/// device code grant
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails or if the
/// device code grant is not in the [`Pending`] state
///
/// [`Pending`]: DeviceCodeGrantState::Pending
async fn reject(
&mut self,
clock: &dyn Clock,
device_code_grant: DeviceCodeGrant,
browser_session: &BrowserSession,
) -> Result<DeviceCodeGrant, Self::Error>;
/// Mark the device code grant as exchanged and store the session which was
/// created
///
/// Returns the updated device code grant
///
/// # Parameters
///
/// * `clock`: The clock used to generate timestamps
/// * `device_code_grant`: The device code grant to exchange
/// * `session`: The OAuth 2.0 session which was created
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails or if the
/// device code grant is not in the [`Fulfilled`] state
///
/// [`Fulfilled`]: DeviceCodeGrantState::Fulfilled
async fn exchange(
&mut self,
clock: &dyn Clock,
device_code_grant: DeviceCodeGrant,
session: &Session,
) -> Result<DeviceCodeGrant, Self::Error>;
}
repository_impl!(OAuth2DeviceCodeGrantRepository:
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
params: OAuth2DeviceCodeGrantParams<'_>,
) -> Result<DeviceCodeGrant, Self::Error>;
async fn lookup(&mut self, id: Ulid) -> Result<Option<DeviceCodeGrant>, Self::Error>;
async fn find_by_device_code(
&mut self,
device_code: &str,
) -> Result<Option<DeviceCodeGrant>, Self::Error>;
async fn find_by_user_code(
&mut self,
user_code: &str,
) -> Result<Option<DeviceCodeGrant>, Self::Error>;
async fn fulfill(
&mut self,
clock: &dyn Clock,
device_code_grant: DeviceCodeGrant,
browser_session: &BrowserSession,
) -> Result<DeviceCodeGrant, Self::Error>;
async fn reject(
&mut self,
clock: &dyn Clock,
device_code_grant: DeviceCodeGrant,
browser_session: &BrowserSession,
) -> Result<DeviceCodeGrant, Self::Error>;
async fn exchange(
&mut self,
clock: &dyn Clock,
device_code_grant: DeviceCodeGrant,
session: &Session,
) -> Result<DeviceCodeGrant, Self::Error>;
);

View File

@@ -17,6 +17,7 @@
mod access_token;
mod authorization_grant;
mod client;
mod device_code_grant;
mod refresh_token;
mod session;
@@ -24,6 +25,7 @@ pub use self::{
access_token::OAuth2AccessTokenRepository,
authorization_grant::OAuth2AuthorizationGrantRepository,
client::OAuth2ClientRepository,
device_code_grant::{OAuth2DeviceCodeGrantParams, OAuth2DeviceCodeGrantRepository},
refresh_token::OAuth2RefreshTokenRepository,
session::{OAuth2SessionFilter, OAuth2SessionRepository},
};

View File

@@ -24,7 +24,7 @@ use crate::{
job::JobRepository,
oauth2::{
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
OAuth2RefreshTokenRepository, OAuth2SessionRepository,
OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
},
upstream_oauth2::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
@@ -178,6 +178,11 @@ pub trait RepositoryAccess: Send {
&'c mut self,
) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c>;
/// Get an [`OAuth2DeviceCodeGrantRepository`]
fn oauth2_device_code_grant<'c>(
&'c mut self,
) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c>;
/// Get a [`CompatSessionRepository`]
fn compat_session<'c>(
&'c mut self,
@@ -217,7 +222,8 @@ mod impls {
job::JobRepository,
oauth2::{
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
OAuth2ClientRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository,
OAuth2SessionRepository,
},
upstream_oauth2::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
@@ -360,6 +366,15 @@ mod impls {
))
}
fn oauth2_device_code_grant<'c>(
&'c mut self,
) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.oauth2_device_code_grant(),
&mut self.mapper,
))
}
fn compat_session<'c>(
&'c mut self,
) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
@@ -472,6 +487,12 @@ mod impls {
(**self).oauth2_refresh_token()
}
fn oauth2_device_code_grant<'c>(
&'c mut self,
) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
(**self).oauth2_device_code_grant()
}
fn compat_session<'c>(
&'c mut self,
) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {