From a82074fdb189a9bbc53af4edf59704d4bde08464 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 30 Dec 2022 10:16:22 +0100 Subject: [PATCH 01/45] WIP: repository pattern for upstream oauth2 links --- Cargo.lock | 4 +- Cargo.toml | 6 + crates/graphql/src/lib.rs | 3 +- crates/graphql/src/model/users.rs | 12 +- .../handlers/src/upstream_oauth2/callback.rs | 14 +- crates/handlers/src/upstream_oauth2/link.rs | 22 +- crates/handlers/src/views/shared.rs | 7 +- crates/storage/Cargo.toml | 1 + crates/storage/sqlx-data.json | 148 +++--- crates/storage/src/lib.rs | 3 + crates/storage/src/pagination.rs | 8 + crates/storage/src/repository.rs | 41 ++ crates/storage/src/upstream_oauth2/link.rs | 432 ++++++++++-------- crates/storage/src/upstream_oauth2/mod.rs | 5 +- 14 files changed, 419 insertions(+), 287 deletions(-) create mode 100644 crates/storage/src/repository.rs diff --git a/Cargo.lock b/Cargo.lock index 059e91c04..bc0b37baf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3097,6 +3097,7 @@ dependencies = [ name = "mas-storage" version = "0.1.0" dependencies = [ + "async-trait", "chrono", "mas-data-model", "mas-iana", @@ -5575,8 +5576,7 @@ checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" [[package]] name = "ulid" version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13a3aaa69b04e5b66cc27309710a569ea23593612387d67daaf102e73aa974fd" +source = "git+https://github.com/sandhose/ulid-rs.git?rev=f1ef6fd736c4d3cbc7cf314fad707f0803de46ed#f1ef6fd736c4d3cbc7cf314fad707f0803de46ed" dependencies = [ "rand 0.8.5", "serde", diff --git a/Cargo.toml b/Cargo.toml index f621be0cc..9799f34b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,9 @@ opt-level = 3 [profile.dev.package.sqlx-macros] opt-level = 3 + +# Until https://github.com/dylanhart/ulid-rs/pull/56 gets merged and released +[patch.crates-io.ulid] +git = "https://github.com/sandhose/ulid-rs.git" +#branch = "relax-sized-on-rng" +rev = "f1ef6fd736c4d3cbc7cf314fad707f0803de46ed" diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 1e691a96e..9a86ecbe5 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -30,6 +30,7 @@ use async_graphql::{ connection::{query, Connection, Edge, OpaqueCursor}, Context, Description, EmptyMutation, EmptySubscription, ID, }; +use mas_storage::{Repository, UpstreamOAuthLinkRepository}; use model::CreationEvent; use sqlx::PgPool; @@ -171,7 +172,7 @@ impl RootQuery { let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let link = mas_storage::upstream_oauth2::lookup_link(&mut conn, id).await?; + let link = conn.upstream_oauth_link().lookup(id).await?; // Ensure that the link belongs to the current user let link = link.filter(|link| link.user_id == Some(current_user.id)); diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index ad8bfa435..01fcfb0eb 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -17,6 +17,7 @@ use async_graphql::{ Context, Description, Object, ID, }; use chrono::{DateTime, Utc}; +use mas_storage::{Repository, UpstreamOAuthLinkRepository}; use sqlx::PgPool; use super::{ @@ -285,14 +286,13 @@ impl User { }) .transpose()?; - let (has_previous_page, has_next_page, edges) = - mas_storage::upstream_oauth2::get_paginated_user_links( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = conn + .upstream_oauth_link() + .list_paginated(&self.0, before_id, after_id, first, last) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|s| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|s| { Edge::new( OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Link, s.id)), UpstreamOAuth2Link::new(s), diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index ab31641c8..6158f9413 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -25,8 +25,9 @@ use mas_oidc_client::requests::{ authorization_code::AuthorizationValidationData, jose::JwtVerificationData, }; use mas_router::{Route, UrlBuilder}; -use mas_storage::upstream_oauth2::{ - add_link, complete_session, lookup_link_by_subject, lookup_session, +use mas_storage::{ + upstream_oauth2::{complete_session, lookup_session}, + Repository, UpstreamOAuthLinkRepository, }; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; @@ -231,12 +232,17 @@ pub(crate) async fn get( let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?; // Look for an existing link - let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject).await?; + let maybe_link = txn + .upstream_oauth_link() + .find_by_subject(&provider, &subject) + .await?; let link = if let Some(link) = maybe_link { link } else { - add_link(&mut txn, &mut rng, &clock, &provider, subject).await? + txn.upstream_oauth_link() + .add(&mut rng, &clock, &provider, subject) + .await? }; let session = complete_session(&mut txn, &clock, session, &link, response.id_token).await?; diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 15c5ac93d..4a109ba6e 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -25,10 +25,9 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_storage::{ - upstream_oauth2::{ - associate_link_to_user, consume_session, lookup_link, lookup_session_on_link, - }, + upstream_oauth2::{consume_session, lookup_session_on_link}, user::{add_user, authenticate_session_with_upstream, lookup_user, start_session}, + Repository, UpstreamOAuthLinkRepository, }; use mas_templates::{ EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, @@ -104,7 +103,9 @@ pub(crate) async fn get( .lookup_link(link_id) .map_err(|_| RouteError::MissingCookie)?; - let link = lookup_link(&mut txn, link_id) + let link = txn + .upstream_oauth_link() + .lookup(link_id) .await? .ok_or(RouteError::LinkNotFound)?; @@ -205,7 +206,9 @@ pub(crate) async fn post( post_auth_action: post_auth_action.cloned(), }; - let link = lookup_link(&mut txn, link_id) + let link = txn + .upstream_oauth_link() + .lookup(link_id) .await? .ok_or(RouteError::LinkNotFound)?; @@ -224,7 +227,10 @@ pub(crate) async fn post( let mut session = match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { - associate_link_to_user(&mut txn, &link, &session.user).await?; + txn.upstream_oauth_link() + .associate_to_user(&link, &session.user) + .await?; + session } @@ -235,7 +241,9 @@ pub(crate) async fn post( (None, None, FormData::Register { username }) => { let user = add_user(&mut txn, &mut rng, &clock, &username).await?; - associate_link_to_user(&mut txn, &link, &user).await?; + txn.upstream_oauth_link() + .associate_to_user(&link, &user) + .await?; start_session(&mut txn, &mut rng, &clock, user).await? } diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index fcdef3b4a..d4b190025 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -15,7 +15,8 @@ use anyhow::Context; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, + compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, Repository, + UpstreamOAuthLinkRepository, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; use serde::{Deserialize, Serialize}; @@ -63,7 +64,9 @@ impl OptionalPostAuthAction { PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword, PostAuthAction::LinkUpstream { id } => { - let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id) + let link = conn + .upstream_oauth_link() + .lookup(id) .await? .context("Failed to load upstream OAuth 2.0 link")?; diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index b0ed4c5e5..71240129f 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" license = "Apache-2.0" [dependencies] +async-trait = "0.1.60" sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] } chrono = { version = "0.4.23", features = ["serde"] } serde = { version = "1.0.152", features = ["derive"] } diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 1ce99d797..63368ec0c 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -241,19 +241,6 @@ }, "query": "\n UPDATE user_emails\n SET confirmed_at = $2\n WHERE user_email_id = $1\n " }, - "1e7b1b7e06b5d97d81dc4a8524bb223c3dc7ddbbcce7cc2a142dbfbdd6a2902e": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " - }, "1eb6d13e75d8f526c2785749a020731c18012f03e07995213acd38ab560ce497": { "describe": { "columns": [], @@ -882,6 +869,50 @@ }, "query": "\n INSERT INTO user_emails (user_email_id, user_id, email, created_at)\n VALUES ($1, $2, $3, $4)\n " }, + "4187907bfc770b2c76f741671d5e672f5c35eed7c9a9e57ff52888b1768a5ed6": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_link_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_provider_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "subject", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_link_id = $1\n " + }, "42bfb0de5bbea2d580f1ff2322255731a4a5655ba80fc2dba0b55a0add8c55c0": { "describe": { "columns": [ @@ -1043,50 +1074,6 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n oauth2_session_id = os.oauth2_session_id,\n fulfilled_at = os.created_at\n FROM oauth2_sessions os\n WHERE\n og.oauth2_authorization_grant_id = $1\n AND os.oauth2_session_id = $2\n RETURNING fulfilled_at AS \"fulfilled_at!: DateTime\"\n " }, - "47d4048365144c7bfc14790dfb8fa7f862d2952075a68cd5e90ac76d9e6d1388": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_link_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_provider_id", - "ordinal": 1, - "type_info": "Uuid" - }, - { - "name": "user_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "subject", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 4, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_link_id = $1\n " - }, "4f8ec19f3f1bfe0268fe102a24e5a9fa542e77eccbebdce65e6deb1c197adf36": { "describe": { "columns": [ @@ -1345,6 +1332,21 @@ }, "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n\n ORDER BY ue.email ASC\n " }, + "5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n " + }, "60d039442cfa57e187602c0ff5e386e32fb774b5ad2d2f2c616040819b76873e": { "describe": { "columns": [], @@ -1837,6 +1839,19 @@ }, "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = ANY($1::uuid[])\n " }, + "7ce387b1b0aaf10e72adde667b19521b66eaafa51f73bf2f95e38b8f3b64a229": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + } + }, + "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " + }, "7cf5ae665b15ba78b01bb1dfa304150a89fd7203f4ee15b0753cb2143049a3dc": { "describe": { "columns": [ @@ -2687,21 +2702,6 @@ }, "query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n " }, - "e1dc9dd2bf26a341050a53151bf51f7638448ccc2bd458bbdfe87cc22f086313": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n " - }, "e30562e9637d3a723a91adca6336a8d083657ce6d7fe9551fcd6a9d672453d3c": { "describe": { "columns": [], @@ -2729,7 +2729,7 @@ }, "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " }, - "f71cb5761bfc15d8bc3ba7ee49b63fb3c3ea9691745688eb5fd91f4f6e1ec018": { + "e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": { "describe": { "columns": [ { @@ -2772,7 +2772,7 @@ ] } }, - "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " + "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " }, "fb71ac6539039313fd90b29ac943330e54c7b62b2778727726e2f60a554f9c5a": { "describe": { diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index f059e376c..268652016 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -178,8 +178,11 @@ impl Clock { pub mod compat; pub mod oauth2; pub(crate) mod pagination; +pub(crate) mod repository; pub mod upstream_oauth2; pub mod user; +pub use self::{repository::Repository, upstream_oauth2::UpstreamOAuthLinkRepository}; + /// Embedded migrations, allowing them to run on startup pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/crates/storage/src/pagination.rs b/crates/storage/src/pagination.rs index 956556750..a240c554e 100644 --- a/crates/storage/src/pagination.rs +++ b/crates/storage/src/pagination.rs @@ -111,6 +111,14 @@ pub fn process_page( Ok((has_previous_page, has_next_page, page)) } +pub struct Page { + pub has_next_page: bool, + pub has_previous_page: bool, + pub edges: Vec, +} + +impl Page {} + pub trait QueryBuilderExt { fn generate_pagination( &mut self, diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs new file mode 100644 index 000000000..0bfc25216 --- /dev/null +++ b/crates/storage/src/repository.rs @@ -0,0 +1,41 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use sqlx::{PgConnection, Postgres, Transaction}; + +use crate::upstream_oauth2::PgUpstreamOAuthLinkRepository; + +pub trait Repository { + type UpstreamOAuthLinkRepository<'c> + where + Self: 'c; + + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; +} + +impl Repository for PgConnection { + type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; + + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { + PgUpstreamOAuthLinkRepository::new(self) + } +} + +impl<'t> Repository for Transaction<'t, Postgres> { + type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; + + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { + PgUpstreamOAuthLinkRepository::new(self) + } +} diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 931b2b7d6..3849af3c4 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -12,19 +12,71 @@ // See the License for the specific language governing permissions and // limitations under the License. +use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; -use rand::Rng; -use sqlx::{PgExecutor, QueryBuilder}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; use crate::{ - pagination::{process_page, QueryBuilderExt}, + pagination::{process_page, Page, QueryBuilderExt}, Clock, DatabaseError, LookupResultExt, }; +#[async_trait] +pub trait UpstreamOAuthLinkRepository: Send + Sync { + type Error; + + /// Lookup an upstream OAuth link by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find an upstream OAuth link for a provider by its subject + async fn find_by_subject( + &mut self, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: &str, + ) -> Result, Self::Error>; + + /// Add a new upstream OAuth link + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: String, + ) -> Result; + + /// Associate an upstream OAuth link to a user + async fn associate_to_user( + &mut self, + upstream_oauth_link: &UpstreamOAuthLink, + user: &User, + ) -> Result<(), Self::Error>; + + /// Get a paginated list of upstream OAuth links + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error>; +} + +pub struct PgUpstreamOAuthLinkRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthLinkRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + #[derive(sqlx::FromRow)] struct LinkLookup { upstream_oauth_link_id: Uuid, @@ -46,197 +98,203 @@ impl From for UpstreamOAuthLink { } } -#[tracing::instrument( - skip_all, - fields(upstream_oauth_link.id = %id), - err, -)] -pub async fn lookup_link( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - LinkLookup, - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - WHERE upstream_oauth_link_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()? - .map(Into::into); +#[async_trait] +impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { + type Error = DatabaseError; - Ok(res) -} + #[tracing::instrument( + skip_all, + fields(upstream_oauth_link.id = %id), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + LinkLookup, + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + WHERE upstream_oauth_link_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()? + .map(Into::into); -#[tracing::instrument( - skip_all, - fields( - upstream_oauth_link.subject = subject, - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - ), - err, -)] -pub async fn lookup_link_by_subject( - executor: impl PgExecutor<'_>, - upstream_oauth_provider: &UpstreamOAuthProvider, - subject: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - LinkLookup, - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - WHERE upstream_oauth_provider_id = $1 - AND subject = $2 - "#, - Uuid::from(upstream_oauth_provider.id), - subject, - ) - .fetch_one(executor) - .await - .to_option()? - .map(Into::into); + Ok(res) + } - Ok(res) -} + #[tracing::instrument( + skip_all, + fields( + upstream_oauth_link.subject = subject, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + ), + err, + )] + async fn find_by_subject( + &mut self, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + LinkLookup, + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + WHERE upstream_oauth_provider_id = $1 + AND subject = $2 + "#, + Uuid::from(upstream_oauth_provider.id), + subject, + ) + .fetch_one(&mut *self.conn) + .await + .to_option()? + .map(Into::into); -#[tracing::instrument( - skip_all, - fields( - upstream_oauth_link.id, - upstream_oauth_link.subject = subject, - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - ), - err, -)] -pub async fn add_link( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - upstream_oauth_provider: &UpstreamOAuthProvider, - subject: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); + Ok(res) + } - sqlx::query!( - r#" - INSERT INTO upstream_oauth_links ( - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - ) VALUES ($1, $2, NULL, $3, $4) - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_provider.id), - &subject, - created_at, - ) - .execute(executor) - .await?; + #[tracing::instrument( + skip_all, + fields( + upstream_oauth_link.id, + upstream_oauth_link.subject = subject, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); - Ok(UpstreamOAuthLink { - id, - provider_id: upstream_oauth_provider.id, - user_id: None, - subject, - created_at, - }) -} - -#[tracing::instrument( - skip_all, - fields( - %upstream_oauth_link.id, - %upstream_oauth_link.subject, - %user.id, - %user.username, - ), - err, -)] -pub async fn associate_link_to_user( - executor: impl PgExecutor<'_>, - upstream_oauth_link: &UpstreamOAuthLink, - user: &User, -) -> Result<(), sqlx::Error> { - sqlx::query!( - r#" - UPDATE upstream_oauth_links - SET user_id = $1 - WHERE upstream_oauth_link_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(upstream_oauth_link.id), - ) - .execute(executor) - .await?; - - Ok(()) -} - -#[tracing::instrument( - skip_all, - fields(%user.id, %user.username), - err -)] -pub async fn get_paginated_user_links( - executor: impl PgExecutor<'_>, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - "#, - ); - - query - .push(" WHERE user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("upstream_oauth_link_id", before, after, first, last)?; - - let span = info_span!( - "Fetch paginated upstream OAuth 2.0 user links", - db.statement = query.sql() - ); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) + sqlx::query!( + r#" + INSERT INTO upstream_oauth_links ( + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + ) VALUES ($1, $2, NULL, $3, $4) + "#, + Uuid::from(id), + Uuid::from(upstream_oauth_provider.id), + &subject, + created_at, + ) + .execute(&mut *self.conn) .await?; - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; + Ok(UpstreamOAuthLink { + id, + provider_id: upstream_oauth_provider.id, + user_id: None, + subject, + created_at, + }) + } - let page: Vec<_> = page.into_iter().map(Into::into).collect(); - Ok((has_previous_page, has_next_page, page)) + #[tracing::instrument( + skip_all, + fields( + %upstream_oauth_link.id, + %upstream_oauth_link.subject, + %user.id, + %user.username, + ), + err, + )] + async fn associate_to_user( + &mut self, + upstream_oauth_link: &UpstreamOAuthLink, + user: &User, + ) -> Result<(), Self::Error> { + sqlx::query!( + r#" + UPDATE upstream_oauth_links + SET user_id = $1 + WHERE upstream_oauth_link_id = $2 + "#, + Uuid::from(user.id), + Uuid::from(upstream_oauth_link.id), + ) + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + #[tracing::instrument( + skip_all, + fields(%user.id, %user.username), + err + )] + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("upstream_oauth_link_id", before, after, first, last)?; + + let span = info_span!( + "Fetch paginated upstream OAuth 2.0 user links", + db.statement = query.sql() + ); + let page: Vec = query + .build_query_as() + .fetch_all(&mut *self.conn) + .instrument(span) + .await?; + + let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; + + let edges: Vec<_> = edges.into_iter().map(Into::into).collect(); + Ok(Page { + has_next_page, + has_previous_page, + edges, + }) + } } diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 4b1d517a6..4842fb475 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -17,10 +17,7 @@ mod provider; mod session; pub use self::{ - link::{ - add_link, associate_link_to_user, get_paginated_user_links, lookup_link, - lookup_link_by_subject, - }, + link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository}, provider::{add_provider, get_paginated_providers, get_providers, lookup_provider}, session::{ add_session, complete_session, consume_session, lookup_session, lookup_session_on_link, From 9b1dc0880a70b8aea4709fb9184fdbb8e6c2c778 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 30 Dec 2022 10:55:37 +0100 Subject: [PATCH 02/45] storage: repository pattern for upstream oauth2 providers --- crates/cli/src/commands/manage.rs | 28 +- crates/graphql/src/lib.rs | 17 +- crates/graphql/src/model/upstream_oauth.rs | 4 +- .../handlers/src/upstream_oauth2/authorize.rs | 6 +- crates/handlers/src/views/login.rs | 7 +- crates/handlers/src/views/shared.rs | 13 +- crates/storage/sqlx-data.json | 244 ++++++------ crates/storage/src/repository.rs | 17 +- crates/storage/src/upstream_oauth2/link.rs | 2 +- crates/storage/src/upstream_oauth2/mod.rs | 2 +- .../storage/src/upstream_oauth2/provider.rs | 349 ++++++++++-------- 11 files changed, 380 insertions(+), 309 deletions(-) diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 5472b78c7..a92f35af7 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -19,10 +19,11 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod use mas_router::UrlBuilder; use mas_storage::{ oauth2::client::{insert_client_from_config, lookup_client, truncate_clients}, + upstream_oauth2::UpstreamOAuthProviderRepository, user::{ add_user_password, lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, }, - Clock, + Clock, Repository, }; use oauth2_types::scope::Scope; use rand::SeedableRng; @@ -329,18 +330,19 @@ impl Options { .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .transpose()?; - let provider = mas_storage::upstream_oauth2::add_provider( - &mut conn, - &mut rng, - &clock, - issuer.clone(), - scope.clone(), - token_endpoint_auth_method, - token_endpoint_signing_alg, - client_id.clone(), - encrypted_client_secret, - ) - .await?; + let provider = conn + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + issuer.clone(), + scope.clone(), + token_endpoint_auth_method, + token_endpoint_signing_alg, + client_id.clone(), + encrypted_client_secret, + ) + .await?; let redirect_uri = url_builder.upstream_oauth_callback(provider.id); let auth_uri = url_builder.upstream_oauth_authorize(provider.id); diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 9a86ecbe5..8f3ef3219 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -30,7 +30,9 @@ use async_graphql::{ connection::{query, Connection, Edge, OpaqueCursor}, Context, Description, EmptyMutation, EmptySubscription, ID, }; -use mas_storage::{Repository, UpstreamOAuthLinkRepository}; +use mas_storage::{ + upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository, +}; use model::CreationEvent; use sqlx::PgPool; @@ -190,7 +192,7 @@ impl RootQuery { let database = ctx.data::()?; let mut conn = database.acquire().await?; - let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id).await?; + let provider = conn.upstream_oauth_provider().lookup(id).await?; Ok(provider.map(UpstreamOAuth2Provider::new)) } @@ -227,14 +229,13 @@ impl RootQuery { }) .transpose()?; - let (has_previous_page, has_next_page, edges) = - mas_storage::upstream_oauth2::get_paginated_providers( - &mut conn, before_id, after_id, first, last, - ) + let page = conn + .upstream_oauth_provider() + .list_paginated(before_id, after_id, first, last) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|p| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|p| { Edge::new( OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)), UpstreamOAuth2Provider::new(p), diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index 87164dd4e..2de6f2f71 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -15,6 +15,7 @@ use anyhow::Context as _; use async_graphql::{Context, Object, ID}; use chrono::{DateTime, Utc}; +use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository}; use sqlx::PgPool; use super::{NodeType, User}; @@ -101,7 +102,8 @@ impl UpstreamOAuth2Link { // Fetch on-the-fly let database = ctx.data::()?; let mut conn = database.acquire().await?; - mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id) + conn.upstream_oauth_provider() + .lookup(self.link.provider_id) .await? .context("Upstream OAuth 2.0 provider not found")? }; diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 787124512..5e69f4169 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -22,7 +22,7 @@ use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_keystore::Encrypter; use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; -use mas_storage::upstream_oauth2::lookup_provider; +use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository}; use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -66,7 +66,9 @@ pub(crate) async fn get( let mut txn = pool.begin().await?; - let provider = lookup_provider(&mut txn, provider_id) + let provider = txn + .upstream_oauth_provider() + .lookup(provider_id) .await? .ok_or(RouteError::ProviderNotFound)?; diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 24fc17b75..fd54175d2 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -24,11 +24,12 @@ use mas_axum_utils::{ use mas_data_model::BrowserSession; use mas_keystore::Encrypter; use mas_storage::{ + upstream_oauth2::UpstreamOAuthProviderRepository, user::{ add_user_password, authenticate_session_with_password, lookup_user_by_username, lookup_user_password, start_session, }, - Clock, + Clock, Repository, }; use mas_templates::{ FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, @@ -69,7 +70,7 @@ pub(crate) async fn get( let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) } else { - let providers = mas_storage::upstream_oauth2::get_providers(&mut conn).await?; + let providers = conn.upstream_oauth_provider().all().await?; let content = render( LoginContext::default().with_upstrem_providers(providers), query, @@ -114,7 +115,7 @@ pub(crate) async fn post( }; if !state.is_valid() { - let providers = mas_storage::upstream_oauth2::get_providers(&mut conn).await?; + let providers = conn.upstream_oauth_provider().all().await?; let content = render( LoginContext::default() .with_form_state(state) diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index d4b190025..6035c74db 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -15,8 +15,8 @@ use anyhow::Context; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, Repository, - UpstreamOAuthLinkRepository, + compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, + upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; use serde::{Deserialize, Serialize}; @@ -70,10 +70,11 @@ impl OptionalPostAuthAction { .await? .context("Failed to load upstream OAuth 2.0 link")?; - let provider = - mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id) - .await? - .context("Failed to load upstream OAuth 2.0 provider")?; + let provider = conn + .upstream_oauth_provider() + .lookup(link.provider_id) + .await? + .context("Failed to load upstream OAuth 2.0 provider")?; let provider = Box::new(provider); let link = Box::new(link); diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 63368ec0c..52b0118fd 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -116,68 +116,6 @@ }, "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = $1\n " }, - "0af182315b36766eca8e232280986bade0202d1b1d64160a99cd14eadcbfc25b": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_provider_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "issuer", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "scope", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "client_id", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "encrypted_client_secret", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "token_endpoint_signing_alg", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_method", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 7, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n " - }, "0b49cde0b7b79f79ec261502ab89bcffa81f9f5ed2f922a41b1718274b9e3073": { "describe": { "columns": [ @@ -241,6 +179,66 @@ }, "query": "\n UPDATE user_emails\n SET confirmed_at = $2\n WHERE user_email_id = $1\n " }, + "154e2e4488ff87e09163698750b56a43127cee4e1392785416a586d40a4d9b21": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_provider_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "issuer", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "scope", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "client_id", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "encrypted_client_secret", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "token_endpoint_signing_alg", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_method", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 7, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + false, + false + ], + "parameters": { + "Left": [] + } + }, + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n " + }, "1eb6d13e75d8f526c2785749a020731c18012f03e07995213acd38ab560ce497": { "describe": { "columns": [], @@ -2089,6 +2087,68 @@ }, "query": "\n SELECT COUNT(*)\n FROM user_emails ue\n WHERE ue.user_id = $1\n " }, + "8f7a9fb1f24c24f8dbc3c193df2a742c9ac730ab958587b67297de2d4b843863": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_provider_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "issuer", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "scope", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "client_id", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "encrypted_client_secret", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "token_endpoint_signing_alg", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_method", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 7, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n " + }, "99f5f9eb0adc5ec120ed8194cbf6a8545155bef09e6d94d92fb67fd1b14d4f28": { "describe": { "columns": [], @@ -2586,66 +2646,6 @@ }, "query": "\n INSERT INTO oauth2_clients\n (oauth2_client_id,\n encrypted_client_secret,\n grant_type_authorization_code,\n grant_type_refresh_token,\n client_name,\n logo_uri,\n client_uri,\n policy_uri,\n tos_uri,\n jwks_uri,\n jwks,\n id_token_signed_response_alg,\n userinfo_signed_response_alg,\n token_endpoint_auth_method,\n token_endpoint_auth_signing_alg,\n initiate_login_uri)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " }, - "cf00e0ad529bcb5c0640adcfe0880a3560d9739f355b90ca3ba88dd1eaf26565": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_provider_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "issuer", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "scope", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "client_id", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "encrypted_client_secret", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "token_endpoint_signing_alg", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_method", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 7, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - false, - false - ], - "parameters": { - "Left": [] - } - }, - "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n " - }, "d1738c27339b81f0844da4bd9b040b9b07a91aa4d9b199b98f24c9cee5709b2b": { "describe": { "columns": [], diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 0bfc25216..c1d259fc8 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -14,28 +14,43 @@ use sqlx::{PgConnection, Postgres, Transaction}; -use crate::upstream_oauth2::PgUpstreamOAuthLinkRepository; +use crate::upstream_oauth2::{PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository}; pub trait Repository { type UpstreamOAuthLinkRepository<'c> where Self: 'c; + type UpstreamOAuthProviderRepository<'c> + where + Self: 'c; + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; + fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; } impl Repository for PgConnection { type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; + type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) } + + fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { + PgUpstreamOAuthProviderRepository::new(self) + } } impl<'t> Repository for Transaction<'t, Postgres> { type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; + type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) } + + fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { + PgUpstreamOAuthProviderRepository::new(self) + } } diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 3849af3c4..100e98336 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -56,7 +56,7 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { user: &User, ) -> Result<(), Self::Error>; - /// Get a paginated list of upstream OAuth links + /// Get a paginated list of upstream OAuth links on a user async fn list_paginated( &mut self, user: &User, diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 4842fb475..d29b5e719 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -18,7 +18,7 @@ mod session; pub use self::{ link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository}, - provider::{add_provider, get_paginated_providers, get_providers, lookup_provider}, + provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository}, session::{ add_session, complete_session, consume_session, lookup_session, lookup_session_on_link, }, diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 360b9a4af..3d8ba1417 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -12,21 +12,66 @@ // See the License for the specific language governing permissions and // limitations under the License. +use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::UpstreamOAuthProvider; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use oauth2_types::scope::Scope; -use rand::Rng; -use sqlx::{PgExecutor, QueryBuilder}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; use crate::{ - pagination::{process_page, QueryBuilderExt}, + pagination::{process_page, Page, QueryBuilderExt}, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; +#[async_trait] +pub trait UpstreamOAuthProviderRepository: Send + Sync { + type Error; + + /// Lookup an upstream OAuth provider by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Add a new upstream OAuth provider + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option, + ) -> Result; + + /// Get a paginated list of upstream OAuth providers + async fn list_paginated( + &mut self, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error>; + + /// Get all upstream OAuth providers + async fn all(&mut self) -> Result, Self::Error>; +} + +pub struct PgUpstreamOAuthProviderRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthProviderRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + #[derive(sqlx::FromRow)] struct ProviderLookup { upstream_oauth_provider_id: Uuid, @@ -79,71 +124,72 @@ impl TryFrom for UpstreamOAuthProvider { } } -#[tracing::instrument( - skip_all, - fields(upstream_oauth_provider.id = %id), - err, -)] -pub async fn lookup_provider( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - ProviderLookup, - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - WHERE upstream_oauth_provider_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; +#[async_trait] +impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> { + type Error = DatabaseError; - let res = res - .map(UpstreamOAuthProvider::try_from) - .transpose() - .map_err(DatabaseError::from)?; + #[tracing::instrument( + skip_all, + fields(upstream_oauth_provider.id = %id), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + ProviderLookup, + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; - Ok(res) -} + let res = res + .map(UpstreamOAuthProvider::try_from) + .transpose() + .map_err(DatabaseError::from)?; -#[tracing::instrument( - skip_all, - fields( - upstream_oauth_provider.id, - upstream_oauth_provider.issuer = %issuer, - upstream_oauth_provider.client_id = %client_id, - ), - err, -)] -#[allow(clippy::too_many_arguments)] -pub async fn add_provider( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - issuer: String, - scope: Scope, - token_endpoint_auth_method: OAuthClientAuthenticationMethod, - token_endpoint_signing_alg: Option, - client_id: String, - encrypted_client_secret: Option, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id)); + Ok(res) + } - sqlx::query!( - r#" + #[tracing::instrument( + skip_all, + fields( + upstream_oauth_provider.id, + upstream_oauth_provider.issuer = %issuer, + upstream_oauth_provider.client_id = %client_id, + ), + err, + )] + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id)); + + sqlx::query!( + r#" INSERT INTO upstream_oauth_providers ( upstream_oauth_provider_id, issuer, @@ -155,94 +201,95 @@ pub async fn add_provider( created_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) "#, - Uuid::from(id), - &issuer, - scope.to_string(), - token_endpoint_auth_method.to_string(), - token_endpoint_signing_alg.as_ref().map(ToString::to_string), - &client_id, - encrypted_client_secret.as_deref(), - created_at, - ) - .execute(executor) - .await?; - - Ok(UpstreamOAuthProvider { - id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at, - }) -} - -#[tracing::instrument(skip_all, err)] -pub async fn get_paginated_providers( - executor: impl PgExecutor<'_>, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - WHERE 1 = 1 - "#, - ); - - query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?; - - let span = info_span!( - "Fetch paginated upstream OAuth 2.0 providers", - db.statement = query.sql() - ); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) + Uuid::from(id), + &issuer, + scope.to_string(), + token_endpoint_auth_method.to_string(), + token_endpoint_signing_alg.as_ref().map(ToString::to_string), + &client_id, + encrypted_client_secret.as_deref(), + created_at, + ) + .execute(&mut *self.conn) .await?; - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; + Ok(UpstreamOAuthProvider { + id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at, + }) + } - let page: Result, _> = page.into_iter().map(TryInto::try_into).collect(); - Ok((has_previous_page, has_next_page, page?)) -} - -#[tracing::instrument(skip_all, err)] -pub async fn get_providers( - executor: impl PgExecutor<'_>, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - ProviderLookup, - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - "#, - ) - .fetch_all(executor) - .await?; - - let res: Result, _> = res.into_iter().map(TryInto::try_into).collect(); - Ok(res?) + async fn list_paginated( + &mut self, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + WHERE 1 = 1 + "#, + ); + + query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?; + + let span = info_span!( + "Fetch paginated upstream OAuth 2.0 providers", + db.statement = query.sql() + ); + let page: Vec = query + .build_query_as() + .fetch_all(&mut *self.conn) + .instrument(span) + .await?; + + let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; + + let edges: Result, _> = edges.into_iter().map(TryInto::try_into).collect(); + Ok(Page { + has_next_page, + has_previous_page, + edges: edges?, + }) + } + #[tracing::instrument(skip_all, err)] + async fn all(&mut self) -> Result, Self::Error> { + let res = sqlx::query_as!( + ProviderLookup, + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + "#, + ) + .fetch_all(&mut *self.conn) + .await?; + + let res: Result, _> = res.into_iter().map(TryInto::try_into).collect(); + Ok(res?) + } } From 0fb0e6d5cdca0e18339b284331b8b66478cd78d7 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 30 Dec 2022 15:39:51 +0100 Subject: [PATCH 03/45] storage: upstream oauth session repository + unit tests --- Cargo.lock | 1 + .../handlers/src/upstream_oauth2/authorize.rs | 26 +- .../handlers/src/upstream_oauth2/callback.rs | 24 +- crates/handlers/src/upstream_oauth2/link.rs | 36 +- crates/storage/Cargo.toml | 2 +- crates/storage/sqlx-data.json | 213 ++----- crates/storage/src/repository.rs | 20 +- crates/storage/src/upstream_oauth2/mod.rs | 110 +++- crates/storage/src/upstream_oauth2/session.rs | 527 ++++++++---------- 9 files changed, 469 insertions(+), 490 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bc0b37baf..b5e1e65be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3104,6 +3104,7 @@ dependencies = [ "mas-jose", "oauth2-types", "rand 0.8.5", + "rand_chacha 0.3.1", "serde", "serde_json", "sqlx", diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 5e69f4169..178eba1ab 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -22,7 +22,10 @@ use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_keystore::Encrypter; use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; -use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository}; +use mas_storage::{ + upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, + Repository, +}; use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -97,16 +100,17 @@ pub(crate) async fn get( &mut rng, )?; - let session = mas_storage::upstream_oauth2::add_session( - &mut txn, - &mut rng, - &clock, - &provider, - data.state.clone(), - data.code_challenge_verifier, - data.nonce, - ) - .await?; + let session = txn + .upstream_oauth_session() + .add( + &mut rng, + &clock, + &provider, + data.state.clone(), + data.code_challenge_verifier, + data.nonce, + ) + .await?; let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar) .add(session.id, provider.id, data.state, query.post_auth_action) diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 6158f9413..295f7307b 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -26,7 +26,7 @@ use mas_oidc_client::requests::{ }; use mas_router::{Route, UrlBuilder}; use mas_storage::{ - upstream_oauth2::{complete_session, lookup_session}, + upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, Repository, UpstreamOAuthLinkRepository, }; use oauth2_types::errors::ClientErrorCode; @@ -65,6 +65,9 @@ pub(crate) enum RouteError { #[error("Session not found")] SessionNotFound, + #[error("Provider not found")] + ProviderNotFound, + #[error("Provider mismatch")] ProviderMismatch, @@ -105,6 +108,7 @@ impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { + Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(), Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(), Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(), @@ -127,16 +131,24 @@ pub(crate) async fn get( let mut txn = pool.begin().await?; + let provider = txn + .upstream_oauth_provider() + .lookup(provider_id) + .await? + .ok_or(RouteError::ProviderNotFound)?; + let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let (session_id, _post_auth_action) = sessions_cookie .find_session(provider_id, ¶ms.state) .map_err(|_| RouteError::MissingCookie)?; - let (provider, session) = lookup_session(&mut txn, session_id) + let session = txn + .upstream_oauth_session() + .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; - if provider.id != provider_id { + if provider.id != session.provider_id { // The provider in the session cookie should match the one from the URL return Err(RouteError::ProviderMismatch); } @@ -245,7 +257,11 @@ pub(crate) async fn get( .await? }; - let session = complete_session(&mut txn, &clock, session, &link, response.id_token).await?; + let session = txn + .upstream_oauth_session() + .complete_with_link(&clock, session, &link, response.id_token) + .await?; + let cookie_jar = sessions_cookie .add_link_to_session(session.id, link.id)? .save(cookie_jar, clock.now()); diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 4a109ba6e..c01d97996 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -25,7 +25,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_storage::{ - upstream_oauth2::{consume_session, lookup_session_on_link}, + upstream_oauth2::UpstreamOAuthSessionRepository, user::{add_user, authenticate_session_with_upstream, lookup_user, start_session}, Repository, UpstreamOAuthLinkRepository, }; @@ -109,12 +109,18 @@ pub(crate) async fn get( .await? .ok_or(RouteError::LinkNotFound)?; - // This checks that we're in a browser session which is allowed to consume this - // link: the upstream auth session should have been started in this browser. - let upstream_session = lookup_session_on_link(&mut txn, &link, session_id) + let upstream_session = txn + .upstream_oauth_session() + .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; + // This checks that we're in a browser session which is allowed to consume this + // link: the upstream auth session should have been started in this browser. + if upstream_session.link_id != Some(link.id) { + return Err(RouteError::SessionNotFound); + } + if upstream_session.consumed() { return Err(RouteError::SessionConsumed); } @@ -127,7 +133,10 @@ pub(crate) async fn get( (Some(mut session), Some(user_id)) if session.user.id == user_id => { // Session already linked, and link matches the currently logged // user. Mark the session as consumed and renew the authentication. - consume_session(&mut txn, &clock, upstream_session).await?; + txn.upstream_oauth_session() + .consume(&clock, upstream_session) + .await?; + authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link) .await?; @@ -212,12 +221,18 @@ pub(crate) async fn post( .await? .ok_or(RouteError::LinkNotFound)?; - // This checks that we're in a browser session which is allowed to consume this - // link: the upstream auth session should have been started in this browser. - let upstream_session = lookup_session_on_link(&mut txn, &link, session_id) + let upstream_session = txn + .upstream_oauth_session() + .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; + // This checks that we're in a browser session which is allowed to consume this + // link: the upstream auth session should have been started in this browser. + if upstream_session.link_id != Some(link.id) { + return Err(RouteError::SessionNotFound); + } + if upstream_session.consumed() { return Err(RouteError::SessionConsumed); } @@ -251,7 +266,10 @@ pub(crate) async fn post( _ => return Err(RouteError::InvalidFormAction), }; - consume_session(&mut txn, &clock, upstream_session).await?; + txn.upstream_oauth_session() + .consume(&clock, upstream_session) + .await?; + authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link).await?; let cookie_jar = sessions_cookie diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index 71240129f..fb6c0fdce 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -14,8 +14,8 @@ serde_json = "1.0.91" thiserror = "1.0.38" tracing = "0.1.37" -# Password hashing rand = "0.8.5" +rand_chacha = "0.3.1" url = { version = "2.3.1", features = ["serde"] } uuid = "1.2.2" ulid = { version = "1.0.0", features = ["uuid", "serde"] } diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 52b0118fd..8167fef34 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -521,81 +521,6 @@ }, "query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n " }, - "2ca7b990c11e84db62fb7887a2bc3410ec1eee2f6a0ec124db36575111970ca9": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_authorization_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_provider_id", - "ordinal": 1, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_link_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "state", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "code_challenge_verifier", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "nonce", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "id_token", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 7, - "type_info": "Timestamptz" - }, - { - "name": "completed_at", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "consumed_at", - "ordinal": 9, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false, - true, - false, - true, - false, - true, - true - ], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n AND upstream_oauth_link_id = $2\n " - }, "2e581d57db471b96091860cd0252361d16332deeffabab0dace405ead55324be": { "describe": { "columns": [ @@ -708,21 +633,6 @@ }, "query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n " }, - "2fb8f1aef96571a6f3f6260d7836de99ff24ba1947747e08b0e8d64097507442": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz", - "Text", - "Uuid" - ] - } - }, - "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n " - }, "360466ff599c67c9af2ac75399c0b536a22c1178972a0172b707bcc81d47357b": { "describe": { "columns": [], @@ -1388,7 +1298,24 @@ }, "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " }, - "65c7600f1af07cb6ea49d89ae6fbca5374a57c5a866c8aadd7b75ed1d2d1d0cd": { + "64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Text", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n " + }, + "67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c": { "describe": { "columns": [ { @@ -1440,41 +1367,6 @@ "name": "consumed_at", "ordinal": 9, "type_info": "Timestamptz" - }, - { - "name": "provider_issuer", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "provider_scope", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "provider_client_id", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "provider_encrypted_client_secret", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "provider_token_endpoint_auth_method", - "ordinal": 14, - "type_info": "Text" - }, - { - "name": "provider_token_endpoint_signing_alg", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "provider_created_at", - "ordinal": 16, - "type_info": "Timestamptz" } ], "nullable": [ @@ -1487,14 +1379,7 @@ true, false, true, - true, - false, - false, - false, - true, - false, - true, - false + true ], "parameters": { "Left": [ @@ -1502,7 +1387,20 @@ ] } }, - "query": "\n SELECT\n ua.upstream_oauth_authorization_session_id,\n ua.upstream_oauth_provider_id,\n ua.upstream_oauth_link_id,\n ua.state,\n ua.code_challenge_verifier,\n ua.nonce,\n ua.id_token,\n ua.created_at,\n ua.completed_at,\n ua.consumed_at,\n up.issuer AS \"provider_issuer\",\n up.scope AS \"provider_scope\",\n up.client_id AS \"provider_client_id\",\n up.encrypted_client_secret AS \"provider_encrypted_client_secret\",\n up.token_endpoint_auth_method AS \"provider_token_endpoint_auth_method\",\n up.token_endpoint_signing_alg AS \"provider_token_endpoint_signing_alg\",\n up.created_at AS \"provider_created_at\"\n FROM upstream_oauth_authorization_sessions ua\n INNER JOIN upstream_oauth_providers up\n USING (upstream_oauth_provider_id)\n WHERE upstream_oauth_authorization_session_id = $1\n " + "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n " + }, + "689ffbfc5137ec788e89062ad679bbe6b23a8861c09a7246dc1659c28f12bf8d": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Timestamptz", + "Uuid" + ] + } + }, + "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n " }, "6bf0da5ba3dd07b499193a2e0ddeea6e712f9df8f7f28874ff56a952a9f10e54": { "describe": { @@ -2420,6 +2318,21 @@ }, "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.user_email_id = $2\n " }, + "b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz", + "Text", + "Uuid" + ] + } + }, + "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n " + }, "bc768c63a7737818967bc28560de714bbbd262bdf3ab73d297263bb73dcd9f5e": { "describe": { "columns": [], @@ -2702,19 +2615,6 @@ }, "query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n " }, - "e30562e9637d3a723a91adca6336a8d083657ce6d7fe9551fcd6a9d672453d3c": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Timestamptz", - "Uuid" - ] - } - }, - "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n " - }, "e446e37d48c8838ef2e0d0fd82f8f7b04893c84ad46747cdf193ebd83755ceb2": { "describe": { "columns": [], @@ -2773,22 +2673,5 @@ } }, "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " - }, - "fb71ac6539039313fd90b29ac943330e54c7b62b2778727726e2f60a554f9c5a": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Text", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n " } } \ No newline at end of file diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index c1d259fc8..9e6ca8074 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -14,7 +14,10 @@ use sqlx::{PgConnection, Postgres, Transaction}; -use crate::upstream_oauth2::{PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository}; +use crate::upstream_oauth2::{ + PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, + PgUpstreamOAuthSessionRepository, +}; pub trait Repository { type UpstreamOAuthLinkRepository<'c> @@ -25,13 +28,19 @@ pub trait Repository { where Self: 'c; + type UpstreamOAuthSessionRepository<'c> + where + Self: 'c; + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; + fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>; } impl Repository for PgConnection { type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; + type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -40,11 +49,16 @@ impl Repository for PgConnection { fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { PgUpstreamOAuthProviderRepository::new(self) } + + fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { + PgUpstreamOAuthSessionRepository::new(self) + } } impl<'t> Repository for Transaction<'t, Postgres> { type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; + type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -53,4 +67,8 @@ impl<'t> Repository for Transaction<'t, Postgres> { fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { PgUpstreamOAuthProviderRepository::new(self) } + + fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { + PgUpstreamOAuthSessionRepository::new(self) + } } diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index d29b5e719..1abcd1d02 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -19,7 +19,111 @@ mod session; pub use self::{ link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository}, provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository}, - session::{ - add_session, complete_session, consume_session, lookup_session, lookup_session_on_link, - }, + session::{PgUpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository}, }; + +#[cfg(test)] +mod tests { + use oauth2_types::scope::{Scope, OPENID}; + use rand::SeedableRng; + use sqlx::PgPool; + + use super::*; + use crate::{Clock, Repository}; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_repository(pool: PgPool) -> Result<(), Box> { + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + let clock = Clock::default(); + let mut conn = pool.acquire().await?; + + // The provider list should be empty at the start + let all_providers = conn.upstream_oauth_provider().all().await?; + assert!(all_providers.is_empty()); + + // Let's add a provider + let provider = conn + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + "https://example.com/".to_owned(), + Scope::from_iter([OPENID]), + mas_iana::oauth::OAuthClientAuthenticationMethod::None, + None, + "client-id".to_owned(), + None, + ) + .await?; + + // Look it up in the database + let provider = conn + .upstream_oauth_provider() + .lookup(provider.id) + .await? + .expect("provider to be found in the database"); + assert_eq!(provider.issuer, "https://example.com/"); + assert_eq!(provider.client_id, "client-id"); + + // Start a session + let session = conn + .upstream_oauth_session() + .add( + &mut rng, + &clock, + &provider, + "some-state".to_owned(), + None, + "some-nonce".to_owned(), + ) + .await?; + + // Look it up in the database + let session = conn + .upstream_oauth_session() + .lookup(session.id) + .await? + .expect("session to be found in the database"); + assert_eq!(session.provider_id, provider.id); + assert_eq!(session.link_id, None); + assert!(!session.completed()); + assert!(!session.consumed()); + + // Create a link + let link = conn + .upstream_oauth_link() + .add(&mut rng, &clock, &provider, "a-subject".to_owned()) + .await?; + + // We can look it up by its ID + conn.upstream_oauth_link() + .lookup(link.id) + .await? + .expect("link to be found in database"); + + // or by its subject + let link = conn + .upstream_oauth_link() + .find_by_subject(&provider, "a-subject") + .await? + .expect("link to be found in database"); + assert_eq!(link.subject, "a-subject"); + assert_eq!(link.provider_id, provider.id); + + let session = conn + .upstream_oauth_session() + .complete_with_link(&clock, session, &link, None) + .await?; + assert!(session.completed()); + assert!(!session.consumed()); + assert_eq!(session.link_id, Some(link.id)); + + let session = conn + .upstream_oauth_session() + .consume(&clock, session) + .await?; + assert!(session.consumed()); + + Ok(()) + } +} diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index 5e013f241..f8dffcf39 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -12,261 +12,62 @@ // See the License for the specific language governing permissions and // limitations under the License. +use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider}; -use rand::Rng; -use sqlx::PgExecutor; +use rand::RngCore; +use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +use crate::{Clock, DatabaseError, LookupResultExt}; -struct SessionAndProviderLookup { - upstream_oauth_authorization_session_id: Uuid, - upstream_oauth_provider_id: Uuid, - upstream_oauth_link_id: Option, - state: String, - code_challenge_verifier: Option, - nonce: String, - id_token: Option, - created_at: DateTime, - completed_at: Option>, - consumed_at: Option>, - provider_issuer: String, - provider_scope: String, - provider_client_id: String, - provider_encrypted_client_secret: Option, - provider_token_endpoint_auth_method: String, - provider_token_endpoint_signing_alg: Option, - provider_created_at: DateTime, +#[async_trait] +pub trait UpstreamOAuthSessionRepository: Send + Sync { + type Error; + + /// Lookup a session by its ID + async fn lookup( + &mut self, + id: Ulid, + ) -> Result, Self::Error>; + + /// Add a session to the database + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + state: String, + code_challenge_verifier: Option, + nonce: String, + ) -> Result; + + /// Mark a session as completed and associate the given link + async fn complete_with_link( + &mut self, + clock: &Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result; + + /// Mark a session as consumed + async fn consume( + &mut self, + clock: &Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + ) -> Result; } -/// Lookup a session and its provider by its ID -#[tracing::instrument( - skip_all, - fields(upstream_oauth_authorization_session.id = %id), - err, -)] -pub async fn lookup_session( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - SessionAndProviderLookup, - r#" - SELECT - ua.upstream_oauth_authorization_session_id, - ua.upstream_oauth_provider_id, - ua.upstream_oauth_link_id, - ua.state, - ua.code_challenge_verifier, - ua.nonce, - ua.id_token, - ua.created_at, - ua.completed_at, - ua.consumed_at, - up.issuer AS "provider_issuer", - up.scope AS "provider_scope", - up.client_id AS "provider_client_id", - up.encrypted_client_secret AS "provider_encrypted_client_secret", - up.token_endpoint_auth_method AS "provider_token_endpoint_auth_method", - up.token_endpoint_signing_alg AS "provider_token_endpoint_signing_alg", - up.created_at AS "provider_created_at" - FROM upstream_oauth_authorization_sessions ua - INNER JOIN upstream_oauth_providers up - USING (upstream_oauth_provider_id) - WHERE upstream_oauth_authorization_session_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let id = res.upstream_oauth_provider_id.into(); - let provider = UpstreamOAuthProvider { - id, - issuer: res.provider_issuer, - scope: res.provider_scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("scope") - .row(id) - .source(e) - })?, - client_id: res.provider_client_id, - encrypted_client_secret: res.provider_encrypted_client_secret, - token_endpoint_auth_method: res.provider_token_endpoint_auth_method.parse().map_err( - |e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("token_endpoint_auth_method") - .row(id) - .source(e) - }, - )?, - token_endpoint_signing_alg: res - .provider_token_endpoint_signing_alg - .map(|x| x.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("token_endpoint_signing_alg") - .row(id) - .source(e) - })?, - created_at: res.provider_created_at, - }; - - let session = UpstreamOAuthAuthorizationSession { - id: res.upstream_oauth_authorization_session_id.into(), - provider_id: provider.id, - link_id: res.upstream_oauth_link_id.map(Ulid::from), - state: res.state, - code_challenge_verifier: res.code_challenge_verifier, - nonce: res.nonce, - id_token: res.id_token, - created_at: res.created_at, - completed_at: res.completed_at, - consumed_at: res.consumed_at, - }; - - Ok(Some((provider, session))) +pub struct PgUpstreamOAuthSessionRepository<'c> { + conn: &'c mut PgConnection, } -/// Add a session to the database -#[tracing::instrument( - skip_all, - fields( - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - upstream_oauth_authorization_session.id, - ), - err, -)] -pub async fn add_session( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - upstream_oauth_provider: &UpstreamOAuthProvider, - state: String, - code_challenge_verifier: Option, - nonce: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record( - "upstream_oauth_authorization_session.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO upstream_oauth_authorization_sessions ( - upstream_oauth_authorization_session_id, - upstream_oauth_provider_id, - state, - code_challenge_verifier, - nonce, - created_at, - completed_at, - consumed_at, - id_token - ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL) - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_provider.id), - &state, - code_challenge_verifier.as_deref(), - nonce, - created_at, - ) - .execute(executor) - .await?; - - Ok(UpstreamOAuthAuthorizationSession { - id, - provider_id: upstream_oauth_provider.id, - link_id: None, - state, - code_challenge_verifier, - nonce, - id_token: None, - created_at, - completed_at: None, - consumed_at: None, - }) -} - -/// Mark a session as completed and associate the given link -#[tracing::instrument( - skip_all, - fields( - %upstream_oauth_authorization_session.id, - %upstream_oauth_link.id, - ), - err, -)] -pub async fn complete_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, - upstream_oauth_link: &UpstreamOAuthLink, - id_token: Option, -) -> Result { - let completed_at = clock.now(); - sqlx::query!( - r#" - UPDATE upstream_oauth_authorization_sessions - SET upstream_oauth_link_id = $1, - completed_at = $2, - id_token = $3 - WHERE upstream_oauth_authorization_session_id = $4 - "#, - Uuid::from(upstream_oauth_link.id), - completed_at, - id_token, - Uuid::from(upstream_oauth_authorization_session.id), - ) - .execute(executor) - .await?; - - upstream_oauth_authorization_session.completed_at = Some(completed_at); - upstream_oauth_authorization_session.id_token = id_token; - - Ok(upstream_oauth_authorization_session) -} - -/// Mark a session as consumed -#[tracing::instrument( - skip_all, - fields( - %upstream_oauth_authorization_session.id, - ), - err, -)] -pub async fn consume_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, -) -> Result { - let consumed_at = clock.now(); - sqlx::query!( - r#" - UPDATE upstream_oauth_authorization_sessions - SET consumed_at = $1 - WHERE upstream_oauth_authorization_session_id = $2 - "#, - consumed_at, - Uuid::from(upstream_oauth_authorization_session.id), - ) - .execute(executor) - .await?; - - upstream_oauth_authorization_session.consumed_at = Some(consumed_at); - - Ok(upstream_oauth_authorization_session) +impl<'c> PgUpstreamOAuthSessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } } struct SessionLookup { @@ -282,57 +83,191 @@ struct SessionLookup { consumed_at: Option>, } -/// Lookup a session, which belongs to a link, by its ID -#[tracing::instrument( - skip_all, - fields( - upstream_oauth_authorization_session.id = %id, - %upstream_oauth_link.id, - ), - err, -)] -pub async fn lookup_session_on_link( - executor: impl PgExecutor<'_>, - upstream_oauth_link: &UpstreamOAuthLink, - id: Ulid, -) -> Result, sqlx::Error> { - let res = sqlx::query_as!( - SessionLookup, - r#" - SELECT - upstream_oauth_authorization_session_id, - upstream_oauth_provider_id, - upstream_oauth_link_id, - state, - code_challenge_verifier, - nonce, - id_token, - created_at, - completed_at, - consumed_at - FROM upstream_oauth_authorization_sessions - WHERE upstream_oauth_authorization_session_id = $1 - AND upstream_oauth_link_id = $2 - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_link.id), - ) - .fetch_one(executor) - .await - .to_option()?; +#[async_trait] +impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> { + type Error = DatabaseError; - let Some(res) = res else { return Ok(None) }; + #[tracing::instrument( + skip_all, + fields(upstream_oauth_provider.id = %id), + err, + )] + async fn lookup( + &mut self, + id: Ulid, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + SessionLookup, + r#" + SELECT + upstream_oauth_authorization_session_id, + upstream_oauth_provider_id, + upstream_oauth_link_id, + state, + code_challenge_verifier, + nonce, + id_token, + created_at, + completed_at, + consumed_at + FROM upstream_oauth_authorization_sessions + WHERE upstream_oauth_authorization_session_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; - Ok(Some(UpstreamOAuthAuthorizationSession { - id: res.upstream_oauth_authorization_session_id.into(), - provider_id: res.upstream_oauth_provider_id.into(), - link_id: res.upstream_oauth_link_id.map(Ulid::from), - state: res.state, - code_challenge_verifier: res.code_challenge_verifier, - nonce: res.nonce, - id_token: res.id_token, - created_at: res.created_at, - completed_at: res.completed_at, - consumed_at: res.consumed_at, - })) + let Some(res) = res else { return Ok(None) }; + + let session = UpstreamOAuthAuthorizationSession { + id: res.upstream_oauth_authorization_session_id.into(), + provider_id: res.upstream_oauth_provider_id.into(), + link_id: res.upstream_oauth_link_id.map(Ulid::from), + state: res.state, + code_challenge_verifier: res.code_challenge_verifier, + nonce: res.nonce, + id_token: res.id_token, + created_at: res.created_at, + completed_at: res.completed_at, + consumed_at: res.consumed_at, + }; + + Ok(Some(session)) + } + + #[tracing::instrument( + skip_all, + fields( + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + upstream_oauth_authorization_session.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + state: String, + code_challenge_verifier: Option, + nonce: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "upstream_oauth_authorization_session.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_authorization_sessions ( + upstream_oauth_authorization_session_id, + upstream_oauth_provider_id, + state, + code_challenge_verifier, + nonce, + created_at, + completed_at, + consumed_at, + id_token + ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL) + "#, + Uuid::from(id), + Uuid::from(upstream_oauth_provider.id), + &state, + code_challenge_verifier.as_deref(), + nonce, + created_at, + ) + .execute(&mut *self.conn) + .await?; + + Ok(UpstreamOAuthAuthorizationSession { + id, + provider_id: upstream_oauth_provider.id, + link_id: None, + state, + code_challenge_verifier, + nonce, + id_token: None, + created_at, + completed_at: None, + consumed_at: None, + }) + } + + #[tracing::instrument( + skip_all, + fields( + %upstream_oauth_authorization_session.id, + %upstream_oauth_link.id, + ), + err, + )] + async fn complete_with_link( + &mut self, + clock: &Clock, + mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result { + let completed_at = clock.now(); + sqlx::query!( + r#" + UPDATE upstream_oauth_authorization_sessions + SET upstream_oauth_link_id = $1, + completed_at = $2, + id_token = $3 + WHERE upstream_oauth_authorization_session_id = $4 + "#, + Uuid::from(upstream_oauth_link.id), + completed_at, + id_token, + Uuid::from(upstream_oauth_authorization_session.id), + ) + .execute(&mut *self.conn) + .await?; + + upstream_oauth_authorization_session.completed_at = Some(completed_at); + upstream_oauth_authorization_session.id_token = id_token; + upstream_oauth_authorization_session.link_id = Some(upstream_oauth_link.id); + + Ok(upstream_oauth_authorization_session) + } + + /// Mark a session as consumed + #[tracing::instrument( + skip_all, + fields( + %upstream_oauth_authorization_session.id, + ), + err, + )] + async fn consume( + &mut self, + clock: &Clock, + mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + ) -> Result { + let consumed_at = clock.now(); + sqlx::query!( + r#" + UPDATE upstream_oauth_authorization_sessions + SET consumed_at = $1 + WHERE upstream_oauth_authorization_session_id = $2 + "#, + consumed_at, + Uuid::from(upstream_oauth_authorization_session.id), + ) + .execute(&mut *self.conn) + .await?; + + upstream_oauth_authorization_session.consumed_at = Some(consumed_at); + + Ok(upstream_oauth_authorization_session) + } } From bd7f949300be47052dfc89ca6c022d8adae96d66 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 2 Jan 2023 15:28:44 +0100 Subject: [PATCH 04/45] storage: user and user email repository --- crates/cli/src/commands/manage.rs | 19 +- crates/data-model/src/users.rs | 11 +- crates/graphql/src/lib.rs | 10 +- crates/graphql/src/model/upstream_oauth.rs | 9 +- crates/graphql/src/model/users.rs | 29 +- crates/handlers/src/compat/login.rs | 8 +- .../handlers/src/compat/login_sso_complete.rs | 18 +- crates/handlers/src/oauth2/userinfo.rs | 24 +- crates/handlers/src/upstream_oauth2/link.rs | 27 +- .../handlers/src/views/account/emails/add.rs | 8 +- .../handlers/src/views/account/emails/mod.rs | 104 +- .../src/views/account/emails/verify.rs | 39 +- crates/handlers/src/views/account/mod.rs | 7 +- crates/handlers/src/views/login.rs | 10 +- crates/handlers/src/views/register.rs | 34 +- crates/storage/sqlx-data.json | 2407 ++++++++--------- crates/storage/src/compat.rs | 124 +- crates/storage/src/lib.rs | 2 +- crates/storage/src/oauth2/access_token.rs | 64 +- .../storage/src/oauth2/authorization_grant.rs | 51 +- crates/storage/src/oauth2/refresh_token.rs | 39 +- crates/storage/src/repository.rs | 39 +- crates/storage/src/user/email.rs | 555 ++++ crates/storage/src/user/mod.rs | 928 ++----- crates/templates/src/context.rs | 4 +- docs/development/database.md | 2 + 26 files changed, 2148 insertions(+), 2424 deletions(-) create mode 100644 crates/storage/src/user/email.rs diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index a92f35af7..d46130e28 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -20,9 +20,7 @@ use mas_router::UrlBuilder; use mas_storage::{ oauth2::client::{insert_client_from_config, lookup_client, truncate_clients}, upstream_oauth2::UpstreamOAuthProviderRepository, - user::{ - add_user_password, lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, - }, + user::{add_user_password, UserEmailRepository, UserRepository}, Clock, Repository, }; use oauth2_types::scope::Scope; @@ -202,7 +200,9 @@ impl Options { let password_manager = password_manager_from_config(&passwords_config).await?; let mut txn = pool.begin().await?; - let user = lookup_user_by_username(&mut txn, username) + let user = txn + .user() + .find_by_username(username) .await? .context("User not found")?; @@ -232,13 +232,18 @@ impl Options { let pool = database_from_config(&config).await?; let mut txn = pool.begin().await?; - let user = lookup_user_by_username(&mut txn, username) + let user = txn + .user() + .find_by_username(username) .await? .context("User not found")?; - let email = lookup_user_email(&mut txn, &user, email) + + let email = txn + .user_email() + .find(&user, email) .await? .context("Email not found")?; - let email = mark_user_email_as_verified(&mut txn, &clock, email).await?; + let email = txn.user_email().mark_as_verified(&clock, email).await?; txn.commit().await?; info!(?email, "Email marked as verified"); diff --git a/crates/data-model/src/users.rs b/crates/data-model/src/users.rs index 4d9c884a0..995535d76 100644 --- a/crates/data-model/src/users.rs +++ b/crates/data-model/src/users.rs @@ -22,7 +22,7 @@ pub struct User { pub id: Ulid, pub username: String, pub sub: String, - pub primary_email: Option, + pub primary_user_email_id: Option, } impl User { @@ -32,7 +32,7 @@ impl User { id: Ulid::from_datetime_with_source(now.into(), rng), username: "john".to_owned(), sub: "123-456".to_owned(), - primary_email: None, + primary_user_email_id: None, }] } } @@ -89,6 +89,7 @@ impl BrowserSession { #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct UserEmail { pub id: Ulid, + pub user_id: Ulid, pub email: String, pub created_at: DateTime, pub confirmed_at: Option>, @@ -100,12 +101,14 @@ impl UserEmail { vec![ Self { id: Ulid::from_datetime_with_source(now.into(), rng), + user_id: Ulid::from_datetime_with_source(now.into(), rng), email: "alice@example.com".to_owned(), created_at: now, confirmed_at: Some(now), }, Self { id: Ulid::from_datetime_with_source(now.into(), rng), + user_id: Ulid::from_datetime_with_source(now.into(), rng), email: "bob@example.com".to_owned(), created_at: now, confirmed_at: None, @@ -124,7 +127,7 @@ pub enum UserEmailVerificationState { #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct UserEmailVerification { pub id: Ulid, - pub email: UserEmail, + pub user_email_id: Ulid, pub code: String, pub created_at: DateTime, pub state: UserEmailVerificationState, @@ -152,8 +155,8 @@ impl UserEmailVerification { .into_iter() .map(move |email| Self { id: Ulid::from_datetime_with_source(now.into(), &mut rng), + user_email_id: email.id, code: "123456".to_owned(), - email, created_at: now - Duration::minutes(10), state: state.clone(), }) diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 8f3ef3219..f01d83701 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -31,7 +31,8 @@ use async_graphql::{ Context, Description, EmptyMutation, EmptySubscription, ID, }; use mas_storage::{ - upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository, + upstream_oauth2::UpstreamOAuthProviderRepository, user::UserEmailRepository, Repository, + UpstreamOAuthLinkRepository, }; use model::CreationEvent; use sqlx::PgPool; @@ -154,8 +155,11 @@ impl RootQuery { let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let user_email = - mas_storage::user::lookup_user_email_by_id(&mut conn, ¤t_user, id).await?; + let user_email = conn + .user_email() + .lookup(id) + .await? + .filter(|e| e.user_id == current_user.id); Ok(user_email.map(UserEmail)) } diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index 2de6f2f71..249a09286 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -15,7 +15,9 @@ use anyhow::Context as _; use async_graphql::{Context, Object, ID}; use chrono::{DateTime, Utc}; -use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository}; +use mas_storage::{ + upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, Repository, +}; use sqlx::PgPool; use super::{NodeType, User}; @@ -120,7 +122,10 @@ impl UpstreamOAuth2Link { // Fetch on-the-fly let database = ctx.data::()?; let mut conn = database.acquire().await?; - mas_storage::user::lookup_user(&mut conn, *user_id).await? + conn.user() + .lookup(*user_id) + .await? + .context("User not found")? } else { return Ok(None); }; diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 01fcfb0eb..58119c093 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -17,7 +17,7 @@ use async_graphql::{ Context, Description, Object, ID, }; use chrono::{DateTime, Utc}; -use mas_storage::{Repository, UpstreamOAuthLinkRepository}; +use mas_storage::{user::UserEmailRepository, Repository, UpstreamOAuthLinkRepository}; use sqlx::PgPool; use super::{ @@ -54,8 +54,14 @@ impl User { } /// Primary email address of the user. - async fn primary_email(&self) -> Option { - self.0.primary_email.clone().map(UserEmail) + async fn primary_email( + &self, + ctx: &Context<'_>, + ) -> Result, async_graphql::Error> { + let database = ctx.data::()?; + let mut conn = database.acquire().await?; + + Ok(conn.user_email().get_primary(&self.0).await?.map(UserEmail)) } /// Get the list of compatibility SSO logins, chronologically sorted @@ -182,18 +188,17 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::UserEmail)) .transpose()?; - let (has_previous_page, has_next_page, edges) = - mas_storage::user::get_paginated_user_emails( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = conn + .user_email() + .list_paginated(&self.0, before_id, after_id, first, last) .await?; let mut connection = Connection::with_additional_fields( - has_previous_page, - has_next_page, + page.has_previous_page, + page.has_next_page, UserEmailsPagination(self.0.clone()), ); - connection.edges.extend(edges.into_iter().map(|u| { + connection.edges.extend(page.edges.into_iter().map(|u| { Edge::new( OpaqueCursor(NodeCursor(NodeType::UserEmail, u.id)), UserEmail(u), @@ -339,9 +344,9 @@ pub struct UserEmailsPagination(mas_data_model::User); #[Object] impl UserEmailsPagination { /// Identifies the total count of items in the connection. - async fn total_count(&self, ctx: &Context<'_>) -> Result { + async fn total_count(&self, ctx: &Context<'_>) -> Result { let mut conn = ctx.data::()?.acquire().await?; - let count = mas_storage::user::count_user_emails(&mut conn, &self.0).await?; + let count = conn.user_email().count(&self.0).await?; Ok(count) } } diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 6d5ca8255..dd4d4742b 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -21,8 +21,8 @@ use mas_storage::{ add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token, mark_compat_sso_login_as_exchanged, start_compat_session, }, - user::{add_user_password, lookup_user_by_username, lookup_user_password}, - Clock, + user::{add_user_password, lookup_user_password, UserRepository}, + Clock, Repository, }; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; @@ -314,7 +314,9 @@ async fn user_password_login( let (clock, mut rng) = crate::clock_and_rng(); // Find the user - let user = lookup_user_by_username(&mut *txn, &username) + let user = txn + .user() + .find_by_username(&username) .await? .ok_or(RouteError::UserNotFound)?; diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index e0416cd17..497908427 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -80,14 +80,8 @@ pub async fn get( return Ok((cookie_jar, url).into_response()); }; - // TODO: make that more generic - if session - .user - .primary_email - .as_ref() - .and_then(|e| e.confirmed_at) - .is_none() - { + // TODO: make that more generic, check that the email has been confirmed + if session.user.primary_user_email_id.is_none() { let destination = mas_router::AccountAddEmail::default() .and_then(PostAuthAction::continue_compat_sso_login(id)); return Ok((cookie_jar, destination.go()).into_response()); @@ -149,13 +143,7 @@ pub async fn post( }; // TODO: make that more generic - if session - .user - .primary_email - .as_ref() - .and_then(|e| e.confirmed_at) - .is_none() - { + if session.user.primary_user_email_id.is_none() { let destination = mas_router::AccountAddEmail::default() .and_then(PostAuthAction::continue_compat_sso_login(id)); return Ok((cookie_jar, destination.go()).into_response()); diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 0369739d2..225870ad3 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -28,6 +28,7 @@ use mas_jose::{ }; use mas_keystore::Keystore; use mas_router::UrlBuilder; +use mas_storage::{user::UserEmailRepository, Repository}; use oauth2_types::scope; use serde::Serialize; use serde_with::skip_serializing_none; @@ -66,6 +67,7 @@ pub enum RouteError { } impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::DatabaseError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); @@ -92,19 +94,19 @@ pub async fn get( let session = user_authorization.protected(&mut conn).await?; let user = session.browser_session.user; - let mut user_info = UserInfo { - sub: user.sub, - username: user.username, - email: None, - email_verified: None, + + let user_email = if session.scope.contains(&scope::EMAIL) { + conn.user_email().get_primary(&user).await? + } else { + None }; - if session.scope.contains(&scope::EMAIL) { - if let Some(email) = user.primary_email { - user_info.email_verified = Some(email.confirmed_at.is_some()); - user_info.email = Some(email.email); - } - } + let user_info = UserInfo { + sub: user.sub.clone(), + username: user.username.clone(), + email_verified: user_email.as_ref().map(|u| u.confirmed_at.is_some()), + email: user_email.map(|u| u.email), + }; if let Some(alg) = session.client.userinfo_signed_response_alg { let key = key_store diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index c01d97996..dbd06059b 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -26,7 +26,7 @@ use mas_axum_utils::{ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthSessionRepository, - user::{add_user, authenticate_session_with_upstream, lookup_user, start_session}, + user::{authenticate_session_with_upstream, start_session, UserRepository}, Repository, UpstreamOAuthLinkRepository, }; use mas_templates::{ @@ -51,6 +51,10 @@ pub(crate) enum RouteError { #[error("Session not found")] SessionNotFound, + /// Couldn't find the user + #[error("User not found")] + UserNotFound, + /// Session was already consumed #[error("Session already consumed")] SessionConsumed, @@ -157,7 +161,11 @@ pub(crate) async fn get( // Session already linked, but link doesn't match the currently // logged user. Suggest logging out of the current user // and logging in with the new one - let user = lookup_user(&mut txn, user_id).await?; + let user = txn + .user() + .lookup(user_id) + .await? + .ok_or(RouteError::UserNotFound)?; let ctx = UpstreamExistingLinkContext::new(user) .with_session(user_session) @@ -177,7 +185,11 @@ pub(crate) async fn get( (None, Some(user_id)) => { // Session linked, but user not logged in: do the login - let user = lookup_user(&mut txn, user_id).await?; + let user = txn + .user() + .lookup(user_id) + .await? + .ok_or(RouteError::UserNotFound)?; let ctx = UpstreamExistingLinkContext::new(user).with_csrf(csrf_token.form_value()); @@ -250,12 +262,17 @@ pub(crate) async fn post( } (None, Some(user_id), FormData::Login) => { - let user = lookup_user(&mut txn, user_id).await?; + let user = txn + .user() + .lookup(user_id) + .await? + .ok_or(RouteError::UserNotFound)?; + start_session(&mut txn, &mut rng, &clock, user).await? } (None, None, FormData::Register { username }) => { - let user = add_user(&mut txn, &mut rng, &clock, &username).await?; + let user = txn.user().add(&mut rng, &clock, username).await?; txn.upstream_oauth_link() .associate_to_user(&link, &user) .await?; diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 06fe7e067..c7cd27676 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -24,7 +24,7 @@ use mas_axum_utils::{ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::user::add_user_email; +use mas_storage::{user::UserEmailRepository, Repository}; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; use sqlx::PgPool; @@ -88,7 +88,11 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = add_user_email(&mut txn, &mut rng, &clock, &session.user, form.email).await?; + let user_email = txn + .user_email() + .add(&mut rng, &clock, &session.user, form.email) + .await?; + let next = mas_router::AccountVerifyEmail::new(user_email.id); let next = if let Some(action) = query.post_auth_action { next.and_then(action) diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 061e360c0..e6e1e3410 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::{anyhow, Context}; use axum::{ extract::{Form, State}, response::{Html, IntoResponse, Response}, @@ -27,17 +28,11 @@ use mas_data_model::{BrowserSession, User, UserEmail}; use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{ - user::{ - add_user_email, add_user_email_verification_code, get_user_email, get_user_emails, - remove_user_email, set_user_email_as_primary, - }, - Clock, -}; +use mas_storage::{user::UserEmailRepository, Clock, Repository}; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; use serde::Deserialize; -use sqlx::{PgExecutor, PgPool}; +use sqlx::{PgConnection, PgPool}; use tracing::info; pub mod add; @@ -79,11 +74,11 @@ async fn render( templates: Templates, session: BrowserSession, cookie_jar: PrivateCookieJar, - executor: impl PgExecutor<'_>, + conn: &mut PgConnection, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng); - let emails = get_user_emails(executor, &session.user).await?; + let emails = conn.user_email().all(&session.user).await?; let ctx = AccountEmailsContext::new(emails) .with_session(session) @@ -96,7 +91,7 @@ async fn render( async fn start_email_verification( mailer: &Mailer, - executor: impl PgExecutor<'_>, + conn: &mut PgConnection, mut rng: impl Rng + Send, clock: &Clock, user: &User, @@ -108,15 +103,10 @@ async fn start_email_verification( let address: Address = user_email.email.parse()?; - let verification = add_user_email_verification_code( - executor, - &mut rng, - clock, - user_email, - Duration::hours(8), - code, - ) - .await?; + let verification = conn + .user_email() + .add_verification_code(&mut rng, clock, &user_email, Duration::hours(8), code) + .await?; // And send the verification email let mailbox = Mailbox::new(Some(user.username.clone()), address); @@ -126,7 +116,7 @@ async fn start_email_verification( mailer.send_verification_email(mailbox, &context).await?; info!( - email.id = %verification.email.id, + email.id = %user_email.id, "Verification email sent" ); Ok(()) @@ -157,49 +147,65 @@ pub(crate) async fn post( match form { ManagementForm::Add { email } => { - let user_email = - add_user_email(&mut txn, &mut rng, &clock, &session.user, email).await?; - let next = mas_router::AccountVerifyEmail::new(user_email.id); - start_email_verification( - &mailer, - &mut txn, - &mut rng, - &clock, - &session.user, - user_email, - ) - .await?; + let email = txn + .user_email() + .add(&mut rng, &clock, &session.user, email) + .await?; + + let next = mas_router::AccountVerifyEmail::new(email.id); + start_email_verification(&mailer, &mut txn, &mut rng, &clock, &session.user, email) + .await?; txn.commit().await?; return Ok((cookie_jar, next.go()).into_response()); } ManagementForm::ResendConfirmation { id } => { let id = id.parse()?; - let user_email = get_user_email(&mut txn, &session.user, id).await?; - let next = mas_router::AccountVerifyEmail::new(user_email.id); - start_email_verification( - &mailer, - &mut txn, - &mut rng, - &clock, - &session.user, - user_email, - ) - .await?; + let email = txn + .user_email() + .lookup(id) + .await? + .context("Email not found")?; + + if email.user_id != session.user.id { + return Err(anyhow!("Email not found").into()); + } + + let next = mas_router::AccountVerifyEmail::new(email.id); + start_email_verification(&mailer, &mut txn, &mut rng, &clock, &session.user, email) + .await?; txn.commit().await?; return Ok((cookie_jar, next.go()).into_response()); } ManagementForm::Remove { id } => { let id = id.parse()?; - let email = get_user_email(&mut txn, &session.user, id).await?; - remove_user_email(&mut txn, email).await?; + let email = txn + .user_email() + .lookup(id) + .await? + .context("Email not found")?; + + if email.user_id != session.user.id { + return Err(anyhow!("Email not found").into()); + } + + txn.user_email().remove(email).await?; } ManagementForm::SetPrimary { id } => { let id = id.parse()?; - let email = get_user_email(&mut txn, &session.user, id).await?; - set_user_email_as_primary(&mut txn, &email).await?; - session.user.primary_email = Some(email); + let email = txn + .user_email() + .lookup(id) + .await? + .context("Email not found")?; + + if email.user_id != session.user.id { + return Err(anyhow!("Email not found").into()); + } + + txn.user_email().set_as_primary(&email).await?; + session.user.primary_user_email_id = Some(email.id); } }; diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 0ce6503a5..1192743e7 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -24,13 +24,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{ - user::{ - consume_email_verification, lookup_user_email_by_id, lookup_user_email_verification_code, - mark_user_email_as_verified, set_user_email_as_primary, - }, - Clock, -}; +use mas_storage::{user::UserEmailRepository, Clock, Repository}; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; use sqlx::PgPool; @@ -65,8 +59,11 @@ pub(crate) async fn get( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = lookup_user_email_by_id(&mut conn, &session.user, id) + let user_email = conn + .user_email() + .lookup(id) .await? + .filter(|u| u.user_id == session.user.id) .context("Could not find user email")?; if user_email.confirmed_at.is_some() { @@ -106,23 +103,31 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let email = lookup_user_email_by_id(&mut txn, &session.user, id) + let user_email = txn + .user_email() + .lookup(id) .await? + .filter(|u| u.user_id == session.user.id) .context("Could not find user email")?; - if session.user.primary_email.is_none() { - set_user_email_as_primary(&mut txn, &email).await?; - } - - // TODO: make those 8 hours configurable - let verification = lookup_user_email_verification_code(&mut txn, &clock, email, &form.code) + let verification = txn + .user_email() + .find_verification_code(&clock, &user_email, &form.code) .await? .context("Invalid code")?; // TODO: display nice errors if the code was already consumed or expired - let verification = consume_email_verification(&mut txn, &clock, verification).await?; + txn.user_email() + .consume_verification_code(&clock, verification) + .await?; - let _email = mark_user_email_as_verified(&mut txn, &clock, verification.email).await?; + if session.user.primary_user_email_id.is_none() { + txn.user_email().set_as_primary(&user_email).await?; + } + + txn.user_email() + .mark_as_verified(&clock, user_email) + .await?; txn.commit().await?; diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 87eec0961..07a70898f 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -23,7 +23,10 @@ use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::user::{count_active_sessions, get_user_emails}; +use mas_storage::{ + user::{count_active_sessions, UserEmailRepository}, + Repository, +}; use mas_templates::{AccountContext, TemplateContext, Templates}; use sqlx::PgPool; @@ -49,7 +52,7 @@ pub(crate) async fn get( let active_sessions = count_active_sessions(&mut conn, &session.user).await?; - let emails = get_user_emails(&mut conn, &session.user).await?; + let emails = conn.user_email().all(&session.user).await?; let ctx = AccountContext::new(active_sessions, emails) .with_session(session) diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index fd54175d2..5ba76b726 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -26,8 +26,8 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, user::{ - add_user_password, authenticate_session_with_password, lookup_user_by_username, - lookup_user_password, start_session, + add_user_password, authenticate_session_with_password, lookup_user_password, start_session, + UserRepository, }, Clock, Repository, }; @@ -130,8 +130,6 @@ pub(crate) async fn post( return Ok((cookie_jar, Html(content)).into_response()); } - lookup_user_by_username(&mut conn, &form.username).await?; - match login( password_manager, &mut conn, @@ -175,7 +173,9 @@ async fn login( ) -> Result { // XXX: we're loosing the error context here // First, lookup the user - let user = lookup_user_by_username(&mut *conn, username) + let user = conn + .user() + .find_by_username(username) .await .map_err(|_e| FormError::Internal)? .ok_or(FormError::InvalidCredentials)?; diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 9a12efac2..01dc21167 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -31,9 +31,12 @@ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::Route; -use mas_storage::user::{ - add_user, add_user_email, add_user_email_verification_code, add_user_password, - authenticate_session_with_password, start_session, username_exists, +use mas_storage::{ + user::{ + add_user_password, authenticate_session_with_password, start_session, UserEmailRepository, + UserRepository, + }, + Repository, }; use mas_templates::{ EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, @@ -114,7 +117,7 @@ pub(crate) async fn post( if form.username.is_empty() { state.add_error_on_field(RegisterFormField::Username, FieldError::Required); - } else if username_exists(&mut txn, &form.username).await? { + } else if txn.user().exists(&form.username).await? { state.add_error_on_field(RegisterFormField::Username, FieldError::Exists); } @@ -185,7 +188,7 @@ pub(crate) async fn post( return Ok((cookie_jar, Html(content)).into_response()); } - let user = add_user(&mut txn, &mut rng, &clock, &form.username).await?; + let user = txn.user().add(&mut rng, &clock, form.username).await?; let password = Zeroizing::new(form.password.into_bytes()); let (version, hashed_password) = password_manager.hash(&mut rng, password).await?; let user_password = add_user_password( @@ -199,7 +202,10 @@ pub(crate) async fn post( ) .await?; - let user_email = add_user_email(&mut txn, &mut rng, &clock, &user, form.email).await?; + let user_email = txn + .user_email() + .add(&mut rng, &clock, &user, form.email) + .await?; // First, generate a code let range = Uniform::::from(0..1_000_000); @@ -208,15 +214,10 @@ pub(crate) async fn post( let address: Address = user_email.email.parse()?; - let verification = add_user_email_verification_code( - &mut txn, - &mut rng, - &clock, - user_email, - Duration::hours(8), - code, - ) - .await?; + let verification = txn + .user_email() + .add_verification_code(&mut rng, &clock, &user_email, Duration::hours(8), code) + .await?; // And send the verification email let mailbox = Mailbox::new(Some(user.username.clone()), address); @@ -225,8 +226,7 @@ pub(crate) async fn post( mailer.send_verification_email(mailbox, &context).await?; - let next = mas_router::AccountVerifyEmail::new(verification.email.id) - .and_maybe(query.post_auth_action); + let next = mas_router::AccountVerifyEmail::new(user_email.id).and_maybe(query.post_auth_action); let mut session = start_session(&mut txn, &mut rng, &clock, user).await?; authenticate_session_with_password(&mut txn, &mut rng, &clock, &mut session, &user_password) diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 8167fef34..3191a9dbf 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -1,5 +1,103 @@ { "db": "PostgreSQL", + "03bc4a14e97e011fec04e5788a967e04838cf978984254ecfd2c8b8a979da1c8": { + "describe": { + "columns": [ + { + "name": "oauth2_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "oauth2_access_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id!", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "oauth2_client_id!", + "ordinal": 5, + "type_info": "Uuid" + }, + { + "name": "scope!", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "user_session_id!", + "ordinal": 7, + "type_info": "Uuid" + }, + { + "name": "user_session_created_at!", + "ordinal": 8, + "type_info": "Timestamptz" + }, + { + "name": "user_id!", + "ordinal": 9, + "type_info": "Uuid" + }, + { + "name": "user_username!", + "ordinal": 10, + "type_info": "Text" + }, + { + "name": "user_primary_user_email_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "user_session_last_authentication_id?", + "ordinal": 12, + "type_info": "Uuid" + }, + { + "name": "user_session_last_authentication_created_at?", + "ordinal": 13, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , us.user_session_id AS \"user_session_id!\"\n , us.created_at AS \"user_session_created_at!\"\n , u.user_id AS \"user_id!\"\n , u.username AS \"user_username!\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , usa.user_session_authentication_id AS \"user_session_last_authentication_id?\"\n , usa.created_at AS \"user_session_last_authentication_created_at?\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n INNER JOIN user_sessions us\n USING (user_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " + }, "05b50b7ae0109063c50fe70e83635a31920e44a7fbaa2b4f07552ba2f83a28d7": { "describe": { "columns": [ @@ -116,7 +214,7 @@ }, "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = $1\n " }, - "0b49cde0b7b79f79ec261502ab89bcffa81f9f5ed2f922a41b1718274b9e3073": { + "08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35": { "describe": { "columns": [ { @@ -125,28 +223,71 @@ "type_info": "Uuid" }, { - "name": "user_username", + "name": "username", "ordinal": 1, "type_info": "Text" }, { - "name": "user_email_id?", + "name": "primary_user_email_id", "ordinal": 2, "type_info": "Uuid" }, { - "name": "user_email?", + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE user_id = $1\n " + }, + "09d995295b2e4f180181ec96023b1e524ddae9098694eedc4dcce857e3095c0e": { + "describe": { + "columns": [ + { + "name": "user_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_session_created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "user_username", "ordinal": 3, "type_info": "Text" }, { - "name": "user_email_created_at?", + "name": "user_primary_user_email_id", "ordinal": 4, - "type_info": "Timestamptz" + "type_info": "Uuid" }, { - "name": "user_email_confirmed_at?", + "name": "last_authentication_id?", "ordinal": 5, + "type_info": "Uuid" + }, + { + "name": "last_authd_at?", + "ordinal": 6, "type_info": "Timestamptz" } ], @@ -155,29 +296,17 @@ false, false, false, + true, false, - true + false ], "parameters": { "Left": [ - "Text" + "Uuid" ] } }, - "query": "\n SELECT\n u.user_id,\n u.username AS user_username,\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM users u\n\n LEFT JOIN user_emails ue\n USING (user_id)\n\n WHERE u.username = $1\n " - }, - "1166343ad1563cb66ab387368f67320a53c34edf388bdb991359ebdf324497d5": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE user_emails\n SET confirmed_at = $2\n WHERE user_email_id = $1\n " + "query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , a.user_session_authentication_id AS \"last_authentication_id?\"\n , a.created_at AS \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications a\n USING (user_session_id)\n WHERE s.user_session_id = $1 AND s.finished_at IS NULL\n ORDER BY a.created_at DESC\n LIMIT 1\n " }, "154e2e4488ff87e09163698750b56a43127cee4e1392785416a586d40a4d9b21": { "describe": { @@ -239,56 +368,99 @@ }, "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n " }, - "1eb6d13e75d8f526c2785749a020731c18012f03e07995213acd38ab560ce497": { + "16a1c5fe5a4c5481212560d79d589b550dfefe7480c5ee4febcbfaaa01ee93a4": { "describe": { - "columns": [], - "nullable": [], + "columns": [ + { + "name": "compat_sso_login_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "compat_sso_login_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "compat_sso_login_redirect_uri", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "compat_sso_login_created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_sso_login_fulfilled_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "compat_sso_login_exchanged_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id?", + "ordinal": 6, + "type_info": "Uuid" + }, + { + "name": "compat_session_created_at?", + "ordinal": 7, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_finished_at?", + "ordinal": 8, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_device_id?", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "user_id?", + "ordinal": 10, + "type_info": "Uuid" + }, + { + "name": "user_username?", + "ordinal": 11, + "type_info": "Text" + }, + { + "name": "user_primary_user_email_id?", + "ordinal": 12, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + false, + false, + true, + false, + false, + false, + true + ], "parameters": { "Left": [ - "Uuid", - "Uuid", - "Timestamptz" + "Text" ] } }, - "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n " + "query": "\n SELECT\n cl.compat_sso_login_id,\n cl.login_token AS \"compat_sso_login_token\",\n cl.redirect_uri AS \"compat_sso_login_redirect_uri\",\n cl.created_at AS \"compat_sso_login_created_at\",\n cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\",\n cl.exchanged_at AS \"compat_sso_login_exchanged_at\",\n cs.compat_session_id AS \"compat_session_id?\",\n cs.created_at AS \"compat_session_created_at?\",\n cs.finished_at AS \"compat_session_finished_at?\",\n cs.device_id AS \"compat_session_device_id?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n u.primary_user_email_id AS \"user_primary_user_email_id?\"\n FROM compat_sso_logins cl\n LEFT JOIN compat_sessions cs\n USING (compat_session_id)\n LEFT JOIN users u\n USING (user_id)\n WHERE cl.login_token = $1\n " }, - "1ee5cecfafd4726a4ebc08da8a34c09178e6e1e072581c8fca9d3d76967792cb": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n " - }, - "2153118b364a33582e7f598acce3789fcb8d938948a819b15cf0b6d37edf58b2": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO compat_access_tokens\n (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " - }, - "24d6154b138a5e9105b996d6447e45a5c208e157f6583b4220cf58813d6f436c": { + "1a5e0d1d88065bb4e7f790942856d1d94ecdb30a7007f3277ca3f7cbdabd4dff": { "describe": { "columns": [ { @@ -407,33 +579,18 @@ "type_info": "Text" }, { - "name": "user_session_last_authentication_id?", + "name": "user_primary_user_email_id?", "ordinal": 23, "type_info": "Uuid" }, { - "name": "user_session_last_authentication_created_at?", + "name": "user_session_last_authentication_id?", "ordinal": 24, - "type_info": "Timestamptz" - }, - { - "name": "user_email_id?", - "ordinal": 25, "type_info": "Uuid" }, { - "name": "user_email?", - "ordinal": 26, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 27, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 28, + "name": "user_session_last_authentication_created_at?", + "ordinal": 25, "type_info": "Timestamptz" } ], @@ -461,6 +618,223 @@ false, false, false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT\n og.oauth2_authorization_grant_id,\n og.created_at AS oauth2_authorization_grant_created_at,\n og.cancelled_at AS oauth2_authorization_grant_cancelled_at,\n og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at,\n og.exchanged_at AS oauth2_authorization_grant_exchanged_at,\n og.scope AS oauth2_authorization_grant_scope,\n og.state AS oauth2_authorization_grant_state,\n og.redirect_uri AS oauth2_authorization_grant_redirect_uri,\n og.response_mode AS oauth2_authorization_grant_response_mode,\n og.nonce AS oauth2_authorization_grant_nonce,\n og.max_age AS oauth2_authorization_grant_max_age,\n og.oauth2_client_id AS oauth2_client_id,\n og.authorization_code AS oauth2_authorization_grant_code,\n og.response_type_code AS oauth2_authorization_grant_response_type_code,\n og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token,\n og.code_challenge AS oauth2_authorization_grant_code_challenge,\n og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method,\n og.requires_consent AS oauth2_authorization_grant_requires_consent,\n os.oauth2_session_id AS \"oauth2_session_id?\",\n us.user_session_id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n u.primary_user_email_id AS \"user_primary_user_email_id?\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN user_sessions us\n USING (user_session_id)\n LEFT JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n\n WHERE og.authorization_code = $1\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " + }, + "1b448fe73e12bef622b75857e4c9b257c9529ca18da7f63d127e63184f4bc94b": { + "describe": { + "columns": [ + { + "name": "oauth2_authorization_grant_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_cancelled_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_fulfilled_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_exchanged_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_scope", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_state", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_redirect_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_mode", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_nonce", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_max_age", + "ordinal": 10, + "type_info": "Int4" + }, + { + "name": "oauth2_client_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_code", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_type_code", + "ordinal": 13, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_response_type_id_token", + "ordinal": 14, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_code_challenge", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_code_challenge_method", + "ordinal": 16, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_requires_consent", + "ordinal": 17, + "type_info": "Bool" + }, + { + "name": "oauth2_session_id?", + "ordinal": 18, + "type_info": "Uuid" + }, + { + "name": "user_session_id?", + "ordinal": 19, + "type_info": "Uuid" + }, + { + "name": "user_session_created_at?", + "ordinal": 20, + "type_info": "Timestamptz" + }, + { + "name": "user_id?", + "ordinal": 21, + "type_info": "Uuid" + }, + { + "name": "user_username?", + "ordinal": 22, + "type_info": "Text" + }, + { + "name": "user_primary_user_email_id?", + "ordinal": 23, + "type_info": "Uuid" + }, + { + "name": "user_session_last_authentication_id?", + "ordinal": 24, + "type_info": "Uuid" + }, + { + "name": "user_session_last_authentication_created_at?", + "ordinal": 25, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true, + false, + false, + true, + true, + false, + false, + false, + false, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n og.oauth2_authorization_grant_id,\n og.created_at AS oauth2_authorization_grant_created_at,\n og.cancelled_at AS oauth2_authorization_grant_cancelled_at,\n og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at,\n og.exchanged_at AS oauth2_authorization_grant_exchanged_at,\n og.scope AS oauth2_authorization_grant_scope,\n og.state AS oauth2_authorization_grant_state,\n og.redirect_uri AS oauth2_authorization_grant_redirect_uri,\n og.response_mode AS oauth2_authorization_grant_response_mode,\n og.nonce AS oauth2_authorization_grant_nonce,\n og.max_age AS oauth2_authorization_grant_max_age,\n og.oauth2_client_id AS oauth2_client_id,\n og.authorization_code AS oauth2_authorization_grant_code,\n og.response_type_code AS oauth2_authorization_grant_response_type_code,\n og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token,\n og.code_challenge AS oauth2_authorization_grant_code_challenge,\n og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method,\n og.requires_consent AS oauth2_authorization_grant_requires_consent,\n os.oauth2_session_id AS \"oauth2_session_id?\",\n us.user_session_id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n u.primary_user_email_id AS \"user_primary_user_email_id?\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN user_sessions us\n USING (user_session_id)\n LEFT JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n\n WHERE og.oauth2_authorization_grant_id = $1\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " + }, + "1d372f36c382ab16264cea54537af3544ea6d6d75d10b432b07dbd0dadd2fa4e": { + "describe": { + "columns": [ + { + "name": "user_email_confirmation_code_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_email_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "code", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 5, + "type_info": "Timestamptz" + } + ], + "nullable": [ false, false, false, @@ -470,11 +844,61 @@ ], "parameters": { "Left": [ - "Text" + "Text", + "Uuid" ] } }, - "query": "\n SELECT\n og.oauth2_authorization_grant_id,\n og.created_at AS oauth2_authorization_grant_created_at,\n og.cancelled_at AS oauth2_authorization_grant_cancelled_at,\n og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at,\n og.exchanged_at AS oauth2_authorization_grant_exchanged_at,\n og.scope AS oauth2_authorization_grant_scope,\n og.state AS oauth2_authorization_grant_state,\n og.redirect_uri AS oauth2_authorization_grant_redirect_uri,\n og.response_mode AS oauth2_authorization_grant_response_mode,\n og.nonce AS oauth2_authorization_grant_nonce,\n og.max_age AS oauth2_authorization_grant_max_age,\n og.oauth2_client_id AS oauth2_client_id,\n og.authorization_code AS oauth2_authorization_grant_code,\n og.response_type_code AS oauth2_authorization_grant_response_type_code,\n og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token,\n og.code_challenge AS oauth2_authorization_grant_code_challenge,\n og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method,\n og.requires_consent AS oauth2_authorization_grant_requires_consent,\n os.oauth2_session_id AS \"oauth2_session_id?\",\n us.user_session_id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN user_sessions us\n USING (user_session_id)\n LEFT JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE og.authorization_code = $1\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " + "query": "\n SELECT user_email_confirmation_code_id\n , user_email_id\n , code\n , created_at\n , expires_at\n , consumed_at\n FROM user_email_confirmation_codes\n WHERE code = $1\n AND user_email_id = $2\n " + }, + "1eb6d13e75d8f526c2785749a020731c18012f03e07995213acd38ab560ce497": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n " + }, + "1ee5cecfafd4726a4ebc08da8a34c09178e6e1e072581c8fca9d3d76967792cb": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n " + }, + "2153118b364a33582e7f598acce3789fcb8d938948a819b15cf0b6d37edf58b2": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO compat_access_tokens\n (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " }, "262bee715889dc3e608639549600a131e641951ff979634e7c97afc74bbc1605": { "describe": { @@ -507,119 +931,6 @@ }, "query": "\n INSERT INTO oauth2_clients\n (oauth2_client_id,\n encrypted_client_secret,\n grant_type_authorization_code,\n grant_type_refresh_token,\n token_endpoint_auth_method,\n jwks,\n jwks_uri)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n " }, - "27a729b229491d179391b19b634f07291312bd238380c5a7ea0f60e9b71dfb14": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n " - }, - "2e581d57db471b96091860cd0252361d16332deeffabab0dace405ead55324be": { - "describe": { - "columns": [ - { - "name": "compat_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_access_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "user_id!", - "ordinal": 8, - "type_info": "Uuid" - }, - { - "name": "user_username!", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 10, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 12, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 13, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true, - false, - false, - true, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text", - "Timestamptz" - ] - } - }, - "query": "\n SELECT\n ct.compat_access_token_id,\n ct.access_token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.compat_session_id,\n cs.created_at AS \"compat_session_created_at\",\n cs.finished_at AS \"compat_session_finished_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE ct.access_token = $1\n AND (ct.expires_at < $2 OR ct.expires_at IS NULL)\n AND cs.finished_at IS NULL \n " - }, "2e756fe7be50128c0acc5f79df3a084230e9ca13cd45bd0858f97e59da20006e": { "describe": { "columns": [], @@ -649,133 +960,17 @@ }, "query": "\n INSERT INTO compat_refresh_tokens\n (compat_refresh_token_id, compat_session_id,\n compat_access_token_id, refresh_token, created_at)\n VALUES ($1, $2, $3, $4, $5)\n " }, - "3a19b087ae9e4dab770f102de1cb62628525fc72c7b052e1c146161ab088c02b": { + "3d66f3121b11ce923b9c60609b510a8ca899640e78cc8f5b03168622928ffe94": { "describe": { - "columns": [ - { - "name": "user_email_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_email", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "user_email_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at", - "ordinal": 3, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid", - "Text" - ] - } - }, - "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.email = $2\n " - }, - "3df0838b660466f69ee681337fe6753133748defb715e53c8381badcc3e8bca9": { - "describe": { - "columns": [ - { - "name": "user_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_id", - "ordinal": 1, - "type_info": "Uuid" - }, - { - "name": "username", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "last_authentication_id?", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "last_authd_at?", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "user_email_id?", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 9, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - false, - true - ], + "columns": [], + "nullable": [], "parameters": { "Left": [ "Uuid" ] } }, - "query": "\n SELECT\n s.user_session_id,\n u.user_id,\n u.username,\n s.created_at,\n a.user_session_authentication_id AS \"last_authentication_id?\",\n a.created_at AS \"last_authd_at?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications a\n USING (user_session_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n WHERE s.user_session_id = $1 AND s.finished_at IS NULL\n ORDER BY a.created_at DESC\n LIMIT 1\n " - }, - "3e8f862ed05ce3e58c181ac6e0bd71e0a6a88419611af6f4117d14d9c36cb1ef": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO user_emails (user_email_id, user_id, email, created_at)\n VALUES ($1, $2, $3, $4)\n " + "query": "\n DELETE FROM user_emails\n WHERE user_email_id = $1\n " }, "4187907bfc770b2c76f741671d5e672f5c35eed7c9a9e57ff52888b1768a5ed6": { "describe": { @@ -821,102 +1016,36 @@ }, "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_link_id = $1\n " }, - "42bfb0de5bbea2d580f1ff2322255731a4a5655ba80fc2dba0b55a0add8c55c0": { + "4192c1144c0ea530cf1aa77993a38e94cd5cf8b5c42cb037efb7917c6fc44a1d": { "describe": { "columns": [ { - "name": "compat_sso_login_id", + "name": "user_email_id", "ordinal": 0, "type_info": "Uuid" }, { - "name": "compat_sso_login_token", + "name": "user_id", "ordinal": 1, - "type_info": "Text" + "type_info": "Uuid" }, { - "name": "compat_sso_login_redirect_uri", + "name": "email", "ordinal": 2, "type_info": "Text" }, { - "name": "compat_sso_login_created_at", + "name": "created_at", "ordinal": 3, "type_info": "Timestamptz" }, { - "name": "compat_sso_login_fulfilled_at", + "name": "confirmed_at", "ordinal": 4, "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_exchanged_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id?", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at?", - "ordinal": 7, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at?", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id?", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_id?", - "ordinal": 10, - "type_info": "Uuid" - }, - { - "name": "user_username?", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 12, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 14, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 15, - "type_info": "Timestamptz" } ], "nullable": [ - false, - false, - false, - false, - true, - true, - false, - false, - true, - false, - false, false, false, false, @@ -929,7 +1058,7 @@ ] } }, - "query": "\n SELECT\n cl.compat_sso_login_id,\n cl.login_token AS \"compat_sso_login_token\",\n cl.redirect_uri AS \"compat_sso_login_redirect_uri\",\n cl.created_at AS \"compat_sso_login_created_at\",\n cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\",\n cl.exchanged_at AS \"compat_sso_login_exchanged_at\",\n cs.compat_session_id AS \"compat_session_id?\",\n cs.created_at AS \"compat_session_created_at?\",\n cs.finished_at AS \"compat_session_finished_at?\",\n cs.device_id AS \"compat_session_device_id?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM compat_sso_logins cl\n LEFT JOIN compat_sessions cs\n USING (compat_session_id)\n LEFT JOIN users u\n USING (user_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n WHERE cl.compat_sso_login_id = $1\n " + "query": "\n SELECT user_email_id\n , user_id\n , email\n , created_at\n , confirmed_at\n FROM user_emails\n\n WHERE user_email_id = $1\n " }, "43a5cafbdc8037e9fb779812a0793cf0859902aa0dc8d25d4c33d231d3d1118b": { "describe": { @@ -982,122 +1111,6 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n oauth2_session_id = os.oauth2_session_id,\n fulfilled_at = os.created_at\n FROM oauth2_sessions os\n WHERE\n og.oauth2_authorization_grant_id = $1\n AND os.oauth2_session_id = $2\n RETURNING fulfilled_at AS \"fulfilled_at!: DateTime\"\n " }, - "4f8ec19f3f1bfe0268fe102a24e5a9fa542e77eccbebdce65e6deb1c197adf36": { - "describe": { - "columns": [ - { - "name": "oauth2_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "oauth2_access_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id!", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 5, - "type_info": "Uuid" - }, - { - "name": "scope!", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "user_session_id!", - "ordinal": 7, - "type_info": "Uuid" - }, - { - "name": "user_session_created_at!", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "user_id!", - "ordinal": 9, - "type_info": "Uuid" - }, - { - "name": "user_username!", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "user_session_last_authentication_id?", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_created_at?", - "ordinal": 12, - "type_info": "Timestamptz" - }, - { - "name": "user_email_id?", - "ordinal": 13, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 14, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 15, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 16, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n at.oauth2_access_token_id,\n at.access_token AS \"oauth2_access_token\",\n at.created_at AS \"oauth2_access_token_created_at\",\n at.expires_at AS \"oauth2_access_token_expires_at\",\n os.oauth2_session_id AS \"oauth2_session_id!\",\n os.oauth2_client_id AS \"oauth2_client_id!\",\n os.scope AS \"scope!\",\n us.user_session_id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n INNER JOIN user_sessions us\n USING (user_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " - }, "51158bfcaa1a8d8e051bffe7c5ba0369bf53fb162f7622626054e89e68fc07bd": { "describe": { "columns": [ @@ -1119,6 +1132,104 @@ }, "query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n " }, + "51bf417d259989d1228ba86fa11432e9428dece97b79e93f13921d0a510a9428": { + "describe": { + "columns": [ + { + "name": "compat_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "compat_refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "compat_refresh_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "compat_access_token_id", + "ordinal": 3, + "type_info": "Uuid" + }, + { + "name": "compat_access_token", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "compat_access_token_created_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_access_token_expires_at", + "ordinal": 6, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 7, + "type_info": "Uuid" + }, + { + "name": "compat_session_created_at", + "ordinal": 8, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_finished_at", + "ordinal": 9, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_device_id", + "ordinal": 10, + "type_info": "Text" + }, + { + "name": "user_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "user_username!", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "user_primary_user_email_id", + "ordinal": 13, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + false, + true, + false, + false, + true, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT\n cr.compat_refresh_token_id,\n cr.refresh_token AS \"compat_refresh_token\",\n cr.created_at AS \"compat_refresh_token_created_at\",\n ct.compat_access_token_id,\n ct.access_token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.compat_session_id,\n cs.created_at AS \"compat_session_created_at\",\n cs.finished_at AS \"compat_session_finished_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.user_id,\n u.username AS \"user_username!\",\n u.primary_user_email_id AS \"user_primary_user_email_id\"\n\n FROM compat_refresh_tokens cr\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN compat_access_tokens ct\n USING (compat_access_token_id)\n INNER JOIN users u\n USING (user_id)\n\n WHERE cr.refresh_token = $1\n AND cr.consumed_at IS NULL\n AND cs.finished_at IS NULL\n " + }, "559a486756d08d101eb7188ef6637b9d24c024d056795b8121f7f04a7f9db6a3": { "describe": { "columns": [ @@ -1140,56 +1251,6 @@ }, "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n FROM compat_access_tokens ca\n WHERE ca.access_token = $1\n AND ca.compat_session_id = cs.compat_session_id\n AND cs.finished_at IS NULL\n RETURNING cs.compat_session_id\n " }, - "59439585536bb4e547a6cf58a8bc6ac735f29c225bcbeac7d371f09166789a73": { - "describe": { - "columns": [ - { - "name": "user_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_username", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 5, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n u.user_id,\n u.username AS user_username,\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM users u\n\n LEFT JOIN user_emails ue\n USING (user_id)\n\n WHERE u.user_id = $1\n " - }, "5b5d5c82da37c6f2d8affacfb02119965c04d1f2a9cc53dbf5bd4c12584969a0": { "describe": { "columns": [], @@ -1202,44 +1263,6 @@ }, "query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n " }, - "5ccde09ee3fe43e7b492d73fa67708b5dcb2b7496c4d05bcfcf0ea63c7576d48": { - "describe": { - "columns": [ - { - "name": "user_email_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_email", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "user_email_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at", - "ordinal": 3, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n\n ORDER BY ue.email ASC\n " - }, "5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": { "describe": { "columns": [], @@ -1415,210 +1438,6 @@ }, "query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n " }, - "7262f81a335a984c4051383d2ede7455ff65ed90fbd3151d625f8a21fd26cb05": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO user_email_confirmation_codes\n (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " - }, - "75a16693cabdf57012f741e789b19d0a0f96fcd1e41bb2af92f2991b722cc9f1": { - "describe": { - "columns": [ - { - "name": "oauth2_authorization_grant_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_cancelled_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_fulfilled_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_exchanged_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_scope", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_state", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_redirect_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_mode", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_nonce", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_max_age", - "ordinal": 10, - "type_info": "Int4" - }, - { - "name": "oauth2_client_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_code", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_type_code", - "ordinal": 13, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_response_type_id_token", - "ordinal": 14, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_code_challenge", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_code_challenge_method", - "ordinal": 16, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_requires_consent", - "ordinal": 17, - "type_info": "Bool" - }, - { - "name": "oauth2_session_id?", - "ordinal": 18, - "type_info": "Uuid" - }, - { - "name": "user_session_id?", - "ordinal": 19, - "type_info": "Uuid" - }, - { - "name": "user_session_created_at?", - "ordinal": 20, - "type_info": "Timestamptz" - }, - { - "name": "user_id?", - "ordinal": 21, - "type_info": "Uuid" - }, - { - "name": "user_username?", - "ordinal": 22, - "type_info": "Text" - }, - { - "name": "user_session_last_authentication_id?", - "ordinal": 23, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_created_at?", - "ordinal": 24, - "type_info": "Timestamptz" - }, - { - "name": "user_email_id?", - "ordinal": 25, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 26, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 27, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 28, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true, - false, - false, - true, - true, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n og.oauth2_authorization_grant_id,\n og.created_at AS oauth2_authorization_grant_created_at,\n og.cancelled_at AS oauth2_authorization_grant_cancelled_at,\n og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at,\n og.exchanged_at AS oauth2_authorization_grant_exchanged_at,\n og.scope AS oauth2_authorization_grant_scope,\n og.state AS oauth2_authorization_grant_state,\n og.redirect_uri AS oauth2_authorization_grant_redirect_uri,\n og.response_mode AS oauth2_authorization_grant_response_mode,\n og.nonce AS oauth2_authorization_grant_nonce,\n og.max_age AS oauth2_authorization_grant_max_age,\n og.oauth2_client_id AS oauth2_client_id,\n og.authorization_code AS oauth2_authorization_grant_code,\n og.response_type_code AS oauth2_authorization_grant_response_type_code,\n og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token,\n og.code_challenge AS oauth2_authorization_grant_code_challenge,\n og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method,\n og.requires_consent AS oauth2_authorization_grant_requires_consent,\n os.oauth2_session_id AS \"oauth2_session_id?\",\n us.user_session_id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN user_sessions us\n USING (user_session_id)\n LEFT JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE og.oauth2_authorization_grant_id = $1\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " - }, "7756a60c36a64a259f7450d6eb77ee92303638ca374a63f23ac4944ccf9f4436": { "describe": { "columns": [ @@ -1748,185 +1567,6 @@ }, "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " }, - "7cf5ae665b15ba78b01bb1dfa304150a89fd7203f4ee15b0753cb2143049a3dc": { - "describe": { - "columns": [ - { - "name": "oauth2_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "oauth2_refresh_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_id?", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "oauth2_access_token?", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "oauth2_access_token_created_at?", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_expires_at?", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id!", - "ordinal": 7, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 8, - "type_info": "Uuid" - }, - { - "name": "oauth2_session_scope!", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_session_id!", - "ordinal": 10, - "type_info": "Uuid" - }, - { - "name": "user_session_created_at!", - "ordinal": 11, - "type_info": "Timestamptz" - }, - { - "name": "user_id!", - "ordinal": 12, - "type_info": "Uuid" - }, - { - "name": "user_username!", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "user_session_last_authentication_id?", - "ordinal": 14, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_created_at?", - "ordinal": 15, - "type_info": "Timestamptz" - }, - { - "name": "user_email_id?", - "ordinal": 16, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 17, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 18, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 19, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n rt.oauth2_refresh_token_id,\n rt.refresh_token AS oauth2_refresh_token,\n rt.created_at AS oauth2_refresh_token_created_at,\n at.oauth2_access_token_id AS \"oauth2_access_token_id?\",\n at.access_token AS \"oauth2_access_token?\",\n at.created_at AS \"oauth2_access_token_created_at?\",\n at.expires_at AS \"oauth2_access_token_expires_at?\",\n os.oauth2_session_id AS \"oauth2_session_id!\",\n os.oauth2_client_id AS \"oauth2_client_id!\",\n os.scope AS \"oauth2_session_scope!\",\n us.user_session_id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN oauth2_access_tokens at\n USING (oauth2_access_token_id)\n INNER JOIN user_sessions us\n USING (user_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND us.finished_at IS NULL\n AND os.finished_at IS NULL\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " - }, - "7d600dd15e9dac72c8071c854799fc2ac69777ade5e2d7d2d944b0dedf8ecdf8": { - "describe": { - "columns": [ - { - "name": "user_email_confirmation_code_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "code", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "consumed_at", - "ordinal": 4, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text", - "Uuid" - ] - } - }, - "query": "\n SELECT\n ec.user_email_confirmation_code_id,\n ec.code,\n ec.created_at,\n ec.expires_at,\n ec.consumed_at\n FROM user_email_confirmation_codes ec\n WHERE ec.code = $1\n AND ec.user_email_id = $2\n " - }, "7e3247e35ecf5335f0656c53bcde27264a9efb8dccb6246344950614f487dcaf": { "describe": { "columns": [], @@ -1940,17 +1580,43 @@ }, "query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n " }, - "819d6472e5bcbd83a83f3a7680e8dc88e77f3970d6beddcf54e8416c880bd496": { + "836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c": { "describe": { - "columns": [], - "nullable": [], + "columns": [ + { + "name": "user_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "username", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "primary_user_email_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false + ], "parameters": { "Left": [ - "Uuid" + "Text" ] } }, - "query": "\n UPDATE users\n SET primary_user_email_id = user_emails.user_email_id\n FROM user_emails\n WHERE user_emails.user_email_id = $1\n AND users.user_id = user_emails.user_id\n " + "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE username = $1\n " }, "874e677f82c221c5bb621c12f293bcef4e70c68c87ec003fcd475bcb994b5a4c": { "describe": { @@ -1965,26 +1631,6 @@ }, "query": "\n UPDATE oauth2_refresh_tokens\n SET consumed_at = $2\n WHERE oauth2_refresh_token_id = $1\n " }, - "89e0d338348588831a7a810763a1901073f7a7cb81d51c18bb987a5be10c1202": { - "describe": { - "columns": [ - { - "name": "count", - "ordinal": 0, - "type_info": "Int8" - } - ], - "nullable": [ - null - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT COUNT(*)\n FROM user_emails ue\n WHERE ue.user_id = $1\n " - }, "8f7a9fb1f24c24f8dbc3c193df2a742c9ac730ab958587b67297de2d4b843863": { "describe": { "columns": [ @@ -2047,6 +1693,159 @@ }, "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n " }, + "90b5512c0c9dc3b3eb6500056cc72f9993216d9b553c2e33a7edec26ffb0fc59": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE user_emails\n SET confirmed_at = $2\n WHERE user_email_id = $1\n " + }, + "90fe32cb9c88a262a682c0db700fef7d69d6ce0be1f930d9f16c50b921a8b819": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO user_emails (user_email_id, user_id, email, created_at)\n VALUES ($1, $2, $3, $4)\n " + }, + "921d77c194609615a7e9a6fd806e9cc17a7927e3e5deb58f3917ceeb9ab4dede": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE user_email_confirmation_codes\n SET consumed_at = $2\n WHERE user_email_confirmation_code_id = $1\n " + }, + "94fd96446b237c87bd6bf741f3c42b37ee751b87b7fcc459602bdf8c46962443": { + "describe": { + "columns": [ + { + "name": "exists!", + "ordinal": 0, + "type_info": "Bool" + } + ], + "nullable": [ + null + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT EXISTS(\n SELECT 1 FROM users WHERE username = $1\n ) AS \"exists!\"\n " + }, + "976ac2435784128eab195c8e6b9bd6e8d7b3a9142c2a34538de03817a3c94e99": { + "describe": { + "columns": [ + { + "name": "compat_sso_login_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "compat_sso_login_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "compat_sso_login_redirect_uri", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "compat_sso_login_created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_sso_login_fulfilled_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "compat_sso_login_exchanged_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id?", + "ordinal": 6, + "type_info": "Uuid" + }, + { + "name": "compat_session_created_at?", + "ordinal": 7, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_finished_at?", + "ordinal": 8, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_device_id?", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "user_id?", + "ordinal": 10, + "type_info": "Uuid" + }, + { + "name": "user_username?", + "ordinal": 11, + "type_info": "Text" + }, + { + "name": "user_primary_user_email_id?", + "ordinal": 12, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + false, + false, + true, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n cl.compat_sso_login_id,\n cl.login_token AS \"compat_sso_login_token\",\n cl.redirect_uri AS \"compat_sso_login_redirect_uri\",\n cl.created_at AS \"compat_sso_login_created_at\",\n cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\",\n cl.exchanged_at AS \"compat_sso_login_exchanged_at\",\n cs.compat_session_id AS \"compat_session_id?\",\n cs.created_at AS \"compat_session_created_at?\",\n cs.finished_at AS \"compat_session_finished_at?\",\n cs.device_id AS \"compat_session_device_id?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n u.primary_user_email_id AS \"user_primary_user_email_id?\"\n FROM compat_sso_logins cl\n LEFT JOIN compat_sessions cs\n USING (compat_session_id)\n LEFT JOIN users u\n USING (user_id)\n WHERE cl.compat_sso_login_id = $1\n " + }, "99f5f9eb0adc5ec120ed8194cbf6a8545155bef09e6d94d92fb67fd1b14d4f28": { "describe": { "columns": [], @@ -2137,6 +1936,50 @@ }, "query": "\n SELECT up.user_password_id\n , up.hashed_password\n , up.version\n , up.upgraded_from_id\n , up.created_at\n FROM user_passwords up\n WHERE up.user_id = $1\n ORDER BY up.created_at DESC\n LIMIT 1\n " }, + "a300fe99c95679c5664646a6a525c0491829e97db45f3234483872ed38436322": { + "describe": { + "columns": [ + { + "name": "user_email_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "email", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "confirmed_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT user_email_id\n , user_id\n , email\n , created_at\n , confirmed_at\n FROM user_emails\n\n WHERE user_id = $1\n\n ORDER BY email ASC\n " + }, "a5a7dad633396e087239d5629092e4a305908ffce9c2610db07372f719070546": { "describe": { "columns": [], @@ -2149,137 +1992,7 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n requires_consent = 'f'\n WHERE\n og.oauth2_authorization_grant_id = $1\n " }, - "a8117b4dd167167b477fb4ebda52789e376defbdc67f3d9093aa06308b2f856e": { - "describe": { - "columns": [ - { - "name": "compat_sso_login_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_sso_login_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_sso_login_redirect_uri", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "compat_sso_login_created_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_fulfilled_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_exchanged_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id?", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at?", - "ordinal": 7, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at?", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id?", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_id?", - "ordinal": 10, - "type_info": "Uuid" - }, - { - "name": "user_username?", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 12, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 14, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 15, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - false, - false, - true, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n cl.compat_sso_login_id,\n cl.login_token AS \"compat_sso_login_token\",\n cl.redirect_uri AS \"compat_sso_login_redirect_uri\",\n cl.created_at AS \"compat_sso_login_created_at\",\n cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\",\n cl.exchanged_at AS \"compat_sso_login_exchanged_at\",\n cs.compat_session_id AS \"compat_session_id?\",\n cs.created_at AS \"compat_session_created_at?\",\n cs.finished_at AS \"compat_session_finished_at?\",\n cs.device_id AS \"compat_session_device_id?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM compat_sso_logins cl\n LEFT JOIN compat_sessions cs\n USING (compat_session_id)\n LEFT JOIN users u\n USING (user_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n WHERE cl.login_token = $1\n " - }, - "af77bad7259175464c5ad57f9662571c17b29552ebb70e4b6022584b41bdff0d": { - "describe": { - "columns": [ - { - "name": "exists!", - "ordinal": 0, - "type_info": "Bool" - } - ], - "nullable": [ - null - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT EXISTS(\n SELECT 1 FROM users WHERE username = $1\n ) AS \"exists!\"\n " - }, - "b5b955169ebe6c399e53b74627c11c8219c0736ef2b5b6b44be568a35fd5389f": { + "aff08a8caabeb62f4929e6e901e7ca7c55e284c18c5c1d1e78821dd9bc961412": { "describe": { "columns": [ { @@ -2288,18 +2001,23 @@ "type_info": "Uuid" }, { - "name": "user_email", + "name": "user_id", "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "email", + "ordinal": 2, "type_info": "Text" }, { - "name": "user_email_created_at", - "ordinal": 2, + "name": "created_at", + "ordinal": 3, "type_info": "Timestamptz" }, { - "name": "user_email_confirmed_at", - "ordinal": 3, + "name": "confirmed_at", + "ordinal": 4, "type_info": "Timestamptz" } ], @@ -2307,16 +2025,47 @@ false, false, false, + false, true ], "parameters": { "Left": [ "Uuid", - "Uuid" + "Text" ] } }, - "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.user_email_id = $2\n " + "query": "\n SELECT user_email_id\n , user_id\n , email\n , created_at\n , confirmed_at\n FROM user_emails\n\n WHERE user_id = $1 AND email = $2\n " + }, + "b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n " + }, + "b515bbfb331e46acd3c0219f09223cc5d8d31cb41287e693dcb82c6e199f7991": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO user_email_confirmation_codes\n (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " }, "b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64": { "describe": { @@ -2348,6 +2097,18 @@ }, "query": "\n INSERT INTO oauth2_sessions\n (oauth2_session_id, user_session_id, oauth2_client_id, scope, created_at)\n SELECT\n $1,\n $2,\n og.oauth2_client_id,\n og.scope,\n $3\n FROM\n oauth2_authorization_grants og\n WHERE\n og.oauth2_authorization_grant_id = $4\n " }, + "bd1f6daa5fa1b10250c01f8b3fbe451646a9ceeefa6f72b9c4e29b6d05f17641": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n UPDATE users\n SET primary_user_email_id = user_emails.user_email_id\n FROM user_emails\n WHERE user_emails.user_email_id = $1\n AND users.user_id = user_emails.user_id\n " + }, "bd7a4a008851f3f6d7591e3463e4369cee08820af57dcd3faf95f8e9be82857d": { "describe": { "columns": [], @@ -2365,122 +2126,6 @@ }, "query": "\n INSERT INTO user_passwords\n (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at)\n VALUES ($1, $2, $3, $4, $5, $6)\n " }, - "c52c911bf39ada298bfdc4526028f1b29fdcb6f557b288bb7ea2472b160c8698": { - "describe": { - "columns": [ - { - "name": "compat_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_refresh_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_id", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "compat_access_token", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "compat_access_token_created_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_expires_at", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 7, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at", - "ordinal": 9, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "user_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "user_username!", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 13, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 14, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 15, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 16, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - true, - false, - false, - true, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n cr.compat_refresh_token_id,\n cr.refresh_token AS \"compat_refresh_token\",\n cr.created_at AS \"compat_refresh_token_created_at\",\n ct.compat_access_token_id,\n ct.access_token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.compat_session_id,\n cs.created_at AS \"compat_session_created_at\",\n cs.finished_at AS \"compat_session_finished_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.user_id,\n u.username AS \"user_username!\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_refresh_tokens cr\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN compat_access_tokens ct\n USING (compat_access_token_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE cr.refresh_token = $1\n AND cr.consumed_at IS NULL\n AND cs.finished_at IS NULL\n " - }, "c88376abdba124ff0487a9a69d2345c7d69d7394f355111ec369cfa6d45fb40f": { "describe": { "columns": [], @@ -2559,6 +2204,142 @@ }, "query": "\n INSERT INTO oauth2_clients\n (oauth2_client_id,\n encrypted_client_secret,\n grant_type_authorization_code,\n grant_type_refresh_token,\n client_name,\n logo_uri,\n client_uri,\n policy_uri,\n tos_uri,\n jwks_uri,\n jwks,\n id_token_signed_response_alg,\n userinfo_signed_response_alg,\n token_endpoint_auth_method,\n token_endpoint_auth_signing_alg,\n initiate_login_uri)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " }, + "d023d7346ec1f32da9459db3c39dffd8a4e3d4e91cdf096928de4517d3f8c622": { + "describe": { + "columns": [ + { + "name": "oauth2_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "oauth2_refresh_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_id?", + "ordinal": 3, + "type_info": "Uuid" + }, + { + "name": "oauth2_access_token?", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "oauth2_access_token_created_at?", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_expires_at?", + "ordinal": 6, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id!", + "ordinal": 7, + "type_info": "Uuid" + }, + { + "name": "oauth2_client_id!", + "ordinal": 8, + "type_info": "Uuid" + }, + { + "name": "oauth2_session_scope!", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "user_session_id!", + "ordinal": 10, + "type_info": "Uuid" + }, + { + "name": "user_session_created_at!", + "ordinal": 11, + "type_info": "Timestamptz" + }, + { + "name": "user_id!", + "ordinal": 12, + "type_info": "Uuid" + }, + { + "name": "user_username!", + "ordinal": 13, + "type_info": "Text" + }, + { + "name": "user_primary_user_email_id", + "ordinal": 14, + "type_info": "Uuid" + }, + { + "name": "user_session_last_authentication_id?", + "ordinal": 15, + "type_info": "Uuid" + }, + { + "name": "user_session_last_authentication_created_at?", + "ordinal": 16, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false, + false, + false, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT\n rt.oauth2_refresh_token_id,\n rt.refresh_token AS oauth2_refresh_token,\n rt.created_at AS oauth2_refresh_token_created_at,\n at.oauth2_access_token_id AS \"oauth2_access_token_id?\",\n at.access_token AS \"oauth2_access_token?\",\n at.created_at AS \"oauth2_access_token_created_at?\",\n at.expires_at AS \"oauth2_access_token_expires_at?\",\n os.oauth2_session_id AS \"oauth2_session_id!\",\n os.oauth2_client_id AS \"oauth2_client_id!\",\n os.scope AS \"oauth2_session_scope!\",\n us.user_session_id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n u.primary_user_email_id AS \"user_primary_user_email_id\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN oauth2_access_tokens at\n USING (oauth2_access_token_id)\n INNER JOIN user_sessions us\n USING (user_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND us.finished_at IS NULL\n AND os.finished_at IS NULL\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " + }, + "d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c": { + "describe": { + "columns": [ + { + "name": "count", + "ordinal": 0, + "type_info": "Int8" + } + ], + "nullable": [ + null + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT COUNT(*)\n FROM user_emails\n WHERE user_id = $1\n " + }, "d1738c27339b81f0844da4bd9b040b9b07a91aa4d9b199b98f24c9cee5709b2b": { "describe": { "columns": [], @@ -2574,19 +2355,6 @@ }, "query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n " }, - "d55a321e8935f4effda29d9620a0f622125cb38472785049ee21c2616a6bd068": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE user_email_confirmation_codes\n SET consumed_at = $2\n WHERE user_email_confirmation_code_id = $1\n " - }, "d8677b3b6ee594c230fad98c1aa1c6e3d983375bf5b701c7b52468e7f906abf9": { "describe": { "columns": [], @@ -2603,18 +2371,6 @@ }, "query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,\n refresh_token, created_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " }, - "e16ac9f75be25ef6873f1851e916df3ea730422409decc0344f7f05ce3c3841f": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n " - }, "e446e37d48c8838ef2e0d0fd82f8f7b04893c84ad46747cdf193ebd83755ceb2": { "describe": { "columns": [], @@ -2673,5 +2429,86 @@ } }, "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " + }, + "f624e1bdbff4e97b300362d1bbd86035e4a0fdd8ffe16c3bfb9bc451ba60851b": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "compat_access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "compat_access_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "compat_access_token_expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "compat_session_created_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_finished_at", + "ordinal": 6, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_device_id", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "user_id!", + "ordinal": 8, + "type_info": "Uuid" + }, + { + "name": "user_username!", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "user_primary_user_email_id", + "ordinal": 10, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false, + true, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Text", + "Timestamptz" + ] + } + }, + "query": "\n SELECT\n ct.compat_access_token_id,\n ct.access_token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.compat_session_id,\n cs.created_at AS \"compat_session_created_at\",\n cs.finished_at AS \"compat_session_finished_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n u.primary_user_email_id AS \"user_primary_user_email_id\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN users u\n USING (user_id)\n\n WHERE ct.access_token = $1\n AND (ct.expires_at < $2 OR ct.expires_at IS NULL)\n AND cs.finished_at IS NULL \n " } } \ No newline at end of file diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs index 9737fcb9c..4708e91bd 100644 --- a/crates/storage/src/compat.rs +++ b/crates/storage/src/compat.rs @@ -15,7 +15,7 @@ use chrono::{DateTime, Duration, Utc}; use mas_data_model::{ CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState, - Device, User, UserEmail, + Device, User, }; use rand::Rng; use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder}; @@ -40,10 +40,7 @@ struct CompatAccessTokenLookup { compat_session_device_id: String, user_id: Uuid, user_username: String, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, + user_primary_user_email_id: Option, } #[tracing::instrument(skip_all, err)] @@ -66,18 +63,13 @@ pub async fn lookup_active_compat_access_token( cs.device_id AS "compat_session_device_id", u.user_id AS "user_id!", u.username AS "user_username!", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" + u.primary_user_email_id AS "user_primary_user_email_id" FROM compat_access_tokens ct INNER JOIN compat_sessions cs USING (compat_session_id) INNER JOIN users u USING (user_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id WHERE ct.access_token = $1 AND (ct.expires_at < $2 OR ct.expires_at IS NULL) @@ -101,32 +93,11 @@ pub async fn lookup_active_compat_access_token( }; let user_id = Ulid::from(res.user_id); - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("compat_sessions") - .column("user_id") - .row(user_id) - .into()) - } - }; - let user = User { id: user_id, username: res.user_username, sub: user_id.to_string(), - primary_email, + primary_user_email_id: res.user_primary_user_email_id.map(Into::into), }; let id = res.compat_session_id.into(); @@ -162,10 +133,7 @@ pub struct CompatRefreshTokenLookup { compat_session_device_id: String, user_id: Uuid, user_username: String, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, + user_primary_user_email_id: Option, } #[tracing::instrument(skip_all, err)] @@ -191,10 +159,7 @@ pub async fn lookup_active_compat_refresh_token( cs.device_id AS "compat_session_device_id", u.user_id, u.username AS "user_username!", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" + u.primary_user_email_id AS "user_primary_user_email_id" FROM compat_refresh_tokens cr INNER JOIN compat_sessions cs @@ -203,8 +168,6 @@ pub async fn lookup_active_compat_refresh_token( USING (compat_access_token_id) INNER JOIN users u USING (user_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id WHERE cr.refresh_token = $1 AND cr.consumed_at IS NULL @@ -233,32 +196,11 @@ pub async fn lookup_active_compat_refresh_token( }; let user_id = Ulid::from(res.user_id); - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(user_id) - .into()) - } - }; - let user = User { id: user_id, username: res.user_username, sub: user_id.to_string(), - primary_email, + primary_user_email_id: res.user_primary_user_email_id.map(Into::into), }; let session_id = res.compat_session_id.into(); @@ -528,10 +470,7 @@ struct CompatSsoLoginLookup { compat_session_device_id: Option, user_id: Option, user_username: Option, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, + user_primary_user_email_id: Option, } impl TryFrom for CompatSsoLogin { @@ -546,32 +485,18 @@ impl TryFrom for CompatSsoLogin { .source(e) })?; - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, + let user = match ( + res.user_id, + res.user_username, + res.user_primary_user_email_id, ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users").column("primary_user_email_id")) - } - }; - - let user = match (res.user_id, res.user_username, primary_email) { - (Some(id), Some(username), primary_email) => { + (Some(id), Some(username), primary_email_id) => { let id = Ulid::from(id); Some(User { id, username, sub: id.to_string(), - primary_email, + primary_user_email_id: primary_email_id.map(Into::into), }) } @@ -667,17 +592,12 @@ pub async fn get_compat_sso_login_by_id( cs.device_id AS "compat_session_device_id?", u.user_id AS "user_id?", u.username AS "user_username?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" + u.primary_user_email_id AS "user_primary_user_email_id?" FROM compat_sso_logins cl LEFT JOIN compat_sessions cs USING (compat_session_id) LEFT JOIN users u USING (user_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id WHERE cl.compat_sso_login_id = $1 "#, Uuid::from(id), @@ -725,17 +645,12 @@ pub async fn get_paginated_user_compat_sso_logins( cs.device_id AS "compat_session_device_id", u.user_id AS "user_id", u.username AS "user_username", - ue.user_email_id AS "user_email_id", - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" + u.primary_user_email_id AS "user_primary_user_email_id?" FROM compat_sso_logins cl LEFT JOIN compat_sessions cs USING (compat_session_id) LEFT JOIN users u USING (user_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id "#, ); @@ -781,17 +696,12 @@ pub async fn get_compat_sso_login_by_token( cs.device_id AS "compat_session_device_id?", u.user_id AS "user_id?", u.username AS "user_username?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" + u.primary_user_email_id AS "user_primary_user_email_id?" FROM compat_sso_logins cl LEFT JOIN compat_sessions cs USING (compat_session_id) LEFT JOIN users u USING (user_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id WHERE cl.login_token = $1 "#, token, diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 268652016..bda7ee2ce 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -161,7 +161,7 @@ impl DatabaseInconsistencyError { } } -#[derive(Default, Debug, Clone, Copy)] +#[derive(Default, Debug, Clone)] pub struct Clock { _private: (), } diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index be9831426..71e014e43 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -13,7 +13,7 @@ // limitations under the License. use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail}; +use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; @@ -84,12 +84,9 @@ pub struct OAuth2AccessTokenLookup { user_session_created_at: DateTime, user_id: Uuid, user_username: String, + user_primary_user_email_id: Option, user_session_last_authentication_id: Option, user_session_last_authentication_created_at: Option>, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, } #[allow(clippy::too_many_lines)] @@ -100,24 +97,20 @@ pub async fn lookup_active_access_token( let res = sqlx::query_as!( OAuth2AccessTokenLookup, r#" - SELECT - at.oauth2_access_token_id, - at.access_token AS "oauth2_access_token", - at.created_at AS "oauth2_access_token_created_at", - at.expires_at AS "oauth2_access_token_expires_at", - os.oauth2_session_id AS "oauth2_session_id!", - os.oauth2_client_id AS "oauth2_client_id!", - os.scope AS "scope!", - us.user_session_id AS "user_session_id!", - us.created_at AS "user_session_created_at!", - u.user_id AS "user_id!", - u.username AS "user_username!", - usa.user_session_authentication_id AS "user_session_last_authentication_id?", - usa.created_at AS "user_session_last_authentication_created_at?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" + SELECT at.oauth2_access_token_id + , at.access_token AS "oauth2_access_token" + , at.created_at AS "oauth2_access_token_created_at" + , at.expires_at AS "oauth2_access_token_expires_at" + , os.oauth2_session_id AS "oauth2_session_id!" + , os.oauth2_client_id AS "oauth2_client_id!" + , os.scope AS "scope!" + , us.user_session_id AS "user_session_id!" + , us.created_at AS "user_session_created_at!" + , u.user_id AS "user_id!" + , u.username AS "user_username!" + , u.primary_user_email_id AS "user_primary_user_email_id" + , usa.user_session_authentication_id AS "user_session_last_authentication_id?" + , usa.created_at AS "user_session_last_authentication_created_at?" FROM oauth2_access_tokens at INNER JOIN oauth2_sessions os @@ -128,8 +121,6 @@ pub async fn lookup_active_access_token( USING (user_id) LEFT JOIN user_session_authentications usa USING (user_session_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id WHERE at.access_token = $1 AND at.revoked_at IS NULL @@ -162,32 +153,11 @@ pub async fn lookup_active_access_token( })?; let user_id = Ulid::from(res.user_id); - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(user_id) - .into()) - } - }; - let user = User { id: user_id, username: res.user_username, sub: user_id.to_string(), - primary_email, + primary_user_email_id: res.user_primary_user_email_id.map(Into::into), }; let last_authentication = match ( diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index e00274cc0..957400d9e 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -17,7 +17,7 @@ use std::num::NonZeroU32; use chrono::{DateTime, Utc}; use mas_data_model::{ Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, - Client, Pkce, Session, User, UserEmail, + Client, Pkce, Session, User, }; use mas_iana::oauth::PkceCodeChallengeMethod; use oauth2_types::{requests::ResponseMode, scope::Scope}; @@ -154,12 +154,9 @@ struct GrantLookup { user_session_created_at: Option>, user_id: Option, user_username: Option, + user_primary_user_email_id: Option, user_session_last_authentication_id: Option, user_session_last_authentication_created_at: Option>, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, } impl GrantLookup { @@ -197,34 +194,14 @@ impl GrantLookup { _ => return Err(DatabaseInconsistencyError::on("user_session_authentications").into()), }; - let primary_email = match ( - self.user_email_id, - self.user_email, - self.user_email_created_at, - self.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .into()) - } - }; - let session = match ( self.oauth2_session_id, self.user_session_id, self.user_session_created_at, self.user_id, self.user_username, + self.user_primary_user_email_id, last_authentication, - primary_email, ) { ( Some(session_id), @@ -232,15 +209,15 @@ impl GrantLookup { Some(user_session_created_at), Some(user_id), Some(user_username), + user_primary_user_email_id, last_authentication, - primary_email, ) => { let user_id = Ulid::from(user_id); let user = User { id: user_id, username: user_username, sub: user_id.to_string(), - primary_email, + primary_user_email_id: user_primary_user_email_id.map(Into::into), }; let browser_session = BrowserSession { @@ -439,12 +416,9 @@ pub async fn get_grant_by_id( us.created_at AS "user_session_created_at?", u.user_id AS "user_id?", u.username AS "user_username?", + u.primary_user_email_id AS "user_primary_user_email_id?", usa.user_session_authentication_id AS "user_session_last_authentication_id?", - usa.created_at AS "user_session_last_authentication_created_at?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" + usa.created_at AS "user_session_last_authentication_created_at?" FROM oauth2_authorization_grants og LEFT JOIN oauth2_sessions os @@ -455,8 +429,6 @@ pub async fn get_grant_by_id( USING (user_id) LEFT JOIN user_session_authentications usa USING (user_session_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id WHERE og.oauth2_authorization_grant_id = $1 @@ -508,12 +480,9 @@ pub async fn lookup_grant_by_code( us.created_at AS "user_session_created_at?", u.user_id AS "user_id?", u.username AS "user_username?", + u.primary_user_email_id AS "user_primary_user_email_id?", usa.user_session_authentication_id AS "user_session_last_authentication_id?", - usa.created_at AS "user_session_last_authentication_created_at?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" + usa.created_at AS "user_session_last_authentication_created_at?" FROM oauth2_authorization_grants og LEFT JOIN oauth2_sessions os @@ -524,8 +493,6 @@ pub async fn lookup_grant_by_code( USING (user_id) LEFT JOIN user_session_authentications usa USING (user_session_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id WHERE og.authorization_code = $1 diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 74e111c9e..79d90c010 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -13,9 +13,7 @@ // limitations under the License. use chrono::{DateTime, Utc}; -use mas_data_model::{ - AccessToken, Authentication, BrowserSession, RefreshToken, Session, User, UserEmail, -}; +use mas_data_model::{AccessToken, Authentication, BrowserSession, RefreshToken, Session, User}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; @@ -87,12 +85,9 @@ struct OAuth2RefreshTokenLookup { user_session_created_at: DateTime, user_id: Uuid, user_username: String, + user_primary_user_email_id: Option, user_session_last_authentication_id: Option, user_session_last_authentication_created_at: Option>, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, } #[tracing::instrument(skip_all, err)] @@ -119,12 +114,9 @@ pub async fn lookup_active_refresh_token( us.created_at AS "user_session_created_at!", u.user_id AS "user_id!", u.username AS "user_username!", + u.primary_user_email_id AS "user_primary_user_email_id", usa.user_session_authentication_id AS "user_session_last_authentication_id?", - usa.created_at AS "user_session_last_authentication_created_at?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" + usa.created_at AS "user_session_last_authentication_created_at?" FROM oauth2_refresh_tokens rt INNER JOIN oauth2_sessions os USING (oauth2_session_id) @@ -190,32 +182,11 @@ pub async fn lookup_active_refresh_token( })?; let user_id = Ulid::from(res.user_id); - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(user_id) - .into()) - } - }; - let user = User { id: user_id, username: res.user_username, sub: user_id.to_string(), - primary_email, + primary_user_email_id: res.user_primary_user_email_id.map(Into::into), }; let last_authentication = match ( diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 9e6ca8074..f321e9417 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -14,9 +14,12 @@ use sqlx::{PgConnection, Postgres, Transaction}; -use crate::upstream_oauth2::{ - PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, - PgUpstreamOAuthSessionRepository, +use crate::{ + upstream_oauth2::{ + PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, + PgUpstreamOAuthSessionRepository, + }, + user::{PgUserEmailRepository, PgUserRepository}, }; pub trait Repository { @@ -32,15 +35,27 @@ pub trait Repository { where Self: 'c; + type UserRepository<'c> + where + Self: 'c; + + type UserEmailRepository<'c> + where + Self: 'c; + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>; + fn user(&mut self) -> Self::UserRepository<'_>; + fn user_email(&mut self) -> Self::UserEmailRepository<'_>; } impl Repository for PgConnection { type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; + type UserRepository<'c> = PgUserRepository<'c> where Self: 'c; + type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -53,12 +68,22 @@ impl Repository for PgConnection { fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { PgUpstreamOAuthSessionRepository::new(self) } + + fn user(&mut self) -> Self::UserRepository<'_> { + PgUserRepository::new(self) + } + + fn user_email(&mut self) -> Self::UserEmailRepository<'_> { + PgUserEmailRepository::new(self) + } } impl<'t> Repository for Transaction<'t, Postgres> { type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; + type UserRepository<'c> = PgUserRepository<'c> where Self: 'c; + type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -71,4 +96,12 @@ impl<'t> Repository for Transaction<'t, Postgres> { fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { PgUpstreamOAuthSessionRepository::new(self) } + + fn user(&mut self) -> Self::UserRepository<'_> { + PgUserRepository::new(self) + } + + fn user_email(&mut self) -> Self::UserEmailRepository<'_> { + PgUserEmailRepository::new(self) + } } diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs new file mode 100644 index 000000000..83784a56e --- /dev/null +++ b/crates/storage/src/user/email.rs @@ -0,0 +1,555 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::{process_page, Page, QueryBuilderExt}, + Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, +}; + +#[async_trait] +pub trait UserEmailRepository: Send + Sync { + type Error; + + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error>; + async fn get_primary(&mut self, user: &User) -> Result, Self::Error>; + + async fn all(&mut self, user: &User) -> Result, Self::Error>; + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error>; + async fn count(&mut self, user: &User) -> Result; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + email: String, + ) -> Result; + async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error>; + + async fn mark_as_verified( + &mut self, + clock: &Clock, + user_email: UserEmail, + ) -> Result; + + async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error>; + + async fn add_verification_code( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user_email: &UserEmail, + max_age: chrono::Duration, + code: String, + ) -> Result; + + async fn find_verification_code( + &mut self, + clock: &Clock, + user_email: &UserEmail, + code: &str, + ) -> Result, Self::Error>; + + async fn consume_verification_code( + &mut self, + clock: &Clock, + verification: UserEmailVerification, + ) -> Result; +} + +pub struct PgUserEmailRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserEmailRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(Debug, Clone, sqlx::FromRow)] +struct UserEmailLookup { + user_email_id: Uuid, + user_id: Uuid, + email: String, + created_at: DateTime, + confirmed_at: Option>, +} + +impl From for UserEmail { + fn from(e: UserEmailLookup) -> UserEmail { + UserEmail { + id: e.user_email_id.into(), + user_id: e.user_id.into(), + email: e.email, + created_at: e.created_at, + confirmed_at: e.confirmed_at, + } + } +} + +struct UserEmailConfirmationCodeLookup { + user_email_confirmation_code_id: Uuid, + user_email_id: Uuid, + code: String, + created_at: DateTime, + expires_at: DateTime, + consumed_at: Option>, +} + +impl UserEmailConfirmationCodeLookup { + fn into_verification(self, clock: &Clock) -> UserEmailVerification { + let now = clock.now(); + let state = if let Some(when) = self.consumed_at { + UserEmailVerificationState::AlreadyUsed { when } + } else if self.expires_at < now { + UserEmailVerificationState::Expired { + when: self.expires_at, + } + } else { + UserEmailVerificationState::Valid + }; + + UserEmailVerification { + id: self.user_email_confirmation_code_id.into(), + user_email_id: self.user_email_id.into(), + code: self.code, + state, + created_at: self.created_at, + } + } +} + +#[async_trait] +impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + skip_all, + fields(user_email.id = %id), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailLookup, + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + + WHERE user_email_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(user_email) = res else { return Ok(None) }; + + Ok(Some(user_email.into())) + } + + #[tracing::instrument( + skip_all, + fields(%user.id, user_email.email = email), + err, + )] + async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailLookup, + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + + WHERE user_id = $1 AND email = $2 + "#, + Uuid::from(user.id), + email, + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(user_email) = res else { return Ok(None) }; + + Ok(Some(user_email.into())) + } + + #[tracing::instrument( + skip_all, + fields(%user.id), + err, + )] + async fn get_primary(&mut self, user: &User) -> Result, Self::Error> { + let Some(id) = user.primary_user_email_id else { return Ok(None) }; + + let user_email = self.lookup(id).await?.ok_or_else(|| { + DatabaseInconsistencyError::on("users") + .column("primary_user_email_id") + .row(user.id) + })?; + + Ok(Some(user_email)) + } + + #[tracing::instrument( + skip_all, + fields(%user.id), + err, + )] + async fn all(&mut self, user: &User) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailLookup, + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + + WHERE user_id = $1 + + ORDER BY email ASC + "#, + Uuid::from(user.id), + ) + .fetch_all(&mut *self.conn) + .await?; + + Ok(res.into_iter().map(Into::into).collect()) + } + + #[tracing::instrument( + skip_all, + fields(%user.id), + err, + )] + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, DatabaseError> { + let mut query = QueryBuilder::new( + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("ue.user_email_id", before, after, first, last)?; + + let edges: Vec = query.build_query_as().fetch_all(&mut *self.conn).await?; + + let (has_previous_page, has_next_page, edges) = process_page(edges, first, last)?; + + let edges = edges.into_iter().map(Into::into).collect(); + + Ok(Page { + has_next_page, + has_previous_page, + edges, + }) + } + + #[tracing::instrument( + skip_all, + fields(%user.id), + err, + )] + async fn count(&mut self, user: &User) -> Result { + let res = sqlx::query_scalar!( + r#" + SELECT COUNT(*) + FROM user_emails + WHERE user_id = $1 + "#, + Uuid::from(user.id), + ) + .fetch_one(&mut *self.conn) + .await?; + + let res = res.unwrap_or_default(); + + Ok(res + .try_into() + .map_err(DatabaseError::to_invalid_operation)?) + } + + #[tracing::instrument( + skip_all, + fields( + %user.id, + user_email.id, + user_email.email = email, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + email: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_email.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO user_emails (user_email_id, user_id, email, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + Uuid::from(user.id), + &email, + created_at, + ) + .execute(&mut *self.conn) + .await?; + + Ok(UserEmail { + id, + user_id: user.id, + email, + created_at, + confirmed_at: None, + }) + } + + #[tracing::instrument( + skip_all, + fields( + user.id = %user_email.user_id, + %user_email.id, + %user_email.email, + ), + err, + )] + async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> { + sqlx::query!( + r#" + DELETE FROM user_emails + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + ) + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + async fn mark_as_verified( + &mut self, + clock: &Clock, + mut user_email: UserEmail, + ) -> Result { + let confirmed_at = clock.now(); + sqlx::query!( + r#" + UPDATE user_emails + SET confirmed_at = $2 + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + confirmed_at, + ) + .execute(&mut *self.conn) + .await?; + + user_email.confirmed_at = Some(confirmed_at); + Ok(user_email) + } + + async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error> { + sqlx::query!( + r#" + UPDATE users + SET primary_user_email_id = user_emails.user_email_id + FROM user_emails + WHERE user_emails.user_email_id = $1 + AND users.user_id = user_emails.user_id + "#, + Uuid::from(user_email.id), + ) + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + #[tracing::instrument( + skip_all, + fields( + %user_email.id, + %user_email.email, + user_email_verification.id, + user_email_verification.code = code, + ), + err, + )] + async fn add_verification_code( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user_email: &UserEmail, + max_age: chrono::Duration, + code: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id)); + let expires_at = created_at + max_age; + + sqlx::query!( + r#" + INSERT INTO user_email_confirmation_codes + (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(user_email.id), + code, + created_at, + expires_at, + ) + .execute(&mut *self.conn) + .await?; + + let verification = UserEmailVerification { + id, + user_email_id: user_email.id, + code, + created_at, + state: UserEmailVerificationState::Valid, + }; + + Ok(verification) + } + + #[tracing::instrument( + skip_all, + fields( + %user_email.id, + user.id = %user_email.user_id, + ), + err, + )] + async fn find_verification_code( + &mut self, + clock: &Clock, + user_email: &UserEmail, + code: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailConfirmationCodeLookup, + r#" + SELECT user_email_confirmation_code_id + , user_email_id + , code + , created_at + , expires_at + , consumed_at + FROM user_email_confirmation_codes + WHERE code = $1 + AND user_email_id = $2 + "#, + code, + Uuid::from(user_email.id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into_verification(clock))) + } + + #[tracing::instrument( + skip_all, + fields( + %user_email_verification.id, + user_email.id = %user_email_verification.user_email_id, + ), + err, + )] + async fn consume_verification_code( + &mut self, + clock: &Clock, + mut user_email_verification: UserEmailVerification, + ) -> Result { + if !matches!( + user_email_verification.state, + UserEmailVerificationState::Valid + ) { + return Err(DatabaseError::invalid_operation()); + } + + let consumed_at = clock.now(); + + sqlx::query!( + r#" + UPDATE user_email_confirmation_codes + SET consumed_at = $2 + WHERE user_email_confirmation_code_id = $1 + "#, + Uuid::from(user_email_verification.id), + consumed_at + ) + .execute(&mut *self.conn) + .await?; + + user_email_verification.state = + UserEmailVerificationState::AlreadyUsed { when: consumed_at }; + + Ok(user_email_verification) + } +} diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 1b8c2c61d..54b3689cc 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -12,13 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_data_model::{ - Authentication, BrowserSession, User, UserEmail, UserEmailVerification, - UserEmailVerificationState, -}; -use rand::Rng; -use sqlx::{PgExecutor, QueryBuilder}; +use mas_data_model::{Authentication, BrowserSession, User}; +use rand::{Rng, RngCore}; +use sqlx::{PgConnection, PgExecutor, QueryBuilder}; use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; @@ -29,35 +27,188 @@ use crate::{ }; mod authentication; +mod email; mod password; pub use self::{ authentication::{authenticate_session_with_password, authenticate_session_with_upstream}, + email::{PgUserEmailRepository, UserEmailRepository}, password::{add_user_password, lookup_user_password}, }; +#[async_trait] +pub trait UserRepository: Send + Sync { + type Error; + + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn find_by_username(&mut self, username: &str) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + username: String, + ) -> Result; + async fn exists(&mut self, username: &str) -> Result; +} + +pub struct PgUserRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + #[derive(Debug, Clone)] struct UserLookup { user_id: Uuid, - user_username: String, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, + username: String, + primary_user_email_id: Option, + + #[allow(dead_code)] + created_at: DateTime, +} + +impl From for User { + fn from(value: UserLookup) -> Self { + let id = value.user_id.into(); + Self { + id, + username: value.username, + sub: id.to_string(), + primary_user_email_id: value.primary_user_email_id.map(Into::into), + } + } +} + +#[async_trait] +impl<'c> UserRepository for PgUserRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + skip_all, + fields(user.id = %id), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserLookup, + r#" + SELECT user_id + , username + , primary_user_email_id + , created_at + FROM users + WHERE user_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + skip_all, + fields(user.username = username), + err, + )] + async fn find_by_username(&mut self, username: &str) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserLookup, + r#" + SELECT user_id + , username + , primary_user_email_id + , created_at + FROM users + WHERE username = $1 + "#, + username, + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + skip_all, + fields( + user.username = username, + user.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + username: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO users (user_id, username, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + username, + created_at, + ) + .execute(&mut *self.conn) + .await?; + + Ok(User { + id, + username, + sub: id.to_string(), + primary_user_email_id: None, + }) + } + + #[tracing::instrument( + skip_all, + fields(user.username = username), + err, + )] + async fn exists(&mut self, username: &str) -> Result { + let exists = sqlx::query_scalar!( + r#" + SELECT EXISTS( + SELECT 1 FROM users WHERE username = $1 + ) AS "exists!" + "#, + username + ) + .fetch_one(&mut *self.conn) + .await?; + + Ok(exists) + } } #[derive(sqlx::FromRow)] struct SessionLookup { user_session_id: Uuid, + user_session_created_at: DateTime, user_id: Uuid, - username: String, - created_at: DateTime, + user_username: String, + user_primary_user_email_id: Option, last_authentication_id: Option, last_authd_at: Option>, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, } impl TryInto for SessionLookup { @@ -65,31 +216,11 @@ impl TryInto for SessionLookup { fn try_into(self) -> Result { let id = Ulid::from(self.user_id); - let primary_email = match ( - self.user_email_id, - self.user_email, - self.user_email_created_at, - self.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(id)) - } - }; - let user = User { id, - username: self.username, + username: self.user_username, sub: id.to_string(), - primary_email, + primary_user_email_id: self.user_primary_user_email_id.map(Into::into), }; let last_authentication = match (self.last_authentication_id, self.last_authd_at) { @@ -108,7 +239,7 @@ impl TryInto for SessionLookup { Ok(BrowserSession { id: self.user_session_id.into(), user, - created_at: self.created_at, + created_at: self.user_session_created_at, last_authentication, }) } @@ -126,24 +257,18 @@ pub async fn lookup_active_session( let res = sqlx::query_as!( SessionLookup, r#" - SELECT - s.user_session_id, - u.user_id, - u.username, - s.created_at, - a.user_session_authentication_id AS "last_authentication_id?", - a.created_at AS "last_authd_at?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" + SELECT s.user_session_id + , s.created_at AS "user_session_created_at" + , u.user_id + , u.username AS "user_username" + , u.primary_user_email_id AS "user_primary_user_email_id" + , a.user_session_authentication_id AS "last_authentication_id?" + , a.created_at AS "last_authd_at?" FROM user_sessions s INNER JOIN users u USING (user_id) LEFT JOIN user_session_authentications a USING (user_session_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id WHERE s.user_session_id = $1 AND s.finished_at IS NULL ORDER BY a.created_at DESC LIMIT 1 @@ -279,44 +404,6 @@ pub async fn count_active_sessions( Ok(res) } -#[tracing::instrument( - skip_all, - fields( - user.username = username, - user.id, - ), - err, -)] -pub async fn add_user( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - username: &str, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("user.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO users (user_id, username, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - username, - created_at, - ) - .execute(executor) - .await?; - - Ok(User { - id, - username: username.to_owned(), - sub: id.to_string(), - primary_email: None, - }) -} - #[tracing::instrument( skip_all, fields(%user_session.id), @@ -343,662 +430,3 @@ pub async fn end_session( DatabaseError::ensure_affected_rows(&res, 1) } - -#[tracing::instrument( - skip_all, - fields(user.username = username), - err, -)] -pub async fn lookup_user_by_username( - executor: impl PgExecutor<'_>, - username: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - UserLookup, - r#" - SELECT - u.user_id, - u.username AS user_username, - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" - FROM users u - - LEFT JOIN user_emails ue - USING (user_id) - - WHERE u.username = $1 - "#, - username, - ) - .fetch_one(executor) - .instrument(info_span!("Fetch user")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let id = Ulid::from(res.user_id); - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(id) - .into()) - } - }; - - Ok(Some(User { - id, - username: res.user_username, - sub: id.to_string(), - primary_email, - })) -} - -#[tracing::instrument( - skip_all, - fields(user.id = %id), - err, -)] -pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result { - let res = sqlx::query_as!( - UserLookup, - r#" - SELECT - u.user_id, - u.username AS user_username, - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" - FROM users u - - LEFT JOIN user_emails ue - USING (user_id) - - WHERE u.user_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .instrument(info_span!("Fetch user")) - .await?; - - let id = Ulid::from(res.user_id); - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(id) - .into()) - } - }; - - Ok(User { - id, - username: res.user_username, - sub: id.to_string(), - primary_email, - }) -} - -#[tracing::instrument( - skip_all, - fields(user.username = username), - err, -)] -pub async fn username_exists( - executor: impl PgExecutor<'_>, - username: &str, -) -> Result { - sqlx::query_scalar!( - r#" - SELECT EXISTS( - SELECT 1 FROM users WHERE username = $1 - ) AS "exists!" - "#, - username - ) - .fetch_one(executor) - .await -} - -#[derive(Debug, Clone, sqlx::FromRow)] -struct UserEmailLookup { - user_email_id: Uuid, - user_email: String, - user_email_created_at: DateTime, - user_email_confirmed_at: Option>, -} - -impl From for UserEmail { - fn from(e: UserEmailLookup) -> UserEmail { - UserEmail { - id: e.user_email_id.into(), - email: e.user_email, - created_at: e.user_email_created_at, - confirmed_at: e.user_email_confirmed_at, - } - } -} - -#[tracing::instrument( - skip_all, - fields(%user.id, %user.username), - err, -)] -pub async fn get_user_emails( - executor: impl PgExecutor<'_>, - user: &User, -) -> Result, sqlx::Error> { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT - ue.user_email_id, - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM user_emails ue - - WHERE ue.user_id = $1 - - ORDER BY ue.email ASC - "#, - Uuid::from(user.id), - ) - .fetch_all(executor) - .instrument(info_span!("Fetch user emails")) - .await?; - - Ok(res.into_iter().map(Into::into).collect()) -} - -#[tracing::instrument( - skip_all, - fields(%user.id, %user.username), - err, -)] -pub async fn count_user_emails( - executor: impl PgExecutor<'_>, - user: &User, -) -> Result { - let res = sqlx::query_scalar!( - r#" - SELECT COUNT(*) - FROM user_emails ue - WHERE ue.user_id = $1 - "#, - Uuid::from(user.id), - ) - .fetch_one(executor) - .instrument(info_span!("Count user emails")) - .await?; - - Ok(res.unwrap_or_default()) -} - -#[tracing::instrument( - skip_all, - fields(%user.id, %user.username), - err, -)] -pub async fn get_paginated_user_emails( - executor: impl PgExecutor<'_>, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - ue.user_email_id, - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM user_emails ue - "#, - ); - - query - .push(" WHERE ue.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("ue.user_email_id", before, after, first, last)?; - - let span = info_span!("Fetch paginated user sessions", db.statement = query.sql()); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) - .await?; - - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - - Ok(( - has_previous_page, - has_next_page, - page.into_iter().map(Into::into).collect(), - )) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - user_email.id = %id, - ), - err, -)] -pub async fn get_user_email( - executor: impl PgExecutor<'_>, - user: &User, - id: Ulid, -) -> Result { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT - ue.user_email_id, - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM user_emails ue - - WHERE ue.user_id = $1 - AND ue.user_email_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(id), - ) - .fetch_one(executor) - .instrument(info_span!("Fetch user emails")) - .await?; - - Ok(res.into()) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - user_email.id, - user_email.email = %email, - ), - err, -)] -pub async fn add_user_email( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user: &User, - email: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("user_email.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO user_emails (user_email_id, user_id, email, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - Uuid::from(user.id), - &email, - created_at, - ) - .execute(executor) - .instrument(info_span!("Add user email")) - .await?; - - Ok(UserEmail { - id, - email, - created_at, - confirmed_at: None, - }) -} - -#[tracing::instrument( - skip_all, - fields( - %user_email.id, - %user_email.email, - ), - err, -)] -pub async fn set_user_email_as_primary( - executor: impl PgExecutor<'_>, - user_email: &UserEmail, -) -> Result<(), sqlx::Error> { - sqlx::query!( - r#" - UPDATE users - SET primary_user_email_id = user_emails.user_email_id - FROM user_emails - WHERE user_emails.user_email_id = $1 - AND users.user_id = user_emails.user_id - "#, - Uuid::from(user_email.id), - ) - .execute(executor) - .instrument(info_span!("Add user email")) - .await?; - - Ok(()) -} - -#[tracing::instrument( - skip_all, - fields( - %user_email.id, - %user_email.email, - ), - err, -)] -pub async fn remove_user_email( - executor: impl PgExecutor<'_>, - user_email: UserEmail, -) -> Result<(), sqlx::Error> { - sqlx::query!( - r#" - DELETE FROM user_emails - WHERE user_emails.user_email_id = $1 - "#, - Uuid::from(user_email.id), - ) - .execute(executor) - .instrument(info_span!("Remove user email")) - .await?; - - Ok(()) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - user_email.email = email, - ), - err, -)] -pub async fn lookup_user_email( - executor: impl PgExecutor<'_>, - user: &User, - email: &str, -) -> Result, sqlx::Error> { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT - ue.user_email_id, - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM user_emails ue - - WHERE ue.user_id = $1 - AND ue.email = $2 - "#, - Uuid::from(user.id), - email, - ) - .fetch_one(executor) - .instrument(info_span!("Lookup user email")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - user_email.id = %id, - ), - err, -)] -pub async fn lookup_user_email_by_id( - executor: impl PgExecutor<'_>, - user: &User, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT - ue.user_email_id, - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM user_emails ue - - WHERE ue.user_id = $1 - AND ue.user_email_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(id), - ) - .fetch_one(executor) - .instrument(info_span!("Lookup user email")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) -} - -#[tracing::instrument( - skip_all, - fields(%user_email.id), - err, -)] -pub async fn mark_user_email_as_verified( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut user_email: UserEmail, -) -> Result { - let confirmed_at = clock.now(); - sqlx::query!( - r#" - UPDATE user_emails - SET confirmed_at = $2 - WHERE user_email_id = $1 - "#, - Uuid::from(user_email.id), - confirmed_at, - ) - .execute(executor) - .instrument(info_span!("Confirm user email")) - .await?; - - user_email.confirmed_at = Some(confirmed_at); - - Ok(user_email) -} - -struct UserEmailConfirmationCodeLookup { - user_email_confirmation_code_id: Uuid, - code: String, - created_at: DateTime, - expires_at: DateTime, - consumed_at: Option>, -} - -#[tracing::instrument( - skip_all, - fields(%user_email.id), - err, -)] -pub async fn lookup_user_email_verification_code( - executor: impl PgExecutor<'_>, - clock: &Clock, - user_email: UserEmail, - code: &str, -) -> Result, DatabaseError> { - let now = clock.now(); - - let res = sqlx::query_as!( - UserEmailConfirmationCodeLookup, - r#" - SELECT - ec.user_email_confirmation_code_id, - ec.code, - ec.created_at, - ec.expires_at, - ec.consumed_at - FROM user_email_confirmation_codes ec - WHERE ec.code = $1 - AND ec.user_email_id = $2 - "#, - code, - Uuid::from(user_email.id), - ) - .fetch_one(executor) - .instrument(info_span!("Lookup user email verification")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let state = if let Some(when) = res.consumed_at { - UserEmailVerificationState::AlreadyUsed { when } - } else if res.expires_at < now { - UserEmailVerificationState::Expired { - when: res.expires_at, - } - } else { - UserEmailVerificationState::Valid - }; - - Ok(Some(UserEmailVerification { - id: res.user_email_confirmation_code_id.into(), - code: res.code, - email: user_email, - state, - created_at: res.created_at, - })) -} - -#[tracing::instrument( - skip_all, - fields( - %user_email_verification.id, - ), - err, -)] -pub async fn consume_email_verification( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut user_email_verification: UserEmailVerification, -) -> Result { - if !matches!( - user_email_verification.state, - UserEmailVerificationState::Valid - ) { - return Err(DatabaseError::invalid_operation()); - } - - let consumed_at = clock.now(); - - sqlx::query!( - r#" - UPDATE user_email_confirmation_codes - SET consumed_at = $2 - WHERE user_email_confirmation_code_id = $1 - "#, - Uuid::from(user_email_verification.id), - consumed_at - ) - .execute(executor) - .instrument(info_span!("Consume user email verification")) - .await?; - - user_email_verification.state = UserEmailVerificationState::AlreadyUsed { when: consumed_at }; - - Ok(user_email_verification) -} - -#[tracing::instrument( - skip_all, - fields( - %user_email.id, - %user_email.email, - user_email_confirmation.id, - user_email_confirmation.code = code, - ), - err, -)] -pub async fn add_user_email_verification_code( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user_email: UserEmail, - max_age: chrono::Duration, - code: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id)); - let expires_at = created_at + max_age; - - sqlx::query!( - r#" - INSERT INTO user_email_confirmation_codes - (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(user_email.id), - code, - created_at, - expires_at, - ) - .execute(executor) - .instrument(info_span!("Add user email verification code")) - .await?; - - let verification = UserEmailVerification { - id, - email: user_email, - code, - created_at, - state: UserEmailVerificationState::Valid, - }; - - Ok(verification) -} diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 642a760bc..8e8e0c8ae 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -618,6 +618,7 @@ impl TemplateContext for EmailVerificationContext { .map(|user| { let email = UserEmail { id: Ulid::from_datetime_with_source(now.into(), rng), + user_id: user.id, email: "foobar@example.com".to_owned(), created_at: now, confirmed_at: None, @@ -625,8 +626,8 @@ impl TemplateContext for EmailVerificationContext { let verification = UserEmailVerification { id: Ulid::from_datetime_with_source(now.into(), rng), + user_email_id: email.id, code: "123456".to_owned(), - email, created_at: now, state: mas_data_model::UserEmailVerificationState::Valid, }; @@ -684,6 +685,7 @@ impl TemplateContext for EmailVerificationPageContext { { let email = UserEmail { id: Ulid::from_datetime_with_source(now.into(), rng), + user_id: Ulid::from_datetime_with_source(now.into(), rng), email: "foobar@example.com".to_owned(), created_at: now, confirmed_at: None, diff --git a/docs/development/database.md b/docs/development/database.md index e9583124a..513fe7eff 100644 --- a/docs/development/database.md +++ b/docs/development/database.md @@ -35,6 +35,8 @@ Note that migrations are embedded in the final binary and can be run from the se ## Writing database interactions +**TODO**: *This section is outdated.* + A typical interaction with the database look like this: ```rust From 9f371a7a531d594dbc45b0429389914798c540ad Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 2 Jan 2023 16:46:34 +0100 Subject: [PATCH 05/45] storage: trace storage operations better --- crates/storage/src/lib.rs | 1 + crates/storage/src/tracing.rs | 29 +++++++++ crates/storage/src/upstream_oauth2/link.rs | 31 ++++++--- .../storage/src/upstream_oauth2/provider.rs | 37 ++++++++--- crates/storage/src/upstream_oauth2/session.rs | 18 +++++- crates/storage/src/user/email.rs | 63 ++++++++++++++++--- crates/storage/src/user/mod.rs | 25 +++++++- 7 files changed, 176 insertions(+), 28 deletions(-) create mode 100644 crates/storage/src/tracing.rs diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index bda7ee2ce..09a70023f 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -179,6 +179,7 @@ pub mod compat; pub mod oauth2; pub(crate) mod pagination; pub(crate) mod repository; +pub(crate) mod tracing; pub mod upstream_oauth2; pub mod user; diff --git a/crates/storage/src/tracing.rs b/crates/storage/src/tracing.rs new file mode 100644 index 000000000..60eb284c9 --- /dev/null +++ b/crates/storage/src/tracing.rs @@ -0,0 +1,29 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub trait ExecuteExt<'q, DB> { + /// Records the statement as `db.statement` in the current span + fn traced(self) -> Self; +} + +impl<'q, DB, T> ExecuteExt<'q, DB> for T +where + T: sqlx::Execute<'q, DB>, + DB: sqlx::Database, +{ + fn traced(self) -> Self { + tracing::Span::current().record("db.statement", self.sql()); + self + } +} diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 100e98336..0d443671c 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -17,12 +17,12 @@ use chrono::{DateTime, Utc}; use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; use rand::RngCore; use sqlx::{PgConnection, QueryBuilder}; -use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; use crate::{ pagination::{process_page, Page, QueryBuilderExt}, + tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt, }; @@ -103,8 +103,12 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { type Error = DatabaseError; #[tracing::instrument( + name = "db.upstream_oauth_link.lookup", skip_all, - fields(upstream_oauth_link.id = %id), + fields( + db.statement, + upstream_oauth_link.id = %id, + ), err, )] async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { @@ -122,6 +126,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { "#, Uuid::from(id), ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()? @@ -131,8 +136,10 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { } #[tracing::instrument( + name = "db.upstream_oauth_link.find_by_subject", skip_all, fields( + db.statement, upstream_oauth_link.subject = subject, %upstream_oauth_provider.id, %upstream_oauth_provider.issuer, @@ -161,6 +168,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { Uuid::from(upstream_oauth_provider.id), subject, ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()? @@ -170,8 +178,10 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { } #[tracing::instrument( + name = "db.upstream_oauth_link.add", skip_all, fields( + db.statement, upstream_oauth_link.id, upstream_oauth_link.subject = subject, %upstream_oauth_provider.id, @@ -206,6 +216,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { &subject, created_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -219,8 +230,10 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { } #[tracing::instrument( + name = "db.upstream_oauth_link.associate_to_user", skip_all, fields( + db.statement, %upstream_oauth_link.id, %upstream_oauth_link.subject, %user.id, @@ -242,6 +255,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { Uuid::from(user.id), Uuid::from(upstream_oauth_link.id), ) + .traced() .execute(&mut *self.conn) .await?; @@ -249,8 +263,13 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { } #[tracing::instrument( + name = "db.upstream_oauth_link.list_paginated", skip_all, - fields(%user.id, %user.username), + fields( + db.statement, + %user.id, + %user.username, + ), err )] async fn list_paginated( @@ -278,14 +297,10 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { .push_bind(Uuid::from(user.id)) .generate_pagination("upstream_oauth_link_id", before, after, first, last)?; - let span = info_span!( - "Fetch paginated upstream OAuth 2.0 user links", - db.statement = query.sql() - ); let page: Vec = query .build_query_as() + .traced() .fetch_all(&mut *self.conn) - .instrument(span) .await?; let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 3d8ba1417..a7efb6c88 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -19,12 +19,12 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod use oauth2_types::scope::Scope; use rand::RngCore; use sqlx::{PgConnection, QueryBuilder}; -use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; use crate::{ pagination::{process_page, Page, QueryBuilderExt}, + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; @@ -129,8 +129,12 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' type Error = DatabaseError; #[tracing::instrument( + name = "db.upstream_oauth_provider.lookup", skip_all, - fields(upstream_oauth_provider.id = %id), + fields( + db.statement, + upstream_oauth_provider.id = %id, + ), err, )] async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { @@ -151,6 +155,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' "#, Uuid::from(id), ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -164,8 +169,10 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' } #[tracing::instrument( + name = "db.upstream_oauth_provider.add", skip_all, fields( + db.statement, upstream_oauth_provider.id, upstream_oauth_provider.issuer = %issuer, upstream_oauth_provider.client_id = %client_id, @@ -210,6 +217,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' encrypted_client_secret.as_deref(), created_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -225,6 +233,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' }) } + #[tracing::instrument( + name = "db.upstream_oauth_provider.list_paginated", + skip_all, + fields( + db.statement, + ), + err, + )] async fn list_paginated( &mut self, before: Option, @@ -250,14 +266,10 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?; - let span = info_span!( - "Fetch paginated upstream OAuth 2.0 providers", - db.statement = query.sql() - ); let page: Vec = query .build_query_as() + .traced() .fetch_all(&mut *self.conn) - .instrument(span) .await?; let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; @@ -269,7 +281,15 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' edges: edges?, }) } - #[tracing::instrument(skip_all, err)] + + #[tracing::instrument( + name = "db.upstream_oauth_provider.all", + skip_all, + fields( + db.statement, + ), + err, + )] async fn all(&mut self) -> Result, Self::Error> { let res = sqlx::query_as!( ProviderLookup, @@ -286,6 +306,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' FROM upstream_oauth_providers "#, ) + .traced() .fetch_all(&mut *self.conn) .await?; diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index f8dffcf39..f13c6ec88 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -20,7 +20,7 @@ use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseError, LookupResultExt}; +use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; #[async_trait] pub trait UpstreamOAuthSessionRepository: Send + Sync { @@ -88,8 +88,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> type Error = DatabaseError; #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.lookup", skip_all, - fields(upstream_oauth_provider.id = %id), + fields( + db.statement, + upstream_oauth_provider.id = %id, + ), err, )] async fn lookup( @@ -115,6 +119,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> "#, Uuid::from(id), ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -138,8 +143,10 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> } #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.add", skip_all, fields( + db.statement, %upstream_oauth_provider.id, %upstream_oauth_provider.issuer, %upstream_oauth_provider.client_id, @@ -184,6 +191,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> nonce, created_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -202,8 +210,10 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> } #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.complete_with_link", skip_all, fields( + db.statement, %upstream_oauth_authorization_session.id, %upstream_oauth_link.id, ), @@ -230,6 +240,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> id_token, Uuid::from(upstream_oauth_authorization_session.id), ) + .traced() .execute(&mut *self.conn) .await?; @@ -242,8 +253,10 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> /// Mark a session as consumed #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.consume", skip_all, fields( + db.statement, %upstream_oauth_authorization_session.id, ), err, @@ -263,6 +276,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> consumed_at, Uuid::from(upstream_oauth_authorization_session.id), ) + .traced() .execute(&mut *self.conn) .await?; diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 83784a56e..2f7486110 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -22,6 +22,7 @@ use uuid::Uuid; use crate::{ pagination::{process_page, Page, QueryBuilderExt}, + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; @@ -152,8 +153,12 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { type Error = DatabaseError; #[tracing::instrument( + name = "db.user_email.lookup", skip_all, - fields(user_email.id = %id), + fields( + db.statement, + user_email.id = %id, + ), err, )] async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { @@ -171,6 +176,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { "#, Uuid::from(id), ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -181,8 +187,13 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.find", skip_all, - fields(%user.id, user_email.email = email), + fields( + db.statement, + %user.id, + user_email.email = email, + ), err, )] async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error> { @@ -201,6 +212,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { Uuid::from(user.id), email, ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -211,8 +223,12 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.get_primary", skip_all, - fields(%user.id), + fields( + db.statement, + %user.id, + ), err, )] async fn get_primary(&mut self, user: &User) -> Result, Self::Error> { @@ -228,8 +244,12 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.all", skip_all, - fields(%user.id), + fields( + db.statement, + %user.id, + ), err, )] async fn all(&mut self, user: &User) -> Result, Self::Error> { @@ -249,6 +269,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { "#, Uuid::from(user.id), ) + .traced() .fetch_all(&mut *self.conn) .await?; @@ -256,8 +277,12 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.list_paginated", skip_all, - fields(%user.id), + fields( + db.statement, + %user.id, + ), err, )] async fn list_paginated( @@ -284,7 +309,11 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { .push_bind(Uuid::from(user.id)) .generate_pagination("ue.user_email_id", before, after, first, last)?; - let edges: Vec = query.build_query_as().fetch_all(&mut *self.conn).await?; + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; let (has_previous_page, has_next_page, edges) = process_page(edges, first, last)?; @@ -298,8 +327,12 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.count", skip_all, - fields(%user.id), + fields( + db.statement, + %user.id, + ), err, )] async fn count(&mut self, user: &User) -> Result { @@ -311,6 +344,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { "#, Uuid::from(user.id), ) + .traced() .fetch_one(&mut *self.conn) .await?; @@ -322,8 +356,10 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.add", skip_all, fields( + db.statement, %user.id, user_email.id, user_email.email = email, @@ -351,6 +387,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { &email, created_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -364,8 +401,10 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.remove", skip_all, fields( + db.statement, user.id = %user_email.user_id, %user_email.id, %user_email.email, @@ -380,6 +419,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { "#, Uuid::from(user_email.id), ) + .traced() .execute(&mut *self.conn) .await?; @@ -426,8 +466,10 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.add_verification_code", skip_all, fields( + db.statement, %user_email.id, %user_email.email, user_email_verification.id, @@ -460,6 +502,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { created_at, expires_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -475,8 +518,10 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.find_verification_code", skip_all, fields( + db.statement, %user_email.id, user.id = %user_email.user_id, ), @@ -504,6 +549,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { code, Uuid::from(user_email.id), ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -514,8 +560,10 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.consume_verification_code", skip_all, fields( + db.statement, %user_email_verification.id, user_email.id = %user_email_verification.user_email_id, ), @@ -544,6 +592,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { Uuid::from(user_email_verification.id), consumed_at ) + .traced() .execute(&mut *self.conn) .await?; diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 54b3689cc..50f71752e 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -23,6 +23,7 @@ use uuid::Uuid; use crate::{ pagination::{process_page, QueryBuilderExt}, + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; @@ -88,8 +89,12 @@ impl<'c> UserRepository for PgUserRepository<'c> { type Error = DatabaseError; #[tracing::instrument( + name = "db.user.lookup", skip_all, - fields(user.id = %id), + fields( + db.statement, + user.id = %id, + ), err, )] async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { @@ -105,6 +110,7 @@ impl<'c> UserRepository for PgUserRepository<'c> { "#, Uuid::from(id), ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -115,8 +121,12 @@ impl<'c> UserRepository for PgUserRepository<'c> { } #[tracing::instrument( + name = "db.user.find_by_username", skip_all, - fields(user.username = username), + fields( + db.statement, + user.username = username, + ), err, )] async fn find_by_username(&mut self, username: &str) -> Result, Self::Error> { @@ -132,6 +142,7 @@ impl<'c> UserRepository for PgUserRepository<'c> { "#, username, ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -142,8 +153,10 @@ impl<'c> UserRepository for PgUserRepository<'c> { } #[tracing::instrument( + name = "db.user.add", skip_all, fields( + db.statement, user.username = username, user.id, ), @@ -168,6 +181,7 @@ impl<'c> UserRepository for PgUserRepository<'c> { username, created_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -180,8 +194,12 @@ impl<'c> UserRepository for PgUserRepository<'c> { } #[tracing::instrument( + name = "db.user.exists", skip_all, - fields(user.username = username), + fields( + db.statement, + user.username = username, + ), err, )] async fn exists(&mut self, username: &str) -> Result { @@ -193,6 +211,7 @@ impl<'c> UserRepository for PgUserRepository<'c> { "#, username ) + .traced() .fetch_one(&mut *self.conn) .await?; From f77923599b04da4813337e83b0bfdb4222ab779d Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 3 Jan 2023 15:21:47 +0100 Subject: [PATCH 06/45] strorage: browser session and user password repositories --- crates/axum-utils/src/session.rs | 16 +- crates/cli/src/commands/manage.rs | 15 +- crates/data-model/src/users.rs | 7 + crates/graphql/src/lib.rs | 7 +- crates/graphql/src/model/users.rs | 16 +- crates/handlers/src/compat/login.rs | 26 +- crates/handlers/src/graphql.rs | 6 +- crates/handlers/src/upstream_oauth2/link.rs | 19 +- crates/handlers/src/views/account/mod.rs | 4 +- crates/handlers/src/views/account/password.rs | 35 +- crates/handlers/src/views/login.rs | 39 +- crates/handlers/src/views/logout.rs | 4 +- crates/handlers/src/views/reauth.rs | 36 +- crates/handlers/src/views/register.rs | 26 +- crates/storage/sqlx-data.json | 320 ++++++------- crates/storage/src/oauth2/access_token.rs | 1 + .../storage/src/oauth2/authorization_grant.rs | 1 + crates/storage/src/oauth2/mod.rs | 13 +- crates/storage/src/oauth2/refresh_token.rs | 1 + crates/storage/src/repository.rs | 35 +- crates/storage/src/user/authentication.rs | 105 ----- crates/storage/src/user/mod.rs | 250 +---------- crates/storage/src/user/password.rs | 229 ++++++---- crates/storage/src/user/session.rs | 425 ++++++++++++++++++ crates/templates/src/context.rs | 4 +- 25 files changed, 914 insertions(+), 726 deletions(-) delete mode 100644 crates/storage/src/user/authentication.rs create mode 100644 crates/storage/src/user/session.rs diff --git a/crates/axum-utils/src/session.rs b/crates/axum-utils/src/session.rs index a63c22668..64887895e 100644 --- a/crates/axum-utils/src/session.rs +++ b/crates/axum-utils/src/session.rs @@ -14,9 +14,9 @@ use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use mas_data_model::BrowserSession; -use mas_storage::{user::lookup_active_session, DatabaseError}; +use mas_storage::{user::BrowserSessionRepository, DatabaseError, Repository}; use serde::{Deserialize, Serialize}; -use sqlx::{Executor, Postgres}; +use sqlx::PgConnection; use ulid::Ulid; use crate::CookieExt; @@ -46,7 +46,7 @@ impl SessionInfo { /// Load the [`BrowserSession`] from database pub async fn load_session( &self, - executor: impl Executor<'_, Database = Postgres>, + conn: &mut PgConnection, ) -> Result, DatabaseError> { let session_id = if let Some(id) = self.current { id @@ -54,8 +54,14 @@ impl SessionInfo { return Ok(None); }; - let res = lookup_active_session(executor, session_id).await?; - Ok(res) + let maybe_session = conn + .browser_session() + .lookup(session_id) + .await? + // Ensure that the session is still active + .filter(BrowserSession::active); + + Ok(maybe_session) } } diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index d46130e28..60b94bfe1 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -20,7 +20,7 @@ use mas_router::UrlBuilder; use mas_storage::{ oauth2::client::{insert_client_from_config, lookup_client, truncate_clients}, upstream_oauth2::UpstreamOAuthProviderRepository, - user::{add_user_password, UserEmailRepository, UserRepository}, + user::{UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, Repository, }; use oauth2_types::scope::Scope; @@ -210,16 +210,9 @@ impl Options { let (version, hashed_password) = password_manager.hash(&mut rng, password).await?; - add_user_password( - &mut txn, - &mut rng, - &clock, - &user, - version, - hashed_password, - None, - ) - .await?; + txn.user_password() + .add(&mut rng, &clock, &user, version, hashed_password, None) + .await?; info!(%user.id, %user.username, "Password changed"); txn.commit().await?; diff --git a/crates/data-model/src/users.rs b/crates/data-model/src/users.rs index 995535d76..638ed77e3 100644 --- a/crates/data-model/src/users.rs +++ b/crates/data-model/src/users.rs @@ -57,10 +57,16 @@ pub struct BrowserSession { pub id: Ulid, pub user: User, pub created_at: DateTime, + pub finished_at: Option>, pub last_authentication: Option, } impl BrowserSession { + #[must_use] + pub fn active(&self) -> bool { + self.finished_at.is_none() + } + #[must_use] pub fn was_authenticated_after(&self, after: DateTime) -> bool { if let Some(auth) = &self.last_authentication { @@ -80,6 +86,7 @@ impl BrowserSession { id: Ulid::from_datetime_with_source(now.into(), rng), user, created_at: now, + finished_at: None, last_authentication: None, }) .collect() diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index f01d83701..b79b6fe9f 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -31,8 +31,9 @@ use async_graphql::{ Context, Description, EmptyMutation, EmptySubscription, ID, }; use mas_storage::{ - upstream_oauth2::UpstreamOAuthProviderRepository, user::UserEmailRepository, Repository, - UpstreamOAuthLinkRepository, + upstream_oauth2::UpstreamOAuthProviderRepository, + user::{BrowserSessionRepository, UserEmailRepository}, + Repository, UpstreamOAuthLinkRepository, }; use model::CreationEvent; use sqlx::PgPool; @@ -128,7 +129,7 @@ impl RootQuery { let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let browser_session = mas_storage::user::lookup_active_session(&mut conn, id).await?; + let browser_session = conn.browser_session().lookup(id).await?; let ret = browser_session.and_then(|browser_session| { if browser_session.user.id == current_user.id { diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 58119c093..dc40d6cd3 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -17,7 +17,10 @@ use async_graphql::{ Context, Description, Object, ID, }; use chrono::{DateTime, Utc}; -use mas_storage::{user::UserEmailRepository, Repository, UpstreamOAuthLinkRepository}; +use mas_storage::{ + user::{BrowserSessionRepository, UserEmailRepository}, + Repository, UpstreamOAuthLinkRepository, +}; use sqlx::PgPool; use super::{ @@ -140,14 +143,13 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::BrowserSession)) .transpose()?; - let (has_previous_page, has_next_page, edges) = - mas_storage::user::get_paginated_user_sessions( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = conn + .browser_session() + .list_active_paginated(&self.0, before_id, after_id, first, last) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|u| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|u| { Edge::new( OpaqueCursor(NodeCursor(NodeType::BrowserSession, u.id)), BrowserSession(u), diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index dd4d4742b..c3b910024 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -21,7 +21,7 @@ use mas_storage::{ add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token, mark_compat_sso_login_as_exchanged, start_compat_session, }, - user::{add_user_password, lookup_user_password, UserRepository}, + user::{UserPasswordRepository, UserRepository}, Clock, Repository, }; use serde::{Deserialize, Serialize}; @@ -321,7 +321,9 @@ async fn user_password_login( .ok_or(RouteError::UserNotFound)?; // Lookup its password - let user_password = lookup_user_password(&mut *txn, &user) + let user_password = txn + .user_password() + .active(&user) .await? .ok_or(RouteError::NoPassword)?; @@ -340,16 +342,16 @@ async fn user_password_login( if let Some((version, hashed_password)) = new_password_hash { // Save the upgraded password if needed - add_user_password( - &mut *txn, - &mut rng, - &clock, - &user, - version, - hashed_password, - Some(user_password), - ) - .await?; + txn.user_password() + .add( + &mut rng, + &clock, + &user, + version, + hashed_password, + Some(&user_password), + ) + .await?; } // Now that the user credentials have been verified, start a new compat session diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql.rs index ba6919940..2177388bc 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql.rs @@ -67,7 +67,8 @@ pub async fn post( let content_type = content_type.map(|TypedHeader(h)| h.to_string()); let (session_info, _cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&pool).await?; + let mut conn = pool.acquire().await?; + let maybe_session = session_info.load_session(&mut conn).await?; let mut request = async_graphql::http::receive_batch_body( content_type, @@ -116,7 +117,8 @@ pub async fn get( RawQuery(query): RawQuery, ) -> Result { let (session_info, _cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&pool).await?; + let mut conn = pool.acquire().await?; + let maybe_session = session_info.load_session(&mut conn).await?; let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?; diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index dbd06059b..80fa04f71 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -26,7 +26,7 @@ use mas_axum_utils::{ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthSessionRepository, - user::{authenticate_session_with_upstream, start_session, UserRepository}, + user::{BrowserSessionRepository, UserRepository}, Repository, UpstreamOAuthLinkRepository, }; use mas_templates::{ @@ -134,14 +134,16 @@ pub(crate) async fn get( let maybe_user_session = user_session_info.load_session(&mut txn).await?; let render = match (maybe_user_session, link.user_id) { - (Some(mut session), Some(user_id)) if session.user.id == user_id => { + (Some(session), Some(user_id)) if session.user.id == user_id => { // Session already linked, and link matches the currently logged // user. Mark the session as consumed and renew the authentication. txn.upstream_oauth_session() .consume(&clock, upstream_session) .await?; - authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link) + let session = txn + .browser_session() + .authenticate_with_upstream(&mut rng, &clock, session, &link) .await?; cookie_jar = cookie_jar.set_session(&session); @@ -252,7 +254,7 @@ pub(crate) async fn post( let (user_session_info, cookie_jar) = cookie_jar.session_info(); let maybe_user_session = user_session_info.load_session(&mut txn).await?; - let mut session = match (maybe_user_session, link.user_id, form) { + let session = match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { txn.upstream_oauth_link() .associate_to_user(&link, &session.user) @@ -268,7 +270,7 @@ pub(crate) async fn post( .await? .ok_or(RouteError::UserNotFound)?; - start_session(&mut txn, &mut rng, &clock, user).await? + txn.browser_session().add(&mut rng, &clock, &user).await? } (None, None, FormData::Register { username }) => { @@ -277,7 +279,7 @@ pub(crate) async fn post( .associate_to_user(&link, &user) .await?; - start_session(&mut txn, &mut rng, &clock, user).await? + txn.browser_session().add(&mut rng, &clock, &user).await? } _ => return Err(RouteError::InvalidFormAction), @@ -287,7 +289,10 @@ pub(crate) async fn post( .consume(&clock, upstream_session) .await?; - authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link).await?; + let session = txn + .browser_session() + .authenticate_with_upstream(&mut rng, &clock, session, &link) + .await?; let cookie_jar = sessions_cookie .consume_link(link_id)? diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 07a70898f..0188aef29 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -24,7 +24,7 @@ use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ - user::{count_active_sessions, UserEmailRepository}, + user::{BrowserSessionRepository, UserEmailRepository}, Repository, }; use mas_templates::{AccountContext, TemplateContext, Templates}; @@ -50,7 +50,7 @@ pub(crate) async fn get( return Ok((cookie_jar, login.go()).into_response()); }; - let active_sessions = count_active_sessions(&mut conn, &session.user).await?; + let active_sessions = conn.browser_session().count_active(&session.user).await?; let emails = conn.user_email().all(&session.user).await?; diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 2ba4b3f8a..42c0194be 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -26,8 +26,8 @@ use mas_data_model::BrowserSession; use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ - user::{add_user_password, authenticate_session_with_password, lookup_user_password}, - Clock, + user::{BrowserSessionRepository, UserPasswordRepository}, + Clock, Repository, }; use mas_templates::{EmptyContext, TemplateContext, Templates}; use rand::Rng; @@ -98,14 +98,16 @@ pub(crate) async fn post( let maybe_session = session_info.load_session(&mut txn).await?; - let mut session = if let Some(session) = maybe_session { + let session = if let Some(session) = maybe_session { session } else { let login = mas_router::Login::and_then(mas_router::PostAuthAction::ChangePassword); return Ok((cookie_jar, login.go()).into_response()); }; - let user_password = lookup_user_password(&mut txn, &session.user) + let user_password = txn + .user_password() + .active(&session.user) .await? .context("user has no password")?; @@ -127,18 +129,21 @@ pub(crate) async fn post( } let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?; - let user_password = add_user_password( - &mut txn, - &mut rng, - &clock, - &session.user, - version, - hashed_password, - None, - ) - .await?; + let user_password = txn + .user_password() + .add( + &mut rng, + &clock, + &session.user, + version, + hashed_password, + None, + ) + .await?; - authenticate_session_with_password(&mut txn, &mut rng, &clock, &mut session, &user_password) + let session = txn + .browser_session() + .authenticate_with_password(&mut rng, &clock, session, &user_password) .await?; let reply = render(&mut rng, &clock, templates.clone(), session, cookie_jar).await?; diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 5ba76b726..1ef5efbb8 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -25,10 +25,7 @@ use mas_data_model::BrowserSession; use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, - user::{ - add_user_password, authenticate_session_with_password, lookup_user_password, start_session, - UserRepository, - }, + user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, Clock, Repository, }; use mas_templates::{ @@ -181,7 +178,9 @@ async fn login( .ok_or(FormError::InvalidCredentials)?; // And its password - let user_password = lookup_user_password(&mut *conn, &user) + let user_password = conn + .user_password() + .active(&user) .await .map_err(|_e| FormError::Internal)? .ok_or(FormError::InvalidCredentials)?; @@ -201,28 +200,32 @@ async fn login( let user_password = if let Some((version, new_password_hash)) = new_password_hash { // Save the upgraded password - add_user_password( - &mut *conn, - &mut rng, - clock, - &user, - version, - new_password_hash, - Some(user_password), - ) - .await - .map_err(|_| FormError::Internal)? + conn.user_password() + .add( + &mut rng, + clock, + &user, + version, + new_password_hash, + Some(&user_password), + ) + .await + .map_err(|_| FormError::Internal)? } else { user_password }; // Start a new session - let mut user_session = start_session(&mut *conn, &mut rng, clock, user) + let user_session = conn + .browser_session() + .add(&mut rng, clock, &user) .await .map_err(|_| FormError::Internal)?; // And mark it as authenticated by the password - authenticate_session_with_password(&mut *conn, rng, clock, &mut user_session, &user_password) + let user_session = conn + .browser_session() + .authenticate_with_password(&mut rng, clock, user_session, &user_password) .await .map_err(|_| FormError::Internal)?; diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 07043e64a..88c4a9c2d 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -23,7 +23,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; -use mas_storage::{user::end_session, Clock}; +use mas_storage::{user::BrowserSessionRepository, Clock, Repository}; use sqlx::PgPool; pub(crate) async fn post( @@ -41,7 +41,7 @@ pub(crate) async fn post( let maybe_session = session_info.load_session(&mut txn).await?; if let Some(session) = maybe_session { - end_session(&mut txn, &clock, &session).await?; + txn.browser_session().finish(&clock, session).await?; cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended()); } diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 875189a76..7911a9301 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -24,8 +24,9 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::user::{ - add_user_password, authenticate_session_with_password, lookup_user_password, +use mas_storage::{ + user::{BrowserSessionRepository, UserPasswordRepository}, + Repository, }; use mas_templates::{ReauthContext, TemplateContext, Templates}; use serde::Deserialize; @@ -93,7 +94,7 @@ pub(crate) async fn post( let maybe_session = session_info.load_session(&mut txn).await?; - let mut session = if let Some(session) = maybe_session { + let session = if let Some(session) = maybe_session { session } else { // If there is no session, redirect to the login screen, keeping the @@ -103,7 +104,9 @@ pub(crate) async fn post( }; // Load the user password - let user_password = lookup_user_password(&mut txn, &session.user) + let user_password = txn + .user_password() + .active(&session.user) .await? .context("User has no password")?; @@ -122,22 +125,25 @@ pub(crate) async fn post( let user_password = if let Some((version, new_password_hash)) = new_password_hash { // Save the upgraded password - add_user_password( - &mut *txn, - &mut rng, - &clock, - &session.user, - version, - new_password_hash, - Some(user_password), - ) - .await? + txn.user_password() + .add( + &mut rng, + &clock, + &session.user, + version, + new_password_hash, + Some(&user_password), + ) + .await? } else { user_password }; // Mark the session as authenticated by the password - authenticate_session_with_password(&mut txn, rng, &clock, &mut session, &user_password).await?; + let session = txn + .browser_session() + .authenticate_with_password(&mut rng, &clock, session, &user_password) + .await?; let cookie_jar = cookie_jar.set_session(&session); txn.commit().await?; diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 01dc21167..b2fe9fe0e 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -32,10 +32,7 @@ use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::Route; use mas_storage::{ - user::{ - add_user_password, authenticate_session_with_password, start_session, UserEmailRepository, - UserRepository, - }, + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, Repository, }; use mas_templates::{ @@ -191,16 +188,10 @@ pub(crate) async fn post( let user = txn.user().add(&mut rng, &clock, form.username).await?; let password = Zeroizing::new(form.password.into_bytes()); let (version, hashed_password) = password_manager.hash(&mut rng, password).await?; - let user_password = add_user_password( - &mut txn, - &mut rng, - &clock, - &user, - version, - hashed_password, - None, - ) - .await?; + let user_password = txn + .user_password() + .add(&mut rng, &clock, &user, version, hashed_password, None) + .await?; let user_email = txn .user_email() @@ -228,8 +219,11 @@ pub(crate) async fn post( let next = mas_router::AccountVerifyEmail::new(user_email.id).and_maybe(query.post_auth_action); - let mut session = start_session(&mut txn, &mut rng, &clock, user).await?; - authenticate_session_with_password(&mut txn, &mut rng, &clock, &mut session, &user_password) + let session = txn.browser_session().add(&mut rng, &clock, &user).await?; + + let session = txn + .browser_session() + .authenticate_with_password(&mut rng, &clock, session, &user_password) .await?; txn.commit().await?; diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 3191a9dbf..1676c69a6 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -252,62 +252,6 @@ }, "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE user_id = $1\n " }, - "09d995295b2e4f180181ec96023b1e524ddae9098694eedc4dcce857e3095c0e": { - "describe": { - "columns": [ - { - "name": "user_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_session_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "user_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "user_username", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "user_primary_user_email_id", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "last_authentication_id?", - "ordinal": 5, - "type_info": "Uuid" - }, - { - "name": "last_authd_at?", - "ordinal": 6, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , a.user_session_authentication_id AS \"last_authentication_id?\"\n , a.created_at AS \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications a\n USING (user_session_id)\n WHERE s.user_session_id = $1 AND s.finished_at IS NULL\n ORDER BY a.created_at DESC\n LIMIT 1\n " - }, "154e2e4488ff87e09163698750b56a43127cee4e1392785416a586d40a4d9b21": { "describe": { "columns": [ @@ -851,20 +795,6 @@ }, "query": "\n SELECT user_email_confirmation_code_id\n , user_email_id\n , code\n , created_at\n , expires_at\n , consumed_at\n FROM user_email_confirmation_codes\n WHERE code = $1\n AND user_email_id = $2\n " }, - "1eb6d13e75d8f526c2785749a020731c18012f03e07995213acd38ab560ce497": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n " - }, "1ee5cecfafd4726a4ebc08da8a34c09178e6e1e072581c8fca9d3d76967792cb": { "describe": { "columns": [], @@ -1060,6 +990,20 @@ }, "query": "\n SELECT user_email_id\n , user_id\n , email\n , created_at\n , confirmed_at\n FROM user_emails\n\n WHERE user_email_id = $1\n " }, + "41c1aafbd338c24476f27d342cf80eef7de2836e85b078232d143d6712fc2be4": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n " + }, "43a5cafbdc8037e9fb779812a0793cf0859902aa0dc8d25d4c33d231d3d1118b": { "describe": { "columns": [], @@ -1076,6 +1020,50 @@ }, "query": "\n INSERT INTO oauth2_access_tokens\n (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " }, + "446a8d7bd8532a751810401adfab924dc20785c91770ed43d62df2e590e8da71": { + "describe": { + "columns": [ + { + "name": "user_password_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "hashed_password", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "version", + "ordinal": 2, + "type_info": "Int4" + }, + { + "name": "upgraded_from_id", + "ordinal": 3, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT up.user_password_id\n , up.hashed_password\n , up.version\n , up.upgraded_from_id\n , up.created_at\n FROM user_passwords up\n WHERE up.user_id = $1\n ORDER BY up.created_at DESC\n LIMIT 1\n " + }, "4693f2b9b3d51ff4a05e233b6667161ebc97f331d96bf5f1c61069e1c8492105": { "describe": { "columns": [], @@ -1308,19 +1296,6 @@ }, "query": "\n INSERT INTO oauth2_consents\n (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)\n SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)\n ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5\n " }, - "64a56818dd16ac6368efe3e34196a77b7feda1eb87b696e0063a51bf50e499e5": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Timestamptz", - "Uuid" - ] - } - }, - "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " - }, "64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f": { "describe": { "columns": [], @@ -1438,6 +1413,43 @@ }, "query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n " }, + "6f97b5f9ad0d4d15387150bea3839fb7f81015f7ceef61ecaadba64521895cff": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Int4", + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO user_passwords\n (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at)\n VALUES ($1, $2, $3, $4, $5, $6)\n " + }, + "751d549073d77ded84aea1aaba36d3b130ec71bc592d722eb75b959b80f0b4ff": { + "describe": { + "columns": [ + { + "name": "count!", + "ordinal": 0, + "type_info": "Int8" + } + ], + "nullable": [ + null + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.finished_at IS NULL\n " + }, "7756a60c36a64a259f7450d6eb77ee92303638ca374a63f23ac4944ccf9f4436": { "describe": { "columns": [ @@ -1554,6 +1566,68 @@ }, "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = ANY($1::uuid[])\n " }, + "79295f3d3a75f831e9469aabfa720d381a254d00dbe39fef1e9652029d51b89b": { + "describe": { + "columns": [ + { + "name": "user_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_session_created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "user_session_finished_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "user_id", + "ordinal": 3, + "type_info": "Uuid" + }, + { + "name": "user_username", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "user_primary_user_email_id", + "ordinal": 5, + "type_info": "Uuid" + }, + { + "name": "last_authentication_id?", + "ordinal": 6, + "type_info": "Uuid" + }, + { + "name": "last_authd_at?", + "ordinal": 7, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , a.user_session_authentication_id AS \"last_authentication_id?\"\n , a.created_at AS \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications a\n USING (user_session_id)\n WHERE s.user_session_id = $1\n ORDER BY a.created_at DESC\n LIMIT 1\n " + }, "7ce387b1b0aaf10e72adde667b19521b66eaafa51f73bf2f95e38b8f3b64a229": { "describe": { "columns": [], @@ -1872,70 +1946,6 @@ }, "query": "\n UPDATE oauth2_sessions\n SET finished_at = $2\n WHERE oauth2_session_id = $1\n " }, - "9edf5e8a3e00a7cdd8e55b97105df7831ee580096299df4bd6c1ed7c96b95e83": { - "describe": { - "columns": [ - { - "name": "count!", - "ordinal": 0, - "type_info": "Int8" - } - ], - "nullable": [ - null - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.finished_at IS NULL\n " - }, - "a1c19d9d7f1522d126787c7f9946ed51cbbd8f27a4947bc371acab3e7bf23267": { - "describe": { - "columns": [ - { - "name": "user_password_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "hashed_password", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "version", - "ordinal": 2, - "type_info": "Int4" - }, - { - "name": "upgraded_from_id", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "created_at", - "ordinal": 4, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT up.user_password_id\n , up.hashed_password\n , up.version\n , up.upgraded_from_id\n , up.created_at\n FROM user_passwords up\n WHERE up.user_id = $1\n ORDER BY up.created_at DESC\n LIMIT 1\n " - }, "a300fe99c95679c5664646a6a525c0491829e97db45f3234483872ed38436322": { "describe": { "columns": [ @@ -2109,7 +2119,7 @@ }, "query": "\n UPDATE users\n SET primary_user_email_id = user_emails.user_email_id\n FROM user_emails\n WHERE user_emails.user_email_id = $1\n AND users.user_id = user_emails.user_id\n " }, - "bd7a4a008851f3f6d7591e3463e4369cee08820af57dcd3faf95f8e9be82857d": { + "c1d90a7f2287ec779c81a521fab19e5ede3fa95484033e0312c30d9b6ecc03f0": { "describe": { "columns": [], "nullable": [], @@ -2117,14 +2127,11 @@ "Left": [ "Uuid", "Uuid", - "Text", - "Int4", - "Uuid", "Timestamptz" ] } }, - "query": "\n INSERT INTO user_passwords\n (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at)\n VALUES ($1, $2, $3, $4, $5, $6)\n " + "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " }, "c88376abdba124ff0487a9a69d2345c7d69d7394f355111ec369cfa6d45fb40f": { "describe": { @@ -2371,19 +2378,18 @@ }, "query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,\n refresh_token, created_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " }, - "e446e37d48c8838ef2e0d0fd82f8f7b04893c84ad46747cdf193ebd83755ceb2": { + "dbf4be84eeff9ea51b00185faae2d453ab449017ed492bf6711dc7fceb630880": { "describe": { "columns": [], "nullable": [], "parameters": { "Left": [ - "Uuid", - "Uuid", - "Timestamptz" + "Timestamptz", + "Uuid" ] } }, - "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " + "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " }, "e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": { "describe": { diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index 71e014e43..5c2347d2e 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -175,6 +175,7 @@ pub async fn lookup_active_access_token( let browser_session = BrowserSession { id: res.user_session_id.into(), created_at: res.user_session_created_at, + finished_at: None, user, last_authentication, }; diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 957400d9e..b7ffb30de 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -224,6 +224,7 @@ impl GrantLookup { id: user_session_id.into(), user, created_at: user_session_created_at, + finished_at: None, last_authentication, }; diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index 81a743633..bdc9c1b57 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -23,8 +23,8 @@ use uuid::Uuid; use self::client::lookup_clients; use crate::{ pagination::{process_page, QueryBuilderExt}, - user::lookup_active_session, - Clock, DatabaseError, DatabaseInconsistencyError, + user::BrowserSessionRepository, + Clock, DatabaseError, DatabaseInconsistencyError, Repository, }; pub mod access_token; @@ -134,11 +134,10 @@ pub async fn get_paginated_user_oauth_sessions( // ideal let mut browser_sessions: HashMap = HashMap::new(); for id in browser_session_ids { - let v = lookup_active_session(&mut *conn, id) - .await? - .ok_or_else(|| { - DatabaseInconsistencyError::on("oauth2_sessions").column("user_session_id") - })?; + let v = conn.browser_session().lookup(id).await?.ok_or_else(|| { + DatabaseInconsistencyError::on("oauth2_sessions").column("user_session_id") + })?; + browser_sessions.insert(id, v); } diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 79d90c010..5c2b63180 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -204,6 +204,7 @@ pub async fn lookup_active_refresh_token( let browser_session = BrowserSession { id: res.user_session_id.into(), created_at: res.user_session_created_at, + finished_at: None, user, last_authentication, }; diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index f321e9417..ef1a567a3 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -19,7 +19,10 @@ use crate::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, PgUpstreamOAuthSessionRepository, }, - user::{PgUserEmailRepository, PgUserRepository}, + user::{ + PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository, + PgUserRepository, + }, }; pub trait Repository { @@ -43,11 +46,21 @@ pub trait Repository { where Self: 'c; + type UserPasswordRepository<'c> + where + Self: 'c; + + type BrowserSessionRepository<'c> + where + Self: 'c; + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>; fn user(&mut self) -> Self::UserRepository<'_>; fn user_email(&mut self) -> Self::UserEmailRepository<'_>; + fn user_password(&mut self) -> Self::UserPasswordRepository<'_>; + fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>; } impl Repository for PgConnection { @@ -56,6 +69,8 @@ impl Repository for PgConnection { type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; type UserRepository<'c> = PgUserRepository<'c> where Self: 'c; type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; + type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; + type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -76,6 +91,14 @@ impl Repository for PgConnection { fn user_email(&mut self) -> Self::UserEmailRepository<'_> { PgUserEmailRepository::new(self) } + + fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { + PgUserPasswordRepository::new(self) + } + + fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { + PgBrowserSessionRepository::new(self) + } } impl<'t> Repository for Transaction<'t, Postgres> { @@ -84,6 +107,8 @@ impl<'t> Repository for Transaction<'t, Postgres> { type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; type UserRepository<'c> = PgUserRepository<'c> where Self: 'c; type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; + type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; + type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -104,4 +129,12 @@ impl<'t> Repository for Transaction<'t, Postgres> { fn user_email(&mut self) -> Self::UserEmailRepository<'_> { PgUserEmailRepository::new(self) } + + fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { + PgUserPasswordRepository::new(self) + } + + fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { + PgBrowserSessionRepository::new(self) + } } diff --git a/crates/storage/src/user/authentication.rs b/crates/storage/src/user/authentication.rs deleted file mode 100644 index 546b54a28..000000000 --- a/crates/storage/src/user/authentication.rs +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink}; -use rand::Rng; -use sqlx::PgExecutor; -use ulid::Ulid; -use uuid::Uuid; - -use crate::Clock; - -#[tracing::instrument( - skip_all, - fields( - user.id = %user_session.user.id, - %user_password.id, - %user_session.id, - user_session_authentication.id, - ), - err, -)] -pub async fn authenticate_session_with_password( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user_session: &mut BrowserSession, - user_password: &Password, -) -> Result<(), sqlx::Error> { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record( - "user_session_authentication.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO user_session_authentications - (user_session_authentication_id, user_session_id, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - Uuid::from(user_session.id), - created_at, - ) - .execute(executor) - .await?; - - user_session.last_authentication = Some(Authentication { id, created_at }); - - Ok(()) -} - -#[tracing::instrument( - skip_all, - fields( - user.id = %user_session.user.id, - %upstream_oauth_link.id, - %user_session.id, - user_session_authentication.id, - ), - err, -)] -pub async fn authenticate_session_with_upstream( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user_session: &mut BrowserSession, - upstream_oauth_link: &UpstreamOAuthLink, -) -> Result<(), sqlx::Error> { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record( - "user_session_authentication.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO user_session_authentications - (user_session_authentication_id, user_session_id, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - Uuid::from(user_session.id), - created_at, - ) - .execute(executor) - .await?; - - user_session.last_authentication = Some(Authentication { id, created_at }); - - Ok(()) -} diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 50f71752e..592cb59de 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -14,27 +14,22 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_data_model::{Authentication, BrowserSession, User}; -use rand::{Rng, RngCore}; -use sqlx::{PgConnection, PgExecutor, QueryBuilder}; -use tracing::{info_span, Instrument}; +use mas_data_model::User; +use rand::RngCore; +use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{ - pagination::{process_page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; -mod authentication; mod email; mod password; +mod session; pub use self::{ - authentication::{authenticate_session_with_password, authenticate_session_with_upstream}, email::{PgUserEmailRepository, UserEmailRepository}, - password::{add_user_password, lookup_user_password}, + password::{PgUserPasswordRepository, UserPasswordRepository}, + session::{BrowserSessionRepository, PgBrowserSessionRepository}, }; #[async_trait] @@ -218,234 +213,3 @@ impl<'c> UserRepository for PgUserRepository<'c> { Ok(exists) } } - -#[derive(sqlx::FromRow)] -struct SessionLookup { - user_session_id: Uuid, - user_session_created_at: DateTime, - user_id: Uuid, - user_username: String, - user_primary_user_email_id: Option, - last_authentication_id: Option, - last_authd_at: Option>, -} - -impl TryInto for SessionLookup { - type Error = DatabaseInconsistencyError; - - fn try_into(self) -> Result { - let id = Ulid::from(self.user_id); - let user = User { - id, - username: self.user_username, - sub: id.to_string(), - primary_user_email_id: self.user_primary_user_email_id.map(Into::into), - }; - - let last_authentication = match (self.last_authentication_id, self.last_authd_at) { - (Some(id), Some(created_at)) => Some(Authentication { - id: id.into(), - created_at, - }), - (None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on( - "user_session_authentications", - )) - } - }; - - Ok(BrowserSession { - id: self.user_session_id.into(), - user, - created_at: self.user_session_created_at, - last_authentication, - }) - } -} - -#[tracing::instrument( - skip_all, - fields(user_session.id = %id), - err, -)] -pub async fn lookup_active_session( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - SessionLookup, - r#" - SELECT s.user_session_id - , s.created_at AS "user_session_created_at" - , u.user_id - , u.username AS "user_username" - , u.primary_user_email_id AS "user_primary_user_email_id" - , a.user_session_authentication_id AS "last_authentication_id?" - , a.created_at AS "last_authd_at?" - FROM user_sessions s - INNER JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications a - USING (user_session_id) - WHERE s.user_session_id = $1 AND s.finished_at IS NULL - ORDER BY a.created_at DESC - LIMIT 1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - ), - err, -)] -pub async fn get_paginated_user_sessions( - executor: impl PgExecutor<'_>, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - s.user_session_id, - u.user_id, - u.username, - s.created_at, - a.user_session_authentication_id AS "last_authentication_id", - a.created_at AS "last_authd_at", - ue.user_email_id AS "user_email_id", - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM user_sessions s - INNER JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications a - USING (user_session_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id - "#, - ); - - query - .push(" WHERE s.finished_at IS NULL AND s.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("s.user_session_id", before, after, first, last)?; - - let span = info_span!("Fetch paginated user emails", db.statement = query.sql()); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) - .await?; - - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - - let page: Result, _> = page.into_iter().map(TryInto::try_into).collect(); - Ok((has_previous_page, has_next_page, page?)) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - user_session.id, - ), - err, -)] -pub async fn start_session( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user: User, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("user_session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO user_sessions (user_session_id, user_id, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - Uuid::from(user.id), - created_at, - ) - .execute(executor) - .await?; - - let session = BrowserSession { - id, - user, - created_at, - last_authentication: None, - }; - - Ok(session) -} - -#[tracing::instrument( - skip_all, - fields(%user.id), - err, -)] -pub async fn count_active_sessions( - executor: impl PgExecutor<'_>, - user: &User, -) -> Result { - let res = sqlx::query_scalar!( - r#" - SELECT COUNT(*) as "count!" - FROM user_sessions s - WHERE s.user_id = $1 AND s.finished_at IS NULL - "#, - Uuid::from(user.id), - ) - .fetch_one(executor) - .await?; - - Ok(res) -} - -#[tracing::instrument( - skip_all, - fields(%user_session.id), - err, -)] -pub async fn end_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - user_session: &BrowserSession, -) -> Result<(), DatabaseError> { - let now = clock.now(); - let res = sqlx::query!( - r#" - UPDATE user_sessions - SET finished_at = $1 - WHERE user_session_id = $2 - "#, - now, - Uuid::from(user_session.id), - ) - .execute(executor) - .instrument(info_span!("End session")) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} diff --git a/crates/storage/src/user/password.rs b/crates/storage/src/user/password.rs index 14ac52226..56c8a439c 100644 --- a/crates/storage/src/user/password.rs +++ b/crates/storage/src/user/password.rs @@ -12,63 +12,42 @@ // See the License for the specific language governing permissions and // limitations under the License. +use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{Password, User}; -use rand::Rng; -use sqlx::PgExecutor; +use rand::RngCore; +use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +use crate::{ + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, +}; -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - user_password.id, - user_password.version = version, - ), - err, -)] -pub async fn add_user_password( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user: &User, - version: u16, - hashed_password: String, - upgraded_from: Option, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("user_password.id", tracing::field::display(id)); +#[async_trait] +pub trait UserPasswordRepository: Send + Sync { + type Error; - let upgraded_from_id = upgraded_from.map(|p| p.id); + async fn active(&mut self, user: &User) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + version: u16, + hashed_password: String, + upgraded_from: Option<&Password>, + ) -> Result; +} - sqlx::query!( - r#" - INSERT INTO user_passwords - (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at) - VALUES ($1, $2, $3, $4, $5, $6) - "#, - Uuid::from(id), - Uuid::from(user.id), - hashed_password, - i32::from(version), - upgraded_from_id.map(Uuid::from), - created_at, - ) - .execute(executor) - .await?; +pub struct PgUserPasswordRepository<'c> { + conn: &'c mut PgConnection, +} - Ok(Password { - id, - hashed_password, - version, - upgraded_from_id, - created_at, - }) +impl<'c> PgUserPasswordRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } } struct UserPasswordLookup { @@ -79,57 +58,115 @@ struct UserPasswordLookup { created_at: DateTime, } -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - ), - err, -)] -pub async fn lookup_user_password( - executor: impl PgExecutor<'_>, - user: &User, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - UserPasswordLookup, - r#" - SELECT up.user_password_id - , up.hashed_password - , up.version - , up.upgraded_from_id - , up.created_at - FROM user_passwords up - WHERE up.user_id = $1 - ORDER BY up.created_at DESC - LIMIT 1 - "#, - Uuid::from(user.id), - ) - .fetch_one(executor) - .await - .to_option()?; +#[async_trait] +impl<'c> UserPasswordRepository for PgUserPasswordRepository<'c> { + type Error = DatabaseError; - let Some(res) = res else { return Ok(None) }; + #[tracing::instrument( + name = "db.user_password.active", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err, + )] + async fn active(&mut self, user: &User) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserPasswordLookup, + r#" + SELECT up.user_password_id + , up.hashed_password + , up.version + , up.upgraded_from_id + , up.created_at + FROM user_passwords up + WHERE up.user_id = $1 + ORDER BY up.created_at DESC + LIMIT 1 + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; - let id = Ulid::from(res.user_password_id); + let Some(res) = res else { return Ok(None) }; - let version = res.version.try_into().map_err(|e| { - DatabaseInconsistencyError::on("user_passwords") - .column("version") - .row(id) - .source(e) - })?; + let id = Ulid::from(res.user_password_id); - let upgraded_from_id = res.upgraded_from_id.map(Ulid::from); - let created_at = res.created_at; - let hashed_password = res.hashed_password; + let version = res.version.try_into().map_err(|e| { + DatabaseInconsistencyError::on("user_passwords") + .column("version") + .row(id) + .source(e) + })?; - Ok(Some(Password { - id, - hashed_password, - version, - upgraded_from_id, - created_at, - })) + let upgraded_from_id = res.upgraded_from_id.map(Ulid::from); + let created_at = res.created_at; + let hashed_password = res.hashed_password; + + Ok(Some(Password { + id, + hashed_password, + version, + upgraded_from_id, + created_at, + })) + } + + #[tracing::instrument( + name = "db.user_password.add", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + user_password.id, + user_password.version = version, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + version: u16, + hashed_password: String, + upgraded_from: Option<&Password>, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_password.id", tracing::field::display(id)); + + let upgraded_from_id = upgraded_from.map(|p| p.id); + + sqlx::query!( + r#" + INSERT INTO user_passwords + (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + Uuid::from(id), + Uuid::from(user.id), + hashed_password, + i32::from(version), + upgraded_from_id.map(Uuid::from), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(Password { + id, + hashed_password, + version, + upgraded_from_id, + created_at, + }) + } } diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs new file mode 100644 index 000000000..01102ca93 --- /dev/null +++ b/crates/storage/src/user/session.rs @@ -0,0 +1,425 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::{process_page, Page, QueryBuilderExt}, + tracing::ExecuteExt, + Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, +}; + +#[async_trait] +pub trait BrowserSessionRepository: Send + Sync { + type Error; + + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + ) -> Result; + async fn finish( + &mut self, + clock: &Clock, + user_session: BrowserSession, + ) -> Result; + async fn list_active_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error>; + async fn count_active(&mut self, user: &User) -> Result; + + async fn authenticate_with_password( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user_session: BrowserSession, + user_password: &Password, + ) -> Result; + + async fn authenticate_with_upstream( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user_session: BrowserSession, + upstream_oauth_link: &UpstreamOAuthLink, + ) -> Result; +} + +pub struct PgBrowserSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgBrowserSessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct SessionLookup { + user_session_id: Uuid, + user_session_created_at: DateTime, + user_session_finished_at: Option>, + user_id: Uuid, + user_username: String, + user_primary_user_email_id: Option, + last_authentication_id: Option, + last_authd_at: Option>, +} + +impl TryInto for SessionLookup { + type Error = DatabaseInconsistencyError; + + fn try_into(self) -> Result { + let id = Ulid::from(self.user_id); + let user = User { + id, + username: self.user_username, + sub: id.to_string(), + primary_user_email_id: self.user_primary_user_email_id.map(Into::into), + }; + + let last_authentication = match (self.last_authentication_id, self.last_authd_at) { + (Some(id), Some(created_at)) => Some(Authentication { + id: id.into(), + created_at, + }), + (None, None) => None, + _ => { + return Err(DatabaseInconsistencyError::on( + "user_session_authentications", + )) + } + }; + + Ok(BrowserSession { + id: self.user_session_id.into(), + user, + created_at: self.user_session_created_at, + finished_at: self.user_session_finished_at, + last_authentication, + }) + } +} + +#[async_trait] +impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.browser_session.lookup", + skip_all, + fields( + db.statement, + user_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + SessionLookup, + r#" + SELECT s.user_session_id + , s.created_at AS "user_session_created_at" + , s.finished_at AS "user_session_finished_at" + , u.user_id + , u.username AS "user_username" + , u.primary_user_email_id AS "user_primary_user_email_id" + , a.user_session_authentication_id AS "last_authentication_id?" + , a.created_at AS "last_authd_at?" + FROM user_sessions s + INNER JOIN users u + USING (user_id) + LEFT JOIN user_session_authentications a + USING (user_session_id) + WHERE s.user_session_id = $1 + ORDER BY a.created_at DESC + LIMIT 1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.browser_session.add", + skip_all, + fields( + db.statement, + %user.id, + user_session.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO user_sessions (user_session_id, user_id, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + Uuid::from(user.id), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let session = BrowserSession { + id, + // XXX + user: user.clone(), + created_at, + finished_at: None, + last_authentication: None, + }; + + Ok(session) + } + + #[tracing::instrument( + name = "db.browser_session.finish", + skip_all, + fields( + db.statement, + %user_session.id, + ), + err, + )] + async fn finish( + &mut self, + clock: &Clock, + mut user_session: BrowserSession, + ) -> Result { + let finished_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE user_sessions + SET finished_at = $1 + WHERE user_session_id = $2 + "#, + finished_at, + Uuid::from(user_session.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_session.finished_at = Some(finished_at); + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(user_session) + } + + #[tracing::instrument( + name = "db.browser_session.list_active_paginated", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn list_active_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error> { + // TODO: ordering of last authentication is wrong + let mut query = QueryBuilder::new( + r#" + SELECT DISTINCT ON (s.user_session_id) + s.user_session_id, + u.user_id, + u.username, + s.created_at, + a.user_session_authentication_id AS "last_authentication_id", + a.created_at AS "last_authd_at", + FROM user_sessions s + INNER JOIN users u + USING (user_id) + LEFT JOIN user_session_authentications a + USING (user_session_id) + "#, + ); + + query + .push(" WHERE s.finished_at IS NULL AND s.user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("s.user_session_id", before, after, first, last)?; + + let page: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; + + let edges: Result, _> = edges.into_iter().map(TryInto::try_into).collect(); + Ok(Page { + has_previous_page, + has_next_page, + edges: edges?, + }) + } + + #[tracing::instrument( + name = "db.browser_session.count_active", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn count_active(&mut self, user: &User) -> Result { + let res = sqlx::query_scalar!( + r#" + SELECT COUNT(*) as "count!" + FROM user_sessions s + WHERE s.user_id = $1 AND s.finished_at IS NULL + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + res.try_into().map_err(DatabaseError::to_invalid_operation) + } + + #[tracing::instrument( + name = "db.browser_session.authenticate_with_password", + skip_all, + fields( + db.statement, + %user_session.id, + %user_password.id, + user_session_authentication.id, + ), + err, + )] + async fn authenticate_with_password( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + mut user_session: BrowserSession, + user_password: &Password, + ) -> Result { + let _user_password = user_password; + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "user_session_authentication.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO user_session_authentications + (user_session_authentication_id, user_session_id, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_session.last_authentication = Some(Authentication { id, created_at }); + + Ok(user_session) + } + + #[tracing::instrument( + name = "db.browser_session.authenticate_with_upstream", + skip_all, + fields( + db.statement, + %user_session.id, + %upstream_oauth_link.id, + user_session_authentication.id, + ), + err, + )] + async fn authenticate_with_upstream( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + mut user_session: BrowserSession, + upstream_oauth_link: &UpstreamOAuthLink, + ) -> Result { + let _upstream_oauth_link = upstream_oauth_link; + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "user_session_authentication.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO user_session_authentications + (user_session_authentication_id, user_session_id, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_session.last_authentication = Some(Authentication { id, created_at }); + + Ok(user_session) + } +} diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 8e8e0c8ae..cd70289b6 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -532,14 +532,14 @@ where { /// Context used by the `account/index.html` template #[derive(Serialize)] pub struct AccountContext { - active_sessions: i64, + active_sessions: usize, emails: Vec, } impl AccountContext { /// Constructs a context for the "my account" page #[must_use] - pub fn new(active_sessions: i64, emails: Vec) -> Self { + pub fn new(active_sessions: usize, emails: Vec) -> Self { Self { active_sessions, emails, From 3e312ef4696e1d574f3f8badda1ccebc38017a07 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 3 Jan 2023 16:43:18 +0100 Subject: [PATCH 07/45] Allow updating clients from the config without truncating them --- crates/cli/src/commands/manage.rs | 27 +++++++++-------- crates/storage/sqlx-data.json | 46 +++++++++++------------------ crates/storage/src/oauth2/client.rs | 30 ++++++++++--------- 3 files changed, 48 insertions(+), 55 deletions(-) diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 60b94bfe1..b6c0e4658 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -18,7 +18,7 @@ use mas_config::{DatabaseConfig, PasswordsConfig, RootConfig}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_router::UrlBuilder; use mas_storage::{ - oauth2::client::{insert_client_from_config, lookup_client, truncate_clients}, + oauth2::client::{insert_client_from_config, lookup_client}, upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, Repository, @@ -146,9 +146,9 @@ enum Subcommand { /// Import clients from config ImportClients { - /// Remove all clients before importing + /// Update existing clients #[arg(long)] - truncate: bool, + update: bool, }, /// Set a user password @@ -244,27 +244,28 @@ impl Options { Ok(()) } - SC::ImportClients { truncate } => { + SC::ImportClients { update } => { let config: RootConfig = root.load_config()?; let pool = database_from_config(&config.database).await?; let encrypter = config.secrets.encrypter(); let mut txn = pool.begin().await?; - if *truncate { - warn!("Removing all clients first"); - truncate_clients(&mut txn).await?; - } - for client in config.clients.iter() { let client_id = client.client_id; - let res = lookup_client(&mut txn, client_id).await?; - if res.is_some() { - warn!(%client_id, "Skipping already imported client"); + + let existing = lookup_client(&mut txn, client_id).await?.is_some(); + if !update && existing { + warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients."); continue; } - info!(%client_id, "Importing client"); + if existing { + info!(%client_id, "Updating client"); + } else { + info!(%client_id, "Importing client"); + } + let client_secret = client.client_secret(); let client_auth_method = client.client_auth_method(); let jwks = client.jwks(); diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 1676c69a6..010367428 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -843,24 +843,6 @@ }, "query": "\n UPDATE oauth2_authorization_grants\n SET exchanged_at = $2\n WHERE oauth2_authorization_grant_id = $1\n " }, - "26a9391df9f1128673cdaf431fe8c5e4a83b576ddf7b02d92abfab6deadd4fa2": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Bool", - "Bool", - "Text", - "Jsonb", - "Text" - ] - } - }, - "query": "\n INSERT INTO oauth2_clients\n (oauth2_client_id,\n encrypted_client_secret,\n grant_type_authorization_code,\n grant_type_refresh_token,\n token_endpoint_auth_method,\n jwks,\n jwks_uri)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n " - }, "2e756fe7be50128c0acc5f79df3a084230e9ca13cd45bd0858f97e59da20006e": { "describe": { "columns": [], @@ -2119,6 +2101,24 @@ }, "query": "\n UPDATE users\n SET primary_user_email_id = user_emails.user_email_id\n FROM user_emails\n WHERE user_emails.user_email_id = $1\n AND users.user_id = user_emails.user_id\n " }, + "c0b4996085f6f2127e1e8cfdf18b9029c22096fadfe6de59dce01c789791edb5": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Bool", + "Bool", + "Text", + "Jsonb", + "Text" + ] + } + }, + "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n " + }, "c1d90a7f2287ec779c81a521fab19e5ede3fa95484033e0312c30d9b6ecc03f0": { "describe": { "columns": [], @@ -2174,16 +2174,6 @@ }, "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " }, - "cb8ba981330e58a6c8580f6e394a721df110e1f2206e080434aa821c44c0164b": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [] - } - }, - "query": "TRUNCATE oauth2_client_redirect_uris, oauth2_clients CASCADE" - }, "cc9e30678d673546efca336ee8e550083eed71459611fa2db52264e51e175901": { "describe": { "columns": [], diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 6138e377c..164b0e80b 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -481,15 +481,24 @@ pub async fn insert_client_from_config( sqlx::query!( r#" INSERT INTO oauth2_clients - (oauth2_client_id, - encrypted_client_secret, - grant_type_authorization_code, - grant_type_refresh_token, - token_endpoint_auth_method, - jwks, - jwks_uri) + ( oauth2_client_id + , encrypted_client_secret + , grant_type_authorization_code + , grant_type_refresh_token + , token_endpoint_auth_method + , jwks + , jwks_uri + ) VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (oauth2_client_id) + DO + UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret + , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code + , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token + , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method + , jwks = EXCLUDED.jwks + , jwks_uri = EXCLUDED.jwks_uri "#, Uuid::from(client_id), encrypted_client_secret, @@ -529,10 +538,3 @@ pub async fn insert_client_from_config( Ok(()) } - -pub async fn truncate_clients(executor: impl PgExecutor<'_>) -> Result<(), sqlx::Error> { - sqlx::query!("TRUNCATE oauth2_client_redirect_uris, oauth2_clients CASCADE") - .execute(executor) - .await?; - Ok(()) -} From 26b6023b337920b939c9b5cd887011417a7e6bd9 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 4 Jan 2023 11:28:08 +0100 Subject: [PATCH 08/45] Switch the `ulid` dependency override the main repo --- Cargo.lock | 2 +- Cargo.toml | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b5e1e65be..f299df15d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5577,7 +5577,7 @@ checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" [[package]] name = "ulid" version = "1.0.0" -source = "git+https://github.com/sandhose/ulid-rs.git?rev=f1ef6fd736c4d3cbc7cf314fad707f0803de46ed#f1ef6fd736c4d3cbc7cf314fad707f0803de46ed" +source = "git+https://github.com/dylanhart/ulid-rs.git?rev=0b9295c2db2114cd87aa19abcc1fc00c16b272db#0b9295c2db2114cd87aa19abcc1fc00c16b272db" dependencies = [ "rand 0.8.5", "serde", diff --git a/Cargo.toml b/Cargo.toml index 9799f34b6..644ecbf4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,8 +8,7 @@ opt-level = 3 [profile.dev.package.sqlx-macros] opt-level = 3 -# Until https://github.com/dylanhart/ulid-rs/pull/56 gets merged and released +# Until https://github.com/dylanhart/ulid-rs/pull/56 gets released [patch.crates-io.ulid] -git = "https://github.com/sandhose/ulid-rs.git" -#branch = "relax-sized-on-rng" -rev = "f1ef6fd736c4d3cbc7cf314fad707f0803de46ed" +git = "https://github.com/dylanhart/ulid-rs.git" +rev = "0b9295c2db2114cd87aa19abcc1fc00c16b272db" From 94b6d31fe9532a0f24df952dabbab5621b8aaaf2 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 4 Jan 2023 14:48:29 +0100 Subject: [PATCH 09/45] storage: OAuth2 client repository --- crates/axum-utils/src/client_authorization.rs | 11 +- crates/cli/src/commands/manage.rs | 29 +- crates/graphql/src/lib.rs | 3 +- crates/graphql/src/model/oauth.rs | 6 +- .../handlers/src/oauth2/authorization/mod.rs | 9 +- crates/handlers/src/oauth2/registration.rs | 72 +- crates/oauth2-types/src/registration/mod.rs | 5 + crates/storage/sqlx-data.json | 570 +++++++------- crates/storage/src/oauth2/access_token.rs | 8 +- .../storage/src/oauth2/authorization_grant.rs | 11 +- crates/storage/src/oauth2/client.rs | 705 +++++++++++------- crates/storage/src/oauth2/mod.rs | 4 +- crates/storage/src/oauth2/refresh_token.rs | 8 +- crates/storage/src/repository.rs | 16 + crates/storage/src/tracing.rs | 15 +- 15 files changed, 833 insertions(+), 639 deletions(-) diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 00baf4a52..0e6771b4e 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -31,10 +31,10 @@ use mas_http::HttpServiceExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_keystore::Encrypter; -use mas_storage::{oauth2::client::lookup_client_by_client_id, DatabaseError}; +use mas_storage::{oauth2::client::OAuth2ClientRepository, DatabaseError, Repository}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; -use sqlx::PgExecutor; +use sqlx::PgConnection; use thiserror::Error; use tower::{Service, ServiceExt}; @@ -73,10 +73,7 @@ pub enum Credentials { } impl Credentials { - pub async fn fetch( - &self, - executor: impl PgExecutor<'_>, - ) -> Result, DatabaseError> { + pub async fn fetch(&self, conn: &mut PgConnection) -> Result, DatabaseError> { let client_id = match self { Credentials::None { client_id } | Credentials::ClientSecretBasic { client_id, .. } @@ -84,7 +81,7 @@ impl Credentials { | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id, }; - lookup_client_by_client_id(executor, client_id).await + conn.oauth2_client().find_by_client_id(client_id).await } #[tracing::instrument(skip_all, err)] diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index b6c0e4658..9e2b39882 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -18,7 +18,7 @@ use mas_config::{DatabaseConfig, PasswordsConfig, RootConfig}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_router::UrlBuilder; use mas_storage::{ - oauth2::client::{insert_client_from_config, lookup_client}, + oauth2::client::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, Repository, @@ -254,7 +254,7 @@ impl Options { for client in config.clients.iter() { let client_id = client.client_id; - let existing = lookup_client(&mut txn, client_id).await?.is_some(); + let existing = txn.oauth2_client().lookup(client_id).await?.is_some(); if !update && existing { warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients."); continue; @@ -270,25 +270,24 @@ impl Options { let client_auth_method = client.client_auth_method(); let jwks = client.jwks(); let jwks_uri = client.jwks_uri(); - let redirect_uris = &client.redirect_uris; // TODO: should be moved somewhere else let encrypted_client_secret = client_secret .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .transpose()?; - insert_client_from_config( - &mut txn, - &mut rng, - &clock, - client_id, - client_auth_method, - encrypted_client_secret.as_deref(), - jwks, - jwks_uri, - redirect_uris, - ) - .await?; + txn.oauth2_client() + .add_from_config( + &mut rng, + &clock, + client_id, + client_auth_method, + encrypted_client_secret, + jwks.cloned(), + jwks_uri.cloned(), + client.redirect_uris.clone(), + ) + .await?; } txn.commit().await?; diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index b79b6fe9f..ffa63396b 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -31,6 +31,7 @@ use async_graphql::{ Context, Description, EmptyMutation, EmptySubscription, ID, }; use mas_storage::{ + oauth2::client::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserEmailRepository}, Repository, UpstreamOAuthLinkRepository, @@ -95,7 +96,7 @@ impl RootQuery { let database = ctx.data::()?; let mut conn = database.acquire().await?; - let client = mas_storage::oauth2::client::lookup_client(&mut conn, id).await?; + let client = conn.oauth2_client().lookup(id).await?; Ok(client.map(OAuth2Client)) } diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 89598ffa5..5f1236f2e 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,7 +14,7 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::oauth2::client::lookup_client; +use mas_storage::{oauth2::client::OAuth2ClientRepository, Repository}; use oauth2_types::scope::Scope; use sqlx::PgPool; use ulid::Ulid; @@ -115,7 +115,9 @@ impl OAuth2Consent { /// OAuth 2.0 client for which the user granted access. pub async fn client(&self, ctx: &Context<'_>) -> Result { let mut conn = ctx.data::()?.acquire().await?; - let client = lookup_client(&mut conn, self.client_id) + let client = conn + .oauth2_client() + .lookup(self.client_id) .await? .context("Could not load client")?; Ok(OAuth2Client(client)) diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 1b999ffc6..36d15d2bc 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -25,8 +25,9 @@ use mas_data_model::{AuthorizationCode, Pkce}; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; -use mas_storage::oauth2::{ - authorization_grant::new_authorization_grant, client::lookup_client_by_client_id, +use mas_storage::{ + oauth2::{authorization_grant::new_authorization_grant, client::OAuth2ClientRepository}, + Repository, }; use mas_templates::Templates; use oauth2_types::{ @@ -141,7 +142,9 @@ pub(crate) async fn get( let mut txn = pool.begin().await?; // First, figure out what client it is - let client = lookup_client_by_client_id(&mut txn, ¶ms.auth.client_id) + let client = txn + .oauth2_client() + .find_by_client_id(¶ms.auth.client_id) .await? .ok_or(RouteError::ClientNotFound)?; diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index 25b734cfd..b12194eb5 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -19,7 +19,7 @@ use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; use mas_policy::{PolicyFactory, Violation}; -use mas_storage::oauth2::client::insert_client; +use mas_storage::{oauth2::client::OAuth2ClientRepository, Repository}; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, registration::{ @@ -30,7 +30,6 @@ use rand::distributions::{Alphanumeric, DistString}; use sqlx::PgPool; use thiserror::Error; use tracing::info; -use ulid::Ulid; use crate::impl_from_error_for_route; @@ -50,6 +49,7 @@ pub(crate) enum RouteError { } impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::DatabaseError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -124,16 +124,9 @@ pub(crate) async fn post( return Err(RouteError::PolicyDenied(res.violations)); } - // Contacts was checked by the policy - let contacts = metadata.contacts.as_deref().unwrap_or_default(); - // Grab a txn let mut txn = pool.begin().await?; - let now = clock.now(); - // Let's generate a random client ID - let client_id = Ulid::from_datetime_with_source(now.into(), &mut rng); - let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method { Some( OAuthClientAuthenticationMethod::ClientSecretJwt @@ -148,41 +141,42 @@ pub(crate) async fn post( _ => (None, None), }; - insert_client( - &mut txn, - &mut rng, - &clock, - client_id, - metadata.redirect_uris(), - encrypted_client_secret.as_deref(), - //&metadata.response_types(), - metadata.grant_types(), - contacts, - metadata - .client_name - .as_ref() - .map(|l| l.non_localized().as_ref()), - metadata.logo_uri.as_ref().map(Localized::non_localized), - metadata.client_uri.as_ref().map(Localized::non_localized), - metadata.policy_uri.as_ref().map(Localized::non_localized), - metadata.tos_uri.as_ref().map(Localized::non_localized), - metadata.jwks_uri.as_ref(), - metadata.jwks.as_ref(), - // XXX: those might not be right, should be function calls - metadata.id_token_signed_response_alg.as_ref(), - metadata.userinfo_signed_response_alg.as_ref(), - metadata.token_endpoint_auth_method.as_ref(), - metadata.token_endpoint_auth_signing_alg.as_ref(), - metadata.initiate_login_uri.as_ref(), - ) - .await?; + let client = txn + .oauth2_client() + .add( + &mut rng, + &clock, + metadata.redirect_uris().to_vec(), + encrypted_client_secret, + //&metadata.response_types(), + metadata.grant_types().to_vec(), + metadata.contacts.clone().unwrap_or_default(), + metadata + .client_name + .clone() + .map(Localized::to_non_localized), + metadata.logo_uri.clone().map(Localized::to_non_localized), + metadata.client_uri.clone().map(Localized::to_non_localized), + metadata.policy_uri.clone().map(Localized::to_non_localized), + metadata.tos_uri.clone().map(Localized::to_non_localized), + metadata.jwks_uri.clone(), + metadata.jwks.clone(), + // XXX: those might not be right, should be function calls + metadata.id_token_signed_response_alg.clone(), + metadata.userinfo_signed_response_alg.clone(), + metadata.token_endpoint_auth_method.clone(), + metadata.token_endpoint_auth_signing_alg.clone(), + metadata.initiate_login_uri.clone(), + ) + .await?; txn.commit().await?; let response = ClientRegistrationResponse { - client_id: client_id.to_string(), + client_id: client.client_id, client_secret, - client_id_issued_at: Some(now), + // XXX: we should have a `created_at` field on the clients + client_id_issued_at: Some(client.id.datetime().into()), client_secret_expires_at: None, }; diff --git a/crates/oauth2-types/src/registration/mod.rs b/crates/oauth2-types/src/registration/mod.rs index 18aa24fa0..0d9589966 100644 --- a/crates/oauth2-types/src/registration/mod.rs +++ b/crates/oauth2-types/src/registration/mod.rs @@ -90,6 +90,11 @@ impl Localized { &self.non_localized } + /// Get the non-localized variant. + pub fn to_non_localized(self) -> T { + self.non_localized + } + /// Get the variant corresponding to the given language, if it exists. pub fn get(&self, language: Option<&LanguageTag>) -> Option<&T> { match language { diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 010367428..feced0ba8 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -98,122 +98,6 @@ }, "query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , us.user_session_id AS \"user_session_id!\"\n , us.created_at AS \"user_session_created_at!\"\n , u.user_id AS \"user_id!\"\n , u.username AS \"user_username!\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , usa.user_session_authentication_id AS \"user_session_last_authentication_id?\"\n , usa.created_at AS \"user_session_last_authentication_created_at?\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n INNER JOIN user_sessions us\n USING (user_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " }, - "05b50b7ae0109063c50fe70e83635a31920e44a7fbaa2b4f07552ba2f83a28d7": { - "describe": { - "columns": [ - { - "name": "oauth2_client_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "encrypted_client_secret", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "redirect_uris!", - "ordinal": 2, - "type_info": "TextArray" - }, - { - "name": "grant_type_authorization_code", - "ordinal": 3, - "type_info": "Bool" - }, - { - "name": "grant_type_refresh_token", - "ordinal": 4, - "type_info": "Bool" - }, - { - "name": "client_name", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "logo_uri", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "client_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "policy_uri", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "tos_uri", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "jwks_uri", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "jwks", - "ordinal": 11, - "type_info": "Jsonb" - }, - { - "name": "id_token_signed_response_alg", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "userinfo_signed_response_alg", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_method", - "ordinal": 14, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_signing_alg", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "initiate_login_uri", - "ordinal": 16, - "type_info": "Text" - } - ], - "nullable": [ - false, - true, - null, - false, - false, - true, - true, - true, - true, - true, - true, - true, - true, - true, - true, - true, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = $1\n " - }, "08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35": { "describe": { "columns": [ @@ -1046,20 +930,6 @@ }, "query": "\n SELECT up.user_password_id\n , up.hashed_password\n , up.version\n , up.upgraded_from_id\n , up.created_at\n FROM user_passwords up\n WHERE up.user_id = $1\n ORDER BY up.created_at DESC\n LIMIT 1\n " }, - "4693f2b9b3d51ff4a05e233b6667161ebc97f331d96bf5f1c61069e1c8492105": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "UuidArray", - "Uuid", - "TextArray" - ] - } - }, - "query": "\n INSERT INTO oauth2_client_redirect_uris\n (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)\n SELECT id, $2, redirect_uri\n FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)\n " - }, "46c5ae7052504bfd7b94f20e61b9cf92570779a794bccda23dd654fb8523f340": { "describe": { "columns": [ @@ -1432,7 +1302,147 @@ }, "query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.finished_at IS NULL\n " }, - "7756a60c36a64a259f7450d6eb77ee92303638ca374a63f23ac4944ccf9f4436": { + "79295f3d3a75f831e9469aabfa720d381a254d00dbe39fef1e9652029d51b89b": { + "describe": { + "columns": [ + { + "name": "user_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_session_created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "user_session_finished_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "user_id", + "ordinal": 3, + "type_info": "Uuid" + }, + { + "name": "user_username", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "user_primary_user_email_id", + "ordinal": 5, + "type_info": "Uuid" + }, + { + "name": "last_authentication_id?", + "ordinal": 6, + "type_info": "Uuid" + }, + { + "name": "last_authd_at?", + "ordinal": 7, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , a.user_session_authentication_id AS \"last_authentication_id?\"\n , a.created_at AS \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications a\n USING (user_session_id)\n WHERE s.user_session_id = $1\n ORDER BY a.created_at DESC\n LIMIT 1\n " + }, + "7be139553610ace03193a99fe27fcb4e3d50c90accdaf22ca1cfeefdc9734300": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "UuidArray", + "Uuid", + "TextArray" + ] + } + }, + "query": "\n INSERT INTO oauth2_client_redirect_uris\n (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)\n SELECT id, $2, redirect_uri\n FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)\n " + }, + "7ce387b1b0aaf10e72adde667b19521b66eaafa51f73bf2f95e38b8f3b64a229": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + } + }, + "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " + }, + "7e3247e35ecf5335f0656c53bcde27264a9efb8dccb6246344950614f487dcaf": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n " + }, + "836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c": { + "describe": { + "columns": [ + { + "name": "user_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "username", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "primary_user_email_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE username = $1\n " + }, + "85499663f1adc7b7439592063f06914089f6243126a177b365bde37db5f6b33d": { "describe": { "columns": [ { @@ -1546,133 +1556,7 @@ ] } }, - "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = ANY($1::uuid[])\n " - }, - "79295f3d3a75f831e9469aabfa720d381a254d00dbe39fef1e9652029d51b89b": { - "describe": { - "columns": [ - { - "name": "user_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_session_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "user_session_finished_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "user_id", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "user_username", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "user_primary_user_email_id", - "ordinal": 5, - "type_info": "Uuid" - }, - { - "name": "last_authentication_id?", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "last_authd_at?", - "ordinal": 7, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , a.user_session_authentication_id AS \"last_authentication_id?\"\n , a.created_at AS \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications a\n USING (user_session_id)\n WHERE s.user_session_id = $1\n ORDER BY a.created_at DESC\n LIMIT 1\n " - }, - "7ce387b1b0aaf10e72adde667b19521b66eaafa51f73bf2f95e38b8f3b64a229": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " - }, - "7e3247e35ecf5335f0656c53bcde27264a9efb8dccb6246344950614f487dcaf": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n " - }, - "836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c": { - "describe": { - "columns": [ - { - "name": "user_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "username", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "primary_user_email_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "created_at", - "ordinal": 3, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE username = $1\n " + "query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n\n WHERE oauth2_client_id = ANY($1::uuid[])\n " }, "874e677f82c221c5bb621c12f293bcef4e70c68c87ec003fcd475bcb994b5a4c": { "describe": { @@ -2174,33 +2058,6 @@ }, "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " }, - "cc9e30678d673546efca336ee8e550083eed71459611fa2db52264e51e175901": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Bool", - "Bool", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Jsonb", - "Text", - "Text", - "Text", - "Text", - "Text" - ] - } - }, - "query": "\n INSERT INTO oauth2_clients\n (oauth2_client_id,\n encrypted_client_secret,\n grant_type_authorization_code,\n grant_type_refresh_token,\n client_name,\n logo_uri,\n client_uri,\n policy_uri,\n tos_uri,\n jwks_uri,\n jwks,\n id_token_signed_response_alg,\n userinfo_signed_response_alg,\n token_endpoint_auth_method,\n token_endpoint_auth_signing_alg,\n initiate_login_uri)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " - }, "d023d7346ec1f32da9459db3c39dffd8a4e3d4e91cdf096928de4517d3f8c622": { "describe": { "columns": [ @@ -2368,6 +2225,122 @@ }, "query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,\n refresh_token, created_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " }, + "db90cbc406a399f5447bd2c1d8018464f83b927dec620353516c0285b76fcf24": { + "describe": { + "columns": [ + { + "name": "oauth2_client_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "encrypted_client_secret", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "redirect_uris!", + "ordinal": 2, + "type_info": "TextArray" + }, + { + "name": "grant_type_authorization_code", + "ordinal": 3, + "type_info": "Bool" + }, + { + "name": "grant_type_refresh_token", + "ordinal": 4, + "type_info": "Bool" + }, + { + "name": "client_name", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "logo_uri", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "client_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "policy_uri", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "tos_uri", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "jwks_uri", + "ordinal": 10, + "type_info": "Text" + }, + { + "name": "jwks", + "ordinal": 11, + "type_info": "Jsonb" + }, + { + "name": "id_token_signed_response_alg", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "userinfo_signed_response_alg", + "ordinal": 13, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_method", + "ordinal": 14, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_signing_alg", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "initiate_login_uri", + "ordinal": 16, + "type_info": "Text" + } + ], + "nullable": [ + false, + true, + null, + false, + false, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n\n WHERE oauth2_client_id = $1\n " + }, "dbf4be84eeff9ea51b00185faae2d453ab449017ed492bf6711dc7fceb630880": { "describe": { "columns": [], @@ -2426,6 +2399,33 @@ }, "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " }, + "f5edcd4c306ca8179cdf9d4aab59fbba971b54611c91345849920954dd8089b3": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Bool", + "Bool", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Jsonb", + "Text", + "Text", + "Text", + "Text", + "Text" + ] + } + }, + "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " + }, "f624e1bdbff4e97b300362d1bbd86035e4a0fdd8ffe16c3bfb9bc451ba60851b": { "describe": { "columns": [ diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index 5c2347d2e..e41f4812f 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -19,8 +19,8 @@ use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; use uuid::Uuid; -use super::client::lookup_client; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; +use super::client::OAuth2ClientRepository; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError, Repository}; #[tracing::instrument( skip_all, @@ -144,7 +144,9 @@ pub async fn lookup_active_access_token( }; let session_id = res.oauth2_session_id.into(); - let client = lookup_client(&mut *conn, res.oauth2_client_id.into()) + let client = conn + .oauth2_client() + .lookup(res.oauth2_client_id.into()) .await? .ok_or_else(|| { DatabaseInconsistencyError::on("oauth2_sessions") diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index b7ffb30de..bfd918600 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -27,8 +27,8 @@ use ulid::Ulid; use url::Url; use uuid::Uuid; -use super::client::lookup_client; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +use super::client::OAuth2ClientRepository; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Repository}; #[tracing::instrument( skip_all, @@ -163,7 +163,7 @@ impl GrantLookup { #[allow(clippy::too_many_lines)] async fn into_authorization_grant( self, - executor: impl PgExecutor<'_>, + conn: &mut PgConnection, ) -> Result { let id = self.oauth2_authorization_grant_id.into(); let scope: Scope = self.oauth2_authorization_grant_scope.parse().map_err(|e| { @@ -173,8 +173,9 @@ impl GrantLookup { .source(e) })?; - // TODO: don't unwrap - let client = lookup_client(executor, self.oauth2_client_id.into()) + let client = conn + .oauth2_client() + .lookup(self.oauth2_client_id.into()) .await? .ok_or_else(|| { DatabaseInconsistencyError::on("oauth2_authorization_grants") diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 164b0e80b..afe789db4 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -12,8 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashMap, string::ToString}; +use std::{ + collections::{BTreeMap, BTreeSet}, + string::ToString, +}; +use async_trait::async_trait; use mas_data_model::{Client, JwksOrJwksUri}; use mas_iana::{ jose::JsonWebSignatureAlg, @@ -21,17 +25,83 @@ use mas_iana::{ }; use mas_jose::jwk::PublicJsonWebKeySet; use oauth2_types::requests::GrantType; -use rand::Rng; -use sqlx::{PgConnection, PgExecutor}; +use rand::{Rng, RngCore}; +use sqlx::PgConnection; +use tracing::{info_span, Instrument}; use ulid::Ulid; use url::Url; use uuid::Uuid; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +use crate::{ + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, +}; + +#[async_trait] +pub trait OAuth2ClientRepository: Send + Sync { + type Error; + + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_client_id(&mut self, client_id: &str) -> Result, Self::Error> { + let Ok(id) = client_id.parse() else { return Ok(None) }; + self.lookup(id).await + } + + async fn load_batch( + &mut self, + ids: BTreeSet, + ) -> Result, Self::Error>; + + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + redirect_uris: Vec, + encrypted_client_secret: Option, + grant_types: Vec, + contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, + ) -> Result; + + #[allow(clippy::too_many_arguments)] + async fn add_from_config( + &mut self, + mut rng: impl Rng + Send, + clock: &Clock, + client_id: Ulid, + client_auth_method: OAuthClientAuthenticationMethod, + encrypted_client_secret: Option, + jwks: Option, + jwks_uri: Option, + redirect_uris: Vec, + ) -> Result; +} + +pub struct PgOAuth2ClientRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2ClientRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} // XXX: response_types & contacts #[derive(Debug)] -pub struct OAuth2ClientLookup { +struct OAuth2ClientLookup { oauth2_client_id: Uuid, encrypted_client_secret: Option, redirect_uris: Vec, @@ -234,252 +304,305 @@ impl TryInto for OAuth2ClientLookup { } } -#[tracing::instrument(skip_all, err)] -pub async fn lookup_clients( - executor: impl PgExecutor<'_>, - ids: impl IntoIterator + Send, -) -> Result, DatabaseError> { - let ids: Vec = ids.into_iter().map(Uuid::from).collect(); - let res = sqlx::query_as!( - OAuth2ClientLookup, - r#" - SELECT - c.oauth2_client_id, - c.encrypted_client_secret, - ARRAY( - SELECT redirect_uri - FROM oauth2_client_redirect_uris r - WHERE r.oauth2_client_id = c.oauth2_client_id - ) AS "redirect_uris!", - c.grant_type_authorization_code, - c.grant_type_refresh_token, - c.client_name, - c.logo_uri, - c.client_uri, - c.policy_uri, - c.tos_uri, - c.jwks_uri, - c.jwks, - c.id_token_signed_response_alg, - c.userinfo_signed_response_alg, - c.token_endpoint_auth_method, - c.token_endpoint_auth_signing_alg, - c.initiate_login_uri - FROM oauth2_clients c +#[async_trait] +impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { + type Error = DatabaseError; - WHERE c.oauth2_client_id = ANY($1::uuid[]) - "#, - &ids, - ) - .fetch_all(executor) - .await?; + #[tracing::instrument( + name = "db.oauth2_client.lookup", + skip_all, + fields( + db.statement, + oauth2_client.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2ClientLookup, + r#" + SELECT oauth2_client_id + , encrypted_client_secret + , ARRAY( + SELECT redirect_uri + FROM oauth2_client_redirect_uris r + WHERE r.oauth2_client_id = c.oauth2_client_id + ) AS "redirect_uris!" + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + FROM oauth2_clients c - res.into_iter() - .map(|r| { - r.try_into() - .map(|c: Client| (c.id, c)) - .map_err(DatabaseError::from) - }) - .collect() -} + WHERE oauth2_client_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; -#[tracing::instrument( - skip_all, - fields(client.id = %id), - err, -)] -pub async fn lookup_client( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - OAuth2ClientLookup, - r#" - SELECT - c.oauth2_client_id, - c.encrypted_client_secret, - ARRAY( - SELECT redirect_uri - FROM oauth2_client_redirect_uris r - WHERE r.oauth2_client_id = c.oauth2_client_id - ) AS "redirect_uris!", - c.grant_type_authorization_code, - c.grant_type_refresh_token, - c.client_name, - c.logo_uri, - c.client_uri, - c.policy_uri, - c.tos_uri, - c.jwks_uri, - c.jwks, - c.id_token_signed_response_alg, - c.userinfo_signed_response_alg, - c.token_endpoint_auth_method, - c.token_endpoint_auth_signing_alg, - c.initiate_login_uri - FROM oauth2_clients c + let Some(res) = res else { return Ok(None) }; - WHERE c.oauth2_client_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; + Ok(Some(res.try_into()?)) + } - let Some(res) = res else { return Ok(None) }; + #[tracing::instrument( + name = "db.oauth2_client.load_batch", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn load_batch( + &mut self, + ids: BTreeSet, + ) -> Result, Self::Error> { + let ids: Vec = ids.into_iter().map(Uuid::from).collect(); + let res = sqlx::query_as!( + OAuth2ClientLookup, + r#" + SELECT oauth2_client_id + , encrypted_client_secret + , ARRAY( + SELECT redirect_uri + FROM oauth2_client_redirect_uris r + WHERE r.oauth2_client_id = c.oauth2_client_id + ) AS "redirect_uris!" + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + FROM oauth2_clients c - Ok(Some(res.try_into()?)) -} + WHERE oauth2_client_id = ANY($1::uuid[]) + "#, + &ids, + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; -#[tracing::instrument( - skip_all, - fields(client.id = client_id), - err, -)] -pub async fn lookup_client_by_client_id( - executor: impl PgExecutor<'_>, - client_id: &str, -) -> Result, DatabaseError> { - let Ok(id) = client_id.parse() else { return Ok(None) }; - lookup_client(executor, id).await -} + res.into_iter() + .map(|r| { + r.try_into() + .map(|c: Client| (c.id, c)) + .map_err(DatabaseError::from) + }) + .collect() + } -#[tracing::instrument( - skip_all, - fields(client.id = %client_id, client.name = client_name), - err, -)] -#[allow(clippy::too_many_arguments)] -pub async fn insert_client( - conn: &mut PgConnection, - mut rng: impl Rng + Send, - clock: &Clock, - client_id: Ulid, - redirect_uris: &[Url], - encrypted_client_secret: Option<&str>, - grant_types: &[GrantType], - _contacts: &[String], - client_name: Option<&str>, - logo_uri: Option<&Url>, - client_uri: Option<&Url>, - policy_uri: Option<&Url>, - tos_uri: Option<&Url>, - jwks_uri: Option<&Url>, - jwks: Option<&PublicJsonWebKeySet>, - id_token_signed_response_alg: Option<&JsonWebSignatureAlg>, - userinfo_signed_response_alg: Option<&JsonWebSignatureAlg>, - token_endpoint_auth_method: Option<&OAuthClientAuthenticationMethod>, - token_endpoint_auth_signing_alg: Option<&JsonWebSignatureAlg>, - initiate_login_uri: Option<&Url>, -) -> Result<(), sqlx::Error> { - let grant_type_authorization_code = grant_types.contains(&GrantType::AuthorizationCode); - let grant_type_refresh_token = grant_types.contains(&GrantType::RefreshToken); - let logo_uri = logo_uri.map(Url::as_str); - let client_uri = client_uri.map(Url::as_str); - let policy_uri = policy_uri.map(Url::as_str); - let tos_uri = tos_uri.map(Url::as_str); - let jwks = jwks.map(serde_json::to_value).transpose().unwrap(); // TODO - let jwks_uri = jwks_uri.map(Url::as_str); - let id_token_signed_response_alg = id_token_signed_response_alg.map(ToString::to_string); - let userinfo_signed_response_alg = userinfo_signed_response_alg.map(ToString::to_string); - let token_endpoint_auth_method = token_endpoint_auth_method.map(ToString::to_string); - let token_endpoint_auth_signing_alg = token_endpoint_auth_signing_alg.map(ToString::to_string); - let initiate_login_uri = initiate_login_uri.map(Url::as_str); + #[tracing::instrument( + name = "db.oauth2_client.add", + skip_all, + fields( + db.statement, + client.id, + client.name = client_name + ), + err, + )] + #[allow(clippy::too_many_lines)] + async fn add( + &mut self, + mut rng: &mut (dyn RngCore + Send), + clock: &Clock, + redirect_uris: Vec, + encrypted_client_secret: Option, + grant_types: Vec, + contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, + ) -> Result { + let now = clock.now(); + let id = Ulid::from_datetime_with_source(now.into(), rng); + tracing::Span::current().record("client.id", tracing::field::display(id)); - sqlx::query!( - r#" - INSERT INTO oauth2_clients - (oauth2_client_id, - encrypted_client_secret, - grant_type_authorization_code, - grant_type_refresh_token, - client_name, - logo_uri, - client_uri, - policy_uri, - tos_uri, - jwks_uri, - jwks, - id_token_signed_response_alg, - userinfo_signed_response_alg, - token_endpoint_auth_method, - token_endpoint_auth_signing_alg, - initiate_login_uri) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) - "#, - Uuid::from(client_id), - encrypted_client_secret, - grant_type_authorization_code, - grant_type_refresh_token, - client_name, - logo_uri, - client_uri, - policy_uri, - tos_uri, - jwks_uri, - jwks, - id_token_signed_response_alg, - userinfo_signed_response_alg, - token_endpoint_auth_method, - token_endpoint_auth_signing_alg, - initiate_login_uri, - ) - .execute(&mut *conn) - .await?; + let jwks_json = jwks + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(DatabaseError::to_invalid_operation)?; - let now = clock.now(); - let (ids, redirect_uris): (Vec, Vec) = redirect_uris - .iter() - .map(|uri| { - ( - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), - uri.as_str().to_owned(), + sqlx::query!( + r#" + INSERT INTO oauth2_clients + ( oauth2_client_id + , encrypted_client_secret + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + "#, + Uuid::from(id), + encrypted_client_secret, + grant_types.contains(&GrantType::AuthorizationCode), + grant_types.contains(&GrantType::RefreshToken), + client_name, + logo_uri.as_ref().map(Url::as_str), + client_uri.as_ref().map(Url::as_str), + policy_uri.as_ref().map(Url::as_str), + tos_uri.as_ref().map(Url::as_str), + jwks_uri.as_ref().map(Url::as_str), + jwks_json, + id_token_signed_response_alg + .as_ref() + .map(ToString::to_string), + userinfo_signed_response_alg + .as_ref() + .map(ToString::to_string), + token_endpoint_auth_method.as_ref().map(ToString::to_string), + token_endpoint_auth_signing_alg + .as_ref() + .map(ToString::to_string), + initiate_login_uri.as_ref().map(Url::as_str), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + { + let span = info_span!( + "db.oauth2_client.add.redirect_uris", + db.statement = tracing::field::Empty, + client.id = %id, + ); + + let (uri_ids, redirect_uris): (Vec, Vec) = redirect_uris + .iter() + .map(|uri| { + ( + Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), + uri.as_str().to_owned(), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_client_redirect_uris + (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) + SELECT id, $2, redirect_uri + FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) + "#, + &uri_ids, + Uuid::from(id), + &redirect_uris, ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + let jwks = match (jwks, jwks_uri) { + (None, None) => None, + (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), + (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), + _ => return Err(DatabaseError::invalid_operation()), + }; + + Ok(Client { + id, + client_id: id.to_string(), + encrypted_client_secret, + redirect_uris, + response_types: vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ], + grant_types, + contacts, + client_name, + logo_uri, + client_uri, + policy_uri, + tos_uri, + jwks, + id_token_signed_response_alg, + userinfo_signed_response_alg, + token_endpoint_auth_method, + token_endpoint_auth_signing_alg, + initiate_login_uri, }) - .unzip(); + } - sqlx::query!( - r#" - INSERT INTO oauth2_client_redirect_uris - (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) - SELECT id, $2, redirect_uri - FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) - "#, - &ids, - Uuid::from(client_id), - &redirect_uris, - ) - .execute(&mut *conn) - .await?; + #[tracing::instrument( + name = "db.oauth2_client.add_from_config", + skip_all, + fields( + db.statement, + client.id = %client_id, + ), + err, + )] + async fn add_from_config( + &mut self, + mut rng: impl Rng + Send, + clock: &Clock, + client_id: Ulid, + client_auth_method: OAuthClientAuthenticationMethod, + encrypted_client_secret: Option, + jwks: Option, + jwks_uri: Option, + redirect_uris: Vec, + ) -> Result { + let jwks_json = jwks + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(DatabaseError::to_invalid_operation)?; - Ok(()) -} + let client_auth_method = client_auth_method.to_string(); -#[allow(clippy::too_many_arguments)] -pub async fn insert_client_from_config( - conn: &mut PgConnection, - mut rng: impl Rng + Send, - clock: &Clock, - client_id: Ulid, - client_auth_method: OAuthClientAuthenticationMethod, - encrypted_client_secret: Option<&str>, - jwks: Option<&PublicJsonWebKeySet>, - jwks_uri: Option<&Url>, - redirect_uris: &[Url], -) -> Result<(), DatabaseError> { - let jwks = jwks - .map(serde_json::to_value) - .transpose() - .map_err(DatabaseError::to_invalid_operation)?; - - let jwks_uri = jwks_uri.map(Url::as_str); - - let client_auth_method = client_auth_method.to_string(); - - sqlx::query!( - r#" + sqlx::query!( + r#" INSERT INTO oauth2_clients ( oauth2_client_id , encrypted_client_secret @@ -500,41 +623,83 @@ pub async fn insert_client_from_config( , jwks = EXCLUDED.jwks , jwks_uri = EXCLUDED.jwks_uri "#, - Uuid::from(client_id), - encrypted_client_secret, - true, - true, - client_auth_method, - jwks, - jwks_uri, - ) - .execute(&mut *conn) - .await?; + Uuid::from(client_id), + encrypted_client_secret, + true, + true, + client_auth_method, + jwks_json, + jwks_uri.as_ref().map(Url::as_str), + ) + .traced() + .execute(&mut *self.conn) + .await?; - let now = clock.now(); - let (ids, redirect_uris): (Vec, Vec) = redirect_uris - .iter() - .map(|uri| { - ( - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), - uri.as_str().to_owned(), + { + let span = info_span!( + "db.oauth2_client.add_from_config.redirect_uris", + client.id = %client_id, + db.statement = tracing::field::Empty, + ); + + let now = clock.now(); + let (ids, redirect_uris): (Vec, Vec) = redirect_uris + .iter() + .map(|uri| { + ( + Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), + uri.as_str().to_owned(), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_client_redirect_uris + (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) + SELECT id, $2, redirect_uri + FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) + "#, + &ids, + Uuid::from(client_id), + &redirect_uris, ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + let jwks = match (jwks, jwks_uri) { + (None, None) => None, + (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), + (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), + _ => return Err(DatabaseError::invalid_operation()), + }; + + Ok(Client { + id: client_id, + client_id: client_id.to_string(), + encrypted_client_secret, + redirect_uris, + response_types: vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ], + grant_types: Vec::new(), + contacts: Vec::new(), + client_name: None, + logo_uri: None, + client_uri: None, + policy_uri: None, + tos_uri: None, + jwks, + id_token_signed_response_alg: None, + userinfo_signed_response_alg: None, + token_endpoint_auth_method: None, + token_endpoint_auth_signing_alg: None, + initiate_login_uri: None, }) - .unzip(); - - sqlx::query!( - r#" - INSERT INTO oauth2_client_redirect_uris - (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) - SELECT id, $2, redirect_uri - FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) - "#, - &ids, - Uuid::from(client_id), - &redirect_uris, - ) - .execute(&mut *conn) - .await?; - - Ok(()) + } } diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index bdc9c1b57..c0153a3a0 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -20,7 +20,7 @@ use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; -use self::client::lookup_clients; +use self::client::OAuth2ClientRepository; use crate::{ pagination::{process_page, QueryBuilderExt}, user::BrowserSessionRepository, @@ -128,7 +128,7 @@ pub async fn get_paginated_user_oauth_sessions( let browser_session_ids: BTreeSet = page.iter().map(|i| Ulid::from(i.user_session_id)).collect(); - let clients = lookup_clients(&mut *conn, client_ids).await?; + let clients = conn.oauth2_client().load_batch(client_ids).await?; // TODO: this can generate N queries instead of batching. This is less than // ideal diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 5c2b63180..57abf103f 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -19,8 +19,8 @@ use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; use uuid::Uuid; -use super::client::lookup_client; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; +use super::client::OAuth2ClientRepository; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError, Repository}; #[tracing::instrument( skip_all, @@ -173,7 +173,9 @@ pub async fn lookup_active_refresh_token( }; let session_id = res.oauth2_session_id.into(); - let client = lookup_client(&mut *conn, res.oauth2_client_id.into()) + let client = conn + .oauth2_client() + .lookup(res.oauth2_client_id.into()) .await? .ok_or_else(|| { DatabaseInconsistencyError::on("oauth2_sessions") diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index ef1a567a3..4bca22530 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -15,6 +15,7 @@ use sqlx::{PgConnection, Postgres, Transaction}; use crate::{ + oauth2::client::PgOAuth2ClientRepository, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, PgUpstreamOAuthSessionRepository, @@ -54,6 +55,10 @@ pub trait Repository { where Self: 'c; + type OAuth2ClientRepository<'c> + where + Self: 'c; + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>; @@ -61,6 +66,7 @@ pub trait Repository { fn user_email(&mut self) -> Self::UserEmailRepository<'_>; fn user_password(&mut self) -> Self::UserPasswordRepository<'_>; fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>; + fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>; } impl Repository for PgConnection { @@ -71,6 +77,7 @@ impl Repository for PgConnection { type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; + type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -99,6 +106,10 @@ impl Repository for PgConnection { fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { PgBrowserSessionRepository::new(self) } + + fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { + PgOAuth2ClientRepository::new(self) + } } impl<'t> Repository for Transaction<'t, Postgres> { @@ -109,6 +120,7 @@ impl<'t> Repository for Transaction<'t, Postgres> { type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; + type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -137,4 +149,8 @@ impl<'t> Repository for Transaction<'t, Postgres> { fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { PgBrowserSessionRepository::new(self) } + + fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { + PgOAuth2ClientRepository::new(self) + } } diff --git a/crates/storage/src/tracing.rs b/crates/storage/src/tracing.rs index 60eb284c9..08c62e465 100644 --- a/crates/storage/src/tracing.rs +++ b/crates/storage/src/tracing.rs @@ -12,9 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub trait ExecuteExt<'q, DB> { +use tracing::Span; + +pub trait ExecuteExt<'q, DB>: Sized { /// Records the statement as `db.statement` in the current span - fn traced(self) -> Self; + fn traced(self) -> Self { + self.record(&Span::current()) + } + + /// Records the statement as `db.statement` in the given span + fn record(self, span: &Span) -> Self; } impl<'q, DB, T> ExecuteExt<'q, DB> for T @@ -22,8 +29,8 @@ where T: sqlx::Execute<'q, DB>, DB: sqlx::Database, { - fn traced(self) -> Self { - tracing::Span::current().record("db.statement", self.sql()); + fn record(self, span: &Span) -> Self { + span.record("db.statement", self.sql()); self } } From 9d8eee12f8f4602d3e78d2e29418582415cfaf47 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 4 Jan 2023 16:02:42 +0100 Subject: [PATCH 10/45] Better tracing spans --- crates/cli/src/commands/config.rs | 8 +++++++- crates/cli/src/commands/database.rs | 3 +++ crates/cli/src/commands/debug.rs | 5 ++++- crates/cli/src/commands/manage.rs | 21 ++++++++++++++++++++- crates/cli/src/commands/server.rs | 6 +++++- crates/cli/src/commands/templates.rs | 3 +++ crates/cli/src/util.rs | 1 + crates/config/src/sections/secrets.rs | 1 + crates/email/src/mailer.rs | 3 ++- crates/handlers/src/passwords.rs | 24 +++++++++++++++--------- crates/policy/src/lib.rs | 25 ++++++++++++++++++++----- crates/templates/src/lib.rs | 25 ++++++++++++++++++++----- 12 files changed, 101 insertions(+), 24 deletions(-) diff --git a/crates/cli/src/commands/config.rs b/crates/cli/src/commands/config.rs index a1d6376e0..f8e3aaa54 100644 --- a/crates/cli/src/commands/config.rs +++ b/crates/cli/src/commands/config.rs @@ -15,7 +15,7 @@ use clap::Parser; use mas_config::{ConfigurationSection, RootConfig}; use rand::SeedableRng; -use tracing::info; +use tracing::{info, info_span}; #[derive(Parser, Debug)] pub(super) struct Options { @@ -40,6 +40,8 @@ impl Options { use Subcommand as SC; match &self.subcommand { SC::Dump => { + let _span = info_span!("cli.config.dump").entered(); + let config: RootConfig = root.load_config()?; serde_yaml::to_writer(std::io::stdout(), &config)?; @@ -47,11 +49,15 @@ impl Options { Ok(()) } SC::Check => { + let _span = info_span!("cli.config.check").entered(); + let _config: RootConfig = root.load_config()?; info!(path = ?root.config, "Configuration file looks good"); Ok(()) } SC::Generate => { + let _span = info_span!("cli.config.generate").entered(); + // XXX: we should disallow SeedableRng::from_entropy let rng = rand_chacha::ChaChaRng::from_entropy(); let config = RootConfig::load_and_generate(rng).await?; diff --git a/crates/cli/src/commands/database.rs b/crates/cli/src/commands/database.rs index 338fdbf91..ca59ce1dd 100644 --- a/crates/cli/src/commands/database.rs +++ b/crates/cli/src/commands/database.rs @@ -16,6 +16,7 @@ use anyhow::Context; use clap::Parser; use mas_config::DatabaseConfig; use mas_storage::MIGRATOR; +use tracing::{info_span, Instrument}; use crate::util::database_from_config; @@ -33,12 +34,14 @@ enum Subcommand { impl Options { pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { + let _span = info_span!("cli.database.migrate").entered(); let config: DatabaseConfig = root.load_config()?; let pool = database_from_config(&config).await?; // Run pending migrations MIGRATOR .run(&pool) + .instrument(info_span!("db.migrate")) .await .context("could not run migrations")?; diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 7b2d6cd8f..24f9262ce 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -19,7 +19,7 @@ use mas_handlers::HttpClientFactory; use mas_http::HttpServiceExt; use tokio::io::AsyncWriteExt; use tower::{Service, ServiceExt}; -use tracing::info; +use tracing::{info, info_span}; use crate::util::policy_factory_from_config; @@ -74,6 +74,7 @@ impl Options { json: false, url, } => { + let _span = info_span!("cli.debug.http").entered(); let mut client = http_client_factory.client("cli-debug-http").await?; let request = hyper::Request::builder() .uri(url) @@ -98,6 +99,7 @@ impl Options { json: true, url, } => { + let _span = info_span!("cli.debug.http").entered(); let mut client = http_client_factory .client("cli-debug-http") .await? @@ -122,6 +124,7 @@ impl Options { } SC::Policy => { + let _span = info_span!("cli.debug.policy").entered(); let config: PolicyConfig = root.load_config()?; info!("Loading and compiling the policy module"); let policy_factory = policy_factory_from_config(&config).await?; diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 9e2b39882..62d62db87 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -25,7 +25,7 @@ use mas_storage::{ }; use oauth2_types::scope::Scope; use rand::SeedableRng; -use tracing::{info, warn}; +use tracing::{info, info_span, warn}; use crate::util::{database_from_config, password_manager_from_config}; @@ -193,6 +193,9 @@ impl Options { match &self.subcommand { SC::SetPassword { username, password } => { + let _span = + info_span!("cli.manage.set_password", user.username = %username).entered(); + let database_config: DatabaseConfig = root.load_config()?; let passwords_config: PasswordsConfig = root.load_config()?; @@ -221,6 +224,13 @@ impl Options { } SC::VerifyEmail { username, email } => { + let _span = info_span!( + "cli.manage.verify_email", + user.username = username, + user_email.email = email + ) + .entered(); + let config: DatabaseConfig = root.load_config()?; let pool = database_from_config(&config).await?; let mut txn = pool.begin().await?; @@ -245,6 +255,8 @@ impl Options { } SC::ImportClients { update } => { + let _span = info_span!("cli.manage.import_clients").entered(); + let config: RootConfig = root.load_config()?; let pool = database_from_config(&config.database).await?; let encrypter = config.secrets.encrypter(); @@ -303,6 +315,13 @@ impl Options { client_secret, signing_alg, } => { + let _span = info_span!( + "cli.manage.add_oauth_upstream", + upstream_oauth_provider.issuer = issuer, + upstream_oauth_provider.client_id = client_id, + ) + .entered(); + let config: RootConfig = root.load_config()?; let encrypter = config.secrets.encrypter(); let pool = database_from_config(&config.database).await?; diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 606547827..fb2a3f168 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -24,7 +24,7 @@ use mas_router::UrlBuilder; use mas_storage::MIGRATOR; use mas_tasks::TaskQueue; use tokio::signal::unix::SignalKind; -use tracing::{info, warn}; +use tracing::{info, info_span, warn, Instrument}; use crate::util::{ database_from_config, mailer_from_config, password_manager_from_config, @@ -45,6 +45,7 @@ pub(super) struct Options { impl Options { #[allow(clippy::too_many_lines)] pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { + let span = info_span!("cli.run.init").entered(); let config: RootConfig = root.load_config()?; // Connect to the database @@ -55,6 +56,7 @@ impl Options { info!("Running pending migrations"); MIGRATOR .run(&pool) + .instrument(info_span!("db.migrate")) .await .context("could not run migrations")?; } @@ -186,6 +188,8 @@ impl Options { .with_signal(SignalKind::terminate())? .with_signal(SignalKind::interrupt())?; + span.exit(); + mas_listener::server::run_servers(servers, shutdown).await; Ok(()) diff --git a/crates/cli/src/commands/templates.rs b/crates/cli/src/commands/templates.rs index 186a097e9..a3a1bf9cc 100644 --- a/crates/cli/src/commands/templates.rs +++ b/crates/cli/src/commands/templates.rs @@ -17,6 +17,7 @@ use clap::Parser; use mas_storage::Clock; use mas_templates::Templates; use rand::SeedableRng; +use tracing::info_span; #[derive(Parser, Debug)] pub(super) struct Options { @@ -38,6 +39,8 @@ impl Options { use Subcommand as SC; match &self.subcommand { SC::Check { path } => { + let _span = info_span!("cli.templates.check").entered(); + let clock = Clock::default(); // XXX: we should disallow SeedableRng::from_entropy let mut rng = rand_chacha::ChaChaRng::from_entropy(); diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index ec525478f..b6485e951 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -110,6 +110,7 @@ pub async fn templates_from_config( Templates::load(config.path.clone(), url_builder.clone()).await } +#[tracing::instrument(name = "db.connect", skip_all, err(Debug))] pub async fn database_from_config(config: &DatabaseConfig) -> Result { let mut options = match &config.options { DatabaseConnectConfig::Uri { uri } => uri diff --git a/crates/config/src/sections/secrets.rs b/crates/config/src/sections/secrets.rs index 3a0f79a41..afe624f46 100644 --- a/crates/config/src/sections/secrets.rs +++ b/crates/config/src/sections/secrets.rs @@ -86,6 +86,7 @@ impl SecretsConfig { /// # Errors /// /// Returns an error when a key could not be imported + #[tracing::instrument(name = "secrets.load", skip_all, err(Debug))] pub async fn key_store(&self) -> anyhow::Result { let mut keys = Vec::with_capacity(self.keys.len()); for item in &self.keys { diff --git a/crates/email/src/mailer.rs b/crates/email/src/mailer.rs index 177cc8d10..979a9f7bb 100644 --- a/crates/email/src/mailer.rs +++ b/crates/email/src/mailer.rs @@ -100,10 +100,10 @@ impl Mailer { /// /// Will return `Err` if the email failed rendering or failed sending #[tracing::instrument( + name = "email.verification.send", skip_all, fields( email.to = %to, - email.from = %self.from, user.id = %context.user().id, user_email_verification.id = %context.verification().id, user_email_verification.code = context.verification().code, @@ -125,6 +125,7 @@ impl Mailer { /// # Errors /// /// Returns an error if the connection failed + #[tracing::instrument(name = "email.test_connection", skip_all, err)] pub async fn test_connection(&self) -> Result<(), crate::transport::Error> { self.transport.test_connection().await } diff --git a/crates/handlers/src/passwords.rs b/crates/handlers/src/passwords.rs index 89326c9a3..4be71a45a 100644 --- a/crates/handlers/src/passwords.rs +++ b/crates/handlers/src/passwords.rs @@ -71,7 +71,7 @@ impl PasswordManager { /// # Errors /// /// Returns an error if the hashing failed - #[tracing::instrument(skip_all)] + #[tracing::instrument(name = "passwords.hash", skip_all)] pub async fn hash( &self, rng: R, @@ -82,13 +82,16 @@ impl PasswordManager { let rng = rand_chacha::ChaChaRng::from_rng(rng)?; let hashers = self.hashers.clone(); let default_hasher_version = self.default_hasher; + let span = tracing::Span::current(); let hashed = tokio::task::spawn_blocking(move || { - let default_hasher = hashers - .get(&default_hasher_version) - .context("Default hasher not found")?; + span.in_scope(move || { + let default_hasher = hashers + .get(&default_hasher_version) + .context("Default hasher not found")?; - default_hasher.hash_blocking(rng, &password) + default_hasher.hash_blocking(rng, &password) + }) }) .await??; @@ -100,7 +103,7 @@ impl PasswordManager { /// # Errors /// /// Returns an error if the password hash verification failed - #[tracing::instrument(skip_all, fields(%scheme))] + #[tracing::instrument(name = "passwords.verify", skip_all, fields(%scheme))] pub async fn verify( &self, scheme: SchemeVersion, @@ -108,10 +111,13 @@ impl PasswordManager { hashed_password: String, ) -> Result<(), anyhow::Error> { let hashers = self.hashers.clone(); + let span = tracing::Span::current(); tokio::task::spawn_blocking(move || { - let hasher = hashers.get(&scheme).context("Hashing scheme not found")?; - hasher.verify_blocking(&hashed_password, &password) + span.in_scope(move || { + let hasher = hashers.get(&scheme).context("Hashing scheme not found")?; + hasher.verify_blocking(&hashed_password, &password) + }) }) .await??; @@ -124,7 +130,7 @@ impl PasswordManager { /// # Errors /// /// Returns an error if the password hash verification failed - #[tracing::instrument(skip_all, fields(%scheme))] + #[tracing::instrument(name = "passwords.verify_and_upgrade", skip_all, fields(%scheme))] pub async fn verify_and_upgrade( &self, rng: R, diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 27cfb1b32..d895339a1 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -69,7 +69,7 @@ pub struct PolicyFactory { } impl PolicyFactory { - #[tracing::instrument(skip(source), err)] + #[tracing::instrument(name = "policy.load", skip(source), err)] pub async fn load( mut source: impl AsyncRead + std::marker::Unpin, data: serde_json::Value, @@ -108,7 +108,7 @@ impl PolicyFactory { authorization_grant_endpoint, }; - // Try to instanciate + // Try to instantiate factory .instantiate() .await @@ -117,7 +117,7 @@ impl PolicyFactory { Ok(factory) } - #[tracing::instrument(skip(self), err)] + #[tracing::instrument(name = "policy.instantiate", skip_all, err)] pub async fn instantiate(&self) -> Result { let mut store = Store::new(&self.engine, ()); let runtime = Runtime::new(&mut store, &self.module) @@ -189,7 +189,14 @@ pub enum EvaluationError { } impl Policy { - #[tracing::instrument(skip(self, password))] + #[tracing::instrument( + name = "policy.evaluate.register", + skip_all, + fields( + data.username = username, + ), + err, + )] pub async fn evaluate_register( &mut self, username: &str, @@ -234,7 +241,15 @@ impl Policy { Ok(res) } - #[tracing::instrument(skip(self))] + #[tracing::instrument( + name = "policy.evaluate.authorization_grant", + skip_all, + fields( + data.authorization_grant.id = %authorization_grant.id, + data.user.id = %user.id, + ), + err, + )] pub async fn evaluate_authorization_grant( &mut self, authorization_grant: &AuthorizationGrant, diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 496bcc26a..938a701a4 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -96,6 +96,12 @@ impl Templates { } /// Load the templates from the given config + #[tracing::instrument( + name = "templates.load", + skip_all, + fields(%path), + err, + )] pub async fn load( path: Utf8PathBuf, url_builder: UrlBuilder, @@ -110,14 +116,17 @@ impl Templates { async fn load_(path: &Utf8Path, url_builder: UrlBuilder) -> Result { let path = path.to_owned(); + let span = tracing::Span::current(); // This uses blocking I/Os, do that in a blocking task let mut tera = tokio::task::spawn_blocking(move || { - let path = path.canonicalize_utf8()?; - let path = format!("{path}/**/*.{{html,txt,subject}}"); + span.in_scope(move || { + let path = path.canonicalize_utf8()?; + let path = format!("{path}/**/*.{{html,txt,subject}}"); - info!(%path, "Loading templates from filesystem"); - Tera::new(&path) + info!(%path, "Loading templates from filesystem"); + Tera::new(&path) + }) }) .await??; @@ -138,7 +147,13 @@ impl Templates { } /// Reload the templates on disk - pub async fn reload(&self) -> anyhow::Result<()> { + #[tracing::instrument( + name = "templates.reload", + skip_all, + fields(path = %self.path), + err, + )] + pub async fn reload(&self) -> Result<(), TemplateLoadingError> { // Prepare the new Tera instance let new_tera = Self::load_(&self.path, self.url_builder.clone()).await?; From 6053e24d735a58b5e878e2fd59304dbd8cd07fc9 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 4 Jan 2023 18:06:17 +0100 Subject: [PATCH 11/45] storage: Load with less joins This is done to simplify some queries, to avoid loading more data than necessary, and in preparation of a proper cache layer --- crates/data-model/src/compat.rs | 4 +- crates/data-model/src/oauth2/session.rs | 7 +- crates/graphql/src/model/compat_sessions.rs | 15 +- crates/graphql/src/model/oauth.rs | 40 +- crates/handlers/src/compat/login.rs | 29 +- .../handlers/src/compat/login_sso_complete.rs | 2 +- crates/handlers/src/oauth2/introspection.rs | 51 +- crates/handlers/src/oauth2/token.rs | 29 +- crates/handlers/src/oauth2/userinfo.rs | 35 +- crates/storage/sqlx-data.json | 1186 +++++++---------- crates/storage/src/compat.rs | 201 +-- crates/storage/src/oauth2/access_token.rs | 74 +- .../storage/src/oauth2/authorization_grant.rs | 184 +-- crates/storage/src/oauth2/mod.rs | 57 +- crates/storage/src/oauth2/refresh_token.rs | 99 +- crates/storage/src/oauth2/session.rs | 20 + 16 files changed, 824 insertions(+), 1209 deletions(-) create mode 100644 crates/storage/src/oauth2/session.rs diff --git a/crates/data-model/src/compat.rs b/crates/data-model/src/compat.rs index d6f772db3..07ff9aaaa 100644 --- a/crates/data-model/src/compat.rs +++ b/crates/data-model/src/compat.rs @@ -23,8 +23,6 @@ use thiserror::Error; use ulid::Ulid; use url::Url; -use crate::User; - static DEVICE_ID_LENGTH: usize = 10; #[derive(Debug, Clone, PartialEq, Eq, Serialize)] @@ -85,7 +83,7 @@ impl TryFrom for Device { #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct CompatSession { pub id: Ulid, - pub user: User, + pub user_id: Ulid, pub device: Device, pub created_at: DateTime, pub finished_at: Option>, diff --git a/crates/data-model/src/oauth2/session.rs b/crates/data-model/src/oauth2/session.rs index ff222ca83..29454feb2 100644 --- a/crates/data-model/src/oauth2/session.rs +++ b/crates/data-model/src/oauth2/session.rs @@ -16,13 +16,10 @@ use oauth2_types::scope::Scope; use serde::Serialize; use ulid::Ulid; -use super::client::Client; -use crate::users::BrowserSession; - #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct Session { pub id: Ulid, - pub browser_session: BrowserSession, - pub client: Client, + pub user_session_id: Ulid, + pub client_id: Ulid, pub scope: Scope, } diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index 5b272b184..f3610233d 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use async_graphql::{Description, Object, ID}; +use anyhow::Context as _; +use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; use mas_data_model::CompatSsoLoginState; +use mas_storage::{user::UserRepository, Repository}; +use sqlx::PgPool; use url::Url; use super::{NodeType, User}; @@ -32,8 +35,14 @@ impl CompatSession { } /// The user authorized for this session. - async fn user(&self) -> User { - User(self.0.user.clone()) + async fn user(&self, ctx: &Context<'_>) -> Result { + let mut conn = ctx.data::()?.acquire().await?; + let user = conn + .user() + .lookup(self.0.user_id) + .await? + .context("Could not load user")?; + Ok(User(user)) } /// The Matrix Device ID of this session. diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 5f1236f2e..8e418e6c8 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,7 +14,9 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::{oauth2::client::OAuth2ClientRepository, Repository}; +use mas_storage::{ + oauth2::client::OAuth2ClientRepository, user::BrowserSessionRepository, Repository, +}; use oauth2_types::scope::Scope; use sqlx::PgPool; use ulid::Ulid; @@ -35,8 +37,15 @@ impl OAuth2Session { } /// OAuth 2.0 client used by this session. - pub async fn client(&self) -> OAuth2Client { - OAuth2Client(self.0.client.clone()) + pub async fn client(&self, ctx: &Context<'_>) -> Result { + let mut conn = ctx.data::()?.acquire().await?; + let client = conn + .oauth2_client() + .lookup(self.0.client_id) + .await? + .context("Could not load client")?; + + Ok(OAuth2Client(client)) } /// Scope granted for this session. @@ -45,13 +54,30 @@ impl OAuth2Session { } /// The browser session which started this OAuth 2.0 session. - pub async fn browser_session(&self) -> BrowserSession { - BrowserSession(self.0.browser_session.clone()) + pub async fn browser_session( + &self, + ctx: &Context<'_>, + ) -> Result { + let mut conn = ctx.data::()?.acquire().await?; + let browser_session = conn + .browser_session() + .lookup(self.0.user_session_id) + .await? + .context("Could not load browser session")?; + + Ok(BrowserSession(browser_session)) } /// User authorized for this session. - pub async fn user(&self) -> User { - User(self.0.browser_session.user.clone()) + pub async fn user(&self, ctx: &Context<'_>) -> Result { + let mut conn = ctx.data::()?.acquire().await?; + let browser_session = conn + .browser_session() + .lookup(self.0.user_session_id) + .await? + .context("Could not load browser session")?; + + Ok(User(browser_session.user)) } } diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index c3b910024..c59c7dd81 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -15,7 +15,7 @@ use axum::{extract::State, response::IntoResponse, Json}; use chrono::Duration; use hyper::StatusCode; -use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType}; +use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User}; use mas_storage::{ compat::{ add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token, @@ -197,7 +197,7 @@ pub(crate) async fn post( ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); let mut txn = pool.begin().await?; - let session = match input.credentials { + let (session, user) = match input.credentials { Credentials::Password { identifier: Identifier::User { user }, password, @@ -210,7 +210,7 @@ pub(crate) async fn post( } }; - let user_id = format!("@{username}:{homeserver}", username = session.user.username); + let user_id = format!("@{username}:{homeserver}", username = user.username); // If the client asked for a refreshable token, make it expire let expires_in = if input.refresh_token { @@ -262,13 +262,13 @@ async fn token_login( txn: &mut Transaction<'_, Postgres>, clock: &Clock, token: &str, -) -> Result { +) -> Result<(CompatSession, User), RouteError> { let login = get_compat_sso_login_by_token(&mut *txn, token) .await? .ok_or(RouteError::InvalidLoginToken)?; let now = clock.now(); - match login.state { + let user_id = match login.state { CompatSsoLoginState::Pending => { tracing::error!( compat_sso_login.id = %login.id, @@ -278,11 +278,14 @@ async fn token_login( } CompatSsoLoginState::Fulfilled { fulfilled_at: fullfilled_at, + ref session, .. } => { if now > fullfilled_at + Duration::seconds(30) { return Err(RouteError::LoginTookTooLong); } + + session.user_id } CompatSsoLoginState::Exchanged { exchanged_at, .. } => { if now > exchanged_at + Duration::seconds(30) { @@ -295,12 +298,18 @@ async fn token_login( return Err(RouteError::InvalidLoginToken); } - } + }; + + let user = txn + .user() + .lookup(user_id) + .await? + .ok_or(RouteError::UserNotFound)?; let login = mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?; match login.state { - CompatSsoLoginState::Exchanged { session, .. } => Ok(session), + CompatSsoLoginState::Exchanged { session, .. } => Ok((session, user)), _ => unreachable!(), } } @@ -310,7 +319,7 @@ async fn user_password_login( txn: &mut Transaction<'_, Postgres>, username: String, password: String, -) -> Result { +) -> Result<(CompatSession, User), RouteError> { let (clock, mut rng) = crate::clock_and_rng(); // Find the user @@ -356,7 +365,7 @@ async fn user_password_login( // Now that the user credentials have been verified, start a new compat session let device = Device::generate(&mut rng); - let session = start_compat_session(&mut *txn, &mut rng, &clock, user, device).await?; + let session = start_compat_session(&mut *txn, &mut rng, &clock, &user, device).await?; - Ok(session) + Ok((session, user)) } diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 497908427..f31856d6c 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -182,7 +182,7 @@ pub async fn post( let device = Device::generate(&mut rng); let _login = - fullfill_compat_sso_login(&mut txn, &mut rng, &clock, session.user, login, device).await?; + fullfill_compat_sso_login(&mut txn, &mut rng, &clock, &session.user, login, device).await?; txn.commit().await?; diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index c2e68261b..71f3f1488 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -26,7 +26,8 @@ use mas_storage::{ oauth2::{ access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, }, - Clock, + user::{BrowserSessionRepository, UserRepository}, + Clock, Repository, }; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -171,16 +172,23 @@ pub(crate) async fn post( .await? .ok_or(RouteError::UnknownToken)?; + let browser_session = conn + .browser_session() + .lookup(session.user_session_id) + .await? + // XXX: is that the right error to bubble up? + .ok_or(RouteError::UnknownToken)?; + IntrospectionResponse { active: true, scope: Some(session.scope), - client_id: Some(session.client.client_id), - username: Some(session.browser_session.user.username), + client_id: Some(session.client_id.to_string()), + username: Some(browser_session.user.username), token_type: Some(OAuthTokenTypeHint::AccessToken), exp: Some(token.expires_at), iat: Some(token.created_at), nbf: Some(token.created_at), - sub: Some(session.browser_session.user.sub), + sub: Some(browser_session.user.sub), aud: None, iss: None, jti: None, @@ -191,16 +199,23 @@ pub(crate) async fn post( .await? .ok_or(RouteError::UnknownToken)?; + let browser_session = conn + .browser_session() + .lookup(session.user_session_id) + .await? + // XXX: is that the right error to bubble up? + .ok_or(RouteError::UnknownToken)?; + IntrospectionResponse { active: true, scope: Some(session.scope), - client_id: Some(session.client.client_id), - username: Some(session.browser_session.user.username), + client_id: Some(session.client_id.to_string()), + username: Some(browser_session.user.username), token_type: Some(OAuthTokenTypeHint::RefreshToken), exp: None, iat: Some(token.created_at), nbf: Some(token.created_at), - sub: Some(session.browser_session.user.sub), + sub: Some(browser_session.user.sub), aud: None, iss: None, jti: None, @@ -211,6 +226,13 @@ pub(crate) async fn post( .await? .ok_or(RouteError::UnknownToken)?; + let user = conn + .user() + .lookup(session.user_id) + .await? + // XXX: is that the right error to bubble up? + .ok_or(RouteError::UnknownToken)?; + let device_scope = session.device.to_scope_token(); let scope = [API_SCOPE, device_scope].into_iter().collect(); @@ -218,12 +240,12 @@ pub(crate) async fn post( active: true, scope: Some(scope), client_id: Some("legacy".into()), - username: Some(session.user.username), + username: Some(user.username), token_type: Some(OAuthTokenTypeHint::AccessToken), exp: token.expires_at, iat: Some(token.created_at), nbf: Some(token.created_at), - sub: Some(session.user.sub), + sub: Some(user.sub), aud: None, iss: None, jti: None, @@ -235,6 +257,13 @@ pub(crate) async fn post( .await? .ok_or(RouteError::UnknownToken)?; + let user = conn + .user() + .lookup(session.user_id) + .await? + // XXX: is that the right error to bubble up? + .ok_or(RouteError::UnknownToken)?; + let device_scope = session.device.to_scope_token(); let scope = [API_SCOPE, device_scope].into_iter().collect(); @@ -242,12 +271,12 @@ pub(crate) async fn post( active: true, scope: Some(scope), client_id: Some("legacy".into()), - username: Some(session.user.username), + username: Some(user.username), token_type: Some(OAuthTokenTypeHint::RefreshToken), exp: None, iat: Some(refresh_token.created_at), nbf: Some(refresh_token.created_at), - sub: Some(session.user.sub), + sub: Some(user.sub), aud: None, iss: None, jti: None, diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index a6d899f4f..473dcab82 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -31,11 +31,15 @@ use mas_jose::{ }; use mas_keystore::{Encrypter, Keystore}; use mas_router::UrlBuilder; -use mas_storage::oauth2::{ - access_token::{add_access_token, revoke_access_token}, - authorization_grant::{exchange_grant, lookup_grant_by_code}, - end_oauth_session, - refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token}, +use mas_storage::{ + oauth2::{ + access_token::{add_access_token, revoke_access_token}, + authorization_grant::{exchange_grant, lookup_grant_by_code}, + end_oauth_session, + refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token}, + }, + user::BrowserSessionRepository, + Repository, }; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -102,12 +106,15 @@ pub(crate) enum RouteError { #[error("no suitable key found for signing")] InvalidSigningKey, + + #[error("failed to load browser session")] + NoSuchBrowserSession, } impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { - Self::Internal(_) | Self::InvalidSigningKey => ( + Self::Internal(_) | Self::InvalidSigningKey | Self::NoSuchBrowserSession => ( StatusCode::INTERNAL_SERVER_ERROR, Json(ClientError::from(ClientErrorCode::ServerError)), ), @@ -253,7 +260,7 @@ async fn authorization_code_grant( // This should never happen, since we looked up in the database using the code let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?; - if client.client_id != session.client.client_id { + if client.id != session.client_id { return Err(RouteError::UnauthorizedClient); } @@ -267,7 +274,11 @@ async fn authorization_code_grant( } }; - let browser_session = &session.browser_session; + let browser_session = txn + .browser_session() + .lookup(session.user_session_id) + .await? + .ok_or(RouteError::NoSuchBrowserSession)?; let ttl = Duration::minutes(5); let access_token_str = TokenType::AccessToken.generate(&mut rng); @@ -357,7 +368,7 @@ async fn refresh_token_grant( .await? .ok_or(RouteError::InvalidGrant)?; - if client.client_id != session.client.client_id { + if client.id != session.client_id { // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 return Err(RouteError::InvalidGrant); } diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 225870ad3..699b049a5 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -28,7 +28,11 @@ use mas_jose::{ }; use mas_keystore::Keystore; use mas_router::UrlBuilder; -use mas_storage::{user::UserEmailRepository, Repository}; +use mas_storage::{ + oauth2::client::OAuth2ClientRepository, + user::{BrowserSessionRepository, UserEmailRepository}, + Repository, +}; use oauth2_types::scope; use serde::Serialize; use serde_with::skip_serializing_none; @@ -64,6 +68,12 @@ pub enum RouteError { #[error("no suitable key found for signing")] InvalidSigningKey, + + #[error("failed to load client")] + NoSuchClient, + + #[error("failed to load browser session")] + NoSuchBrowserSession, } impl_from_error_for_route!(sqlx::Error); @@ -74,7 +84,10 @@ impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { - Self::Internal(_) | Self::InvalidSigningKey => { + Self::Internal(_) + | Self::InvalidSigningKey + | Self::NoSuchClient + | Self::NoSuchBrowserSession => { (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response() } Self::AuthorizationVerificationError(_e) => StatusCode::UNAUTHORIZED.into_response(), @@ -93,7 +106,13 @@ pub async fn get( let session = user_authorization.protected(&mut conn).await?; - let user = session.browser_session.user; + let browser_session = conn + .browser_session() + .lookup(session.user_session_id) + .await? + .ok_or(RouteError::NoSuchBrowserSession)?; + + let user = browser_session.user; let user_email = if session.scope.contains(&scope::EMAIL) { conn.user_email().get_primary(&user).await? @@ -108,7 +127,13 @@ pub async fn get( email: user_email.map(|u| u.email), }; - if let Some(alg) = session.client.userinfo_signed_response_alg { + let client = conn + .oauth2_client() + .lookup(session.client_id) + .await? + .ok_or(RouteError::NoSuchClient)?; + + if let Some(alg) = client.userinfo_signed_response_alg { let key = key_store .signing_key_for_algorithm(&alg) .ok_or(RouteError::InvalidSigningKey)?; @@ -119,7 +144,7 @@ pub async fn get( let user_info = SignedUserInfo { iss: url_builder.oidc_issuer().to_string(), - aud: session.client.client_id, + aud: client.client_id, user_info, }; diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index feced0ba8..4784c030f 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -1,77 +1,67 @@ { "db": "PostgreSQL", - "03bc4a14e97e011fec04e5788a967e04838cf978984254ecfd2c8b8a979da1c8": { + "021f845e564500457e2e0c8614beb1d9fd10b4b5f13515478f7ca25b5474d016": { "describe": { "columns": [ { - "name": "oauth2_access_token_id", + "name": "compat_refresh_token_id", "ordinal": 0, "type_info": "Uuid" }, { - "name": "oauth2_access_token", + "name": "compat_refresh_token", "ordinal": 1, "type_info": "Text" }, { - "name": "oauth2_access_token_created_at", + "name": "compat_refresh_token_created_at", "ordinal": 2, "type_info": "Timestamptz" }, { - "name": "oauth2_access_token_expires_at", + "name": "compat_access_token_id", "ordinal": 3, - "type_info": "Timestamptz" + "type_info": "Uuid" }, { - "name": "oauth2_session_id!", + "name": "compat_access_token", "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 5, - "type_info": "Uuid" - }, - { - "name": "scope!", - "ordinal": 6, "type_info": "Text" }, { - "name": "user_session_id!", + "name": "compat_access_token_created_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_access_token_expires_at", + "ordinal": 6, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", "ordinal": 7, "type_info": "Uuid" }, { - "name": "user_session_created_at!", + "name": "compat_session_created_at", "ordinal": 8, "type_info": "Timestamptz" }, { - "name": "user_id!", + "name": "compat_session_finished_at", "ordinal": 9, - "type_info": "Uuid" + "type_info": "Timestamptz" }, { - "name": "user_username!", + "name": "compat_session_device_id", "ordinal": 10, "type_info": "Text" }, { - "name": "user_primary_user_email_id", + "name": "user_id", "ordinal": 11, "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_id?", - "ordinal": 12, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_created_at?", - "ordinal": 13, - "type_info": "Timestamptz" } ], "nullable": [ @@ -81,9 +71,7 @@ false, false, false, - false, - false, - false, + true, false, false, true, @@ -96,7 +84,7 @@ ] } }, - "query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , us.user_session_id AS \"user_session_id!\"\n , us.created_at AS \"user_session_created_at!\"\n , u.user_id AS \"user_id!\"\n , u.username AS \"user_username!\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , usa.user_session_authentication_id AS \"user_session_last_authentication_id?\"\n , usa.created_at AS \"user_session_last_authentication_created_at?\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n INNER JOIN user_sessions us\n USING (user_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " + "query": "\n SELECT cr.compat_refresh_token_id\n , cr.refresh_token AS \"compat_refresh_token\"\n , cr.created_at AS \"compat_refresh_token_created_at\"\n , ct.compat_access_token_id\n , ct.access_token AS \"compat_access_token\"\n , ct.created_at AS \"compat_access_token_created_at\"\n , ct.expires_at AS \"compat_access_token_expires_at\"\n , cs.compat_session_id\n , cs.created_at AS \"compat_session_created_at\"\n , cs.finished_at AS \"compat_session_finished_at\"\n , cs.device_id AS \"compat_session_device_id\"\n , cs.user_id\n\n FROM compat_refresh_tokens cr\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN compat_access_tokens ct\n USING (compat_access_token_id)\n\n WHERE cr.refresh_token = $1\n AND cr.consumed_at IS NULL\n AND cs.finished_at IS NULL\n " }, "08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35": { "describe": { @@ -196,438 +184,6 @@ }, "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n " }, - "16a1c5fe5a4c5481212560d79d589b550dfefe7480c5ee4febcbfaaa01ee93a4": { - "describe": { - "columns": [ - { - "name": "compat_sso_login_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_sso_login_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_sso_login_redirect_uri", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "compat_sso_login_created_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_fulfilled_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_exchanged_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id?", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at?", - "ordinal": 7, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at?", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id?", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_id?", - "ordinal": 10, - "type_info": "Uuid" - }, - { - "name": "user_username?", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "user_primary_user_email_id?", - "ordinal": 12, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - false, - false, - true, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n cl.compat_sso_login_id,\n cl.login_token AS \"compat_sso_login_token\",\n cl.redirect_uri AS \"compat_sso_login_redirect_uri\",\n cl.created_at AS \"compat_sso_login_created_at\",\n cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\",\n cl.exchanged_at AS \"compat_sso_login_exchanged_at\",\n cs.compat_session_id AS \"compat_session_id?\",\n cs.created_at AS \"compat_session_created_at?\",\n cs.finished_at AS \"compat_session_finished_at?\",\n cs.device_id AS \"compat_session_device_id?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n u.primary_user_email_id AS \"user_primary_user_email_id?\"\n FROM compat_sso_logins cl\n LEFT JOIN compat_sessions cs\n USING (compat_session_id)\n LEFT JOIN users u\n USING (user_id)\n WHERE cl.login_token = $1\n " - }, - "1a5e0d1d88065bb4e7f790942856d1d94ecdb30a7007f3277ca3f7cbdabd4dff": { - "describe": { - "columns": [ - { - "name": "oauth2_authorization_grant_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_cancelled_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_fulfilled_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_exchanged_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_scope", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_state", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_redirect_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_mode", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_nonce", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_max_age", - "ordinal": 10, - "type_info": "Int4" - }, - { - "name": "oauth2_client_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_code", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_type_code", - "ordinal": 13, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_response_type_id_token", - "ordinal": 14, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_code_challenge", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_code_challenge_method", - "ordinal": 16, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_requires_consent", - "ordinal": 17, - "type_info": "Bool" - }, - { - "name": "oauth2_session_id?", - "ordinal": 18, - "type_info": "Uuid" - }, - { - "name": "user_session_id?", - "ordinal": 19, - "type_info": "Uuid" - }, - { - "name": "user_session_created_at?", - "ordinal": 20, - "type_info": "Timestamptz" - }, - { - "name": "user_id?", - "ordinal": 21, - "type_info": "Uuid" - }, - { - "name": "user_username?", - "ordinal": 22, - "type_info": "Text" - }, - { - "name": "user_primary_user_email_id?", - "ordinal": 23, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_id?", - "ordinal": 24, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_created_at?", - "ordinal": 25, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true, - false, - false, - true, - true, - false, - false, - false, - false, - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n og.oauth2_authorization_grant_id,\n og.created_at AS oauth2_authorization_grant_created_at,\n og.cancelled_at AS oauth2_authorization_grant_cancelled_at,\n og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at,\n og.exchanged_at AS oauth2_authorization_grant_exchanged_at,\n og.scope AS oauth2_authorization_grant_scope,\n og.state AS oauth2_authorization_grant_state,\n og.redirect_uri AS oauth2_authorization_grant_redirect_uri,\n og.response_mode AS oauth2_authorization_grant_response_mode,\n og.nonce AS oauth2_authorization_grant_nonce,\n og.max_age AS oauth2_authorization_grant_max_age,\n og.oauth2_client_id AS oauth2_client_id,\n og.authorization_code AS oauth2_authorization_grant_code,\n og.response_type_code AS oauth2_authorization_grant_response_type_code,\n og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token,\n og.code_challenge AS oauth2_authorization_grant_code_challenge,\n og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method,\n og.requires_consent AS oauth2_authorization_grant_requires_consent,\n os.oauth2_session_id AS \"oauth2_session_id?\",\n us.user_session_id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n u.primary_user_email_id AS \"user_primary_user_email_id?\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN user_sessions us\n USING (user_session_id)\n LEFT JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n\n WHERE og.authorization_code = $1\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " - }, - "1b448fe73e12bef622b75857e4c9b257c9529ca18da7f63d127e63184f4bc94b": { - "describe": { - "columns": [ - { - "name": "oauth2_authorization_grant_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_cancelled_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_fulfilled_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_exchanged_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_scope", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_state", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_redirect_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_mode", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_nonce", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_max_age", - "ordinal": 10, - "type_info": "Int4" - }, - { - "name": "oauth2_client_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_code", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_type_code", - "ordinal": 13, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_response_type_id_token", - "ordinal": 14, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_code_challenge", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_code_challenge_method", - "ordinal": 16, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_requires_consent", - "ordinal": 17, - "type_info": "Bool" - }, - { - "name": "oauth2_session_id?", - "ordinal": 18, - "type_info": "Uuid" - }, - { - "name": "user_session_id?", - "ordinal": 19, - "type_info": "Uuid" - }, - { - "name": "user_session_created_at?", - "ordinal": 20, - "type_info": "Timestamptz" - }, - { - "name": "user_id?", - "ordinal": 21, - "type_info": "Uuid" - }, - { - "name": "user_username?", - "ordinal": 22, - "type_info": "Text" - }, - { - "name": "user_primary_user_email_id?", - "ordinal": 23, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_id?", - "ordinal": 24, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_created_at?", - "ordinal": 25, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true, - false, - false, - true, - true, - false, - false, - false, - false, - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n og.oauth2_authorization_grant_id,\n og.created_at AS oauth2_authorization_grant_created_at,\n og.cancelled_at AS oauth2_authorization_grant_cancelled_at,\n og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at,\n og.exchanged_at AS oauth2_authorization_grant_exchanged_at,\n og.scope AS oauth2_authorization_grant_scope,\n og.state AS oauth2_authorization_grant_state,\n og.redirect_uri AS oauth2_authorization_grant_redirect_uri,\n og.response_mode AS oauth2_authorization_grant_response_mode,\n og.nonce AS oauth2_authorization_grant_nonce,\n og.max_age AS oauth2_authorization_grant_max_age,\n og.oauth2_client_id AS oauth2_client_id,\n og.authorization_code AS oauth2_authorization_grant_code,\n og.response_type_code AS oauth2_authorization_grant_response_type_code,\n og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token,\n og.code_challenge AS oauth2_authorization_grant_code_challenge,\n og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method,\n og.requires_consent AS oauth2_authorization_grant_requires_consent,\n os.oauth2_session_id AS \"oauth2_session_id?\",\n us.user_session_id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n u.primary_user_email_id AS \"user_primary_user_email_id?\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN user_sessions us\n USING (user_session_id)\n LEFT JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n\n WHERE og.oauth2_authorization_grant_id = $1\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " - }, "1d372f36c382ab16264cea54537af3544ea6d6d75d10b432b07dbd0dadd2fa4e": { "describe": { "columns": [ @@ -951,6 +507,86 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n oauth2_session_id = os.oauth2_session_id,\n fulfilled_at = os.created_at\n FROM oauth2_sessions os\n WHERE\n og.oauth2_authorization_grant_id = $1\n AND os.oauth2_session_id = $2\n RETURNING fulfilled_at AS \"fulfilled_at!: DateTime\"\n " }, + "4f8b0cd13d9488c2dd0f183d090d3856da15dcdb57a8c113febbee665a2a3ac5": { + "describe": { + "columns": [ + { + "name": "compat_sso_login_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "compat_sso_login_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "compat_sso_login_redirect_uri", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "compat_sso_login_created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_sso_login_fulfilled_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "compat_sso_login_exchanged_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id?", + "ordinal": 6, + "type_info": "Uuid" + }, + { + "name": "compat_session_created_at?", + "ordinal": 7, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_finished_at?", + "ordinal": 8, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_device_id?", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "user_id?", + "ordinal": 10, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT cl.compat_sso_login_id\n , cl.login_token AS \"compat_sso_login_token\"\n , cl.redirect_uri AS \"compat_sso_login_redirect_uri\"\n , cl.created_at AS \"compat_sso_login_created_at\"\n , cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\"\n , cl.exchanged_at AS \"compat_sso_login_exchanged_at\"\n , cs.compat_session_id AS \"compat_session_id?\"\n , cs.created_at AS \"compat_session_created_at?\"\n , cs.finished_at AS \"compat_session_finished_at?\"\n , cs.device_id AS \"compat_session_device_id?\"\n , cs.user_id AS \"user_id?\"\n FROM compat_sso_logins cl\n LEFT JOIN compat_sessions cs\n USING (compat_session_id)\n WHERE cl.login_token = $1\n " + }, "51158bfcaa1a8d8e051bffe7c5ba0369bf53fb162f7622626054e89e68fc07bd": { "describe": { "columns": [ @@ -972,104 +608,6 @@ }, "query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n " }, - "51bf417d259989d1228ba86fa11432e9428dece97b79e93f13921d0a510a9428": { - "describe": { - "columns": [ - { - "name": "compat_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_refresh_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_id", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "compat_access_token", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "compat_access_token_created_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_expires_at", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 7, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at", - "ordinal": 9, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "user_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "user_username!", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "user_primary_user_email_id", - "ordinal": 13, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - true, - false, - false, - true, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n cr.compat_refresh_token_id,\n cr.refresh_token AS \"compat_refresh_token\",\n cr.created_at AS \"compat_refresh_token_created_at\",\n ct.compat_access_token_id,\n ct.access_token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.compat_session_id,\n cs.created_at AS \"compat_session_created_at\",\n cs.finished_at AS \"compat_session_finished_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.user_id,\n u.username AS \"user_username!\",\n u.primary_user_email_id AS \"user_primary_user_email_id\"\n\n FROM compat_refresh_tokens cr\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN compat_access_tokens ct\n USING (compat_access_token_id)\n INNER JOIN users u\n USING (user_id)\n\n WHERE cr.refresh_token = $1\n AND cr.consumed_at IS NULL\n AND cs.finished_at IS NULL\n " - }, "559a486756d08d101eb7188ef6637b9d24c024d056795b8121f7f04a7f9db6a3": { "describe": { "columns": [ @@ -1674,27 +1212,7 @@ }, "query": "\n UPDATE user_email_confirmation_codes\n SET consumed_at = $2\n WHERE user_email_confirmation_code_id = $1\n " }, - "94fd96446b237c87bd6bf741f3c42b37ee751b87b7fcc459602bdf8c46962443": { - "describe": { - "columns": [ - { - "name": "exists!", - "ordinal": 0, - "type_info": "Bool" - } - ], - "nullable": [ - null - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT EXISTS(\n SELECT 1 FROM users WHERE username = $1\n ) AS \"exists!\"\n " - }, - "976ac2435784128eab195c8e6b9bd6e8d7b3a9142c2a34538de03817a3c94e99": { + "92ef320b75ca479ed1a38f6d654fdb953431188a8654c806fd5f98444b00c012": { "describe": { "columns": [ { @@ -1751,16 +1269,6 @@ "name": "user_id?", "ordinal": 10, "type_info": "Uuid" - }, - { - "name": "user_username?", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "user_primary_user_email_id?", - "ordinal": 12, - "type_info": "Uuid" } ], "nullable": [ @@ -1774,9 +1282,7 @@ false, true, false, - false, - false, - true + false ], "parameters": { "Left": [ @@ -1784,7 +1290,27 @@ ] } }, - "query": "\n SELECT\n cl.compat_sso_login_id,\n cl.login_token AS \"compat_sso_login_token\",\n cl.redirect_uri AS \"compat_sso_login_redirect_uri\",\n cl.created_at AS \"compat_sso_login_created_at\",\n cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\",\n cl.exchanged_at AS \"compat_sso_login_exchanged_at\",\n cs.compat_session_id AS \"compat_session_id?\",\n cs.created_at AS \"compat_session_created_at?\",\n cs.finished_at AS \"compat_session_finished_at?\",\n cs.device_id AS \"compat_session_device_id?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n u.primary_user_email_id AS \"user_primary_user_email_id?\"\n FROM compat_sso_logins cl\n LEFT JOIN compat_sessions cs\n USING (compat_session_id)\n LEFT JOIN users u\n USING (user_id)\n WHERE cl.compat_sso_login_id = $1\n " + "query": "\n SELECT cl.compat_sso_login_id\n , cl.login_token AS \"compat_sso_login_token\"\n , cl.redirect_uri AS \"compat_sso_login_redirect_uri\"\n , cl.created_at AS \"compat_sso_login_created_at\"\n , cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\"\n , cl.exchanged_at AS \"compat_sso_login_exchanged_at\"\n , cs.compat_session_id AS \"compat_session_id?\"\n , cs.created_at AS \"compat_session_created_at?\"\n , cs.finished_at AS \"compat_session_finished_at?\"\n , cs.device_id AS \"compat_session_device_id?\"\n , cs.user_id AS \"user_id?\"\n\n FROM compat_sso_logins cl\n LEFT JOIN compat_sessions cs\n USING (compat_session_id)\n WHERE cl.compat_sso_login_id = $1\n " + }, + "94fd96446b237c87bd6bf741f3c42b37ee751b87b7fcc459602bdf8c46962443": { + "describe": { + "columns": [ + { + "name": "exists!", + "ordinal": 0, + "type_info": "Bool" + } + ], + "nullable": [ + null + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT EXISTS(\n SELECT 1 FROM users WHERE username = $1\n ) AS \"exists!\"\n " }, "99f5f9eb0adc5ec120ed8194cbf6a8545155bef09e6d94d92fb67fd1b14d4f28": { "describe": { @@ -2017,6 +1543,140 @@ }, "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " }, + "c467144ae98322e3ed6d34df6626d63c15bdfc7137e12097cfb6f9398f7029ca": { + "describe": { + "columns": [ + { + "name": "oauth2_authorization_grant_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_cancelled_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_fulfilled_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_exchanged_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_scope", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_state", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_redirect_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_mode", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_nonce", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_max_age", + "ordinal": 10, + "type_info": "Int4" + }, + { + "name": "oauth2_client_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_code", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_type_code", + "ordinal": 13, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_response_type_id_token", + "ordinal": 14, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_code_challenge", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_code_challenge_method", + "ordinal": 16, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_requires_consent", + "ordinal": 17, + "type_info": "Bool" + }, + { + "name": "oauth2_session_id?", + "ordinal": 18, + "type_info": "Uuid" + }, + { + "name": "user_session_id?", + "ordinal": 19, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true, + false, + false, + true, + true, + false, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT og.oauth2_authorization_grant_id\n , og.created_at AS oauth2_authorization_grant_created_at\n , og.cancelled_at AS oauth2_authorization_grant_cancelled_at\n , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , og.exchanged_at AS oauth2_authorization_grant_exchanged_at\n , og.scope AS oauth2_authorization_grant_scope\n , og.state AS oauth2_authorization_grant_state\n , og.redirect_uri AS oauth2_authorization_grant_redirect_uri\n , og.response_mode AS oauth2_authorization_grant_response_mode\n , og.nonce AS oauth2_authorization_grant_nonce\n , og.max_age AS oauth2_authorization_grant_max_age\n , og.oauth2_client_id AS oauth2_client_id\n , og.authorization_code AS oauth2_authorization_grant_code\n , og.response_type_code AS oauth2_authorization_grant_response_type_code\n , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , og.code_challenge AS oauth2_authorization_grant_code_challenge\n , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , og.requires_consent AS oauth2_authorization_grant_requires_consent\n , os.oauth2_session_id AS \"oauth2_session_id?\"\n , os.user_session_id AS \"user_session_id?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE og.authorization_code = $1\n " + }, "c88376abdba124ff0487a9a69d2345c7d69d7394f355111ec369cfa6d45fb40f": { "describe": { "columns": [], @@ -2043,22 +1703,7 @@ }, "query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)\n " }, - "caf54e4659306a746747aa61906bdb2cb8da51176e90435aa8b9754ebf3e4d60": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " - }, - "d023d7346ec1f32da9459db3c39dffd8a4e3d4e91cdf096928de4517d3f8c622": { + "cad4d47709278a9ddbebfc91642967b465bafa596827d9b86a336841b2cfbf0c": { "describe": { "columns": [ { @@ -2115,36 +1760,6 @@ "name": "user_session_id!", "ordinal": 10, "type_info": "Uuid" - }, - { - "name": "user_session_created_at!", - "ordinal": 11, - "type_info": "Timestamptz" - }, - { - "name": "user_id!", - "ordinal": 12, - "type_info": "Uuid" - }, - { - "name": "user_username!", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "user_primary_user_email_id", - "ordinal": 14, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_id?", - "ordinal": 15, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_created_at?", - "ordinal": 16, - "type_info": "Timestamptz" } ], "nullable": [ @@ -2158,12 +1773,6 @@ false, false, false, - false, - false, - false, - false, - true, - false, false ], "parameters": { @@ -2172,7 +1781,156 @@ ] } }, - "query": "\n SELECT\n rt.oauth2_refresh_token_id,\n rt.refresh_token AS oauth2_refresh_token,\n rt.created_at AS oauth2_refresh_token_created_at,\n at.oauth2_access_token_id AS \"oauth2_access_token_id?\",\n at.access_token AS \"oauth2_access_token?\",\n at.created_at AS \"oauth2_access_token_created_at?\",\n at.expires_at AS \"oauth2_access_token_expires_at?\",\n os.oauth2_session_id AS \"oauth2_session_id!\",\n os.oauth2_client_id AS \"oauth2_client_id!\",\n os.scope AS \"oauth2_session_scope!\",\n us.user_session_id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n u.primary_user_email_id AS \"user_primary_user_email_id\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN oauth2_access_tokens at\n USING (oauth2_access_token_id)\n INNER JOIN user_sessions us\n USING (user_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND us.finished_at IS NULL\n AND os.finished_at IS NULL\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " + "query": "\n SELECT rt.oauth2_refresh_token_id\n , rt.refresh_token AS oauth2_refresh_token\n , rt.created_at AS oauth2_refresh_token_created_at\n , at.oauth2_access_token_id AS \"oauth2_access_token_id?\"\n , at.access_token AS \"oauth2_access_token?\"\n , at.created_at AS \"oauth2_access_token_created_at?\"\n , at.expires_at AS \"oauth2_access_token_expires_at?\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"oauth2_session_scope!\"\n , os.user_session_id AS \"user_session_id!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN oauth2_access_tokens at\n USING (oauth2_access_token_id)\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND os.finished_at IS NULL\n " + }, + "caf54e4659306a746747aa61906bdb2cb8da51176e90435aa8b9754ebf3e4d60": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " + }, + "d08b787fc422b6699ffc0a491ecf92fb993db0aca51534b315bcfa4891baca84": { + "describe": { + "columns": [ + { + "name": "oauth2_authorization_grant_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_cancelled_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_fulfilled_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_exchanged_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_scope", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_state", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_redirect_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_mode", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_nonce", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_max_age", + "ordinal": 10, + "type_info": "Int4" + }, + { + "name": "oauth2_client_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_code", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_type_code", + "ordinal": 13, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_response_type_id_token", + "ordinal": 14, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_code_challenge", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_code_challenge_method", + "ordinal": 16, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_requires_consent", + "ordinal": 17, + "type_info": "Bool" + }, + { + "name": "oauth2_session_id?", + "ordinal": 18, + "type_info": "Uuid" + }, + { + "name": "user_session_id?", + "ordinal": 19, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true, + false, + false, + true, + true, + false, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT og.oauth2_authorization_grant_id\n , og.created_at AS oauth2_authorization_grant_created_at\n , og.cancelled_at AS oauth2_authorization_grant_cancelled_at\n , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , og.exchanged_at AS oauth2_authorization_grant_exchanged_at\n , og.scope AS oauth2_authorization_grant_scope\n , og.state AS oauth2_authorization_grant_state\n , og.redirect_uri AS oauth2_authorization_grant_redirect_uri\n , og.response_mode AS oauth2_authorization_grant_response_mode\n , og.nonce AS oauth2_authorization_grant_nonce\n , og.max_age AS oauth2_authorization_grant_max_age\n , og.oauth2_client_id AS oauth2_client_id\n , og.authorization_code AS oauth2_authorization_grant_code\n , og.response_type_code AS oauth2_authorization_grant_response_type_code\n , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , og.code_challenge AS oauth2_authorization_grant_code_challenge\n , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , og.requires_consent AS oauth2_authorization_grant_requires_consent\n , os.oauth2_session_id AS \"oauth2_session_id?\"\n , os.user_session_id AS \"user_session_id?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE og.oauth2_authorization_grant_id = $1\n " }, "d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c": { "describe": { @@ -2209,6 +1967,75 @@ }, "query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n " }, + "d4c6c070a0cd889cef9e0cfd65c64522a03f0bae12ee7c6b74343ec8f38d24c1": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "compat_access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "compat_access_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "compat_access_token_expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "compat_session_created_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_finished_at", + "ordinal": 6, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_device_id", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "user_id!", + "ordinal": 8, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Text", + "Timestamptz" + ] + } + }, + "query": "\n SELECT ct.compat_access_token_id\n , ct.access_token AS \"compat_access_token\"\n , ct.created_at AS \"compat_access_token_created_at\"\n , ct.expires_at AS \"compat_access_token_expires_at\"\n , cs.compat_session_id\n , cs.created_at AS \"compat_session_created_at\"\n , cs.finished_at AS \"compat_session_finished_at\"\n , cs.device_id AS \"compat_session_device_id\"\n , cs.user_id AS \"user_id!\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n\n WHERE ct.access_token = $1\n AND (ct.expires_at < $2 OR ct.expires_at IS NULL)\n AND cs.finished_at IS NULL \n " + }, "d8677b3b6ee594c230fad98c1aa1c6e3d983375bf5b701c7b52468e7f906abf9": { "describe": { "columns": [], @@ -2426,62 +2253,47 @@ }, "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " }, - "f624e1bdbff4e97b300362d1bbd86035e4a0fdd8ffe16c3bfb9bc451ba60851b": { + "fba88894ee24cd181f50412571a19ee658f77012d330e7dab43a3c18d549355a": { "describe": { "columns": [ { - "name": "compat_access_token_id", + "name": "oauth2_access_token_id", "ordinal": 0, "type_info": "Uuid" }, { - "name": "compat_access_token", + "name": "oauth2_access_token", "ordinal": 1, "type_info": "Text" }, { - "name": "compat_access_token_created_at", + "name": "oauth2_access_token_created_at", "ordinal": 2, "type_info": "Timestamptz" }, { - "name": "compat_access_token_expires_at", + "name": "oauth2_access_token_expires_at", "ordinal": 3, "type_info": "Timestamptz" }, { - "name": "compat_session_id", + "name": "oauth2_session_id!", "ordinal": 4, "type_info": "Uuid" }, { - "name": "compat_session_created_at", + "name": "oauth2_client_id!", "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "user_id!", - "ordinal": 8, "type_info": "Uuid" }, { - "name": "user_username!", - "ordinal": 9, + "name": "scope!", + "ordinal": 6, "type_info": "Text" }, { - "name": "user_primary_user_email_id", - "ordinal": 10, + "name": "user_session_id!", + "ordinal": 7, "type_info": "Uuid" } ], @@ -2489,22 +2301,18 @@ false, false, false, - true, - false, - false, - true, false, false, false, - true + false, + false ], "parameters": { "Left": [ - "Text", - "Timestamptz" + "Text" ] } }, - "query": "\n SELECT\n ct.compat_access_token_id,\n ct.access_token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.compat_session_id,\n cs.created_at AS \"compat_session_created_at\",\n cs.finished_at AS \"compat_session_finished_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n u.primary_user_email_id AS \"user_primary_user_email_id\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN users u\n USING (user_id)\n\n WHERE ct.access_token = $1\n AND (ct.expires_at < $2 OR ct.expires_at IS NULL)\n AND cs.finished_at IS NULL \n " + "query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , os.user_session_id AS \"user_session_id!\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n " } } \ No newline at end of file diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs index 4708e91bd..ba47990d1 100644 --- a/crates/storage/src/compat.rs +++ b/crates/storage/src/compat.rs @@ -39,8 +39,6 @@ struct CompatAccessTokenLookup { compat_session_finished_at: Option>, compat_session_device_id: String, user_id: Uuid, - user_username: String, - user_primary_user_email_id: Option, } #[tracing::instrument(skip_all, err)] @@ -52,24 +50,19 @@ pub async fn lookup_active_compat_access_token( let res = sqlx::query_as!( CompatAccessTokenLookup, r#" - SELECT - ct.compat_access_token_id, - ct.access_token AS "compat_access_token", - ct.created_at AS "compat_access_token_created_at", - ct.expires_at AS "compat_access_token_expires_at", - cs.compat_session_id, - cs.created_at AS "compat_session_created_at", - cs.finished_at AS "compat_session_finished_at", - cs.device_id AS "compat_session_device_id", - u.user_id AS "user_id!", - u.username AS "user_username!", - u.primary_user_email_id AS "user_primary_user_email_id" + SELECT ct.compat_access_token_id + , ct.access_token AS "compat_access_token" + , ct.created_at AS "compat_access_token_created_at" + , ct.expires_at AS "compat_access_token_expires_at" + , cs.compat_session_id + , cs.created_at AS "compat_session_created_at" + , cs.finished_at AS "compat_session_finished_at" + , cs.device_id AS "compat_session_device_id" + , cs.user_id AS "user_id!" FROM compat_access_tokens ct INNER JOIN compat_sessions cs USING (compat_session_id) - INNER JOIN users u - USING (user_id) WHERE ct.access_token = $1 AND (ct.expires_at < $2 OR ct.expires_at IS NULL) @@ -92,14 +85,6 @@ pub async fn lookup_active_compat_access_token( expires_at: res.compat_access_token_expires_at, }; - let user_id = Ulid::from(res.user_id); - let user = User { - id: user_id, - username: res.user_username, - sub: user_id.to_string(), - primary_user_email_id: res.user_primary_user_email_id.map(Into::into), - }; - let id = res.compat_session_id.into(); let device = Device::try_from(res.compat_session_device_id).map_err(|e| { DatabaseInconsistencyError::on("compat_sessions") @@ -110,7 +95,7 @@ pub async fn lookup_active_compat_access_token( let session = CompatSession { id, - user, + user_id: res.user_id.into(), device, created_at: res.compat_session_created_at, finished_at: res.compat_session_finished_at, @@ -132,8 +117,6 @@ pub struct CompatRefreshTokenLookup { compat_session_finished_at: Option>, compat_session_device_id: String, user_id: Uuid, - user_username: String, - user_primary_user_email_id: Option, } #[tracing::instrument(skip_all, err)] @@ -145,29 +128,24 @@ pub async fn lookup_active_compat_refresh_token( let res = sqlx::query_as!( CompatRefreshTokenLookup, r#" - SELECT - cr.compat_refresh_token_id, - cr.refresh_token AS "compat_refresh_token", - cr.created_at AS "compat_refresh_token_created_at", - ct.compat_access_token_id, - ct.access_token AS "compat_access_token", - ct.created_at AS "compat_access_token_created_at", - ct.expires_at AS "compat_access_token_expires_at", - cs.compat_session_id, - cs.created_at AS "compat_session_created_at", - cs.finished_at AS "compat_session_finished_at", - cs.device_id AS "compat_session_device_id", - u.user_id, - u.username AS "user_username!", - u.primary_user_email_id AS "user_primary_user_email_id" + SELECT cr.compat_refresh_token_id + , cr.refresh_token AS "compat_refresh_token" + , cr.created_at AS "compat_refresh_token_created_at" + , ct.compat_access_token_id + , ct.access_token AS "compat_access_token" + , ct.created_at AS "compat_access_token_created_at" + , ct.expires_at AS "compat_access_token_expires_at" + , cs.compat_session_id + , cs.created_at AS "compat_session_created_at" + , cs.finished_at AS "compat_session_finished_at" + , cs.device_id AS "compat_session_device_id" + , cs.user_id FROM compat_refresh_tokens cr INNER JOIN compat_sessions cs USING (compat_session_id) INNER JOIN compat_access_tokens ct USING (compat_access_token_id) - INNER JOIN users u - USING (user_id) WHERE cr.refresh_token = $1 AND cr.consumed_at IS NULL @@ -195,25 +173,17 @@ pub async fn lookup_active_compat_refresh_token( expires_at: res.compat_access_token_expires_at, }; - let user_id = Ulid::from(res.user_id); - let user = User { - id: user_id, - username: res.user_username, - sub: user_id.to_string(), - primary_user_email_id: res.user_primary_user_email_id.map(Into::into), - }; - - let session_id = res.compat_session_id.into(); + let id = res.compat_session_id.into(); let device = Device::try_from(res.compat_session_device_id).map_err(|e| { DatabaseInconsistencyError::on("compat_sessions") .column("device_id") - .row(session_id) + .row(id) .source(e) })?; let session = CompatSession { - id: session_id, - user, + id, + user_id: res.user_id.into(), device, created_at: res.compat_session_created_at, finished_at: res.compat_session_finished_at, @@ -228,7 +198,7 @@ pub async fn lookup_active_compat_refresh_token( compat_session.id = %session.id, compat_session.device.id = session.device.as_str(), compat_access_token.id, - user.id = %session.user.id, + user.id = %session.user_id, ), err, )] @@ -305,7 +275,7 @@ pub async fn expire_compat_access_token( compat_session.device.id = session.device.as_str(), compat_access_token.id = %access_token.id, compat_refresh_token.id, - user.id = %session.user.id, + user.id = %session.user_id, ), err, )] @@ -469,8 +439,6 @@ struct CompatSsoLoginLookup { compat_session_finished_at: Option>, compat_session_device_id: Option, user_id: Option, - user_username: Option, - user_primary_user_email_id: Option, } impl TryFrom for CompatSsoLogin { @@ -485,33 +453,14 @@ impl TryFrom for CompatSsoLogin { .source(e) })?; - let user = match ( - res.user_id, - res.user_username, - res.user_primary_user_email_id, - ) { - (Some(id), Some(username), primary_email_id) => { - let id = Ulid::from(id); - Some(User { - id, - username, - sub: id.to_string(), - primary_user_email_id: primary_email_id.map(Into::into), - }) - } - - (None, None, None) => None, - _ => return Err(DatabaseInconsistencyError::on("compat_sessions").column("user_id")), - }; - let session = match ( res.compat_session_id, res.compat_session_device_id, res.compat_session_created_at, res.compat_session_finished_at, - user, + res.user_id, ) { - (Some(id), Some(device_id), Some(created_at), finished_at, Some(user)) => { + (Some(id), Some(device_id), Some(created_at), finished_at, Some(user_id)) => { let id = id.into(); let device = Device::try_from(device_id).map_err(|e| { DatabaseInconsistencyError::on("compat_sessions") @@ -521,7 +470,7 @@ impl TryFrom for CompatSsoLogin { })?; Some(CompatSession { id, - user, + user_id: user_id.into(), device, created_at, finished_at, @@ -579,25 +528,21 @@ pub async fn get_compat_sso_login_by_id( let res = sqlx::query_as!( CompatSsoLoginLookup, r#" - SELECT - cl.compat_sso_login_id, - cl.login_token AS "compat_sso_login_token", - cl.redirect_uri AS "compat_sso_login_redirect_uri", - cl.created_at AS "compat_sso_login_created_at", - cl.fulfilled_at AS "compat_sso_login_fulfilled_at", - cl.exchanged_at AS "compat_sso_login_exchanged_at", - cs.compat_session_id AS "compat_session_id?", - cs.created_at AS "compat_session_created_at?", - cs.finished_at AS "compat_session_finished_at?", - cs.device_id AS "compat_session_device_id?", - u.user_id AS "user_id?", - u.username AS "user_username?", - u.primary_user_email_id AS "user_primary_user_email_id?" + SELECT cl.compat_sso_login_id + , cl.login_token AS "compat_sso_login_token" + , cl.redirect_uri AS "compat_sso_login_redirect_uri" + , cl.created_at AS "compat_sso_login_created_at" + , cl.fulfilled_at AS "compat_sso_login_fulfilled_at" + , cl.exchanged_at AS "compat_sso_login_exchanged_at" + , cs.compat_session_id AS "compat_session_id?" + , cs.created_at AS "compat_session_created_at?" + , cs.finished_at AS "compat_session_finished_at?" + , cs.device_id AS "compat_session_device_id?" + , cs.user_id AS "user_id?" + FROM compat_sso_logins cl LEFT JOIN compat_sessions cs USING (compat_session_id) - LEFT JOIN users u - USING (user_id) WHERE cl.compat_sso_login_id = $1 "#, Uuid::from(id), @@ -632,25 +577,20 @@ pub async fn get_paginated_user_compat_sso_logins( // because we already have them let mut query = QueryBuilder::new( r#" - SELECT - cl.compat_sso_login_id, - cl.login_token AS "compat_sso_login_token", - cl.redirect_uri AS "compat_sso_login_redirect_uri", - cl.created_at AS "compat_sso_login_created_at", - cl.fulfilled_at AS "compat_sso_login_fulfilled_at", - cl.exchanged_at AS "compat_sso_login_exchanged_at", - cs.compat_session_id AS "compat_session_id", - cs.created_at AS "compat_session_created_at", - cs.finished_at AS "compat_session_finished_at", - cs.device_id AS "compat_session_device_id", - u.user_id AS "user_id", - u.username AS "user_username", - u.primary_user_email_id AS "user_primary_user_email_id?" + SELECT cl.compat_sso_login_id + , cl.login_token AS "compat_sso_login_token" + , cl.redirect_uri AS "compat_sso_login_redirect_uri" + , cl.created_at AS "compat_sso_login_created_at" + , cl.fulfilled_at AS "compat_sso_login_fulfilled_at" + , cl.exchanged_at AS "compat_sso_login_exchanged_at" + , cs.compat_session_id AS "compat_session_id" + , cs.created_at AS "compat_session_created_at" + , cs.finished_at AS "compat_session_finished_at" + , cs.device_id AS "compat_session_device_id" + , cs.user_id FROM compat_sso_logins cl LEFT JOIN compat_sessions cs USING (compat_session_id) - LEFT JOIN users u - USING (user_id) "#, ); @@ -683,25 +623,20 @@ pub async fn get_compat_sso_login_by_token( let res = sqlx::query_as!( CompatSsoLoginLookup, r#" - SELECT - cl.compat_sso_login_id, - cl.login_token AS "compat_sso_login_token", - cl.redirect_uri AS "compat_sso_login_redirect_uri", - cl.created_at AS "compat_sso_login_created_at", - cl.fulfilled_at AS "compat_sso_login_fulfilled_at", - cl.exchanged_at AS "compat_sso_login_exchanged_at", - cs.compat_session_id AS "compat_session_id?", - cs.created_at AS "compat_session_created_at?", - cs.finished_at AS "compat_session_finished_at?", - cs.device_id AS "compat_session_device_id?", - u.user_id AS "user_id?", - u.username AS "user_username?", - u.primary_user_email_id AS "user_primary_user_email_id?" + SELECT cl.compat_sso_login_id + , cl.login_token AS "compat_sso_login_token" + , cl.redirect_uri AS "compat_sso_login_redirect_uri" + , cl.created_at AS "compat_sso_login_created_at" + , cl.fulfilled_at AS "compat_sso_login_fulfilled_at" + , cl.exchanged_at AS "compat_sso_login_exchanged_at" + , cs.compat_session_id AS "compat_session_id?" + , cs.created_at AS "compat_session_created_at?" + , cs.finished_at AS "compat_session_finished_at?" + , cs.device_id AS "compat_session_device_id?" + , cs.user_id AS "user_id?" FROM compat_sso_logins cl LEFT JOIN compat_sessions cs USING (compat_session_id) - LEFT JOIN users u - USING (user_id) WHERE cl.login_token = $1 "#, token, @@ -729,7 +664,7 @@ pub async fn start_compat_session( executor: impl PgExecutor<'_>, mut rng: impl Rng + Send, clock: &Clock, - user: User, + user: &User, device: Device, ) -> Result { let created_at = clock.now(); @@ -751,7 +686,7 @@ pub async fn start_compat_session( Ok(CompatSession { id, - user, + user_id: user.id, device, created_at, finished_at: None, @@ -773,7 +708,7 @@ pub async fn fullfill_compat_sso_login( conn: impl Acquire<'_, Database = Postgres> + Send, mut rng: impl Rng + Send, clock: &Clock, - user: User, + user: &User, mut compat_sso_login: CompatSsoLogin, device: Device, ) -> Result { diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index e41f4812f..c85d1b485 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -13,21 +13,20 @@ // limitations under the License. use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User}; +use mas_data_model::{AccessToken, Session}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; use uuid::Uuid; -use super::client::OAuth2ClientRepository; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, Repository}; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; #[tracing::instrument( skip_all, fields( %session.id, - client.id = %session.client.id, - user.id = %session.browser_session.user.id, + user_session.id = %session.user_session_id, + client.id = %session.client_id, access_token.id, ), err, @@ -81,12 +80,6 @@ pub struct OAuth2AccessTokenLookup { oauth2_client_id: Uuid, scope: String, user_session_id: Uuid, - user_session_created_at: DateTime, - user_id: Uuid, - user_username: String, - user_primary_user_email_id: Option, - user_session_last_authentication_id: Option, - user_session_last_authentication_created_at: Option>, } #[allow(clippy::too_many_lines)] @@ -104,30 +97,15 @@ pub async fn lookup_active_access_token( , os.oauth2_session_id AS "oauth2_session_id!" , os.oauth2_client_id AS "oauth2_client_id!" , os.scope AS "scope!" - , us.user_session_id AS "user_session_id!" - , us.created_at AS "user_session_created_at!" - , u.user_id AS "user_id!" - , u.username AS "user_username!" - , u.primary_user_email_id AS "user_primary_user_email_id" - , usa.user_session_authentication_id AS "user_session_last_authentication_id?" - , usa.created_at AS "user_session_last_authentication_created_at?" + , os.user_session_id AS "user_session_id!" FROM oauth2_access_tokens at INNER JOIN oauth2_sessions os USING (oauth2_session_id) - INNER JOIN user_sessions us - USING (user_session_id) - INNER JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications usa - USING (user_session_id) WHERE at.access_token = $1 AND at.revoked_at IS NULL AND os.finished_at IS NULL - - ORDER BY usa.created_at DESC - LIMIT 1 "#, token, ) @@ -144,44 +122,6 @@ pub async fn lookup_active_access_token( }; let session_id = res.oauth2_session_id.into(); - let client = conn - .oauth2_client() - .lookup(res.oauth2_client_id.into()) - .await? - .ok_or_else(|| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("client_id") - .row(session_id) - })?; - - let user_id = Ulid::from(res.user_id); - let user = User { - id: user_id, - username: res.user_username, - sub: user_id.to_string(), - primary_user_email_id: res.user_primary_user_email_id.map(Into::into), - }; - - let last_authentication = match ( - res.user_session_last_authentication_id, - res.user_session_last_authentication_created_at, - ) { - (None, None) => None, - (Some(id), Some(created_at)) => Some(Authentication { - id: id.into(), - created_at, - }), - _ => return Err(DatabaseInconsistencyError::on("user_session_authentications").into()), - }; - - let browser_session = BrowserSession { - id: res.user_session_id.into(), - created_at: res.user_session_created_at, - finished_at: None, - user, - last_authentication, - }; - let scope = res.scope.parse().map_err(|e| { DatabaseInconsistencyError::on("oauth2_sessions") .column("scope") @@ -191,8 +131,8 @@ pub async fn lookup_active_access_token( let session = Session { id: session_id, - client, - browser_session, + client_id: res.oauth2_client_id.into(), + user_session_id: res.user_session_id.into(), scope, }; diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index bfd918600..3a18ef415 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -16,8 +16,8 @@ use std::num::NonZeroU32; use chrono::{DateTime, Utc}; use mas_data_model::{ - Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, - Client, Pkce, Session, User, + AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, Client, Pkce, + Session, }; use mas_iana::oauth::PkceCodeChallengeMethod; use oauth2_types::{requests::ResponseMode, scope::Scope}; @@ -151,12 +151,6 @@ struct GrantLookup { oauth2_client_id: Uuid, oauth2_session_id: Option, user_session_id: Option, - user_session_created_at: Option>, - user_id: Option, - user_username: Option, - user_primary_user_email_id: Option, - user_session_last_authentication_id: Option, - user_session_last_authentication_created_at: Option>, } impl GrantLookup { @@ -183,65 +177,20 @@ impl GrantLookup { .row(id) })?; - let last_authentication = match ( - self.user_session_last_authentication_id, - self.user_session_last_authentication_created_at, - ) { - (Some(id), Some(created_at)) => Some(Authentication { - id: id.into(), - created_at, - }), - (None, None) => None, - _ => return Err(DatabaseInconsistencyError::on("user_session_authentications").into()), - }; - - let session = match ( - self.oauth2_session_id, - self.user_session_id, - self.user_session_created_at, - self.user_id, - self.user_username, - self.user_primary_user_email_id, - last_authentication, - ) { - ( - Some(session_id), - Some(user_session_id), - Some(user_session_created_at), - Some(user_id), - Some(user_username), - user_primary_user_email_id, - last_authentication, - ) => { - let user_id = Ulid::from(user_id); - let user = User { - id: user_id, - username: user_username, - sub: user_id.to_string(), - primary_user_email_id: user_primary_user_email_id.map(Into::into), - }; - - let browser_session = BrowserSession { - id: user_session_id.into(), - user, - created_at: user_session_created_at, - finished_at: None, - last_authentication, - }; - - let client = client.clone(); + let session = match (self.oauth2_session_id, self.user_session_id) { + (Some(session_id), Some(user_session_id)) => { let scope = scope.clone(); let session = Session { id: session_id.into(), - client, - browser_session, + client_id: client.id, + user_session_id: user_session_id.into(), scope, }; Some(session) } - (None, None, None, None, None, None, None) => None, + (None, None) => None, _ => { return Err( DatabaseInconsistencyError::on("oauth2_authorization_grants") @@ -394,48 +343,32 @@ pub async fn get_grant_by_id( let res = sqlx::query_as!( GrantLookup, r#" - SELECT - og.oauth2_authorization_grant_id, - og.created_at AS oauth2_authorization_grant_created_at, - og.cancelled_at AS oauth2_authorization_grant_cancelled_at, - og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at, - og.exchanged_at AS oauth2_authorization_grant_exchanged_at, - og.scope AS oauth2_authorization_grant_scope, - og.state AS oauth2_authorization_grant_state, - og.redirect_uri AS oauth2_authorization_grant_redirect_uri, - og.response_mode AS oauth2_authorization_grant_response_mode, - og.nonce AS oauth2_authorization_grant_nonce, - og.max_age AS oauth2_authorization_grant_max_age, - og.oauth2_client_id AS oauth2_client_id, - og.authorization_code AS oauth2_authorization_grant_code, - og.response_type_code AS oauth2_authorization_grant_response_type_code, - og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token, - og.code_challenge AS oauth2_authorization_grant_code_challenge, - og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method, - og.requires_consent AS oauth2_authorization_grant_requires_consent, - os.oauth2_session_id AS "oauth2_session_id?", - us.user_session_id AS "user_session_id?", - us.created_at AS "user_session_created_at?", - u.user_id AS "user_id?", - u.username AS "user_username?", - u.primary_user_email_id AS "user_primary_user_email_id?", - usa.user_session_authentication_id AS "user_session_last_authentication_id?", - usa.created_at AS "user_session_last_authentication_created_at?" + SELECT og.oauth2_authorization_grant_id + , og.created_at AS oauth2_authorization_grant_created_at + , og.cancelled_at AS oauth2_authorization_grant_cancelled_at + , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at + , og.exchanged_at AS oauth2_authorization_grant_exchanged_at + , og.scope AS oauth2_authorization_grant_scope + , og.state AS oauth2_authorization_grant_state + , og.redirect_uri AS oauth2_authorization_grant_redirect_uri + , og.response_mode AS oauth2_authorization_grant_response_mode + , og.nonce AS oauth2_authorization_grant_nonce + , og.max_age AS oauth2_authorization_grant_max_age + , og.oauth2_client_id AS oauth2_client_id + , og.authorization_code AS oauth2_authorization_grant_code + , og.response_type_code AS oauth2_authorization_grant_response_type_code + , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token + , og.code_challenge AS oauth2_authorization_grant_code_challenge + , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method + , og.requires_consent AS oauth2_authorization_grant_requires_consent + , os.oauth2_session_id AS "oauth2_session_id?" + , os.user_session_id AS "user_session_id?" FROM oauth2_authorization_grants og LEFT JOIN oauth2_sessions os USING (oauth2_session_id) - LEFT JOIN user_sessions us - USING (user_session_id) - LEFT JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications usa - USING (user_session_id) WHERE og.oauth2_authorization_grant_id = $1 - - ORDER BY usa.created_at DESC - LIMIT 1 "#, Uuid::from(id), ) @@ -458,48 +391,32 @@ pub async fn lookup_grant_by_code( let res = sqlx::query_as!( GrantLookup, r#" - SELECT - og.oauth2_authorization_grant_id, - og.created_at AS oauth2_authorization_grant_created_at, - og.cancelled_at AS oauth2_authorization_grant_cancelled_at, - og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at, - og.exchanged_at AS oauth2_authorization_grant_exchanged_at, - og.scope AS oauth2_authorization_grant_scope, - og.state AS oauth2_authorization_grant_state, - og.redirect_uri AS oauth2_authorization_grant_redirect_uri, - og.response_mode AS oauth2_authorization_grant_response_mode, - og.nonce AS oauth2_authorization_grant_nonce, - og.max_age AS oauth2_authorization_grant_max_age, - og.oauth2_client_id AS oauth2_client_id, - og.authorization_code AS oauth2_authorization_grant_code, - og.response_type_code AS oauth2_authorization_grant_response_type_code, - og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token, - og.code_challenge AS oauth2_authorization_grant_code_challenge, - og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method, - og.requires_consent AS oauth2_authorization_grant_requires_consent, - os.oauth2_session_id AS "oauth2_session_id?", - us.user_session_id AS "user_session_id?", - us.created_at AS "user_session_created_at?", - u.user_id AS "user_id?", - u.username AS "user_username?", - u.primary_user_email_id AS "user_primary_user_email_id?", - usa.user_session_authentication_id AS "user_session_last_authentication_id?", - usa.created_at AS "user_session_last_authentication_created_at?" + SELECT og.oauth2_authorization_grant_id + , og.created_at AS oauth2_authorization_grant_created_at + , og.cancelled_at AS oauth2_authorization_grant_cancelled_at + , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at + , og.exchanged_at AS oauth2_authorization_grant_exchanged_at + , og.scope AS oauth2_authorization_grant_scope + , og.state AS oauth2_authorization_grant_state + , og.redirect_uri AS oauth2_authorization_grant_redirect_uri + , og.response_mode AS oauth2_authorization_grant_response_mode + , og.nonce AS oauth2_authorization_grant_nonce + , og.max_age AS oauth2_authorization_grant_max_age + , og.oauth2_client_id AS oauth2_client_id + , og.authorization_code AS oauth2_authorization_grant_code + , og.response_type_code AS oauth2_authorization_grant_response_type_code + , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token + , og.code_challenge AS oauth2_authorization_grant_code_challenge + , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method + , og.requires_consent AS oauth2_authorization_grant_requires_consent + , os.oauth2_session_id AS "oauth2_session_id?" + , os.user_session_id AS "user_session_id?" FROM oauth2_authorization_grants og LEFT JOIN oauth2_sessions os USING (oauth2_session_id) - LEFT JOIN user_sessions us - USING (user_session_id) - LEFT JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications usa - USING (user_session_id) WHERE og.authorization_code = $1 - - ORDER BY usa.created_at DESC - LIMIT 1 "#, code, ) @@ -561,8 +478,8 @@ pub async fn derive_session( Ok(Session { id, - browser_session, - client: grant.client.clone(), + user_session_id: browser_session.id, + client_id: grant.client.id, scope: grant.scope.clone(), }) } @@ -573,8 +490,7 @@ pub async fn derive_session( %grant.id, client.id = %grant.client.id, %session.id, - user_session.id = %session.browser_session.id, - user.id = %session.browser_session.user.id, + user_session.id = %session.user_session_id, ), err, )] diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index c0153a3a0..66313139a 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -12,19 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{BTreeSet, HashMap}; - -use mas_data_model::{BrowserSession, Session, User}; +use mas_data_model::{Session, User}; use sqlx::{PgConnection, PgExecutor, QueryBuilder}; use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; -use self::client::OAuth2ClientRepository; use crate::{ pagination::{process_page, QueryBuilderExt}, - user::BrowserSessionRepository, - Clock, DatabaseError, DatabaseInconsistencyError, Repository, + Clock, DatabaseError, DatabaseInconsistencyError, }; pub mod access_token; @@ -32,14 +28,14 @@ pub mod authorization_grant; pub mod client; pub mod consent; pub mod refresh_token; +pub mod session; #[tracing::instrument( skip_all, fields( %session.id, - user.id = %session.browser_session.user.id, - user_session.id = %session.browser_session.id, - client.id = %session.client.id, + user_session.id = %session.user_session_id, + client.id = %session.client_id, ), err, )] @@ -120,49 +116,10 @@ pub async fn get_paginated_user_oauth_sessions( let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - let client_ids: BTreeSet = page - .iter() - .map(|i| Ulid::from(i.oauth2_client_id)) - .collect(); - - let browser_session_ids: BTreeSet = - page.iter().map(|i| Ulid::from(i.user_session_id)).collect(); - - let clients = conn.oauth2_client().load_batch(client_ids).await?; - - // TODO: this can generate N queries instead of batching. This is less than - // ideal - let mut browser_sessions: HashMap = HashMap::new(); - for id in browser_session_ids { - let v = conn.browser_session().lookup(id).await?.ok_or_else(|| { - DatabaseInconsistencyError::on("oauth2_sessions").column("user_session_id") - })?; - - browser_sessions.insert(id, v); - } - let page: Result, DatabaseInconsistencyError> = page .into_iter() .map(|item| { let id = Ulid::from(item.oauth2_session_id); - let client = clients - .get(&Ulid::from(item.oauth2_client_id)) - .ok_or_else(|| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("oauth2_client_id") - .row(id) - })? - .clone(); - - let browser_session = browser_sessions - .get(&Ulid::from(item.user_session_id)) - .ok_or_else(|| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("user_session_id") - .row(id) - })? - .clone(); - let scope = item.scope.parse().map_err(|e| { DatabaseInconsistencyError::on("oauth2_sessions") .column("scope") @@ -172,8 +129,8 @@ pub async fn get_paginated_user_oauth_sessions( Ok(Session { id: Ulid::from(item.oauth2_session_id), - client, - browser_session, + client_id: item.oauth2_client_id.into(), + user_session_id: item.user_session_id.into(), scope, }) }) diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 57abf103f..bf223a794 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -13,22 +13,20 @@ // limitations under the License. use chrono::{DateTime, Utc}; -use mas_data_model::{AccessToken, Authentication, BrowserSession, RefreshToken, Session, User}; +use mas_data_model::{AccessToken, RefreshToken, Session}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; use uuid::Uuid; -use super::client::OAuth2ClientRepository; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, Repository}; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; #[tracing::instrument( skip_all, fields( %session.id, - user.id = %session.browser_session.user.id, - user_session.id = %session.browser_session.id, - client.id = %session.client.id, + user_session.id = %session.user_session_id, + client.id = %session.client_id, refresh_token.id, ), err, @@ -82,12 +80,6 @@ struct OAuth2RefreshTokenLookup { oauth2_client_id: Uuid, oauth2_session_scope: String, user_session_id: Uuid, - user_session_created_at: DateTime, - user_id: Uuid, - user_username: String, - user_primary_user_email_id: Option, - user_session_last_authentication_id: Option, - user_session_last_authentication_created_at: Option>, } #[tracing::instrument(skip_all, err)] @@ -99,46 +91,27 @@ pub async fn lookup_active_refresh_token( let res = sqlx::query_as!( OAuth2RefreshTokenLookup, r#" - SELECT - rt.oauth2_refresh_token_id, - rt.refresh_token AS oauth2_refresh_token, - rt.created_at AS oauth2_refresh_token_created_at, - at.oauth2_access_token_id AS "oauth2_access_token_id?", - at.access_token AS "oauth2_access_token?", - at.created_at AS "oauth2_access_token_created_at?", - at.expires_at AS "oauth2_access_token_expires_at?", - os.oauth2_session_id AS "oauth2_session_id!", - os.oauth2_client_id AS "oauth2_client_id!", - os.scope AS "oauth2_session_scope!", - us.user_session_id AS "user_session_id!", - us.created_at AS "user_session_created_at!", - u.user_id AS "user_id!", - u.username AS "user_username!", - u.primary_user_email_id AS "user_primary_user_email_id", - usa.user_session_authentication_id AS "user_session_last_authentication_id?", - usa.created_at AS "user_session_last_authentication_created_at?" + SELECT rt.oauth2_refresh_token_id + , rt.refresh_token AS oauth2_refresh_token + , rt.created_at AS oauth2_refresh_token_created_at + , at.oauth2_access_token_id AS "oauth2_access_token_id?" + , at.access_token AS "oauth2_access_token?" + , at.created_at AS "oauth2_access_token_created_at?" + , at.expires_at AS "oauth2_access_token_expires_at?" + , os.oauth2_session_id AS "oauth2_session_id!" + , os.oauth2_client_id AS "oauth2_client_id!" + , os.scope AS "oauth2_session_scope!" + , os.user_session_id AS "user_session_id!" FROM oauth2_refresh_tokens rt INNER JOIN oauth2_sessions os USING (oauth2_session_id) LEFT JOIN oauth2_access_tokens at USING (oauth2_access_token_id) - INNER JOIN user_sessions us - USING (user_session_id) - INNER JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications usa - USING (user_session_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id WHERE rt.refresh_token = $1 AND rt.consumed_at IS NULL AND rt.revoked_at IS NULL - AND us.finished_at IS NULL AND os.finished_at IS NULL - - ORDER BY usa.created_at DESC - LIMIT 1 "#, token, ) @@ -173,44 +146,6 @@ pub async fn lookup_active_refresh_token( }; let session_id = res.oauth2_session_id.into(); - let client = conn - .oauth2_client() - .lookup(res.oauth2_client_id.into()) - .await? - .ok_or_else(|| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("client_id") - .row(session_id) - })?; - - let user_id = Ulid::from(res.user_id); - let user = User { - id: user_id, - username: res.user_username, - sub: user_id.to_string(), - primary_user_email_id: res.user_primary_user_email_id.map(Into::into), - }; - - let last_authentication = match ( - res.user_session_last_authentication_id, - res.user_session_last_authentication_created_at, - ) { - (None, None) => None, - (Some(id), Some(created_at)) => Some(Authentication { - id: id.into(), - created_at, - }), - _ => return Err(DatabaseInconsistencyError::on("user_session_authentications").into()), - }; - - let browser_session = BrowserSession { - id: res.user_session_id.into(), - created_at: res.user_session_created_at, - finished_at: None, - user, - last_authentication, - }; - let scope = res.oauth2_session_scope.parse().map_err(|e| { DatabaseInconsistencyError::on("oauth2_sessions") .column("scope") @@ -220,8 +155,8 @@ pub async fn lookup_active_refresh_token( let session = Session { id: session_id, - client, - browser_session, + client_id: res.oauth2_client_id.into(), + user_session_id: res.user_session_id.into(), scope, }; diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs new file mode 100644 index 000000000..71efb5bd7 --- /dev/null +++ b/crates/storage/src/oauth2/session.rs @@ -0,0 +1,20 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; + +#[async_trait] +pub trait OAuth2SessionRepository { + type Error; +} From 644eb61dd435b9c24d7b91a04302bf47f1387fe8 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 5 Jan 2023 15:16:03 +0100 Subject: [PATCH 12/45] storage: oauth2 session repository --- crates/axum-utils/src/client_authorization.rs | 2 +- crates/cli/src/commands/manage.rs | 2 +- crates/data-model/src/oauth2/session.rs | 2 + crates/graphql/src/lib.rs | 2 +- crates/graphql/src/model/oauth.rs | 4 +- crates/graphql/src/model/users.rs | 12 +- .../src/oauth2/authorization/complete.rs | 15 +- .../handlers/src/oauth2/authorization/mod.rs | 2 +- crates/handlers/src/oauth2/registration.rs | 2 +- crates/handlers/src/oauth2/token.rs | 4 +- crates/handlers/src/oauth2/userinfo.rs | 2 +- crates/storage/sqlx-data.json | 57 ++--- crates/storage/src/oauth2/access_token.rs | 1 + .../storage/src/oauth2/authorization_grant.rs | 59 +---- crates/storage/src/oauth2/mod.rs | 127 +--------- crates/storage/src/oauth2/refresh_token.rs | 1 + crates/storage/src/oauth2/session.rs | 223 ++++++++++++++++++ crates/storage/src/repository.rs | 17 +- 18 files changed, 307 insertions(+), 227 deletions(-) diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 0e6771b4e..6f212369b 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -31,7 +31,7 @@ use mas_http::HttpServiceExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_keystore::Encrypter; -use mas_storage::{oauth2::client::OAuth2ClientRepository, DatabaseError, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, DatabaseError, Repository}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; use sqlx::PgConnection; diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 62d62db87..d159f3e30 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -18,7 +18,7 @@ use mas_config::{DatabaseConfig, PasswordsConfig, RootConfig}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_router::UrlBuilder; use mas_storage::{ - oauth2::client::OAuth2ClientRepository, + oauth2::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, Repository, diff --git a/crates/data-model/src/oauth2/session.rs b/crates/data-model/src/oauth2/session.rs index 29454feb2..aec48ac13 100644 --- a/crates/data-model/src/oauth2/session.rs +++ b/crates/data-model/src/oauth2/session.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use chrono::{DateTime, Utc}; use oauth2_types::scope::Scope; use serde::Serialize; use ulid::Ulid; @@ -22,4 +23,5 @@ pub struct Session { pub user_session_id: Ulid, pub client_id: Ulid, pub scope: Scope, + pub finished_at: Option>, } diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index ffa63396b..d01be16ce 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -31,7 +31,7 @@ use async_graphql::{ Context, Description, EmptyMutation, EmptySubscription, ID, }; use mas_storage::{ - oauth2::client::OAuth2ClientRepository, + oauth2::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserEmailRepository}, Repository, UpstreamOAuthLinkRepository, diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 8e418e6c8..0ab2bc684 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,9 +14,7 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::{ - oauth2::client::OAuth2ClientRepository, user::BrowserSessionRepository, Repository, -}; +use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, Repository}; use oauth2_types::scope::Scope; use sqlx::PgPool; use ulid::Ulid; diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index dc40d6cd3..2f241ced6 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -18,6 +18,7 @@ use async_graphql::{ }; use chrono::{DateTime, Utc}; use mas_storage::{ + oauth2::OAuth2SessionRepository, user::{BrowserSessionRepository, UserEmailRepository}, Repository, UpstreamOAuthLinkRepository, }; @@ -241,14 +242,13 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::OAuth2Session)) .transpose()?; - let (has_previous_page, has_next_page, edges) = - mas_storage::oauth2::get_paginated_user_oauth_sessions( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = conn + .oauth2_session() + .list_paginated(&self.0, before_id, after_id, first, last) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|s| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|s| { Edge::new( OpaqueCursor(NodeCursor(NodeType::OAuth2Session, s.id)), OAuth2Session(s), diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index eb4b8889d..01c89ff3a 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -25,9 +25,13 @@ use mas_data_model::{AuthorizationGrant, BrowserSession}; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; -use mas_storage::oauth2::{ - authorization_grant::{derive_session, fulfill_grant, get_grant_by_id}, - consent::fetch_client_consent, +use mas_storage::{ + oauth2::{ + authorization_grant::{fulfill_grant, get_grant_by_id}, + consent::fetch_client_consent, + OAuth2SessionRepository, + }, + Repository, }; use mas_templates::Templates; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; @@ -193,7 +197,10 @@ pub(crate) async fn complete( } // All good, let's start the session - let session = derive_session(&mut txn, &mut rng, &clock, &grant, browser_session).await?; + let session = txn + .oauth2_session() + .create_from_grant(&mut rng, &clock, &grant, &browser_session) + .await?; let grant = fulfill_grant(&mut txn, grant, session.clone()).await?; diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 36d15d2bc..cfcd936eb 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -26,7 +26,7 @@ use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - oauth2::{authorization_grant::new_authorization_grant, client::OAuth2ClientRepository}, + oauth2::{authorization_grant::new_authorization_grant, OAuth2ClientRepository}, Repository, }; use mas_templates::Templates; diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index b12194eb5..a6ff61587 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -19,7 +19,7 @@ use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; use mas_policy::{PolicyFactory, Violation}; -use mas_storage::{oauth2::client::OAuth2ClientRepository, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, registration::{ diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 473dcab82..391bdde2b 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -35,8 +35,8 @@ use mas_storage::{ oauth2::{ access_token::{add_access_token, revoke_access_token}, authorization_grant::{exchange_grant, lookup_grant_by_code}, - end_oauth_session, refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token}, + OAuth2SessionRepository, }, user::BrowserSessionRepository, Repository, @@ -234,7 +234,7 @@ async fn authorization_code_grant( // Ending the session if the token was already exchanged more than 20s ago if now - exchanged_at > Duration::seconds(20) { debug!("Ending potentially compromised session"); - end_oauth_session(&mut txn, &clock, session).await?; + txn.oauth2_session().finish(&clock, session).await?; txn.commit().await?; } diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 699b049a5..d2b2b6150 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -29,7 +29,7 @@ use mas_jose::{ use mas_keystore::Keystore; use mas_router::UrlBuilder; use mas_storage::{ - oauth2::client::OAuth2ClientRepository, + oauth2::OAuth2ClientRepository, user::{BrowserSessionRepository, UserEmailRepository}, Repository, }; diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 4784c030f..740e9a3a5 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -629,6 +629,22 @@ }, "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n FROM compat_access_tokens ca\n WHERE ca.access_token = $1\n AND ca.compat_session_id = cs.compat_session_id\n AND cs.finished_at IS NULL\n RETURNING cs.compat_session_id\n " }, + "583ae9a0db9cd55fa57a179339550f3dab1bfc76f35ad488e1560ea37f7ed029": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO oauth2_sessions\n ( oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n )\n VALUES ($1, $2, $3, $4, $5)\n " + }, "5b5d5c82da37c6f2d8affacfb02119965c04d1f2a9cc53dbf5bd4c12584969a0": { "describe": { "columns": [], @@ -1325,19 +1341,6 @@ }, "query": "\n UPDATE compat_refresh_tokens\n SET consumed_at = $2\n WHERE compat_refresh_token_id = $1\n " }, - "9c1ef3114bfe22884d893bb11dc6054421c28cce4bd828cfe6a4ad46c062481a": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE oauth2_sessions\n SET finished_at = $2\n WHERE oauth2_session_id = $1\n " - }, "a300fe99c95679c5664646a6a525c0491829e97db45f3234483872ed38436322": { "describe": { "columns": [ @@ -1469,6 +1472,19 @@ }, "query": "\n INSERT INTO user_email_confirmation_codes\n (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " }, + "b700dc3f7d0f86f4904725d8357e34b7e457f857ed37c467c314142877fd5367": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE oauth2_sessions\n SET finished_at = $2\n WHERE oauth2_session_id = $1\n " + }, "b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64": { "describe": { "columns": [], @@ -1484,21 +1500,6 @@ }, "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n " }, - "bc768c63a7737818967bc28560de714bbbd262bdf3ab73d297263bb73dcd9f5e": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Timestamptz", - "Uuid" - ] - } - }, - "query": "\n INSERT INTO oauth2_sessions\n (oauth2_session_id, user_session_id, oauth2_client_id, scope, created_at)\n SELECT\n $1,\n $2,\n og.oauth2_client_id,\n og.scope,\n $3\n FROM\n oauth2_authorization_grants og\n WHERE\n og.oauth2_authorization_grant_id = $4\n " - }, "bd1f6daa5fa1b10250c01f8b3fbe451646a9ceeefa6f72b9c4e29b6d05f17641": { "describe": { "columns": [], diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index c85d1b485..cadb93e01 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -134,6 +134,7 @@ pub async fn lookup_active_access_token( client_id: res.oauth2_client_id.into(), user_session_id: res.user_session_id.into(), scope, + finished_at: None, }; Ok(Some((access_token, session))) diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 3a18ef415..29577d599 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -16,8 +16,7 @@ use std::num::NonZeroU32; use chrono::{DateTime, Utc}; use mas_data_model::{ - AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, Client, Pkce, - Session, + AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session, }; use mas_iana::oauth::PkceCodeChallengeMethod; use oauth2_types::{requests::ResponseMode, scope::Scope}; @@ -27,7 +26,7 @@ use ulid::Ulid; use url::Url; use uuid::Uuid; -use super::client::OAuth2ClientRepository; +use super::OAuth2ClientRepository; use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Repository}; #[tracing::instrument( @@ -186,6 +185,7 @@ impl GrantLookup { client_id: client.id, user_session_id: user_session_id.into(), scope, + finished_at: None, }; Some(session) @@ -431,59 +431,6 @@ pub async fn lookup_grant_by_code( Ok(Some(grant)) } -#[tracing::instrument( - skip_all, - fields( - %grant.id, - client.id = %grant.client.id, - session.id, - user_session.id = %browser_session.id, - user.id = %browser_session.user.id, - ), - err, -)] -pub async fn derive_session( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - grant: &AuthorizationGrant, - browser_session: BrowserSession, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO oauth2_sessions - (oauth2_session_id, user_session_id, oauth2_client_id, scope, created_at) - SELECT - $1, - $2, - og.oauth2_client_id, - og.scope, - $3 - FROM - oauth2_authorization_grants og - WHERE - og.oauth2_authorization_grant_id = $4 - "#, - Uuid::from(id), - Uuid::from(browser_session.id), - created_at, - Uuid::from(grant.id), - ) - .execute(executor) - .await?; - - Ok(Session { - id, - user_session_id: browser_session.id, - client_id: grant.client.id, - scope: grant.scope.clone(), - }) -} - #[tracing::instrument( skip_all, fields( diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index 66313139a..b02216a65 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -12,129 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use mas_data_model::{Session, User}; -use sqlx::{PgConnection, PgExecutor, QueryBuilder}; -use tracing::{info_span, Instrument}; -use ulid::Ulid; -use uuid::Uuid; - -use crate::{ - pagination::{process_page, QueryBuilderExt}, - Clock, DatabaseError, DatabaseInconsistencyError, -}; - pub mod access_token; pub mod authorization_grant; -pub mod client; +mod client; pub mod consent; pub mod refresh_token; -pub mod session; +mod session; -#[tracing::instrument( - skip_all, - fields( - %session.id, - user_session.id = %session.user_session_id, - client.id = %session.client_id, - ), - err, -)] -pub async fn end_oauth_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - session: Session, -) -> Result<(), DatabaseError> { - let finished_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_sessions - SET finished_at = $2 - WHERE oauth2_session_id = $1 - "#, - Uuid::from(session.id), - finished_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} - -#[derive(sqlx::FromRow)] -struct OAuthSessionLookup { - oauth2_session_id: Uuid, - user_session_id: Uuid, - oauth2_client_id: Uuid, - scope: String, -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - ), - err, -)] -pub async fn get_paginated_user_oauth_sessions( - conn: &mut PgConnection, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - os.oauth2_session_id, - os.user_session_id, - os.oauth2_client_id, - os.scope, - os.created_at, - os.finished_at - FROM oauth2_sessions os - LEFT JOIN user_sessions us - USING (user_session_id) - "#, - ); - - query - .push(" WHERE us.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("oauth2_session_id", before, after, first, last)?; - - let span = info_span!( - "Fetch paginated user oauth sessions", - db.statement = query.sql() - ); - let page: Vec = query - .build_query_as() - .fetch_all(&mut *conn) - .instrument(span) - .await?; - - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - - let page: Result, DatabaseInconsistencyError> = page - .into_iter() - .map(|item| { - let id = Ulid::from(item.oauth2_session_id); - let scope = item.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("scope") - .row(id) - .source(e) - })?; - - Ok(Session { - id: Ulid::from(item.oauth2_session_id), - client_id: item.oauth2_client_id.into(), - user_session_id: item.user_session_id.into(), - scope, - }) - }) - .collect(); - - Ok((has_previous_page, has_next_page, page?)) -} +pub use self::{ + client::{OAuth2ClientRepository, PgOAuth2ClientRepository}, + session::{OAuth2SessionRepository, PgOAuth2SessionRepository}, +}; diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index bf223a794..61ace6fa8 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -158,6 +158,7 @@ pub async fn lookup_active_refresh_token( client_id: res.oauth2_client_id.into(), user_session_id: res.user_session_id.into(), scope, + finished_at: None, }; Ok(Some((refresh_token, session))) diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 71efb5bd7..5841a1d91 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -13,8 +13,231 @@ // limitations under the License. use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::{process_page, Page, QueryBuilderExt}, + tracing::ExecuteExt, + Clock, DatabaseError, DatabaseInconsistencyError, +}; #[async_trait] pub trait OAuth2SessionRepository { type Error; + + async fn create_from_grant( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + grant: &AuthorizationGrant, + user_session: &BrowserSession, + ) -> Result; + + async fn finish(&mut self, clock: &Clock, session: Session) -> Result; + + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error>; +} + +pub struct PgOAuth2SessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2SessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct OAuthSessionLookup { + oauth2_session_id: Uuid, + user_session_id: Uuid, + oauth2_client_id: Uuid, + scope: String, + finished_at: Option>, +} + +impl TryFrom for Session { + type Error = DatabaseInconsistencyError; + + fn try_from(value: OAuthSessionLookup) -> Result { + let id = Ulid::from(value.oauth2_session_id); + let scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_sessions") + .column("scope") + .row(id) + .source(e) + })?; + + Ok(Session { + id, + client_id: value.oauth2_client_id.into(), + user_session_id: value.user_session_id.into(), + scope, + finished_at: value.finished_at, + }) + } +} + +#[async_trait] +impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_session.create_from_grant", + skip_all, + fields( + db.statement, + %user_session.id, + user.id = %user_session.user.id, + %grant.id, + client.id = %grant.client.id, + session.id, + session.scope = %grant.scope, + ), + err, + )] + async fn create_from_grant( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + grant: &AuthorizationGrant, + user_session: &BrowserSession, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_sessions + ( oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + ) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + Uuid::from(grant.client.id), + grant.scope.to_string(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(Session { + id, + user_session_id: user_session.id, + client_id: grant.client.id, + scope: grant.scope.clone(), + finished_at: None, + }) + } + + #[tracing::instrument( + name = "db.oauth2_session.finish", + skip_all, + fields( + db.statement, + %session.id, + %session.scope, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + ), + err, + )] + async fn finish( + &mut self, + clock: &Clock, + mut session: Session, + ) -> Result { + let finished_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_sessions + SET finished_at = $2 + WHERE oauth2_session_id = $1 + "#, + Uuid::from(session.id), + finished_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + session.finished_at = Some(finished_at); + + Ok(session) + } + + #[tracing::instrument( + name = "db.oauth2_session.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err, + )] + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + , finished_at + FROM oauth2_sessions os + "#, + ); + + query + .push(" WHERE us.user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("oauth2_session_id", before, after, first, last)?; + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let (has_previous_page, has_next_page, edges) = process_page(edges, first, last)?; + + let edges: Result, DatabaseInconsistencyError> = + edges.into_iter().map(Session::try_from).collect(); + + Ok(Page { + has_next_page, + has_previous_page, + edges: edges?, + }) + } } diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 4bca22530..8eda57016 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -15,7 +15,7 @@ use sqlx::{PgConnection, Postgres, Transaction}; use crate::{ - oauth2::client::PgOAuth2ClientRepository, + oauth2::{PgOAuth2ClientRepository, PgOAuth2SessionRepository}, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, PgUpstreamOAuthSessionRepository, @@ -59,6 +59,10 @@ pub trait Repository { where Self: 'c; + type OAuth2SessionRepository<'c> + where + Self: 'c; + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>; @@ -67,6 +71,7 @@ pub trait Repository { fn user_password(&mut self) -> Self::UserPasswordRepository<'_>; fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>; fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>; + fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>; } impl Repository for PgConnection { @@ -78,6 +83,7 @@ impl Repository for PgConnection { type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; + type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -110,6 +116,10 @@ impl Repository for PgConnection { fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { PgOAuth2ClientRepository::new(self) } + + fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { + PgOAuth2SessionRepository::new(self) + } } impl<'t> Repository for Transaction<'t, Postgres> { @@ -121,6 +131,7 @@ impl<'t> Repository for Transaction<'t, Postgres> { type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; + type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -153,4 +164,8 @@ impl<'t> Repository for Transaction<'t, Postgres> { fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { PgOAuth2ClientRepository::new(self) } + + fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { + PgOAuth2SessionRepository::new(self) + } } From 8b8b21329e0a03a56d6781c8b7d05cccb5791426 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 5 Jan 2023 16:49:19 +0100 Subject: [PATCH 13/45] storage: do less joins on authorization grants and refresh tokens --- .../src/oauth2/authorization_grant.rs | 12 +- crates/data-model/src/tokens.rs | 2 +- crates/handlers/src/oauth2/token.rs | 58 ++++---- crates/storage/src/oauth2/access_token.rs | 6 +- .../storage/src/oauth2/authorization_grant.rs | 134 +++++++----------- crates/storage/src/oauth2/refresh_token.rs | 36 +---- crates/storage/src/oauth2/session.rs | 41 +++++- 7 files changed, 140 insertions(+), 149 deletions(-) diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index cb85a2654..a7222cda6 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -63,11 +63,11 @@ pub enum AuthorizationGrantStage { #[default] Pending, Fulfilled { - session: Session, + session_id: Ulid, fulfilled_at: DateTime, }, Exchanged { - session: Session, + session_id: Ulid, fulfilled_at: DateTime, exchanged_at: DateTime, }, @@ -85,12 +85,12 @@ impl AuthorizationGrantStage { pub fn fulfill( self, fulfilled_at: DateTime, - session: Session, + session: &Session, ) -> Result { match self { Self::Pending => Ok(Self::Fulfilled { fulfilled_at, - session, + session_id: session.id, }), _ => Err(InvalidTransitionError), } @@ -100,11 +100,11 @@ impl AuthorizationGrantStage { match self { Self::Fulfilled { fulfilled_at, - session, + session_id, } => Ok(Self::Exchanged { fulfilled_at, exchanged_at, - session, + session_id, }), _ => Err(InvalidTransitionError), } diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index 93b29f6dd..7b058820d 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -33,7 +33,7 @@ pub struct RefreshToken { pub id: Ulid, pub refresh_token: String, pub created_at: DateTime, - pub access_token: Option, + pub access_token_id: Option, } /// Type of token to generate or validate diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 391bdde2b..eb0e20dde 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -109,12 +109,18 @@ pub(crate) enum RouteError { #[error("failed to load browser session")] NoSuchBrowserSession, + + #[error("failed to load oauth session")] + NoSuchOAuthSession, } impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { - Self::Internal(_) | Self::InvalidSigningKey | Self::NoSuchBrowserSession => ( + Self::Internal(_) + | Self::InvalidSigningKey + | Self::NoSuchBrowserSession + | Self::NoSuchOAuthSession => ( StatusCode::INTERNAL_SERVER_ERROR, Json(ClientError::from(ClientErrorCode::ServerError)), ), @@ -219,7 +225,7 @@ async fn authorization_code_grant( let now = clock.now(); - let session = match authz_grant.stage { + let session_id = match authz_grant.stage { AuthorizationGrantStage::Cancelled { cancelled_at } => { debug!(%cancelled_at, "Authorization grant was cancelled"); return Err(RouteError::InvalidGrant); @@ -227,13 +233,18 @@ async fn authorization_code_grant( AuthorizationGrantStage::Exchanged { exchanged_at, fulfilled_at, - session, + session_id, } => { debug!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged"); // Ending the session if the token was already exchanged more than 20s ago if now - exchanged_at > Duration::seconds(20) { debug!("Ending potentially compromised session"); + let session = txn + .oauth2_session() + .lookup(session_id) + .await? + .ok_or(RouteError::NoSuchOAuthSession)?; txn.oauth2_session().finish(&clock, session).await?; txn.commit().await?; } @@ -245,7 +256,7 @@ async fn authorization_code_grant( return Err(RouteError::InvalidGrant); } AuthorizationGrantStage::Fulfilled { - ref session, + session_id, fulfilled_at, } => { if now - fulfilled_at > Duration::minutes(10) { @@ -253,10 +264,16 @@ async fn authorization_code_grant( return Err(RouteError::InvalidGrant); } - session + session_id } }; + let session = txn + .oauth2_session() + .lookup(session_id) + .await? + .ok_or(RouteError::NoSuchOAuthSession)?; + // This should never happen, since we looked up in the database using the code let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?; @@ -284,23 +301,16 @@ async fn authorization_code_grant( let access_token_str = TokenType::AccessToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - let access_token = add_access_token( - &mut txn, - &mut rng, - &clock, - session, - access_token_str.clone(), - ttl, - ) - .await?; + let access_token = + add_access_token(&mut txn, &mut rng, &clock, &session, access_token_str, ttl).await?; - let _refresh_token = add_refresh_token( + let refresh_token = add_refresh_token( &mut txn, &mut rng, &clock, - session, - access_token, - refresh_token_str.clone(), + &session, + &access_token, + refresh_token_str, ) .await?; @@ -328,7 +338,7 @@ async fn authorization_code_grant( .signing_key_for_algorithm(&alg) .ok_or(RouteError::InvalidSigningKey)?; - claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token_str)?)?; + claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token.access_token)?)?; claims::C_HASH.insert(&mut claims, hash_token(&alg, &grant.code)?)?; let signer = key.params().signing_key_for_alg(&alg)?; @@ -341,9 +351,9 @@ async fn authorization_code_grant( None }; - let mut params = AccessTokenResponse::new(access_token_str) + let mut params = AccessTokenResponse::new(access_token.access_token) .with_expires_in(ttl) - .with_refresh_token(refresh_token_str) + .with_refresh_token(refresh_token.refresh_token) .with_scope(session.scope.clone()); if let Some(id_token) = id_token { @@ -392,15 +402,15 @@ async fn refresh_token_grant( &mut rng, &clock, &session, - new_access_token, + &new_access_token, refresh_token_str, ) .await?; consume_refresh_token(&mut txn, &clock, &refresh_token).await?; - if let Some(access_token) = refresh_token.access_token { - revoke_access_token(&mut txn, &clock, access_token).await?; + if let Some(access_token_id) = refresh_token.access_token_id { + revoke_access_token(&mut txn, &clock, access_token_id).await?; } let params = AccessTokenResponse::new(access_token_str) diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index cadb93e01..cd4cafbfc 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -142,13 +142,13 @@ pub async fn lookup_active_access_token( #[tracing::instrument( skip_all, - fields(%access_token.id), + fields(access_token.id = %access_token_id), err, )] pub async fn revoke_access_token( executor: impl PgExecutor<'_>, clock: &Clock, - access_token: AccessToken, + access_token_id: Ulid, ) -> Result<(), DatabaseError> { let revoked_at = clock.now(); let res = sqlx::query!( @@ -157,7 +157,7 @@ pub async fn revoke_access_token( SET revoked_at = $2 WHERE oauth2_access_token_id = $1 "#, - Uuid::from(access_token.id), + Uuid::from(access_token_id), revoked_at, ) .execute(executor) diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 29577d599..33bd8b5d2 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -149,7 +149,6 @@ struct GrantLookup { oauth2_authorization_grant_requires_consent: bool, oauth2_client_id: Uuid, oauth2_session_id: Option, - user_session_id: Option, } impl GrantLookup { @@ -176,45 +175,22 @@ impl GrantLookup { .row(id) })?; - let session = match (self.oauth2_session_id, self.user_session_id) { - (Some(session_id), Some(user_session_id)) => { - let scope = scope.clone(); - - let session = Session { - id: session_id.into(), - client_id: client.id, - user_session_id: user_session_id.into(), - scope, - finished_at: None, - }; - - Some(session) - } - (None, None) => None, - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("oauth2_session_id") - .row(id) - .into(), - ) - } - }; - let stage = match ( self.oauth2_authorization_grant_fulfilled_at, self.oauth2_authorization_grant_exchanged_at, self.oauth2_authorization_grant_cancelled_at, - session, + self.oauth2_session_id, ) { (None, None, None, None) => AuthorizationGrantStage::Pending, - (Some(fulfilled_at), None, None, Some(session)) => AuthorizationGrantStage::Fulfilled { - session, - fulfilled_at, - }, - (Some(fulfilled_at), Some(exchanged_at), None, Some(session)) => { + (Some(fulfilled_at), None, None, Some(session_id)) => { + AuthorizationGrantStage::Fulfilled { + session_id: session_id.into(), + fulfilled_at, + } + } + (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => { AuthorizationGrantStage::Exchanged { - session, + session_id: session_id.into(), fulfilled_at, exchanged_at, } @@ -343,32 +319,29 @@ pub async fn get_grant_by_id( let res = sqlx::query_as!( GrantLookup, r#" - SELECT og.oauth2_authorization_grant_id - , og.created_at AS oauth2_authorization_grant_created_at - , og.cancelled_at AS oauth2_authorization_grant_cancelled_at - , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at - , og.exchanged_at AS oauth2_authorization_grant_exchanged_at - , og.scope AS oauth2_authorization_grant_scope - , og.state AS oauth2_authorization_grant_state - , og.redirect_uri AS oauth2_authorization_grant_redirect_uri - , og.response_mode AS oauth2_authorization_grant_response_mode - , og.nonce AS oauth2_authorization_grant_nonce - , og.max_age AS oauth2_authorization_grant_max_age - , og.oauth2_client_id AS oauth2_client_id - , og.authorization_code AS oauth2_authorization_grant_code - , og.response_type_code AS oauth2_authorization_grant_response_type_code - , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token - , og.code_challenge AS oauth2_authorization_grant_code_challenge - , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method - , og.requires_consent AS oauth2_authorization_grant_requires_consent - , os.oauth2_session_id AS "oauth2_session_id?" - , os.user_session_id AS "user_session_id?" + SELECT oauth2_authorization_grant_id + , created_at AS oauth2_authorization_grant_created_at + , cancelled_at AS oauth2_authorization_grant_cancelled_at + , fulfilled_at AS oauth2_authorization_grant_fulfilled_at + , exchanged_at AS oauth2_authorization_grant_exchanged_at + , scope AS oauth2_authorization_grant_scope + , state AS oauth2_authorization_grant_state + , redirect_uri AS oauth2_authorization_grant_redirect_uri + , response_mode AS oauth2_authorization_grant_response_mode + , nonce AS oauth2_authorization_grant_nonce + , max_age AS oauth2_authorization_grant_max_age + , oauth2_client_id AS oauth2_client_id + , authorization_code AS oauth2_authorization_grant_code + , response_type_code AS oauth2_authorization_grant_response_type_code + , response_type_id_token AS oauth2_authorization_grant_response_type_id_token + , code_challenge AS oauth2_authorization_grant_code_challenge + , code_challenge_method AS oauth2_authorization_grant_code_challenge_method + , requires_consent AS oauth2_authorization_grant_requires_consent + , oauth2_session_id AS "oauth2_session_id?" FROM - oauth2_authorization_grants og - LEFT JOIN oauth2_sessions os - USING (oauth2_session_id) + oauth2_authorization_grants - WHERE og.oauth2_authorization_grant_id = $1 + WHERE oauth2_authorization_grant_id = $1 "#, Uuid::from(id), ) @@ -391,32 +364,29 @@ pub async fn lookup_grant_by_code( let res = sqlx::query_as!( GrantLookup, r#" - SELECT og.oauth2_authorization_grant_id - , og.created_at AS oauth2_authorization_grant_created_at - , og.cancelled_at AS oauth2_authorization_grant_cancelled_at - , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at - , og.exchanged_at AS oauth2_authorization_grant_exchanged_at - , og.scope AS oauth2_authorization_grant_scope - , og.state AS oauth2_authorization_grant_state - , og.redirect_uri AS oauth2_authorization_grant_redirect_uri - , og.response_mode AS oauth2_authorization_grant_response_mode - , og.nonce AS oauth2_authorization_grant_nonce - , og.max_age AS oauth2_authorization_grant_max_age - , og.oauth2_client_id AS oauth2_client_id - , og.authorization_code AS oauth2_authorization_grant_code - , og.response_type_code AS oauth2_authorization_grant_response_type_code - , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token - , og.code_challenge AS oauth2_authorization_grant_code_challenge - , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method - , og.requires_consent AS oauth2_authorization_grant_requires_consent - , os.oauth2_session_id AS "oauth2_session_id?" - , os.user_session_id AS "user_session_id?" + SELECT oauth2_authorization_grant_id + , created_at AS oauth2_authorization_grant_created_at + , cancelled_at AS oauth2_authorization_grant_cancelled_at + , fulfilled_at AS oauth2_authorization_grant_fulfilled_at + , exchanged_at AS oauth2_authorization_grant_exchanged_at + , scope AS oauth2_authorization_grant_scope + , state AS oauth2_authorization_grant_state + , redirect_uri AS oauth2_authorization_grant_redirect_uri + , response_mode AS oauth2_authorization_grant_response_mode + , nonce AS oauth2_authorization_grant_nonce + , max_age AS oauth2_authorization_grant_max_age + , oauth2_client_id AS oauth2_client_id + , authorization_code AS oauth2_authorization_grant_code + , response_type_code AS oauth2_authorization_grant_response_type_code + , response_type_id_token AS oauth2_authorization_grant_response_type_id_token + , code_challenge AS oauth2_authorization_grant_code_challenge + , code_challenge_method AS oauth2_authorization_grant_code_challenge_method + , requires_consent AS oauth2_authorization_grant_requires_consent + , oauth2_session_id AS "oauth2_session_id?" FROM - oauth2_authorization_grants og - LEFT JOIN oauth2_sessions os - USING (oauth2_session_id) + oauth2_authorization_grants - WHERE og.authorization_code = $1 + WHERE authorization_code = $1 "#, code, ) @@ -466,7 +436,7 @@ pub async fn fulfill_grant( grant.stage = grant .stage - .fulfill(fulfilled_at, session) + .fulfill(fulfilled_at, &session) .map_err(DatabaseError::to_invalid_operation)?; Ok(grant) diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 61ace6fa8..e4c35c710 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -36,7 +36,7 @@ pub async fn add_refresh_token( mut rng: impl Rng + Send, clock: &Clock, session: &Session, - access_token: AccessToken, + access_token: &AccessToken, refresh_token: String, ) -> Result { let created_at = clock.now(); @@ -63,7 +63,7 @@ pub async fn add_refresh_token( Ok(RefreshToken { id, refresh_token, - access_token: Some(access_token), + access_token_id: Some(access_token.id), created_at, }) } @@ -73,9 +73,6 @@ struct OAuth2RefreshTokenLookup { oauth2_refresh_token: String, oauth2_refresh_token_created_at: DateTime, oauth2_access_token_id: Option, - oauth2_access_token: Option, - oauth2_access_token_created_at: Option>, - oauth2_access_token_expires_at: Option>, oauth2_session_id: Uuid, oauth2_client_id: Uuid, oauth2_session_scope: String, @@ -94,10 +91,7 @@ pub async fn lookup_active_refresh_token( SELECT rt.oauth2_refresh_token_id , rt.refresh_token AS oauth2_refresh_token , rt.created_at AS oauth2_refresh_token_created_at - , at.oauth2_access_token_id AS "oauth2_access_token_id?" - , at.access_token AS "oauth2_access_token?" - , at.created_at AS "oauth2_access_token_created_at?" - , at.expires_at AS "oauth2_access_token_expires_at?" + , rt.oauth2_access_token_id AS "oauth2_access_token_id?" , os.oauth2_session_id AS "oauth2_session_id!" , os.oauth2_client_id AS "oauth2_client_id!" , os.scope AS "oauth2_session_scope!" @@ -105,8 +99,6 @@ pub async fn lookup_active_refresh_token( FROM oauth2_refresh_tokens rt INNER JOIN oauth2_sessions os USING (oauth2_session_id) - LEFT JOIN oauth2_access_tokens at - USING (oauth2_access_token_id) WHERE rt.refresh_token = $1 AND rt.consumed_at IS NULL @@ -118,31 +110,11 @@ pub async fn lookup_active_refresh_token( .fetch_one(&mut *conn) .await?; - let access_token = match ( - res.oauth2_access_token_id, - res.oauth2_access_token, - res.oauth2_access_token_created_at, - res.oauth2_access_token_expires_at, - ) { - (None, None, None, None) => None, - (Some(id), Some(access_token), Some(created_at), Some(expires_at)) => { - let id = Ulid::from(id); - Some(AccessToken { - id, - jti: id.to_string(), - access_token, - created_at, - expires_at, - }) - } - _ => return Err(DatabaseInconsistencyError::on("oauth2_access_tokens").into()), - }; - let refresh_token = RefreshToken { id: res.oauth2_refresh_token_id.into(), refresh_token: res.oauth2_refresh_token, created_at: res.oauth2_refresh_token_created_at, - access_token, + access_token_id: res.oauth2_access_token_id.map(Ulid::from), }; let session_id = res.oauth2_session_id.into(); diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 5841a1d91..7acaf8431 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -23,13 +23,15 @@ use uuid::Uuid; use crate::{ pagination::{process_page, Page, QueryBuilderExt}, tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, + Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; #[async_trait] pub trait OAuth2SessionRepository { type Error; + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn create_from_grant( &mut self, rng: &mut (dyn RngCore + Send), @@ -66,6 +68,8 @@ struct OAuthSessionLookup { user_session_id: Uuid, oauth2_client_id: Uuid, scope: String, + #[allow(dead_code)] + created_at: DateTime, finished_at: Option>, } @@ -95,6 +99,41 @@ impl TryFrom for Session { impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { type Error = DatabaseError; + #[tracing::instrument( + name = "db.oauth2_session.lookup", + skip_all, + fields( + db.statement, + session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuthSessionLookup, + r#" + SELECT oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + , finished_at + FROM oauth2_sessions + + WHERE oauth2_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(session) = res else { return Ok(None) }; + + Ok(Some(session.try_into()?)) + } + #[tracing::instrument( name = "db.oauth2_session.create_from_grant", skip_all, From 31779f5222c53cbb1fbfd114d6dcd34a386b5bb6 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 9 Jan 2023 10:49:51 +0100 Subject: [PATCH 14/45] data-model: don't embed the client in the auth grant --- crates/data-model/src/lib.rs | 6 ++ .../src/oauth2/authorization_grant.rs | 10 +- crates/data-model/src/oauth2/session.rs | 13 +++ .../src/oauth2/authorization/complete.rs | 15 ++- .../handlers/src/oauth2/authorization/mod.rs | 10 +- crates/handlers/src/oauth2/consent.rs | 21 ++++- .../storage/src/oauth2/authorization_grant.rs | 92 ++++++++----------- crates/storage/src/oauth2/session.rs | 18 ++-- 8 files changed, 104 insertions(+), 81 deletions(-) diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index d104642ea..c5f4539cf 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -23,12 +23,18 @@ clippy::type_repetition_in_bounds )] +use thiserror::Error; + pub(crate) mod compat; pub(crate) mod oauth2; pub(crate) mod tokens; pub(crate) mod upstream_oauth2; pub(crate) mod users; +#[derive(Debug, Error)] +#[error("invalid state transition")] +pub struct InvalidTransitionError; + pub use self::{ compat::{ CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState, diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index a7222cda6..10f619c71 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -21,11 +21,11 @@ use oauth2_types::{ requests::ResponseMode, }; use serde::Serialize; -use thiserror::Error; use ulid::Ulid; use url::Url; -use super::{client::Client, session::Session}; +use super::session::Session; +use crate::InvalidTransitionError; #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct Pkce { @@ -53,10 +53,6 @@ pub struct AuthorizationCode { pub pkce: Option, } -#[derive(Debug, Error)] -#[error("invalid state transition")] -pub struct InvalidTransitionError; - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)] #[serde(tag = "stage", rename_all = "lowercase")] pub enum AuthorizationGrantStage { @@ -132,7 +128,7 @@ pub struct AuthorizationGrant { #[serde(flatten)] pub stage: AuthorizationGrantStage, pub code: Option, - pub client: Client, + pub client_id: Ulid, pub redirect_uri: Url, pub scope: oauth2_types::scope::Scope, pub state: Option, diff --git a/crates/data-model/src/oauth2/session.rs b/crates/data-model/src/oauth2/session.rs index aec48ac13..bbadd3a71 100644 --- a/crates/data-model/src/oauth2/session.rs +++ b/crates/data-model/src/oauth2/session.rs @@ -17,6 +17,8 @@ use oauth2_types::scope::Scope; use serde::Serialize; use ulid::Ulid; +use crate::InvalidTransitionError; + #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct Session { pub id: Ulid, @@ -25,3 +27,14 @@ pub struct Session { pub scope: Scope, pub finished_at: Option>, } + +impl Session { + pub fn finish(mut self, finished_at: DateTime) -> Result { + if self.finished_at.is_some() { + return Err(InvalidTransitionError); + } + + self.finished_at = Some(finished_at); + Ok(self) + } +} diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 01c89ff3a..b5cfd6ffc 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -29,7 +29,7 @@ use mas_storage::{ oauth2::{ authorization_grant::{fulfill_grant, get_grant_by_id}, consent::fetch_client_consent, - OAuth2SessionRepository, + OAuth2ClientRepository, OAuth2SessionRepository, }, Repository, }; @@ -125,6 +125,7 @@ pub(crate) async fn get( } Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending), Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)), + Err(e) => Err(RouteError::Internal(e.into())), } } @@ -144,6 +145,9 @@ pub enum GrantCompletionError { #[error("denied by the policy")] PolicyViolation, + + #[error("failed to load client")] + NoSuchClient, } impl_from_error_for_route!(GrantCompletionError: sqlx::Error); @@ -182,8 +186,13 @@ pub(crate) async fn complete( return Err(GrantCompletionError::PolicyViolation); } - let current_consent = - fetch_client_consent(&mut txn, &browser_session.user, &grant.client).await?; + let client = txn + .oauth2_client() + .lookup(grant.client_id) + .await? + .ok_or(GrantCompletionError::NoSuchClient)?; + + let current_consent = fetch_client_consent(&mut txn, &browser_session.user, &client).await?; let lacks_consent = grant .scope diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index cfcd936eb..faf7015db 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -360,7 +360,10 @@ pub(crate) async fn get( Err(GrantCompletionError::Internal(e)) => { return Err(RouteError::Internal(e)) } - Err(e @ GrantCompletionError::NotPending) => { + Err( + e @ (GrantCompletionError::NotPending + | GrantCompletionError::NoSuchClient), + ) => { // This should never happen return Err(RouteError::Internal(Box::new(e))); } @@ -390,7 +393,10 @@ pub(crate) async fn get( Err(GrantCompletionError::Internal(e)) => { return Err(RouteError::Internal(e)) } - Err(e @ GrantCompletionError::NotPending) => { + Err( + e @ (GrantCompletionError::NotPending + | GrantCompletionError::NoSuchClient), + ) => { // This should never happen return Err(RouteError::Internal(Box::new(e))); } diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 451077837..64b72d58e 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -28,9 +28,13 @@ use mas_data_model::AuthorizationGrantStage; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; -use mas_storage::oauth2::{ - authorization_grant::{get_grant_by_id, give_consent_to_grant}, - consent::insert_client_consent, +use mas_storage::{ + oauth2::{ + authorization_grant::{get_grant_by_id, give_consent_to_grant}, + consent::insert_client_consent, + OAuth2ClientRepository, + }, + Repository, }; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; use sqlx::PgPool; @@ -55,6 +59,9 @@ pub enum RouteError { #[error("Policy violation")] PolicyViolation, + + #[error("Failed to load client")] + NoSuchClient, } impl_from_error_for_route!(sqlx::Error); @@ -160,6 +167,12 @@ pub(crate) async fn post( return Err(RouteError::PolicyViolation); } + let client = txn + .oauth2_client() + .lookup(grant.client_id) + .await? + .ok_or(RouteError::NoSuchClient)?; + // Do not consent for the "urn:matrix:org.matrix.msc2967.client:device:*" scope let scope_without_device = grant .scope @@ -172,7 +185,7 @@ pub(crate) async fn post( &mut rng, &clock, &session.user, - &grant.client, + &client, &scope_without_device, ) .await?; diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 33bd8b5d2..c5d969765 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -26,8 +26,7 @@ use ulid::Ulid; use url::Url; use uuid::Uuid; -use super::OAuth2ClientRepository; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Repository}; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; #[tracing::instrument( skip_all, @@ -116,7 +115,7 @@ pub async fn new_authorization_grant( stage: AuthorizationGrantStage::Pending, code, redirect_uri, - client, + client_id: client.id, scope, state, nonce, @@ -151,35 +150,27 @@ struct GrantLookup { oauth2_session_id: Option, } -impl GrantLookup { - #[allow(clippy::too_many_lines)] - async fn into_authorization_grant( - self, - conn: &mut PgConnection, - ) -> Result { - let id = self.oauth2_authorization_grant_id.into(); - let scope: Scope = self.oauth2_authorization_grant_scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("scope") - .row(id) - .source(e) - })?; +impl TryFrom for AuthorizationGrant { + type Error = DatabaseInconsistencyError; - let client = conn - .oauth2_client() - .lookup(self.oauth2_client_id.into()) - .await? - .ok_or_else(|| { + #[allow(clippy::too_many_lines)] + fn try_from(value: GrantLookup) -> Result { + let id = value.oauth2_authorization_grant_id.into(); + let scope: Scope = value + .oauth2_authorization_grant_scope + .parse() + .map_err(|e| { DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("client_id") + .column("scope") .row(id) + .source(e) })?; let stage = match ( - self.oauth2_authorization_grant_fulfilled_at, - self.oauth2_authorization_grant_exchanged_at, - self.oauth2_authorization_grant_cancelled_at, - self.oauth2_session_id, + value.oauth2_authorization_grant_fulfilled_at, + value.oauth2_authorization_grant_exchanged_at, + value.oauth2_authorization_grant_cancelled_at, + value.oauth2_session_id, ) { (None, None, None, None) => AuthorizationGrantStage::Pending, (Some(fulfilled_at), None, None, Some(session_id)) => { @@ -202,15 +193,14 @@ impl GrantLookup { return Err( DatabaseInconsistencyError::on("oauth2_authorization_grants") .column("stage") - .row(id) - .into(), + .row(id), ); } }; let pkce = match ( - self.oauth2_authorization_grant_code_challenge, - self.oauth2_authorization_grant_code_challenge_method, + value.oauth2_authorization_grant_code_challenge, + value.oauth2_authorization_grant_code_challenge_method, ) { (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { Some(Pkce { @@ -227,15 +217,14 @@ impl GrantLookup { return Err( DatabaseInconsistencyError::on("oauth2_authorization_grants") .column("code_challenge_method") - .row(id) - .into(), + .row(id), ); } }; let code: Option = match ( - self.oauth2_authorization_grant_response_type_code, - self.oauth2_authorization_grant_code, + value.oauth2_authorization_grant_response_type_code, + value.oauth2_authorization_grant_code, pkce, ) { (false, None, None) => None, @@ -244,13 +233,12 @@ impl GrantLookup { return Err( DatabaseInconsistencyError::on("oauth2_authorization_grants") .column("authorization_code") - .row(id) - .into(), + .row(id), ); } }; - let redirect_uri = self + let redirect_uri = value .oauth2_authorization_grant_redirect_uri .parse() .map_err(|e| { @@ -260,7 +248,7 @@ impl GrantLookup { .source(e) })?; - let response_mode = self + let response_mode = value .oauth2_authorization_grant_response_mode .parse() .map_err(|e| { @@ -270,7 +258,7 @@ impl GrantLookup { .source(e) })?; - let max_age = self + let max_age = value .oauth2_authorization_grant_max_age .map(u32::try_from) .transpose() @@ -292,17 +280,17 @@ impl GrantLookup { Ok(AuthorizationGrant { id, stage, - client, + client_id: value.oauth2_client_id.into(), code, scope, - state: self.oauth2_authorization_grant_state, - nonce: self.oauth2_authorization_grant_nonce, + state: value.oauth2_authorization_grant_state, + nonce: value.oauth2_authorization_grant_nonce, max_age, response_mode, redirect_uri, - created_at: self.oauth2_authorization_grant_created_at, - response_type_id_token: self.oauth2_authorization_grant_response_type_id_token, - requires_consent: self.oauth2_authorization_grant_requires_consent, + created_at: value.oauth2_authorization_grant_created_at, + response_type_id_token: value.oauth2_authorization_grant_response_type_id_token, + requires_consent: value.oauth2_authorization_grant_requires_consent, }) } } @@ -351,9 +339,7 @@ pub async fn get_grant_by_id( let Some(res) = res else { return Ok(None) }; - let grant = res.into_authorization_grant(&mut *conn).await?; - - Ok(Some(grant)) + Ok(Some(res.try_into()?)) } #[tracing::instrument(skip_all, err)] @@ -396,16 +382,14 @@ pub async fn lookup_grant_by_code( let Some(res) = res else { return Ok(None) }; - let grant = res.into_authorization_grant(&mut *conn).await?; - - Ok(Some(grant)) + Ok(Some(res.try_into()?)) } #[tracing::instrument( skip_all, fields( %grant.id, - client.id = %grant.client.id, + client.id = %grant.client_id, %session.id, user_session.id = %session.user_session_id, ), @@ -446,7 +430,7 @@ pub async fn fulfill_grant( skip_all, fields( %grant.id, - client.id = %grant.client.id, + client.id = %grant.client_id, ), err, )] @@ -476,7 +460,7 @@ pub async fn give_consent_to_grant( skip_all, fields( %grant.id, - client.id = %grant.client.id, + client.id = %grant.client_id, ), err, )] diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 7acaf8431..3a681a845 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -142,7 +142,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { %user_session.id, user.id = %user_session.user.id, %grant.id, - client.id = %grant.client.id, + client.id = %grant.client_id, session.id, session.scope = %grant.scope, ), @@ -172,7 +172,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { "#, Uuid::from(id), Uuid::from(user_session.id), - Uuid::from(grant.client.id), + Uuid::from(grant.client_id), grant.scope.to_string(), created_at, ) @@ -183,7 +183,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { Ok(Session { id, user_session_id: user_session.id, - client_id: grant.client.id, + client_id: grant.client_id, scope: grant.scope.clone(), finished_at: None, }) @@ -201,11 +201,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { ), err, )] - async fn finish( - &mut self, - clock: &Clock, - mut session: Session, - ) -> Result { + async fn finish(&mut self, clock: &Clock, session: Session) -> Result { let finished_at = clock.now(); let res = sqlx::query!( r#" @@ -222,9 +218,9 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { DatabaseError::ensure_affected_rows(&res, 1)?; - session.finished_at = Some(finished_at); - - Ok(session) + session + .finish(finished_at) + .map_err(DatabaseError::to_invalid_operation) } #[tracing::instrument( From 2b2f452d96dd720376d6677939aedfc42c340282 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 9 Jan 2023 18:02:32 +0100 Subject: [PATCH 15/45] data-model: have more structs use a state machine --- .../src/{compat.rs => compat/device.rs} | 52 +- crates/data-model/src/compat/mod.rs | 41 + crates/data-model/src/compat/session.rs | 79 ++ crates/data-model/src/compat/sso_login.rs | 148 +++ crates/data-model/src/lib.rs | 11 +- crates/data-model/src/oauth2/mod.rs | 4 +- crates/data-model/src/oauth2/session.rs | 61 +- crates/data-model/src/upstream_oauth2/link.rs | 26 + crates/data-model/src/upstream_oauth2/mod.rs | 61 +- .../src/upstream_oauth2/provider.rs | 31 + .../data-model/src/upstream_oauth2/session.rs | 170 ++++ crates/graphql/src/model/compat_sessions.rs | 22 +- .../handlers/src/upstream_oauth2/callback.rs | 6 +- crates/handlers/src/upstream_oauth2/link.rs | 8 +- crates/storage/sqlx-data.json | 852 +++++++++--------- crates/storage/src/compat.rs | 59 +- crates/storage/src/oauth2/access_token.rs | 7 +- crates/storage/src/oauth2/refresh_token.rs | 7 +- crates/storage/src/oauth2/session.rs | 13 +- crates/storage/src/upstream_oauth2/mod.rs | 15 +- crates/storage/src/upstream_oauth2/session.rs | 96 +- 21 files changed, 1148 insertions(+), 621 deletions(-) rename crates/data-model/src/{compat.rs => compat/device.rs} (64%) create mode 100644 crates/data-model/src/compat/mod.rs create mode 100644 crates/data-model/src/compat/session.rs create mode 100644 crates/data-model/src/compat/sso_login.rs create mode 100644 crates/data-model/src/upstream_oauth2/link.rs create mode 100644 crates/data-model/src/upstream_oauth2/provider.rs create mode 100644 crates/data-model/src/upstream_oauth2/session.rs diff --git a/crates/data-model/src/compat.rs b/crates/data-model/src/compat/device.rs similarity index 64% rename from crates/data-model/src/compat.rs rename to crates/data-model/src/compat/device.rs index 07ff9aaaa..84bdd067e 100644 --- a/crates/data-model/src/compat.rs +++ b/crates/data-model/src/compat/device.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Utc}; use oauth2_types::scope::ScopeToken; use rand::{ distributions::{Alphanumeric, DistString}, @@ -20,8 +19,6 @@ use rand::{ }; use serde::Serialize; use thiserror::Error; -use ulid::Ulid; -use url::Url; static DEVICE_ID_LENGTH: usize = 10; @@ -79,50 +76,3 @@ impl TryFrom for Device { Ok(Self { id }) } } - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct CompatSession { - pub id: Ulid, - pub user_id: Ulid, - pub device: Device, - pub created_at: DateTime, - pub finished_at: Option>, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CompatAccessToken { - pub id: Ulid, - pub token: String, - pub created_at: DateTime, - pub expires_at: Option>, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CompatRefreshToken { - pub id: Ulid, - pub token: String, - pub created_at: DateTime, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub enum CompatSsoLoginState { - Pending, - Fulfilled { - fulfilled_at: DateTime, - session: CompatSession, - }, - Exchanged { - fulfilled_at: DateTime, - exchanged_at: DateTime, - session: CompatSession, - }, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct CompatSsoLogin { - pub id: Ulid, - pub redirect_uri: Url, - pub login_token: String, - pub created_at: DateTime, - pub state: CompatSsoLoginState, -} diff --git a/crates/data-model/src/compat/mod.rs b/crates/data-model/src/compat/mod.rs new file mode 100644 index 000000000..f6e19bd81 --- /dev/null +++ b/crates/data-model/src/compat/mod.rs @@ -0,0 +1,41 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use ulid::Ulid; + +mod device; +mod session; +mod sso_login; + +pub use self::{ + device::Device, + session::{CompatSession, CompatSessionState}, + sso_login::{CompatSsoLogin, CompatSsoLoginState}, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CompatAccessToken { + pub id: Ulid, + pub token: String, + pub created_at: DateTime, + pub expires_at: Option>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CompatRefreshToken { + pub id: Ulid, + pub token: String, + pub created_at: DateTime, +} diff --git a/crates/data-model/src/compat/session.rs b/crates/data-model/src/compat/session.rs new file mode 100644 index 000000000..2c4cdf2d0 --- /dev/null +++ b/crates/data-model/src/compat/session.rs @@ -0,0 +1,79 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; + +use super::Device; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] +pub enum CompatSessionState { + #[default] + Valid, + Finished { + finished_at: DateTime, + }, +} + +impl CompatSessionState { + /// Returns `true` if the compta session state is [`Valid`]. + /// + /// [`Valid`]: ComptaSessionState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the compta session state is [`Finished`]. + /// + /// [`Finished`]: ComptaSessionState::Finished + #[must_use] + pub fn is_finished(&self) -> bool { + matches!(self, Self::Finished { .. }) + } + + pub fn finish(self, finished_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Finished { finished_at }), + Self::Finished { .. } => Err(InvalidTransitionError), + } + } + + #[must_use] + pub fn finished_at(&self) -> Option> { + match self { + CompatSessionState::Valid => None, + CompatSessionState::Finished { finished_at } => Some(*finished_at), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct CompatSession { + pub id: Ulid, + pub state: CompatSessionState, + pub user_id: Ulid, + pub device: Device, + pub created_at: DateTime, +} + +impl std::ops::Deref for CompatSession { + type Target = CompatSessionState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} diff --git a/crates/data-model/src/compat/sso_login.rs b/crates/data-model/src/compat/sso_login.rs new file mode 100644 index 000000000..7e494f828 --- /dev/null +++ b/crates/data-model/src/compat/sso_login.rs @@ -0,0 +1,148 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; +use url::Url; + +use super::CompatSession; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub enum CompatSsoLoginState { + Pending, + Fulfilled { + fulfilled_at: DateTime, + session: CompatSession, + }, + Exchanged { + fulfilled_at: DateTime, + exchanged_at: DateTime, + session: CompatSession, + }, +} + +impl CompatSsoLoginState { + /// Returns `true` if the compat sso login state is [`Pending`]. + /// + /// [`Pending`]: CompatSsoLoginState::Pending + #[must_use] + pub fn is_pending(&self) -> bool { + matches!(self, Self::Pending) + } + + /// Returns `true` if the compat sso login state is [`Fulfilled`]. + /// + /// [`Fulfilled`]: CompatSsoLoginState::Fulfilled + #[must_use] + pub fn is_fulfilled(&self) -> bool { + matches!(self, Self::Fulfilled { .. }) + } + + /// Returns `true` if the compat sso login state is [`Exchanged`]. + /// + /// [`Exchanged`]: CompatSsoLoginState::Exchanged + #[must_use] + pub fn is_exchanged(&self) -> bool { + matches!(self, Self::Exchanged { .. }) + } + + #[must_use] + pub fn fulfilled_at(&self) -> Option> { + match self { + Self::Pending => None, + Self::Fulfilled { fulfilled_at, .. } | Self::Exchanged { fulfilled_at, .. } => { + Some(*fulfilled_at) + } + } + } + + #[must_use] + pub fn exchanged_at(&self) -> Option> { + match self { + Self::Pending | Self::Fulfilled { .. } => None, + Self::Exchanged { exchanged_at, .. } => Some(*exchanged_at), + } + } + + #[must_use] + pub fn session(&self) -> Option<&CompatSession> { + match self { + Self::Pending => None, + Self::Fulfilled { session, .. } | Self::Exchanged { session, .. } => Some(session), + } + } + + pub fn fulfill( + self, + fulfilled_at: DateTime, + session: CompatSession, + ) -> Result { + match self { + Self::Pending => Ok(Self::Fulfilled { + fulfilled_at, + session, + }), + Self::Fulfilled { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError), + } + } + + pub fn exchange(self, exchanged_at: DateTime) -> Result { + match self { + Self::Fulfilled { + fulfilled_at, + session, + } => Ok(Self::Exchanged { + fulfilled_at, + exchanged_at, + session, + }), + Self::Pending { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct CompatSsoLogin { + pub id: Ulid, + pub redirect_uri: Url, + pub login_token: String, + pub created_at: DateTime, + pub state: CompatSsoLoginState, +} + +impl std::ops::Deref for CompatSsoLogin { + type Target = CompatSsoLoginState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl CompatSsoLogin { + pub fn fulfill( + mut self, + fulfilled_at: DateTime, + session: CompatSession, + ) -> Result { + self.state = self.state.fulfill(fulfilled_at, session)?; + Ok(self) + } + + pub fn exchange(mut self, exchanged_at: DateTime) -> Result { + self.state = self.state.exchange(exchanged_at)?; + Ok(self) + } +} diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index c5f4539cf..879dc641f 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,16 +37,17 @@ pub struct InvalidTransitionError; pub use self::{ compat::{ - CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState, - Device, + CompatAccessToken, CompatRefreshToken, CompatSession, CompatSessionState, CompatSsoLogin, + CompatSsoLoginState, Device, }, oauth2::{ AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, - InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, + InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState, }, tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, upstream_oauth2::{ - UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider, + UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, + UpstreamOAuthLink, UpstreamOAuthProvider, }, users::{ Authentication, BrowserSession, Password, User, UserEmail, UserEmailVerification, diff --git a/crates/data-model/src/oauth2/mod.rs b/crates/data-model/src/oauth2/mod.rs index ef512260a..bc76b091d 100644 --- a/crates/data-model/src/oauth2/mod.rs +++ b/crates/data-model/src/oauth2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,5 +19,5 @@ pub(self) mod session; pub use self::{ authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce}, client::{Client, InvalidRedirectUriError, JwksOrJwksUri}, - session::Session, + session::{Session, SessionState}, }; diff --git a/crates/data-model/src/oauth2/session.rs b/crates/data-model/src/oauth2/session.rs index bbadd3a71..68ac821cd 100644 --- a/crates/data-model/src/oauth2/session.rs +++ b/crates/data-model/src/oauth2/session.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,22 +19,69 @@ use ulid::Ulid; use crate::InvalidTransitionError; +trait T { + type State; +} + +impl T for Session { + type State = SessionState; +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] +pub enum SessionState { + #[default] + Valid, + Finished { + finished_at: DateTime, + }, +} + +impl SessionState { + /// Returns `true` if the session state is [`Valid`]. + /// + /// [`Valid`]: SessionState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the session state is [`Finished`]. + /// + /// [`Finished`]: SessionState::Finished + #[must_use] + pub fn is_finished(&self) -> bool { + matches!(self, Self::Finished { .. }) + } + + pub fn finish(self, finished_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Finished { finished_at }), + Self::Finished { .. } => Err(InvalidTransitionError), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct Session { pub id: Ulid, + pub state: SessionState, + pub created_at: DateTime, pub user_session_id: Ulid, pub client_id: Ulid, pub scope: Scope, - pub finished_at: Option>, +} + +impl std::ops::Deref for Session { + type Target = SessionState; + + fn deref(&self) -> &Self::Target { + &self.state + } } impl Session { pub fn finish(mut self, finished_at: DateTime) -> Result { - if self.finished_at.is_some() { - return Err(InvalidTransitionError); - } - - self.finished_at = Some(finished_at); + self.state = self.state.finish(finished_at)?; Ok(self) } } diff --git a/crates/data-model/src/upstream_oauth2/link.rs b/crates/data-model/src/upstream_oauth2/link.rs new file mode 100644 index 000000000..c0699173c --- /dev/null +++ b/crates/data-model/src/upstream_oauth2/link.rs @@ -0,0 +1,26 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UpstreamOAuthLink { + pub id: Ulid, + pub provider_id: Ulid, + pub user_id: Option, + pub subject: String, + pub created_at: DateTime, +} diff --git a/crates/data-model/src/upstream_oauth2/mod.rs b/crates/data-model/src/upstream_oauth2/mod.rs index 08fbf6c0b..90780a8bf 100644 --- a/crates/data-model/src/upstream_oauth2/mod.rs +++ b/crates/data-model/src/upstream_oauth2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,55 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Utc}; -use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; -use oauth2_types::scope::Scope; -use serde::Serialize; -use ulid::Ulid; +mod link; +mod provider; +mod session; -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct UpstreamOAuthProvider { - pub id: Ulid, - pub issuer: String, - pub scope: Scope, - pub client_id: String, - pub encrypted_client_secret: Option, - pub token_endpoint_signing_alg: Option, - pub token_endpoint_auth_method: OAuthClientAuthenticationMethod, - pub created_at: DateTime, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct UpstreamOAuthLink { - pub id: Ulid, - pub provider_id: Ulid, - pub user_id: Option, - pub subject: String, - pub created_at: DateTime, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct UpstreamOAuthAuthorizationSession { - pub id: Ulid, - pub provider_id: Ulid, - pub link_id: Option, - pub state: String, - pub code_challenge_verifier: Option, - pub nonce: String, - pub created_at: DateTime, - pub completed_at: Option>, - pub consumed_at: Option>, - pub id_token: Option, -} - -impl UpstreamOAuthAuthorizationSession { - #[must_use] - pub const fn completed(&self) -> bool { - self.completed_at.is_some() - } - - #[must_use] - pub const fn consumed(&self) -> bool { - self.consumed_at.is_some() - } -} +pub use self::{ + link::UpstreamOAuthLink, + provider::UpstreamOAuthProvider, + session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState}, +}; diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs new file mode 100644 index 000000000..919b22217 --- /dev/null +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -0,0 +1,31 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use oauth2_types::scope::Scope; +use serde::Serialize; +use ulid::Ulid; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UpstreamOAuthProvider { + pub id: Ulid, + pub issuer: String, + pub scope: Scope, + pub client_id: String, + pub encrypted_client_secret: Option, + pub token_endpoint_signing_alg: Option, + pub token_endpoint_auth_method: OAuthClientAuthenticationMethod, + pub created_at: DateTime, +} diff --git a/crates/data-model/src/upstream_oauth2/session.rs b/crates/data-model/src/upstream_oauth2/session.rs new file mode 100644 index 000000000..9ce612668 --- /dev/null +++ b/crates/data-model/src/upstream_oauth2/session.rs @@ -0,0 +1,170 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; + +use super::UpstreamOAuthLink; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] +pub enum UpstreamOAuthAuthorizationSessionState { + #[default] + Pending, + Completed { + completed_at: DateTime, + link_id: Ulid, + id_token: Option, + }, + Consumed { + completed_at: DateTime, + consumed_at: DateTime, + link_id: Ulid, + id_token: Option, + }, +} + +impl UpstreamOAuthAuthorizationSessionState { + pub fn complete( + self, + completed_at: DateTime, + link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result { + match self { + Self::Pending => Ok(Self::Completed { + completed_at, + link_id: link.id, + id_token, + }), + Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError), + } + } + + pub fn consume(self, consumed_at: DateTime) -> Result { + match self { + Self::Completed { + completed_at, + link_id, + id_token, + } => Ok(Self::Consumed { + completed_at, + link_id, + consumed_at, + id_token, + }), + Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError), + } + } + + #[must_use] + pub fn link_id(&self) -> Option { + match self { + Self::Pending => None, + Self::Completed { link_id, .. } | Self::Consumed { link_id, .. } => Some(*link_id), + } + } + + #[must_use] + pub fn completed_at(&self) -> Option> { + match self { + Self::Pending => None, + Self::Completed { completed_at, .. } | Self::Consumed { completed_at, .. } => { + Some(*completed_at) + } + } + } + + #[must_use] + pub fn id_token(&self) -> Option<&str> { + match self { + Self::Pending => None, + Self::Completed { id_token, .. } | Self::Consumed { id_token, .. } => { + id_token.as_deref() + } + } + } + + #[must_use] + pub fn consumed_at(&self) -> Option> { + match self { + Self::Pending | Self::Completed { .. } => None, + Self::Consumed { consumed_at, .. } => Some(*consumed_at), + } + } + + /// Returns `true` if the upstream oauth authorization session state is + /// [`Pending`]. + /// + /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending + #[must_use] + pub fn is_pending(&self) -> bool { + matches!(self, Self::Pending) + } + + /// Returns `true` if the upstream oauth authorization session state is + /// [`Completed`]. + /// + /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed + #[must_use] + pub fn is_completed(&self) -> bool { + matches!(self, Self::Completed { .. }) + } + + /// Returns `true` if the upstream oauth authorization session state is + /// [`Consumed`]. + /// + /// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed + #[must_use] + pub fn is_consumed(&self) -> bool { + matches!(self, Self::Consumed { .. }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UpstreamOAuthAuthorizationSession { + pub id: Ulid, + pub state: UpstreamOAuthAuthorizationSessionState, + pub provider_id: Ulid, + pub state_str: String, + pub code_challenge_verifier: Option, + pub nonce: String, + pub created_at: DateTime, +} + +impl std::ops::Deref for UpstreamOAuthAuthorizationSession { + type Target = UpstreamOAuthAuthorizationSessionState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl UpstreamOAuthAuthorizationSession { + pub fn complete( + mut self, + completed_at: DateTime, + link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result { + self.state = self.state.complete(completed_at, link, id_token)?; + Ok(self) + } + + pub fn consume(mut self, consumed_at: DateTime) -> Result { + self.state = self.state.consume(consumed_at)?; + Ok(self) + } +} diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index f3610233d..38fe866ce 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -15,7 +15,6 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_data_model::CompatSsoLoginState; use mas_storage::{user::UserRepository, Repository}; use sqlx::PgPool; use url::Url; @@ -57,7 +56,7 @@ impl CompatSession { /// When the session ended. pub async fn finished_at(&self) -> Option> { - self.0.finished_at + self.0.finished_at() } } @@ -86,29 +85,16 @@ impl CompatSsoLogin { /// When the login was fulfilled, and the user was redirected back to the /// client. async fn fulfilled_at(&self) -> Option> { - match &self.0.state { - CompatSsoLoginState::Pending => None, - CompatSsoLoginState::Fulfilled { fulfilled_at, .. } - | CompatSsoLoginState::Exchanged { fulfilled_at, .. } => Some(*fulfilled_at), - } + self.0.fulfilled_at() } /// When the client exchanged the login token sent during the redirection. async fn exchanged_at(&self) -> Option> { - match &self.0.state { - CompatSsoLoginState::Pending | CompatSsoLoginState::Fulfilled { .. } => None, - CompatSsoLoginState::Exchanged { exchanged_at, .. } => Some(*exchanged_at), - } + self.0.exchanged_at() } /// The compat session which was started by this login. async fn session(&self) -> Option { - match &self.0.state { - CompatSsoLoginState::Pending => None, - CompatSsoLoginState::Fulfilled { session, .. } - | CompatSsoLoginState::Exchanged { session, .. } => { - Some(CompatSession(session.clone())) - } - } + self.0.session().cloned().map(CompatSession) } } diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 295f7307b..8cb9a605d 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -153,12 +153,12 @@ pub(crate) async fn get( return Err(RouteError::ProviderMismatch); } - if params.state != session.state { + if params.state != session.state_str { // The state in the session cookie should match the one from the params return Err(RouteError::StateMismatch); } - if session.completed() { + if !session.is_pending() { // The session was already completed return Err(RouteError::AlreadyCompleted); } @@ -207,7 +207,7 @@ pub(crate) async fn get( // TODO: all that should be borrowed let validation_data = AuthorizationValidationData { - state: session.state.clone(), + state: session.state_str.clone(), nonce: session.nonce.clone(), code_challenge_verifier: session.code_challenge_verifier.clone(), redirect_uri, diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 80fa04f71..10e1f80e8 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -121,11 +121,11 @@ pub(crate) async fn get( // This checks that we're in a browser session which is allowed to consume this // link: the upstream auth session should have been started in this browser. - if upstream_session.link_id != Some(link.id) { + if upstream_session.link_id() != Some(link.id) { return Err(RouteError::SessionNotFound); } - if upstream_session.consumed() { + if upstream_session.is_consumed() { return Err(RouteError::SessionConsumed); } @@ -243,11 +243,11 @@ pub(crate) async fn post( // This checks that we're in a browser session which is allowed to consume this // link: the upstream auth session should have been started in this browser. - if upstream_session.link_id != Some(link.id) { + if upstream_session.link_id() != Some(link.id) { return Err(RouteError::SessionNotFound); } - if upstream_session.consumed() { + if upstream_session.is_consumed() { return Err(RouteError::SessionConsumed); } diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 740e9a3a5..e8a33bd24 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -657,6 +657,74 @@ }, "query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n " }, + "5f0e2aec0d7766d3674af3e68417921fec7068e83845e218a4a00d86487557f9": { + "describe": { + "columns": [ + { + "name": "oauth2_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "oauth2_access_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_created_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id!", + "ordinal": 5, + "type_info": "Uuid" + }, + { + "name": "oauth2_client_id!", + "ordinal": 6, + "type_info": "Uuid" + }, + { + "name": "scope!", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "user_session_id!", + "ordinal": 8, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , os.user_session_id AS \"user_session_id!\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n " + }, "5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": { "describe": { "columns": [], @@ -1397,6 +1465,134 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n requires_consent = 'f'\n WHERE\n og.oauth2_authorization_grant_id = $1\n " }, + "aa2fd69c595f94d8598715766a79671dba8f87b9d7af6ac30e3fa1fbc8cce28a": { + "describe": { + "columns": [ + { + "name": "oauth2_authorization_grant_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_cancelled_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_fulfilled_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_exchanged_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_scope", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_state", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_redirect_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_mode", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_nonce", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_max_age", + "ordinal": 10, + "type_info": "Int4" + }, + { + "name": "oauth2_client_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_code", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_type_code", + "ordinal": 13, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_response_type_id_token", + "ordinal": 14, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_code_challenge", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_code_challenge_method", + "ordinal": 16, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_requires_consent", + "ordinal": 17, + "type_info": "Bool" + }, + { + "name": "oauth2_session_id?", + "ordinal": 18, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE authorization_code = $1\n " + }, "aff08a8caabeb62f4929e6e901e7ca7c55e284c18c5c1d1e78821dd9bc961412": { "describe": { "columns": [ @@ -1442,6 +1638,134 @@ }, "query": "\n SELECT user_email_id\n , user_id\n , email\n , created_at\n , confirmed_at\n FROM user_emails\n\n WHERE user_id = $1 AND email = $2\n " }, + "b12f7ba71ad522261f54ffbb6739a7a06214b4f01e3ed6f7fdaa2033d249f3fb": { + "describe": { + "columns": [ + { + "name": "oauth2_authorization_grant_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_cancelled_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_fulfilled_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_exchanged_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_scope", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_state", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_redirect_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_mode", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_nonce", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_max_age", + "ordinal": 10, + "type_info": "Int4" + }, + { + "name": "oauth2_client_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_code", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_type_code", + "ordinal": 13, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_response_type_id_token", + "ordinal": 14, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_code_challenge", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_code_challenge_method", + "ordinal": 16, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_requires_consent", + "ordinal": 17, + "type_info": "Bool" + }, + { + "name": "oauth2_session_id?", + "ordinal": 18, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n " + }, "b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537": { "describe": { "columns": [], @@ -1544,140 +1868,6 @@ }, "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " }, - "c467144ae98322e3ed6d34df6626d63c15bdfc7137e12097cfb6f9398f7029ca": { - "describe": { - "columns": [ - { - "name": "oauth2_authorization_grant_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_cancelled_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_fulfilled_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_exchanged_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_scope", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_state", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_redirect_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_mode", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_nonce", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_max_age", - "ordinal": 10, - "type_info": "Int4" - }, - { - "name": "oauth2_client_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_code", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_type_code", - "ordinal": 13, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_response_type_id_token", - "ordinal": 14, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_code_challenge", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_code_challenge_method", - "ordinal": 16, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_requires_consent", - "ordinal": 17, - "type_info": "Bool" - }, - { - "name": "oauth2_session_id?", - "ordinal": 18, - "type_info": "Uuid" - }, - { - "name": "user_session_id?", - "ordinal": 19, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - true, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true, - false, - false, - true, - true, - false, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT og.oauth2_authorization_grant_id\n , og.created_at AS oauth2_authorization_grant_created_at\n , og.cancelled_at AS oauth2_authorization_grant_cancelled_at\n , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , og.exchanged_at AS oauth2_authorization_grant_exchanged_at\n , og.scope AS oauth2_authorization_grant_scope\n , og.state AS oauth2_authorization_grant_state\n , og.redirect_uri AS oauth2_authorization_grant_redirect_uri\n , og.response_mode AS oauth2_authorization_grant_response_mode\n , og.nonce AS oauth2_authorization_grant_nonce\n , og.max_age AS oauth2_authorization_grant_max_age\n , og.oauth2_client_id AS oauth2_client_id\n , og.authorization_code AS oauth2_authorization_grant_code\n , og.response_type_code AS oauth2_authorization_grant_response_type_code\n , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , og.code_challenge AS oauth2_authorization_grant_code_challenge\n , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , og.requires_consent AS oauth2_authorization_grant_requires_consent\n , os.oauth2_session_id AS \"oauth2_session_id?\"\n , os.user_session_id AS \"user_session_id?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE og.authorization_code = $1\n " - }, "c88376abdba124ff0487a9a69d2345c7d69d7394f355111ec369cfa6d45fb40f": { "describe": { "columns": [], @@ -1704,86 +1894,6 @@ }, "query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)\n " }, - "cad4d47709278a9ddbebfc91642967b465bafa596827d9b86a336841b2cfbf0c": { - "describe": { - "columns": [ - { - "name": "oauth2_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "oauth2_refresh_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_id?", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "oauth2_access_token?", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "oauth2_access_token_created_at?", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_expires_at?", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id!", - "ordinal": 7, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 8, - "type_info": "Uuid" - }, - { - "name": "oauth2_session_scope!", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_session_id!", - "ordinal": 10, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT rt.oauth2_refresh_token_id\n , rt.refresh_token AS oauth2_refresh_token\n , rt.created_at AS oauth2_refresh_token_created_at\n , at.oauth2_access_token_id AS \"oauth2_access_token_id?\"\n , at.access_token AS \"oauth2_access_token?\"\n , at.created_at AS \"oauth2_access_token_created_at?\"\n , at.expires_at AS \"oauth2_access_token_expires_at?\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"oauth2_session_scope!\"\n , os.user_session_id AS \"user_session_id!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN oauth2_access_tokens at\n USING (oauth2_access_token_id)\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND os.finished_at IS NULL\n " - }, "caf54e4659306a746747aa61906bdb2cb8da51176e90435aa8b9754ebf3e4d60": { "describe": { "columns": [], @@ -1799,140 +1909,6 @@ }, "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " }, - "d08b787fc422b6699ffc0a491ecf92fb993db0aca51534b315bcfa4891baca84": { - "describe": { - "columns": [ - { - "name": "oauth2_authorization_grant_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_cancelled_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_fulfilled_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_exchanged_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_scope", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_state", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_redirect_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_mode", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_nonce", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_max_age", - "ordinal": 10, - "type_info": "Int4" - }, - { - "name": "oauth2_client_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_code", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_type_code", - "ordinal": 13, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_response_type_id_token", - "ordinal": 14, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_code_challenge", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_code_challenge_method", - "ordinal": 16, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_requires_consent", - "ordinal": 17, - "type_info": "Bool" - }, - { - "name": "oauth2_session_id?", - "ordinal": 18, - "type_info": "Uuid" - }, - { - "name": "user_session_id?", - "ordinal": 19, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - true, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true, - false, - false, - true, - true, - false, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT og.oauth2_authorization_grant_id\n , og.created_at AS oauth2_authorization_grant_created_at\n , og.cancelled_at AS oauth2_authorization_grant_cancelled_at\n , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , og.exchanged_at AS oauth2_authorization_grant_exchanged_at\n , og.scope AS oauth2_authorization_grant_scope\n , og.state AS oauth2_authorization_grant_state\n , og.redirect_uri AS oauth2_authorization_grant_redirect_uri\n , og.response_mode AS oauth2_authorization_grant_response_mode\n , og.nonce AS oauth2_authorization_grant_nonce\n , og.max_age AS oauth2_authorization_grant_max_age\n , og.oauth2_client_id AS oauth2_client_id\n , og.authorization_code AS oauth2_authorization_grant_code\n , og.response_type_code AS oauth2_authorization_grant_response_type_code\n , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , og.code_challenge AS oauth2_authorization_grant_code_challenge\n , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , og.requires_consent AS oauth2_authorization_grant_requires_consent\n , os.oauth2_session_id AS \"oauth2_session_id?\"\n , os.user_session_id AS \"user_session_id?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE og.oauth2_authorization_grant_id = $1\n " - }, "d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c": { "describe": { "columns": [ @@ -2182,6 +2158,74 @@ }, "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " }, + "e25b8071b59075c4be9fac283410ec4acf771fdf06076ef7bbb11bf086c4bc03": { + "describe": { + "columns": [ + { + "name": "oauth2_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "oauth2_refresh_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_id?", + "ordinal": 3, + "type_info": "Uuid" + }, + { + "name": "oauth2_session_created_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id!", + "ordinal": 5, + "type_info": "Uuid" + }, + { + "name": "oauth2_client_id!", + "ordinal": 6, + "type_info": "Uuid" + }, + { + "name": "oauth2_session_scope!", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "user_session_id!", + "ordinal": 8, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false, + false, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT rt.oauth2_refresh_token_id\n , rt.refresh_token AS oauth2_refresh_token\n , rt.created_at AS oauth2_refresh_token_created_at\n , rt.oauth2_access_token_id AS \"oauth2_access_token_id?\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"oauth2_session_scope!\"\n , os.user_session_id AS \"user_session_id!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND os.finished_at IS NULL\n " + }, "e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": { "describe": { "columns": [ @@ -2227,6 +2271,56 @@ }, "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " }, + "f0ace1af3775192a555c4ebb59b81183f359771f9f77e5fad759d38d872541d1": { + "describe": { + "columns": [ + { + "name": "oauth2_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_session_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "oauth2_client_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "scope", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "finished_at", + "ordinal": 5, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n , finished_at\n FROM oauth2_sessions\n\n WHERE oauth2_session_id = $1\n " + }, "f5edcd4c306ca8179cdf9d4aab59fbba971b54611c91345849920954dd8089b3": { "describe": { "columns": [], @@ -2253,67 +2347,5 @@ } }, "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " - }, - "fba88894ee24cd181f50412571a19ee658f77012d330e7dab43a3c18d549355a": { - "describe": { - "columns": [ - { - "name": "oauth2_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "oauth2_access_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id!", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 5, - "type_info": "Uuid" - }, - { - "name": "scope!", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "user_session_id!", - "ordinal": 7, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , os.user_session_id AS \"user_session_id!\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n " } } \ No newline at end of file diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs index ba47990d1..c328f9e13 100644 --- a/crates/storage/src/compat.rs +++ b/crates/storage/src/compat.rs @@ -14,8 +14,8 @@ use chrono::{DateTime, Duration, Utc}; use mas_data_model::{ - CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState, - Device, User, + CompatAccessToken, CompatRefreshToken, CompatSession, CompatSessionState, CompatSsoLogin, + CompatSsoLoginState, Device, User, }; use rand::Rng; use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder}; @@ -93,12 +93,17 @@ pub async fn lookup_active_compat_access_token( .source(e) })?; + let state = match res.compat_session_finished_at { + None => CompatSessionState::Valid, + Some(finished_at) => CompatSessionState::Finished { finished_at }, + }; + let session = CompatSession { id, + state, user_id: res.user_id.into(), device, created_at: res.compat_session_created_at, - finished_at: res.compat_session_finished_at, }; Ok(Some((token, session))) @@ -181,12 +186,17 @@ pub async fn lookup_active_compat_refresh_token( .source(e) })?; + let state = match res.compat_session_finished_at { + None => CompatSessionState::Valid, + Some(finished_at) => CompatSessionState::Finished { finished_at }, + }; + let session = CompatSession { id, + state, user_id: res.user_id.into(), device, created_at: res.compat_session_created_at, - finished_at: res.compat_session_finished_at, }; Ok(Some((refresh_token, access_token, session))) @@ -468,12 +478,18 @@ impl TryFrom for CompatSsoLogin { .row(id) .source(e) })?; + + let state = match finished_at { + None => CompatSessionState::Valid, + Some(finished_at) => CompatSessionState::Finished { finished_at }, + }; + Some(CompatSession { id, + state, user_id: user_id.into(), device, created_at, - finished_at, }) } (None, None, None, None, None) => None, @@ -686,10 +702,10 @@ pub async fn start_compat_session( Ok(CompatSession { id, + state: CompatSessionState::default(), user_id: user.id, device, created_at, - finished_at: None, }) } @@ -709,7 +725,7 @@ pub async fn fullfill_compat_sso_login( mut rng: impl Rng + Send, clock: &Clock, user: &User, - mut compat_sso_login: CompatSsoLogin, + compat_sso_login: CompatSsoLogin, device: Device, ) -> Result { if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) { @@ -719,8 +735,12 @@ pub async fn fullfill_compat_sso_login( let mut txn = conn.begin().await?; let session = start_compat_session(&mut txn, &mut rng, clock, user, device).await?; + let session_id = session.id; let fulfilled_at = clock.now(); + let compat_sso_login = compat_sso_login + .fulfill(fulfilled_at, session) + .map_err(DatabaseError::to_invalid_operation)?; sqlx::query!( r#" UPDATE compat_sso_logins @@ -731,20 +751,13 @@ pub async fn fullfill_compat_sso_login( compat_sso_login_id = $1 "#, Uuid::from(compat_sso_login.id), - Uuid::from(session.id), + Uuid::from(session_id), fulfilled_at, ) .execute(&mut txn) .instrument(tracing::info_span!("Update compat SSO login")) .await?; - let state = CompatSsoLoginState::Fulfilled { - fulfilled_at, - session, - }; - - compat_sso_login.state = state; - txn.commit().await?; Ok(compat_sso_login) @@ -761,13 +774,13 @@ pub async fn fullfill_compat_sso_login( pub async fn mark_compat_sso_login_as_exchanged( executor: impl PgExecutor<'_>, clock: &Clock, - mut compat_sso_login: CompatSsoLogin, + compat_sso_login: CompatSsoLogin, ) -> Result { - let CompatSsoLoginState::Fulfilled { fulfilled_at, session } = compat_sso_login.state else { - return Err(DatabaseError::invalid_operation()); - }; - let exchanged_at = clock.now(); + let compat_sso_login = compat_sso_login + .exchange(exchanged_at) + .map_err(DatabaseError::to_invalid_operation)?; + sqlx::query!( r#" UPDATE compat_sso_logins @@ -783,11 +796,5 @@ pub async fn mark_compat_sso_login_as_exchanged( .instrument(tracing::info_span!("Update compat SSO login")) .await?; - let state = CompatSsoLoginState::Exchanged { - fulfilled_at, - exchanged_at, - session, - }; - compat_sso_login.state = state; Ok(compat_sso_login) } diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index cd4cafbfc..58c13b190 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -13,7 +13,7 @@ // limitations under the License. use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{AccessToken, Session}; +use mas_data_model::{AccessToken, Session, SessionState}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; @@ -76,6 +76,7 @@ pub struct OAuth2AccessTokenLookup { oauth2_access_token: String, oauth2_access_token_created_at: DateTime, oauth2_access_token_expires_at: DateTime, + oauth2_session_created_at: DateTime, oauth2_session_id: Uuid, oauth2_client_id: Uuid, scope: String, @@ -94,6 +95,7 @@ pub async fn lookup_active_access_token( , at.access_token AS "oauth2_access_token" , at.created_at AS "oauth2_access_token_created_at" , at.expires_at AS "oauth2_access_token_expires_at" + , os.created_at AS "oauth2_session_created_at" , os.oauth2_session_id AS "oauth2_session_id!" , os.oauth2_client_id AS "oauth2_client_id!" , os.scope AS "scope!" @@ -131,10 +133,11 @@ pub async fn lookup_active_access_token( let session = Session { id: session_id, + state: SessionState::Valid, + created_at: res.oauth2_session_created_at, client_id: res.oauth2_client_id.into(), user_session_id: res.user_session_id.into(), scope, - finished_at: None, }; Ok(Some((access_token, session))) diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index e4c35c710..f49b38e8a 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -13,7 +13,7 @@ // limitations under the License. use chrono::{DateTime, Utc}; -use mas_data_model::{AccessToken, RefreshToken, Session}; +use mas_data_model::{AccessToken, RefreshToken, Session, SessionState}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; @@ -73,6 +73,7 @@ struct OAuth2RefreshTokenLookup { oauth2_refresh_token: String, oauth2_refresh_token_created_at: DateTime, oauth2_access_token_id: Option, + oauth2_session_created_at: DateTime, oauth2_session_id: Uuid, oauth2_client_id: Uuid, oauth2_session_scope: String, @@ -92,6 +93,7 @@ pub async fn lookup_active_refresh_token( , rt.refresh_token AS oauth2_refresh_token , rt.created_at AS oauth2_refresh_token_created_at , rt.oauth2_access_token_id AS "oauth2_access_token_id?" + , os.created_at AS "oauth2_session_created_at" , os.oauth2_session_id AS "oauth2_session_id!" , os.oauth2_client_id AS "oauth2_client_id!" , os.scope AS "oauth2_session_scope!" @@ -127,10 +129,11 @@ pub async fn lookup_active_refresh_token( let session = Session { id: session_id, + state: SessionState::Valid, + created_at: res.oauth2_session_created_at, client_id: res.oauth2_client_id.into(), user_session_id: res.user_session_id.into(), scope, - finished_at: None, }; Ok(Some((refresh_token, session))) diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 3a681a845..9df2f61d2 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User}; +use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User}; use rand::RngCore; use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; @@ -85,12 +85,18 @@ impl TryFrom for Session { .source(e) })?; + let state = match value.finished_at { + None => SessionState::Valid, + Some(finished_at) => SessionState::Finished { finished_at }, + }; + Ok(Session { id, + state, + created_at: value.created_at, client_id: value.oauth2_client_id.into(), user_session_id: value.user_session_id.into(), scope, - finished_at: value.finished_at, }) } } @@ -182,10 +188,11 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { Ok(Session { id, + state: SessionState::Valid, + created_at, user_session_id: user_session.id, client_id: grant.client_id, scope: grant.scope.clone(), - finished_at: None, }) } diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 1abcd1d02..e195056c8 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -85,9 +85,10 @@ mod tests { .await? .expect("session to be found in the database"); assert_eq!(session.provider_id, provider.id); - assert_eq!(session.link_id, None); - assert!(!session.completed()); - assert!(!session.consumed()); + assert_eq!(session.link_id(), None); + assert!(session.is_pending()); + assert!(!session.is_completed()); + assert!(!session.is_consumed()); // Create a link let link = conn @@ -114,15 +115,15 @@ mod tests { .upstream_oauth_session() .complete_with_link(&clock, session, &link, None) .await?; - assert!(session.completed()); - assert!(!session.consumed()); - assert_eq!(session.link_id, Some(link.id)); + assert!(session.is_completed()); + assert!(!session.is_consumed()); + assert_eq!(session.link_id(), Some(link.id)); let session = conn .upstream_oauth_session() .consume(&clock, session) .await?; - assert!(session.consumed()); + assert!(session.is_consumed()); Ok(()) } diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index f13c6ec88..d5da6ef8b 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -14,13 +14,18 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider}; +use mas_data_model::{ + UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, + UpstreamOAuthProvider, +}; use rand::RngCore; use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; +use crate::{ + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, +}; #[async_trait] pub trait UpstreamOAuthSessionRepository: Send + Sync { @@ -83,6 +88,52 @@ struct SessionLookup { consumed_at: Option>, } +impl TryFrom for UpstreamOAuthAuthorizationSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: SessionLookup) -> Result { + let id = value.upstream_oauth_authorization_session_id.into(); + let state = match ( + value.upstream_oauth_link_id, + value.id_token, + value.completed_at, + value.consumed_at, + ) { + (None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, + (Some(link_id), id_token, Some(completed_at), None) => { + UpstreamOAuthAuthorizationSessionState::Completed { + completed_at, + link_id: link_id.into(), + id_token, + } + } + (Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => { + UpstreamOAuthAuthorizationSessionState::Consumed { + completed_at, + link_id: link_id.into(), + id_token, + consumed_at, + } + } + _ => { + return Err( + DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id), + ) + } + }; + + Ok(Self { + id, + provider_id: value.upstream_oauth_provider_id.into(), + state_str: value.state, + nonce: value.nonce, + code_challenge_verifier: value.code_challenge_verifier, + created_at: value.created_at, + state, + }) + } +} + #[async_trait] impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> { type Error = DatabaseError; @@ -126,20 +177,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> let Some(res) = res else { return Ok(None) }; - let session = UpstreamOAuthAuthorizationSession { - id: res.upstream_oauth_authorization_session_id.into(), - provider_id: res.upstream_oauth_provider_id.into(), - link_id: res.upstream_oauth_link_id.map(Ulid::from), - state: res.state, - code_challenge_verifier: res.code_challenge_verifier, - nonce: res.nonce, - id_token: res.id_token, - created_at: res.created_at, - completed_at: res.completed_at, - consumed_at: res.consumed_at, - }; - - Ok(Some(session)) + Ok(Some(res.try_into()?)) } #[tracing::instrument( @@ -159,7 +197,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> rng: &mut (dyn RngCore + Send), clock: &Clock, upstream_oauth_provider: &UpstreamOAuthProvider, - state: String, + state_str: String, code_challenge_verifier: Option, nonce: String, ) -> Result { @@ -186,7 +224,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> "#, Uuid::from(id), Uuid::from(upstream_oauth_provider.id), - &state, + &state_str, code_challenge_verifier.as_deref(), nonce, created_at, @@ -197,15 +235,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> Ok(UpstreamOAuthAuthorizationSession { id, + state: UpstreamOAuthAuthorizationSessionState::default(), provider_id: upstream_oauth_provider.id, - link_id: None, - state, + state_str, code_challenge_verifier, nonce, - id_token: None, created_at, - completed_at: None, - consumed_at: None, }) } @@ -222,11 +257,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> async fn complete_with_link( &mut self, clock: &Clock, - mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, ) -> Result { let completed_at = clock.now(); + sqlx::query!( r#" UPDATE upstream_oauth_authorization_sessions @@ -244,9 +280,9 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> .execute(&mut *self.conn) .await?; - upstream_oauth_authorization_session.completed_at = Some(completed_at); - upstream_oauth_authorization_session.id_token = id_token; - upstream_oauth_authorization_session.link_id = Some(upstream_oauth_link.id); + let upstream_oauth_authorization_session = upstream_oauth_authorization_session + .complete(completed_at, upstream_oauth_link, id_token) + .map_err(DatabaseError::to_invalid_operation)?; Ok(upstream_oauth_authorization_session) } @@ -264,7 +300,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> async fn consume( &mut self, clock: &Clock, - mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, ) -> Result { let consumed_at = clock.now(); sqlx::query!( @@ -280,7 +316,9 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> .execute(&mut *self.conn) .await?; - upstream_oauth_authorization_session.consumed_at = Some(consumed_at); + let upstream_oauth_authorization_session = upstream_oauth_authorization_session + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation)?; Ok(upstream_oauth_authorization_session) } From f0a44fcd5e7580c225697b1e7b84f4d21bd3f851 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 10 Jan 2023 18:49:35 +0100 Subject: [PATCH 16/45] storage: do less joins in compat sessions --- crates/data-model/src/compat/mod.rs | 65 +++ crates/data-model/src/compat/session.rs | 7 + crates/data-model/src/compat/sso_login.rs | 20 +- crates/data-model/src/lib.rs | 4 +- crates/graphql/src/model/compat_sessions.rs | 16 +- crates/handlers/src/compat/login.rs | 37 +- crates/handlers/src/compat/logout.rs | 29 +- crates/handlers/src/compat/refresh.rs | 44 +- crates/handlers/src/oauth2/introspection.rs | 26 +- crates/storage/sqlx-data.json | 541 +++++++++----------- crates/storage/src/compat.rs | 369 ++++++------- 11 files changed, 616 insertions(+), 542 deletions(-) diff --git a/crates/data-model/src/compat/mod.rs b/crates/data-model/src/compat/mod.rs index f6e19bd81..d0e560c70 100644 --- a/crates/data-model/src/compat/mod.rs +++ b/crates/data-model/src/compat/mod.rs @@ -24,18 +24,83 @@ pub use self::{ session::{CompatSession, CompatSessionState}, sso_login::{CompatSsoLogin, CompatSsoLoginState}, }; +use crate::InvalidTransitionError; #[derive(Debug, Clone, PartialEq, Eq)] pub struct CompatAccessToken { pub id: Ulid, + pub session_id: Ulid, pub token: String, pub created_at: DateTime, pub expires_at: Option>, } +impl CompatAccessToken { + #[must_use] + pub fn is_valid(&self, now: DateTime) -> bool { + if let Some(expires_at) = self.expires_at { + expires_at > now + } else { + true + } + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum CompatRefreshTokenState { + #[default] + Valid, + Consumed { + consumed_at: DateTime, + }, +} + +impl CompatRefreshTokenState { + /// Returns `true` if the compat refresh token state is [`Valid`]. + /// + /// [`Valid`]: CompatRefreshTokenState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the compat refresh token state is [`Consumed`]. + /// + /// [`Consumed`]: CompatRefreshTokenState::Consumed + #[must_use] + pub fn is_consumed(&self) -> bool { + matches!(self, Self::Consumed { .. }) + } + + pub fn consume(self, consumed_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Consumed { consumed_at }), + Self::Consumed { .. } => Err(InvalidTransitionError), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct CompatRefreshToken { pub id: Ulid, + pub state: CompatRefreshTokenState, + pub session_id: Ulid, + pub access_token_id: Ulid, pub token: String, pub created_at: DateTime, } + +impl std::ops::Deref for CompatRefreshToken { + type Target = CompatRefreshTokenState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl CompatRefreshToken { + pub fn consume(mut self, consumed_at: DateTime) -> Result { + self.state = self.state.consume(consumed_at)?; + Ok(self) + } +} diff --git a/crates/data-model/src/compat/session.rs b/crates/data-model/src/compat/session.rs index 2c4cdf2d0..a5c2c17de 100644 --- a/crates/data-model/src/compat/session.rs +++ b/crates/data-model/src/compat/session.rs @@ -77,3 +77,10 @@ impl std::ops::Deref for CompatSession { &self.state } } + +impl CompatSession { + pub fn finish(mut self, finished_at: DateTime) -> Result { + self.state = self.state.finish(finished_at)?; + Ok(self) + } +} diff --git a/crates/data-model/src/compat/sso_login.rs b/crates/data-model/src/compat/sso_login.rs index 7e494f828..54fd96b31 100644 --- a/crates/data-model/src/compat/sso_login.rs +++ b/crates/data-model/src/compat/sso_login.rs @@ -25,12 +25,12 @@ pub enum CompatSsoLoginState { Pending, Fulfilled { fulfilled_at: DateTime, - session: CompatSession, + session_id: Ulid, }, Exchanged { fulfilled_at: DateTime, exchanged_at: DateTime, - session: CompatSession, + session_id: Ulid, }, } @@ -78,22 +78,24 @@ impl CompatSsoLoginState { } #[must_use] - pub fn session(&self) -> Option<&CompatSession> { + pub fn session_id(&self) -> Option { match self { Self::Pending => None, - Self::Fulfilled { session, .. } | Self::Exchanged { session, .. } => Some(session), + Self::Fulfilled { session_id, .. } | Self::Exchanged { session_id, .. } => { + Some(*session_id) + } } } pub fn fulfill( self, fulfilled_at: DateTime, - session: CompatSession, + session: &CompatSession, ) -> Result { match self { Self::Pending => Ok(Self::Fulfilled { fulfilled_at, - session, + session_id: session.id, }), Self::Fulfilled { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError), } @@ -103,11 +105,11 @@ impl CompatSsoLoginState { match self { Self::Fulfilled { fulfilled_at, - session, + session_id, } => Ok(Self::Exchanged { fulfilled_at, exchanged_at, - session, + session_id, }), Self::Pending { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError), } @@ -135,7 +137,7 @@ impl CompatSsoLogin { pub fn fulfill( mut self, fulfilled_at: DateTime, - session: CompatSession, + session: &CompatSession, ) -> Result { self.state = self.state.fulfill(fulfilled_at, session)?; Ok(self) diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 879dc641f..8454f05d3 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -37,8 +37,8 @@ pub struct InvalidTransitionError; pub use self::{ compat::{ - CompatAccessToken, CompatRefreshToken, CompatSession, CompatSessionState, CompatSsoLogin, - CompatSsoLoginState, Device, + CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, + CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, }, oauth2::{ AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index 38fe866ce..394639cff 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -15,7 +15,7 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_storage::{user::UserRepository, Repository}; +use mas_storage::{compat::lookup_compat_session, user::UserRepository, Repository}; use sqlx::PgPool; use url::Url; @@ -94,7 +94,17 @@ impl CompatSsoLogin { } /// The compat session which was started by this login. - async fn session(&self) -> Option { - self.0.session().cloned().map(CompatSession) + async fn session( + &self, + ctx: &Context<'_>, + ) -> Result, async_graphql::Error> { + let Some(session_id) = self.0.session_id() else { return Ok(None) }; + + let mut conn = ctx.data::()?.acquire().await?; + let session = lookup_compat_session(&mut conn, session_id) + .await? + .context("Could not load compat session")?; + + Ok(Some(CompatSession(session))) } } diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index c59c7dd81..f36d520bd 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -19,7 +19,7 @@ use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User use mas_storage::{ compat::{ add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token, - mark_compat_sso_login_as_exchanged, start_compat_session, + lookup_compat_session, mark_compat_sso_login_as_exchanged, start_compat_session, }, user::{UserPasswordRepository, UserRepository}, Clock, Repository, @@ -137,6 +137,9 @@ pub enum RouteError { #[error("user not found")] UserNotFound, + #[error("session not found")] + SessionNotFound, + #[error("user has no password")] NoPassword, @@ -156,7 +159,7 @@ impl_from_error_for_route!(mas_storage::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { - Self::Internal(_) => MatrixError { + Self::Internal(_) | Self::SessionNotFound => MatrixError { errcode: "M_UNKNOWN", error: "Internal server error", status: StatusCode::INTERNAL_SERVER_ERROR, @@ -268,7 +271,7 @@ async fn token_login( .ok_or(RouteError::InvalidLoginToken)?; let now = clock.now(); - let user_id = match login.state { + let session_id = match login.state { CompatSsoLoginState::Pending => { tracing::error!( compat_sso_login.id = %login.id, @@ -277,21 +280,26 @@ async fn token_login( return Err(RouteError::InvalidLoginToken); } CompatSsoLoginState::Fulfilled { - fulfilled_at: fullfilled_at, - ref session, + fulfilled_at, + session_id, .. } => { - if now > fullfilled_at + Duration::seconds(30) { + if now > fulfilled_at + Duration::seconds(30) { return Err(RouteError::LoginTookTooLong); } - session.user_id + session_id } - CompatSsoLoginState::Exchanged { exchanged_at, .. } => { + CompatSsoLoginState::Exchanged { + exchanged_at, + session_id, + .. + } => { if now > exchanged_at + Duration::seconds(30) { // TODO: log that session out tracing::error!( compat_sso_login.id = %login.id, + compat_session.id = %session_id, "Login token exchanged a second time more than 30s after" ); } @@ -300,18 +308,19 @@ async fn token_login( } }; + let session = lookup_compat_session(&mut *txn, session_id) + .await? + .ok_or(RouteError::SessionNotFound)?; + let user = txn .user() - .lookup(user_id) + .lookup(session.user_id) .await? .ok_or(RouteError::UserNotFound)?; - let login = mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?; + mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?; - match login.state { - CompatSsoLoginState::Exchanged { session, .. } => Ok((session, user)), - _ => unreachable!(), - } + Ok((session, user)) } async fn user_password_login( diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 4dca7797f..e16c8c98b 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -16,7 +16,10 @@ use axum::{extract::State, response::IntoResponse, Json, TypedHeader}; use headers::{authorization::Bearer, Authorization}; use hyper::StatusCode; use mas_data_model::TokenType; -use mas_storage::{compat::compat_logout, Clock}; +use mas_storage::{ + compat::{end_compat_session, find_compat_access_token, lookup_compat_session}, + Clock, +}; use sqlx::PgPool; use thiserror::Error; @@ -36,12 +39,10 @@ pub enum RouteError { #[error("Invalid access token")] InvalidAuthorization, - - #[error("Logout failed")] - LogoutFailed, } impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -56,7 +57,7 @@ impl IntoResponse for RouteError { error: "Missing access token", status: StatusCode::UNAUTHORIZED, }, - Self::InvalidAuthorization | Self::LogoutFailed | Self::TokenFormat(_) => MatrixError { + Self::InvalidAuthorization | Self::TokenFormat(_) => MatrixError { errcode: "M_UNKNOWN_TOKEN", error: "Invalid access token", status: StatusCode::UNAUTHORIZED, @@ -71,7 +72,7 @@ pub(crate) async fn post( maybe_authorization: Option>>, ) -> Result { let clock = Clock::default(); - let mut conn = pool.acquire().await?; + let mut txn = pool.begin().await?; let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; @@ -82,9 +83,19 @@ pub(crate) async fn post( return Err(RouteError::InvalidAuthorization); } - if !compat_logout(&mut conn, &clock, token).await? { - return Err(RouteError::LogoutFailed); - } + let token = find_compat_access_token(&mut txn, token) + .await? + .filter(|t| t.is_valid(clock.now())) + .ok_or(RouteError::InvalidAuthorization)?; + + let session = lookup_compat_session(&mut txn, token.session_id) + .await? + .filter(|s| s.is_valid()) + .ok_or(RouteError::InvalidAuthorization)?; + + end_compat_session(&mut txn, &clock, session).await?; + + txn.commit().await?; Ok(Json(serde_json::json!({}))) } diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 912a0f1a5..58e9eb8e4 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -18,7 +18,8 @@ use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; use mas_storage::compat::{ add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token, - expire_compat_access_token, lookup_active_compat_refresh_token, + expire_compat_access_token, find_compat_refresh_token, lookup_compat_access_token, + lookup_compat_session, }; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; @@ -40,17 +41,26 @@ pub enum RouteError { #[error("invalid token")] InvalidToken, + + #[error("refresh token already consumed")] + RefreshTokenConsumed, + + #[error("invalid session")] + InvalidSession, + + #[error("unknown session")] + UnknownSession, } impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { - Self::Internal(_) => MatrixError { + Self::Internal(_) | Self::UnknownSession => MatrixError { errcode: "M_UNKNOWN", error: "Internal error", status: StatusCode::INTERNAL_SERVER_ERROR, }, - Self::InvalidToken => MatrixError { + Self::InvalidToken | Self::InvalidSession | Self::RefreshTokenConsumed => MatrixError { errcode: "M_UNKNOWN_TOKEN", error: "Invalid refresh token", status: StatusCode::UNAUTHORIZED, @@ -91,10 +101,25 @@ pub(crate) async fn post( return Err(RouteError::InvalidToken); } - let (refresh_token, access_token, session) = - lookup_active_compat_refresh_token(&mut txn, &input.refresh_token) - .await? - .ok_or(RouteError::InvalidToken)?; + let refresh_token = find_compat_refresh_token(&mut txn, &input.refresh_token) + .await? + .ok_or(RouteError::InvalidToken)?; + + if !refresh_token.is_valid() { + return Err(RouteError::RefreshTokenConsumed); + } + + let session = lookup_compat_session(&mut txn, refresh_token.session_id) + .await? + .ok_or(RouteError::UnknownSession)?; + + if !session.is_valid() { + return Err(RouteError::InvalidSession); + } + + let access_token = lookup_compat_access_token(&mut txn, refresh_token.access_token_id) + .await? + .filter(|t| t.is_valid(clock.now())); let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng); let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng); @@ -120,7 +145,10 @@ pub(crate) async fn post( .await?; consume_compat_refresh_token(&mut txn, &clock, refresh_token).await?; - expire_compat_access_token(&mut txn, &clock, access_token).await?; + + if let Some(access_token) = access_token { + expire_compat_access_token(&mut txn, &clock, access_token).await?; + } txn.commit().await?; diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 71f3f1488..2cf34c979 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -22,7 +22,7 @@ use mas_data_model::{TokenFormatError, TokenType}; use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_keystore::Encrypter; use mas_storage::{ - compat::{lookup_active_compat_access_token, lookup_active_compat_refresh_token}, + compat::{find_compat_access_token, find_compat_refresh_token, lookup_compat_session}, oauth2::{ access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, }, @@ -194,6 +194,7 @@ pub(crate) async fn post( jti: None, } } + TokenType::RefreshToken => { let (token, session) = lookup_active_refresh_token(&mut conn, token) .await? @@ -221,9 +222,16 @@ pub(crate) async fn post( jti: None, } } + TokenType::CompatAccessToken => { - let (token, session) = lookup_active_compat_access_token(&mut conn, &clock, token) + let token = find_compat_access_token(&mut conn, token) .await? + .filter(|t| t.is_valid(clock.now())) + .ok_or(RouteError::UnknownToken)?; + + let session = lookup_compat_session(&mut conn, token.session_id) + .await? + .filter(|s| s.is_valid()) .ok_or(RouteError::UnknownToken)?; let user = conn @@ -251,11 +259,17 @@ pub(crate) async fn post( jti: None, } } + TokenType::CompatRefreshToken => { - let (refresh_token, _access_token, session) = - lookup_active_compat_refresh_token(&mut conn, token) - .await? - .ok_or(RouteError::UnknownToken)?; + let refresh_token = find_compat_refresh_token(&mut conn, token) + .await? + .filter(|t| t.is_valid()) + .ok_or(RouteError::UnknownToken)?; + + let session = lookup_compat_session(&mut conn, refresh_token.session_id) + .await? + .filter(|s| s.is_valid()) + .ok_or(RouteError::UnknownToken)?; let user = conn .user() diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index e8a33bd24..d31c59a35 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -1,91 +1,5 @@ { "db": "PostgreSQL", - "021f845e564500457e2e0c8614beb1d9fd10b4b5f13515478f7ca25b5474d016": { - "describe": { - "columns": [ - { - "name": "compat_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_refresh_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_id", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "compat_access_token", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "compat_access_token_created_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_expires_at", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 7, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at", - "ordinal": 9, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "user_id", - "ordinal": 11, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - true, - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT cr.compat_refresh_token_id\n , cr.refresh_token AS \"compat_refresh_token\"\n , cr.created_at AS \"compat_refresh_token_created_at\"\n , ct.compat_access_token_id\n , ct.access_token AS \"compat_access_token\"\n , ct.created_at AS \"compat_access_token_created_at\"\n , ct.expires_at AS \"compat_access_token_expires_at\"\n , cs.compat_session_id\n , cs.created_at AS \"compat_session_created_at\"\n , cs.finished_at AS \"compat_session_finished_at\"\n , cs.device_id AS \"compat_session_device_id\"\n , cs.user_id\n\n FROM compat_refresh_tokens cr\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN compat_access_tokens ct\n USING (compat_access_token_id)\n\n WHERE cr.refresh_token = $1\n AND cr.consumed_at IS NULL\n AND cs.finished_at IS NULL\n " - }, "08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35": { "describe": { "columns": [ @@ -312,6 +226,50 @@ }, "query": "\n INSERT INTO compat_refresh_tokens\n (compat_refresh_token_id, compat_session_id,\n compat_access_token_id, refresh_token, created_at)\n VALUES ($1, $2, $3, $4, $5)\n " }, + "3cf8e061206620071b39d0262cd165bb367b12b8e904180730d8acfa5af3d4b9": { + "describe": { + "columns": [ + { + "name": "compat_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "device_id", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "finished_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_session_id\n , device_id\n , user_id\n , created_at\n , finished_at\n FROM compat_sessions\n WHERE compat_session_id = $1\n " + }, "3d66f3121b11ce923b9c60609b510a8ca899640e78cc8f5b03168622928ffe94": { "describe": { "columns": [], @@ -507,7 +465,20 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n oauth2_session_id = os.oauth2_session_id,\n fulfilled_at = os.created_at\n FROM oauth2_sessions os\n WHERE\n og.oauth2_authorization_grant_id = $1\n AND os.oauth2_session_id = $2\n RETURNING fulfilled_at AS \"fulfilled_at!: DateTime\"\n " }, - "4f8b0cd13d9488c2dd0f183d090d3856da15dcdb57a8c113febbee665a2a3ac5": { + "4c4dbb846bb98d84f6b7f886f8af9833c7efe27b8b4f297077887232bef322ee": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n WHERE compat_session_id = $1\n " + }, + "4f080990eb6dd9f6128f3a1aee195b99d5f286fa0f6c27d744f73848343879d4": { "describe": { "columns": [ { @@ -541,29 +512,9 @@ "type_info": "Timestamptz" }, { - "name": "compat_session_id?", + "name": "compat_session_id", "ordinal": 6, "type_info": "Uuid" - }, - { - "name": "compat_session_created_at?", - "ordinal": 7, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at?", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id?", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_id?", - "ordinal": 10, - "type_info": "Uuid" } ], "nullable": [ @@ -573,19 +524,15 @@ false, true, true, - false, - false, - true, - false, - false + true ], "parameters": { "Left": [ - "Text" + "Uuid" ] } }, - "query": "\n SELECT cl.compat_sso_login_id\n , cl.login_token AS \"compat_sso_login_token\"\n , cl.redirect_uri AS \"compat_sso_login_redirect_uri\"\n , cl.created_at AS \"compat_sso_login_created_at\"\n , cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\"\n , cl.exchanged_at AS \"compat_sso_login_exchanged_at\"\n , cs.compat_session_id AS \"compat_session_id?\"\n , cs.created_at AS \"compat_session_created_at?\"\n , cs.finished_at AS \"compat_session_finished_at?\"\n , cs.device_id AS \"compat_session_device_id?\"\n , cs.user_id AS \"user_id?\"\n FROM compat_sso_logins cl\n LEFT JOIN compat_sessions cs\n USING (compat_session_id)\n WHERE cl.login_token = $1\n " + "query": "\n SELECT cl.compat_sso_login_id\n , cl.login_token AS \"compat_sso_login_token\"\n , cl.redirect_uri AS \"compat_sso_login_redirect_uri\"\n , cl.created_at AS \"compat_sso_login_created_at\"\n , cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\"\n , cl.exchanged_at AS \"compat_sso_login_exchanged_at\"\n , cl.compat_session_id AS \"compat_session_id\"\n\n FROM compat_sso_logins cl\n WHERE cl.compat_sso_login_id = $1\n " }, "51158bfcaa1a8d8e051bffe7c5ba0369bf53fb162f7622626054e89e68fc07bd": { "describe": { @@ -608,27 +555,6 @@ }, "query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n " }, - "559a486756d08d101eb7188ef6637b9d24c024d056795b8121f7f04a7f9db6a3": { - "describe": { - "columns": [ - { - "name": "compat_session_id", - "ordinal": 0, - "type_info": "Uuid" - } - ], - "nullable": [ - false - ], - "parameters": { - "Left": [ - "Text", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n FROM compat_access_tokens ca\n WHERE ca.access_token = $1\n AND ca.compat_session_id = cs.compat_session_id\n AND cs.finished_at IS NULL\n RETURNING cs.compat_session_id\n " - }, "583ae9a0db9cd55fa57a179339550f3dab1bfc76f35ad488e1560ea37f7ed029": { "describe": { "columns": [], @@ -1296,86 +1222,6 @@ }, "query": "\n UPDATE user_email_confirmation_codes\n SET consumed_at = $2\n WHERE user_email_confirmation_code_id = $1\n " }, - "92ef320b75ca479ed1a38f6d654fdb953431188a8654c806fd5f98444b00c012": { - "describe": { - "columns": [ - { - "name": "compat_sso_login_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_sso_login_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_sso_login_redirect_uri", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "compat_sso_login_created_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_fulfilled_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_exchanged_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id?", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at?", - "ordinal": 7, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at?", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id?", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_id?", - "ordinal": 10, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT cl.compat_sso_login_id\n , cl.login_token AS \"compat_sso_login_token\"\n , cl.redirect_uri AS \"compat_sso_login_redirect_uri\"\n , cl.created_at AS \"compat_sso_login_created_at\"\n , cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\"\n , cl.exchanged_at AS \"compat_sso_login_exchanged_at\"\n , cs.compat_session_id AS \"compat_session_id?\"\n , cs.created_at AS \"compat_session_created_at?\"\n , cs.finished_at AS \"compat_session_finished_at?\"\n , cs.device_id AS \"compat_session_device_id?\"\n , cs.user_id AS \"user_id?\"\n\n FROM compat_sso_logins cl\n LEFT JOIN compat_sessions cs\n USING (compat_session_id)\n WHERE cl.compat_sso_login_id = $1\n " - }, "94fd96446b237c87bd6bf741f3c42b37ee751b87b7fcc459602bdf8c46962443": { "describe": { "columns": [ @@ -1868,6 +1714,106 @@ }, "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " }, + "c3e60701299be7728108b8967ec5396fb186adaac360d6a0152d25e4a4f46f46": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE compat_access_token_id = $1\n " + }, + "c78246fc8737491352f71ea9410e79df8de88596c8197405cda36eb8c8187810": { + "describe": { + "columns": [ + { + "name": "compat_sso_login_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "compat_sso_login_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "compat_sso_login_redirect_uri", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "compat_sso_login_created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_sso_login_fulfilled_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "compat_sso_login_exchanged_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 6, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + true + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT cl.compat_sso_login_id\n , cl.login_token AS \"compat_sso_login_token\"\n , cl.redirect_uri AS \"compat_sso_login_redirect_uri\"\n , cl.created_at AS \"compat_sso_login_created_at\"\n , cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\"\n , cl.exchanged_at AS \"compat_sso_login_exchanged_at\"\n , cl.compat_session_id AS \"compat_session_id\"\n FROM compat_sso_logins cl\n WHERE cl.login_token = $1\n " + }, "c88376abdba124ff0487a9a69d2345c7d69d7394f355111ec369cfa6d45fb40f": { "describe": { "columns": [], @@ -1894,6 +1840,56 @@ }, "query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)\n " }, + "ca63558e877bd115aa7ca24de0cc3f78a13cb55105758fe0bd930da513f75504": { + "describe": { + "columns": [ + { + "name": "compat_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "compat_access_token_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE refresh_token = $1\n " + }, "caf54e4659306a746747aa61906bdb2cb8da51176e90435aa8b9754ebf3e4d60": { "describe": { "columns": [], @@ -1909,6 +1905,50 @@ }, "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " }, + "cf43b82bdf534400f900cff3c5356083db0f9e5407e288b64f43d7ac100de058": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE access_token = $1\n " + }, "d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c": { "describe": { "columns": [ @@ -1944,75 +1984,6 @@ }, "query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n " }, - "d4c6c070a0cd889cef9e0cfd65c64522a03f0bae12ee7c6b74343ec8f38d24c1": { - "describe": { - "columns": [ - { - "name": "compat_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_access_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "user_id!", - "ordinal": 8, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - true, - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Text", - "Timestamptz" - ] - } - }, - "query": "\n SELECT ct.compat_access_token_id\n , ct.access_token AS \"compat_access_token\"\n , ct.created_at AS \"compat_access_token_created_at\"\n , ct.expires_at AS \"compat_access_token_expires_at\"\n , cs.compat_session_id\n , cs.created_at AS \"compat_session_created_at\"\n , cs.finished_at AS \"compat_session_finished_at\"\n , cs.device_id AS \"compat_session_device_id\"\n , cs.user_id AS \"user_id!\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n\n WHERE ct.access_token = $1\n AND (ct.expires_at < $2 OR ct.expires_at IS NULL)\n AND cs.finished_at IS NULL \n " - }, "d8677b3b6ee594c230fad98c1aa1c6e3d983375bf5b701c7b52468e7f906abf9": { "describe": { "columns": [], diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs index c328f9e13..3befa8dda 100644 --- a/crates/storage/src/compat.rs +++ b/crates/storage/src/compat.rs @@ -14,8 +14,8 @@ use chrono::{DateTime, Duration, Utc}; use mas_data_model::{ - CompatAccessToken, CompatRefreshToken, CompatSession, CompatSessionState, CompatSsoLogin, - CompatSsoLoginState, Device, User, + CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, + CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, User, }; use rand::Rng; use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder}; @@ -29,71 +29,47 @@ use crate::{ Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; -struct CompatAccessTokenLookup { - compat_access_token_id: Uuid, - compat_access_token: String, - compat_access_token_created_at: DateTime, - compat_access_token_expires_at: Option>, +struct CompatSessionLookup { compat_session_id: Uuid, - compat_session_created_at: DateTime, - compat_session_finished_at: Option>, - compat_session_device_id: String, + device_id: String, user_id: Uuid, + created_at: DateTime, + finished_at: Option>, } #[tracing::instrument(skip_all, err)] -pub async fn lookup_active_compat_access_token( +pub async fn lookup_compat_session( executor: impl PgExecutor<'_>, - clock: &Clock, - token: &str, -) -> Result, DatabaseError> { + session_id: Ulid, +) -> Result, DatabaseError> { let res = sqlx::query_as!( - CompatAccessTokenLookup, + CompatSessionLookup, r#" - SELECT ct.compat_access_token_id - , ct.access_token AS "compat_access_token" - , ct.created_at AS "compat_access_token_created_at" - , ct.expires_at AS "compat_access_token_expires_at" - , cs.compat_session_id - , cs.created_at AS "compat_session_created_at" - , cs.finished_at AS "compat_session_finished_at" - , cs.device_id AS "compat_session_device_id" - , cs.user_id AS "user_id!" - - FROM compat_access_tokens ct - INNER JOIN compat_sessions cs - USING (compat_session_id) - - WHERE ct.access_token = $1 - AND (ct.expires_at < $2 OR ct.expires_at IS NULL) - AND cs.finished_at IS NULL + SELECT compat_session_id + , device_id + , user_id + , created_at + , finished_at + FROM compat_sessions + WHERE compat_session_id = $1 "#, - token, - clock.now(), + Uuid::from(session_id), ) .fetch_one(executor) - .instrument(info_span!("Fetch compat access token")) .await .to_option()?; let Some(res) = res else { return Ok(None) }; - let token = CompatAccessToken { - id: res.compat_access_token_id.into(), - token: res.compat_access_token, - created_at: res.compat_access_token_created_at, - expires_at: res.compat_access_token_expires_at, - }; - let id = res.compat_session_id.into(); - let device = Device::try_from(res.compat_session_device_id).map_err(|e| { + let device = Device::try_from(res.device_id).map_err(|e| { DatabaseInconsistencyError::on("compat_sessions") .column("device_id") .row(id) .source(e) })?; - let state = match res.compat_session_finished_at { + let state = match res.finished_at { None => CompatSessionState::Valid, Some(finished_at) => CompatSessionState::Finished { finished_at }, }; @@ -103,103 +79,148 @@ pub async fn lookup_active_compat_access_token( state, user_id: res.user_id.into(), device, - created_at: res.compat_session_created_at, + created_at: res.created_at, }; - Ok(Some((token, session))) + Ok(Some(session)) } -pub struct CompatRefreshTokenLookup { - compat_refresh_token_id: Uuid, - compat_refresh_token: String, - compat_refresh_token_created_at: DateTime, +struct CompatAccessTokenLookup { compat_access_token_id: Uuid, - compat_access_token: String, - compat_access_token_created_at: DateTime, - compat_access_token_expires_at: Option>, + access_token: String, + created_at: DateTime, + expires_at: Option>, compat_session_id: Uuid, - compat_session_created_at: DateTime, - compat_session_finished_at: Option>, - compat_session_device_id: String, - user_id: Uuid, +} + +impl From for CompatAccessToken { + fn from(value: CompatAccessTokenLookup) -> Self { + Self { + id: value.compat_access_token_id.into(), + session_id: value.compat_session_id.into(), + token: value.access_token, + created_at: value.created_at, + expires_at: value.expires_at, + } + } } #[tracing::instrument(skip_all, err)] -#[allow(clippy::type_complexity)] -pub async fn lookup_active_compat_refresh_token( +pub async fn find_compat_access_token( executor: impl PgExecutor<'_>, token: &str, -) -> Result, DatabaseError> { +) -> Result, DatabaseError> { let res = sqlx::query_as!( - CompatRefreshTokenLookup, + CompatAccessTokenLookup, r#" - SELECT cr.compat_refresh_token_id - , cr.refresh_token AS "compat_refresh_token" - , cr.created_at AS "compat_refresh_token_created_at" - , ct.compat_access_token_id - , ct.access_token AS "compat_access_token" - , ct.created_at AS "compat_access_token_created_at" - , ct.expires_at AS "compat_access_token_expires_at" - , cs.compat_session_id - , cs.created_at AS "compat_session_created_at" - , cs.finished_at AS "compat_session_finished_at" - , cs.device_id AS "compat_session_device_id" - , cs.user_id + SELECT compat_access_token_id + , access_token + , created_at + , expires_at + , compat_session_id - FROM compat_refresh_tokens cr - INNER JOIN compat_sessions cs - USING (compat_session_id) - INNER JOIN compat_access_tokens ct - USING (compat_access_token_id) + FROM compat_access_tokens - WHERE cr.refresh_token = $1 - AND cr.consumed_at IS NULL - AND cs.finished_at IS NULL + WHERE access_token = $1 + "#, + token, + ) + .fetch_one(executor) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) +} + +#[tracing::instrument( + skip_all, + fields( + compat_access_token.id = %id, + ), + err, +)] +pub async fn lookup_compat_access_token( + executor: impl PgExecutor<'_>, + id: Ulid, +) -> Result, DatabaseError> { + let res = sqlx::query_as!( + CompatAccessTokenLookup, + r#" + SELECT compat_access_token_id + , access_token + , created_at + , expires_at + , compat_session_id + + FROM compat_access_tokens + + WHERE compat_access_token_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(executor) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) +} + +pub struct CompatRefreshTokenLookup { + compat_refresh_token_id: Uuid, + refresh_token: String, + created_at: DateTime, + consumed_at: Option>, + compat_access_token_id: Uuid, + compat_session_id: Uuid, +} + +#[tracing::instrument(skip_all, err)] +#[allow(clippy::type_complexity)] +pub async fn find_compat_refresh_token( + executor: impl PgExecutor<'_>, + token: &str, +) -> Result, DatabaseError> { + let res = sqlx::query_as!( + CompatRefreshTokenLookup, + r#" + SELECT compat_refresh_token_id + , refresh_token + , created_at + , consumed_at + , compat_session_id + , compat_access_token_id + + FROM compat_refresh_tokens + + WHERE refresh_token = $1 "#, token, ) .fetch_one(executor) - .instrument(info_span!("Fetch compat refresh token")) .await .to_option()?; let Some(res) = res else { return Ok(None); }; + let state = match res.consumed_at { + None => CompatRefreshTokenState::Valid, + Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at }, + }; + let refresh_token = CompatRefreshToken { id: res.compat_refresh_token_id.into(), - token: res.compat_refresh_token, - created_at: res.compat_refresh_token_created_at, - }; - - let access_token = CompatAccessToken { - id: res.compat_access_token_id.into(), - token: res.compat_access_token, - created_at: res.compat_access_token_created_at, - expires_at: res.compat_access_token_expires_at, - }; - - let id = res.compat_session_id.into(); - let device = Device::try_from(res.compat_session_device_id).map_err(|e| { - DatabaseInconsistencyError::on("compat_sessions") - .column("device_id") - .row(id) - .source(e) - })?; - - let state = match res.compat_session_finished_at { - None => CompatSessionState::Valid, - Some(finished_at) => CompatSessionState::Finished { finished_at }, - }; - - let session = CompatSession { - id, state, - user_id: res.user_id.into(), - device, - created_at: res.compat_session_created_at, + session_id: res.compat_session_id.into(), + access_token_id: res.compat_access_token_id.into(), + token: res.refresh_token, + created_at: res.created_at, }; - Ok(Some((refresh_token, access_token, session))) + Ok(Some(refresh_token)) } #[tracing::instrument( @@ -244,6 +265,7 @@ pub async fn add_compat_access_token( Ok(CompatAccessToken { id, + session_id: session.id, token, created_at, expires_at, @@ -320,6 +342,9 @@ pub async fn add_compat_refresh_token( Ok(CompatRefreshToken { id, + state: CompatRefreshTokenState::default(), + session_id: session.id, + access_token_id: access_token.id, token, created_at, }) @@ -327,42 +352,35 @@ pub async fn add_compat_refresh_token( #[tracing::instrument( skip_all, - fields(compat_session.id), + fields(%compat_session.id), err, )] -pub async fn compat_logout( +pub async fn end_compat_session( executor: impl PgExecutor<'_>, clock: &Clock, - token: &str, -) -> Result { + compat_session: CompatSession, +) -> Result { let finished_at = clock.now(); - // TODO: this does not check for token expiration - let res = sqlx::query_scalar!( + + let res = sqlx::query!( r#" UPDATE compat_sessions cs SET finished_at = $2 - FROM compat_access_tokens ca - WHERE ca.access_token = $1 - AND ca.compat_session_id = cs.compat_session_id - AND cs.finished_at IS NULL - RETURNING cs.compat_session_id + WHERE compat_session_id = $1 "#, - token, + Uuid::from(compat_session.id), finished_at, ) - .fetch_one(executor) - .await - .to_option()?; + .execute(executor) + .await?; - if let Some(compat_session_id) = res { - tracing::Span::current().record( - "compat_session.id", - tracing::field::display(compat_session_id), - ); - Ok(true) - } else { - Ok(false) - } + DatabaseError::ensure_affected_rows(&res, 1)?; + + let compat_session = compat_session + .finish(finished_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(compat_session) } #[tracing::instrument( @@ -445,10 +463,6 @@ struct CompatSsoLoginLookup { compat_sso_login_fulfilled_at: Option>, compat_sso_login_exchanged_at: Option>, compat_session_id: Option, - compat_session_created_at: Option>, - compat_session_finished_at: Option>, - compat_session_device_id: Option, - user_id: Option, } impl TryFrom for CompatSsoLogin { @@ -463,58 +477,21 @@ impl TryFrom for CompatSsoLogin { .source(e) })?; - let session = match ( - res.compat_session_id, - res.compat_session_device_id, - res.compat_session_created_at, - res.compat_session_finished_at, - res.user_id, - ) { - (Some(id), Some(device_id), Some(created_at), finished_at, Some(user_id)) => { - let id = id.into(); - let device = Device::try_from(device_id).map_err(|e| { - DatabaseInconsistencyError::on("compat_sessions") - .column("device") - .row(id) - .source(e) - })?; - - let state = match finished_at { - None => CompatSessionState::Valid, - Some(finished_at) => CompatSessionState::Finished { finished_at }, - }; - - Some(CompatSession { - id, - state, - user_id: user_id.into(), - device, - created_at, - }) - } - (None, None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("compat_sso_logins") - .column("compat_session_id") - .row(id)) - } - }; - let state = match ( res.compat_sso_login_fulfilled_at, res.compat_sso_login_exchanged_at, - session, + res.compat_session_id, ) { (None, None, None) => CompatSsoLoginState::Pending, - (Some(fulfilled_at), None, Some(session)) => CompatSsoLoginState::Fulfilled { + (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled { fulfilled_at, - session, + session_id: session_id.into(), }, - (Some(fulfilled_at), Some(exchanged_at), Some(session)) => { + (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => { CompatSsoLoginState::Exchanged { fulfilled_at, exchanged_at, - session, + session_id: session_id.into(), } } _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), @@ -550,15 +527,9 @@ pub async fn get_compat_sso_login_by_id( , cl.created_at AS "compat_sso_login_created_at" , cl.fulfilled_at AS "compat_sso_login_fulfilled_at" , cl.exchanged_at AS "compat_sso_login_exchanged_at" - , cs.compat_session_id AS "compat_session_id?" - , cs.created_at AS "compat_session_created_at?" - , cs.finished_at AS "compat_session_finished_at?" - , cs.device_id AS "compat_session_device_id?" - , cs.user_id AS "user_id?" + , cl.compat_session_id AS "compat_session_id" FROM compat_sso_logins cl - LEFT JOIN compat_sessions cs - USING (compat_session_id) WHERE cl.compat_sso_login_id = $1 "#, Uuid::from(id), @@ -589,8 +560,6 @@ pub async fn get_paginated_user_compat_sso_logins( first: Option, last: Option, ) -> Result<(bool, bool, Vec), DatabaseError> { - // TODO: this queries too much (like user info) which we probably don't need - // because we already have them let mut query = QueryBuilder::new( r#" SELECT cl.compat_sso_login_id @@ -599,14 +568,8 @@ pub async fn get_paginated_user_compat_sso_logins( , cl.created_at AS "compat_sso_login_created_at" , cl.fulfilled_at AS "compat_sso_login_fulfilled_at" , cl.exchanged_at AS "compat_sso_login_exchanged_at" - , cs.compat_session_id AS "compat_session_id" - , cs.created_at AS "compat_session_created_at" - , cs.finished_at AS "compat_session_finished_at" - , cs.device_id AS "compat_session_device_id" - , cs.user_id + , cl.compat_session_id AS "compat_session_id" FROM compat_sso_logins cl - LEFT JOIN compat_sessions cs - USING (compat_session_id) "#, ); @@ -645,14 +608,8 @@ pub async fn get_compat_sso_login_by_token( , cl.created_at AS "compat_sso_login_created_at" , cl.fulfilled_at AS "compat_sso_login_fulfilled_at" , cl.exchanged_at AS "compat_sso_login_exchanged_at" - , cs.compat_session_id AS "compat_session_id?" - , cs.created_at AS "compat_session_created_at?" - , cs.finished_at AS "compat_session_finished_at?" - , cs.device_id AS "compat_session_device_id?" - , cs.user_id AS "user_id?" + , cl.compat_session_id AS "compat_session_id" FROM compat_sso_logins cl - LEFT JOIN compat_sessions cs - USING (compat_session_id) WHERE cl.login_token = $1 "#, token, @@ -739,7 +696,7 @@ pub async fn fullfill_compat_sso_login( let fulfilled_at = clock.now(); let compat_sso_login = compat_sso_login - .fulfill(fulfilled_at, session) + .fulfill(fulfilled_at, &session) .map_err(DatabaseError::to_invalid_operation)?; sqlx::query!( r#" From 3a1fc8982cac429cbcf8f9c7346a34120ee01db0 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 11 Jan 2023 12:14:52 +0100 Subject: [PATCH 17/45] storage: cleanup access/refresh token lookups --- crates/axum-utils/src/user_authorization.rs | 30 +- crates/data-model/src/lib.rs | 4 +- crates/data-model/src/tokens.rs | 112 +++++++- crates/handlers/src/oauth2/introspection.rs | 29 +- crates/handlers/src/oauth2/token.rs | 22 +- crates/handlers/src/oauth2/userinfo.rs | 4 +- crates/storage/sqlx-data.json | 286 ++++++++++---------- crates/storage/src/oauth2/access_token.rs | 145 ++++++---- crates/storage/src/oauth2/refresh_token.rs | 83 +++--- 9 files changed, 452 insertions(+), 263 deletions(-) diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index 923ef34d7..a76e1e9a1 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -24,10 +24,14 @@ use axum::{ response::{IntoResponse, Response}, BoxError, }; +use chrono::{DateTime, Utc}; use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName}; use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode}; use mas_data_model::Session; -use mas_storage::{oauth2::access_token::lookup_active_access_token, DatabaseError}; +use mas_storage::{ + oauth2::{access_token::find_access_token, OAuth2SessionRepository}, + DatabaseError, Repository, +}; use serde::{de::DeserializeOwned, Deserialize}; use sqlx::PgConnection; use thiserror::Error; @@ -49,7 +53,7 @@ enum AccessToken { } impl AccessToken { - pub async fn fetch( + async fn fetch( &self, conn: &mut PgConnection, ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> { @@ -58,7 +62,13 @@ impl AccessToken { AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), }; - let (token, session) = lookup_active_access_token(conn, token.as_str()) + let token = find_access_token(conn, token.as_str()) + .await? + .ok_or(AuthorizationVerificationError::InvalidToken)?; + + let session = conn + .oauth2_session() + .lookup(token.session_id) .await? .ok_or(AuthorizationVerificationError::InvalidToken)?; @@ -77,13 +87,18 @@ impl UserAuthorization { pub async fn protected_form( self, conn: &mut PgConnection, + now: DateTime, ) -> Result<(Session, F), AuthorizationVerificationError> { let form = match self.form { Some(f) => f, None => return Err(AuthorizationVerificationError::MissingForm), }; - let (_token, session) = self.access_token.fetch(conn).await?; + let (token, session) = self.access_token.fetch(conn).await?; + + if !token.is_valid(now) || !session.is_valid() { + return Err(AuthorizationVerificationError::InvalidToken); + } Ok((session, form)) } @@ -92,8 +107,13 @@ impl UserAuthorization { pub async fn protected( self, conn: &mut PgConnection, + now: DateTime, ) -> Result { - let (_token, session) = self.access_token.fetch(conn).await?; + let (token, session) = self.access_token.fetch(conn).await?; + + if !token.is_valid(now) || !session.is_valid() { + return Err(AuthorizationVerificationError::InvalidToken); + } Ok(session) } diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 8454f05d3..bde11fbed 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -44,7 +44,9 @@ pub use self::{ AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState, }, - tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, + tokens::{ + AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType, + }, upstream_oauth2::{ UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider, diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index 7b058820d..120f293e7 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -19,23 +19,133 @@ use rand::{distributions::Alphanumeric, Rng}; use thiserror::Error; use ulid::Ulid; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum AccessTokenState { + #[default] + Valid, + Revoked { + revoked_at: DateTime, + }, +} + +impl AccessTokenState { + fn revoke(self, revoked_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Revoked { revoked_at }), + Self::Revoked { .. } => Err(InvalidTransitionError), + } + } + + /// Returns `true` if the refresh token state is [`Valid`]. + /// + /// [`Valid`]: RefreshTokenState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the refresh token state is [`Revoked`]. + /// + /// [`Revoked`]: RefreshTokenState::Revoked + #[must_use] + pub fn is_revoked(&self) -> bool { + matches!(self, Self::Revoked { .. }) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct AccessToken { pub id: Ulid, - pub jti: String, + pub state: AccessTokenState, + pub session_id: Ulid, pub access_token: String, pub created_at: DateTime, pub expires_at: DateTime, } +impl AccessToken { + #[must_use] + pub fn jti(&self) -> String { + self.id.to_string() + } + + #[must_use] + pub fn is_valid(&self, now: DateTime) -> bool { + self.state.is_valid() && self.expires_at > now + } + + pub fn revoke(mut self, revoked_at: DateTime) -> Result { + self.state = self.state.revoke(revoked_at)?; + Ok(self) + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum RefreshTokenState { + #[default] + Valid, + Consumed { + consumed_at: DateTime, + }, +} + +impl RefreshTokenState { + fn consume(self, consumed_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Consumed { consumed_at }), + Self::Consumed { .. } => Err(InvalidTransitionError), + } + } + + /// Returns `true` if the refresh token state is [`Valid`]. + /// + /// [`Valid`]: RefreshTokenState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the refresh token state is [`Consumed`]. + /// + /// [`Consumed`]: RefreshTokenState::Consumed + #[must_use] + pub fn is_consumed(&self) -> bool { + matches!(self, Self::Consumed { .. }) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct RefreshToken { pub id: Ulid, + pub state: RefreshTokenState, pub refresh_token: String, + pub session_id: Ulid, pub created_at: DateTime, pub access_token_id: Option, } +impl std::ops::Deref for RefreshToken { + type Target = RefreshTokenState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl RefreshToken { + #[must_use] + pub fn jti(&self) -> String { + self.id.to_string() + } + + pub fn consume(mut self, consumed_at: DateTime) -> Result { + self.state = self.state.consume(consumed_at)?; + Ok(self) + } +} + /// Type of token to generate or validate #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TokenType { diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 2cf34c979..3dec02db1 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -24,7 +24,8 @@ use mas_keystore::Encrypter; use mas_storage::{ compat::{find_compat_access_token, find_compat_refresh_token, lookup_compat_session}, oauth2::{ - access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, + access_token::find_access_token, refresh_token::lookup_refresh_token, + OAuth2SessionRepository, }, user::{BrowserSessionRepository, UserRepository}, Clock, Repository, @@ -168,8 +169,17 @@ pub(crate) async fn post( let reply = match token_type { TokenType::AccessToken => { - let (token, session) = lookup_active_access_token(&mut conn, token) + let token = find_access_token(&mut conn, token) .await? + .filter(|t| t.is_valid(clock.now())) + .ok_or(RouteError::UnknownToken)?; + + let session = conn + .oauth2_session() + .lookup(token.session_id) + .await? + .filter(|s| s.is_valid()) + // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; let browser_session = conn @@ -191,13 +201,22 @@ pub(crate) async fn post( sub: Some(browser_session.user.sub), aud: None, iss: None, - jti: None, + jti: Some(token.jti()), } } TokenType::RefreshToken => { - let (token, session) = lookup_active_refresh_token(&mut conn, token) + let token = lookup_refresh_token(&mut conn, token) .await? + .filter(|t| t.is_valid()) + .ok_or(RouteError::UnknownToken)?; + + let session = conn + .oauth2_session() + .lookup(token.session_id) + .await? + .filter(|s| s.is_valid()) + // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; let browser_session = conn @@ -219,7 +238,7 @@ pub(crate) async fn post( sub: Some(browser_session.user.sub), aud: None, iss: None, - jti: None, + jti: Some(token.jti()), } } diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index eb0e20dde..75ddb4a64 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -33,9 +33,9 @@ use mas_keystore::{Encrypter, Keystore}; use mas_router::UrlBuilder; use mas_storage::{ oauth2::{ - access_token::{add_access_token, revoke_access_token}, + access_token::{add_access_token, lookup_access_token, revoke_access_token}, authorization_grant::{exchange_grant, lookup_grant_by_code}, - refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token}, + refresh_token::{add_refresh_token, consume_refresh_token, lookup_refresh_token}, OAuth2SessionRepository, }, user::BrowserSessionRepository, @@ -374,10 +374,20 @@ async fn refresh_token_grant( ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token) + let refresh_token = lookup_refresh_token(&mut txn, &grant.refresh_token) .await? .ok_or(RouteError::InvalidGrant)?; + let session = txn + .oauth2_session() + .lookup(refresh_token.session_id) + .await? + .ok_or(RouteError::NoSuchOAuthSession)?; + + if !refresh_token.is_valid() || !session.is_valid() { + return Err(RouteError::InvalidGrant); + } + if client.id != session.client_id { // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 return Err(RouteError::InvalidGrant); @@ -407,10 +417,12 @@ async fn refresh_token_grant( ) .await?; - consume_refresh_token(&mut txn, &clock, &refresh_token).await?; + let refresh_token = consume_refresh_token(&mut txn, &clock, refresh_token).await?; if let Some(access_token_id) = refresh_token.access_token_id { - revoke_access_token(&mut txn, &clock, access_token_id).await?; + if let Some(access_token) = lookup_access_token(&mut txn, access_token_id).await? { + revoke_access_token(&mut txn, &clock, access_token).await?; + } } let params = AccessTokenResponse::new(access_token_str) diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index d2b2b6150..49b6c5f1c 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -101,10 +101,10 @@ pub async fn get( State(key_store): State, user_authorization: UserAuthorization, ) -> Result { - let (_clock, mut rng) = crate::clock_and_rng(); + let (clock, mut rng) = crate::clock_and_rng(); let mut conn = pool.acquire().await?; - let session = user_authorization.protected(&mut conn).await?; + let session = user_authorization.protected(&mut conn, clock.now()).await?; let browser_session = conn .browser_session() diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index d31c59a35..5324fa2a5 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -583,74 +583,6 @@ }, "query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n " }, - "5f0e2aec0d7766d3674af3e68417921fec7068e83845e218a4a00d86487557f9": { - "describe": { - "columns": [ - { - "name": "oauth2_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "oauth2_access_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_created_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id!", - "ordinal": 5, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "scope!", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "user_session_id!", - "ordinal": 8, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , os.user_session_id AS \"user_session_id!\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n " - }, "5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": { "describe": { "columns": [], @@ -1612,6 +1544,56 @@ }, "query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n " }, + "b20e846843cf88810fbc0f4b0fa3159117f035841758d682d90c614c374f6059": { + "describe": { + "columns": [ + { + "name": "oauth2_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "revoked_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE oauth2_access_token_id = $1\n " + }, "b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537": { "describe": { "columns": [], @@ -1984,6 +1966,106 @@ }, "query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n " }, + "d1f1aac41bb2f0d194f9b3d846663c267865d0d22dd5fa8a668daf29dca88d36": { + "describe": { + "columns": [ + { + "name": "oauth2_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n " + }, + "d2b1af24f88b2f05eb219f7cbdcfa9680bafe9f77fa1772097875b3718bd1aff": { + "describe": { + "columns": [ + { + "name": "oauth2_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "revoked_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE access_token = $1\n " + }, "d8677b3b6ee594c230fad98c1aa1c6e3d983375bf5b701c7b52468e7f906abf9": { "describe": { "columns": [], @@ -2129,74 +2211,6 @@ }, "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " }, - "e25b8071b59075c4be9fac283410ec4acf771fdf06076ef7bbb11bf086c4bc03": { - "describe": { - "columns": [ - { - "name": "oauth2_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "oauth2_refresh_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_id?", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "oauth2_session_created_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id!", - "ordinal": 5, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "oauth2_session_scope!", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "user_session_id!", - "ordinal": 8, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - true, - false, - false, - false, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT rt.oauth2_refresh_token_id\n , rt.refresh_token AS oauth2_refresh_token\n , rt.created_at AS oauth2_refresh_token_created_at\n , rt.oauth2_access_token_id AS \"oauth2_access_token_id?\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"oauth2_session_scope!\"\n , os.user_session_id AS \"user_session_id!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND os.finished_at IS NULL\n " - }, "e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": { "describe": { "columns": [ diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index 58c13b190..8389dff44 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -13,13 +13,13 @@ // limitations under the License. use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{AccessToken, Session, SessionState}; +use mas_data_model::{AccessToken, AccessTokenState, Session}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; +use crate::{Clock, DatabaseError, LookupResultExt}; #[tracing::instrument( skip_all, @@ -63,8 +63,9 @@ pub async fn add_access_token( Ok(AccessToken { id, + state: AccessTokenState::default(), access_token, - jti: id.to_string(), + session_id: session.id, created_at, expires_at, }) @@ -73,74 +74,59 @@ pub async fn add_access_token( #[derive(Debug)] pub struct OAuth2AccessTokenLookup { oauth2_access_token_id: Uuid, - oauth2_access_token: String, - oauth2_access_token_created_at: DateTime, - oauth2_access_token_expires_at: DateTime, - oauth2_session_created_at: DateTime, oauth2_session_id: Uuid, - oauth2_client_id: Uuid, - scope: String, - user_session_id: Uuid, + access_token: String, + created_at: DateTime, + expires_at: DateTime, + revoked_at: Option>, } -#[allow(clippy::too_many_lines)] -pub async fn lookup_active_access_token( +impl From for AccessToken { + fn from(value: OAuth2AccessTokenLookup) -> Self { + let state = match value.revoked_at { + None => AccessTokenState::Valid, + Some(revoked_at) => AccessTokenState::Revoked { revoked_at }, + }; + + Self { + id: value.oauth2_access_token_id.into(), + state, + session_id: value.oauth2_session_id.into(), + access_token: value.access_token, + created_at: value.created_at, + expires_at: value.expires_at, + } + } +} + +#[tracing::instrument(skip_all, err)] +pub async fn find_access_token( conn: &mut PgConnection, token: &str, -) -> Result, DatabaseError> { +) -> Result, DatabaseError> { let res = sqlx::query_as!( OAuth2AccessTokenLookup, r#" - SELECT at.oauth2_access_token_id - , at.access_token AS "oauth2_access_token" - , at.created_at AS "oauth2_access_token_created_at" - , at.expires_at AS "oauth2_access_token_expires_at" - , os.created_at AS "oauth2_session_created_at" - , os.oauth2_session_id AS "oauth2_session_id!" - , os.oauth2_client_id AS "oauth2_client_id!" - , os.scope AS "scope!" - , os.user_session_id AS "user_session_id!" + SELECT oauth2_access_token_id + , access_token + , created_at + , expires_at + , revoked_at + , oauth2_session_id - FROM oauth2_access_tokens at - INNER JOIN oauth2_sessions os - USING (oauth2_session_id) + FROM oauth2_access_tokens - WHERE at.access_token = $1 - AND at.revoked_at IS NULL - AND os.finished_at IS NULL + WHERE access_token = $1 "#, token, ) .fetch_one(&mut *conn) - .await?; + .await + .to_option()?; - let access_token_id = Ulid::from(res.oauth2_access_token_id); - let access_token = AccessToken { - id: access_token_id, - jti: access_token_id.to_string(), - access_token: res.oauth2_access_token, - created_at: res.oauth2_access_token_created_at, - expires_at: res.oauth2_access_token_expires_at, - }; + let Some(res) = res else { return Ok(None) }; - let session_id = res.oauth2_session_id.into(); - let scope = res.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("scope") - .row(session_id) - .source(e) - })?; - - let session = Session { - id: session_id, - state: SessionState::Valid, - created_at: res.oauth2_session_created_at, - client_id: res.oauth2_client_id.into(), - user_session_id: res.user_session_id.into(), - scope, - }; - - Ok(Some((access_token, session))) + Ok(Some(res.into())) } #[tracing::instrument( @@ -148,11 +134,48 @@ pub async fn lookup_active_access_token( fields(access_token.id = %access_token_id), err, )] +pub async fn lookup_access_token( + conn: &mut PgConnection, + access_token_id: Ulid, +) -> Result, DatabaseError> { + let res = sqlx::query_as!( + OAuth2AccessTokenLookup, + r#" + SELECT oauth2_access_token_id + , access_token + , created_at + , expires_at + , revoked_at + , oauth2_session_id + + FROM oauth2_access_tokens + + WHERE oauth2_access_token_id = $1 + "#, + Uuid::from(access_token_id), + ) + .fetch_one(&mut *conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) +} + +#[tracing::instrument( + skip_all, + fields( + %access_token.id, + session.id = %access_token.session_id, + ), + err, +)] pub async fn revoke_access_token( executor: impl PgExecutor<'_>, clock: &Clock, - access_token_id: Ulid, -) -> Result<(), DatabaseError> { + access_token: AccessToken, +) -> Result { let revoked_at = clock.now(); let res = sqlx::query!( r#" @@ -160,13 +183,17 @@ pub async fn revoke_access_token( SET revoked_at = $2 WHERE oauth2_access_token_id = $1 "#, - Uuid::from(access_token_id), + Uuid::from(access_token.id), revoked_at, ) .execute(executor) .await?; - DatabaseError::ensure_affected_rows(&res, 1) + DatabaseError::ensure_affected_rows(&res, 1)?; + + access_token + .revoke(revoked_at) + .map_err(DatabaseError::to_invalid_operation) } pub async fn cleanup_expired( diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index f49b38e8a..29f6ab342 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -13,13 +13,13 @@ // limitations under the License. use chrono::{DateTime, Utc}; -use mas_data_model::{AccessToken, RefreshToken, Session, SessionState}; +use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; +use crate::{Clock, DatabaseError}; #[tracing::instrument( skip_all, @@ -62,6 +62,8 @@ pub async fn add_refresh_token( Ok(RefreshToken { id, + state: RefreshTokenState::default(), + session_id: session.id, refresh_token, access_token_id: Some(access_token.id), created_at, @@ -70,73 +72,52 @@ pub async fn add_refresh_token( struct OAuth2RefreshTokenLookup { oauth2_refresh_token_id: Uuid, - oauth2_refresh_token: String, - oauth2_refresh_token_created_at: DateTime, + refresh_token: String, + created_at: DateTime, + consumed_at: Option>, oauth2_access_token_id: Option, - oauth2_session_created_at: DateTime, oauth2_session_id: Uuid, - oauth2_client_id: Uuid, - oauth2_session_scope: String, - user_session_id: Uuid, } #[tracing::instrument(skip_all, err)] #[allow(clippy::too_many_lines)] -pub async fn lookup_active_refresh_token( +pub async fn lookup_refresh_token( conn: &mut PgConnection, token: &str, -) -> Result, DatabaseError> { +) -> Result, DatabaseError> { let res = sqlx::query_as!( OAuth2RefreshTokenLookup, r#" - SELECT rt.oauth2_refresh_token_id - , rt.refresh_token AS oauth2_refresh_token - , rt.created_at AS oauth2_refresh_token_created_at - , rt.oauth2_access_token_id AS "oauth2_access_token_id?" - , os.created_at AS "oauth2_session_created_at" - , os.oauth2_session_id AS "oauth2_session_id!" - , os.oauth2_client_id AS "oauth2_client_id!" - , os.scope AS "oauth2_session_scope!" - , os.user_session_id AS "user_session_id!" - FROM oauth2_refresh_tokens rt - INNER JOIN oauth2_sessions os - USING (oauth2_session_id) + SELECT oauth2_refresh_token_id + , refresh_token + , created_at + , consumed_at + , oauth2_access_token_id + , oauth2_session_id + FROM oauth2_refresh_tokens - WHERE rt.refresh_token = $1 - AND rt.consumed_at IS NULL - AND rt.revoked_at IS NULL - AND os.finished_at IS NULL + WHERE refresh_token = $1 "#, token, ) .fetch_one(&mut *conn) .await?; + let state = match res.consumed_at { + None => RefreshTokenState::Valid, + Some(consumed_at) => RefreshTokenState::Consumed { consumed_at }, + }; + let refresh_token = RefreshToken { id: res.oauth2_refresh_token_id.into(), - refresh_token: res.oauth2_refresh_token, - created_at: res.oauth2_refresh_token_created_at, + state, + session_id: res.oauth2_session_id.into(), + refresh_token: res.refresh_token, + created_at: res.created_at, access_token_id: res.oauth2_access_token_id.map(Ulid::from), }; - let session_id = res.oauth2_session_id.into(); - let scope = res.oauth2_session_scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("scope") - .row(session_id) - .source(e) - })?; - - let session = Session { - id: session_id, - state: SessionState::Valid, - created_at: res.oauth2_session_created_at, - client_id: res.oauth2_client_id.into(), - user_session_id: res.user_session_id.into(), - scope, - }; - - Ok(Some((refresh_token, session))) + Ok(Some(refresh_token)) } #[tracing::instrument( @@ -149,8 +130,8 @@ pub async fn lookup_active_refresh_token( pub async fn consume_refresh_token( executor: impl PgExecutor<'_>, clock: &Clock, - refresh_token: &RefreshToken, -) -> Result<(), DatabaseError> { + refresh_token: RefreshToken, +) -> Result { let consumed_at = clock.now(); let res = sqlx::query!( r#" @@ -164,5 +145,9 @@ pub async fn consume_refresh_token( .execute(executor) .await?; - DatabaseError::ensure_affected_rows(&res, 1) + DatabaseError::ensure_affected_rows(&res, 1)?; + + refresh_token + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation) } From af267657c516976529d3de2ac6f6e1fe8c1d5a5c Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 12 Jan 2023 15:41:26 +0100 Subject: [PATCH 18/45] storage: repository pattern for the compat layer --- crates/data-model/src/compat/sso_login.rs | 3 +- crates/graphql/src/model/compat_sessions.rs | 6 +- crates/graphql/src/model/users.rs | 12 +- crates/handlers/src/compat/login.rs | 45 +- .../handlers/src/compat/login_sso_complete.rs | 23 +- .../handlers/src/compat/login_sso_redirect.rs | 8 +- crates/handlers/src/compat/logout.rs | 14 +- crates/handlers/src/compat/refresh.rs | 65 +- crates/handlers/src/oauth2/introspection.rs | 24 +- crates/handlers/src/views/shared.rs | 6 +- crates/storage/sqlx-data.json | 754 +++++++++-------- crates/storage/src/compat.rs | 757 ------------------ crates/storage/src/compat/access_token.rs | 246 ++++++ crates/storage/src/compat/mod.rs | 25 + crates/storage/src/compat/refresh_token.rs | 260 ++++++ crates/storage/src/compat/session.rs | 220 +++++ crates/storage/src/compat/sso_login.rs | 397 +++++++++ crates/storage/src/repository.rs | 64 ++ 18 files changed, 1738 insertions(+), 1191 deletions(-) delete mode 100644 crates/storage/src/compat.rs create mode 100644 crates/storage/src/compat/access_token.rs create mode 100644 crates/storage/src/compat/mod.rs create mode 100644 crates/storage/src/compat/refresh_token.rs create mode 100644 crates/storage/src/compat/session.rs create mode 100644 crates/storage/src/compat/sso_login.rs diff --git a/crates/data-model/src/compat/sso_login.rs b/crates/data-model/src/compat/sso_login.rs index 54fd96b31..ccc7bb370 100644 --- a/crates/data-model/src/compat/sso_login.rs +++ b/crates/data-model/src/compat/sso_login.rs @@ -20,8 +20,9 @@ use url::Url; use super::CompatSession; use crate::InvalidTransitionError; -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] pub enum CompatSsoLoginState { + #[default] Pending, Fulfilled { fulfilled_at: DateTime, diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index 394639cff..3c94c672a 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -15,7 +15,7 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_storage::{compat::lookup_compat_session, user::UserRepository, Repository}; +use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository}; use sqlx::PgPool; use url::Url; @@ -101,7 +101,9 @@ impl CompatSsoLogin { let Some(session_id) = self.0.session_id() else { return Ok(None) }; let mut conn = ctx.data::()?.acquire().await?; - let session = lookup_compat_session(&mut conn, session_id) + let session = conn + .compat_session() + .lookup(session_id) .await? .context("Could not load compat session")?; diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 2f241ced6..b19a1ae12 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -18,6 +18,7 @@ use async_graphql::{ }; use chrono::{DateTime, Utc}; use mas_storage::{ + compat::CompatSsoLoginRepository, oauth2::OAuth2SessionRepository, user::{BrowserSessionRepository, UserEmailRepository}, Repository, UpstreamOAuthLinkRepository, @@ -96,14 +97,13 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::CompatSsoLogin)) .transpose()?; - let (has_previous_page, has_next_page, edges) = - mas_storage::compat::get_paginated_user_compat_sso_logins( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = conn + .compat_sso_login() + .list_paginated(&self.0, before_id, after_id, first, last) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|u| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|u| { Edge::new( OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.id)), CompatSsoLogin(u), diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index f36d520bd..e7376f722 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -18,8 +18,8 @@ use hyper::StatusCode; use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User}; use mas_storage::{ compat::{ - add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token, - lookup_compat_session, mark_compat_sso_login_as_exchanged, start_compat_session, + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + CompatSsoLoginRepository, }, user::{UserPasswordRepository, UserRepository}, Clock, Repository, @@ -224,27 +224,17 @@ pub(crate) async fn post( }; let access_token = TokenType::CompatAccessToken.generate(&mut rng); - let access_token = add_compat_access_token( - &mut txn, - &mut rng, - &clock, - &session, - access_token, - expires_in, - ) - .await?; + let access_token = txn + .compat_access_token() + .add(&mut rng, &clock, &session, access_token, expires_in) + .await?; let refresh_token = if input.refresh_token { let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng); - let refresh_token = add_compat_refresh_token( - &mut txn, - &mut rng, - &clock, - &session, - &access_token, - refresh_token, - ) - .await?; + let refresh_token = txn + .compat_refresh_token() + .add(&mut rng, &clock, &session, &access_token, refresh_token) + .await?; Some(refresh_token.token) } else { None @@ -266,7 +256,9 @@ async fn token_login( clock: &Clock, token: &str, ) -> Result<(CompatSession, User), RouteError> { - let login = get_compat_sso_login_by_token(&mut *txn, token) + let login = txn + .compat_sso_login() + .find_by_token(token) .await? .ok_or(RouteError::InvalidLoginToken)?; @@ -308,7 +300,9 @@ async fn token_login( } }; - let session = lookup_compat_session(&mut *txn, session_id) + let session = txn + .compat_session() + .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; @@ -318,7 +312,7 @@ async fn token_login( .await? .ok_or(RouteError::UserNotFound)?; - mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?; + txn.compat_sso_login().exchange(clock, login).await?; Ok((session, user)) } @@ -374,7 +368,10 @@ async fn user_password_login( // Now that the user credentials have been verified, start a new compat session let device = Device::generate(&mut rng); - let session = start_compat_session(&mut *txn, &mut rng, &clock, &user, device).await?; + let session = txn + .compat_session() + .add(&mut rng, &clock, &user, device) + .await?; Ok((session, user)) } diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index f31856d6c..333524246 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -29,7 +29,10 @@ use mas_axum_utils::{ use mas_data_model::Device; use mas_keystore::Encrypter; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; -use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id}; +use mas_storage::{ + compat::{CompatSessionRepository, CompatSsoLoginRepository}, + Repository, +}; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use serde::{Deserialize, Serialize}; use sqlx::PgPool; @@ -87,7 +90,9 @@ pub async fn get( return Ok((cookie_jar, destination.go()).into_response()); } - let login = get_compat_sso_login_by_id(&mut conn, id) + let login = conn + .compat_sso_login() + .lookup(id) .await? .context("Could not find compat SSO login")?; @@ -149,7 +154,9 @@ pub async fn post( return Ok((cookie_jar, destination.go()).into_response()); } - let login = get_compat_sso_login_by_id(&mut txn, id) + let login = txn + .compat_sso_login() + .lookup(id) .await? .context("Could not find compat SSO login")?; @@ -181,8 +188,14 @@ pub async fn post( }; let device = Device::generate(&mut rng); - let _login = - fullfill_compat_sso_login(&mut txn, &mut rng, &clock, &session.user, login, device).await?; + let compat_session = txn + .compat_session() + .add(&mut rng, &clock, &session.user, device) + .await?; + + txn.compat_sso_login() + .fulfill(&clock, login, &compat_session) + .await?; txn.commit().await?; diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index f90862c72..9c23b733f 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -19,7 +19,7 @@ use axum::{ }; use hyper::StatusCode; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; -use mas_storage::compat::insert_compat_sso_login; +use mas_storage::{compat::CompatSsoLoginRepository, Repository}; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use serde_with::serde; @@ -49,6 +49,7 @@ pub enum RouteError { } impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -80,7 +81,10 @@ pub async fn get( let token = Alphanumeric.sample_string(&mut rng, 32); let mut conn = pool.acquire().await?; - let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?; + let login = conn + .compat_sso_login() + .add(&mut rng, &clock, token, redirect_url) + .await?; Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.id, params.action))) } diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index e16c8c98b..25125c72f 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -17,8 +17,8 @@ use headers::{authorization::Bearer, Authorization}; use hyper::StatusCode; use mas_data_model::TokenType; use mas_storage::{ - compat::{end_compat_session, find_compat_access_token, lookup_compat_session}, - Clock, + compat::{CompatAccessTokenRepository, CompatSessionRepository}, + Clock, Repository, }; use sqlx::PgPool; use thiserror::Error; @@ -83,17 +83,21 @@ pub(crate) async fn post( return Err(RouteError::InvalidAuthorization); } - let token = find_compat_access_token(&mut txn, token) + let token = txn + .compat_access_token() + .find_by_token(token) .await? .filter(|t| t.is_valid(clock.now())) .ok_or(RouteError::InvalidAuthorization)?; - let session = lookup_compat_session(&mut txn, token.session_id) + let session = txn + .compat_session() + .lookup(token.session_id) .await? .filter(|s| s.is_valid()) .ok_or(RouteError::InvalidAuthorization)?; - end_compat_session(&mut txn, &clock, session).await?; + txn.compat_session().finish(&clock, session).await?; txn.commit().await?; diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 58e9eb8e4..7bfc940a4 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -16,10 +16,9 @@ use axum::{extract::State, response::IntoResponse, Json}; use chrono::Duration; use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; -use mas_storage::compat::{ - add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token, - expire_compat_access_token, find_compat_refresh_token, lookup_compat_access_token, - lookup_compat_session, +use mas_storage::{ + compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, + Repository, }; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; @@ -101,7 +100,9 @@ pub(crate) async fn post( return Err(RouteError::InvalidToken); } - let refresh_token = find_compat_refresh_token(&mut txn, &input.refresh_token) + let refresh_token = txn + .compat_refresh_token() + .find_by_token(&input.refresh_token) .await? .ok_or(RouteError::InvalidToken)?; @@ -109,7 +110,9 @@ pub(crate) async fn post( return Err(RouteError::RefreshTokenConsumed); } - let session = lookup_compat_session(&mut txn, refresh_token.session_id) + let session = txn + .compat_session() + .lookup(refresh_token.session_id) .await? .ok_or(RouteError::UnknownSession)?; @@ -117,7 +120,9 @@ pub(crate) async fn post( return Err(RouteError::InvalidSession); } - let access_token = lookup_compat_access_token(&mut txn, refresh_token.access_token_id) + let access_token = txn + .compat_access_token() + .lookup(refresh_token.access_token_id) .await? .filter(|t| t.is_valid(clock.now())); @@ -125,29 +130,35 @@ pub(crate) async fn post( let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng); let expires_in = Duration::minutes(5); - let new_access_token = add_compat_access_token( - &mut txn, - &mut rng, - &clock, - &session, - new_access_token_str, - Some(expires_in), - ) - .await?; - let new_refresh_token = add_compat_refresh_token( - &mut txn, - &mut rng, - &clock, - &session, - &new_access_token, - new_refresh_token_str, - ) - .await?; + let new_access_token = txn + .compat_access_token() + .add( + &mut rng, + &clock, + &session, + new_access_token_str, + Some(expires_in), + ) + .await?; + let new_refresh_token = txn + .compat_refresh_token() + .add( + &mut rng, + &clock, + &session, + &new_access_token, + new_refresh_token_str, + ) + .await?; - consume_compat_refresh_token(&mut txn, &clock, refresh_token).await?; + txn.compat_refresh_token() + .consume(&clock, refresh_token) + .await?; if let Some(access_token) = access_token { - expire_compat_access_token(&mut txn, &clock, access_token).await?; + txn.compat_access_token() + .expire(&clock, access_token) + .await?; } txn.commit().await?; diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 3dec02db1..ef6ba5b21 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -22,7 +22,7 @@ use mas_data_model::{TokenFormatError, TokenType}; use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_keystore::Encrypter; use mas_storage::{ - compat::{find_compat_access_token, find_compat_refresh_token, lookup_compat_session}, + compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, oauth2::{ access_token::find_access_token, refresh_token::lookup_refresh_token, OAuth2SessionRepository, @@ -243,12 +243,16 @@ pub(crate) async fn post( } TokenType::CompatAccessToken => { - let token = find_compat_access_token(&mut conn, token) + let access_token = conn + .compat_access_token() + .find_by_token(token) .await? .filter(|t| t.is_valid(clock.now())) .ok_or(RouteError::UnknownToken)?; - let session = lookup_compat_session(&mut conn, token.session_id) + let session = conn + .compat_session() + .lookup(access_token.session_id) .await? .filter(|s| s.is_valid()) .ok_or(RouteError::UnknownToken)?; @@ -269,9 +273,9 @@ pub(crate) async fn post( client_id: Some("legacy".into()), username: Some(user.username), token_type: Some(OAuthTokenTypeHint::AccessToken), - exp: token.expires_at, - iat: Some(token.created_at), - nbf: Some(token.created_at), + exp: access_token.expires_at, + iat: Some(access_token.created_at), + nbf: Some(access_token.created_at), sub: Some(user.sub), aud: None, iss: None, @@ -280,12 +284,16 @@ pub(crate) async fn post( } TokenType::CompatRefreshToken => { - let refresh_token = find_compat_refresh_token(&mut conn, token) + let refresh_token = conn + .compat_refresh_token() + .find_by_token(token) .await? .filter(|t| t.is_valid()) .ok_or(RouteError::UnknownToken)?; - let session = lookup_compat_session(&mut conn, refresh_token.session_id) + let session = conn + .compat_session() + .lookup(refresh_token.session_id) .await? .filter(|s| s.is_valid()) .ok_or(RouteError::UnknownToken)?; diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index 6035c74db..3872588f0 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -15,7 +15,7 @@ use anyhow::Context; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, + compat::CompatSsoLoginRepository, oauth2::authorization_grant::get_grant_by_id, upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; @@ -54,7 +54,9 @@ impl OptionalPostAuthAction { } PostAuthAction::ContinueCompatSsoLogin { id } => { - let login = get_compat_sso_login_by_id(conn, id) + let login = conn + .compat_sso_login() + .lookup(id) .await? .context("Failed to load compat SSO login")?; let login = Box::new(login); diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 5324fa2a5..6e7082dd8 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -98,6 +98,21 @@ }, "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n " }, + "18c3e56c72ef26bd42653c379767ffdd97bb06cb1686dfbf4099f3ad3d7b22c8": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " + }, "1d372f36c382ab16264cea54537af3544ea6d6d75d10b432b07dbd0dadd2fa4e": { "describe": { "columns": [ @@ -168,22 +183,6 @@ }, "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n " }, - "2153118b364a33582e7f598acce3789fcb8d938948a819b15cf0b6d37edf58b2": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO compat_access_tokens\n (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " - }, "262bee715889dc3e608639549600a131e641951ff979634e7c97afc74bbc1605": { "describe": { "columns": [], @@ -197,79 +196,6 @@ }, "query": "\n UPDATE oauth2_authorization_grants\n SET exchanged_at = $2\n WHERE oauth2_authorization_grant_id = $1\n " }, - "2e756fe7be50128c0acc5f79df3a084230e9ca13cd45bd0858f97e59da20006e": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n " - }, - "360466ff599c67c9af2ac75399c0b536a22c1178972a0172b707bcc81d47357b": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO compat_refresh_tokens\n (compat_refresh_token_id, compat_session_id,\n compat_access_token_id, refresh_token, created_at)\n VALUES ($1, $2, $3, $4, $5)\n " - }, - "3cf8e061206620071b39d0262cd165bb367b12b8e904180730d8acfa5af3d4b9": { - "describe": { - "columns": [ - { - "name": "compat_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "device_id", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "user_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "created_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "finished_at", - "ordinal": 4, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT compat_session_id\n , device_id\n , user_id\n , created_at\n , finished_at\n FROM compat_sessions\n WHERE compat_session_id = $1\n " - }, "3d66f3121b11ce923b9c60609b510a8ca899640e78cc8f5b03168622928ffe94": { "describe": { "columns": [], @@ -384,6 +310,56 @@ }, "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n " }, + "432e199b0d47fe299d840c91159726c0a4f89f65b4dc3e33ddad58aabf6b148b": { + "describe": { + "columns": [ + { + "name": "compat_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "compat_access_token_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE refresh_token = $1\n " + }, "43a5cafbdc8037e9fb779812a0793cf0859902aa0dc8d25d4c33d231d3d1118b": { "describe": { "columns": [], @@ -465,20 +441,7 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n oauth2_session_id = os.oauth2_session_id,\n fulfilled_at = os.created_at\n FROM oauth2_sessions os\n WHERE\n og.oauth2_authorization_grant_id = $1\n AND os.oauth2_session_id = $2\n RETURNING fulfilled_at AS \"fulfilled_at!: DateTime\"\n " }, - "4c4dbb846bb98d84f6b7f886f8af9833c7efe27b8b4f297077887232bef322ee": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n WHERE compat_session_id = $1\n " - }, - "4f080990eb6dd9f6128f3a1aee195b99d5f286fa0f6c27d744f73848343879d4": { + "478f0ad710da8bfd803c6cddd982bc504d1b6bd0f5283de53c8c7b1b4b7dafd4": { "describe": { "columns": [ { @@ -487,27 +450,27 @@ "type_info": "Uuid" }, { - "name": "compat_sso_login_token", + "name": "login_token", "ordinal": 1, "type_info": "Text" }, { - "name": "compat_sso_login_redirect_uri", + "name": "redirect_uri", "ordinal": 2, "type_info": "Text" }, { - "name": "compat_sso_login_created_at", + "name": "created_at", "ordinal": 3, "type_info": "Timestamptz" }, { - "name": "compat_sso_login_fulfilled_at", + "name": "fulfilled_at", "ordinal": 4, "type_info": "Timestamptz" }, { - "name": "compat_sso_login_exchanged_at", + "name": "exchanged_at", "ordinal": 5, "type_info": "Timestamptz" }, @@ -528,11 +491,25 @@ ], "parameters": { "Left": [ - "Uuid" + "Text" ] } }, - "query": "\n SELECT cl.compat_sso_login_id\n , cl.login_token AS \"compat_sso_login_token\"\n , cl.redirect_uri AS \"compat_sso_login_redirect_uri\"\n , cl.created_at AS \"compat_sso_login_created_at\"\n , cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\"\n , cl.exchanged_at AS \"compat_sso_login_exchanged_at\"\n , cl.compat_session_id AS \"compat_session_id\"\n\n FROM compat_sso_logins cl\n WHERE cl.compat_sso_login_id = $1\n " + "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n\n FROM compat_sso_logins\n WHERE login_token = $1\n " + }, + "4d79ce892e4595edb8b801e94fb0cbef28facdfd2e45d1c72c57f47418fbe24b": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_sso_logins\n SET\n compat_session_id = $2,\n fulfilled_at = $3\n WHERE\n compat_sso_login_id = $1\n " }, "51158bfcaa1a8d8e051bffe7c5ba0369bf53fb162f7622626054e89e68fc07bd": { "describe": { @@ -555,6 +532,50 @@ }, "query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n " }, + "53ad718642644b47a2d49f768d81bd993088526923769a9147281686c2d47591": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE access_token = $1\n " + }, "583ae9a0db9cd55fa57a179339550f3dab1bfc76f35ad488e1560ea37f7ed029": { "describe": { "columns": [], @@ -598,20 +619,6 @@ }, "query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n " }, - "60d039442cfa57e187602c0ff5e386e32fb774b5ad2d2f2c616040819b76873e": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_sso_logins\n SET\n compat_session_id = $2,\n fulfilled_at = $3\n WHERE\n compat_sso_login_id = $1\n " - }, "62d05e8e4317bdb180298737d422e64d161c5ad3813913a6f7d67a53569ea76a": { "describe": { "columns": [], @@ -745,6 +752,21 @@ }, "query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n " }, + "6e21e7d816f806da9bb5176931bdb550dee05c44c9d93f53df95fe3b4a840347": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n " + }, "6f97b5f9ad0d4d15387150bea3839fb7f81015f7ceef61ecaadba64521895cff": { "describe": { "columns": [], @@ -782,6 +804,50 @@ }, "query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.finished_at IS NULL\n " }, + "77dfa9fae1a9c77b70476d7da19d3313a02886994cfff0690451229fb5ae2f77": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE compat_access_token_id = $1\n " + }, "79295f3d3a75f831e9469aabfa720d381a254d00dbe39fef1e9652029d51b89b": { "describe": { "columns": [ @@ -871,19 +937,6 @@ }, "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " }, - "7e3247e35ecf5335f0656c53bcde27264a9efb8dccb6246344950614f487dcaf": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n " - }, "836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c": { "describe": { "columns": [ @@ -1154,6 +1207,19 @@ }, "query": "\n UPDATE user_email_confirmation_codes\n SET consumed_at = $2\n WHERE user_email_confirmation_code_id = $1\n " }, + "9348d87f9e06b614c7e90bdc93bcf38236766aaf4d894bf768debdff2b59fae2": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n " + }, "94fd96446b237c87bd6bf741f3c42b37ee751b87b7fcc459602bdf8c46962443": { "describe": { "columns": [ @@ -1174,18 +1240,21 @@ }, "query": "\n SELECT EXISTS(\n SELECT 1 FROM users WHERE username = $1\n ) AS \"exists!\"\n " }, - "99f5f9eb0adc5ec120ed8194cbf6a8545155bef09e6d94d92fb67fd1b14d4f28": { + "9f7bdc034c618e47e49c467d0d7f5b8c297d055abe248cc876dbc12c5a7dc920": { "describe": { "columns": [], "nullable": [], "parameters": { "Left": [ "Uuid", + "Uuid", + "Uuid", + "Text", "Timestamptz" ] } }, - "query": "\n UPDATE compat_refresh_tokens\n SET consumed_at = $2\n WHERE compat_refresh_token_id = $1\n " + "query": "\n INSERT INTO compat_refresh_tokens\n (compat_refresh_token_id, compat_session_id,\n compat_access_token_id, refresh_token, created_at)\n VALUES ($1, $2, $3, $4, $5)\n " }, "a300fe99c95679c5664646a6a525c0491829e97db45f3234483872ed38436322": { "describe": { @@ -1243,6 +1312,22 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n requires_consent = 'f'\n WHERE\n og.oauth2_authorization_grant_id = $1\n " }, + "a7f780528882a2ae66c45435215763eed0582264861436eab3f862e3eb12cab1": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO compat_access_tokens\n (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " + }, "aa2fd69c595f94d8598715766a79671dba8f87b9d7af6ac30e3fa1fbc8cce28a": { "describe": { "columns": [ @@ -1371,6 +1456,19 @@ }, "query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE authorization_code = $1\n " }, + "ab34912b42a48a8b5c8d63e271b99b7d0b690a2471873c6654b1b6cf2079b95c": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n WHERE compat_session_id = $1\n " + }, "aff08a8caabeb62f4929e6e901e7ca7c55e284c18c5c1d1e78821dd9bc961412": { "describe": { "columns": [ @@ -1652,6 +1750,19 @@ }, "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n " }, + "bbf62633c561706a762089bbab2f76a9ba3e2ed3539ef16accb601fb609c2ec9": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n " + }, "bd1f6daa5fa1b10250c01f8b3fbe451646a9ceeefa6f72b9c4e29b6d05f17641": { "describe": { "columns": [], @@ -1696,106 +1807,6 @@ }, "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " }, - "c3e60701299be7728108b8967ec5396fb186adaac360d6a0152d25e4a4f46f46": { - "describe": { - "columns": [ - { - "name": "compat_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 4, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - true, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE compat_access_token_id = $1\n " - }, - "c78246fc8737491352f71ea9410e79df8de88596c8197405cda36eb8c8187810": { - "describe": { - "columns": [ - { - "name": "compat_sso_login_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_sso_login_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_sso_login_redirect_uri", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "compat_sso_login_created_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_fulfilled_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_exchanged_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 6, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT cl.compat_sso_login_id\n , cl.login_token AS \"compat_sso_login_token\"\n , cl.redirect_uri AS \"compat_sso_login_redirect_uri\"\n , cl.created_at AS \"compat_sso_login_created_at\"\n , cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\"\n , cl.exchanged_at AS \"compat_sso_login_exchanged_at\"\n , cl.compat_session_id AS \"compat_session_id\"\n FROM compat_sso_logins cl\n WHERE cl.login_token = $1\n " - }, "c88376abdba124ff0487a9a69d2345c7d69d7394f355111ec369cfa6d45fb40f": { "describe": { "columns": [], @@ -1822,114 +1833,18 @@ }, "query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)\n " }, - "ca63558e877bd115aa7ca24de0cc3f78a13cb55105758fe0bd930da513f75504": { - "describe": { - "columns": [ - { - "name": "compat_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "consumed_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "compat_access_token_id", - "ordinal": 5, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE refresh_token = $1\n " - }, - "caf54e4659306a746747aa61906bdb2cb8da51176e90435aa8b9754ebf3e4d60": { + "d0b403e9c843ef19fa5ad60bec32ebf14a1ba0d01681c3836366d3f55e7851f4": { "describe": { "columns": [], "nullable": [], "parameters": { "Left": [ "Uuid", - "Uuid", - "Text", "Timestamptz" ] } }, - "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " - }, - "cf43b82bdf534400f900cff3c5356083db0f9e5407e288b64f43d7ac100de058": { - "describe": { - "columns": [ - { - "name": "compat_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 4, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - true, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE access_token = $1\n " + "query": "\n UPDATE compat_refresh_tokens\n SET consumed_at = $2\n WHERE compat_refresh_token_id = $1\n " }, "d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c": { "describe": { @@ -1951,21 +1866,6 @@ }, "query": "\n SELECT COUNT(*)\n FROM user_emails\n WHERE user_id = $1\n " }, - "d1738c27339b81f0844da4bd9b040b9b07a91aa4d9b199b98f24c9cee5709b2b": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n " - }, "d1f1aac41bb2f0d194f9b3d846663c267865d0d22dd5fa8a668daf29dca88d36": { "describe": { "columns": [ @@ -2211,6 +2111,112 @@ }, "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " }, + "ddb22dd9ae9367af65a607e1fdc48b3d9581d67deea0c168f24e02090082bb82": { + "describe": { + "columns": [ + { + "name": "compat_sso_login_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "login_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "redirect_uri", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "fulfilled_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "exchanged_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 6, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n\n FROM compat_sso_logins\n WHERE compat_sso_login_id = $1\n " + }, + "e35d56de7136d43d0803ec825b0612e4185cef838f105d66f18cb24865e45140": { + "describe": { + "columns": [ + { + "name": "compat_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "compat_access_token_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE compat_refresh_token_id = $1\n " + }, "e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": { "describe": { "columns": [ @@ -2306,6 +2312,50 @@ }, "query": "\n SELECT oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n , finished_at\n FROM oauth2_sessions\n\n WHERE oauth2_session_id = $1\n " }, + "f3ee06958d827b152c57328caa0a6030c372cb99cdb60e4b75a28afeb5096f45": { + "describe": { + "columns": [ + { + "name": "compat_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "device_id", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "finished_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_session_id\n , device_id\n , user_id\n , created_at\n , finished_at\n FROM compat_sessions\n WHERE compat_session_id = $1\n " + }, "f5edcd4c306ca8179cdf9d4aab59fbba971b54611c91345849920954dd8089b3": { "describe": { "columns": [], diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs deleted file mode 100644 index 3befa8dda..000000000 --- a/crates/storage/src/compat.rs +++ /dev/null @@ -1,757 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{ - CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, - CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, User, -}; -use rand::Rng; -use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder}; -use tracing::{info_span, Instrument}; -use ulid::Ulid; -use url::Url; -use uuid::Uuid; - -use crate::{ - pagination::{process_page, QueryBuilderExt}, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; - -struct CompatSessionLookup { - compat_session_id: Uuid, - device_id: String, - user_id: Uuid, - created_at: DateTime, - finished_at: Option>, -} - -#[tracing::instrument(skip_all, err)] -pub async fn lookup_compat_session( - executor: impl PgExecutor<'_>, - session_id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatSessionLookup, - r#" - SELECT compat_session_id - , device_id - , user_id - , created_at - , finished_at - FROM compat_sessions - WHERE compat_session_id = $1 - "#, - Uuid::from(session_id), - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let id = res.compat_session_id.into(); - let device = Device::try_from(res.device_id).map_err(|e| { - DatabaseInconsistencyError::on("compat_sessions") - .column("device_id") - .row(id) - .source(e) - })?; - - let state = match res.finished_at { - None => CompatSessionState::Valid, - Some(finished_at) => CompatSessionState::Finished { finished_at }, - }; - - let session = CompatSession { - id, - state, - user_id: res.user_id.into(), - device, - created_at: res.created_at, - }; - - Ok(Some(session)) -} - -struct CompatAccessTokenLookup { - compat_access_token_id: Uuid, - access_token: String, - created_at: DateTime, - expires_at: Option>, - compat_session_id: Uuid, -} - -impl From for CompatAccessToken { - fn from(value: CompatAccessTokenLookup) -> Self { - Self { - id: value.compat_access_token_id.into(), - session_id: value.compat_session_id.into(), - token: value.access_token, - created_at: value.created_at, - expires_at: value.expires_at, - } - } -} - -#[tracing::instrument(skip_all, err)] -pub async fn find_compat_access_token( - executor: impl PgExecutor<'_>, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatAccessTokenLookup, - r#" - SELECT compat_access_token_id - , access_token - , created_at - , expires_at - , compat_session_id - - FROM compat_access_tokens - - WHERE access_token = $1 - "#, - token, - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) -} - -#[tracing::instrument( - skip_all, - fields( - compat_access_token.id = %id, - ), - err, -)] -pub async fn lookup_compat_access_token( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatAccessTokenLookup, - r#" - SELECT compat_access_token_id - , access_token - , created_at - , expires_at - , compat_session_id - - FROM compat_access_tokens - - WHERE compat_access_token_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) -} - -pub struct CompatRefreshTokenLookup { - compat_refresh_token_id: Uuid, - refresh_token: String, - created_at: DateTime, - consumed_at: Option>, - compat_access_token_id: Uuid, - compat_session_id: Uuid, -} - -#[tracing::instrument(skip_all, err)] -#[allow(clippy::type_complexity)] -pub async fn find_compat_refresh_token( - executor: impl PgExecutor<'_>, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatRefreshTokenLookup, - r#" - SELECT compat_refresh_token_id - , refresh_token - , created_at - , consumed_at - , compat_session_id - , compat_access_token_id - - FROM compat_refresh_tokens - - WHERE refresh_token = $1 - "#, - token, - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None); }; - - let state = match res.consumed_at { - None => CompatRefreshTokenState::Valid, - Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at }, - }; - - let refresh_token = CompatRefreshToken { - id: res.compat_refresh_token_id.into(), - state, - session_id: res.compat_session_id.into(), - access_token_id: res.compat_access_token_id.into(), - token: res.refresh_token, - created_at: res.created_at, - }; - - Ok(Some(refresh_token)) -} - -#[tracing::instrument( - skip_all, - fields( - compat_session.id = %session.id, - compat_session.device.id = session.device.as_str(), - compat_access_token.id, - user.id = %session.user_id, - ), - err, -)] -pub async fn add_compat_access_token( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - session: &CompatSession, - token: String, - expires_after: Option, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); - - let expires_at = expires_after.map(|expires_after| created_at + expires_after); - - sqlx::query!( - r#" - INSERT INTO compat_access_tokens - (compat_access_token_id, compat_session_id, access_token, created_at, expires_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - token, - created_at, - expires_at, - ) - .execute(executor) - .instrument(tracing::info_span!("Insert compat access token")) - .await?; - - Ok(CompatAccessToken { - id, - session_id: session.id, - token, - created_at, - expires_at, - }) -} - -#[tracing::instrument( - skip_all, - fields( - compat_access_token.id = %access_token.id, - ), - err, -)] -pub async fn expire_compat_access_token( - executor: impl PgExecutor<'_>, - clock: &Clock, - access_token: CompatAccessToken, -) -> Result<(), DatabaseError> { - let expires_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE compat_access_tokens - SET expires_at = $2 - WHERE compat_access_token_id = $1 - "#, - Uuid::from(access_token.id), - expires_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} - -#[tracing::instrument( - skip_all, - fields( - compat_session.id = %session.id, - compat_session.device.id = session.device.as_str(), - compat_access_token.id = %access_token.id, - compat_refresh_token.id, - user.id = %session.user_id, - ), - err, -)] -pub async fn add_compat_refresh_token( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - session: &CompatSession, - access_token: &CompatAccessToken, - token: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_refresh_tokens - (compat_refresh_token_id, compat_session_id, - compat_access_token_id, refresh_token, created_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - Uuid::from(access_token.id), - token, - created_at, - ) - .execute(executor) - .instrument(tracing::info_span!("Insert compat refresh token")) - .await?; - - Ok(CompatRefreshToken { - id, - state: CompatRefreshTokenState::default(), - session_id: session.id, - access_token_id: access_token.id, - token, - created_at, - }) -} - -#[tracing::instrument( - skip_all, - fields(%compat_session.id), - err, -)] -pub async fn end_compat_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - compat_session: CompatSession, -) -> Result { - let finished_at = clock.now(); - - let res = sqlx::query!( - r#" - UPDATE compat_sessions cs - SET finished_at = $2 - WHERE compat_session_id = $1 - "#, - Uuid::from(compat_session.id), - finished_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - let compat_session = compat_session - .finish(finished_at) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(compat_session) -} - -#[tracing::instrument( - skip_all, - fields( - compat_refresh_token.id = %refresh_token.id, - ), - err, -)] -pub async fn consume_compat_refresh_token( - executor: impl PgExecutor<'_>, - clock: &Clock, - refresh_token: CompatRefreshToken, -) -> Result<(), DatabaseError> { - let consumed_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE compat_refresh_tokens - SET consumed_at = $2 - WHERE compat_refresh_token_id = $1 - "#, - Uuid::from(refresh_token.id), - consumed_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} - -#[tracing::instrument( - skip_all, - fields( - compat_sso_login.id, - compat_sso_login.redirect_uri = %redirect_uri, - ), - err, -)] -pub async fn insert_compat_sso_login( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - login_token: String, - redirect_uri: Url, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_sso_logins - (compat_sso_login_id, login_token, redirect_uri, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - &login_token, - redirect_uri.as_str(), - created_at, - ) - .execute(executor) - .instrument(tracing::info_span!("Insert compat SSO login")) - .await?; - - Ok(CompatSsoLogin { - id, - login_token, - redirect_uri, - created_at, - state: CompatSsoLoginState::Pending, - }) -} - -#[derive(sqlx::FromRow)] -struct CompatSsoLoginLookup { - compat_sso_login_id: Uuid, - compat_sso_login_token: String, - compat_sso_login_redirect_uri: String, - compat_sso_login_created_at: DateTime, - compat_sso_login_fulfilled_at: Option>, - compat_sso_login_exchanged_at: Option>, - compat_session_id: Option, -} - -impl TryFrom for CompatSsoLogin { - type Error = DatabaseInconsistencyError; - - fn try_from(res: CompatSsoLoginLookup) -> Result { - let id = res.compat_sso_login_id.into(); - let redirect_uri = Url::parse(&res.compat_sso_login_redirect_uri).map_err(|e| { - DatabaseInconsistencyError::on("compat_sso_logins") - .column("redirect_uri") - .row(id) - .source(e) - })?; - - let state = match ( - res.compat_sso_login_fulfilled_at, - res.compat_sso_login_exchanged_at, - res.compat_session_id, - ) { - (None, None, None) => CompatSsoLoginState::Pending, - (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled { - fulfilled_at, - session_id: session_id.into(), - }, - (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => { - CompatSsoLoginState::Exchanged { - fulfilled_at, - exchanged_at, - session_id: session_id.into(), - } - } - _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), - }; - - Ok(CompatSsoLogin { - id, - login_token: res.compat_sso_login_token, - redirect_uri, - created_at: res.compat_sso_login_created_at, - state, - }) - } -} - -#[tracing::instrument( - skip_all, - fields( - compat_sso_login.id = %id, - ), - err, -)] -pub async fn get_compat_sso_login_by_id( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatSsoLoginLookup, - r#" - SELECT cl.compat_sso_login_id - , cl.login_token AS "compat_sso_login_token" - , cl.redirect_uri AS "compat_sso_login_redirect_uri" - , cl.created_at AS "compat_sso_login_created_at" - , cl.fulfilled_at AS "compat_sso_login_fulfilled_at" - , cl.exchanged_at AS "compat_sso_login_exchanged_at" - , cl.compat_session_id AS "compat_session_id" - - FROM compat_sso_logins cl - WHERE cl.compat_sso_login_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .instrument(tracing::info_span!("Lookup compat SSO login")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - ), - err, -)] -pub async fn get_paginated_user_compat_sso_logins( - executor: impl PgExecutor<'_>, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT cl.compat_sso_login_id - , cl.login_token AS "compat_sso_login_token" - , cl.redirect_uri AS "compat_sso_login_redirect_uri" - , cl.created_at AS "compat_sso_login_created_at" - , cl.fulfilled_at AS "compat_sso_login_fulfilled_at" - , cl.exchanged_at AS "compat_sso_login_exchanged_at" - , cl.compat_session_id AS "compat_session_id" - FROM compat_sso_logins cl - "#, - ); - - query - .push(" WHERE cs.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("cl.compat_sso_login_id", before, after, first, last)?; - - let span = info_span!( - "Fetch paginated user compat SSO logins", - db.statement = query.sql() - ); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) - .await?; - - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - - let page: Result, _> = page.into_iter().map(TryInto::try_into).collect(); - Ok((has_previous_page, has_next_page, page?)) -} - -#[tracing::instrument(skip_all, err)] -pub async fn get_compat_sso_login_by_token( - executor: impl PgExecutor<'_>, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatSsoLoginLookup, - r#" - SELECT cl.compat_sso_login_id - , cl.login_token AS "compat_sso_login_token" - , cl.redirect_uri AS "compat_sso_login_redirect_uri" - , cl.created_at AS "compat_sso_login_created_at" - , cl.fulfilled_at AS "compat_sso_login_fulfilled_at" - , cl.exchanged_at AS "compat_sso_login_exchanged_at" - , cl.compat_session_id AS "compat_session_id" - FROM compat_sso_logins cl - WHERE cl.login_token = $1 - "#, - token, - ) - .fetch_one(executor) - .instrument(tracing::info_span!("Lookup compat SSO login")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - compat_session.id, - compat_session.device.id = device.as_str(), - ), - err, -)] -pub async fn start_compat_session( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user: &User, - device: Device, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("compat_session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - Uuid::from(user.id), - device.as_str(), - created_at, - ) - .execute(executor) - .await?; - - Ok(CompatSession { - id, - state: CompatSessionState::default(), - user_id: user.id, - device, - created_at, - }) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %compat_sso_login.id, - %compat_sso_login.redirect_uri, - compat_session.id, - compat_session.device.id = device.as_str(), - ), - err, -)] -pub async fn fullfill_compat_sso_login( - conn: impl Acquire<'_, Database = Postgres> + Send, - mut rng: impl Rng + Send, - clock: &Clock, - user: &User, - compat_sso_login: CompatSsoLogin, - device: Device, -) -> Result { - if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) { - return Err(DatabaseError::invalid_operation()); - }; - - let mut txn = conn.begin().await?; - - let session = start_compat_session(&mut txn, &mut rng, clock, user, device).await?; - let session_id = session.id; - - let fulfilled_at = clock.now(); - let compat_sso_login = compat_sso_login - .fulfill(fulfilled_at, &session) - .map_err(DatabaseError::to_invalid_operation)?; - sqlx::query!( - r#" - UPDATE compat_sso_logins - SET - compat_session_id = $2, - fulfilled_at = $3 - WHERE - compat_sso_login_id = $1 - "#, - Uuid::from(compat_sso_login.id), - Uuid::from(session_id), - fulfilled_at, - ) - .execute(&mut txn) - .instrument(tracing::info_span!("Update compat SSO login")) - .await?; - - txn.commit().await?; - - Ok(compat_sso_login) -} - -#[tracing::instrument( - skip_all, - fields( - %compat_sso_login.id, - %compat_sso_login.redirect_uri, - ), - err, -)] -pub async fn mark_compat_sso_login_as_exchanged( - executor: impl PgExecutor<'_>, - clock: &Clock, - compat_sso_login: CompatSsoLogin, -) -> Result { - let exchanged_at = clock.now(); - let compat_sso_login = compat_sso_login - .exchange(exchanged_at) - .map_err(DatabaseError::to_invalid_operation)?; - - sqlx::query!( - r#" - UPDATE compat_sso_logins - SET - exchanged_at = $2 - WHERE - compat_sso_login_id = $1 - "#, - Uuid::from(compat_sso_login.id), - exchanged_at, - ) - .execute(executor) - .instrument(tracing::info_span!("Update compat SSO login")) - .await?; - - Ok(compat_sso_login) -} diff --git a/crates/storage/src/compat/access_token.rs b/crates/storage/src/compat/access_token.rs new file mode 100644 index 000000000..86d2dd198 --- /dev/null +++ b/crates/storage/src/compat/access_token.rs @@ -0,0 +1,246 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; +use mas_data_model::{CompatAccessToken, CompatSession}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; + +#[async_trait] +pub trait CompatAccessTokenRepository: Send + Sync { + type Error; + + /// Lookup a compat access token by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find a compat access token by its token + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error>; + + /// Add a new compat access token to the database + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + compat_session: &CompatSession, + token: String, + expires_after: Option, + ) -> Result; + + /// Set the expiration time of the compat access token to now + async fn expire( + &mut self, + clock: &Clock, + compat_access_token: CompatAccessToken, + ) -> Result; +} + +pub struct PgCompatAccessTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatAccessTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatAccessTokenLookup { + compat_access_token_id: Uuid, + access_token: String, + created_at: DateTime, + expires_at: Option>, + compat_session_id: Uuid, +} + +impl From for CompatAccessToken { + fn from(value: CompatAccessTokenLookup) -> Self { + Self { + id: value.compat_access_token_id.into(), + session_id: value.compat_session_id.into(), + token: value.access_token, + created_at: value.created_at, + expires_at: value.expires_at, + } + } +} + +#[async_trait] +impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_access_token.lookup", + skip_all, + fields( + db.statement, + compat_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatAccessTokenLookup, + r#" + SELECT compat_access_token_id + , access_token + , created_at + , expires_at + , compat_session_id + + FROM compat_access_tokens + + WHERE compat_access_token_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_access_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatAccessTokenLookup, + r#" + SELECT compat_access_token_id + , access_token + , created_at + , expires_at + , compat_session_id + + FROM compat_access_tokens + + WHERE access_token = $1 + "#, + access_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_access_token.add", + skip_all, + fields( + db.statement, + compat_access_token.id, + %compat_session.id, + user.id = %compat_session.user_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + compat_session: &CompatSession, + token: String, + expires_after: Option, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); + + let expires_at = expires_after.map(|expires_after| created_at + expires_after); + + sqlx::query!( + r#" + INSERT INTO compat_access_tokens + (compat_access_token_id, compat_session_id, access_token, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(compat_session.id), + token, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatAccessToken { + id, + session_id: compat_session.id, + token, + created_at, + expires_at, + }) + } + + #[tracing::instrument( + name = "db.compat_access_token.expire", + skip_all, + fields( + db.statement, + %compat_access_token.id, + compat_session.id = %compat_access_token.session_id, + ), + err, + )] + async fn expire( + &mut self, + clock: &Clock, + mut compat_access_token: CompatAccessToken, + ) -> Result { + let expires_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE compat_access_tokens + SET expires_at = $2 + WHERE compat_access_token_id = $1 + "#, + Uuid::from(compat_access_token.id), + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + compat_access_token.expires_at = Some(expires_at); + Ok(compat_access_token) + } +} diff --git a/crates/storage/src/compat/mod.rs b/crates/storage/src/compat/mod.rs new file mode 100644 index 000000000..3a91f8c7d --- /dev/null +++ b/crates/storage/src/compat/mod.rs @@ -0,0 +1,25 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod access_token; +mod refresh_token; +mod session; +mod sso_login; + +pub use self::{ + access_token::{CompatAccessTokenRepository, PgCompatAccessTokenRepository}, + refresh_token::{CompatRefreshTokenRepository, PgCompatRefreshTokenRepository}, + session::{CompatSessionRepository, PgCompatSessionRepository}, + sso_login::{CompatSsoLoginRepository, PgCompatSsoLoginRepository}, +}; diff --git a/crates/storage/src/compat/refresh_token.rs b/crates/storage/src/compat/refresh_token.rs new file mode 100644 index 000000000..300546226 --- /dev/null +++ b/crates/storage/src/compat/refresh_token.rs @@ -0,0 +1,260 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, +}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; + +#[async_trait] +pub trait CompatRefreshTokenRepository: Send + Sync { + type Error; + + /// Lookup a compat refresh token by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find a compat refresh token by its token + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error>; + + /// Add a new compat refresh token to the database + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + compat_session: &CompatSession, + compat_access_token: &CompatAccessToken, + token: String, + ) -> Result; + + /// Consume a compat refresh token + async fn consume( + &mut self, + clock: &Clock, + compat_refresh_token: CompatRefreshToken, + ) -> Result; +} + +pub struct PgCompatRefreshTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatRefreshTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatRefreshTokenLookup { + compat_refresh_token_id: Uuid, + refresh_token: String, + created_at: DateTime, + consumed_at: Option>, + compat_access_token_id: Uuid, + compat_session_id: Uuid, +} + +impl From for CompatRefreshToken { + fn from(value: CompatRefreshTokenLookup) -> Self { + let state = match value.consumed_at { + Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at }, + None => CompatRefreshTokenState::Valid, + }; + + Self { + id: value.compat_refresh_token_id.into(), + state, + session_id: value.compat_session_id.into(), + token: value.refresh_token, + created_at: value.created_at, + access_token_id: value.compat_access_token_id.into(), + } + } +} + +#[async_trait] +impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_refresh_token.lookup", + skip_all, + fields( + db.statement, + compat_refresh_token.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatRefreshTokenLookup, + r#" + SELECT compat_refresh_token_id + , refresh_token + , created_at + , consumed_at + , compat_session_id + , compat_access_token_id + + FROM compat_refresh_tokens + + WHERE compat_refresh_token_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatRefreshTokenLookup, + r#" + SELECT compat_refresh_token_id + , refresh_token + , created_at + , consumed_at + , compat_session_id + , compat_access_token_id + + FROM compat_refresh_tokens + + WHERE refresh_token = $1 + "#, + refresh_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.add", + skip_all, + fields( + db.statement, + compat_refresh_token.id, + %compat_session.id, + user.id = %compat_session.user_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + compat_session: &CompatSession, + compat_access_token: &CompatAccessToken, + token: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_refresh_tokens + (compat_refresh_token_id, compat_session_id, + compat_access_token_id, refresh_token, created_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(compat_session.id), + Uuid::from(compat_access_token.id), + token, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatRefreshToken { + id, + state: CompatRefreshTokenState::default(), + session_id: compat_session.id, + access_token_id: compat_access_token.id, + token, + created_at, + }) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.consume", + skip_all, + fields( + db.statement, + %compat_refresh_token.id, + compat_session.id = %compat_refresh_token.session_id, + ), + err, + )] + async fn consume( + &mut self, + clock: &Clock, + compat_refresh_token: CompatRefreshToken, + ) -> Result { + let consumed_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE compat_refresh_tokens + SET consumed_at = $2 + WHERE compat_refresh_token_id = $1 + "#, + Uuid::from(compat_refresh_token.id), + consumed_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let compat_refresh_token = compat_refresh_token + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(compat_refresh_token) + } +} diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs new file mode 100644 index 000000000..3068be731 --- /dev/null +++ b/crates/storage/src/compat/session.rs @@ -0,0 +1,220 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{CompatSession, CompatSessionState, Device, User}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, +}; + +#[async_trait] +pub trait CompatSessionRepository: Send + Sync { + type Error; + + /// Lookup a compat session by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Start a new compat session + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + device: Device, + ) -> Result; + + /// End a compat session + async fn finish( + &mut self, + clock: &Clock, + compat_session: CompatSession, + ) -> Result; +} + +pub struct PgCompatSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatSessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatSessionLookup { + compat_session_id: Uuid, + device_id: String, + user_id: Uuid, + created_at: DateTime, + finished_at: Option>, +} + +impl TryFrom for CompatSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: CompatSessionLookup) -> Result { + let id = value.compat_session_id.into(); + let device = Device::try_from(value.device_id).map_err(|e| { + DatabaseInconsistencyError::on("compat_sessions") + .column("device_id") + .row(id) + .source(e) + })?; + + let state = match value.finished_at { + None => CompatSessionState::Valid, + Some(finished_at) => CompatSessionState::Finished { finished_at }, + }; + + let session = CompatSession { + id, + state, + user_id: value.user_id.into(), + device, + created_at: value.created_at, + }; + + Ok(session) + } +} + +#[async_trait] +impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_session.lookup", + skip_all, + fields( + db.statement, + compat_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSessionLookup, + r#" + SELECT compat_session_id + , device_id + , user_id + , created_at + , finished_at + FROM compat_sessions + WHERE compat_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_session.add", + skip_all, + fields( + db.statement, + compat_session.id, + %user.id, + %user.username, + compat_session.device.id = device.as_str(), + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + device: Device, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + Uuid::from(user.id), + device.as_str(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatSession { + id, + state: CompatSessionState::default(), + user_id: user.id, + device, + created_at, + }) + } + + #[tracing::instrument( + name = "db.compat_session.finish", + skip_all, + fields( + db.statement, + %compat_session.id, + user.id = %compat_session.user_id, + compat_session.device.id = compat_session.device.as_str(), + ), + err, + )] + async fn finish( + &mut self, + clock: &Clock, + compat_session: CompatSession, + ) -> Result { + let finished_at = clock.now(); + + let res = sqlx::query!( + r#" + UPDATE compat_sessions cs + SET finished_at = $2 + WHERE compat_session_id = $1 + "#, + Uuid::from(compat_session.id), + finished_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let compat_session = compat_session + .finish(finished_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(compat_session) + } +} diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs new file mode 100644 index 000000000..cba777d3d --- /dev/null +++ b/crates/storage/src/compat/sso_login.rs @@ -0,0 +1,397 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use url::Url; +use uuid::Uuid; + +use crate::{ + pagination::{process_page, Page, QueryBuilderExt}, + tracing::ExecuteExt, + Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, +}; + +#[async_trait] +pub trait CompatSsoLoginRepository: Send + Sync { + type Error; + + /// Lookup a compat SSO login by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find a compat SSO login by its login token + async fn find_by_token( + &mut self, + login_token: &str, + ) -> Result, Self::Error>; + + /// Start a new compat SSO login token + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + login_token: String, + redirect_uri: Url, + ) -> Result; + + /// Fulfill a compat SSO login by providing a compat session + async fn fulfill( + &mut self, + clock: &Clock, + compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, + ) -> Result; + + /// Mark a compat SSO login as exchanged + async fn exchange( + &mut self, + clock: &Clock, + compat_sso_login: CompatSsoLogin, + ) -> Result; + + /// Get a paginated list of compat SSO logins for a user + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error>; +} + +pub struct PgCompatSsoLoginRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatSsoLoginRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct CompatSsoLoginLookup { + compat_sso_login_id: Uuid, + login_token: String, + redirect_uri: String, + created_at: DateTime, + fulfilled_at: Option>, + exchanged_at: Option>, + compat_session_id: Option, +} + +impl TryFrom for CompatSsoLogin { + type Error = DatabaseInconsistencyError; + + fn try_from(res: CompatSsoLoginLookup) -> Result { + let id = res.compat_sso_login_id.into(); + let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| { + DatabaseInconsistencyError::on("compat_sso_logins") + .column("redirect_uri") + .row(id) + .source(e) + })?; + + let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) { + (None, None, None) => CompatSsoLoginState::Pending, + (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled { + fulfilled_at, + session_id: session_id.into(), + }, + (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => { + CompatSsoLoginState::Exchanged { + fulfilled_at, + exchanged_at, + session_id: session_id.into(), + } + } + _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), + }; + + Ok(CompatSsoLogin { + id, + login_token: res.login_token, + redirect_uri, + created_at: res.created_at, + state, + }) + } +} + +#[async_trait] +impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_sso_login.lookup", + skip_all, + fields( + db.statement, + compat_sso_login.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSsoLoginLookup, + r#" + SELECT compat_sso_login_id + , login_token + , redirect_uri + , created_at + , fulfilled_at + , exchanged_at + , compat_session_id + + FROM compat_sso_logins + WHERE compat_sso_login_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_sso_login.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + login_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSsoLoginLookup, + r#" + SELECT compat_sso_login_id + , login_token + , redirect_uri + , created_at + , fulfilled_at + , exchanged_at + , compat_session_id + + FROM compat_sso_logins + WHERE login_token = $1 + "#, + login_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_sso_login.add", + skip_all, + fields( + db.statement, + compat_sso_login.id, + compat_sso_login.redirect_uri = %redirect_uri, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + login_token: String, + redirect_uri: Url, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_sso_logins + (compat_sso_login_id, login_token, redirect_uri, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + &login_token, + redirect_uri.as_str(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatSsoLogin { + id, + login_token, + redirect_uri, + created_at, + state: CompatSsoLoginState::default(), + }) + } + + #[tracing::instrument( + name = "db.compat_sso_login.fulfill", + skip_all, + fields( + db.statement, + %compat_sso_login.id, + %compat_session.id, + compat_session.device.id = compat_session.device.as_str(), + user.id = %compat_session.user_id, + ), + err, + )] + async fn fulfill( + &mut self, + clock: &Clock, + compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, + ) -> Result { + let fulfilled_at = clock.now(); + let compat_sso_login = compat_sso_login + .fulfill(fulfilled_at, compat_session) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE compat_sso_logins + SET + compat_session_id = $2, + fulfilled_at = $3 + WHERE + compat_sso_login_id = $1 + "#, + Uuid::from(compat_sso_login.id), + Uuid::from(compat_session.id), + fulfilled_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(compat_sso_login) + } + + #[tracing::instrument( + name = "db.compat_sso_login.exchange", + skip_all, + fields( + db.statement, + %compat_sso_login.id, + ), + err, + )] + async fn exchange( + &mut self, + clock: &Clock, + compat_sso_login: CompatSsoLogin, + ) -> Result { + let exchanged_at = clock.now(); + let compat_sso_login = compat_sso_login + .exchange(exchanged_at) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE compat_sso_logins + SET + exchanged_at = $2 + WHERE + compat_sso_login_id = $1 + "#, + Uuid::from(compat_sso_login.id), + exchanged_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(compat_sso_login) + } + + #[tracing::instrument( + name = "db.compat_sso_login.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err + )] + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT cl.compat_sso_login_id + , cl.login_token + , cl.redirect_uri + , cl.created_at + , cl.fulfilled_at + , cl.exchanged_at + , cl.compat_session_id + + FROM compat_sso_logins cl + INNER JOIN compat_sessions ON compat_session_id + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("cl.compat_sso_login_id", before, after, first, last)?; + + let page: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; + + let edges: Result, _> = edges.into_iter().map(TryInto::try_into).collect(); + Ok(Page { + has_next_page, + has_previous_page, + edges: edges?, + }) + } +} diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 8eda57016..ddd6e1ea8 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -15,6 +15,10 @@ use sqlx::{PgConnection, Postgres, Transaction}; use crate::{ + compat::{ + PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository, + PgCompatSsoLoginRepository, + }, oauth2::{PgOAuth2ClientRepository, PgOAuth2SessionRepository}, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, @@ -63,6 +67,22 @@ pub trait Repository { where Self: 'c; + type CompatSessionRepository<'c> + where + Self: 'c; + + type CompatSsoLoginRepository<'c> + where + Self: 'c; + + type CompatAccessTokenRepository<'c> + where + Self: 'c; + + type CompatRefreshTokenRepository<'c> + where + Self: 'c; + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>; @@ -72,6 +92,10 @@ pub trait Repository { fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>; fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>; fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>; + fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>; + fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>; + fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>; + fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>; } impl Repository for PgConnection { @@ -84,6 +108,10 @@ impl Repository for PgConnection { type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; + type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; + type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; + type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; + type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -120,6 +148,22 @@ impl Repository for PgConnection { fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { PgOAuth2SessionRepository::new(self) } + + fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { + PgCompatSessionRepository::new(self) + } + + fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { + PgCompatSsoLoginRepository::new(self) + } + + fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { + PgCompatAccessTokenRepository::new(self) + } + + fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { + PgCompatRefreshTokenRepository::new(self) + } } impl<'t> Repository for Transaction<'t, Postgres> { @@ -132,6 +176,10 @@ impl<'t> Repository for Transaction<'t, Postgres> { type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; + type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; + type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; + type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; + type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -168,4 +216,20 @@ impl<'t> Repository for Transaction<'t, Postgres> { fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { PgOAuth2SessionRepository::new(self) } + + fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { + PgCompatSessionRepository::new(self) + } + + fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { + PgCompatSsoLoginRepository::new(self) + } + + fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { + PgCompatAccessTokenRepository::new(self) + } + + fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { + PgCompatRefreshTokenRepository::new(self) + } } From 8e5b3e46eafbe0f72a1eef65c88de8eb33f1a22f Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 12 Jan 2023 18:26:04 +0100 Subject: [PATCH 19/45] storage: remaining oauth2 repositories - authorization grants - access tokens - refresh tokens --- crates/axum-utils/src/user_authorization.rs | 6 +- .../src/oauth2/authorization_grant.rs | 26 +- .../src/oauth2/authorization/complete.rs | 20 +- .../handlers/src/oauth2/authorization/mod.rs | 36 +- crates/handlers/src/oauth2/consent.rs | 34 +- crates/handlers/src/oauth2/introspection.rs | 13 +- crates/handlers/src/oauth2/token.rs | 84 +- crates/handlers/src/views/shared.rs | 6 +- crates/storage/sqlx-data.json | 1077 +++++++++-------- crates/storage/src/oauth2/access_token.rs | 362 +++--- .../storage/src/oauth2/authorization_grant.rs | 781 ++++++------ crates/storage/src/oauth2/client.rs | 113 +- crates/storage/src/oauth2/consent.rs | 110 -- crates/storage/src/oauth2/mod.rs | 10 +- crates/storage/src/oauth2/refresh_token.rs | 329 +++-- crates/storage/src/repository.rs | 50 +- crates/tasks/src/database.rs | 9 +- 17 files changed, 1700 insertions(+), 1366 deletions(-) delete mode 100644 crates/storage/src/oauth2/consent.rs diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index a76e1e9a1..ec60103da 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -29,7 +29,7 @@ use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, Header use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode}; use mas_data_model::Session; use mas_storage::{ - oauth2::{access_token::find_access_token, OAuth2SessionRepository}, + oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository}, DatabaseError, Repository, }; use serde::{de::DeserializeOwned, Deserialize}; @@ -62,7 +62,9 @@ impl AccessToken { AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), }; - let token = find_access_token(conn, token.as_str()) + let token = conn + .oauth2_access_token() + .find_by_token(token.as_str()) .await? .ok_or(AuthorizationVerificationError::InvalidToken)?; diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index 10f619c71..76572f489 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -78,7 +78,7 @@ impl AuthorizationGrantStage { Self::Pending } - pub fn fulfill( + fn fulfill( self, fulfilled_at: DateTime, session: &Session, @@ -92,7 +92,7 @@ impl AuthorizationGrantStage { } } - pub fn exchange(self, exchanged_at: DateTime) -> Result { + fn exchange(self, exchanged_at: DateTime) -> Result { match self { Self::Fulfilled { fulfilled_at, @@ -106,7 +106,7 @@ impl AuthorizationGrantStage { } } - pub fn cancel(self, cancelled_at: DateTime) -> Result { + fn cancel(self, cancelled_at: DateTime) -> Result { match self { Self::Pending => Ok(Self::Cancelled { cancelled_at }), _ => Err(InvalidTransitionError), @@ -146,4 +146,24 @@ impl AuthorizationGrant { let max_age: Option = self.max_age.map(|x| x.get().into()); self.created_at - Duration::seconds(max_age.unwrap_or(3600 * 24 * 365)) } + + pub fn exchange(mut self, exchanged_at: DateTime) -> Result { + self.stage = self.stage.exchange(exchanged_at)?; + Ok(self) + } + + pub fn fulfill( + mut self, + fulfilled_at: DateTime, + session: &Session, + ) -> Result { + self.stage = self.stage.fulfill(fulfilled_at, session)?; + Ok(self) + } + + // TODO: this is not used? + pub fn cancel(mut self, canceld_at: DateTime) -> Result { + self.stage = self.stage.cancel(canceld_at)?; + Ok(self) + } } diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index b5cfd6ffc..9f462c502 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -26,11 +26,7 @@ use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - oauth2::{ - authorization_grant::{fulfill_grant, get_grant_by_id}, - consent::fetch_client_consent, - OAuth2ClientRepository, OAuth2SessionRepository, - }, + oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, Repository, }; use mas_templates::Templates; @@ -94,7 +90,9 @@ pub(crate) async fn get( let maybe_session = session_info.load_session(&mut txn).await?; - let grant = get_grant_by_id(&mut txn, grant_id) + let grant = txn + .oauth2_authorization_grant() + .lookup(grant_id) .await? .ok_or(RouteError::NotFound)?; @@ -192,7 +190,10 @@ pub(crate) async fn complete( .await? .ok_or(GrantCompletionError::NoSuchClient)?; - let current_consent = fetch_client_consent(&mut txn, &browser_session.user, &client).await?; + let current_consent = txn + .oauth2_client() + .get_consent_for_user(&client, &browser_session.user) + .await?; let lacks_consent = grant .scope @@ -211,7 +212,10 @@ pub(crate) async fn complete( .create_from_grant(&mut rng, &clock, &grant, &browser_session) .await?; - let grant = fulfill_grant(&mut txn, grant, session.clone()).await?; + let grant = txn + .oauth2_authorization_grant() + .fulfill(&clock, &session, grant) + .await?; // Yep! Let's complete the auth now let mut params = AuthorizationResponse::default(); diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index faf7015db..b33b69129 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -26,7 +26,7 @@ use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - oauth2::{authorization_grant::new_authorization_grant, OAuth2ClientRepository}, + oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, Repository, }; use mas_templates::Templates; @@ -275,23 +275,23 @@ pub(crate) async fn get( let requires_consent = prompt.contains(&Prompt::Consent); - let grant = new_authorization_grant( - &mut txn, - &mut rng, - &clock, - client, - redirect_uri.clone(), - params.auth.scope, - code, - params.auth.state.clone(), - params.auth.nonce, - params.auth.max_age, - None, - response_mode, - response_type.has_id_token(), - requires_consent, - ) - .await?; + let grant = txn + .oauth2_authorization_grant() + .add( + &mut rng, + &clock, + &client, + redirect_uri.clone(), + params.auth.scope, + code, + params.auth.state.clone(), + params.auth.nonce, + params.auth.max_age, + response_mode, + response_type.has_id_token(), + requires_consent, + ) + .await?; let continue_grant = PostAuthAction::continue_grant(grant.id); let res = match maybe_session { diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 64b72d58e..94bf1346a 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -29,11 +29,7 @@ use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - oauth2::{ - authorization_grant::{get_grant_by_id, give_consent_to_grant}, - consent::insert_client_consent, - OAuth2ClientRepository, - }, + oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, Repository, }; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; @@ -91,7 +87,9 @@ pub(crate) async fn get( let maybe_session = session_info.load_session(&mut conn).await?; - let grant = get_grant_by_id(&mut conn, grant_id) + let grant = conn + .oauth2_authorization_grant() + .lookup(grant_id) .await? .ok_or(RouteError::GrantNotFound)?; @@ -146,7 +144,9 @@ pub(crate) async fn post( let maybe_session = session_info.load_session(&mut txn).await?; - let grant = get_grant_by_id(&mut txn, grant_id) + let grant = txn + .oauth2_authorization_grant() + .lookup(grant_id) .await? .ok_or(RouteError::GrantNotFound)?; let next = PostAuthAction::continue_grant(grant_id); @@ -180,17 +180,17 @@ pub(crate) async fn post( .filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:")) .cloned() .collect(); - insert_client_consent( - &mut txn, - &mut rng, - &clock, - &session.user, - &client, - &scope_without_device, - ) - .await?; + txn.oauth2_client() + .give_consent_for_user( + &mut rng, + &clock, + &client, + &session.user, + &scope_without_device, + ) + .await?; - let _grant = give_consent_to_grant(&mut txn, grant).await?; + txn.oauth2_authorization_grant().give_consent(grant).await?; txn.commit().await?; diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index ef6ba5b21..d032695ae 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -23,10 +23,7 @@ use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_keystore::Encrypter; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, - oauth2::{ - access_token::find_access_token, refresh_token::lookup_refresh_token, - OAuth2SessionRepository, - }, + oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository}, user::{BrowserSessionRepository, UserRepository}, Clock, Repository, }; @@ -169,7 +166,9 @@ pub(crate) async fn post( let reply = match token_type { TokenType::AccessToken => { - let token = find_access_token(&mut conn, token) + let token = conn + .oauth2_access_token() + .find_by_token(token) .await? .filter(|t| t.is_valid(clock.now())) .ok_or(RouteError::UnknownToken)?; @@ -206,7 +205,9 @@ pub(crate) async fn post( } TokenType::RefreshToken => { - let token = lookup_refresh_token(&mut conn, token) + let token = conn + .oauth2_refresh_token() + .find_by_token(token) .await? .filter(|t| t.is_valid()) .ok_or(RouteError::UnknownToken)?; diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 75ddb4a64..97f249c27 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -33,10 +33,8 @@ use mas_keystore::{Encrypter, Keystore}; use mas_router::UrlBuilder; use mas_storage::{ oauth2::{ - access_token::{add_access_token, lookup_access_token, revoke_access_token}, - authorization_grant::{exchange_grant, lookup_grant_by_code}, - refresh_token::{add_refresh_token, consume_refresh_token, lookup_refresh_token}, - OAuth2SessionRepository, + OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, + OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, user::BrowserSessionRepository, Repository, @@ -217,9 +215,9 @@ async fn authorization_code_grant( ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - // TODO: there is a bunch of unnecessary cloning here - // TODO: handle "not found" cases - let authz_grant = lookup_grant_by_code(&mut txn, &grant.code) + let authz_grant = txn + .oauth2_authorization_grant() + .find_by_code(&grant.code) .await? .ok_or(RouteError::GrantNotFound)?; @@ -301,18 +299,15 @@ async fn authorization_code_grant( let access_token_str = TokenType::AccessToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - let access_token = - add_access_token(&mut txn, &mut rng, &clock, &session, access_token_str, ttl).await?; + let access_token = txn + .oauth2_access_token() + .add(&mut rng, &clock, &session, access_token_str, ttl) + .await?; - let refresh_token = add_refresh_token( - &mut txn, - &mut rng, - &clock, - &session, - &access_token, - refresh_token_str, - ) - .await?; + let refresh_token = txn + .oauth2_refresh_token() + .add(&mut rng, &clock, &session, &access_token, refresh_token_str) + .await?; let id_token = if session.scope.contains(&scope::OPENID) { let mut claims = HashMap::new(); @@ -360,7 +355,9 @@ async fn authorization_code_grant( params = params.with_id_token(id_token); } - exchange_grant(&mut txn, &clock, authz_grant).await?; + txn.oauth2_authorization_grant() + .exchange(&clock, authz_grant) + .await?; txn.commit().await?; @@ -374,7 +371,9 @@ async fn refresh_token_grant( ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let refresh_token = lookup_refresh_token(&mut txn, &grant.refresh_token) + let refresh_token = txn + .oauth2_refresh_token() + .find_by_token(&grant.refresh_token) .await? .ok_or(RouteError::InvalidGrant)?; @@ -397,31 +396,32 @@ async fn refresh_token_grant( let access_token_str = TokenType::AccessToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - let new_access_token = add_access_token( - &mut txn, - &mut rng, - &clock, - &session, - access_token_str.clone(), - ttl, - ) - .await?; + let new_access_token = txn + .oauth2_access_token() + .add(&mut rng, &clock, &session, access_token_str.clone(), ttl) + .await?; - let new_refresh_token = add_refresh_token( - &mut txn, - &mut rng, - &clock, - &session, - &new_access_token, - refresh_token_str, - ) - .await?; + let new_refresh_token = txn + .oauth2_refresh_token() + .add( + &mut rng, + &clock, + &session, + &new_access_token, + refresh_token_str, + ) + .await?; - let refresh_token = consume_refresh_token(&mut txn, &clock, refresh_token).await?; + let refresh_token = txn + .oauth2_refresh_token() + .consume(&clock, refresh_token) + .await?; if let Some(access_token_id) = refresh_token.access_token_id { - if let Some(access_token) = lookup_access_token(&mut txn, access_token_id).await? { - revoke_access_token(&mut txn, &clock, access_token).await?; + if let Some(access_token) = txn.oauth2_access_token().lookup(access_token_id).await? { + txn.oauth2_access_token() + .revoke(&clock, access_token) + .await?; } } diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index 3872588f0..57d537628 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -15,7 +15,7 @@ use anyhow::Context; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - compat::CompatSsoLoginRepository, oauth2::authorization_grant::get_grant_by_id, + compat::CompatSsoLoginRepository, oauth2::OAuth2AuthorizationGrantRepository, upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; @@ -46,7 +46,9 @@ impl OptionalPostAuthAction { let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let ctx = match action { PostAuthAction::ContinueAuthorizationGrant { id } => { - let grant = get_grant_by_id(conn, id) + let grant = conn + .oauth2_authorization_grant() + .lookup(id) .await? .context("Failed to load authorization grant")?; let grant = Box::new(grant); diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 6e7082dd8..5dd182501 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -1,5 +1,19 @@ { "db": "PostgreSQL", + "015f7ad7c8d5403ce4dfb71d598fd9af472689d5aef7c1c4b1c594ca57c02237": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz", + "Uuid" + ] + } + }, + "query": "\n UPDATE oauth2_authorization_grants\n SET fulfilled_at = $2\n , oauth2_session_id = $3\n WHERE oauth2_authorization_grant_id = $1\n " + }, "08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35": { "describe": { "columns": [ @@ -113,6 +127,18 @@ }, "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " }, + "1a8701f5672de052bb766933f60b93249acc7237b996e8b93cd61b9f69c902ff": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Timestamptz" + ] + } + }, + "query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n " + }, "1d372f36c382ab16264cea54537af3544ea6d6d75d10b432b07dbd0dadd2fa4e": { "describe": { "columns": [ @@ -183,7 +209,7 @@ }, "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n " }, - "262bee715889dc3e608639549600a131e641951ff979634e7c97afc74bbc1605": { + "2564bf6366eb59268c41fb25bb40d0e4e9e1fd1f9ea53b7a359c9025d7304223": { "describe": { "columns": [], "nullable": [], @@ -194,7 +220,7 @@ ] } }, - "query": "\n UPDATE oauth2_authorization_grants\n SET exchanged_at = $2\n WHERE oauth2_authorization_grant_id = $1\n " + "query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n " }, "3d66f3121b11ce923b9c60609b510a8ca899640e78cc8f5b03168622928ffe94": { "describe": { @@ -360,22 +386,6 @@ }, "query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE refresh_token = $1\n " }, - "43a5cafbdc8037e9fb779812a0793cf0859902aa0dc8d25d4c33d231d3d1118b": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO oauth2_access_tokens\n (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " - }, "446a8d7bd8532a751810401adfab924dc20785c91770ed43d62df2e590e8da71": { "describe": { "columns": [ @@ -420,26 +430,55 @@ }, "query": "\n SELECT up.user_password_id\n , up.hashed_password\n , up.version\n , up.upgraded_from_id\n , up.created_at\n FROM user_passwords up\n WHERE up.user_id = $1\n ORDER BY up.created_at DESC\n LIMIT 1\n " }, - "46c5ae7052504bfd7b94f20e61b9cf92570779a794bccda23dd654fb8523f340": { + "477f79556e5777b38feb85013b4f04dbb8230e4b0b0bcc45f669d7b8d0b91db4": { "describe": { "columns": [ { - "name": "fulfilled_at!: DateTime", + "name": "oauth2_access_token_id", "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "revoked_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" } ], "nullable": [ - true + false, + false, + false, + false, + true, + false ], "parameters": { "Left": [ - "Uuid", - "Uuid" + "Text" ] } }, - "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n oauth2_session_id = os.oauth2_session_id,\n fulfilled_at = os.created_at\n FROM oauth2_sessions os\n WHERE\n og.oauth2_authorization_grant_id = $1\n AND os.oauth2_session_id = $2\n RETURNING fulfilled_at AS \"fulfilled_at!: DateTime\"\n " + "query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE access_token = $1\n " }, "478f0ad710da8bfd803c6cddd982bc504d1b6bd0f5283de53c8c7b1b4b7dafd4": { "describe": { @@ -497,6 +536,134 @@ }, "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n\n FROM compat_sso_logins\n WHERE login_token = $1\n " }, + "496813daf6f8486353e7f509a64362626daebb0121c3c9420b96e2d8157f1e07": { + "describe": { + "columns": [ + { + "name": "oauth2_authorization_grant_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "cancelled_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "fulfilled_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "exchanged_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "scope", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "state", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "redirect_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "response_mode", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "nonce", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "max_age", + "ordinal": 10, + "type_info": "Int4" + }, + { + "name": "oauth2_client_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "authorization_code", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "response_type_code", + "ordinal": 13, + "type_info": "Bool" + }, + { + "name": "response_type_id_token", + "ordinal": 14, + "type_info": "Bool" + }, + { + "name": "code_challenge", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "code_challenge_method", + "ordinal": 16, + "type_info": "Text" + }, + { + "name": "requires_consent", + "ordinal": 17, + "type_info": "Bool" + }, + { + "name": "oauth2_session_id", + "ordinal": 18, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_authorization_grant_id\n , created_at\n , cancelled_at\n , fulfilled_at\n , exchanged_at\n , scope\n , state\n , redirect_uri\n , response_mode\n , nonce\n , max_age\n , oauth2_client_id\n , authorization_code\n , response_type_code\n , response_type_id_token\n , code_challenge\n , code_challenge_method\n , requires_consent\n , oauth2_session_id\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n " + }, "4d79ce892e4595edb8b801e94fb0cbef28facdfd2e45d1c72c57f47418fbe24b": { "describe": { "columns": [], @@ -511,27 +678,6 @@ }, "query": "\n UPDATE compat_sso_logins\n SET\n compat_session_id = $2,\n fulfilled_at = $3\n WHERE\n compat_sso_login_id = $1\n " }, - "51158bfcaa1a8d8e051bffe7c5ba0369bf53fb162f7622626054e89e68fc07bd": { - "describe": { - "columns": [ - { - "name": "scope_token", - "ordinal": 0, - "type_info": "Text" - } - ], - "nullable": [ - false - ], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n " - }, "53ad718642644b47a2d49f768d81bd993088526923769a9147281686c2d47591": { "describe": { "columns": [ @@ -592,18 +738,6 @@ }, "query": "\n INSERT INTO oauth2_sessions\n ( oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n )\n VALUES ($1, $2, $3, $4, $5)\n " }, - "5b5d5c82da37c6f2d8affacfb02119965c04d1f2a9cc53dbf5bd4c12584969a0": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Timestamptz" - ] - } - }, - "query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n " - }, "5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": { "describe": { "columns": [], @@ -619,22 +753,6 @@ }, "query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n " }, - "62d05e8e4317bdb180298737d422e64d161c5ad3813913a6f7d67a53569ea76a": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "UuidArray", - "Uuid", - "Uuid", - "TextArray", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO oauth2_consents\n (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)\n SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)\n ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5\n " - }, "64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f": { "describe": { "columns": [], @@ -739,18 +857,133 @@ }, "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n " }, - "6bf0da5ba3dd07b499193a2e0ddeea6e712f9df8f7f28874ff56a952a9f10e54": { + "6a3b543ec53ce242866d1e84de26728e6dd275cae745f9c646e3824d859c5384": { "describe": { - "columns": [], - "nullable": [], + "columns": [ + { + "name": "oauth2_authorization_grant_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "cancelled_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "fulfilled_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "exchanged_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "scope", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "state", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "redirect_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "response_mode", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "nonce", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "max_age", + "ordinal": 10, + "type_info": "Int4" + }, + { + "name": "oauth2_client_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "authorization_code", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "response_type_code", + "ordinal": 13, + "type_info": "Bool" + }, + { + "name": "response_type_id_token", + "ordinal": 14, + "type_info": "Bool" + }, + { + "name": "code_challenge", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "code_challenge_method", + "ordinal": 16, + "type_info": "Text" + }, + { + "name": "requires_consent", + "ordinal": 17, + "type_info": "Bool" + }, + { + "name": "oauth2_session_id", + "ordinal": 18, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true + ], "parameters": { "Left": [ - "Uuid", - "Timestamptz" + "Text" ] } }, - "query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n " + "query": "\n SELECT oauth2_authorization_grant_id\n , created_at\n , cancelled_at\n , fulfilled_at\n , exchanged_at\n , scope\n , state\n , redirect_uri\n , response_mode\n , nonce\n , max_age\n , oauth2_client_id\n , authorization_code\n , response_type_code\n , response_type_id_token\n , code_challenge\n , code_challenge_method\n , requires_consent\n , oauth2_session_id\n FROM\n oauth2_authorization_grants\n\n WHERE authorization_code = $1\n " }, "6e21e7d816f806da9bb5176931bdb550dee05c44c9d93f53df95fe3b4a840347": { "describe": { @@ -1091,18 +1324,26 @@ }, "query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n\n WHERE oauth2_client_id = ANY($1::uuid[])\n " }, - "874e677f82c221c5bb621c12f293bcef4e70c68c87ec003fcd475bcb994b5a4c": { + "8b7297c263336d70c2b647212b16f7ae39bc5cb1572e3a2dcfcd67f196a1fa39": { "describe": { - "columns": [], - "nullable": [], + "columns": [ + { + "name": "scope_token", + "ordinal": 0, + "type_info": "Text" + } + ], + "nullable": [ + false + ], "parameters": { "Left": [ "Uuid", - "Timestamptz" + "Uuid" ] } }, - "query": "\n UPDATE oauth2_refresh_tokens\n SET consumed_at = $2\n WHERE oauth2_refresh_token_id = $1\n " + "query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n " }, "8f7a9fb1f24c24f8dbc3c193df2a742c9ac730ab958587b67297de2d4b843863": { "describe": { @@ -1240,6 +1481,22 @@ }, "query": "\n SELECT EXISTS(\n SELECT 1 FROM users WHERE username = $1\n ) AS \"exists!\"\n " }, + "9a6c197ff4ad80217262d48f8792ce7e16bc5df0677c7cd4ecb4fdbc5ee86395": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "UuidArray", + "Uuid", + "Uuid", + "TextArray", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO oauth2_consents\n (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)\n SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)\n ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5\n " + }, "9f7bdc034c618e47e49c467d0d7f5b8c297d055abe248cc876dbc12c5a7dc920": { "describe": { "columns": [], @@ -1256,6 +1513,22 @@ }, "query": "\n INSERT INTO compat_refresh_tokens\n (compat_refresh_token_id, compat_session_id,\n compat_access_token_id, refresh_token, created_at)\n VALUES ($1, $2, $3, $4, $5)\n " }, + "a2f7433f06fb4f6a7ad5ac6c1db18705276bce41e9b19d5d7e910ad4b767fb5e": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,\n refresh_token, created_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " + }, "a300fe99c95679c5664646a6a525c0491829e97db45f3234483872ed38436322": { "describe": { "columns": [ @@ -1300,17 +1573,55 @@ }, "query": "\n SELECT user_email_id\n , user_id\n , email\n , created_at\n , confirmed_at\n FROM user_emails\n\n WHERE user_id = $1\n\n ORDER BY email ASC\n " }, - "a5a7dad633396e087239d5629092e4a305908ffce9c2610db07372f719070546": { + "a6fa7811d0a7c62c7cccff96dc82db5b25462fa7669fde1941ccab4712585b20": { "describe": { - "columns": [], - "nullable": [], + "columns": [ + { + "name": "oauth2_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + true, + false + ], "parameters": { "Left": [ "Uuid" ] } }, - "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n requires_consent = 'f'\n WHERE\n og.oauth2_authorization_grant_id = $1\n " + "query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE oauth2_refresh_token_id = $1\n " }, "a7f780528882a2ae66c45435215763eed0582264861436eab3f862e3eb12cab1": { "describe": { @@ -1328,134 +1639,6 @@ }, "query": "\n INSERT INTO compat_access_tokens\n (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " }, - "aa2fd69c595f94d8598715766a79671dba8f87b9d7af6ac30e3fa1fbc8cce28a": { - "describe": { - "columns": [ - { - "name": "oauth2_authorization_grant_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_cancelled_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_fulfilled_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_exchanged_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_scope", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_state", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_redirect_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_mode", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_nonce", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_max_age", - "ordinal": 10, - "type_info": "Int4" - }, - { - "name": "oauth2_client_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_code", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_type_code", - "ordinal": 13, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_response_type_id_token", - "ordinal": 14, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_code_challenge", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_code_challenge_method", - "ordinal": 16, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_requires_consent", - "ordinal": 17, - "type_info": "Bool" - }, - { - "name": "oauth2_session_id?", - "ordinal": 18, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - true, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE authorization_code = $1\n " - }, "ab34912b42a48a8b5c8d63e271b99b7d0b690a2471873c6654b1b6cf2079b95c": { "describe": { "columns": [], @@ -1469,6 +1652,22 @@ }, "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n WHERE compat_session_id = $1\n " }, + "afa86e79e3de2a83265cb0db8549d378a2f11b2a27bbd86d60558318c87eb698": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO oauth2_access_tokens\n (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " + }, "aff08a8caabeb62f4929e6e901e7ca7c55e284c18c5c1d1e78821dd9bc961412": { "describe": { "columns": [ @@ -1514,184 +1713,6 @@ }, "query": "\n SELECT user_email_id\n , user_id\n , email\n , created_at\n , confirmed_at\n FROM user_emails\n\n WHERE user_id = $1 AND email = $2\n " }, - "b12f7ba71ad522261f54ffbb6739a7a06214b4f01e3ed6f7fdaa2033d249f3fb": { - "describe": { - "columns": [ - { - "name": "oauth2_authorization_grant_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_cancelled_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_fulfilled_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_exchanged_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_scope", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_state", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_redirect_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_mode", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_nonce", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_max_age", - "ordinal": 10, - "type_info": "Int4" - }, - { - "name": "oauth2_client_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_code", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_type_code", - "ordinal": 13, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_response_type_id_token", - "ordinal": 14, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_code_challenge", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_code_challenge_method", - "ordinal": 16, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_requires_consent", - "ordinal": 17, - "type_info": "Bool" - }, - { - "name": "oauth2_session_id?", - "ordinal": 18, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - true, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n " - }, - "b20e846843cf88810fbc0f4b0fa3159117f035841758d682d90c614c374f6059": { - "describe": { - "columns": [ - { - "name": "oauth2_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "revoked_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id", - "ordinal": 5, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE oauth2_access_token_id = $1\n " - }, "b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537": { "describe": { "columns": [], @@ -1722,6 +1743,19 @@ }, "query": "\n INSERT INTO user_email_confirmation_codes\n (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " }, + "b6a6f5386dc89e4bc2ce56d578a29341848fce336d339b6bbf425956f5ed5032": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE oauth2_refresh_tokens\n SET consumed_at = $2\n WHERE oauth2_refresh_token_id = $1\n " + }, "b700dc3f7d0f86f4904725d8357e34b7e457f857ed37c467c314142877fd5367": { "describe": { "columns": [], @@ -1793,21 +1827,7 @@ }, "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n " }, - "c1d90a7f2287ec779c81a521fab19e5ede3fa95484033e0312c30d9b6ecc03f0": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " - }, - "c88376abdba124ff0487a9a69d2345c7d69d7394f355111ec369cfa6d45fb40f": { + "c0ed9d70e496433d8686a499055d8a8376459109b6154a2c0c13b28462afa523": { "describe": { "columns": [], "nullable": [], @@ -1831,7 +1851,34 @@ ] } }, - "query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)\n " + "query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)\n " + }, + "c1d90a7f2287ec779c81a521fab19e5ede3fa95484033e0312c30d9b6ecc03f0": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " + }, + "c5e7dbb22488aca427b85b3415bd1f1a1766ff865f2e08a5daa095d2a1ccbd56": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE oauth2_authorization_grants\n SET exchanged_at = $2\n WHERE oauth2_authorization_grant_id = $1\n " }, "d0b403e9c843ef19fa5ad60bec32ebf14a1ba0d01681c3836366d3f55e7851f4": { "describe": { @@ -1866,121 +1913,17 @@ }, "query": "\n SELECT COUNT(*)\n FROM user_emails\n WHERE user_id = $1\n " }, - "d1f1aac41bb2f0d194f9b3d846663c267865d0d22dd5fa8a668daf29dca88d36": { - "describe": { - "columns": [ - { - "name": "oauth2_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "consumed_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_id", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "oauth2_session_id", - "ordinal": 5, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - true, - true, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n " - }, - "d2b1af24f88b2f05eb219f7cbdcfa9680bafe9f77fa1772097875b3718bd1aff": { - "describe": { - "columns": [ - { - "name": "oauth2_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "revoked_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id", - "ordinal": 5, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE access_token = $1\n " - }, - "d8677b3b6ee594c230fad98c1aa1c6e3d983375bf5b701c7b52468e7f906abf9": { + "d83421d4a16f4ad084dd0db5abb56d3688851c36a48a50aa6104e8291e73630d": { "describe": { "columns": [], "nullable": [], "parameters": { "Left": [ - "Uuid", - "Uuid", - "Uuid", - "Text", - "Timestamptz" + "Uuid" ] } }, - "query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,\n refresh_token, created_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " + "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n requires_consent = 'f'\n WHERE\n og.oauth2_authorization_grant_id = $1\n " }, "db90cbc406a399f5447bd2c1d8018464f83b927dec620353516c0285b76fcf24": { "describe": { @@ -2111,6 +2054,56 @@ }, "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " }, + "dd16942318bf38d9a245b2c86fedd3cbd6b65e7a13465552d79cd3c022122fd4": { + "describe": { + "columns": [ + { + "name": "oauth2_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "revoked_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE oauth2_access_token_id = $1\n " + }, "ddb22dd9ae9367af65a607e1fdc48b3d9581d67deea0c168f24e02090082bb82": { "describe": { "columns": [ @@ -2262,6 +2255,56 @@ }, "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " }, + "e709869c062ac50248b1f9f8f808cc2f5e7bef58a6c2e42a7bb0c1cb8f508671": { + "describe": { + "columns": [ + { + "name": "oauth2_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n " + }, "f0ace1af3775192a555c4ebb59b81183f359771f9f77e5fad759d38d872541d1": { "describe": { "columns": [ diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index 8389dff44..db10ed72e 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,67 +12,61 @@ // See the License for the specific language governing permissions and // limitations under the License. +use async_trait::async_trait; use chrono::{DateTime, Duration, Utc}; use mas_data_model::{AccessToken, AccessTokenState, Session}; -use rand::Rng; -use sqlx::{PgConnection, PgExecutor}; +use rand::RngCore; +use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseError, LookupResultExt}; +use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; -#[tracing::instrument( - skip_all, - fields( - %session.id, - user_session.id = %session.user_session_id, - client.id = %session.client_id, - access_token.id, - ), - err, -)] -pub async fn add_access_token( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - session: &Session, - access_token: String, - expires_after: Duration, -) -> Result { - let created_at = clock.now(); - let expires_at = created_at + expires_after; - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); +#[async_trait] +pub trait OAuth2AccessTokenRepository: Send + Sync { + type Error; - tracing::Span::current().record("access_token.id", tracing::field::display(id)); + /// Lookup an access token by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - sqlx::query!( - r#" - INSERT INTO oauth2_access_tokens - (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at) - VALUES - ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - &access_token, - created_at, - expires_at, - ) - .execute(executor) - .await?; + /// Find an access token by its token + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error>; - Ok(AccessToken { - id, - state: AccessTokenState::default(), - access_token, - session_id: session.id, - created_at, - expires_at, - }) + /// Add a new access token to the database + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + session: &Session, + access_token: String, + expires_after: Duration, + ) -> Result; + + /// Revoke an access token + async fn revoke( + &mut self, + clock: &Clock, + access_token: AccessToken, + ) -> Result; + + /// Cleanup expired access tokens + async fn cleanup_expired(&mut self, clock: &Clock) -> Result; } -#[derive(Debug)] -pub struct OAuth2AccessTokenLookup { +pub struct PgOAuth2AccessTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2AccessTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct OAuth2AccessTokenLookup { oauth2_access_token_id: Uuid, oauth2_session_id: Uuid, access_token: String, @@ -99,118 +93,164 @@ impl From for AccessToken { } } -#[tracing::instrument(skip_all, err)] -pub async fn find_access_token( - conn: &mut PgConnection, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - OAuth2AccessTokenLookup, - r#" - SELECT oauth2_access_token_id - , access_token - , created_at - , expires_at - , revoked_at - , oauth2_session_id +#[async_trait] +impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> { + type Error = DatabaseError; - FROM oauth2_access_tokens + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2AccessTokenLookup, + r#" + SELECT oauth2_access_token_id + , access_token + , created_at + , expires_at + , revoked_at + , oauth2_session_id - WHERE access_token = $1 - "#, - token, - ) - .fetch_one(&mut *conn) - .await - .to_option()?; + FROM oauth2_access_tokens - let Some(res) = res else { return Ok(None) }; + WHERE oauth2_access_token_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; - Ok(Some(res.into())) -} - -#[tracing::instrument( - skip_all, - fields(access_token.id = %access_token_id), - err, -)] -pub async fn lookup_access_token( - conn: &mut PgConnection, - access_token_id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - OAuth2AccessTokenLookup, - r#" - SELECT oauth2_access_token_id - , access_token - , created_at - , expires_at - , revoked_at - , oauth2_session_id - - FROM oauth2_access_tokens - - WHERE oauth2_access_token_id = $1 - "#, - Uuid::from(access_token_id), - ) - .fetch_one(&mut *conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) -} - -#[tracing::instrument( - skip_all, - fields( - %access_token.id, - session.id = %access_token.session_id, - ), - err, -)] -pub async fn revoke_access_token( - executor: impl PgExecutor<'_>, - clock: &Clock, - access_token: AccessToken, -) -> Result { - let revoked_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_access_tokens - SET revoked_at = $2 - WHERE oauth2_access_token_id = $1 - "#, - Uuid::from(access_token.id), - revoked_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - access_token - .revoke(revoked_at) - .map_err(DatabaseError::to_invalid_operation) -} - -pub async fn cleanup_expired( - executor: impl PgExecutor<'_>, - clock: &Clock, -) -> Result { - // Cleanup token which expired more than 15 minutes ago - let threshold = clock.now() - Duration::minutes(15); - let res = sqlx::query!( - r#" - DELETE FROM oauth2_access_tokens - WHERE expires_at < $1 - "#, - threshold, - ) - .execute(executor) - .await?; - - Ok(res.rows_affected()) + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_access_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2AccessTokenLookup, + r#" + SELECT oauth2_access_token_id + , access_token + , created_at + , expires_at + , revoked_at + , oauth2_session_id + + FROM oauth2_access_tokens + + WHERE access_token = $1 + "#, + access_token, + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_access_token.add", + skip_all, + fields( + db.statement, + %session.id, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + access_token.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + session: &Session, + access_token: String, + expires_after: Duration, + ) -> Result { + let created_at = clock.now(); + let expires_at = created_at + expires_after; + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + + tracing::Span::current().record("access_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_access_tokens + (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at) + VALUES + ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(session.id), + &access_token, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(AccessToken { + id, + state: AccessTokenState::default(), + access_token, + session_id: session.id, + created_at, + expires_at, + }) + } + + async fn revoke( + &mut self, + clock: &Clock, + access_token: AccessToken, + ) -> Result { + let revoked_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_access_tokens + SET revoked_at = $2 + WHERE oauth2_access_token_id = $1 + "#, + Uuid::from(access_token.id), + revoked_at, + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + access_token + .revoke(revoked_at) + .map_err(DatabaseError::to_invalid_operation) + } + + async fn cleanup_expired(&mut self, clock: &Clock) -> Result { + // Cleanup token which expired more than 15 minutes ago + let threshold = clock.now() - Duration::minutes(15); + let res = sqlx::query!( + r#" + DELETE FROM oauth2_access_tokens + WHERE expires_at < $1 + "#, + threshold, + ) + .execute(&mut *self.conn) + .await?; + + Ok(res.rows_affected().try_into().unwrap_or(usize::MAX)) + } } diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index c5d969765..91df93138 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -14,138 +14,97 @@ use std::num::NonZeroU32; +use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{ AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session, }; use mas_iana::oauth::PkceCodeChallengeMethod; use oauth2_types::{requests::ResponseMode, scope::Scope}; -use rand::Rng; -use sqlx::{PgConnection, PgExecutor}; +use rand::RngCore; +use sqlx::PgConnection; use ulid::Ulid; use url::Url; use uuid::Uuid; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +use crate::{ + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, +}; -#[tracing::instrument( - skip_all, - fields( - %client.id, - grant.id, - ), - err, -)] -#[allow(clippy::too_many_arguments)] -pub async fn new_authorization_grant( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - client: Client, - redirect_uri: Url, - scope: Scope, - code: Option, - state: Option, - nonce: Option, - max_age: Option, - _acr_values: Option, - response_mode: ResponseMode, - response_type_id_token: bool, - requires_consent: bool, -) -> Result { - let code_challenge = code - .as_ref() - .and_then(|c| c.pkce.as_ref()) - .map(|p| &p.challenge); - let code_challenge_method = code - .as_ref() - .and_then(|c| c.pkce.as_ref()) - .map(|p| p.challenge_method.to_string()); - // TODO: this conversion is a bit ugly - let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)); - let code_str = code.as_ref().map(|c| &c.code); +#[async_trait] +pub trait OAuth2AuthorizationGrantRepository { + type Error; - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("grant.id", tracing::field::display(id)); + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + client: &Client, + redirect_uri: Url, + scope: Scope, + code: Option, + state: Option, + nonce: Option, + max_age: Option, + response_mode: ResponseMode, + response_type_id_token: bool, + requires_consent: bool, + ) -> Result; - sqlx::query!( - r#" - INSERT INTO oauth2_authorization_grants ( - oauth2_authorization_grant_id, - oauth2_client_id, - redirect_uri, - scope, - state, - nonce, - max_age, - response_mode, - code_challenge, - code_challenge_method, - response_type_code, - response_type_id_token, - authorization_code, - requires_consent, - created_at - ) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) - "#, - Uuid::from(id), - Uuid::from(client.id), - redirect_uri.to_string(), - scope.to_string(), - state, - nonce, - max_age_i32, - response_mode.to_string(), - code_challenge, - code_challenge_method, - code.is_some(), - response_type_id_token, - code_str, - requires_consent, - created_at, - ) - .execute(executor) - .await?; + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - Ok(AuthorizationGrant { - id, - stage: AuthorizationGrantStage::Pending, - code, - redirect_uri, - client_id: client.id, - scope, - state, - nonce, - max_age, - response_mode, - created_at, - response_type_id_token, - requires_consent, - }) + async fn find_by_code(&mut self, code: &str) + -> Result, Self::Error>; + + async fn fulfill( + &mut self, + clock: &Clock, + session: &Session, + authorization_grant: AuthorizationGrant, + ) -> Result; + + async fn exchange( + &mut self, + clock: &Clock, + authorization_grant: AuthorizationGrant, + ) -> Result; + + async fn give_consent( + &mut self, + authorization_grant: AuthorizationGrant, + ) -> Result; +} + +pub struct PgOAuth2AuthorizationGrantRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2AuthorizationGrantRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } } #[allow(clippy::struct_excessive_bools)] struct GrantLookup { oauth2_authorization_grant_id: Uuid, - oauth2_authorization_grant_created_at: DateTime, - oauth2_authorization_grant_cancelled_at: Option>, - oauth2_authorization_grant_fulfilled_at: Option>, - oauth2_authorization_grant_exchanged_at: Option>, - oauth2_authorization_grant_scope: String, - oauth2_authorization_grant_state: Option, - oauth2_authorization_grant_nonce: Option, - oauth2_authorization_grant_redirect_uri: String, - oauth2_authorization_grant_response_mode: String, - oauth2_authorization_grant_max_age: Option, - oauth2_authorization_grant_response_type_code: bool, - oauth2_authorization_grant_response_type_id_token: bool, - oauth2_authorization_grant_code: Option, - oauth2_authorization_grant_code_challenge: Option, - oauth2_authorization_grant_code_challenge_method: Option, - oauth2_authorization_grant_requires_consent: bool, + created_at: DateTime, + cancelled_at: Option>, + fulfilled_at: Option>, + exchanged_at: Option>, + scope: String, + state: Option, + nonce: Option, + redirect_uri: String, + response_mode: String, + max_age: Option, + response_type_code: bool, + response_type_id_token: bool, + authorization_code: Option, + code_challenge: Option, + code_challenge_method: Option, + requires_consent: bool, oauth2_client_id: Uuid, oauth2_session_id: Option, } @@ -156,20 +115,17 @@ impl TryFrom for AuthorizationGrant { #[allow(clippy::too_many_lines)] fn try_from(value: GrantLookup) -> Result { let id = value.oauth2_authorization_grant_id.into(); - let scope: Scope = value - .oauth2_authorization_grant_scope - .parse() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("scope") - .row(id) - .source(e) - })?; + let scope: Scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("scope") + .row(id) + .source(e) + })?; let stage = match ( - value.oauth2_authorization_grant_fulfilled_at, - value.oauth2_authorization_grant_exchanged_at, - value.oauth2_authorization_grant_cancelled_at, + value.fulfilled_at, + value.exchanged_at, + value.cancelled_at, value.oauth2_session_id, ) { (None, None, None, None) => AuthorizationGrantStage::Pending, @@ -198,10 +154,7 @@ impl TryFrom for AuthorizationGrant { } }; - let pkce = match ( - value.oauth2_authorization_grant_code_challenge, - value.oauth2_authorization_grant_code_challenge_method, - ) { + let pkce = match (value.code_challenge, value.code_challenge_method) { (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { Some(Pkce { challenge_method: PkceCodeChallengeMethod::Plain, @@ -222,44 +175,35 @@ impl TryFrom for AuthorizationGrant { } }; - let code: Option = match ( - value.oauth2_authorization_grant_response_type_code, - value.oauth2_authorization_grant_code, - pkce, - ) { - (false, None, None) => None, - (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }), - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("authorization_code") - .row(id), - ); - } - }; + let code: Option = + match (value.response_type_code, value.authorization_code, pkce) { + (false, None, None) => None, + (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }), + _ => { + return Err( + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("authorization_code") + .row(id), + ); + } + }; - let redirect_uri = value - .oauth2_authorization_grant_redirect_uri - .parse() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("redirect_uri") - .row(id) - .source(e) - })?; + let redirect_uri = value.redirect_uri.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("redirect_uri") + .row(id) + .source(e) + })?; - let response_mode = value - .oauth2_authorization_grant_response_mode - .parse() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("response_mode") - .row(id) - .source(e) - })?; + let response_mode = value.response_mode.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("response_mode") + .row(id) + .source(e) + })?; let max_age = value - .oauth2_authorization_grant_max_age + .max_age .map(u32::try_from) .transpose() .map_err(|e| { @@ -283,209 +227,330 @@ impl TryFrom for AuthorizationGrant { client_id: value.oauth2_client_id.into(), code, scope, - state: value.oauth2_authorization_grant_state, - nonce: value.oauth2_authorization_grant_nonce, + state: value.state, + nonce: value.nonce, max_age, response_mode, redirect_uri, - created_at: value.oauth2_authorization_grant_created_at, - response_type_id_token: value.oauth2_authorization_grant_response_type_id_token, - requires_consent: value.oauth2_authorization_grant_requires_consent, + created_at: value.created_at, + response_type_id_token: value.response_type_id_token, + requires_consent: value.requires_consent, }) } } -#[tracing::instrument( - skip_all, - fields(grant.id = %id), - err, -)] -pub async fn get_grant_by_id( - conn: &mut PgConnection, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - GrantLookup, - r#" - SELECT oauth2_authorization_grant_id - , created_at AS oauth2_authorization_grant_created_at - , cancelled_at AS oauth2_authorization_grant_cancelled_at - , fulfilled_at AS oauth2_authorization_grant_fulfilled_at - , exchanged_at AS oauth2_authorization_grant_exchanged_at - , scope AS oauth2_authorization_grant_scope - , state AS oauth2_authorization_grant_state - , redirect_uri AS oauth2_authorization_grant_redirect_uri - , response_mode AS oauth2_authorization_grant_response_mode - , nonce AS oauth2_authorization_grant_nonce - , max_age AS oauth2_authorization_grant_max_age - , oauth2_client_id AS oauth2_client_id - , authorization_code AS oauth2_authorization_grant_code - , response_type_code AS oauth2_authorization_grant_response_type_code - , response_type_id_token AS oauth2_authorization_grant_response_type_id_token - , code_challenge AS oauth2_authorization_grant_code_challenge - , code_challenge_method AS oauth2_authorization_grant_code_challenge_method - , requires_consent AS oauth2_authorization_grant_requires_consent - , oauth2_session_id AS "oauth2_session_id?" - FROM - oauth2_authorization_grants +#[async_trait] +impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> { + type Error = DatabaseError; - WHERE oauth2_authorization_grant_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(&mut *conn) - .await - .to_option()?; + #[tracing::instrument( + name = "db.oauth2_authorization_grant.add", + skip_all, + fields( + db.statement, + grant.id, + grant.scope = %scope, + %client.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + client: &Client, + redirect_uri: Url, + scope: Scope, + code: Option, + state: Option, + nonce: Option, + max_age: Option, + response_mode: ResponseMode, + response_type_id_token: bool, + requires_consent: bool, + ) -> Result { + let code_challenge = code + .as_ref() + .and_then(|c| c.pkce.as_ref()) + .map(|p| &p.challenge); + let code_challenge_method = code + .as_ref() + .and_then(|c| c.pkce.as_ref()) + .map(|p| p.challenge_method.to_string()); + // TODO: this conversion is a bit ugly + let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)); + let code_str = code.as_ref().map(|c| &c.code); - let Some(res) = res else { return Ok(None) }; + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("grant.id", tracing::field::display(id)); - Ok(Some(res.try_into()?)) -} - -#[tracing::instrument(skip_all, err)] -pub async fn lookup_grant_by_code( - conn: &mut PgConnection, - code: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - GrantLookup, - r#" - SELECT oauth2_authorization_grant_id - , created_at AS oauth2_authorization_grant_created_at - , cancelled_at AS oauth2_authorization_grant_cancelled_at - , fulfilled_at AS oauth2_authorization_grant_fulfilled_at - , exchanged_at AS oauth2_authorization_grant_exchanged_at - , scope AS oauth2_authorization_grant_scope - , state AS oauth2_authorization_grant_state - , redirect_uri AS oauth2_authorization_grant_redirect_uri - , response_mode AS oauth2_authorization_grant_response_mode - , nonce AS oauth2_authorization_grant_nonce - , max_age AS oauth2_authorization_grant_max_age - , oauth2_client_id AS oauth2_client_id - , authorization_code AS oauth2_authorization_grant_code - , response_type_code AS oauth2_authorization_grant_response_type_code - , response_type_id_token AS oauth2_authorization_grant_response_type_id_token - , code_challenge AS oauth2_authorization_grant_code_challenge - , code_challenge_method AS oauth2_authorization_grant_code_challenge_method - , requires_consent AS oauth2_authorization_grant_requires_consent - , oauth2_session_id AS "oauth2_session_id?" - FROM - oauth2_authorization_grants - - WHERE authorization_code = $1 - "#, - code, - ) - .fetch_one(&mut *conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) -} - -#[tracing::instrument( - skip_all, - fields( - %grant.id, - client.id = %grant.client_id, - %session.id, - user_session.id = %session.user_session_id, - ), - err, -)] -pub async fn fulfill_grant( - executor: impl PgExecutor<'_>, - mut grant: AuthorizationGrant, - session: Session, -) -> Result { - let fulfilled_at = sqlx::query_scalar!( - r#" - UPDATE oauth2_authorization_grants AS og - SET - oauth2_session_id = os.oauth2_session_id, - fulfilled_at = os.created_at - FROM oauth2_sessions os - WHERE - og.oauth2_authorization_grant_id = $1 - AND os.oauth2_session_id = $2 - RETURNING fulfilled_at AS "fulfilled_at!: DateTime" - "#, - Uuid::from(grant.id), - Uuid::from(session.id), - ) - .fetch_one(executor) - .await?; - - grant.stage = grant - .stage - .fulfill(fulfilled_at, &session) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(grant) -} - -#[tracing::instrument( - skip_all, - fields( - %grant.id, - client.id = %grant.client_id, - ), - err, -)] -pub async fn give_consent_to_grant( - executor: impl PgExecutor<'_>, - mut grant: AuthorizationGrant, -) -> Result { - sqlx::query!( - r#" - UPDATE oauth2_authorization_grants AS og - SET - requires_consent = 'f' - WHERE - og.oauth2_authorization_grant_id = $1 - "#, - Uuid::from(grant.id), - ) - .execute(executor) - .await?; - - grant.requires_consent = false; - - Ok(grant) -} - -#[tracing::instrument( - skip_all, - fields( - %grant.id, - client.id = %grant.client_id, - ), - err, -)] -pub async fn exchange_grant( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut grant: AuthorizationGrant, -) -> Result { - let exchanged_at = clock.now(); - sqlx::query!( - r#" - UPDATE oauth2_authorization_grants - SET exchanged_at = $2 - WHERE oauth2_authorization_grant_id = $1 - "#, - Uuid::from(grant.id), - exchanged_at, - ) - .execute(executor) - .await?; - - grant.stage = grant - .stage - .exchange(exchanged_at) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(grant) + sqlx::query!( + r#" + INSERT INTO oauth2_authorization_grants ( + oauth2_authorization_grant_id, + oauth2_client_id, + redirect_uri, + scope, + state, + nonce, + max_age, + response_mode, + code_challenge, + code_challenge_method, + response_type_code, + response_type_id_token, + authorization_code, + requires_consent, + created_at + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + "#, + Uuid::from(id), + Uuid::from(client.id), + redirect_uri.to_string(), + scope.to_string(), + state, + nonce, + max_age_i32, + response_mode.to_string(), + code_challenge, + code_challenge_method, + code.is_some(), + response_type_id_token, + code_str, + requires_consent, + created_at, + ) + .execute(&mut *self.conn) + .await?; + + Ok(AuthorizationGrant { + id, + stage: AuthorizationGrantStage::Pending, + code, + redirect_uri, + client_id: client.id, + scope, + state, + nonce, + max_age, + response_mode, + created_at, + response_type_id_token, + requires_consent, + }) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.lookup", + skip_all, + fields( + db.statement, + grant.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + GrantLookup, + r#" + SELECT oauth2_authorization_grant_id + , created_at + , cancelled_at + , fulfilled_at + , exchanged_at + , scope + , state + , redirect_uri + , response_mode + , nonce + , max_age + , oauth2_client_id + , authorization_code + , response_type_code + , response_type_id_token + , code_challenge + , code_challenge_method + , requires_consent + , oauth2_session_id + FROM + oauth2_authorization_grants + + WHERE oauth2_authorization_grant_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.find_by_code", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_code( + &mut self, + code: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + GrantLookup, + r#" + SELECT oauth2_authorization_grant_id + , created_at + , cancelled_at + , fulfilled_at + , exchanged_at + , scope + , state + , redirect_uri + , response_mode + , nonce + , max_age + , oauth2_client_id + , authorization_code + , response_type_code + , response_type_id_token + , code_challenge + , code_challenge_method + , requires_consent + , oauth2_session_id + FROM + oauth2_authorization_grants + + WHERE authorization_code = $1 + "#, + code, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.fulfill", + skip_all, + fields( + db.statement, + %grant.id, + client.id = %grant.client_id, + %session.id, + user_session.id = %session.user_session_id, + ), + err, + )] + async fn fulfill( + &mut self, + clock: &Clock, + session: &Session, + grant: AuthorizationGrant, + ) -> Result { + let fulfilled_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_authorization_grants + SET fulfilled_at = $2 + , oauth2_session_id = $3 + WHERE oauth2_authorization_grant_id = $1 + "#, + Uuid::from(grant.id), + fulfilled_at, + Uuid::from(session.id), + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + // XXX: check affected rows & new methods + let grant = grant + .fulfill(fulfilled_at, session) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(grant) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.exchange", + skip_all, + fields( + db.statement, + %grant.id, + client.id = %grant.client_id, + ), + err, + )] + async fn exchange( + &mut self, + clock: &Clock, + grant: AuthorizationGrant, + ) -> Result { + let exchanged_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_authorization_grants + SET exchanged_at = $2 + WHERE oauth2_authorization_grant_id = $1 + "#, + Uuid::from(grant.id), + exchanged_at, + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let grant = grant + .exchange(exchanged_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(grant) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.give_consent", + skip_all, + fields( + db.statement, + %grant.id, + client.id = %grant.client_id, + ), + err, + )] + async fn give_consent( + &mut self, + mut grant: AuthorizationGrant, + ) -> Result { + sqlx::query!( + r#" + UPDATE oauth2_authorization_grants AS og + SET + requires_consent = 'f' + WHERE + og.oauth2_authorization_grant_id = $1 + "#, + Uuid::from(grant.id), + ) + .execute(&mut *self.conn) + .await?; + + grant.requires_consent = false; + + Ok(grant) + } } diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index afe789db4..756017b8e 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -14,17 +14,21 @@ use std::{ collections::{BTreeMap, BTreeSet}, + str::FromStr, string::ToString, }; use async_trait::async_trait; -use mas_data_model::{Client, JwksOrJwksUri}; +use mas_data_model::{Client, JwksOrJwksUri, User}; use mas_iana::{ jose::JsonWebSignatureAlg, oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, }; use mas_jose::jwk::PublicJsonWebKeySet; -use oauth2_types::requests::GrantType; +use oauth2_types::{ + requests::GrantType, + scope::{Scope, ScopeToken}, +}; use rand::{Rng, RngCore}; use sqlx::PgConnection; use tracing::{info_span, Instrument}; @@ -87,6 +91,21 @@ pub trait OAuth2ClientRepository: Send + Sync { jwks_uri: Option, redirect_uris: Vec, ) -> Result; + + async fn get_consent_for_user( + &mut self, + client: &Client, + user: &User, + ) -> Result; + + async fn give_consent_for_user( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + client: &Client, + user: &User, + scope: &Scope, + ) -> Result<(), Self::Error>; } pub struct PgOAuth2ClientRepository<'c> { @@ -702,4 +721,94 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { initiate_login_uri: None, }) } + + #[tracing::instrument( + name = "db.oauth2_client.get_consent_for_user", + skip_all, + fields( + db.statement, + %user.id, + %client.id, + ), + err, + )] + async fn get_consent_for_user( + &mut self, + client: &Client, + user: &User, + ) -> Result { + let scope_tokens: Vec = sqlx::query_scalar!( + r#" + SELECT scope_token + FROM oauth2_consents + WHERE user_id = $1 AND oauth2_client_id = $2 + "#, + Uuid::from(user.id), + Uuid::from(client.id), + ) + .fetch_all(&mut *self.conn) + .await?; + + let scope: Result = scope_tokens + .into_iter() + .map(|s| ScopeToken::from_str(&s)) + .collect(); + + let scope = scope.map_err(|e| { + DatabaseInconsistencyError::on("oauth2_consents") + .column("scope_token") + .source(e) + })?; + + Ok(scope) + } + + #[tracing::instrument( + skip_all, + fields( + db.statement, + %user.id, + %client.id, + %scope, + ), + err, + )] + async fn give_consent_for_user( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + client: &Client, + user: &User, + scope: &Scope, + ) -> Result<(), Self::Error> { + let now = clock.now(); + let (tokens, ids): (Vec, Vec) = scope + .iter() + .map(|token| { + ( + token.to_string(), + Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_consents + (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at) + SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token) + ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5 + "#, + &ids, + Uuid::from(user.id), + Uuid::from(client.id), + &tokens, + now, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } } diff --git a/crates/storage/src/oauth2/consent.rs b/crates/storage/src/oauth2/consent.rs deleted file mode 100644 index c1a5080df..000000000 --- a/crates/storage/src/oauth2/consent.rs +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::str::FromStr; - -use mas_data_model::{Client, User}; -use oauth2_types::scope::{Scope, ScopeToken}; -use rand::Rng; -use sqlx::PgExecutor; -use ulid::Ulid; -use uuid::Uuid; - -use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %client.id, - ), - err, -)] -pub async fn fetch_client_consent( - executor: impl PgExecutor<'_>, - user: &User, - client: &Client, -) -> Result { - let scope_tokens: Vec = sqlx::query_scalar!( - r#" - SELECT scope_token - FROM oauth2_consents - WHERE user_id = $1 AND oauth2_client_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(client.id), - ) - .fetch_all(executor) - .await?; - - let scope: Result = scope_tokens - .into_iter() - .map(|s| ScopeToken::from_str(&s)) - .collect(); - - let scope = scope.map_err(|e| { - DatabaseInconsistencyError::on("oauth2_consents") - .column("scope_token") - .source(e) - })?; - - Ok(scope) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %client.id, - %scope, - ), - err, -)] -pub async fn insert_client_consent( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user: &User, - client: &Client, - scope: &Scope, -) -> Result<(), sqlx::Error> { - let now = clock.now(); - let (tokens, ids): (Vec, Vec) = scope - .iter() - .map(|token| { - ( - token.to_string(), - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), - ) - }) - .unzip(); - - sqlx::query!( - r#" - INSERT INTO oauth2_consents - (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at) - SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token) - ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5 - "#, - &ids, - Uuid::from(user.id), - Uuid::from(client.id), - &tokens, - now, - ) - .execute(executor) - .await?; - - Ok(()) -} diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index b02216a65..480c45155 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -12,14 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub mod access_token; +mod access_token; pub mod authorization_grant; mod client; -pub mod consent; -pub mod refresh_token; +mod refresh_token; mod session; pub use self::{ + access_token::{OAuth2AccessTokenRepository, PgOAuth2AccessTokenRepository}, + authorization_grant::{ + OAuth2AuthorizationGrantRepository, PgOAuth2AuthorizationGrantRepository, + }, client::{OAuth2ClientRepository, PgOAuth2ClientRepository}, + refresh_token::{OAuth2RefreshTokenRepository, PgOAuth2RefreshTokenRepository}, session::{OAuth2SessionRepository, PgOAuth2SessionRepository}, }; diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 29f6ab342..5d3bb0133 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -12,62 +12,55 @@ // See the License for the specific language governing permissions and // limitations under the License. +use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session}; -use rand::Rng; -use sqlx::{PgConnection, PgExecutor}; +use rand::RngCore; +use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseError}; +use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; -#[tracing::instrument( - skip_all, - fields( - %session.id, - user_session.id = %session.user_session_id, - client.id = %session.client_id, - refresh_token.id, - ), - err, -)] -pub async fn add_refresh_token( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - session: &Session, - access_token: &AccessToken, - refresh_token: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("refresh_token.id", tracing::field::display(id)); +#[async_trait] +pub trait OAuth2RefreshTokenRepository: Send + Sync { + type Error; - sqlx::query!( - r#" - INSERT INTO oauth2_refresh_tokens - (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id, - refresh_token, created_at) - VALUES - ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - Uuid::from(access_token.id), - refresh_token, - created_at, - ) - .execute(executor) - .await?; + /// Lookup a refresh token by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - Ok(RefreshToken { - id, - state: RefreshTokenState::default(), - session_id: session.id, - refresh_token, - access_token_id: Some(access_token.id), - created_at, - }) + /// Find a refresh token by its token + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error>; + + /// Add a new refresh token to the database + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + session: &Session, + access_token: &AccessToken, + refresh_token: String, + ) -> Result; + + /// Consume a refresh token + async fn consume( + &mut self, + clock: &Clock, + refresh_token: RefreshToken, + ) -> Result; +} + +pub struct PgOAuth2RefreshTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2RefreshTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } } struct OAuth2RefreshTokenLookup { @@ -79,75 +72,183 @@ struct OAuth2RefreshTokenLookup { oauth2_session_id: Uuid, } -#[tracing::instrument(skip_all, err)] -#[allow(clippy::too_many_lines)] -pub async fn lookup_refresh_token( - conn: &mut PgConnection, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - OAuth2RefreshTokenLookup, - r#" - SELECT oauth2_refresh_token_id - , refresh_token - , created_at - , consumed_at - , oauth2_access_token_id - , oauth2_session_id - FROM oauth2_refresh_tokens +impl From for RefreshToken { + fn from(value: OAuth2RefreshTokenLookup) -> Self { + let state = match value.consumed_at { + None => RefreshTokenState::Valid, + Some(consumed_at) => RefreshTokenState::Consumed { consumed_at }, + }; - WHERE refresh_token = $1 - "#, - token, - ) - .fetch_one(&mut *conn) - .await?; - - let state = match res.consumed_at { - None => RefreshTokenState::Valid, - Some(consumed_at) => RefreshTokenState::Consumed { consumed_at }, - }; - - let refresh_token = RefreshToken { - id: res.oauth2_refresh_token_id.into(), - state, - session_id: res.oauth2_session_id.into(), - refresh_token: res.refresh_token, - created_at: res.created_at, - access_token_id: res.oauth2_access_token_id.map(Ulid::from), - }; - - Ok(Some(refresh_token)) + RefreshToken { + id: value.oauth2_refresh_token_id.into(), + state, + session_id: value.oauth2_session_id.into(), + refresh_token: value.refresh_token, + created_at: value.created_at, + access_token_id: value.oauth2_access_token_id.map(Ulid::from), + } + } } -#[tracing::instrument( - skip_all, - fields( - %refresh_token.id, - ), - err, -)] -pub async fn consume_refresh_token( - executor: impl PgExecutor<'_>, - clock: &Clock, - refresh_token: RefreshToken, -) -> Result { - let consumed_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_refresh_tokens - SET consumed_at = $2 - WHERE oauth2_refresh_token_id = $1 - "#, - Uuid::from(refresh_token.id), - consumed_at, - ) - .execute(executor) - .await?; +#[async_trait] +impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> { + type Error = DatabaseError; - DatabaseError::ensure_affected_rows(&res, 1)?; + #[tracing::instrument( + name = "db.oauth2_refresh_token.lookup", + skip_all, + fields( + db.statement, + refresh_token.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2RefreshTokenLookup, + r#" + SELECT oauth2_refresh_token_id + , refresh_token + , created_at + , consumed_at + , oauth2_access_token_id + , oauth2_session_id + FROM oauth2_refresh_tokens - refresh_token - .consume(consumed_at) - .map_err(DatabaseError::to_invalid_operation) + WHERE oauth2_refresh_token_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2RefreshTokenLookup, + r#" + SELECT oauth2_refresh_token_id + , refresh_token + , created_at + , consumed_at + , oauth2_access_token_id + , oauth2_session_id + FROM oauth2_refresh_tokens + + WHERE refresh_token = $1 + "#, + refresh_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.add", + skip_all, + fields( + db.statement, + %session.id, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + refresh_token.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + session: &Session, + access_token: &AccessToken, + refresh_token: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("refresh_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_refresh_tokens + (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id, + refresh_token, created_at) + VALUES + ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(session.id), + Uuid::from(access_token.id), + refresh_token, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(RefreshToken { + id, + state: RefreshTokenState::default(), + session_id: session.id, + refresh_token, + access_token_id: Some(access_token.id), + created_at, + }) + } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.consume", + skip_all, + fields( + db.statement, + %refresh_token.id, + session.id = %refresh_token.session_id, + ), + err, + )] + async fn consume( + &mut self, + clock: &Clock, + refresh_token: RefreshToken, + ) -> Result { + let consumed_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_refresh_tokens + SET consumed_at = $2 + WHERE oauth2_refresh_token_id = $1 + "#, + Uuid::from(refresh_token.id), + consumed_at, + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + refresh_token + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation) + } } diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index ddd6e1ea8..b9bf5683f 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -19,7 +19,10 @@ use crate::{ PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository, PgCompatSsoLoginRepository, }, - oauth2::{PgOAuth2ClientRepository, PgOAuth2SessionRepository}, + oauth2::{ + PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository, + PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, + }, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, PgUpstreamOAuthSessionRepository, @@ -63,10 +66,22 @@ pub trait Repository { where Self: 'c; + type OAuth2AuthorizationGrantRepository<'c> + where + Self: 'c; + type OAuth2SessionRepository<'c> where Self: 'c; + type OAuth2AccessTokenRepository<'c> + where + Self: 'c; + + type OAuth2RefreshTokenRepository<'c> + where + Self: 'c; + type CompatSessionRepository<'c> where Self: 'c; @@ -91,7 +106,10 @@ pub trait Repository { fn user_password(&mut self) -> Self::UserPasswordRepository<'_>; fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>; fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>; + fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_>; fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>; + fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_>; + fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_>; fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>; fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>; fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>; @@ -107,7 +125,10 @@ impl Repository for PgConnection { type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; + type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c; type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; + type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c; + type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c; type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; @@ -145,10 +166,22 @@ impl Repository for PgConnection { PgOAuth2ClientRepository::new(self) } + fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { + PgOAuth2AuthorizationGrantRepository::new(self) + } + fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { PgOAuth2SessionRepository::new(self) } + fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { + PgOAuth2AccessTokenRepository::new(self) + } + + fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { + PgOAuth2RefreshTokenRepository::new(self) + } + fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { PgCompatSessionRepository::new(self) } @@ -175,7 +208,10 @@ impl<'t> Repository for Transaction<'t, Postgres> { type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; + type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c; type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; + type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c; + type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c; type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; @@ -213,10 +249,22 @@ impl<'t> Repository for Transaction<'t, Postgres> { PgOAuth2ClientRepository::new(self) } + fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { + PgOAuth2AuthorizationGrantRepository::new(self) + } + fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { PgOAuth2SessionRepository::new(self) } + fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { + PgOAuth2AccessTokenRepository::new(self) + } + + fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { + PgOAuth2RefreshTokenRepository::new(self) + } + fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { PgCompatSessionRepository::new(self) } diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index 5e72141e5..f4d11c6af 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -14,7 +14,7 @@ //! Database-related tasks -use mas_storage::Clock; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, Repository}; use sqlx::{Pool, Postgres}; use tracing::{debug, error, info}; @@ -32,7 +32,12 @@ impl std::fmt::Debug for CleanupExpired { #[async_trait::async_trait] impl Task for CleanupExpired { async fn run(&self) { - let res = mas_storage::oauth2::access_token::cleanup_expired(&self.0, &self.1).await; + let res = async move { + let mut conn = self.0.acquire().await?; + conn.oauth2_access_token().cleanup_expired(&self.1).await + } + .await; + match res { Ok(0) => { debug!("no token to clean up"); From 2d781d32ec17b9af4d5648852d3d0bb1050f30c9 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 13 Jan 2023 18:03:37 +0100 Subject: [PATCH 20/45] storage: wrap the postgres repository in a struct --- Cargo.lock | 1 - crates/axum-utils/Cargo.toml | 1 - crates/axum-utils/src/client_authorization.rs | 10 +- crates/axum-utils/src/session.rs | 11 +- crates/axum-utils/src/user_authorization.rs | 39 ++-- crates/cli/src/commands/manage.rs | 32 +-- crates/graphql/src/lib.rs | 34 ++- crates/graphql/src/model/compat_sessions.rs | 12 +- crates/graphql/src/model/oauth.rs | 20 +- crates/graphql/src/model/upstream_oauth.rs | 13 +- crates/graphql/src/model/users.rs | 37 ++-- crates/handlers/src/compat/login.rs | 36 ++-- .../handlers/src/compat/login_sso_complete.rs | 20 +- .../handlers/src/compat/login_sso_redirect.rs | 6 +- crates/handlers/src/compat/logout.rs | 12 +- crates/handlers/src/compat/refresh.rs | 20 +- crates/handlers/src/graphql.rs | 11 +- .../src/oauth2/authorization/complete.rs | 28 +-- .../handlers/src/oauth2/authorization/mod.rs | 20 +- crates/handlers/src/oauth2/consent.rs | 24 ++- crates/handlers/src/oauth2/introspection.rs | 33 +-- crates/handlers/src/oauth2/registration.rs | 9 +- crates/handlers/src/oauth2/token.rs | 52 ++--- crates/handlers/src/oauth2/userinfo.rs | 14 +- .../handlers/src/upstream_oauth2/authorize.rs | 10 +- .../handlers/src/upstream_oauth2/callback.rs | 21 +- crates/handlers/src/upstream_oauth2/link.rs | 48 ++--- .../handlers/src/views/account/emails/add.rs | 16 +- .../handlers/src/views/account/emails/mod.rs | 46 ++-- .../src/views/account/emails/verify.rs | 24 +-- crates/handlers/src/views/account/mod.rs | 10 +- crates/handlers/src/views/account/password.rs | 18 +- crates/handlers/src/views/index.rs | 5 +- crates/handlers/src/views/login.rs | 40 ++-- crates/handlers/src/views/logout.rs | 10 +- crates/handlers/src/views/reauth.rs | 20 +- crates/handlers/src/views/register.rs | 34 +-- crates/handlers/src/views/shared.rs | 19 +- crates/storage/src/lib.rs | 2 +- .../storage/src/oauth2/authorization_grant.rs | 2 +- crates/storage/src/oauth2/session.rs | 2 +- crates/storage/src/repository.rs | 201 +++++++----------- crates/storage/src/upstream_oauth2/mod.rs | 24 +-- crates/tasks/src/database.rs | 6 +- 44 files changed, 505 insertions(+), 548 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f299df15d..b780f20fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2673,7 +2673,6 @@ dependencies = [ "serde_json", "serde_urlencoded", "serde_with", - "sqlx", "thiserror", "tokio", "tower", diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index 0b6572c7c..1cbae2282 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -21,7 +21,6 @@ serde = "1.0.152" serde_with = "2.1.0" serde_urlencoded = "0.7.1" serde_json = "1.0.91" -sqlx = "0.6.2" thiserror = "1.0.38" tokio = "1.23.0" tower = { version = "0.4.13", features = ["util"] } diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 6f212369b..09090230c 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -31,10 +31,9 @@ use mas_http::HttpServiceExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_keystore::Encrypter; -use mas_storage::{oauth2::OAuth2ClientRepository, DatabaseError, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; -use sqlx::PgConnection; use thiserror::Error; use tower::{Service, ServiceExt}; @@ -73,7 +72,10 @@ pub enum Credentials { } impl Credentials { - pub async fn fetch(&self, conn: &mut PgConnection) -> Result, DatabaseError> { + pub async fn fetch<'r, R>(&self, repo: &'r mut R) -> Result, R::Error> + where + R: Repository, + { let client_id = match self { Credentials::None { client_id } | Credentials::ClientSecretBasic { client_id, .. } @@ -81,7 +83,7 @@ impl Credentials { | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id, }; - conn.oauth2_client().find_by_client_id(client_id).await + repo.oauth2_client().find_by_client_id(client_id).await } #[tracing::instrument(skip_all, err)] diff --git a/crates/axum-utils/src/session.rs b/crates/axum-utils/src/session.rs index 64887895e..719613675 100644 --- a/crates/axum-utils/src/session.rs +++ b/crates/axum-utils/src/session.rs @@ -14,9 +14,8 @@ use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use mas_data_model::BrowserSession; -use mas_storage::{user::BrowserSessionRepository, DatabaseError, Repository}; +use mas_storage::{user::BrowserSessionRepository, Repository}; use serde::{Deserialize, Serialize}; -use sqlx::PgConnection; use ulid::Ulid; use crate::CookieExt; @@ -44,17 +43,17 @@ impl SessionInfo { } /// Load the [`BrowserSession`] from database - pub async fn load_session( + pub async fn load_session( &self, - conn: &mut PgConnection, - ) -> Result, DatabaseError> { + repo: &mut R, + ) -> Result, R::Error> { let session_id = if let Some(id) = self.current { id } else { return Ok(None); }; - let maybe_session = conn + let maybe_session = repo .browser_session() .lookup(session_id) .await? diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index ec60103da..11d793122 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -30,10 +30,9 @@ use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode use mas_data_model::Session; use mas_storage::{ oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository}, - DatabaseError, Repository, + Repository, }; use serde::{de::DeserializeOwned, Deserialize}; -use sqlx::PgConnection; use thiserror::Error; #[derive(Debug, Deserialize)] @@ -53,22 +52,23 @@ enum AccessToken { } impl AccessToken { - async fn fetch( + async fn fetch( &self, - conn: &mut PgConnection, - ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> { + repo: &mut R, + ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> + { let token = match self { AccessToken::Form(t) | AccessToken::Header(t) => t, AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), }; - let token = conn + let token = repo .oauth2_access_token() .find_by_token(token.as_str()) .await? .ok_or(AuthorizationVerificationError::InvalidToken)?; - let session = conn + let session = repo .oauth2_session() .lookup(token.session_id) .await? @@ -86,17 +86,17 @@ pub struct UserAuthorization { impl UserAuthorization { // TODO: take scopes to validate as parameter - pub async fn protected_form( + pub async fn protected_form( self, - conn: &mut PgConnection, + repo: &mut R, now: DateTime, - ) -> Result<(Session, F), AuthorizationVerificationError> { + ) -> Result<(Session, F), AuthorizationVerificationError> { let form = match self.form { Some(f) => f, None => return Err(AuthorizationVerificationError::MissingForm), }; - let (token, session) = self.access_token.fetch(conn).await?; + let (token, session) = self.access_token.fetch(repo).await?; if !token.is_valid(now) || !session.is_valid() { return Err(AuthorizationVerificationError::InvalidToken); @@ -106,12 +106,12 @@ impl UserAuthorization { } // TODO: take scopes to validate as parameter - pub async fn protected( + pub async fn protected( self, - conn: &mut PgConnection, + repo: &mut R, now: DateTime, - ) -> Result { - let (token, session) = self.access_token.fetch(conn).await?; + ) -> Result> { + let (token, session) = self.access_token.fetch(repo).await?; if !token.is_valid(now) || !session.is_valid() { return Err(AuthorizationVerificationError::InvalidToken); @@ -129,7 +129,7 @@ pub enum UserAuthorizationError { } #[derive(Debug, Error)] -pub enum AuthorizationVerificationError { +pub enum AuthorizationVerificationError { #[error("missing token")] MissingToken, @@ -140,7 +140,7 @@ pub enum AuthorizationVerificationError { MissingForm, #[error(transparent)] - Internal(#[from] DatabaseError), + Internal(#[from] E), } enum BearerError { @@ -248,7 +248,10 @@ impl IntoResponse for UserAuthorizationError { } } -impl IntoResponse for AuthorizationVerificationError { +impl IntoResponse for AuthorizationVerificationError +where + E: ToString, +{ fn into_response(self) -> Response { match self { Self::MissingForm | Self::MissingToken => { diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index d159f3e30..153789400 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -21,7 +21,7 @@ use mas_storage::{ oauth2::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, - Clock, Repository, + Clock, PgRepository, Repository, }; use oauth2_types::scope::Scope; use rand::SeedableRng; @@ -202,8 +202,8 @@ impl Options { let pool = database_from_config(&database_config).await?; let password_manager = password_manager_from_config(&passwords_config).await?; - let mut txn = pool.begin().await?; - let user = txn + let mut repo = PgRepository::from_pool(&pool).await?; + let user = repo .user() .find_by_username(username) .await? @@ -213,12 +213,12 @@ impl Options { let (version, hashed_password) = password_manager.hash(&mut rng, password).await?; - txn.user_password() + repo.user_password() .add(&mut rng, &clock, &user, version, hashed_password, None) .await?; info!(%user.id, %user.username, "Password changed"); - txn.commit().await?; + repo.save().await?; Ok(()) } @@ -233,22 +233,22 @@ impl Options { let config: DatabaseConfig = root.load_config()?; let pool = database_from_config(&config).await?; - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; - let user = txn + let user = repo .user() .find_by_username(username) .await? .context("User not found")?; - let email = txn + let email = repo .user_email() .find(&user, email) .await? .context("Email not found")?; - let email = txn.user_email().mark_as_verified(&clock, email).await?; + let email = repo.user_email().mark_as_verified(&clock, email).await?; - txn.commit().await?; + repo.save().await?; info!(?email, "Email marked as verified"); Ok(()) @@ -261,12 +261,12 @@ impl Options { let pool = database_from_config(&config.database).await?; let encrypter = config.secrets.encrypter(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; for client in config.clients.iter() { let client_id = client.client_id; - let existing = txn.oauth2_client().lookup(client_id).await?.is_some(); + let existing = repo.oauth2_client().lookup(client_id).await?.is_some(); if !update && existing { warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients."); continue; @@ -288,7 +288,7 @@ impl Options { .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .transpose()?; - txn.oauth2_client() + repo.oauth2_client() .add_from_config( &mut rng, &clock, @@ -302,7 +302,7 @@ impl Options { .await?; } - txn.commit().await?; + repo.save().await?; Ok(()) } @@ -326,7 +326,7 @@ impl Options { let encrypter = config.secrets.encrypter(); let pool = database_from_config(&config.database).await?; let url_builder = UrlBuilder::new(config.http.public_base); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let requires_client_secret = token_endpoint_auth_method.requires_client_secret(); @@ -347,7 +347,7 @@ impl Options { .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .transpose()?; - let provider = conn + let provider = repo .upstream_oauth_provider() .add( &mut rng, diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index d01be16ce..6e58bec74 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -32,9 +32,9 @@ use async_graphql::{ }; use mas_storage::{ oauth2::OAuth2ClientRepository, - upstream_oauth2::UpstreamOAuthProviderRepository, + upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, user::{BrowserSessionRepository, UserEmailRepository}, - Repository, UpstreamOAuthLinkRepository, + PgRepository, Repository, }; use model::CreationEvent; use sqlx::PgPool; @@ -93,10 +93,9 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::OAuth2Client.extract_ulid(&id)?; - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; - let client = conn.oauth2_client().lookup(id).await?; + let client = repo.oauth2_client().lookup(id).await?; Ok(client.map(OAuth2Client)) } @@ -124,13 +123,12 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::BrowserSession.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let browser_session = conn.browser_session().lookup(id).await?; + let browser_session = repo.browser_session().lookup(id).await?; let ret = browser_session.and_then(|browser_session| { if browser_session.user.id == current_user.id { @@ -151,13 +149,12 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::UserEmail.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let user_email = conn + let user_email = repo .user_email() .lookup(id) .await? @@ -174,13 +171,12 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let link = conn.upstream_oauth_link().lookup(id).await?; + let link = repo.upstream_oauth_link().lookup(id).await?; // Ensure that the link belongs to the current user let link = link.filter(|link| link.user_id == Some(current_user.id)); @@ -195,10 +191,9 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?; - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; - let provider = conn.upstream_oauth_provider().lookup(id).await?; + let provider = repo.upstream_oauth_provider().lookup(id).await?; Ok(provider.map(UpstreamOAuth2Provider::new)) } @@ -215,7 +210,7 @@ impl RootQuery { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; query( after, @@ -223,7 +218,6 @@ impl RootQuery { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| { x.extract_for_type(NodeType::UpstreamOAuth2Provider) @@ -235,7 +229,7 @@ impl RootQuery { }) .transpose()?; - let page = conn + let page = repo .upstream_oauth_provider() .list_paginated(before_id, after_id, first, last) .await?; diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index 3c94c672a..a2196e36f 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -15,7 +15,9 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository}; +use mas_storage::{ + compat::CompatSessionRepository, user::UserRepository, PgRepository, Repository, +}; use sqlx::PgPool; use url::Url; @@ -35,8 +37,8 @@ impl CompatSession { /// The user authorized for this session. async fn user(&self, ctx: &Context<'_>) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let user = conn + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let user = repo .user() .lookup(self.0.user_id) .await? @@ -100,8 +102,8 @@ impl CompatSsoLogin { ) -> Result, async_graphql::Error> { let Some(session_id) = self.0.session_id() else { return Ok(None) }; - let mut conn = ctx.data::()?.acquire().await?; - let session = conn + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let session = repo .compat_session() .lookup(session_id) .await? diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 0ab2bc684..171c800fb 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,7 +14,9 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, Repository}; +use mas_storage::{ + oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, PgRepository, Repository, +}; use oauth2_types::scope::Scope; use sqlx::PgPool; use ulid::Ulid; @@ -36,8 +38,8 @@ impl OAuth2Session { /// OAuth 2.0 client used by this session. pub async fn client(&self, ctx: &Context<'_>) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let client = conn + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let client = repo .oauth2_client() .lookup(self.0.client_id) .await? @@ -56,8 +58,8 @@ impl OAuth2Session { &self, ctx: &Context<'_>, ) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let browser_session = conn + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let browser_session = repo .browser_session() .lookup(self.0.user_session_id) .await? @@ -68,8 +70,8 @@ impl OAuth2Session { /// User authorized for this session. pub async fn user(&self, ctx: &Context<'_>) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let browser_session = conn + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let browser_session = repo .browser_session() .lookup(self.0.user_session_id) .await? @@ -138,8 +140,8 @@ impl OAuth2Consent { /// OAuth 2.0 client for which the user granted access. pub async fn client(&self, ctx: &Context<'_>) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let client = conn + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let client = repo .oauth2_client() .lookup(self.client_id) .await? diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index 249a09286..4a4c223b4 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -16,7 +16,8 @@ use anyhow::Context as _; use async_graphql::{Context, Object, ID}; use chrono::{DateTime, Utc}; use mas_storage::{ - upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, Repository, + upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, PgRepository, + Repository, }; use sqlx::PgPool; @@ -102,9 +103,8 @@ impl UpstreamOAuth2Link { provider.clone() } else { // Fetch on-the-fly - let database = ctx.data::()?; - let mut conn = database.acquire().await?; - conn.upstream_oauth_provider() + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + repo.upstream_oauth_provider() .lookup(self.link.provider_id) .await? .context("Upstream OAuth 2.0 provider not found")? @@ -120,9 +120,8 @@ impl UpstreamOAuth2Link { user.clone() } else if let Some(user_id) = &self.link.user_id { // Fetch on-the-fly - let database = ctx.data::()?; - let mut conn = database.acquire().await?; - conn.user() + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + repo.user() .lookup(*user_id) .await? .context("User not found")? diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index b19a1ae12..9cd8d53bc 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -20,8 +20,9 @@ use chrono::{DateTime, Utc}; use mas_storage::{ compat::CompatSsoLoginRepository, oauth2::OAuth2SessionRepository, + upstream_oauth2::UpstreamOAuthLinkRepository, user::{BrowserSessionRepository, UserEmailRepository}, - Repository, UpstreamOAuthLinkRepository, + PgRepository, Repository, }; use sqlx::PgPool; @@ -63,10 +64,9 @@ impl User { &self, ctx: &Context<'_>, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; - Ok(conn.user_email().get_primary(&self.0).await?.map(UserEmail)) + Ok(repo.user_email().get_primary(&self.0).await?.map(UserEmail)) } /// Get the list of compatibility SSO logins, chronologically sorted @@ -81,7 +81,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; query( after, @@ -89,7 +89,6 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| x.extract_for_type(NodeType::CompatSsoLogin)) .transpose()?; @@ -97,7 +96,7 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::CompatSsoLogin)) .transpose()?; - let page = conn + let page = repo .compat_sso_login() .list_paginated(&self.0, before_id, after_id, first, last) .await?; @@ -128,7 +127,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; query( after, @@ -136,7 +135,6 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| x.extract_for_type(NodeType::BrowserSession)) .transpose()?; @@ -144,7 +142,7 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::BrowserSession)) .transpose()?; - let page = conn + let page = repo .browser_session() .list_active_paginated(&self.0, before_id, after_id, first, last) .await?; @@ -175,7 +173,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; query( after, @@ -183,7 +181,6 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| x.extract_for_type(NodeType::UserEmail)) .transpose()?; @@ -191,7 +188,7 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::UserEmail)) .transpose()?; - let page = conn + let page = repo .user_email() .list_paginated(&self.0, before_id, after_id, first, last) .await?; @@ -226,7 +223,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; query( after, @@ -234,7 +231,6 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| x.extract_for_type(NodeType::OAuth2Session)) .transpose()?; @@ -242,7 +238,7 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::OAuth2Session)) .transpose()?; - let page = conn + let page = repo .oauth2_session() .list_paginated(&self.0, before_id, after_id, first, last) .await?; @@ -273,7 +269,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; query( after, @@ -281,7 +277,6 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| { x.extract_for_type(NodeType::UpstreamOAuth2Link) @@ -293,7 +288,7 @@ impl User { }) .transpose()?; - let page = conn + let page = repo .upstream_oauth_link() .list_paginated(&self.0, before_id, after_id, first, last) .await?; @@ -347,8 +342,8 @@ pub struct UserEmailsPagination(mas_data_model::User); impl UserEmailsPagination { /// Identifies the total count of items in the connection. async fn total_count(&self, ctx: &Context<'_>) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let count = conn.user_email().count(&self.0).await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let count = repo.user_email().count(&self.0).await?; Ok(count) } } diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index e7376f722..f344f7e0e 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -22,11 +22,11 @@ use mas_storage::{ CompatSsoLoginRepository, }, user::{UserPasswordRepository, UserRepository}, - Clock, Repository, + Clock, PgRepository, Repository, }; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; -use sqlx::{PgPool, Postgres, Transaction}; +use sqlx::PgPool; use thiserror::Error; use zeroize::Zeroizing; @@ -199,14 +199,14 @@ pub(crate) async fn post( Json(input): Json, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session, user) = match input.credentials { Credentials::Password { identifier: Identifier::User { user }, password, - } => user_password_login(&password_manager, &mut txn, user, password).await?, + } => user_password_login(&password_manager, &mut repo, user, password).await?, - Credentials::Token { token } => token_login(&mut txn, &clock, &token).await?, + Credentials::Token { token } => token_login(&mut repo, &clock, &token).await?, _ => { return Err(RouteError::Unsupported); @@ -224,14 +224,14 @@ pub(crate) async fn post( }; let access_token = TokenType::CompatAccessToken.generate(&mut rng); - let access_token = txn + let access_token = repo .compat_access_token() .add(&mut rng, &clock, &session, access_token, expires_in) .await?; let refresh_token = if input.refresh_token { let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng); - let refresh_token = txn + let refresh_token = repo .compat_refresh_token() .add(&mut rng, &clock, &session, &access_token, refresh_token) .await?; @@ -240,7 +240,7 @@ pub(crate) async fn post( None }; - txn.commit().await?; + repo.save().await?; Ok(Json(ResponseBody { access_token: access_token.token, @@ -252,11 +252,11 @@ pub(crate) async fn post( } async fn token_login( - txn: &mut Transaction<'_, Postgres>, + repo: &mut PgRepository, clock: &Clock, token: &str, ) -> Result<(CompatSession, User), RouteError> { - let login = txn + let login = repo .compat_sso_login() .find_by_token(token) .await? @@ -300,40 +300,40 @@ async fn token_login( } }; - let session = txn + let session = repo .compat_session() .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; - let user = txn + let user = repo .user() .lookup(session.user_id) .await? .ok_or(RouteError::UserNotFound)?; - txn.compat_sso_login().exchange(clock, login).await?; + repo.compat_sso_login().exchange(clock, login).await?; Ok((session, user)) } async fn user_password_login( password_manager: &PasswordManager, - txn: &mut Transaction<'_, Postgres>, + repo: &mut PgRepository, username: String, password: String, ) -> Result<(CompatSession, User), RouteError> { let (clock, mut rng) = crate::clock_and_rng(); // Find the user - let user = txn + let user = repo .user() .find_by_username(&username) .await? .ok_or(RouteError::UserNotFound)?; // Lookup its password - let user_password = txn + let user_password = repo .user_password() .active(&user) .await? @@ -354,7 +354,7 @@ async fn user_password_login( if let Some((version, hashed_password)) = new_password_hash { // Save the upgraded password if needed - txn.user_password() + repo.user_password() .add( &mut rng, &clock, @@ -368,7 +368,7 @@ async fn user_password_login( // Now that the user credentials have been verified, start a new compat session let device = Device::generate(&mut rng); - let session = txn + let session = repo .compat_session() .add(&mut rng, &clock, &user, device) .await?; diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 333524246..7ca61ab2c 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -31,7 +31,7 @@ use mas_keystore::Encrypter; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; use mas_storage::{ compat::{CompatSessionRepository, CompatSsoLoginRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use serde::{Deserialize, Serialize}; @@ -60,12 +60,12 @@ pub async fn get( Query(params): Query, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -90,7 +90,7 @@ pub async fn get( return Ok((cookie_jar, destination.go()).into_response()); } - let login = conn + let login = repo .compat_sso_login() .lookup(id) .await? @@ -124,12 +124,12 @@ pub async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); cookie_jar.verify_form(clock.now(), form)?; - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -154,7 +154,7 @@ pub async fn post( return Ok((cookie_jar, destination.go()).into_response()); } - let login = txn + let login = repo .compat_sso_login() .lookup(id) .await? @@ -188,16 +188,16 @@ pub async fn post( }; let device = Device::generate(&mut rng); - let compat_session = txn + let compat_session = repo .compat_session() .add(&mut rng, &clock, &session.user, device) .await?; - txn.compat_sso_login() + repo.compat_sso_login() .fulfill(&clock, login, &compat_session) .await?; - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, Redirect::to(redirect_uri.as_str())).into_response()) } diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index 9c23b733f..befd3e323 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -19,7 +19,7 @@ use axum::{ }; use hyper::StatusCode; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; -use mas_storage::{compat::CompatSsoLoginRepository, Repository}; +use mas_storage::{compat::CompatSsoLoginRepository, PgRepository, Repository}; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use serde_with::serde; @@ -80,8 +80,8 @@ pub async fn get( } let token = Alphanumeric.sample_string(&mut rng, 32); - let mut conn = pool.acquire().await?; - let login = conn + let mut repo = PgRepository::from_pool(&pool).await?; + let login = repo .compat_sso_login() .add(&mut rng, &clock, token, redirect_url) .await?; diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 25125c72f..762f77b2f 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -18,7 +18,7 @@ use hyper::StatusCode; use mas_data_model::TokenType; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - Clock, Repository, + Clock, PgRepository, Repository, }; use sqlx::PgPool; use thiserror::Error; @@ -72,7 +72,7 @@ pub(crate) async fn post( maybe_authorization: Option>>, ) -> Result { let clock = Clock::default(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; @@ -83,23 +83,23 @@ pub(crate) async fn post( return Err(RouteError::InvalidAuthorization); } - let token = txn + let token = repo .compat_access_token() .find_by_token(token) .await? .filter(|t| t.is_valid(clock.now())) .ok_or(RouteError::InvalidAuthorization)?; - let session = txn + let session = repo .compat_session() .lookup(token.session_id) .await? .filter(|s| s.is_valid()) .ok_or(RouteError::InvalidAuthorization)?; - txn.compat_session().finish(&clock, session).await?; + repo.compat_session().finish(&clock, session).await?; - txn.commit().await?; + repo.save().await?; Ok(Json(serde_json::json!({}))) } diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 7bfc940a4..ea6d5d238 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -18,7 +18,7 @@ use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, - Repository, + PgRepository, Repository, }; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; @@ -92,7 +92,7 @@ pub(crate) async fn post( Json(input): Json, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let token_type = TokenType::check(&input.refresh_token)?; @@ -100,7 +100,7 @@ pub(crate) async fn post( return Err(RouteError::InvalidToken); } - let refresh_token = txn + let refresh_token = repo .compat_refresh_token() .find_by_token(&input.refresh_token) .await? @@ -110,7 +110,7 @@ pub(crate) async fn post( return Err(RouteError::RefreshTokenConsumed); } - let session = txn + let session = repo .compat_session() .lookup(refresh_token.session_id) .await? @@ -120,7 +120,7 @@ pub(crate) async fn post( return Err(RouteError::InvalidSession); } - let access_token = txn + let access_token = repo .compat_access_token() .lookup(refresh_token.access_token_id) .await? @@ -130,7 +130,7 @@ pub(crate) async fn post( let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng); let expires_in = Duration::minutes(5); - let new_access_token = txn + let new_access_token = repo .compat_access_token() .add( &mut rng, @@ -140,7 +140,7 @@ pub(crate) async fn post( Some(expires_in), ) .await?; - let new_refresh_token = txn + let new_refresh_token = repo .compat_refresh_token() .add( &mut rng, @@ -151,17 +151,17 @@ pub(crate) async fn post( ) .await?; - txn.compat_refresh_token() + repo.compat_refresh_token() .consume(&clock, refresh_token) .await?; if let Some(access_token) = access_token { - txn.compat_access_token() + repo.compat_access_token() .expire(&clock, access_token) .await?; } - txn.commit().await?; + repo.save().await?; Ok(Json(ResponseBody { access_token: new_access_token.token, diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql.rs index 2177388bc..d3a610b6a 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql.rs @@ -28,6 +28,7 @@ use hyper::header::CACHE_CONTROL; use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_graphql::Schema; use mas_keystore::Encrypter; +use mas_storage::PgRepository; use sqlx::PgPool; use tracing::{info_span, Instrument}; @@ -67,8 +68,9 @@ pub async fn post( let content_type = content_type.map(|TypedHeader(h)| h.to_string()); let (session_info, _cookie_jar) = cookie_jar.session_info(); - let mut conn = pool.acquire().await?; - let maybe_session = session_info.load_session(&mut conn).await?; + let mut repo = PgRepository::from_pool(&pool).await?; + let maybe_session = session_info.load_session(&mut repo).await?; + repo.cancel().await?; let mut request = async_graphql::http::receive_batch_body( content_type, @@ -117,8 +119,9 @@ pub async fn get( RawQuery(query): RawQuery, ) -> Result { let (session_info, _cookie_jar) = cookie_jar.session_info(); - let mut conn = pool.acquire().await?; - let maybe_session = session_info.load_session(&mut conn).await?; + let mut repo = PgRepository::from_pool(&pool).await?; + let maybe_session = session_info.load_session(&mut repo).await?; + repo.cancel().await?; let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?; diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 9f462c502..c983e79c7 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -27,11 +27,11 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::Templates; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; -use sqlx::{PgPool, Postgres, Transaction}; +use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -84,13 +84,13 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let grant = txn + let grant = repo .oauth2_authorization_grant() .lookup(grant_id) .await? @@ -107,7 +107,7 @@ pub(crate) async fn get( return Ok((cookie_jar, mas_router::Login::and_then(continue_grant).go()).into_response()); }; - match complete(grant, session, &policy_factory, txn).await { + match complete(grant, session, &policy_factory, repo).await { Ok(params) => { let res = callback_destination.go(&templates, params).await?; Ok((cookie_jar, res).into_response()) @@ -159,7 +159,7 @@ pub(crate) async fn complete( grant: AuthorizationGrant, browser_session: BrowserSession, policy_factory: &PolicyFactory, - mut txn: Transaction<'_, Postgres>, + mut repo: PgRepository, ) -> Result>, GrantCompletionError> { let (clock, mut rng) = crate::clock_and_rng(); @@ -170,7 +170,7 @@ pub(crate) async fn complete( // Check if the authentication is fresh enough if !browser_session.was_authenticated_after(grant.max_auth_time()) { - txn.commit().await?; + repo.save().await?; return Err(GrantCompletionError::RequiresReauth); } @@ -184,13 +184,13 @@ pub(crate) async fn complete( return Err(GrantCompletionError::PolicyViolation); } - let client = txn + let client = repo .oauth2_client() .lookup(grant.client_id) .await? .ok_or(GrantCompletionError::NoSuchClient)?; - let current_consent = txn + let current_consent = repo .oauth2_client() .get_consent_for_user(&client, &browser_session.user) .await?; @@ -202,17 +202,17 @@ pub(crate) async fn complete( // Check if the client lacks consent *or* if consent was explicitely asked if lacks_consent || grant.requires_consent { - txn.commit().await?; + repo.save().await?; return Err(GrantCompletionError::RequiresConsent); } // All good, let's start the session - let session = txn + let session = repo .oauth2_session() .create_from_grant(&mut rng, &clock, &grant, &browser_session) .await?; - let grant = txn + let grant = repo .oauth2_authorization_grant() .fulfill(&clock, &session, grant) .await?; @@ -233,6 +233,6 @@ pub(crate) async fn complete( )); } - txn.commit().await?; + repo.save().await?; Ok(params) } diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index b33b69129..155f72f73 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -27,7 +27,7 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::Templates; use oauth2_types::{ @@ -139,10 +139,10 @@ pub(crate) async fn get( Form(params): Form, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; // First, figure out what client it is - let client = txn + let client = repo .oauth2_client() .find_by_client_id(¶ms.auth.client_id) .await? @@ -170,7 +170,7 @@ pub(crate) async fn get( let templates = templates.clone(); let callback_destination = callback_destination.clone(); async move { - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let prompt = params.auth.prompt.as_deref().unwrap_or_default(); // Check if the request/request_uri/registration params are used. If so, reply @@ -275,7 +275,7 @@ pub(crate) async fn get( let requires_consent = prompt.contains(&Prompt::Consent); - let grant = txn + let grant = repo .oauth2_authorization_grant() .add( &mut rng, @@ -302,7 +302,7 @@ pub(crate) async fn get( } None if prompt.contains(&Prompt::Create) => { // Client asked for a registration, show the registration prompt - txn.commit().await?; + repo.save().await?; mas_router::Register::and_then(continue_grant) .go() @@ -310,7 +310,7 @@ pub(crate) async fn get( } None => { // Other cases where we don't have a session, ask for a login - txn.commit().await?; + repo.save().await?; mas_router::Login::and_then(continue_grant) .go() @@ -323,7 +323,7 @@ pub(crate) async fn get( || prompt.contains(&Prompt::SelectAccount) => { // TODO: better pages here - txn.commit().await?; + repo.save().await?; mas_router::Reauth::and_then(continue_grant) .go() @@ -333,7 +333,7 @@ pub(crate) async fn get( // Else, we immediately try to complete the authorization grant Some(user_session) if prompt.contains(&Prompt::None) => { // With prompt=none, we should get back to the client immediately - match self::complete::complete(grant, user_session, &policy_factory, txn).await + match self::complete::complete(grant, user_session, &policy_factory, repo).await { Ok(params) => callback_destination.go(&templates, params).await?, Err(GrantCompletionError::RequiresConsent) => { @@ -372,7 +372,7 @@ pub(crate) async fn get( Some(user_session) => { let grant_id = grant.id; // Else, we show the relevant reauth/consent page if necessary - match self::complete::complete(grant, user_session, &policy_factory, txn).await + match self::complete::complete(grant, user_session, &policy_factory, repo).await { Ok(params) => callback_destination.go(&templates, params).await?, Err( diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 94bf1346a..f3d4fd46e 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -30,7 +30,7 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; use sqlx::PgPool; @@ -81,13 +81,13 @@ pub(crate) async fn get( Path(grant_id): Path, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let grant = conn + let grant = repo .oauth2_authorization_grant() .lookup(grant_id) .await? @@ -136,15 +136,15 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; cookie_jar.verify_form(clock.now(), form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let grant = txn + let grant = repo .oauth2_authorization_grant() .lookup(grant_id) .await? @@ -167,7 +167,7 @@ pub(crate) async fn post( return Err(RouteError::PolicyViolation); } - let client = txn + let client = repo .oauth2_client() .lookup(grant.client_id) .await? @@ -180,7 +180,7 @@ pub(crate) async fn post( .filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:")) .cloned() .collect(); - txn.oauth2_client() + repo.oauth2_client() .give_consent_for_user( &mut rng, &clock, @@ -190,9 +190,11 @@ pub(crate) async fn post( ) .await?; - txn.oauth2_authorization_grant().give_consent(grant).await?; + repo.oauth2_authorization_grant() + .give_consent(grant) + .await?; - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, next.go_next()).into_response()) } diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index d032695ae..2837928f7 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -25,7 +25,7 @@ use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository}, user::{BrowserSessionRepository, UserRepository}, - Clock, Repository, + Clock, PgRepository, Repository, }; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -130,12 +130,13 @@ pub(crate) async fn post( client_authorization: ClientAuthorization, ) -> Result { let clock = Clock::default(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let client = client_authorization .credentials - .fetch(&mut conn) - .await? + .fetch(&mut repo) + .await + .unwrap() .ok_or(RouteError::ClientNotFound)?; let method = match &client.token_endpoint_auth_method { @@ -166,14 +167,14 @@ pub(crate) async fn post( let reply = match token_type { TokenType::AccessToken => { - let token = conn + let token = repo .oauth2_access_token() .find_by_token(token) .await? .filter(|t| t.is_valid(clock.now())) .ok_or(RouteError::UnknownToken)?; - let session = conn + let session = repo .oauth2_session() .lookup(token.session_id) .await? @@ -181,7 +182,7 @@ pub(crate) async fn post( // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; - let browser_session = conn + let browser_session = repo .browser_session() .lookup(session.user_session_id) .await? @@ -205,14 +206,14 @@ pub(crate) async fn post( } TokenType::RefreshToken => { - let token = conn + let token = repo .oauth2_refresh_token() .find_by_token(token) .await? .filter(|t| t.is_valid()) .ok_or(RouteError::UnknownToken)?; - let session = conn + let session = repo .oauth2_session() .lookup(token.session_id) .await? @@ -220,7 +221,7 @@ pub(crate) async fn post( // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; - let browser_session = conn + let browser_session = repo .browser_session() .lookup(session.user_session_id) .await? @@ -244,21 +245,21 @@ pub(crate) async fn post( } TokenType::CompatAccessToken => { - let access_token = conn + let access_token = repo .compat_access_token() .find_by_token(token) .await? .filter(|t| t.is_valid(clock.now())) .ok_or(RouteError::UnknownToken)?; - let session = conn + let session = repo .compat_session() .lookup(access_token.session_id) .await? .filter(|s| s.is_valid()) .ok_or(RouteError::UnknownToken)?; - let user = conn + let user = repo .user() .lookup(session.user_id) .await? @@ -285,21 +286,21 @@ pub(crate) async fn post( } TokenType::CompatRefreshToken => { - let refresh_token = conn + let refresh_token = repo .compat_refresh_token() .find_by_token(token) .await? .filter(|t| t.is_valid()) .ok_or(RouteError::UnknownToken)?; - let session = conn + let session = repo .compat_session() .lookup(refresh_token.session_id) .await? .filter(|s| s.is_valid()) .ok_or(RouteError::UnknownToken)?; - let user = conn + let user = repo .user() .lookup(session.user_id) .await? diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index a6ff61587..d6180f9aa 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -19,7 +19,7 @@ use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; use mas_policy::{PolicyFactory, Violation}; -use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, PgRepository, Repository}; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, registration::{ @@ -124,8 +124,7 @@ pub(crate) async fn post( return Err(RouteError::PolicyDenied(res.violations)); } - // Grab a txn - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method { Some( @@ -141,7 +140,7 @@ pub(crate) async fn post( _ => (None, None), }; - let client = txn + let client = repo .oauth2_client() .add( &mut rng, @@ -170,7 +169,7 @@ pub(crate) async fn post( ) .await?; - txn.commit().await?; + repo.save().await?; let response = ClientRegistrationResponse { client_id: client.client_id, diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 97f249c27..6365a0ada 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -37,7 +37,7 @@ use mas_storage::{ OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, user::BrowserSessionRepository, - Repository, + PgRepository, Repository, }; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -49,7 +49,7 @@ use oauth2_types::{ }; use serde::Serialize; use serde_with::{serde_as, skip_serializing_none}; -use sqlx::{PgPool, Postgres, Transaction}; +use sqlx::PgPool; use thiserror::Error; use tracing::debug; use url::Url; @@ -166,11 +166,11 @@ pub(crate) async fn post( State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let client = client_authorization .credentials - .fetch(&mut txn) + .fetch(&mut repo) .await? .ok_or(RouteError::ClientNotFound)?; @@ -188,10 +188,10 @@ pub(crate) async fn post( let reply = match form { AccessTokenRequest::AuthorizationCode(grant) => { - authorization_code_grant(&grant, &client, &key_store, &url_builder, txn).await? + authorization_code_grant(&grant, &client, &key_store, &url_builder, repo).await? } AccessTokenRequest::RefreshToken(grant) => { - refresh_token_grant(&grant, &client, txn).await? + refresh_token_grant(&grant, &client, repo).await? } _ => { return Err(RouteError::InvalidGrant); @@ -211,11 +211,11 @@ async fn authorization_code_grant( client: &Client, key_store: &Keystore, url_builder: &UrlBuilder, - mut txn: Transaction<'_, Postgres>, + mut repo: PgRepository, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let authz_grant = txn + let authz_grant = repo .oauth2_authorization_grant() .find_by_code(&grant.code) .await? @@ -238,13 +238,13 @@ async fn authorization_code_grant( // Ending the session if the token was already exchanged more than 20s ago if now - exchanged_at > Duration::seconds(20) { debug!("Ending potentially compromised session"); - let session = txn + let session = repo .oauth2_session() .lookup(session_id) .await? .ok_or(RouteError::NoSuchOAuthSession)?; - txn.oauth2_session().finish(&clock, session).await?; - txn.commit().await?; + repo.oauth2_session().finish(&clock, session).await?; + repo.save().await?; } return Err(RouteError::InvalidGrant); @@ -266,7 +266,7 @@ async fn authorization_code_grant( } }; - let session = txn + let session = repo .oauth2_session() .lookup(session_id) .await? @@ -289,7 +289,7 @@ async fn authorization_code_grant( } }; - let browser_session = txn + let browser_session = repo .browser_session() .lookup(session.user_session_id) .await? @@ -299,12 +299,12 @@ async fn authorization_code_grant( let access_token_str = TokenType::AccessToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - let access_token = txn + let access_token = repo .oauth2_access_token() .add(&mut rng, &clock, &session, access_token_str, ttl) .await?; - let refresh_token = txn + let refresh_token = repo .oauth2_refresh_token() .add(&mut rng, &clock, &session, &access_token, refresh_token_str) .await?; @@ -355,11 +355,11 @@ async fn authorization_code_grant( params = params.with_id_token(id_token); } - txn.oauth2_authorization_grant() + repo.oauth2_authorization_grant() .exchange(&clock, authz_grant) .await?; - txn.commit().await?; + repo.save().await?; Ok(params) } @@ -367,17 +367,17 @@ async fn authorization_code_grant( async fn refresh_token_grant( grant: &RefreshTokenGrant, client: &Client, - mut txn: Transaction<'_, Postgres>, + mut repo: PgRepository, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let refresh_token = txn + let refresh_token = repo .oauth2_refresh_token() .find_by_token(&grant.refresh_token) .await? .ok_or(RouteError::InvalidGrant)?; - let session = txn + let session = repo .oauth2_session() .lookup(refresh_token.session_id) .await? @@ -396,12 +396,12 @@ async fn refresh_token_grant( let access_token_str = TokenType::AccessToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - let new_access_token = txn + let new_access_token = repo .oauth2_access_token() .add(&mut rng, &clock, &session, access_token_str.clone(), ttl) .await?; - let new_refresh_token = txn + let new_refresh_token = repo .oauth2_refresh_token() .add( &mut rng, @@ -412,14 +412,14 @@ async fn refresh_token_grant( ) .await?; - let refresh_token = txn + let refresh_token = repo .oauth2_refresh_token() .consume(&clock, refresh_token) .await?; if let Some(access_token_id) = refresh_token.access_token_id { - if let Some(access_token) = txn.oauth2_access_token().lookup(access_token_id).await? { - txn.oauth2_access_token() + if let Some(access_token) = repo.oauth2_access_token().lookup(access_token_id).await? { + repo.oauth2_access_token() .revoke(&clock, access_token) .await?; } @@ -430,7 +430,7 @@ async fn refresh_token_grant( .with_refresh_token(new_refresh_token.refresh_token) .with_scope(session.scope); - txn.commit().await?; + repo.save().await?; Ok(params) } diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 49b6c5f1c..a125c5dde 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -31,7 +31,7 @@ use mas_router::UrlBuilder; use mas_storage::{ oauth2::OAuth2ClientRepository, user::{BrowserSessionRepository, UserEmailRepository}, - Repository, + DatabaseError, PgRepository, Repository, }; use oauth2_types::scope; use serde::Serialize; @@ -64,7 +64,7 @@ pub enum RouteError { Internal(Box), #[error("failed to authenticate")] - AuthorizationVerificationError(#[from] AuthorizationVerificationError), + AuthorizationVerificationError(#[from] AuthorizationVerificationError), #[error("no suitable key found for signing")] InvalidSigningKey, @@ -102,11 +102,11 @@ pub async fn get( user_authorization: UserAuthorization, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; - let session = user_authorization.protected(&mut conn, clock.now()).await?; + let session = user_authorization.protected(&mut repo, clock.now()).await?; - let browser_session = conn + let browser_session = repo .browser_session() .lookup(session.user_session_id) .await? @@ -115,7 +115,7 @@ pub async fn get( let user = browser_session.user; let user_email = if session.scope.contains(&scope::EMAIL) { - conn.user_email().get_primary(&user).await? + repo.user_email().get_primary(&user).await? } else { None }; @@ -127,7 +127,7 @@ pub async fn get( email: user_email.map(|u| u.email), }; - let client = conn + let client = repo .oauth2_client() .lookup(session.client_id) .await? diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 178eba1ab..bdd19b7b8 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -24,7 +24,7 @@ use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; use mas_storage::{ upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, - Repository, + PgRepository, Repository, }; use sqlx::PgPool; use thiserror::Error; @@ -67,9 +67,9 @@ pub(crate) async fn get( ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; - let provider = txn + let provider = repo .upstream_oauth_provider() .lookup(provider_id) .await? @@ -100,7 +100,7 @@ pub(crate) async fn get( &mut rng, )?; - let session = txn + let session = repo .upstream_oauth_session() .add( &mut rng, @@ -116,7 +116,7 @@ pub(crate) async fn get( .add(session.id, provider.id, data.state, query.post_auth_action) .save(cookie_jar, clock.now()); - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, Redirect::temporary(url.as_str()))) } diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 8cb9a605d..521efd7b4 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -26,8 +26,11 @@ use mas_oidc_client::requests::{ }; use mas_router::{Route, UrlBuilder}; use mas_storage::{ - upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, - Repository, UpstreamOAuthLinkRepository, + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + PgRepository, Repository, }; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; @@ -129,9 +132,9 @@ pub(crate) async fn get( ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; - let provider = txn + let provider = repo .upstream_oauth_provider() .lookup(provider_id) .await? @@ -142,7 +145,7 @@ pub(crate) async fn get( .find_session(provider_id, ¶ms.state) .map_err(|_| RouteError::MissingCookie)?; - let session = txn + let session = repo .upstream_oauth_session() .lookup(session_id) .await? @@ -244,7 +247,7 @@ pub(crate) async fn get( let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?; // Look for an existing link - let maybe_link = txn + let maybe_link = repo .upstream_oauth_link() .find_by_subject(&provider, &subject) .await?; @@ -252,12 +255,12 @@ pub(crate) async fn get( let link = if let Some(link) = maybe_link { link } else { - txn.upstream_oauth_link() + repo.upstream_oauth_link() .add(&mut rng, &clock, &provider, subject) .await? }; - let session = txn + let session = repo .upstream_oauth_session() .complete_with_link(&clock, session, &link, response.id_token) .await?; @@ -266,7 +269,7 @@ pub(crate) async fn get( .add_link_to_session(session.id, link.id)? .save(cookie_jar, clock.now()); - txn.commit().await?; + repo.save().await?; Ok(( cookie_jar, diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 10e1f80e8..18849be84 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -25,9 +25,9 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_storage::{ - upstream_oauth2::UpstreamOAuthSessionRepository, + upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, user::{BrowserSessionRepository, UserRepository}, - Repository, UpstreamOAuthLinkRepository, + PgRepository, Repository, }; use mas_templates::{ EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, @@ -99,7 +99,7 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, Path(link_id): Path, ) -> Result { - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (clock, mut rng) = crate::clock_and_rng(); let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); @@ -107,13 +107,13 @@ pub(crate) async fn get( .lookup_link(link_id) .map_err(|_| RouteError::MissingCookie)?; - let link = txn + let link = repo .upstream_oauth_link() .lookup(link_id) .await? .ok_or(RouteError::LinkNotFound)?; - let upstream_session = txn + let upstream_session = repo .upstream_oauth_session() .lookup(session_id) .await? @@ -131,24 +131,24 @@ pub(crate) async fn get( let (user_session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); - let maybe_user_session = user_session_info.load_session(&mut txn).await?; + let maybe_user_session = user_session_info.load_session(&mut repo).await?; let render = match (maybe_user_session, link.user_id) { (Some(session), Some(user_id)) if session.user.id == user_id => { // Session already linked, and link matches the currently logged // user. Mark the session as consumed and renew the authentication. - txn.upstream_oauth_session() + repo.upstream_oauth_session() .consume(&clock, upstream_session) .await?; - let session = txn + let session = repo .browser_session() .authenticate_with_upstream(&mut rng, &clock, session, &link) .await?; cookie_jar = cookie_jar.set_session(&session); - txn.commit().await?; + repo.save().await?; let ctx = EmptyContext .with_session(session) @@ -163,7 +163,7 @@ pub(crate) async fn get( // Session already linked, but link doesn't match the currently // logged user. Suggest logging out of the current user // and logging in with the new one - let user = txn + let user = repo .user() .lookup(user_id) .await? @@ -187,7 +187,7 @@ pub(crate) async fn get( (None, Some(user_id)) => { // Session linked, but user not logged in: do the login - let user = txn + let user = repo .user() .lookup(user_id) .await? @@ -216,8 +216,8 @@ pub(crate) async fn post( Path(link_id): Path, Form(form): Form>, ) -> Result { - let mut txn = pool.begin().await?; let (clock, mut rng) = crate::clock_and_rng(); + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); @@ -229,13 +229,13 @@ pub(crate) async fn post( post_auth_action: post_auth_action.cloned(), }; - let link = txn + let link = repo .upstream_oauth_link() .lookup(link_id) .await? .ok_or(RouteError::LinkNotFound)?; - let upstream_session = txn + let upstream_session = repo .upstream_oauth_session() .lookup(session_id) .await? @@ -252,11 +252,11 @@ pub(crate) async fn post( } let (user_session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_user_session = user_session_info.load_session(&mut txn).await?; + let maybe_user_session = user_session_info.load_session(&mut repo).await?; let session = match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { - txn.upstream_oauth_link() + repo.upstream_oauth_link() .associate_to_user(&link, &session.user) .await?; @@ -264,32 +264,32 @@ pub(crate) async fn post( } (None, Some(user_id), FormData::Login) => { - let user = txn + let user = repo .user() .lookup(user_id) .await? .ok_or(RouteError::UserNotFound)?; - txn.browser_session().add(&mut rng, &clock, &user).await? + repo.browser_session().add(&mut rng, &clock, &user).await? } (None, None, FormData::Register { username }) => { - let user = txn.user().add(&mut rng, &clock, username).await?; - txn.upstream_oauth_link() + let user = repo.user().add(&mut rng, &clock, username).await?; + repo.upstream_oauth_link() .associate_to_user(&link, &user) .await?; - txn.browser_session().add(&mut rng, &clock, &user).await? + repo.browser_session().add(&mut rng, &clock, &user).await? } _ => return Err(RouteError::InvalidFormAction), }; - txn.upstream_oauth_session() + repo.upstream_oauth_session() .consume(&clock, upstream_session) .await?; - let session = txn + let session = repo .browser_session() .authenticate_with_upstream(&mut rng, &clock, session, &link) .await?; @@ -299,7 +299,7 @@ pub(crate) async fn post( .save(cookie_jar, clock.now()); let cookie_jar = cookie_jar.set_session(&session); - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, post_auth_action.go_next())) } diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index c7cd27676..e0cc063d3 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -24,7 +24,7 @@ use mas_axum_utils::{ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Repository}; +use mas_storage::{user::UserEmailRepository, PgRepository, Repository}; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; use sqlx::PgPool; @@ -43,12 +43,12 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -74,12 +74,12 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -88,7 +88,7 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = txn + let user_email = repo .user_email() .add(&mut rng, &clock, &session.user, form.email) .await?; @@ -101,7 +101,7 @@ pub(crate) async fn post( }; start_email_verification( &mailer, - &mut txn, + &mut repo, &mut rng, &clock, &session.user, @@ -109,7 +109,7 @@ pub(crate) async fn post( ) .await?; - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, next.go()).into_response()) } diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index e6e1e3410..3fda398ad 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -28,11 +28,11 @@ use mas_data_model::{BrowserSession, User, UserEmail}; use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Clock, Repository}; +use mas_storage::{user::UserEmailRepository, Clock, PgRepository, Repository}; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; use serde::Deserialize; -use sqlx::{PgConnection, PgPool}; +use sqlx::PgPool; use tracing::info; pub mod add; @@ -54,14 +54,14 @@ pub(crate) async fn get( ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if let Some(session) = maybe_session { - render(&mut rng, &clock, templates, session, cookie_jar, &mut conn).await + render(&mut rng, &clock, templates, session, cookie_jar, &mut repo).await } else { let login = mas_router::Login::default(); Ok((cookie_jar, login.go()).into_response()) @@ -74,11 +74,11 @@ async fn render( templates: Templates, session: BrowserSession, cookie_jar: PrivateCookieJar, - conn: &mut PgConnection, + repo: &mut impl Repository, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng); - let emails = conn.user_email().all(&session.user).await?; + let emails = repo.user_email().all(&session.user).await?; let ctx = AccountEmailsContext::new(emails) .with_session(session) @@ -91,7 +91,7 @@ async fn render( async fn start_email_verification( mailer: &Mailer, - conn: &mut PgConnection, + repo: &mut impl Repository, mut rng: impl Rng + Send, clock: &Clock, user: &User, @@ -103,7 +103,7 @@ async fn start_email_verification( let address: Address = user_email.email.parse()?; - let verification = conn + let verification = repo .user_email() .add_verification_code(&mut rng, clock, &user_email, Duration::hours(8), code) .await?; @@ -130,11 +130,11 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let mut session = if let Some(session) = maybe_session { session @@ -147,21 +147,21 @@ pub(crate) async fn post( match form { ManagementForm::Add { email } => { - let email = txn + let email = repo .user_email() .add(&mut rng, &clock, &session.user, email) .await?; let next = mas_router::AccountVerifyEmail::new(email.id); - start_email_verification(&mailer, &mut txn, &mut rng, &clock, &session.user, email) + start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) .await?; - txn.commit().await?; + repo.save().await?; return Ok((cookie_jar, next.go()).into_response()); } ManagementForm::ResendConfirmation { id } => { let id = id.parse()?; - let email = txn + let email = repo .user_email() .lookup(id) .await? @@ -172,15 +172,15 @@ pub(crate) async fn post( } let next = mas_router::AccountVerifyEmail::new(email.id); - start_email_verification(&mailer, &mut txn, &mut rng, &clock, &session.user, email) + start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) .await?; - txn.commit().await?; + repo.save().await?; return Ok((cookie_jar, next.go()).into_response()); } ManagementForm::Remove { id } => { let id = id.parse()?; - let email = txn + let email = repo .user_email() .lookup(id) .await? @@ -190,11 +190,11 @@ pub(crate) async fn post( return Err(anyhow!("Email not found").into()); } - txn.user_email().remove(email).await?; + repo.user_email().remove(email).await?; } ManagementForm::SetPrimary { id } => { let id = id.parse()?; - let email = txn + let email = repo .user_email() .lookup(id) .await? @@ -204,7 +204,7 @@ pub(crate) async fn post( return Err(anyhow!("Email not found").into()); } - txn.user_email().set_as_primary(&email).await?; + repo.user_email().set_as_primary(&email).await?; session.user.primary_user_email_id = Some(email.id); } }; @@ -215,11 +215,11 @@ pub(crate) async fn post( templates.clone(), session, cookie_jar, - &mut txn, + &mut repo, ) .await?; - txn.commit().await?; + repo.save().await?; Ok(reply) } diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 1192743e7..085b9a337 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -24,7 +24,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Clock, Repository}; +use mas_storage::{user::UserEmailRepository, Clock, PgRepository, Repository}; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; use sqlx::PgPool; @@ -45,12 +45,12 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -59,7 +59,7 @@ pub(crate) async fn get( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = conn + let user_email = repo .user_email() .lookup(id) .await? @@ -89,12 +89,12 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let clock = Clock::default(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -103,33 +103,33 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = txn + let user_email = repo .user_email() .lookup(id) .await? .filter(|u| u.user_id == session.user.id) .context("Could not find user email")?; - let verification = txn + let verification = repo .user_email() .find_verification_code(&clock, &user_email, &form.code) .await? .context("Invalid code")?; // TODO: display nice errors if the code was already consumed or expired - txn.user_email() + repo.user_email() .consume_verification_code(&clock, verification) .await?; if session.user.primary_user_email_id.is_none() { - txn.user_email().set_as_primary(&user_email).await?; + repo.user_email().set_as_primary(&user_email).await?; } - txn.user_email() + repo.user_email() .mark_as_verified(&clock, user_email) .await?; - txn.commit().await?; + repo.save().await?; let destination = query.go_next_or_default(&mas_router::AccountEmails); Ok((cookie_jar, destination).into_response()) diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 0188aef29..5017db00e 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -25,7 +25,7 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::{AccountContext, TemplateContext, Templates}; use sqlx::PgPool; @@ -36,12 +36,12 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -50,9 +50,9 @@ pub(crate) async fn get( return Ok((cookie_jar, login.go()).into_response()); }; - let active_sessions = conn.browser_session().count_active(&session.user).await?; + let active_sessions = repo.browser_session().count_active(&session.user).await?; - let emails = conn.user_email().all(&session.user).await?; + let emails = repo.user_email().all(&session.user).await?; let ctx = AccountContext::new(active_sessions, emails) .with_session(session) diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 42c0194be..8d4964323 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -27,7 +27,7 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - Clock, Repository, + Clock, PgRepository, Repository, }; use mas_templates::{EmptyContext, TemplateContext, Templates}; use rand::Rng; @@ -50,11 +50,11 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if let Some(session) = maybe_session { render(&mut rng, &clock, templates, session, cookie_jar).await @@ -90,13 +90,13 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -105,7 +105,7 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let user_password = txn + let user_password = repo .user_password() .active(&session.user) .await? @@ -129,7 +129,7 @@ pub(crate) async fn post( } let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?; - let user_password = txn + let user_password = repo .user_password() .add( &mut rng, @@ -141,14 +141,14 @@ pub(crate) async fn post( ) .await?; - let session = txn + let session = repo .browser_session() .authenticate_with_password(&mut rng, &clock, session, &user_password) .await?; let reply = render(&mut rng, &clock, templates.clone(), session, cookie_jar).await?; - txn.commit().await?; + repo.save().await?; Ok(reply) } diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index 2471296da..49668daeb 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -20,6 +20,7 @@ use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_keystore::Encrypter; use mas_router::UrlBuilder; +use mas_storage::PgRepository; use mas_templates::{IndexContext, TemplateContext, Templates}; use sqlx::PgPool; @@ -30,11 +31,11 @@ pub async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let session = session_info.load_session(&mut conn).await?; + let session = session_info.load_session(&mut repo).await?; let ctx = IndexContext::new(url_builder.oidc_discovery()) .maybe_with_session(session) diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 1ef5efbb8..76ffa4558 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -26,14 +26,14 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, - Clock, Repository, + Clock, PgRepository, Repository, }; use mas_templates::{ FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, }; use rand::{CryptoRng, Rng}; use serde::{Deserialize, Serialize}; -use sqlx::{PgConnection, PgPool}; +use sqlx::PgPool; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -56,23 +56,23 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if maybe_session.is_some() { let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) } else { - let providers = conn.upstream_oauth_provider().all().await?; + let providers = repo.upstream_oauth_provider().all().await?; let content = render( LoginContext::default().with_upstrem_providers(providers), query, csrf_token, - &mut conn, + &mut repo, &templates, ) .await?; @@ -90,7 +90,7 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; @@ -112,14 +112,14 @@ pub(crate) async fn post( }; if !state.is_valid() { - let providers = conn.upstream_oauth_provider().all().await?; + let providers = repo.upstream_oauth_provider().all().await?; let content = render( LoginContext::default() .with_form_state(state) .with_upstrem_providers(providers), query, csrf_token, - &mut conn, + &mut repo, &templates, ) .await?; @@ -129,7 +129,7 @@ pub(crate) async fn post( match login( password_manager, - &mut conn, + &mut repo, rng, &clock, &form.username, @@ -138,6 +138,8 @@ pub(crate) async fn post( .await { Ok(session_info) => { + repo.save().await?; + let cookie_jar = cookie_jar.set_session(&session_info); let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) @@ -149,7 +151,7 @@ pub(crate) async fn post( LoginContext::default().with_form_state(state), query, csrf_token, - &mut conn, + &mut repo, &templates, ) .await?; @@ -162,7 +164,7 @@ pub(crate) async fn post( // TODO: move that logic elsewhere? async fn login( password_manager: PasswordManager, - conn: &mut PgConnection, + repo: &mut impl Repository, mut rng: impl Rng + CryptoRng + Send, clock: &Clock, username: &str, @@ -170,7 +172,7 @@ async fn login( ) -> Result { // XXX: we're loosing the error context here // First, lookup the user - let user = conn + let user = repo .user() .find_by_username(username) .await @@ -178,7 +180,7 @@ async fn login( .ok_or(FormError::InvalidCredentials)?; // And its password - let user_password = conn + let user_password = repo .user_password() .active(&user) .await @@ -200,7 +202,7 @@ async fn login( let user_password = if let Some((version, new_password_hash)) = new_password_hash { // Save the upgraded password - conn.user_password() + repo.user_password() .add( &mut rng, clock, @@ -216,14 +218,14 @@ async fn login( }; // Start a new session - let user_session = conn + let user_session = repo .browser_session() .add(&mut rng, clock, &user) .await .map_err(|_| FormError::Internal)?; // And mark it as authenticated by the password - let user_session = conn + let user_session = repo .browser_session() .authenticate_with_password(&mut rng, clock, user_session, &user_password) .await @@ -236,10 +238,10 @@ async fn render( ctx: LoginContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - conn: &mut PgConnection, + repo: &mut impl Repository, templates: &Templates, ) -> Result { - let next = action.load_context(conn).await?; + let next = action.load_context(repo).await?; let ctx = if let Some(next) = next { ctx.with_post_action(next) } else { diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 88c4a9c2d..156e6afbe 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -23,7 +23,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; -use mas_storage::{user::BrowserSessionRepository, Clock, Repository}; +use mas_storage::{user::BrowserSessionRepository, Clock, PgRepository, Repository}; use sqlx::PgPool; pub(crate) async fn post( @@ -32,20 +32,20 @@ pub(crate) async fn post( Form(form): Form>>, ) -> Result { let clock = Clock::default(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; let (session_info, mut cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if let Some(session) = maybe_session { - txn.browser_session().finish(&clock, session).await?; + repo.browser_session().finish(&clock, session).await?; cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended()); } - txn.commit().await?; + repo.save().await?; let destination = if let Some(action) = form { action.go_next() diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 7911a9301..aac51abd7 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -26,7 +26,7 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::{ReauthContext, TemplateContext, Templates}; use serde::Deserialize; @@ -48,12 +48,12 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -65,7 +65,7 @@ pub(crate) async fn get( }; let ctx = ReauthContext::default(); - let next = query.load_context(&mut conn).await?; + let next = query.load_context(&mut repo).await?; let ctx = if let Some(next) = next { ctx.with_post_action(next) } else { @@ -86,13 +86,13 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -104,7 +104,7 @@ pub(crate) async fn post( }; // Load the user password - let user_password = txn + let user_password = repo .user_password() .active(&session.user) .await? @@ -125,7 +125,7 @@ pub(crate) async fn post( let user_password = if let Some((version, new_password_hash)) = new_password_hash { // Save the upgraded password - txn.user_password() + repo.user_password() .add( &mut rng, &clock, @@ -140,13 +140,13 @@ pub(crate) async fn post( }; // Mark the session as authenticated by the password - let session = txn + let session = repo .browser_session() .authenticate_with_password(&mut rng, &clock, session, &user_password) .await?; let cookie_jar = cookie_jar.set_session(&session); - txn.commit().await?; + repo.save().await?; let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index b2fe9fe0e..a014eb9db 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -33,7 +33,7 @@ use mas_policy::PolicyFactory; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::{ EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, @@ -41,7 +41,7 @@ use mas_templates::{ }; use rand::{distributions::Uniform, Rng}; use serde::{Deserialize, Serialize}; -use sqlx::{PgConnection, PgPool}; +use sqlx::PgPool; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -66,12 +66,12 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if maybe_session.is_some() { let reply = query.go_next(); @@ -81,7 +81,7 @@ pub(crate) async fn get( RegisterContext::default(), query, csrf_token, - &mut conn, + &mut repo, &templates, ) .await?; @@ -102,7 +102,7 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; @@ -114,7 +114,7 @@ pub(crate) async fn post( if form.username.is_empty() { state.add_error_on_field(RegisterFormField::Username, FieldError::Required); - } else if txn.user().exists(&form.username).await? { + } else if repo.user().exists(&form.username).await? { state.add_error_on_field(RegisterFormField::Username, FieldError::Exists); } @@ -177,7 +177,7 @@ pub(crate) async fn post( RegisterContext::default().with_form_state(state), query, csrf_token, - &mut txn, + &mut repo, &templates, ) .await?; @@ -185,15 +185,15 @@ pub(crate) async fn post( return Ok((cookie_jar, Html(content)).into_response()); } - let user = txn.user().add(&mut rng, &clock, form.username).await?; + let user = repo.user().add(&mut rng, &clock, form.username).await?; let password = Zeroizing::new(form.password.into_bytes()); let (version, hashed_password) = password_manager.hash(&mut rng, password).await?; - let user_password = txn + let user_password = repo .user_password() .add(&mut rng, &clock, &user, version, hashed_password, None) .await?; - let user_email = txn + let user_email = repo .user_email() .add(&mut rng, &clock, &user, form.email) .await?; @@ -205,7 +205,7 @@ pub(crate) async fn post( let address: Address = user_email.email.parse()?; - let verification = txn + let verification = repo .user_email() .add_verification_code(&mut rng, &clock, &user_email, Duration::hours(8), code) .await?; @@ -219,14 +219,14 @@ pub(crate) async fn post( let next = mas_router::AccountVerifyEmail::new(user_email.id).and_maybe(query.post_auth_action); - let session = txn.browser_session().add(&mut rng, &clock, &user).await?; + let session = repo.browser_session().add(&mut rng, &clock, &user).await?; - let session = txn + let session = repo .browser_session() .authenticate_with_password(&mut rng, &clock, session, &user_password) .await?; - txn.commit().await?; + repo.save().await?; let cookie_jar = cookie_jar.set_session(&session); Ok((cookie_jar, next.go()).into_response()) @@ -236,10 +236,10 @@ async fn render( ctx: RegisterContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - conn: &mut PgConnection, + repo: &mut impl Repository, templates: &Templates, ) -> Result { - let next = action.load_context(conn).await?; + let next = action.load_context(repo).await?; let ctx = if let Some(next) = next { ctx.with_post_action(next) } else { diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index 57d537628..db3c33920 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -15,12 +15,13 @@ use anyhow::Context; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - compat::CompatSsoLoginRepository, oauth2::OAuth2AuthorizationGrantRepository, - upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository, + compat::CompatSsoLoginRepository, + oauth2::OAuth2AuthorizationGrantRepository, + upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, + Repository, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; use serde::{Deserialize, Serialize}; -use sqlx::PgConnection; #[derive(Serialize, Deserialize, Default, Debug, Clone)] pub(crate) struct OptionalPostAuthAction { @@ -39,14 +40,14 @@ impl OptionalPostAuthAction { self.go_next_or_default(&mas_router::Index) } - pub async fn load_context( + pub async fn load_context( &self, - conn: &mut PgConnection, + repo: &mut R, ) -> anyhow::Result> { let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let ctx = match action { PostAuthAction::ContinueAuthorizationGrant { id } => { - let grant = conn + let grant = repo .oauth2_authorization_grant() .lookup(id) .await? @@ -56,7 +57,7 @@ impl OptionalPostAuthAction { } PostAuthAction::ContinueCompatSsoLogin { id } => { - let login = conn + let login = repo .compat_sso_login() .lookup(id) .await? @@ -68,13 +69,13 @@ impl OptionalPostAuthAction { PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword, PostAuthAction::LinkUpstream { id } => { - let link = conn + let link = repo .upstream_oauth_link() .lookup(id) .await? .context("Failed to load upstream OAuth 2.0 link")?; - let provider = conn + let provider = repo .upstream_oauth_provider() .lookup(link.provider_id) .await? diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 09a70023f..97aeee243 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -183,7 +183,7 @@ pub(crate) mod tracing; pub mod upstream_oauth2; pub mod user; -pub use self::{repository::Repository, upstream_oauth2::UpstreamOAuthLinkRepository}; +pub use self::repository::{PgRepository, Repository}; /// Embedded migrations, allowing them to run on startup pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 91df93138..c57c5dcd3 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -32,7 +32,7 @@ use crate::{ }; #[async_trait] -pub trait OAuth2AuthorizationGrantRepository { +pub trait OAuth2AuthorizationGrantRepository: Send + Sync { type Error; #[allow(clippy::too_many_arguments)] diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 9df2f61d2..c28bc4efb 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -27,7 +27,7 @@ use crate::{ }; #[async_trait] -pub trait OAuth2SessionRepository { +pub trait OAuth2SessionRepository: Send + Sync { type Error; async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index b9bf5683f..1fde4b417 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -12,89 +12,100 @@ // See the License for the specific language governing permissions and // limitations under the License. -use sqlx::{PgConnection, Postgres, Transaction}; +use sqlx::{PgPool, Postgres, Transaction}; use crate::{ compat::{ - PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository, - PgCompatSsoLoginRepository, + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + CompatSsoLoginRepository, PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, + PgCompatSessionRepository, PgCompatSsoLoginRepository, }, oauth2::{ - PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository, - PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, + OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, + OAuth2RefreshTokenRepository, OAuth2SessionRepository, PgOAuth2AccessTokenRepository, + PgOAuth2AuthorizationGrantRepository, PgOAuth2ClientRepository, + PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, }, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, - PgUpstreamOAuthSessionRepository, + PgUpstreamOAuthSessionRepository, UpstreamOAuthLinkRepository, + UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, user::{ - PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository, - PgUserRepository, + BrowserSessionRepository, PgBrowserSessionRepository, PgUserEmailRepository, + PgUserPasswordRepository, PgUserRepository, UserEmailRepository, UserPasswordRepository, + UserRepository, }, + DatabaseError, }; -pub trait Repository { - type UpstreamOAuthLinkRepository<'c> +pub trait Repository: Send { + type Error: std::error::Error + Send + Sync + 'static; + + type UpstreamOAuthLinkRepository<'c>: UpstreamOAuthLinkRepository + 'c where Self: 'c; - type UpstreamOAuthProviderRepository<'c> + type UpstreamOAuthProviderRepository<'c>: UpstreamOAuthProviderRepository + + 'c where Self: 'c; - type UpstreamOAuthSessionRepository<'c> + type UpstreamOAuthSessionRepository<'c>: UpstreamOAuthSessionRepository + + 'c where Self: 'c; - type UserRepository<'c> + type UserRepository<'c>: UserRepository + 'c where Self: 'c; - type UserEmailRepository<'c> + type UserEmailRepository<'c>: UserEmailRepository + 'c where Self: 'c; - type UserPasswordRepository<'c> + type UserPasswordRepository<'c>: UserPasswordRepository + 'c where Self: 'c; - type BrowserSessionRepository<'c> + type BrowserSessionRepository<'c>: BrowserSessionRepository + 'c where Self: 'c; - type OAuth2ClientRepository<'c> + type OAuth2ClientRepository<'c>: OAuth2ClientRepository + 'c where Self: 'c; - type OAuth2AuthorizationGrantRepository<'c> + type OAuth2AuthorizationGrantRepository<'c>: OAuth2AuthorizationGrantRepository + + 'c where Self: 'c; - type OAuth2SessionRepository<'c> + type OAuth2SessionRepository<'c>: OAuth2SessionRepository + 'c where Self: 'c; - type OAuth2AccessTokenRepository<'c> + type OAuth2AccessTokenRepository<'c>: OAuth2AccessTokenRepository + 'c where Self: 'c; - type OAuth2RefreshTokenRepository<'c> + type OAuth2RefreshTokenRepository<'c>: OAuth2RefreshTokenRepository + 'c where Self: 'c; - type CompatSessionRepository<'c> + type CompatSessionRepository<'c>: CompatSessionRepository + 'c where Self: 'c; - type CompatSsoLoginRepository<'c> + type CompatSsoLoginRepository<'c>: CompatSsoLoginRepository + 'c where Self: 'c; - type CompatAccessTokenRepository<'c> + type CompatAccessTokenRepository<'c>: CompatAccessTokenRepository + 'c where Self: 'c; - type CompatRefreshTokenRepository<'c> + type CompatRefreshTokenRepository<'c>: CompatRefreshTokenRepository + 'c where Self: 'c; @@ -116,7 +127,30 @@ pub trait Repository { fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>; } -impl Repository for PgConnection { +pub struct PgRepository { + txn: Transaction<'static, Postgres>, +} + +impl PgRepository { + pub async fn from_pool(pool: &PgPool) -> Result { + let txn = pool.begin().await?; + Ok(PgRepository { txn }) + } + + pub async fn save(self) -> Result<(), DatabaseError> { + self.txn.commit().await?; + Ok(()) + } + + pub async fn cancel(self) -> Result<(), DatabaseError> { + self.txn.rollback().await?; + Ok(()) + } +} + +impl Repository for PgRepository { + type Error = DatabaseError; + type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; @@ -135,149 +169,66 @@ impl Repository for PgConnection { type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { - PgUpstreamOAuthLinkRepository::new(self) + PgUpstreamOAuthLinkRepository::new(&mut self.txn) } fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { - PgUpstreamOAuthProviderRepository::new(self) + PgUpstreamOAuthProviderRepository::new(&mut self.txn) } fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { - PgUpstreamOAuthSessionRepository::new(self) + PgUpstreamOAuthSessionRepository::new(&mut self.txn) } fn user(&mut self) -> Self::UserRepository<'_> { - PgUserRepository::new(self) + PgUserRepository::new(&mut self.txn) } fn user_email(&mut self) -> Self::UserEmailRepository<'_> { - PgUserEmailRepository::new(self) + PgUserEmailRepository::new(&mut self.txn) } fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { - PgUserPasswordRepository::new(self) + PgUserPasswordRepository::new(&mut self.txn) } fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { - PgBrowserSessionRepository::new(self) + PgBrowserSessionRepository::new(&mut self.txn) } fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { - PgOAuth2ClientRepository::new(self) + PgOAuth2ClientRepository::new(&mut self.txn) } fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { - PgOAuth2AuthorizationGrantRepository::new(self) + PgOAuth2AuthorizationGrantRepository::new(&mut self.txn) } fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { - PgOAuth2SessionRepository::new(self) + PgOAuth2SessionRepository::new(&mut self.txn) } fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { - PgOAuth2AccessTokenRepository::new(self) + PgOAuth2AccessTokenRepository::new(&mut self.txn) } fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { - PgOAuth2RefreshTokenRepository::new(self) + PgOAuth2RefreshTokenRepository::new(&mut self.txn) } fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { - PgCompatSessionRepository::new(self) + PgCompatSessionRepository::new(&mut self.txn) } fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { - PgCompatSsoLoginRepository::new(self) + PgCompatSsoLoginRepository::new(&mut self.txn) } fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { - PgCompatAccessTokenRepository::new(self) + PgCompatAccessTokenRepository::new(&mut self.txn) } fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { - PgCompatRefreshTokenRepository::new(self) - } -} - -impl<'t> Repository for Transaction<'t, Postgres> { - type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; - type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; - type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; - type UserRepository<'c> = PgUserRepository<'c> where Self: 'c; - type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; - type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; - type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; - type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; - type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c; - type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; - type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c; - type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c; - type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; - type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; - type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; - type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; - - fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { - PgUpstreamOAuthLinkRepository::new(self) - } - - fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { - PgUpstreamOAuthProviderRepository::new(self) - } - - fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { - PgUpstreamOAuthSessionRepository::new(self) - } - - fn user(&mut self) -> Self::UserRepository<'_> { - PgUserRepository::new(self) - } - - fn user_email(&mut self) -> Self::UserEmailRepository<'_> { - PgUserEmailRepository::new(self) - } - - fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { - PgUserPasswordRepository::new(self) - } - - fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { - PgBrowserSessionRepository::new(self) - } - - fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { - PgOAuth2ClientRepository::new(self) - } - - fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { - PgOAuth2AuthorizationGrantRepository::new(self) - } - - fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { - PgOAuth2SessionRepository::new(self) - } - - fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { - PgOAuth2AccessTokenRepository::new(self) - } - - fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { - PgOAuth2RefreshTokenRepository::new(self) - } - - fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { - PgCompatSessionRepository::new(self) - } - - fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { - PgCompatSsoLoginRepository::new(self) - } - - fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { - PgCompatAccessTokenRepository::new(self) - } - - fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { - PgCompatRefreshTokenRepository::new(self) + PgCompatRefreshTokenRepository::new(&mut self.txn) } } diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index e195056c8..d2a247314 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -29,20 +29,20 @@ mod tests { use sqlx::PgPool; use super::*; - use crate::{Clock, Repository}; + use crate::{Clock, PgRepository, Repository}; #[sqlx::test(migrator = "crate::MIGRATOR")] async fn test_repository(pool: PgPool) -> Result<(), Box> { let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); let clock = Clock::default(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; // The provider list should be empty at the start - let all_providers = conn.upstream_oauth_provider().all().await?; + let all_providers = repo.upstream_oauth_provider().all().await?; assert!(all_providers.is_empty()); // Let's add a provider - let provider = conn + let provider = repo .upstream_oauth_provider() .add( &mut rng, @@ -57,7 +57,7 @@ mod tests { .await?; // Look it up in the database - let provider = conn + let provider = repo .upstream_oauth_provider() .lookup(provider.id) .await? @@ -66,7 +66,7 @@ mod tests { assert_eq!(provider.client_id, "client-id"); // Start a session - let session = conn + let session = repo .upstream_oauth_session() .add( &mut rng, @@ -79,7 +79,7 @@ mod tests { .await?; // Look it up in the database - let session = conn + let session = repo .upstream_oauth_session() .lookup(session.id) .await? @@ -91,19 +91,19 @@ mod tests { assert!(!session.is_consumed()); // Create a link - let link = conn + let link = repo .upstream_oauth_link() .add(&mut rng, &clock, &provider, "a-subject".to_owned()) .await?; // We can look it up by its ID - conn.upstream_oauth_link() + repo.upstream_oauth_link() .lookup(link.id) .await? .expect("link to be found in database"); // or by its subject - let link = conn + let link = repo .upstream_oauth_link() .find_by_subject(&provider, "a-subject") .await? @@ -111,7 +111,7 @@ mod tests { assert_eq!(link.subject, "a-subject"); assert_eq!(link.provider_id, provider.id); - let session = conn + let session = repo .upstream_oauth_session() .complete_with_link(&clock, session, &link, None) .await?; @@ -119,7 +119,7 @@ mod tests { assert!(!session.is_consumed()); assert_eq!(session.link_id(), Some(link.id)); - let session = conn + let session = repo .upstream_oauth_session() .consume(&clock, session) .await?; diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index f4d11c6af..39a33b8dd 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -14,7 +14,7 @@ //! Database-related tasks -use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, Repository}; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, PgRepository, Repository}; use sqlx::{Pool, Postgres}; use tracing::{debug, error, info}; @@ -33,8 +33,8 @@ impl std::fmt::Debug for CleanupExpired { impl Task for CleanupExpired { async fn run(&self) { let res = async move { - let mut conn = self.0.acquire().await?; - conn.oauth2_access_token().cleanup_expired(&self.1).await + let mut repo = PgRepository::from_pool(&self.0).await?; + repo.oauth2_access_token().cleanup_expired(&self.1).await } .await; From d0d7f1653f95e1ad03a2e642071c75057101d530 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 13 Jan 2023 18:25:25 +0100 Subject: [PATCH 21/45] storage: simplify pagination --- crates/storage/src/compat/sso_login.rs | 14 +-- crates/storage/src/oauth2/session.rs | 14 +-- crates/storage/src/pagination.rs | 88 +++++++++++++------ crates/storage/src/upstream_oauth2/link.rs | 14 +-- .../storage/src/upstream_oauth2/provider.rs | 14 +-- crates/storage/src/user/email.rs | 13 +-- crates/storage/src/user/session.rs | 32 +++---- 7 files changed, 90 insertions(+), 99 deletions(-) diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs index cba777d3d..8cb84dc10 100644 --- a/crates/storage/src/compat/sso_login.rs +++ b/crates/storage/src/compat/sso_login.rs @@ -22,7 +22,7 @@ use url::Url; use uuid::Uuid; use crate::{ - pagination::{process_page, Page, QueryBuilderExt}, + pagination::{Page, QueryBuilderExt}, tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; @@ -379,19 +379,13 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { .push_bind(Uuid::from(user.id)) .generate_pagination("cl.compat_sso_login_id", before, after, first, last)?; - let page: Vec = query + let edges: Vec = query .build_query_as() .traced() .fetch_all(&mut *self.conn) .await?; - let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; - - let edges: Result, _> = edges.into_iter().map(TryInto::try_into).collect(); - Ok(Page { - has_next_page, - has_previous_page, - edges: edges?, - }) + let page = Page::process(edges, first, last)?.try_map(CompatSsoLogin::try_from)?; + Ok(page) } } diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index c28bc4efb..0a6b5c999 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -21,7 +21,7 @@ use ulid::Ulid; use uuid::Uuid; use crate::{ - pagination::{process_page, Page, QueryBuilderExt}, + pagination::{Page, QueryBuilderExt}, tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; @@ -271,15 +271,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { .fetch_all(&mut *self.conn) .await?; - let (has_previous_page, has_next_page, edges) = process_page(edges, first, last)?; - - let edges: Result, DatabaseInconsistencyError> = - edges.into_iter().map(Session::try_from).collect(); - - Ok(Page { - has_next_page, - has_previous_page, - edges: edges?, - }) + let page = Page::process(edges, first, last)?.try_map(Session::try_from)?; + Ok(page) } } diff --git a/crates/storage/src/pagination.rs b/crates/storage/src/pagination.rs index a240c554e..5887117c3 100644 --- a/crates/storage/src/pagination.rs +++ b/crates/storage/src/pagination.rs @@ -82,41 +82,71 @@ where Ok(()) } -/// Process a page returned by a paginated query -pub fn process_page( - mut page: Vec, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), InvalidPagination> { - let limit = match (first, last) { - (Some(count), _) | (_, Some(count)) => count, - _ => return Err(InvalidPagination), - }; - - let is_full = page.len() == (limit + 1); - if is_full { - page.pop(); - } - - let (has_previous_page, has_next_page) = if first.is_some() { - (false, is_full) - } else if last.is_some() { - // 6. If the last argument is provided, I reverse the order of the results - page.reverse(); - (is_full, false) - } else { - unreachable!() - }; - - Ok((has_previous_page, has_next_page, page)) -} - pub struct Page { pub has_next_page: bool, pub has_previous_page: bool, pub edges: Vec, } +impl Page { + /// Process a page returned by a paginated query + pub fn process( + mut edges: Vec, + first: Option, + last: Option, + ) -> Result { + let limit = match (first, last) { + (Some(count), _) | (_, Some(count)) => count, + _ => return Err(InvalidPagination), + }; + + let is_full = edges.len() == (limit + 1); + if is_full { + edges.pop(); + } + + let (has_previous_page, has_next_page) = if first.is_some() { + (false, is_full) + } else if last.is_some() { + // 6. If the last argument is provided, I reverse the order of the results + edges.reverse(); + (is_full, false) + } else { + unreachable!() + }; + + Ok(Page { + has_next_page, + has_previous_page, + edges, + }) + } + + pub fn map(self, f: F) -> Page + where + F: FnMut(T) -> T2, + { + let edges = self.edges.into_iter().map(f).collect(); + Page { + has_next_page: self.has_next_page, + has_previous_page: self.has_previous_page, + edges, + } + } + + pub fn try_map(self, f: F) -> Result, E> + where + F: FnMut(T) -> Result, + { + let edges: Result, E> = self.edges.into_iter().map(f).collect(); + Ok(Page { + has_next_page: self.has_next_page, + has_previous_page: self.has_previous_page, + edges: edges?, + }) + } +} + impl Page {} pub trait QueryBuilderExt { diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 0d443671c..f72a8504a 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -21,7 +21,7 @@ use ulid::Ulid; use uuid::Uuid; use crate::{ - pagination::{process_page, Page, QueryBuilderExt}, + pagination::{Page, QueryBuilderExt}, tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt, }; @@ -297,19 +297,13 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { .push_bind(Uuid::from(user.id)) .generate_pagination("upstream_oauth_link_id", before, after, first, last)?; - let page: Vec = query + let edges: Vec = query .build_query_as() .traced() .fetch_all(&mut *self.conn) .await?; - let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; - - let edges: Vec<_> = edges.into_iter().map(Into::into).collect(); - Ok(Page { - has_next_page, - has_previous_page, - edges, - }) + let page = Page::process(edges, first, last)?.map(UpstreamOAuthLink::from); + Ok(page) } } diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index a7efb6c88..088e9e93b 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -23,7 +23,7 @@ use ulid::Ulid; use uuid::Uuid; use crate::{ - pagination::{process_page, Page, QueryBuilderExt}, + pagination::{Page, QueryBuilderExt}, tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; @@ -266,20 +266,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?; - let page: Vec = query + let edges: Vec = query .build_query_as() .traced() .fetch_all(&mut *self.conn) .await?; - let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; - - let edges: Result, _> = edges.into_iter().map(TryInto::try_into).collect(); - Ok(Page { - has_next_page, - has_previous_page, - edges: edges?, - }) + let page = Page::process(edges, first, last)?.try_map(TryInto::try_into)?; + Ok(page) } #[tracing::instrument( diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 2f7486110..d725dea5d 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -21,7 +21,7 @@ use ulid::Ulid; use uuid::Uuid; use crate::{ - pagination::{process_page, Page, QueryBuilderExt}, + pagination::{Page, QueryBuilderExt}, tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; @@ -315,15 +315,8 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { .fetch_all(&mut *self.conn) .await?; - let (has_previous_page, has_next_page, edges) = process_page(edges, first, last)?; - - let edges = edges.into_iter().map(Into::into).collect(); - - Ok(Page { - has_next_page, - has_previous_page, - edges, - }) + let page = Page::process(edges, first, last)?.map(UserEmail::from); + Ok(page) } #[tracing::instrument( diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index 01102ca93..f2ceed2a3 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -21,7 +21,7 @@ use ulid::Ulid; use uuid::Uuid; use crate::{ - pagination::{process_page, Page, QueryBuilderExt}, + pagination::{Page, QueryBuilderExt}, tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; @@ -91,19 +91,19 @@ struct SessionLookup { last_authd_at: Option>, } -impl TryInto for SessionLookup { +impl TryFrom for BrowserSession { type Error = DatabaseInconsistencyError; - fn try_into(self) -> Result { - let id = Ulid::from(self.user_id); + fn try_from(value: SessionLookup) -> Result { + let id = Ulid::from(value.user_id); let user = User { id, - username: self.user_username, + username: value.user_username, sub: id.to_string(), - primary_user_email_id: self.user_primary_user_email_id.map(Into::into), + primary_user_email_id: value.user_primary_user_email_id.map(Into::into), }; - let last_authentication = match (self.last_authentication_id, self.last_authd_at) { + let last_authentication = match (value.last_authentication_id, value.last_authd_at) { (Some(id), Some(created_at)) => Some(Authentication { id: id.into(), created_at, @@ -117,10 +117,10 @@ impl TryInto for SessionLookup { }; Ok(BrowserSession { - id: self.user_session_id.into(), + id: value.user_session_id.into(), user, - created_at: self.user_session_created_at, - finished_at: self.user_session_finished_at, + created_at: value.user_session_created_at, + finished_at: value.user_session_finished_at, last_authentication, }) } @@ -292,20 +292,14 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { .push_bind(Uuid::from(user.id)) .generate_pagination("s.user_session_id", before, after, first, last)?; - let page: Vec = query + let edges: Vec = query .build_query_as() .traced() .fetch_all(&mut *self.conn) .await?; - let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; - - let edges: Result, _> = edges.into_iter().map(TryInto::try_into).collect(); - Ok(Page { - has_previous_page, - has_next_page, - edges: edges?, - }) + let page = Page::process(edges, first, last)?.try_map(BrowserSession::try_from)?; + Ok(page) } #[tracing::instrument( From 8c3b78ec61529b078224ee3667f8d7faec47f210 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 16 Jan 2023 17:56:51 +0100 Subject: [PATCH 22/45] storage: tests for the user {,email,password,session} repositories --- crates/storage/sqlx-data.json | 12 + crates/storage/src/lib.rs | 72 +++++- crates/storage/src/user/email.rs | 19 ++ crates/storage/src/user/mod.rs | 5 +- crates/storage/src/user/tests.rs | 394 +++++++++++++++++++++++++++++++ 5 files changed, 500 insertions(+), 2 deletions(-) create mode 100644 crates/storage/src/user/tests.rs diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 5dd182501..8148f796d 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -209,6 +209,18 @@ }, "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n " }, + "1f6297fb323e9f2fbfa1c9e3225c0b3037c8c4714533a6240c62275332aa58dc": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM user_email_confirmation_codes\n WHERE user_email_id = $1\n " + }, "2564bf6366eb59268c41fb25bb40d0e4e9e1fd1f9ea53b7a359c9025d7304223": { "describe": { "columns": [], diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 97aeee243..49c9c3a97 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -161,18 +161,88 @@ impl DatabaseInconsistencyError { } } -#[derive(Default, Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct Clock { _private: (), + + #[cfg(test)] + mock: Option>, } impl Clock { #[must_use] pub fn now(&self) -> DateTime { + #[cfg(test)] + if let Some(timestamp) = &self.mock { + let timestamp = timestamp.load(std::sync::atomic::Ordering::Relaxed); + return chrono::TimeZone::timestamp_opt(&Utc, timestamp, 0).unwrap(); + } + // This is the clock used elsewhere, it's fine to call Utc::now here #[allow(clippy::disallowed_methods)] Utc::now() } + + #[cfg(test)] + pub fn mock() -> Self { + use std::sync::{atomic::AtomicI64, Arc}; + + use chrono::TimeZone; + + let datetime = Utc.with_ymd_and_hms(2022, 01, 16, 14, 40, 0).unwrap(); + let timestamp = datetime.timestamp(); + + Self { + mock: Some(Arc::new(AtomicI64::new(timestamp))), + _private: (), + } + } + + #[cfg(test)] + pub fn advance(&self, duration: chrono::Duration) { + let timestamp = self + .mock + .as_ref() + .expect("Clock::advance should only be called on mocked clocks in tests"); + timestamp.fetch_add(duration.num_seconds(), std::sync::atomic::Ordering::Relaxed); + } +} + +#[cfg(test)] +mod tests { + use chrono::Duration; + + use super::*; + + #[test] + fn test_mocked_clock() { + let clock = Clock::mock(); + + // Time should be frozen, and give out the same timestamp on each call + let first = clock.now(); + std::thread::sleep(std::time::Duration::from_millis(10)); + let second = clock.now(); + + assert_eq!(first, second); + + // Clock can be advanced by a fixed duration + clock.advance(Duration::seconds(10)); + let third = clock.now(); + assert_eq!(first + Duration::seconds(10), third); + } + + #[test] + fn test_real_clock() { + let clock = Clock::default(); + + // Time should not be frozen + let first = clock.now(); + std::thread::sleep(std::time::Duration::from_millis(10)); + let second = clock.now(); + + assert_ne!(first, second); + assert!(first < second); + } } pub mod compat; diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index d725dea5d..2d5ad9876 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -17,6 +17,7 @@ use chrono::{DateTime, Utc}; use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState}; use rand::RngCore; use sqlx::{PgConnection, QueryBuilder}; +use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; @@ -405,7 +406,23 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { err, )] async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> { + let span = info_span!( + "db.user_email.remove.codes", + db.statement = tracing::field::Empty + ); sqlx::query!( + r#" + DELETE FROM user_email_confirmation_codes + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + + let res = sqlx::query!( r#" DELETE FROM user_emails WHERE user_email_id = $1 @@ -416,6 +433,8 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { .execute(&mut *self.conn) .await?; + DatabaseError::ensure_affected_rows(&res, 1)?; + Ok(()) } diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 592cb59de..9dd3d2ca9 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,6 +26,9 @@ mod email; mod password; mod session; +#[cfg(test)] +mod tests; + pub use self::{ email::{PgUserEmailRepository, UserEmailRepository}, password::{PgUserPasswordRepository, UserPasswordRepository}, diff --git a/crates/storage/src/user/tests.rs b/crates/storage/src/user/tests.rs new file mode 100644 index 000000000..fca35ce05 --- /dev/null +++ b/crates/storage/src/user/tests.rs @@ -0,0 +1,394 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::Duration; +use rand::SeedableRng; +use rand_chacha::ChaChaRng; +use sqlx::PgPool; + +use crate::{ + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + Clock, PgRepository, Repository, +}; + +/// Test the user repository, by adding and looking up a user +#[sqlx::test(migrator = "crate::MIGRATOR")] +async fn test_user_repo(pool: PgPool) { + const USERNAME: &str = "john"; + + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + + // Initially, the user shouldn't exist + assert!(!repo.user().exists(USERNAME).await.unwrap()); + assert!(repo + .user() + .find_by_username(USERNAME) + .await + .unwrap() + .is_none()); + + // Adding the user should work + let user = repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .unwrap(); + + // And now it should exist + assert!(repo.user().exists(USERNAME).await.unwrap()); + assert!(repo + .user() + .find_by_username(USERNAME) + .await + .unwrap() + .is_some()); + assert!(repo.user().lookup(user.id).await.unwrap().is_some()); + + // Adding a second time should give a conflict + assert!(repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .is_err()); + + repo.save().await.unwrap(); +} + +/// Test the user email repository, by trying out most of its methods +#[sqlx::test(migrator = "crate::MIGRATOR")] +async fn test_user_email_repo(pool: PgPool) { + const USERNAME: &str = "john"; + const CODE: &str = "012345"; + const CODE2: &str = "543210"; + const EMAIL: &str = "john@example.com"; + + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + + let user = repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .unwrap(); + + // The user email should not exist yet + assert!(repo + .user_email() + .find(&user, &EMAIL) + .await + .unwrap() + .is_none()); + + assert_eq!(repo.user_email().count(&user).await.unwrap(), 0); + + let user_email = repo + .user_email() + .add(&mut rng, &clock, &user, EMAIL.to_owned()) + .await + .unwrap(); + + assert_eq!(user_email.user_id, user.id); + assert_eq!(user_email.email, EMAIL); + assert!(user_email.confirmed_at.is_none()); + + assert_eq!(repo.user_email().count(&user).await.unwrap(), 1); + + assert!(repo + .user_email() + .find(&user, &EMAIL) + .await + .unwrap() + .is_some()); + + let user_email = repo + .user_email() + .lookup(user_email.id) + .await + .unwrap() + .expect("user email was not found"); + + assert_eq!(user_email.user_id, user.id); + assert_eq!(user_email.email, EMAIL); + + let verification = repo + .user_email() + .add_verification_code( + &mut rng, + &clock, + &user_email, + Duration::hours(8), + CODE.to_owned(), + ) + .await + .unwrap(); + + let verification_id = verification.id; + assert_eq!(verification.user_email_id, user_email.id); + assert_eq!(verification.code, CODE); + + // A single user email can have multiple verification at the same time + let _verification2 = repo + .user_email() + .add_verification_code( + &mut rng, + &clock, + &user_email, + Duration::hours(8), + CODE2.to_owned(), + ) + .await + .unwrap(); + + let verification = repo + .user_email() + .find_verification_code(&clock, &user_email, CODE) + .await + .unwrap() + .expect("user email verification was not found"); + + assert_eq!(verification.id, verification_id); + assert_eq!(verification.user_email_id, user_email.id); + assert_eq!(verification.code, CODE); + + // Consuming the verification code + repo.user_email() + .consume_verification_code(&clock, verification) + .await + .unwrap(); + + // Mark the email as verified + repo.user_email() + .mark_as_verified(&clock, user_email) + .await + .unwrap(); + + // Reload the user_email + let user_email = repo + .user_email() + .find(&user, &EMAIL) + .await + .unwrap() + .expect("user email was not found"); + + // The email should be marked as verified now + assert!(user_email.confirmed_at.is_some()); + + // Reload the verification + let verification = repo + .user_email() + .find_verification_code(&clock, &user_email, CODE) + .await + .unwrap() + .expect("user email verification was not found"); + + // Consuming a second time should not work + assert!(repo + .user_email() + .consume_verification_code(&clock, verification) + .await + .is_err()); + + // The user shouldn't have a primary email yet + assert!(repo + .user_email() + .get_primary(&user) + .await + .unwrap() + .is_none()); + + repo.user_email().set_as_primary(&user_email).await.unwrap(); + + // Reload the user + let user = repo + .user() + .lookup(user.id) + .await + .unwrap() + .expect("user was not found"); + + // Now it should have one + assert!(repo + .user_email() + .get_primary(&user) + .await + .unwrap() + .is_some()); + + // Deleting the user email should work + repo.user_email().remove(user_email).await.unwrap(); + assert_eq!(repo.user_email().count(&user).await.unwrap(), 0); + + // Reload the user + let user = repo + .user() + .lookup(user.id) + .await + .unwrap() + .expect("user was not found"); + + // The primary user email should be gone + assert!(repo + .user_email() + .get_primary(&user) + .await + .unwrap() + .is_none()); + + repo.save().await.unwrap(); +} + +#[sqlx::test(migrator = "crate::MIGRATOR")] +async fn test_user_password_repo(pool: PgPool) { + const USERNAME: &str = "john"; + const FIRST_PASSWORD_HASH: &str = "doesntmatter"; + const SECOND_PASSWORD_HASH: &str = "alsodoesntmatter"; + + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + + let user = repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .unwrap(); + + // User should have no active password + assert!(repo.user_password().active(&user).await.unwrap().is_none()); + + // Insert a first password + let first_password = repo + .user_password() + .add( + &mut rng, + &clock, + &user, + 1, + FIRST_PASSWORD_HASH.to_owned(), + None, + ) + .await + .unwrap(); + + // User should now have an active password + let first_password_lookup = repo + .user_password() + .active(&user) + .await + .unwrap() + .expect("user should have an active password"); + + assert_eq!(first_password.id, first_password_lookup.id); + assert_eq!(first_password_lookup.hashed_password, FIRST_PASSWORD_HASH); + assert_eq!(first_password_lookup.version, 1); + assert_eq!(first_password_lookup.upgraded_from_id, None); + + // Getting the last inserted password is based on the clock, so we need to + // advance it + clock.advance(Duration::seconds(10)); + + let second_password = repo + .user_password() + .add( + &mut rng, + &clock, + &user, + 2, + SECOND_PASSWORD_HASH.to_owned(), + Some(&first_password), + ) + .await + .unwrap(); + + // User should now have an active password + let second_password_lookup = repo + .user_password() + .active(&user) + .await + .unwrap() + .expect("user should have an active password"); + + assert_eq!(second_password.id, second_password_lookup.id); + assert_eq!(second_password_lookup.hashed_password, SECOND_PASSWORD_HASH); + assert_eq!(second_password_lookup.version, 2); + assert_eq!( + second_password_lookup.upgraded_from_id, + Some(first_password.id) + ); + + repo.save().await.unwrap(); +} + +#[sqlx::test(migrator = "crate::MIGRATOR")] +async fn test_user_session(pool: PgPool) { + const USERNAME: &str = "john"; + + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + + let user = repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .unwrap(); + + assert_eq!(repo.browser_session().count_active(&user).await.unwrap(), 0); + + let session = repo + .browser_session() + .add(&mut rng, &clock, &user) + .await + .unwrap(); + assert_eq!(session.user.id, user.id); + assert!(session.finished_at.is_none()); + + assert_eq!(repo.browser_session().count_active(&user).await.unwrap(), 1); + + let session_lookup = repo + .browser_session() + .lookup(session.id) + .await + .unwrap() + .expect("user session not found"); + + assert_eq!(session_lookup.id, session.id); + assert_eq!(session_lookup.user.id, user.id); + assert!(session_lookup.finished_at.is_none()); + + // Finish the session + repo.browser_session() + .finish(&clock, session_lookup) + .await + .unwrap(); + + // The active session counter is back to 0 + assert_eq!(repo.browser_session().count_active(&user).await.unwrap(), 0); + + // Reload the session + let session_lookup = repo + .browser_session() + .lookup(session.id) + .await + .unwrap() + .expect("user session not found"); + + assert_eq!(session_lookup.id, session.id); + assert_eq!(session_lookup.user.id, user.id); + // This time the session is finished + assert!(session_lookup.finished_at.is_some()); +} From 7e116b1a1cd596d7f20694031dca9bd754a37f5f Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 16 Jan 2023 19:27:43 +0100 Subject: [PATCH 23/45] storage: test compat {session, access token, refresh token} repositories --- crates/storage/src/compat/mod.rs | 293 ++++++++++++++++++++++ crates/storage/src/upstream_oauth2/mod.rs | 31 ++- 2 files changed, 322 insertions(+), 2 deletions(-) diff --git a/crates/storage/src/compat/mod.rs b/crates/storage/src/compat/mod.rs index 3a91f8c7d..c37081b8c 100644 --- a/crates/storage/src/compat/mod.rs +++ b/crates/storage/src/compat/mod.rs @@ -23,3 +23,296 @@ pub use self::{ session::{CompatSessionRepository, PgCompatSessionRepository}, sso_login::{CompatSsoLoginRepository, PgCompatSsoLoginRepository}, }; + +#[cfg(test)] +mod tests { + use chrono::Duration; + use mas_data_model::Device; + use rand::SeedableRng; + use rand_chacha::ChaChaRng; + use sqlx::PgPool; + + use super::*; + use crate::{user::UserRepository, Clock, PgRepository, Repository}; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_session_repository(pool: PgPool) { + const FIRST_TOKEN: &str = "first_access_token"; + const SECOND_TOKEN: &str = "second_access_token"; + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let device_str = device.as_str().to_owned(); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + assert_eq!(session.user_id, user.id); + assert_eq!(session.device.as_str(), device_str); + assert!(session.is_valid()); + assert!(!session.is_finished()); + + // Lookup the session and check it didn't change + let session_lookup = repo + .compat_session() + .lookup(session.id) + .await + .unwrap() + .expect("compat session not found"); + assert_eq!(session_lookup.id, session.id); + assert_eq!(session_lookup.user_id, user.id); + assert_eq!(session_lookup.device.as_str(), device_str); + assert!(session_lookup.is_valid()); + assert!(!session_lookup.is_finished()); + + // Finish the session + let session = repo.compat_session().finish(&clock, session).await.unwrap(); + assert!(!session.is_valid()); + assert!(session.is_finished()); + + // Reload the session and check again + let session_lookup = repo + .compat_session() + .lookup(session.id) + .await + .unwrap() + .expect("compat session not found"); + assert!(!session_lookup.is_valid()); + assert!(session_lookup.is_finished()); + } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_access_token_repository(pool: PgPool) { + const FIRST_TOKEN: &str = "first_access_token"; + const SECOND_TOKEN: &str = "second_access_token"; + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + + // Add an access token to that session + let token = repo + .compat_access_token() + .add( + &mut rng, + &clock, + &session, + FIRST_TOKEN.to_owned(), + Some(Duration::minutes(1)), + ) + .await + .unwrap(); + assert_eq!(token.session_id, session.id); + assert_eq!(token.token, FIRST_TOKEN); + + // Commit the txn and grab a new transaction, to test a conflict + repo.save().await.unwrap(); + + { + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + // Adding the same token a second time should conflict + assert!(repo + .compat_access_token() + .add( + &mut rng, + &clock, + &session, + FIRST_TOKEN.to_owned(), + Some(Duration::minutes(1)), + ) + .await + .is_err()); + repo.cancel().await.unwrap(); + } + + // Grab a new repo + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Looking up via ID works + let token_lookup = repo + .compat_access_token() + .lookup(token.id) + .await + .unwrap() + .expect("compat access token not found"); + assert_eq!(token.id, token_lookup.id); + assert_eq!(token_lookup.session_id, session.id); + + // Looking up via the token value works + let token_lookup = repo + .compat_access_token() + .find_by_token(FIRST_TOKEN) + .await + .unwrap() + .expect("compat access token not found"); + assert_eq!(token.id, token_lookup.id); + assert_eq!(token_lookup.session_id, session.id); + + // Token is currently valid + assert!(token.is_valid(clock.now())); + + clock.advance(Duration::minutes(1)); + // Token should have expired + assert!(!token.is_valid(clock.now())); + + // Add a second access token, this time without expiration + let token = repo + .compat_access_token() + .add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None) + .await + .unwrap(); + assert_eq!(token.session_id, session.id); + assert_eq!(token.token, SECOND_TOKEN); + + // Token is currently valid + assert!(token.is_valid(clock.now())); + + // Make it expire + repo.compat_access_token() + .expire(&clock, token) + .await + .unwrap(); + + // Reload it + let token = repo + .compat_access_token() + .find_by_token(SECOND_TOKEN) + .await + .unwrap() + .expect("compat access token not found"); + + // Token is not valid anymore + assert!(!token.is_valid(clock.now())); + + repo.save().await.unwrap(); + } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_refresh_token_repository(pool: PgPool) { + const ACCESS_TOKEN: &str = "access_token"; + const REFRESH_TOKEN: &str = "refresh_token"; + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + + // Add an access token to that session + let access_token = repo + .compat_access_token() + .add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None) + .await + .unwrap(); + + let refresh_token = repo + .compat_refresh_token() + .add( + &mut rng, + &clock, + &session, + &access_token, + REFRESH_TOKEN.to_owned(), + ) + .await + .unwrap(); + assert_eq!(refresh_token.session_id, session.id); + assert_eq!(refresh_token.access_token_id, access_token.id); + assert_eq!(refresh_token.token, REFRESH_TOKEN); + assert!(refresh_token.is_valid()); + assert!(!refresh_token.is_consumed()); + + // Look it up by ID and check everything matches + let refresh_token_lookup = repo + .compat_refresh_token() + .lookup(refresh_token.id) + .await + .unwrap() + .expect("refresh token not found"); + assert_eq!(refresh_token_lookup.id, refresh_token.id); + assert_eq!(refresh_token_lookup.session_id, session.id); + assert_eq!(refresh_token_lookup.access_token_id, access_token.id); + assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN); + assert!(refresh_token_lookup.is_valid()); + assert!(!refresh_token_lookup.is_consumed()); + + // Look it up by token and check everything matches + let refresh_token_lookup = repo + .compat_refresh_token() + .find_by_token(REFRESH_TOKEN) + .await + .unwrap() + .expect("refresh token not found"); + assert_eq!(refresh_token_lookup.id, refresh_token.id); + assert_eq!(refresh_token_lookup.session_id, session.id); + assert_eq!(refresh_token_lookup.access_token_id, access_token.id); + assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN); + assert!(refresh_token_lookup.is_valid()); + assert!(!refresh_token_lookup.is_consumed()); + + // Consume it + let refresh_token = repo + .compat_refresh_token() + .consume(&clock, refresh_token) + .await + .unwrap(); + assert!(!refresh_token.is_valid()); + assert!(refresh_token.is_consumed()); + + // Reload it and check again + let refresh_token_lookup = repo + .compat_refresh_token() + .find_by_token(REFRESH_TOKEN) + .await + .unwrap() + .expect("refresh token not found"); + assert!(!refresh_token_lookup.is_valid()); + assert!(refresh_token_lookup.is_consumed()); + + // Consuming it again should not work + assert!(repo + .compat_refresh_token() + .consume(&clock, refresh_token) + .await + .is_err()); + + repo.save().await.unwrap(); + } +} diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index d2a247314..0763624f4 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -29,12 +29,12 @@ mod tests { use sqlx::PgPool; use super::*; - use crate::{Clock, PgRepository, Repository}; + use crate::{user::UserRepository, Clock, PgRepository, Repository}; #[sqlx::test(migrator = "crate::MIGRATOR")] async fn test_repository(pool: PgPool) -> Result<(), Box> { let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - let clock = Clock::default(); + let clock = Clock::mock(); let mut repo = PgRepository::from_pool(&pool).await?; // The provider list should be empty at the start @@ -115,6 +115,12 @@ mod tests { .upstream_oauth_session() .complete_with_link(&clock, session, &link, None) .await?; + // Reload the session + let session = repo + .upstream_oauth_session() + .lookup(session.id) + .await? + .expect("session to be found in the database"); assert!(session.is_completed()); assert!(!session.is_consumed()); assert_eq!(session.link_id(), Some(link.id)); @@ -123,8 +129,29 @@ mod tests { .upstream_oauth_session() .consume(&clock, session) .await?; + // Reload the session + let session = repo + .upstream_oauth_session() + .lookup(session.id) + .await? + .expect("session to be found in the database"); assert!(session.is_consumed()); + let user = repo.user().add(&mut rng, &clock, "john".to_owned()).await?; + repo.upstream_oauth_link() + .associate_to_user(&link, &user) + .await?; + + let links = repo + .upstream_oauth_link() + .list_paginated(&user, None, None, Some(10), None) + .await?; + assert!(!links.has_previous_page); + assert!(!links.has_next_page); + assert_eq!(links.edges.len(), 1); + assert_eq!(links.edges[0].id, link.id); + assert_eq!(links.edges[0].user_id, Some(user.id)); + Ok(()) } } From 3ccaafbbe9e1b7ef3142d9d06bdbdca9dd3d06ae Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 17 Jan 2023 15:09:53 +0100 Subject: [PATCH 24/45] storage: simplify the paginated queries --- crates/graphql/src/lib.rs | 5 +- crates/graphql/src/model/users.rs | 17 +- crates/storage/src/compat/sso_login.rs | 18 +- crates/storage/src/lib.rs | 5 +- crates/storage/src/oauth2/session.rs | 16 +- crates/storage/src/pagination.rs | 254 +++++++++++------- crates/storage/src/upstream_oauth2/link.rs | 16 +- crates/storage/src/upstream_oauth2/mod.rs | 4 +- .../storage/src/upstream_oauth2/provider.rs | 16 +- crates/storage/src/user/email.rs | 16 +- crates/storage/src/user/session.rs | 18 +- 11 files changed, 207 insertions(+), 178 deletions(-) diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 6e58bec74..765db8e56 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -34,7 +34,7 @@ use mas_storage::{ oauth2::OAuth2ClientRepository, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, user::{BrowserSessionRepository, UserEmailRepository}, - PgRepository, Repository, + Pagination, PgRepository, Repository, }; use model::CreationEvent; use sqlx::PgPool; @@ -228,10 +228,11 @@ impl RootQuery { x.extract_for_type(NodeType::UpstreamOAuth2Provider) }) .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; let page = repo .upstream_oauth_provider() - .list_paginated(before_id, after_id, first, last) + .list_paginated(&pagination) .await?; let mut connection = Connection::new(page.has_previous_page, page.has_next_page); diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 9cd8d53bc..9a9062d0c 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -22,7 +22,7 @@ use mas_storage::{ oauth2::OAuth2SessionRepository, upstream_oauth2::UpstreamOAuthLinkRepository, user::{BrowserSessionRepository, UserEmailRepository}, - PgRepository, Repository, + Pagination, PgRepository, Repository, }; use sqlx::PgPool; @@ -95,10 +95,11 @@ impl User { let before_id = before .map(|x: OpaqueCursor| x.extract_for_type(NodeType::CompatSsoLogin)) .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; let page = repo .compat_sso_login() - .list_paginated(&self.0, before_id, after_id, first, last) + .list_paginated(&self.0, &pagination) .await?; let mut connection = Connection::new(page.has_previous_page, page.has_next_page); @@ -141,10 +142,11 @@ impl User { let before_id = before .map(|x: OpaqueCursor| x.extract_for_type(NodeType::BrowserSession)) .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; let page = repo .browser_session() - .list_active_paginated(&self.0, before_id, after_id, first, last) + .list_active_paginated(&self.0, &pagination) .await?; let mut connection = Connection::new(page.has_previous_page, page.has_next_page); @@ -187,10 +189,11 @@ impl User { let before_id = before .map(|x: OpaqueCursor| x.extract_for_type(NodeType::UserEmail)) .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; let page = repo .user_email() - .list_paginated(&self.0, before_id, after_id, first, last) + .list_paginated(&self.0, &pagination) .await?; let mut connection = Connection::with_additional_fields( @@ -237,10 +240,11 @@ impl User { let before_id = before .map(|x: OpaqueCursor| x.extract_for_type(NodeType::OAuth2Session)) .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; let page = repo .oauth2_session() - .list_paginated(&self.0, before_id, after_id, first, last) + .list_paginated(&self.0, &pagination) .await?; let mut connection = Connection::new(page.has_previous_page, page.has_next_page); @@ -287,10 +291,11 @@ impl User { x.extract_for_type(NodeType::UpstreamOAuth2Link) }) .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; let page = repo .upstream_oauth_link() - .list_paginated(&self.0, before_id, after_id, first, last) + .list_paginated(&self.0, &pagination) .await?; let mut connection = Connection::new(page.has_previous_page, page.has_next_page); diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs index 8cb84dc10..31c3da3c8 100644 --- a/crates/storage/src/compat/sso_login.rs +++ b/crates/storage/src/compat/sso_login.rs @@ -24,7 +24,7 @@ use uuid::Uuid; use crate::{ pagination::{Page, QueryBuilderExt}, tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, + Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, }; #[async_trait] @@ -68,10 +68,7 @@ pub trait CompatSsoLoginRepository: Send + Sync { async fn list_paginated( &mut self, user: &User, - before: Option, - after: Option, - first: Option, - last: Option, + pagination: &Pagination, ) -> Result, Self::Error>; } @@ -354,10 +351,7 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { async fn list_paginated( &mut self, user: &User, - before: Option, - after: Option, - first: Option, - last: Option, + pagination: &Pagination, ) -> Result, Self::Error> { let mut query = QueryBuilder::new( r#" @@ -377,7 +371,7 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { query .push(" WHERE user_id = ") .push_bind(Uuid::from(user.id)) - .generate_pagination("cl.compat_sso_login_id", before, after, first, last)?; + .generate_pagination("cl.compat_sso_login_id", &pagination); let edges: Vec = query .build_query_as() @@ -385,7 +379,9 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { .fetch_all(&mut *self.conn) .await?; - let page = Page::process(edges, first, last)?.try_map(CompatSsoLogin::try_from)?; + let page = pagination + .process(edges) + .try_map(CompatSsoLogin::try_from)?; Ok(page) } } diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 49c9c3a97..e92d37fe7 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -253,7 +253,10 @@ pub(crate) mod tracing; pub mod upstream_oauth2; pub mod user; -pub use self::repository::{PgRepository, Repository}; +pub use self::{ + pagination::Pagination, + repository::{PgRepository, Repository}, +}; /// Embedded migrations, allowing them to run on startup pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 0a6b5c999..0fa8cb8fc 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -23,7 +23,7 @@ use uuid::Uuid; use crate::{ pagination::{Page, QueryBuilderExt}, tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, + Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, }; #[async_trait] @@ -45,10 +45,7 @@ pub trait OAuth2SessionRepository: Send + Sync { async fn list_paginated( &mut self, user: &User, - before: Option, - after: Option, - first: Option, - last: Option, + pagination: &Pagination, ) -> Result, Self::Error>; } @@ -243,10 +240,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { async fn list_paginated( &mut self, user: &User, - before: Option, - after: Option, - first: Option, - last: Option, + pagination: &Pagination, ) -> Result, Self::Error> { let mut query = QueryBuilder::new( r#" @@ -263,7 +257,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { query .push(" WHERE us.user_id = ") .push_bind(Uuid::from(user.id)) - .generate_pagination("oauth2_session_id", before, after, first, last)?; + .generate_pagination("oauth2_session_id", pagination); let edges: Vec = query .build_query_as() @@ -271,7 +265,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { .fetch_all(&mut *self.conn) .await?; - let page = Page::process(edges, first, last)?.try_map(Session::try_from)?; + let page = pagination.process(edges).try_map(Session::try_from)?; Ok(page) } } diff --git a/crates/storage/src/pagination.rs b/crates/storage/src/pagination.rs index 5887117c3..1fbcbc516 100644 --- a/crates/storage/src/pagination.rs +++ b/crates/storage/src/pagination.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,74 +12,166 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Utilities to manage paginated queries. + use sqlx::{Database, QueryBuilder}; use thiserror::Error; use ulid::Ulid; use uuid::Uuid; +/// An error returned when invalid pagination parameters are provided #[derive(Debug, Error)] #[error("Either 'first' or 'last' must be specified")] pub struct InvalidPagination; -/// Add cursor-based pagination to a query, as used in paginated GraphQL -/// connections -pub fn generate_pagination<'a, DB>( - query: &mut QueryBuilder<'a, DB>, - id_field: &'static str, +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Pagination { before: Option, after: Option, - first: Option, - last: Option, -) -> Result<(), InvalidPagination> -where - DB: Database, - Uuid: sqlx::Type + sqlx::Encode<'a, DB>, - i64: sqlx::Type + sqlx::Encode<'a, DB>, -{ - // ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564 - // 1. Start from the greedy query: SELECT * FROM table + count: usize, + direction: PaginationDirection, +} - // 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE` - // clause - if let Some(after) = after { - query - .push(" AND ") - .push(id_field) - .push(" > ") - .push_bind(Uuid::from(after)); +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PaginationDirection { + Forward, + Backward, +} + +impl Pagination { + /// Creates a new [`Pagination`] from user-provided parameters. + /// + /// # Errors + /// + /// Either `first` or `last` must be provided, else this function will + /// return an [`InvalidPagination`] error. + pub const fn try_new( + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result { + let (direction, count) = match (first, last) { + (Some(first), _) => (PaginationDirection::Forward, first), + (_, Some(last)) => (PaginationDirection::Backward, last), + (None, None) => return Err(InvalidPagination), + }; + + Ok(Self { + before, + after, + count, + direction, + }) } - // 3. If the before argument is provided, add `id < parsed_cursor` to the - // `WHERE` clause - if let Some(before) = before { - query - .push(" AND ") - .push(id_field) - .push(" < ") - .push_bind(Uuid::from(before)); + /// Creates a [`Pagination`] which gets the first N items + pub const fn first(first: usize) -> Self { + Self { + before: None, + after: None, + count: first, + direction: PaginationDirection::Forward, + } } - // 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to - // the query - if let Some(count) = first { - query - .push(" ORDER BY ") - .push(id_field) - .push(" ASC LIMIT ") - .push_bind((count + 1) as i64); - // 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` - // to the query - } else if let Some(count) = last { - query - .push(" ORDER BY ") - .push(id_field) - .push(" DESC LIMIT ") - .push_bind((count + 1) as i64); - } else { - return Err(InvalidPagination); + /// Creates a [`Pagination`] which gets the last N items + pub const fn last(last: usize) -> Self { + Self { + before: None, + after: None, + count: last, + direction: PaginationDirection::Backward, + } } - Ok(()) + /// Get items before the given cursor + pub const fn before(mut self, id: Ulid) -> Self { + self.before = Some(id); + self + } + + /// Get items after the given cursor + pub const fn after(mut self, id: Ulid) -> Self { + self.after = Some(id); + self + } + + /// Add cursor-based pagination to a query, as used in paginated GraphQL + /// connections + fn generate_pagination<'a, DB>(&self, query: &mut QueryBuilder<'a, DB>, id_field: &'static str) + where + DB: Database, + Uuid: sqlx::Type + sqlx::Encode<'a, DB>, + i64: sqlx::Type + sqlx::Encode<'a, DB>, + { + // ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564 + // 1. Start from the greedy query: SELECT * FROM table + + // 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE` + // clause + if let Some(after) = self.after { + query + .push(" AND ") + .push(id_field) + .push(" > ") + .push_bind(Uuid::from(after)); + } + + // 3. If the before argument is provided, add `id < parsed_cursor` to the + // `WHERE` clause + if let Some(before) = self.before { + query + .push(" AND ") + .push(id_field) + .push(" < ") + .push_bind(Uuid::from(before)); + } + + match self.direction { + // 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the + // query + PaginationDirection::Forward => { + query + .push(" ORDER BY ") + .push(id_field) + .push(" ASC LIMIT ") + .push_bind((self.count + 1) as i64); + } + // 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the + // query + PaginationDirection::Backward => { + query + .push(" ORDER BY ") + .push(id_field) + .push(" DESC LIMIT ") + .push_bind((self.count + 1) as i64); + } + }; + } + + /// Process a page returned by a paginated query + pub fn process(&self, mut edges: Vec) -> Page { + let is_full = edges.len() == (self.count + 1); + if is_full { + edges.pop(); + } + + let (has_previous_page, has_next_page) = match self.direction { + PaginationDirection::Forward => (false, is_full), + PaginationDirection::Backward => { + // 6. If the last argument is provided, I reverse the order of the results + edges.reverse(); + (is_full, false) + } + }; + + Page { + has_next_page, + has_previous_page, + edges, + } + } } pub struct Page { @@ -89,39 +181,6 @@ pub struct Page { } impl Page { - /// Process a page returned by a paginated query - pub fn process( - mut edges: Vec, - first: Option, - last: Option, - ) -> Result { - let limit = match (first, last) { - (Some(count), _) | (_, Some(count)) => count, - _ => return Err(InvalidPagination), - }; - - let is_full = edges.len() == (limit + 1); - if is_full { - edges.pop(); - } - - let (has_previous_page, has_next_page) = if first.is_some() { - (false, is_full) - } else if last.is_some() { - // 6. If the last argument is provided, I reverse the order of the results - edges.reverse(); - (is_full, false) - } else { - unreachable!() - }; - - Ok(Page { - has_next_page, - has_previous_page, - edges, - }) - } - pub fn map(self, f: F) -> Page where F: FnMut(T) -> T2, @@ -147,17 +206,13 @@ impl Page { } } -impl Page {} - +/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination +/// to a query pub trait QueryBuilderExt { - fn generate_pagination( - &mut self, - id_field: &'static str, - before: Option, - after: Option, - first: Option, - last: Option, - ) -> Result<&mut Self, InvalidPagination>; + /// Add cursor-based pagination to a query, as used in paginated GraphQL + /// connections + fn generate_pagination(&mut self, id_field: &'static str, pagination: &Pagination) + -> &mut Self; } impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB> @@ -169,12 +224,9 @@ where fn generate_pagination( &mut self, id_field: &'static str, - before: Option, - after: Option, - first: Option, - last: Option, - ) -> Result<&mut Self, InvalidPagination> { - generate_pagination(self, id_field, before, after, first, last)?; - Ok(self) + pagination: &Pagination, + ) -> &mut Self { + pagination.generate_pagination(self, id_field); + self } } diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index f72a8504a..13e86e17e 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -23,7 +23,7 @@ use uuid::Uuid; use crate::{ pagination::{Page, QueryBuilderExt}, tracing::ExecuteExt, - Clock, DatabaseError, LookupResultExt, + Clock, DatabaseError, LookupResultExt, Pagination, }; #[async_trait] @@ -60,10 +60,7 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { async fn list_paginated( &mut self, user: &User, - before: Option, - after: Option, - first: Option, - last: Option, + pagination: &Pagination, ) -> Result, Self::Error>; } @@ -275,10 +272,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { async fn list_paginated( &mut self, user: &User, - before: Option, - after: Option, - first: Option, - last: Option, + pagination: &Pagination, ) -> Result, Self::Error> { let mut query = QueryBuilder::new( r#" @@ -295,7 +289,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { query .push(" WHERE user_id = ") .push_bind(Uuid::from(user.id)) - .generate_pagination("upstream_oauth_link_id", before, after, first, last)?; + .generate_pagination("upstream_oauth_link_id", pagination); let edges: Vec = query .build_query_as() @@ -303,7 +297,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { .fetch_all(&mut *self.conn) .await?; - let page = Page::process(edges, first, last)?.map(UpstreamOAuthLink::from); + let page = pagination.process(edges).map(UpstreamOAuthLink::from); Ok(page) } } diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 0763624f4..d72e5f487 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -29,7 +29,7 @@ mod tests { use sqlx::PgPool; use super::*; - use crate::{user::UserRepository, Clock, PgRepository, Repository}; + use crate::{user::UserRepository, Clock, Pagination, PgRepository, Repository}; #[sqlx::test(migrator = "crate::MIGRATOR")] async fn test_repository(pool: PgPool) -> Result<(), Box> { @@ -144,7 +144,7 @@ mod tests { let links = repo .upstream_oauth_link() - .list_paginated(&user, None, None, Some(10), None) + .list_paginated(&user, &Pagination::first(10)) .await?; assert!(!links.has_previous_page); assert!(!links.has_next_page); diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 088e9e93b..eb09fd799 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -25,7 +25,7 @@ use uuid::Uuid; use crate::{ pagination::{Page, QueryBuilderExt}, tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, + Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, }; #[async_trait] @@ -52,10 +52,7 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { /// Get a paginated list of upstream OAuth providers async fn list_paginated( &mut self, - before: Option, - after: Option, - first: Option, - last: Option, + pagination: &Pagination, ) -> Result, Self::Error>; /// Get all upstream OAuth providers @@ -243,10 +240,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' )] async fn list_paginated( &mut self, - before: Option, - after: Option, - first: Option, - last: Option, + pagination: &Pagination, ) -> Result, Self::Error> { let mut query = QueryBuilder::new( r#" @@ -264,7 +258,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' "#, ); - query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?; + query.generate_pagination("upstream_oauth_provider_id", pagination); let edges: Vec = query .build_query_as() @@ -272,7 +266,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' .fetch_all(&mut *self.conn) .await?; - let page = Page::process(edges, first, last)?.try_map(TryInto::try_into)?; + let page = pagination.process(edges).try_map(TryInto::try_into)?; Ok(page) } diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 2d5ad9876..cef4fa271 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -24,7 +24,7 @@ use uuid::Uuid; use crate::{ pagination::{Page, QueryBuilderExt}, tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, + Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, }; #[async_trait] @@ -39,10 +39,7 @@ pub trait UserEmailRepository: Send + Sync { async fn list_paginated( &mut self, user: &User, - before: Option, - after: Option, - first: Option, - last: Option, + pagination: &Pagination, ) -> Result, Self::Error>; async fn count(&mut self, user: &User) -> Result; @@ -289,10 +286,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { async fn list_paginated( &mut self, user: &User, - before: Option, - after: Option, - first: Option, - last: Option, + pagination: &Pagination, ) -> Result, DatabaseError> { let mut query = QueryBuilder::new( r#" @@ -308,7 +302,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { query .push(" WHERE user_id = ") .push_bind(Uuid::from(user.id)) - .generate_pagination("ue.user_email_id", before, after, first, last)?; + .generate_pagination("ue.user_email_id", &pagination); let edges: Vec = query .build_query_as() @@ -316,7 +310,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { .fetch_all(&mut *self.conn) .await?; - let page = Page::process(edges, first, last)?.map(UserEmail::from); + let page = pagination.process(edges).map(UserEmail::from); Ok(page) } diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index f2ceed2a3..a837c2041 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -23,7 +23,7 @@ use uuid::Uuid; use crate::{ pagination::{Page, QueryBuilderExt}, tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, + Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, }; #[async_trait] @@ -45,10 +45,7 @@ pub trait BrowserSessionRepository: Send + Sync { async fn list_active_paginated( &mut self, user: &User, - before: Option, - after: Option, - first: Option, - last: Option, + pagination: &Pagination, ) -> Result, Self::Error>; async fn count_active(&mut self, user: &User) -> Result; @@ -264,10 +261,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { async fn list_active_paginated( &mut self, user: &User, - before: Option, - after: Option, - first: Option, - last: Option, + pagination: &Pagination, ) -> Result, Self::Error> { // TODO: ordering of last authentication is wrong let mut query = QueryBuilder::new( @@ -290,7 +284,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { query .push(" WHERE s.finished_at IS NULL AND s.user_id = ") .push_bind(Uuid::from(user.id)) - .generate_pagination("s.user_session_id", before, after, first, last)?; + .generate_pagination("s.user_session_id", pagination); let edges: Vec = query .build_query_as() @@ -298,7 +292,9 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { .fetch_all(&mut *self.conn) .await?; - let page = Page::process(edges, first, last)?.try_map(BrowserSession::try_from)?; + let page = pagination + .process(edges) + .try_map(BrowserSession::try_from)?; Ok(page) } From 1be5c9f5cfb8ed570311e71b7efea11824e9dcde Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 17 Jan 2023 16:11:13 +0100 Subject: [PATCH 25/45] storage: add tests for the upstream provider paginated list --- crates/storage/src/upstream_oauth2/mod.rs | 144 +++++++++++++++++++--- 1 file changed, 126 insertions(+), 18 deletions(-) diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index d72e5f487..148489350 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -24,6 +24,7 @@ pub use self::{ #[cfg(test)] mod tests { + use chrono::Duration; use oauth2_types::scope::{Scope, OPENID}; use rand::SeedableRng; use sqlx::PgPool; @@ -32,13 +33,13 @@ mod tests { use crate::{user::UserRepository, Clock, Pagination, PgRepository, Repository}; #[sqlx::test(migrator = "crate::MIGRATOR")] - async fn test_repository(pool: PgPool) -> Result<(), Box> { + async fn test_repository(pool: PgPool) { let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); let clock = Clock::mock(); - let mut repo = PgRepository::from_pool(&pool).await?; + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); // The provider list should be empty at the start - let all_providers = repo.upstream_oauth_provider().all().await?; + let all_providers = repo.upstream_oauth_provider().all().await.unwrap(); assert!(all_providers.is_empty()); // Let's add a provider @@ -54,13 +55,15 @@ mod tests { "client-id".to_owned(), None, ) - .await?; + .await + .unwrap(); // Look it up in the database let provider = repo .upstream_oauth_provider() .lookup(provider.id) - .await? + .await + .unwrap() .expect("provider to be found in the database"); assert_eq!(provider.issuer, "https://example.com/"); assert_eq!(provider.client_id, "client-id"); @@ -76,13 +79,15 @@ mod tests { None, "some-nonce".to_owned(), ) - .await?; + .await + .unwrap(); // Look it up in the database let session = repo .upstream_oauth_session() .lookup(session.id) - .await? + .await + .unwrap() .expect("session to be found in the database"); assert_eq!(session.provider_id, provider.id); assert_eq!(session.link_id(), None); @@ -94,19 +99,22 @@ mod tests { let link = repo .upstream_oauth_link() .add(&mut rng, &clock, &provider, "a-subject".to_owned()) - .await?; + .await + .unwrap(); // We can look it up by its ID repo.upstream_oauth_link() .lookup(link.id) - .await? + .await + .unwrap() .expect("link to be found in database"); // or by its subject let link = repo .upstream_oauth_link() .find_by_subject(&provider, "a-subject") - .await? + .await + .unwrap() .expect("link to be found in database"); assert_eq!(link.subject, "a-subject"); assert_eq!(link.provider_id, provider.id); @@ -114,12 +122,14 @@ mod tests { let session = repo .upstream_oauth_session() .complete_with_link(&clock, session, &link, None) - .await?; + .await + .unwrap(); // Reload the session let session = repo .upstream_oauth_session() .lookup(session.id) - .await? + .await + .unwrap() .expect("session to be found in the database"); assert!(session.is_completed()); assert!(!session.is_consumed()); @@ -128,30 +138,128 @@ mod tests { let session = repo .upstream_oauth_session() .consume(&clock, session) - .await?; + .await + .unwrap(); // Reload the session let session = repo .upstream_oauth_session() .lookup(session.id) - .await? + .await + .unwrap() .expect("session to be found in the database"); assert!(session.is_consumed()); - let user = repo.user().add(&mut rng, &clock, "john".to_owned()).await?; + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); repo.upstream_oauth_link() .associate_to_user(&link, &user) - .await?; + .await + .unwrap(); let links = repo .upstream_oauth_link() .list_paginated(&user, &Pagination::first(10)) - .await?; + .await + .unwrap(); assert!(!links.has_previous_page); assert!(!links.has_next_page); assert_eq!(links.edges.len(), 1); assert_eq!(links.edges[0].id, link.id); assert_eq!(links.edges[0].user_id, Some(user.id)); + } - Ok(()) + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_provider_repository_pagination(pool: PgPool) { + const ISSUER: &str = "https://example.com/"; + let scope = Scope::from_iter([OPENID]); + + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + let mut ids = Vec::with_capacity(20); + // Create 20 providers + for idx in 0..20 { + let client_id = format!("client-{idx}"); + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + ISSUER.to_owned(), + scope.clone(), + mas_iana::oauth::OAuthClientAuthenticationMethod::None, + None, + client_id, + None, + ) + .await + .unwrap(); + ids.push(provider.id); + clock.advance(Duration::seconds(10)); + } + + // Lookup the first 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(&Pagination::first(10)) + .await + .unwrap(); + + // It returned the first 10 items + assert!(page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[..10]); + + // Lookup the next 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(&Pagination::first(10).after(ids[9])) + .await + .unwrap(); + + // It returned the next 10 items + assert!(!page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[10..]); + + // Lookup the last 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(&Pagination::last(10)) + .await + .unwrap(); + + // It returned the last 10 items + assert!(page.has_previous_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[10..]); + + // Lookup the previous 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(&Pagination::last(10).before(ids[10])) + .await + .unwrap(); + + // It returned the previous 10 items + assert!(!page.has_previous_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[..10]); + + // Lookup 10 items between two IDs + let page = repo + .upstream_oauth_provider() + .list_paginated(&Pagination::first(10).after(ids[5]).before(ids[8])) + .await + .unwrap(); + + // It returned the items in between + assert!(!page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[6..8]); } } From b7d342daf6b6325274da28cce9fe631dee29641f Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 17 Jan 2023 16:44:22 +0100 Subject: [PATCH 26/45] storage: don't use references for pagination --- crates/graphql/src/lib.rs | 2 +- crates/graphql/src/model/users.rs | 10 +++++----- crates/storage/src/compat/sso_login.rs | 6 +++--- crates/storage/src/oauth2/session.rs | 4 ++-- crates/storage/src/pagination.rs | 15 ++++++++------- crates/storage/src/upstream_oauth2/link.rs | 4 ++-- crates/storage/src/upstream_oauth2/mod.rs | 12 ++++++------ crates/storage/src/upstream_oauth2/provider.rs | 4 ++-- crates/storage/src/user/email.rs | 6 +++--- crates/storage/src/user/session.rs | 4 ++-- 10 files changed, 34 insertions(+), 33 deletions(-) diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 765db8e56..d2de0c24f 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -232,7 +232,7 @@ impl RootQuery { let page = repo .upstream_oauth_provider() - .list_paginated(&pagination) + .list_paginated(pagination) .await?; let mut connection = Connection::new(page.has_previous_page, page.has_next_page); diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 9a9062d0c..68daff1b5 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -99,7 +99,7 @@ impl User { let page = repo .compat_sso_login() - .list_paginated(&self.0, &pagination) + .list_paginated(&self.0, pagination) .await?; let mut connection = Connection::new(page.has_previous_page, page.has_next_page); @@ -146,7 +146,7 @@ impl User { let page = repo .browser_session() - .list_active_paginated(&self.0, &pagination) + .list_active_paginated(&self.0, pagination) .await?; let mut connection = Connection::new(page.has_previous_page, page.has_next_page); @@ -193,7 +193,7 @@ impl User { let page = repo .user_email() - .list_paginated(&self.0, &pagination) + .list_paginated(&self.0, pagination) .await?; let mut connection = Connection::with_additional_fields( @@ -244,7 +244,7 @@ impl User { let page = repo .oauth2_session() - .list_paginated(&self.0, &pagination) + .list_paginated(&self.0, pagination) .await?; let mut connection = Connection::new(page.has_previous_page, page.has_next_page); @@ -295,7 +295,7 @@ impl User { let page = repo .upstream_oauth_link() - .list_paginated(&self.0, &pagination) + .list_paginated(&self.0, pagination) .await?; let mut connection = Connection::new(page.has_previous_page, page.has_next_page); diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs index 31c3da3c8..76cf1ede4 100644 --- a/crates/storage/src/compat/sso_login.rs +++ b/crates/storage/src/compat/sso_login.rs @@ -68,7 +68,7 @@ pub trait CompatSsoLoginRepository: Send + Sync { async fn list_paginated( &mut self, user: &User, - pagination: &Pagination, + pagination: Pagination, ) -> Result, Self::Error>; } @@ -351,7 +351,7 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { async fn list_paginated( &mut self, user: &User, - pagination: &Pagination, + pagination: Pagination, ) -> Result, Self::Error> { let mut query = QueryBuilder::new( r#" @@ -371,7 +371,7 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { query .push(" WHERE user_id = ") .push_bind(Uuid::from(user.id)) - .generate_pagination("cl.compat_sso_login_id", &pagination); + .generate_pagination("cl.compat_sso_login_id", pagination); let edges: Vec = query .build_query_as() diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 0fa8cb8fc..dc21fbcb3 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -45,7 +45,7 @@ pub trait OAuth2SessionRepository: Send + Sync { async fn list_paginated( &mut self, user: &User, - pagination: &Pagination, + pagination: Pagination, ) -> Result, Self::Error>; } @@ -240,7 +240,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { async fn list_paginated( &mut self, user: &User, - pagination: &Pagination, + pagination: Pagination, ) -> Result, Self::Error> { let mut query = QueryBuilder::new( r#" diff --git a/crates/storage/src/pagination.rs b/crates/storage/src/pagination.rs index 1fbcbc516..1fa74ddac 100644 --- a/crates/storage/src/pagination.rs +++ b/crates/storage/src/pagination.rs @@ -66,6 +66,7 @@ impl Pagination { } /// Creates a [`Pagination`] which gets the first N items + #[must_use] pub const fn first(first: usize) -> Self { Self { before: None, @@ -76,6 +77,7 @@ impl Pagination { } /// Creates a [`Pagination`] which gets the last N items + #[must_use] pub const fn last(last: usize) -> Self { Self { before: None, @@ -86,12 +88,14 @@ impl Pagination { } /// Get items before the given cursor + #[must_use] pub const fn before(mut self, id: Ulid) -> Self { self.before = Some(id); self } /// Get items after the given cursor + #[must_use] pub const fn after(mut self, id: Ulid) -> Self { self.after = Some(id); self @@ -181,6 +185,7 @@ pub struct Page { } impl Page { + #[must_use] pub fn map(self, f: F) -> Page where F: FnMut(T) -> T2, @@ -193,6 +198,7 @@ impl Page { } } + #[must_use] pub fn try_map(self, f: F) -> Result, E> where F: FnMut(T) -> Result, @@ -211,8 +217,7 @@ impl Page { pub trait QueryBuilderExt { /// Add cursor-based pagination to a query, as used in paginated GraphQL /// connections - fn generate_pagination(&mut self, id_field: &'static str, pagination: &Pagination) - -> &mut Self; + fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self; } impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB> @@ -221,11 +226,7 @@ where Uuid: sqlx::Type + sqlx::Encode<'a, DB>, i64: sqlx::Type + sqlx::Encode<'a, DB>, { - fn generate_pagination( - &mut self, - id_field: &'static str, - pagination: &Pagination, - ) -> &mut Self { + fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self { pagination.generate_pagination(self, id_field); self } diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 13e86e17e..76364afe6 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -60,7 +60,7 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { async fn list_paginated( &mut self, user: &User, - pagination: &Pagination, + pagination: Pagination, ) -> Result, Self::Error>; } @@ -272,7 +272,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { async fn list_paginated( &mut self, user: &User, - pagination: &Pagination, + pagination: Pagination, ) -> Result, Self::Error> { let mut query = QueryBuilder::new( r#" diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 148489350..d1b6809f8 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -161,7 +161,7 @@ mod tests { let links = repo .upstream_oauth_link() - .list_paginated(&user, &Pagination::first(10)) + .list_paginated(&user, Pagination::first(10)) .await .unwrap(); assert!(!links.has_previous_page); @@ -205,7 +205,7 @@ mod tests { // Lookup the first 10 items let page = repo .upstream_oauth_provider() - .list_paginated(&Pagination::first(10)) + .list_paginated(Pagination::first(10)) .await .unwrap(); @@ -217,7 +217,7 @@ mod tests { // Lookup the next 10 items let page = repo .upstream_oauth_provider() - .list_paginated(&Pagination::first(10).after(ids[9])) + .list_paginated(Pagination::first(10).after(ids[9])) .await .unwrap(); @@ -229,7 +229,7 @@ mod tests { // Lookup the last 10 items let page = repo .upstream_oauth_provider() - .list_paginated(&Pagination::last(10)) + .list_paginated(Pagination::last(10)) .await .unwrap(); @@ -241,7 +241,7 @@ mod tests { // Lookup the previous 10 items let page = repo .upstream_oauth_provider() - .list_paginated(&Pagination::last(10).before(ids[10])) + .list_paginated(Pagination::last(10).before(ids[10])) .await .unwrap(); @@ -253,7 +253,7 @@ mod tests { // Lookup 10 items between two IDs let page = repo .upstream_oauth_provider() - .list_paginated(&Pagination::first(10).after(ids[5]).before(ids[8])) + .list_paginated(Pagination::first(10).after(ids[5]).before(ids[8])) .await .unwrap(); diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index eb09fd799..14bd65471 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -52,7 +52,7 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { /// Get a paginated list of upstream OAuth providers async fn list_paginated( &mut self, - pagination: &Pagination, + pagination: Pagination, ) -> Result, Self::Error>; /// Get all upstream OAuth providers @@ -240,7 +240,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' )] async fn list_paginated( &mut self, - pagination: &Pagination, + pagination: Pagination, ) -> Result, Self::Error> { let mut query = QueryBuilder::new( r#" diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index cef4fa271..8c8efe1b5 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -39,7 +39,7 @@ pub trait UserEmailRepository: Send + Sync { async fn list_paginated( &mut self, user: &User, - pagination: &Pagination, + pagination: Pagination, ) -> Result, Self::Error>; async fn count(&mut self, user: &User) -> Result; @@ -286,7 +286,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { async fn list_paginated( &mut self, user: &User, - pagination: &Pagination, + pagination: Pagination, ) -> Result, DatabaseError> { let mut query = QueryBuilder::new( r#" @@ -302,7 +302,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { query .push(" WHERE user_id = ") .push_bind(Uuid::from(user.id)) - .generate_pagination("ue.user_email_id", &pagination); + .generate_pagination("ue.user_email_id", pagination); let edges: Vec = query .build_query_as() diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index a837c2041..10b96da77 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -45,7 +45,7 @@ pub trait BrowserSessionRepository: Send + Sync { async fn list_active_paginated( &mut self, user: &User, - pagination: &Pagination, + pagination: Pagination, ) -> Result, Self::Error>; async fn count_active(&mut self, user: &User) -> Result; @@ -261,7 +261,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { async fn list_active_paginated( &mut self, user: &User, - pagination: &Pagination, + pagination: Pagination, ) -> Result, Self::Error> { // TODO: ordering of last authentication is wrong let mut query = QueryBuilder::new( From eb4ce7e7f08e25f69eace792f69ee354cb13153c Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 18 Jan 2023 09:53:42 +0100 Subject: [PATCH 27/45] Split the storage trait from the implementation --- Cargo.lock | 21 + crates/cli/Cargo.toml | 1 + crates/cli/src/commands/database.rs | 2 +- crates/cli/src/commands/manage.rs | 3 +- crates/cli/src/commands/server.rs | 4 +- crates/data-model/Cargo.toml | 2 +- crates/graphql/Cargo.toml | 1 + crates/graphql/src/lib.rs | 3 +- crates/graphql/src/model/compat_sessions.rs | 5 +- crates/graphql/src/model/oauth.rs | 5 +- crates/graphql/src/model/upstream_oauth.rs | 4 +- crates/graphql/src/model/users.rs | 3 +- crates/handlers/Cargo.toml | 1 + crates/handlers/src/compat/login.rs | 5 +- .../handlers/src/compat/login_sso_complete.rs | 3 +- .../handlers/src/compat/login_sso_redirect.rs | 5 +- crates/handlers/src/compat/logout.rs | 5 +- crates/handlers/src/compat/refresh.rs | 5 +- crates/handlers/src/graphql.rs | 2 +- crates/handlers/src/health.rs | 4 +- .../src/oauth2/authorization/complete.rs | 7 +- .../handlers/src/oauth2/authorization/mod.rs | 5 +- crates/handlers/src/oauth2/consent.rs | 5 +- crates/handlers/src/oauth2/introspection.rs | 5 +- crates/handlers/src/oauth2/registration.rs | 5 +- crates/handlers/src/oauth2/token.rs | 5 +- crates/handlers/src/oauth2/userinfo.rs | 9 +- .../handlers/src/upstream_oauth2/authorize.rs | 5 +- .../handlers/src/upstream_oauth2/callback.rs | 5 +- crates/handlers/src/upstream_oauth2/link.rs | 5 +- .../handlers/src/views/account/emails/add.rs | 3 +- .../handlers/src/views/account/emails/mod.rs | 3 +- .../src/views/account/emails/verify.rs | 3 +- crates/handlers/src/views/account/mod.rs | 3 +- crates/handlers/src/views/account/password.rs | 3 +- crates/handlers/src/views/index.rs | 2 +- crates/handlers/src/views/login.rs | 3 +- crates/handlers/src/views/logout.rs | 3 +- crates/handlers/src/views/reauth.rs | 3 +- crates/handlers/src/views/register.rs | 3 +- crates/storage-pg/Cargo.toml | 27 + crates/{storage => storage-pg}/build.rs | 2 +- .../migrations/20221018142001_init.sql | 0 .../20221121151402_upstream_oauth.sql | 0 .../20221213145242_password_schemes.sql | 0 crates/{storage => storage-pg}/sqlx-data.json | 36 +- crates/storage-pg/src/compat/access_token.rs | 216 +++++ crates/storage-pg/src/compat/mod.rs | 322 ++++++++ crates/storage-pg/src/compat/refresh_token.rs | 230 ++++++ crates/storage-pg/src/compat/session.rs | 195 +++++ crates/storage-pg/src/compat/sso_login.rs | 342 ++++++++ crates/storage-pg/src/lib.rs | 170 ++++ crates/storage-pg/src/oauth2/access_token.rs | 223 ++++++ .../src/oauth2/authorization_grant.rs | 510 ++++++++++++ crates/storage-pg/src/oauth2/client.rs | 745 ++++++++++++++++++ crates/storage-pg/src/oauth2/mod.rs | 25 + crates/storage-pg/src/oauth2/refresh_token.rs | 224 ++++++ crates/storage-pg/src/oauth2/session.rs | 248 ++++++ crates/storage-pg/src/pagination.rs | 78 ++ crates/storage-pg/src/repository.rs | 142 ++++ crates/{storage => storage-pg}/src/tracing.rs | 2 +- crates/storage-pg/src/upstream_oauth2/link.rs | 262 ++++++ crates/storage-pg/src/upstream_oauth2/mod.rs | 271 +++++++ .../src/upstream_oauth2/provider.rs | 273 +++++++ .../storage-pg/src/upstream_oauth2/session.rs | 286 +++++++ crates/storage-pg/src/user/email.rs | 554 +++++++++++++ crates/storage-pg/src/user/mod.rs | 203 +++++ crates/storage-pg/src/user/password.rs | 155 ++++ crates/storage-pg/src/user/session.rs | 375 +++++++++ .../{storage => storage-pg}/src/user/tests.rs | 15 +- crates/storage/Cargo.toml | 12 +- crates/storage/src/compat/access_token.rs | 198 +---- crates/storage/src/compat/mod.rs | 299 +------ crates/storage/src/compat/refresh_token.rs | 213 +---- crates/storage/src/compat/session.rs | 180 +---- crates/storage/src/compat/sso_login.rs | 325 +------- crates/storage/src/lib.rs | 154 +--- crates/storage/src/oauth2/access_token.rs | 207 +---- .../storage/src/oauth2/authorization_grant.rs | 493 +----------- crates/storage/src/oauth2/client.rs | 732 +---------------- crates/storage/src/oauth2/mod.rs | 12 +- crates/storage/src/oauth2/refresh_token.rs | 209 +---- crates/storage/src/oauth2/session.rs | 234 +----- crates/storage/src/pagination.rs | 87 +- crates/storage/src/repository.rs | 129 +-- crates/storage/src/upstream_oauth2/link.rs | 249 +----- crates/storage/src/upstream_oauth2/mod.rs | 249 +----- .../storage/src/upstream_oauth2/provider.rs | 255 +----- crates/storage/src/upstream_oauth2/session.rs | 273 +------ crates/storage/src/user/email.rs | 540 +------------ crates/storage/src/user/mod.rs | 179 +---- crates/storage/src/user/password.rs | 141 +--- crates/storage/src/user/session.rs | 361 +-------- crates/tasks/Cargo.toml | 1 + crates/tasks/src/database.rs | 3 +- 95 files changed, 6294 insertions(+), 5741 deletions(-) create mode 100644 crates/storage-pg/Cargo.toml rename crates/{storage => storage-pg}/build.rs (92%) rename crates/{storage => storage-pg}/migrations/20221018142001_init.sql (100%) rename crates/{storage => storage-pg}/migrations/20221121151402_upstream_oauth.sql (100%) rename crates/{storage => storage-pg}/migrations/20221213145242_password_schemes.sql (100%) rename crates/{storage => storage-pg}/sqlx-data.json (98%) create mode 100644 crates/storage-pg/src/compat/access_token.rs create mode 100644 crates/storage-pg/src/compat/mod.rs create mode 100644 crates/storage-pg/src/compat/refresh_token.rs create mode 100644 crates/storage-pg/src/compat/session.rs create mode 100644 crates/storage-pg/src/compat/sso_login.rs create mode 100644 crates/storage-pg/src/lib.rs create mode 100644 crates/storage-pg/src/oauth2/access_token.rs create mode 100644 crates/storage-pg/src/oauth2/authorization_grant.rs create mode 100644 crates/storage-pg/src/oauth2/client.rs create mode 100644 crates/storage-pg/src/oauth2/mod.rs create mode 100644 crates/storage-pg/src/oauth2/refresh_token.rs create mode 100644 crates/storage-pg/src/oauth2/session.rs create mode 100644 crates/storage-pg/src/pagination.rs create mode 100644 crates/storage-pg/src/repository.rs rename crates/{storage => storage-pg}/src/tracing.rs (95%) create mode 100644 crates/storage-pg/src/upstream_oauth2/link.rs create mode 100644 crates/storage-pg/src/upstream_oauth2/mod.rs create mode 100644 crates/storage-pg/src/upstream_oauth2/provider.rs create mode 100644 crates/storage-pg/src/upstream_oauth2/session.rs create mode 100644 crates/storage-pg/src/user/email.rs create mode 100644 crates/storage-pg/src/user/mod.rs create mode 100644 crates/storage-pg/src/user/password.rs create mode 100644 crates/storage-pg/src/user/session.rs rename crates/{storage => storage-pg}/src/user/tests.rs (98%) diff --git a/Cargo.lock b/Cargo.lock index b780f20fa..6ecc512d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2705,6 +2705,7 @@ dependencies = [ "mas-router", "mas-spa", "mas-storage", + "mas-storage-pg", "mas-tasks", "mas-templates", "oauth2-types", @@ -2803,6 +2804,7 @@ dependencies = [ "chrono", "mas-data-model", "mas-storage", + "mas-storage-pg", "oauth2-types", "serde", "sqlx", @@ -2843,6 +2845,7 @@ dependencies = [ "mas-policy", "mas-router", "mas-storage", + "mas-storage-pg", "mas-templates", "mime", "oauth2-types", @@ -3103,6 +3106,23 @@ dependencies = [ "mas-jose", "oauth2-types", "rand 0.8.5", + "thiserror", + "ulid", + "url", +] + +[[package]] +name = "mas-storage-pg" +version = "0.1.0" +dependencies = [ + "async-trait", + "chrono", + "mas-data-model", + "mas-iana", + "mas-jose", + "mas-storage", + "oauth2-types", + "rand 0.8.5", "rand_chacha 0.3.1", "serde", "serde_json", @@ -3121,6 +3141,7 @@ dependencies = [ "async-trait", "futures-util", "mas-storage", + "mas-storage-pg", "sqlx", "tokio", "tokio-stream", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index a3bc4ec3b..451e8d325 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -50,6 +50,7 @@ mas-policy = { path = "../policy" } mas-router = { path = "../router" } mas-spa = { path = "../spa" } mas-storage = { path = "../storage" } +mas-storage-pg = { path = "../storage-pg" } mas-tasks = { path = "../tasks" } mas-templates = { path = "../templates" } oauth2-types = { path = "../oauth2-types" } diff --git a/crates/cli/src/commands/database.rs b/crates/cli/src/commands/database.rs index ca59ce1dd..0e4d68af6 100644 --- a/crates/cli/src/commands/database.rs +++ b/crates/cli/src/commands/database.rs @@ -15,7 +15,7 @@ use anyhow::Context; use clap::Parser; use mas_config::DatabaseConfig; -use mas_storage::MIGRATOR; +use mas_storage_pg::MIGRATOR; use tracing::{info_span, Instrument}; use crate::util::database_from_config; diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 153789400..c608db836 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -21,8 +21,9 @@ use mas_storage::{ oauth2::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, - Clock, PgRepository, Repository, + Clock, Repository, }; +use mas_storage_pg::PgRepository; use oauth2_types::scope::Scope; use rand::SeedableRng; use tracing::{info, info_span, warn}; diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index fb2a3f168..1a7e39e69 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ use mas_config::RootConfig; use mas_handlers::{AppState, HttpClientFactory, MatrixHomeserver}; use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_router::UrlBuilder; -use mas_storage::MIGRATOR; +use mas_storage_pg::MIGRATOR; use mas_tasks::TaskQueue; use tokio::signal::unix::SignalKind; use tracing::{info, info_span, warn, Instrument}; diff --git a/crates/data-model/Cargo.toml b/crates/data-model/Cargo.toml index 661e9bbcc..b96d7b83a 100644 --- a/crates/data-model/Cargo.toml +++ b/crates/data-model/Cargo.toml @@ -11,8 +11,8 @@ thiserror = "1.0.38" serde = "1.0.152" url = { version = "2.3.1", features = ["serde"] } crc = "3.0.0" +ulid = { version = "1.0.0", features = ["serde"] } rand = "0.8.5" -ulid = "1.0.0" rand_chacha = "0.3.1" mas-iana = { path = "../iana" } diff --git a/crates/graphql/Cargo.toml b/crates/graphql/Cargo.toml index ff3159b08..16f7e5b51 100644 --- a/crates/graphql/Cargo.toml +++ b/crates/graphql/Cargo.toml @@ -19,6 +19,7 @@ url = "2.3.1" oauth2-types = { path = "../oauth2-types" } mas-data-model = { path = "../data-model" } mas-storage = { path = "../storage" } +mas-storage-pg = { path = "../storage-pg" } [[bin]] name = "schema" diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index d2de0c24f..159387ae3 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -34,8 +34,9 @@ use mas_storage::{ oauth2::OAuth2ClientRepository, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, user::{BrowserSessionRepository, UserEmailRepository}, - Pagination, PgRepository, Repository, + Pagination, Repository, }; +use mas_storage_pg::PgRepository; use model::CreationEvent; use sqlx::PgPool; diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index a2196e36f..e5cd66bce 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -15,9 +15,8 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_storage::{ - compat::CompatSessionRepository, user::UserRepository, PgRepository, Repository, -}; +use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository}; +use mas_storage_pg::PgRepository; use sqlx::PgPool; use url::Url; diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 171c800fb..90a0c6b7f 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,9 +14,8 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::{ - oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, PgRepository, Repository, -}; +use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, Repository}; +use mas_storage_pg::PgRepository; use oauth2_types::scope::Scope; use sqlx::PgPool; use ulid::Ulid; diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index 4a4c223b4..5767f8d4b 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -16,9 +16,9 @@ use anyhow::Context as _; use async_graphql::{Context, Object, ID}; use chrono::{DateTime, Utc}; use mas_storage::{ - upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, PgRepository, - Repository, + upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, Repository, }; +use mas_storage_pg::PgRepository; use sqlx::PgPool; use super::{NodeType, User}; diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 68daff1b5..3f587eb06 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -22,8 +22,9 @@ use mas_storage::{ oauth2::OAuth2SessionRepository, upstream_oauth2::UpstreamOAuthLinkRepository, user::{BrowserSessionRepository, UserEmailRepository}, - Pagination, PgRepository, Repository, + Pagination, Repository, }; +use mas_storage_pg::PgRepository; use sqlx::PgPool; use super::{ diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index 47dd27755..b8fbb5af9 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -68,6 +68,7 @@ mas-oidc-client = { path = "../oidc-client" } mas-policy = { path = "../policy" } mas-router = { path = "../router" } mas-storage = { path = "../storage" } +mas-storage-pg = { path = "../storage-pg" } mas-templates = { path = "../templates" } oauth2-types = { path = "../oauth2-types" } diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index f344f7e0e..bfd36d8ae 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -22,8 +22,9 @@ use mas_storage::{ CompatSsoLoginRepository, }, user::{UserPasswordRepository, UserRepository}, - Clock, PgRepository, Repository, + Clock, Repository, }; +use mas_storage_pg::PgRepository; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; use sqlx::PgPool; @@ -154,7 +155,7 @@ pub enum RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 7ca61ab2c..1fea922e8 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -31,8 +31,9 @@ use mas_keystore::Encrypter; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; use mas_storage::{ compat::{CompatSessionRepository, CompatSsoLoginRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use serde::{Deserialize, Serialize}; use sqlx::PgPool; diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index befd3e323..38aa08943 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -19,7 +19,8 @@ use axum::{ }; use hyper::StatusCode; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; -use mas_storage::{compat::CompatSsoLoginRepository, PgRepository, Repository}; +use mas_storage::{compat::CompatSsoLoginRepository, Repository}; +use mas_storage_pg::PgRepository; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use serde_with::serde; @@ -49,7 +50,7 @@ pub enum RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 762f77b2f..21229fe76 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -18,8 +18,9 @@ use hyper::StatusCode; use mas_data_model::TokenType; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - Clock, PgRepository, Repository, + Clock, Repository, }; +use mas_storage_pg::PgRepository; use sqlx::PgPool; use thiserror::Error; @@ -42,7 +43,7 @@ pub enum RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index ea6d5d238..e16013950 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -18,8 +18,9 @@ use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; use sqlx::PgPool; @@ -70,7 +71,7 @@ impl IntoResponse for RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl From for RouteError { fn from(_e: TokenFormatError) -> Self { diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql.rs index d3a610b6a..fcc6aa3c3 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql.rs @@ -28,7 +28,7 @@ use hyper::header::CACHE_CONTROL; use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_graphql::Schema; use mas_keystore::Encrypter; -use mas_storage::PgRepository; +use mas_storage_pg::PgRepository; use sqlx::PgPool; use tracing::{info_span, Instrument}; diff --git a/crates/handlers/src/health.rs b/crates/handlers/src/health.rs index 6322dffde..10638497b 100644 --- a/crates/handlers/src/health.rs +++ b/crates/handlers/src/health.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ mod tests { use super::*; - #[sqlx::test(migrator = "mas_storage::MIGRATOR")] + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_get_health(pool: PgPool) -> Result<(), anyhow::Error> { let state = crate::test_state(pool).await?; let app = crate::healthcheck_router().with_state(state); diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index c983e79c7..05554e128 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -27,8 +27,9 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::Templates; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; use sqlx::PgPool; @@ -70,7 +71,7 @@ impl IntoResponse for RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -149,7 +150,7 @@ pub enum GrantCompletionError { } impl_from_error_for_route!(GrantCompletionError: sqlx::Error); -impl_from_error_for_route!(GrantCompletionError: mas_storage::DatabaseError); +impl_from_error_for_route!(GrantCompletionError: mas_storage_pg::DatabaseError); impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError); diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 155f72f73..43bda928b 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -27,8 +27,9 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::Templates; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -91,7 +92,7 @@ impl IntoResponse for RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index f3d4fd46e..b0f752f79 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -30,8 +30,9 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; use sqlx::PgPool; use thiserror::Error; @@ -62,7 +63,7 @@ pub enum RouteError { impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_templates::TemplateError); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 2837928f7..e8f9941f8 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -25,8 +25,9 @@ use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository}, user::{BrowserSessionRepository, UserRepository}, - Clock, PgRepository, Repository, + Clock, Repository, }; +use mas_storage_pg::PgRepository; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, requests::{IntrospectionRequest, IntrospectionResponse}, @@ -97,7 +98,7 @@ impl IntoResponse for RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl From for RouteError { fn from(_e: TokenFormatError) -> Self { diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index d6180f9aa..8e9489e81 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -19,7 +19,8 @@ use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; use mas_policy::{PolicyFactory, Violation}; -use mas_storage::{oauth2::OAuth2ClientRepository, PgRepository, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; +use mas_storage_pg::PgRepository; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, registration::{ @@ -49,7 +50,7 @@ pub(crate) enum RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 6365a0ada..67ecb4985 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -37,8 +37,9 @@ use mas_storage::{ OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, user::BrowserSessionRepository, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, pkce::CodeChallengeError, @@ -151,7 +152,7 @@ impl IntoResponse for RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::claims::ClaimError); impl_from_error_for_route!(mas_jose::claims::TokenHashError); diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index a125c5dde..2f5600379 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -31,8 +31,9 @@ use mas_router::UrlBuilder; use mas_storage::{ oauth2::OAuth2ClientRepository, user::{BrowserSessionRepository, UserEmailRepository}, - DatabaseError, PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use oauth2_types::scope; use serde::Serialize; use serde_with::skip_serializing_none; @@ -64,7 +65,9 @@ pub enum RouteError { Internal(Box), #[error("failed to authenticate")] - AuthorizationVerificationError(#[from] AuthorizationVerificationError), + AuthorizationVerificationError( + #[from] AuthorizationVerificationError, + ), #[error("no suitable key found for signing")] InvalidSigningKey, @@ -77,7 +80,7 @@ pub enum RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index bdd19b7b8..fcf5a7d14 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -24,8 +24,9 @@ use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; use mas_storage::{ upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -46,7 +47,7 @@ impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 521efd7b4..d243666d4 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -30,8 +30,9 @@ use mas_storage::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; use sqlx::PgPool; @@ -99,7 +100,7 @@ pub(crate) enum RouteError { Internal(Box), } -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 18849be84..8709ff219 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -27,8 +27,9 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, user::{BrowserSessionRepository, UserRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{ EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink, @@ -73,7 +74,7 @@ impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index e0cc063d3..e99c8e4d7 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -24,7 +24,8 @@ use mas_axum_utils::{ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, PgRepository, Repository}; +use mas_storage::{user::UserEmailRepository, Repository}; +use mas_storage_pg::PgRepository; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; use sqlx::PgPool; diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 3fda398ad..4d70ab332 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -28,7 +28,8 @@ use mas_data_model::{BrowserSession, User, UserEmail}; use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Clock, PgRepository, Repository}; +use mas_storage::{user::UserEmailRepository, Clock, Repository}; +use mas_storage_pg::PgRepository; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; use serde::Deserialize; diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 085b9a337..2b398b424 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -24,7 +24,8 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Clock, PgRepository, Repository}; +use mas_storage::{user::UserEmailRepository, Clock, Repository}; +use mas_storage_pg::PgRepository; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; use sqlx::PgPool; diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 5017db00e..8d2eb3e2d 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -25,8 +25,9 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{AccountContext, TemplateContext, Templates}; use sqlx::PgPool; diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 8d4964323..089093f63 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -27,8 +27,9 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - Clock, PgRepository, Repository, + Clock, Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{EmptyContext, TemplateContext, Templates}; use rand::Rng; use serde::Deserialize; diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index 49668daeb..cab7c743e 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -20,7 +20,7 @@ use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_keystore::Encrypter; use mas_router::UrlBuilder; -use mas_storage::PgRepository; +use mas_storage_pg::PgRepository; use mas_templates::{IndexContext, TemplateContext, Templates}; use sqlx::PgPool; diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 76ffa4558..87ba9e848 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -26,8 +26,9 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, - Clock, PgRepository, Repository, + Clock, Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{ FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, }; diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 156e6afbe..373264d06 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -23,7 +23,8 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; -use mas_storage::{user::BrowserSessionRepository, Clock, PgRepository, Repository}; +use mas_storage::{user::BrowserSessionRepository, Clock, Repository}; +use mas_storage_pg::PgRepository; use sqlx::PgPool; pub(crate) async fn post( diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index aac51abd7..49249f3cf 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -26,8 +26,9 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{ReauthContext, TemplateContext, Templates}; use serde::Deserialize; use sqlx::PgPool; diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index a014eb9db..58db6ec16 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -33,8 +33,9 @@ use mas_policy::PolicyFactory; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{ EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, TemplateContext, Templates, ToFormState, diff --git a/crates/storage-pg/Cargo.toml b/crates/storage-pg/Cargo.toml new file mode 100644 index 000000000..fad6e30ee --- /dev/null +++ b/crates/storage-pg/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "mas-storage-pg" +version = "0.1.0" +authors = ["Quentin Gliech "] +edition = "2021" +license = "Apache-2.0" + +[dependencies] +async-trait = "0.1.60" +sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] } +chrono = { version = "0.4.23", features = ["serde"] } +serde = { version = "1.0.152", features = ["derive"] } +serde_json = "1.0.91" +thiserror = "1.0.38" +tracing = "0.1.37" + +rand = "0.8.5" +rand_chacha = "0.3.1" +url = { version = "2.3.1", features = ["serde"] } +uuid = "1.2.2" +ulid = { version = "1.0.0", features = ["uuid", "serde"] } + +oauth2-types = { path = "../oauth2-types" } +mas-storage = { path = "../storage" } +mas-data-model = { path = "../data-model" } +mas-iana = { path = "../iana" } +mas-jose = { path = "../jose" } diff --git a/crates/storage/build.rs b/crates/storage-pg/build.rs similarity index 92% rename from crates/storage/build.rs rename to crates/storage-pg/build.rs index dd5b11428..dca71bd6d 100644 --- a/crates/storage/build.rs +++ b/crates/storage-pg/build.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/crates/storage/migrations/20221018142001_init.sql b/crates/storage-pg/migrations/20221018142001_init.sql similarity index 100% rename from crates/storage/migrations/20221018142001_init.sql rename to crates/storage-pg/migrations/20221018142001_init.sql diff --git a/crates/storage/migrations/20221121151402_upstream_oauth.sql b/crates/storage-pg/migrations/20221121151402_upstream_oauth.sql similarity index 100% rename from crates/storage/migrations/20221121151402_upstream_oauth.sql rename to crates/storage-pg/migrations/20221121151402_upstream_oauth.sql diff --git a/crates/storage/migrations/20221213145242_password_schemes.sql b/crates/storage-pg/migrations/20221213145242_password_schemes.sql similarity index 100% rename from crates/storage/migrations/20221213145242_password_schemes.sql rename to crates/storage-pg/migrations/20221213145242_password_schemes.sql diff --git a/crates/storage/sqlx-data.json b/crates/storage-pg/sqlx-data.json similarity index 98% rename from crates/storage/sqlx-data.json rename to crates/storage-pg/sqlx-data.json index 8148f796d..94527512c 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage-pg/sqlx-data.json @@ -1336,6 +1336,24 @@ }, "query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n\n WHERE oauth2_client_id = ANY($1::uuid[])\n " }, + "8a79c7c392dd930628caadec80c9b2645501475ab4feacbac59ca1bc52b16c3f": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Bool", + "Bool", + "Text", + "Jsonb", + "Text" + ] + } + }, + "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n " + }, "8b7297c263336d70c2b647212b16f7ae39bc5cb1572e3a2dcfcd67f196a1fa39": { "describe": { "columns": [ @@ -1821,24 +1839,6 @@ }, "query": "\n UPDATE users\n SET primary_user_email_id = user_emails.user_email_id\n FROM user_emails\n WHERE user_emails.user_email_id = $1\n AND users.user_id = user_emails.user_id\n " }, - "c0b4996085f6f2127e1e8cfdf18b9029c22096fadfe6de59dce01c789791edb5": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Bool", - "Bool", - "Text", - "Jsonb", - "Text" - ] - } - }, - "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n " - }, "c0ed9d70e496433d8686a499055d8a8376459109b6154a2c0c13b28462afa523": { "describe": { "columns": [], diff --git a/crates/storage-pg/src/compat/access_token.rs b/crates/storage-pg/src/compat/access_token.rs new file mode 100644 index 000000000..5f73ed9e4 --- /dev/null +++ b/crates/storage-pg/src/compat/access_token.rs @@ -0,0 +1,216 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; +use mas_data_model::{CompatAccessToken, CompatSession}; +use mas_storage::{compat::CompatAccessTokenRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +pub struct PgCompatAccessTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatAccessTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatAccessTokenLookup { + compat_access_token_id: Uuid, + access_token: String, + created_at: DateTime, + expires_at: Option>, + compat_session_id: Uuid, +} + +impl From for CompatAccessToken { + fn from(value: CompatAccessTokenLookup) -> Self { + Self { + id: value.compat_access_token_id.into(), + session_id: value.compat_session_id.into(), + token: value.access_token, + created_at: value.created_at, + expires_at: value.expires_at, + } + } +} + +#[async_trait] +impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_access_token.lookup", + skip_all, + fields( + db.statement, + compat_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatAccessTokenLookup, + r#" + SELECT compat_access_token_id + , access_token + , created_at + , expires_at + , compat_session_id + + FROM compat_access_tokens + + WHERE compat_access_token_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_access_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatAccessTokenLookup, + r#" + SELECT compat_access_token_id + , access_token + , created_at + , expires_at + , compat_session_id + + FROM compat_access_tokens + + WHERE access_token = $1 + "#, + access_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_access_token.add", + skip_all, + fields( + db.statement, + compat_access_token.id, + %compat_session.id, + user.id = %compat_session.user_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + compat_session: &CompatSession, + token: String, + expires_after: Option, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); + + let expires_at = expires_after.map(|expires_after| created_at + expires_after); + + sqlx::query!( + r#" + INSERT INTO compat_access_tokens + (compat_access_token_id, compat_session_id, access_token, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(compat_session.id), + token, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatAccessToken { + id, + session_id: compat_session.id, + token, + created_at, + expires_at, + }) + } + + #[tracing::instrument( + name = "db.compat_access_token.expire", + skip_all, + fields( + db.statement, + %compat_access_token.id, + compat_session.id = %compat_access_token.session_id, + ), + err, + )] + async fn expire( + &mut self, + clock: &Clock, + mut compat_access_token: CompatAccessToken, + ) -> Result { + let expires_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE compat_access_tokens + SET expires_at = $2 + WHERE compat_access_token_id = $1 + "#, + Uuid::from(compat_access_token.id), + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + compat_access_token.expires_at = Some(expires_at); + Ok(compat_access_token) + } +} diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs new file mode 100644 index 000000000..732ce3aaa --- /dev/null +++ b/crates/storage-pg/src/compat/mod.rs @@ -0,0 +1,322 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod access_token; +mod refresh_token; +mod session; +mod sso_login; + +pub use self::{ + access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository, + session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository, +}; + +#[cfg(test)] +mod tests { + use chrono::Duration; + use mas_data_model::Device; + use mas_storage::{ + compat::{ + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + }, + user::UserRepository, + Clock, Repository, + }; + use rand::SeedableRng; + use rand_chacha::ChaChaRng; + use sqlx::PgPool; + + use crate::PgRepository; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_session_repository(pool: PgPool) { + const FIRST_TOKEN: &str = "first_access_token"; + const SECOND_TOKEN: &str = "second_access_token"; + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let device_str = device.as_str().to_owned(); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + assert_eq!(session.user_id, user.id); + assert_eq!(session.device.as_str(), device_str); + assert!(session.is_valid()); + assert!(!session.is_finished()); + + // Lookup the session and check it didn't change + let session_lookup = repo + .compat_session() + .lookup(session.id) + .await + .unwrap() + .expect("compat session not found"); + assert_eq!(session_lookup.id, session.id); + assert_eq!(session_lookup.user_id, user.id); + assert_eq!(session_lookup.device.as_str(), device_str); + assert!(session_lookup.is_valid()); + assert!(!session_lookup.is_finished()); + + // Finish the session + let session = repo.compat_session().finish(&clock, session).await.unwrap(); + assert!(!session.is_valid()); + assert!(session.is_finished()); + + // Reload the session and check again + let session_lookup = repo + .compat_session() + .lookup(session.id) + .await + .unwrap() + .expect("compat session not found"); + assert!(!session_lookup.is_valid()); + assert!(session_lookup.is_finished()); + } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_access_token_repository(pool: PgPool) { + const FIRST_TOKEN: &str = "first_access_token"; + const SECOND_TOKEN: &str = "second_access_token"; + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + + // Add an access token to that session + let token = repo + .compat_access_token() + .add( + &mut rng, + &clock, + &session, + FIRST_TOKEN.to_owned(), + Some(Duration::minutes(1)), + ) + .await + .unwrap(); + assert_eq!(token.session_id, session.id); + assert_eq!(token.token, FIRST_TOKEN); + + // Commit the txn and grab a new transaction, to test a conflict + repo.save().await.unwrap(); + + { + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + // Adding the same token a second time should conflict + assert!(repo + .compat_access_token() + .add( + &mut rng, + &clock, + &session, + FIRST_TOKEN.to_owned(), + Some(Duration::minutes(1)), + ) + .await + .is_err()); + repo.cancel().await.unwrap(); + } + + // Grab a new repo + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Looking up via ID works + let token_lookup = repo + .compat_access_token() + .lookup(token.id) + .await + .unwrap() + .expect("compat access token not found"); + assert_eq!(token.id, token_lookup.id); + assert_eq!(token_lookup.session_id, session.id); + + // Looking up via the token value works + let token_lookup = repo + .compat_access_token() + .find_by_token(FIRST_TOKEN) + .await + .unwrap() + .expect("compat access token not found"); + assert_eq!(token.id, token_lookup.id); + assert_eq!(token_lookup.session_id, session.id); + + // Token is currently valid + assert!(token.is_valid(clock.now())); + + clock.advance(Duration::minutes(1)); + // Token should have expired + assert!(!token.is_valid(clock.now())); + + // Add a second access token, this time without expiration + let token = repo + .compat_access_token() + .add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None) + .await + .unwrap(); + assert_eq!(token.session_id, session.id); + assert_eq!(token.token, SECOND_TOKEN); + + // Token is currently valid + assert!(token.is_valid(clock.now())); + + // Make it expire + repo.compat_access_token() + .expire(&clock, token) + .await + .unwrap(); + + // Reload it + let token = repo + .compat_access_token() + .find_by_token(SECOND_TOKEN) + .await + .unwrap() + .expect("compat access token not found"); + + // Token is not valid anymore + assert!(!token.is_valid(clock.now())); + + repo.save().await.unwrap(); + } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_refresh_token_repository(pool: PgPool) { + const ACCESS_TOKEN: &str = "access_token"; + const REFRESH_TOKEN: &str = "refresh_token"; + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + + // Add an access token to that session + let access_token = repo + .compat_access_token() + .add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None) + .await + .unwrap(); + + let refresh_token = repo + .compat_refresh_token() + .add( + &mut rng, + &clock, + &session, + &access_token, + REFRESH_TOKEN.to_owned(), + ) + .await + .unwrap(); + assert_eq!(refresh_token.session_id, session.id); + assert_eq!(refresh_token.access_token_id, access_token.id); + assert_eq!(refresh_token.token, REFRESH_TOKEN); + assert!(refresh_token.is_valid()); + assert!(!refresh_token.is_consumed()); + + // Look it up by ID and check everything matches + let refresh_token_lookup = repo + .compat_refresh_token() + .lookup(refresh_token.id) + .await + .unwrap() + .expect("refresh token not found"); + assert_eq!(refresh_token_lookup.id, refresh_token.id); + assert_eq!(refresh_token_lookup.session_id, session.id); + assert_eq!(refresh_token_lookup.access_token_id, access_token.id); + assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN); + assert!(refresh_token_lookup.is_valid()); + assert!(!refresh_token_lookup.is_consumed()); + + // Look it up by token and check everything matches + let refresh_token_lookup = repo + .compat_refresh_token() + .find_by_token(REFRESH_TOKEN) + .await + .unwrap() + .expect("refresh token not found"); + assert_eq!(refresh_token_lookup.id, refresh_token.id); + assert_eq!(refresh_token_lookup.session_id, session.id); + assert_eq!(refresh_token_lookup.access_token_id, access_token.id); + assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN); + assert!(refresh_token_lookup.is_valid()); + assert!(!refresh_token_lookup.is_consumed()); + + // Consume it + let refresh_token = repo + .compat_refresh_token() + .consume(&clock, refresh_token) + .await + .unwrap(); + assert!(!refresh_token.is_valid()); + assert!(refresh_token.is_consumed()); + + // Reload it and check again + let refresh_token_lookup = repo + .compat_refresh_token() + .find_by_token(REFRESH_TOKEN) + .await + .unwrap() + .expect("refresh token not found"); + assert!(!refresh_token_lookup.is_valid()); + assert!(refresh_token_lookup.is_consumed()); + + // Consuming it again should not work + assert!(repo + .compat_refresh_token() + .consume(&clock, refresh_token) + .await + .is_err()); + + repo.save().await.unwrap(); + } +} diff --git a/crates/storage-pg/src/compat/refresh_token.rs b/crates/storage-pg/src/compat/refresh_token.rs new file mode 100644 index 000000000..314e81479 --- /dev/null +++ b/crates/storage-pg/src/compat/refresh_token.rs @@ -0,0 +1,230 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, +}; +use mas_storage::{compat::CompatRefreshTokenRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +pub struct PgCompatRefreshTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatRefreshTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatRefreshTokenLookup { + compat_refresh_token_id: Uuid, + refresh_token: String, + created_at: DateTime, + consumed_at: Option>, + compat_access_token_id: Uuid, + compat_session_id: Uuid, +} + +impl From for CompatRefreshToken { + fn from(value: CompatRefreshTokenLookup) -> Self { + let state = match value.consumed_at { + Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at }, + None => CompatRefreshTokenState::Valid, + }; + + Self { + id: value.compat_refresh_token_id.into(), + state, + session_id: value.compat_session_id.into(), + token: value.refresh_token, + created_at: value.created_at, + access_token_id: value.compat_access_token_id.into(), + } + } +} + +#[async_trait] +impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_refresh_token.lookup", + skip_all, + fields( + db.statement, + compat_refresh_token.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatRefreshTokenLookup, + r#" + SELECT compat_refresh_token_id + , refresh_token + , created_at + , consumed_at + , compat_session_id + , compat_access_token_id + + FROM compat_refresh_tokens + + WHERE compat_refresh_token_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatRefreshTokenLookup, + r#" + SELECT compat_refresh_token_id + , refresh_token + , created_at + , consumed_at + , compat_session_id + , compat_access_token_id + + FROM compat_refresh_tokens + + WHERE refresh_token = $1 + "#, + refresh_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.add", + skip_all, + fields( + db.statement, + compat_refresh_token.id, + %compat_session.id, + user.id = %compat_session.user_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + compat_session: &CompatSession, + compat_access_token: &CompatAccessToken, + token: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_refresh_tokens + (compat_refresh_token_id, compat_session_id, + compat_access_token_id, refresh_token, created_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(compat_session.id), + Uuid::from(compat_access_token.id), + token, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatRefreshToken { + id, + state: CompatRefreshTokenState::default(), + session_id: compat_session.id, + access_token_id: compat_access_token.id, + token, + created_at, + }) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.consume", + skip_all, + fields( + db.statement, + %compat_refresh_token.id, + compat_session.id = %compat_refresh_token.session_id, + ), + err, + )] + async fn consume( + &mut self, + clock: &Clock, + compat_refresh_token: CompatRefreshToken, + ) -> Result { + let consumed_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE compat_refresh_tokens + SET consumed_at = $2 + WHERE compat_refresh_token_id = $1 + "#, + Uuid::from(compat_refresh_token.id), + consumed_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let compat_refresh_token = compat_refresh_token + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(compat_refresh_token) + } +} diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs new file mode 100644 index 000000000..a6e65f9a4 --- /dev/null +++ b/crates/storage-pg/src/compat/session.rs @@ -0,0 +1,195 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{CompatSession, CompatSessionState, Device, User}; +use mas_storage::{compat::CompatSessionRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +pub struct PgCompatSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatSessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatSessionLookup { + compat_session_id: Uuid, + device_id: String, + user_id: Uuid, + created_at: DateTime, + finished_at: Option>, +} + +impl TryFrom for CompatSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: CompatSessionLookup) -> Result { + let id = value.compat_session_id.into(); + let device = Device::try_from(value.device_id).map_err(|e| { + DatabaseInconsistencyError::on("compat_sessions") + .column("device_id") + .row(id) + .source(e) + })?; + + let state = match value.finished_at { + None => CompatSessionState::Valid, + Some(finished_at) => CompatSessionState::Finished { finished_at }, + }; + + let session = CompatSession { + id, + state, + user_id: value.user_id.into(), + device, + created_at: value.created_at, + }; + + Ok(session) + } +} + +#[async_trait] +impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_session.lookup", + skip_all, + fields( + db.statement, + compat_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSessionLookup, + r#" + SELECT compat_session_id + , device_id + , user_id + , created_at + , finished_at + FROM compat_sessions + WHERE compat_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_session.add", + skip_all, + fields( + db.statement, + compat_session.id, + %user.id, + %user.username, + compat_session.device.id = device.as_str(), + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + device: Device, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + Uuid::from(user.id), + device.as_str(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatSession { + id, + state: CompatSessionState::default(), + user_id: user.id, + device, + created_at, + }) + } + + #[tracing::instrument( + name = "db.compat_session.finish", + skip_all, + fields( + db.statement, + %compat_session.id, + user.id = %compat_session.user_id, + compat_session.device.id = compat_session.device.as_str(), + ), + err, + )] + async fn finish( + &mut self, + clock: &Clock, + compat_session: CompatSession, + ) -> Result { + let finished_at = clock.now(); + + let res = sqlx::query!( + r#" + UPDATE compat_sessions cs + SET finished_at = $2 + WHERE compat_session_id = $1 + "#, + Uuid::from(compat_session.id), + finished_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let compat_session = compat_session + .finish(finished_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(compat_session) + } +} diff --git a/crates/storage-pg/src/compat/sso_login.rs b/crates/storage-pg/src/compat/sso_login.rs new file mode 100644 index 000000000..a2eeb9263 --- /dev/null +++ b/crates/storage-pg/src/compat/sso_login.rs @@ -0,0 +1,342 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User}; +use mas_storage::{compat::CompatSsoLoginRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use url::Url; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +pub struct PgCompatSsoLoginRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatSsoLoginRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct CompatSsoLoginLookup { + compat_sso_login_id: Uuid, + login_token: String, + redirect_uri: String, + created_at: DateTime, + fulfilled_at: Option>, + exchanged_at: Option>, + compat_session_id: Option, +} + +impl TryFrom for CompatSsoLogin { + type Error = DatabaseInconsistencyError; + + fn try_from(res: CompatSsoLoginLookup) -> Result { + let id = res.compat_sso_login_id.into(); + let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| { + DatabaseInconsistencyError::on("compat_sso_logins") + .column("redirect_uri") + .row(id) + .source(e) + })?; + + let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) { + (None, None, None) => CompatSsoLoginState::Pending, + (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled { + fulfilled_at, + session_id: session_id.into(), + }, + (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => { + CompatSsoLoginState::Exchanged { + fulfilled_at, + exchanged_at, + session_id: session_id.into(), + } + } + _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), + }; + + Ok(CompatSsoLogin { + id, + login_token: res.login_token, + redirect_uri, + created_at: res.created_at, + state, + }) + } +} + +#[async_trait] +impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_sso_login.lookup", + skip_all, + fields( + db.statement, + compat_sso_login.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSsoLoginLookup, + r#" + SELECT compat_sso_login_id + , login_token + , redirect_uri + , created_at + , fulfilled_at + , exchanged_at + , compat_session_id + + FROM compat_sso_logins + WHERE compat_sso_login_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_sso_login.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + login_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSsoLoginLookup, + r#" + SELECT compat_sso_login_id + , login_token + , redirect_uri + , created_at + , fulfilled_at + , exchanged_at + , compat_session_id + + FROM compat_sso_logins + WHERE login_token = $1 + "#, + login_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_sso_login.add", + skip_all, + fields( + db.statement, + compat_sso_login.id, + compat_sso_login.redirect_uri = %redirect_uri, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + login_token: String, + redirect_uri: Url, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_sso_logins + (compat_sso_login_id, login_token, redirect_uri, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + &login_token, + redirect_uri.as_str(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatSsoLogin { + id, + login_token, + redirect_uri, + created_at, + state: CompatSsoLoginState::default(), + }) + } + + #[tracing::instrument( + name = "db.compat_sso_login.fulfill", + skip_all, + fields( + db.statement, + %compat_sso_login.id, + %compat_session.id, + compat_session.device.id = compat_session.device.as_str(), + user.id = %compat_session.user_id, + ), + err, + )] + async fn fulfill( + &mut self, + clock: &Clock, + compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, + ) -> Result { + let fulfilled_at = clock.now(); + let compat_sso_login = compat_sso_login + .fulfill(fulfilled_at, compat_session) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE compat_sso_logins + SET + compat_session_id = $2, + fulfilled_at = $3 + WHERE + compat_sso_login_id = $1 + "#, + Uuid::from(compat_sso_login.id), + Uuid::from(compat_session.id), + fulfilled_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(compat_sso_login) + } + + #[tracing::instrument( + name = "db.compat_sso_login.exchange", + skip_all, + fields( + db.statement, + %compat_sso_login.id, + ), + err, + )] + async fn exchange( + &mut self, + clock: &Clock, + compat_sso_login: CompatSsoLogin, + ) -> Result { + let exchanged_at = clock.now(); + let compat_sso_login = compat_sso_login + .exchange(exchanged_at) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE compat_sso_logins + SET + exchanged_at = $2 + WHERE + compat_sso_login_id = $1 + "#, + Uuid::from(compat_sso_login.id), + exchanged_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(compat_sso_login) + } + + #[tracing::instrument( + name = "db.compat_sso_login.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err + )] + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT cl.compat_sso_login_id + , cl.login_token + , cl.redirect_uri + , cl.created_at + , cl.fulfilled_at + , cl.exchanged_at + , cl.compat_session_id + + FROM compat_sso_logins cl + INNER JOIN compat_sessions ON compat_session_id + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("cl.compat_sso_login_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination + .process(edges) + .try_map(CompatSsoLogin::try_from)?; + Ok(page) + } +} diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs new file mode 100644 index 000000000..459c8c3bf --- /dev/null +++ b/crates/storage-pg/src/lib.rs @@ -0,0 +1,170 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Interactions with the database + +#![forbid(unsafe_code)] +#![deny( + clippy::all, + clippy::str_to_string, + clippy::future_not_send, + rustdoc::broken_intra_doc_links +)] +#![warn(clippy::pedantic)] +#![allow( + clippy::missing_errors_doc, + clippy::missing_panics_doc, + clippy::module_name_repetitions +)] + +use sqlx::{migrate::Migrator, postgres::PgQueryResult}; +use thiserror::Error; +use ulid::Ulid; + +trait LookupResultExt { + type Output; + + /// Transform a [`Result`] from a sqlx query to transform "not found" errors + /// into [`None`] + fn to_option(self) -> Result, sqlx::Error>; +} + +impl LookupResultExt for Result { + type Output = T; + + fn to_option(self) -> Result, sqlx::Error> { + match self { + Ok(v) => Ok(Some(v)), + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(e), + } + } +} + +/// Generic error when interacting with the database +#[derive(Debug, Error)] +#[error(transparent)] +pub enum DatabaseError { + /// An error which came from the database itself + Driver(#[from] sqlx::Error), + + /// An error which occured while converting the data from the database + Inconsistency(#[from] DatabaseInconsistencyError), + + /// An error which happened because the requested database operation is + /// invalid + #[error("Invalid database operation")] + InvalidOperation { + #[source] + source: Option>, + }, + + /// An error which happens when an operation affects not enough or too many + /// rows + #[error("Expected {expected} rows to be affected, but {actual} rows were affected")] + RowsAffected { expected: u64, actual: u64 }, +} + +impl DatabaseError { + pub(crate) fn ensure_affected_rows( + result: &PgQueryResult, + expected: u64, + ) -> Result<(), DatabaseError> { + let actual = result.rows_affected(); + if actual == expected { + Ok(()) + } else { + Err(DatabaseError::RowsAffected { expected, actual }) + } + } + + pub(crate) fn to_invalid_operation(e: E) -> Self { + Self::InvalidOperation { + source: Some(Box::new(e)), + } + } + + pub(crate) const fn invalid_operation() -> Self { + Self::InvalidOperation { source: None } + } +} + +#[derive(Debug, Error)] +pub struct DatabaseInconsistencyError { + table: &'static str, + column: Option<&'static str>, + row: Option, + + #[source] + source: Option>, +} + +impl std::fmt::Display for DatabaseInconsistencyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Database inconsistency on table {}", self.table)?; + if let Some(column) = self.column { + write!(f, " column {column}")?; + } + if let Some(row) = self.row { + write!(f, " row {row}")?; + } + + Ok(()) + } +} + +impl DatabaseInconsistencyError { + #[must_use] + pub(crate) const fn on(table: &'static str) -> Self { + Self { + table, + column: None, + row: None, + source: None, + } + } + + #[must_use] + pub(crate) const fn column(mut self, column: &'static str) -> Self { + self.column = Some(column); + self + } + + #[must_use] + pub(crate) const fn row(mut self, row: Ulid) -> Self { + self.row = Some(row); + self + } + + pub(crate) fn source( + mut self, + source: E, + ) -> Self { + self.source = Some(Box::new(source)); + self + } +} + +pub mod compat; +pub mod oauth2; +pub(crate) mod pagination; +pub(crate) mod repository; +pub(crate) mod tracing; +pub mod upstream_oauth2; +pub mod user; + +pub use self::repository::PgRepository; + +/// Embedded migrations, allowing them to run on startup +pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/crates/storage-pg/src/oauth2/access_token.rs b/crates/storage-pg/src/oauth2/access_token.rs new file mode 100644 index 000000000..33d95242f --- /dev/null +++ b/crates/storage-pg/src/oauth2/access_token.rs @@ -0,0 +1,223 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; +use mas_data_model::{AccessToken, AccessTokenState, Session}; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +pub struct PgOAuth2AccessTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2AccessTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct OAuth2AccessTokenLookup { + oauth2_access_token_id: Uuid, + oauth2_session_id: Uuid, + access_token: String, + created_at: DateTime, + expires_at: DateTime, + revoked_at: Option>, +} + +impl From for AccessToken { + fn from(value: OAuth2AccessTokenLookup) -> Self { + let state = match value.revoked_at { + None => AccessTokenState::Valid, + Some(revoked_at) => AccessTokenState::Revoked { revoked_at }, + }; + + Self { + id: value.oauth2_access_token_id.into(), + state, + session_id: value.oauth2_session_id.into(), + access_token: value.access_token, + created_at: value.created_at, + expires_at: value.expires_at, + } + } +} + +#[async_trait] +impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> { + type Error = DatabaseError; + + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2AccessTokenLookup, + r#" + SELECT oauth2_access_token_id + , access_token + , created_at + , expires_at + , revoked_at + , oauth2_session_id + + FROM oauth2_access_tokens + + WHERE oauth2_access_token_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_access_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2AccessTokenLookup, + r#" + SELECT oauth2_access_token_id + , access_token + , created_at + , expires_at + , revoked_at + , oauth2_session_id + + FROM oauth2_access_tokens + + WHERE access_token = $1 + "#, + access_token, + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_access_token.add", + skip_all, + fields( + db.statement, + %session.id, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + access_token.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + session: &Session, + access_token: String, + expires_after: Duration, + ) -> Result { + let created_at = clock.now(); + let expires_at = created_at + expires_after; + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + + tracing::Span::current().record("access_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_access_tokens + (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at) + VALUES + ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(session.id), + &access_token, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(AccessToken { + id, + state: AccessTokenState::default(), + access_token, + session_id: session.id, + created_at, + expires_at, + }) + } + + async fn revoke( + &mut self, + clock: &Clock, + access_token: AccessToken, + ) -> Result { + let revoked_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_access_tokens + SET revoked_at = $2 + WHERE oauth2_access_token_id = $1 + "#, + Uuid::from(access_token.id), + revoked_at, + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + access_token + .revoke(revoked_at) + .map_err(DatabaseError::to_invalid_operation) + } + + async fn cleanup_expired(&mut self, clock: &Clock) -> Result { + // Cleanup token which expired more than 15 minutes ago + let threshold = clock.now() - Duration::minutes(15); + let res = sqlx::query!( + r#" + DELETE FROM oauth2_access_tokens + WHERE expires_at < $1 + "#, + threshold, + ) + .execute(&mut *self.conn) + .await?; + + Ok(res.rows_affected().try_into().unwrap_or(usize::MAX)) + } +} diff --git a/crates/storage-pg/src/oauth2/authorization_grant.rs b/crates/storage-pg/src/oauth2/authorization_grant.rs new file mode 100644 index 000000000..027111a76 --- /dev/null +++ b/crates/storage-pg/src/oauth2/authorization_grant.rs @@ -0,0 +1,510 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::num::NonZeroU32; + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session, +}; +use mas_iana::oauth::PkceCodeChallengeMethod; +use mas_storage::{oauth2::OAuth2AuthorizationGrantRepository, Clock}; +use oauth2_types::{requests::ResponseMode, scope::Scope}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use url::Url; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +pub struct PgOAuth2AuthorizationGrantRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2AuthorizationGrantRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[allow(clippy::struct_excessive_bools)] +struct GrantLookup { + oauth2_authorization_grant_id: Uuid, + created_at: DateTime, + cancelled_at: Option>, + fulfilled_at: Option>, + exchanged_at: Option>, + scope: String, + state: Option, + nonce: Option, + redirect_uri: String, + response_mode: String, + max_age: Option, + response_type_code: bool, + response_type_id_token: bool, + authorization_code: Option, + code_challenge: Option, + code_challenge_method: Option, + requires_consent: bool, + oauth2_client_id: Uuid, + oauth2_session_id: Option, +} + +impl TryFrom for AuthorizationGrant { + type Error = DatabaseInconsistencyError; + + #[allow(clippy::too_many_lines)] + fn try_from(value: GrantLookup) -> Result { + let id = value.oauth2_authorization_grant_id.into(); + let scope: Scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("scope") + .row(id) + .source(e) + })?; + + let stage = match ( + value.fulfilled_at, + value.exchanged_at, + value.cancelled_at, + value.oauth2_session_id, + ) { + (None, None, None, None) => AuthorizationGrantStage::Pending, + (Some(fulfilled_at), None, None, Some(session_id)) => { + AuthorizationGrantStage::Fulfilled { + session_id: session_id.into(), + fulfilled_at, + } + } + (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => { + AuthorizationGrantStage::Exchanged { + session_id: session_id.into(), + fulfilled_at, + exchanged_at, + } + } + (None, None, Some(cancelled_at), None) => { + AuthorizationGrantStage::Cancelled { cancelled_at } + } + _ => { + return Err( + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("stage") + .row(id), + ); + } + }; + + let pkce = match (value.code_challenge, value.code_challenge_method) { + (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { + Some(Pkce { + challenge_method: PkceCodeChallengeMethod::Plain, + challenge, + }) + } + (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce { + challenge_method: PkceCodeChallengeMethod::S256, + challenge, + }), + (None, None) => None, + _ => { + return Err( + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("code_challenge_method") + .row(id), + ); + } + }; + + let code: Option = + match (value.response_type_code, value.authorization_code, pkce) { + (false, None, None) => None, + (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }), + _ => { + return Err( + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("authorization_code") + .row(id), + ); + } + }; + + let redirect_uri = value.redirect_uri.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("redirect_uri") + .row(id) + .source(e) + })?; + + let response_mode = value.response_mode.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("response_mode") + .row(id) + .source(e) + })?; + + let max_age = value + .max_age + .map(u32::try_from) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("max_age") + .row(id) + .source(e) + })? + .map(NonZeroU32::try_from) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("max_age") + .row(id) + .source(e) + })?; + + Ok(AuthorizationGrant { + id, + stage, + client_id: value.oauth2_client_id.into(), + code, + scope, + state: value.state, + nonce: value.nonce, + max_age, + response_mode, + redirect_uri, + created_at: value.created_at, + response_type_id_token: value.response_type_id_token, + requires_consent: value.requires_consent, + }) + } +} + +#[async_trait] +impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.add", + skip_all, + fields( + db.statement, + grant.id, + grant.scope = %scope, + %client.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + client: &Client, + redirect_uri: Url, + scope: Scope, + code: Option, + state: Option, + nonce: Option, + max_age: Option, + response_mode: ResponseMode, + response_type_id_token: bool, + requires_consent: bool, + ) -> Result { + let code_challenge = code + .as_ref() + .and_then(|c| c.pkce.as_ref()) + .map(|p| &p.challenge); + let code_challenge_method = code + .as_ref() + .and_then(|c| c.pkce.as_ref()) + .map(|p| p.challenge_method.to_string()); + // TODO: this conversion is a bit ugly + let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)); + let code_str = code.as_ref().map(|c| &c.code); + + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("grant.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_authorization_grants ( + oauth2_authorization_grant_id, + oauth2_client_id, + redirect_uri, + scope, + state, + nonce, + max_age, + response_mode, + code_challenge, + code_challenge_method, + response_type_code, + response_type_id_token, + authorization_code, + requires_consent, + created_at + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + "#, + Uuid::from(id), + Uuid::from(client.id), + redirect_uri.to_string(), + scope.to_string(), + state, + nonce, + max_age_i32, + response_mode.to_string(), + code_challenge, + code_challenge_method, + code.is_some(), + response_type_id_token, + code_str, + requires_consent, + created_at, + ) + .execute(&mut *self.conn) + .await?; + + Ok(AuthorizationGrant { + id, + stage: AuthorizationGrantStage::Pending, + code, + redirect_uri, + client_id: client.id, + scope, + state, + nonce, + max_age, + response_mode, + created_at, + response_type_id_token, + requires_consent, + }) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.lookup", + skip_all, + fields( + db.statement, + grant.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + GrantLookup, + r#" + SELECT oauth2_authorization_grant_id + , created_at + , cancelled_at + , fulfilled_at + , exchanged_at + , scope + , state + , redirect_uri + , response_mode + , nonce + , max_age + , oauth2_client_id + , authorization_code + , response_type_code + , response_type_id_token + , code_challenge + , code_challenge_method + , requires_consent + , oauth2_session_id + FROM + oauth2_authorization_grants + + WHERE oauth2_authorization_grant_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.find_by_code", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_code( + &mut self, + code: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + GrantLookup, + r#" + SELECT oauth2_authorization_grant_id + , created_at + , cancelled_at + , fulfilled_at + , exchanged_at + , scope + , state + , redirect_uri + , response_mode + , nonce + , max_age + , oauth2_client_id + , authorization_code + , response_type_code + , response_type_id_token + , code_challenge + , code_challenge_method + , requires_consent + , oauth2_session_id + FROM + oauth2_authorization_grants + + WHERE authorization_code = $1 + "#, + code, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.fulfill", + skip_all, + fields( + db.statement, + %grant.id, + client.id = %grant.client_id, + %session.id, + user_session.id = %session.user_session_id, + ), + err, + )] + async fn fulfill( + &mut self, + clock: &Clock, + session: &Session, + grant: AuthorizationGrant, + ) -> Result { + let fulfilled_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_authorization_grants + SET fulfilled_at = $2 + , oauth2_session_id = $3 + WHERE oauth2_authorization_grant_id = $1 + "#, + Uuid::from(grant.id), + fulfilled_at, + Uuid::from(session.id), + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + // XXX: check affected rows & new methods + let grant = grant + .fulfill(fulfilled_at, session) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(grant) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.exchange", + skip_all, + fields( + db.statement, + %grant.id, + client.id = %grant.client_id, + ), + err, + )] + async fn exchange( + &mut self, + clock: &Clock, + grant: AuthorizationGrant, + ) -> Result { + let exchanged_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_authorization_grants + SET exchanged_at = $2 + WHERE oauth2_authorization_grant_id = $1 + "#, + Uuid::from(grant.id), + exchanged_at, + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let grant = grant + .exchange(exchanged_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(grant) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.give_consent", + skip_all, + fields( + db.statement, + %grant.id, + client.id = %grant.client_id, + ), + err, + )] + async fn give_consent( + &mut self, + mut grant: AuthorizationGrant, + ) -> Result { + sqlx::query!( + r#" + UPDATE oauth2_authorization_grants AS og + SET + requires_consent = 'f' + WHERE + og.oauth2_authorization_grant_id = $1 + "#, + Uuid::from(grant.id), + ) + .execute(&mut *self.conn) + .await?; + + grant.requires_consent = false; + + Ok(grant) + } +} diff --git a/crates/storage-pg/src/oauth2/client.rs b/crates/storage-pg/src/oauth2/client.rs new file mode 100644 index 000000000..4430c669c --- /dev/null +++ b/crates/storage-pg/src/oauth2/client.rs @@ -0,0 +1,745 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + collections::{BTreeMap, BTreeSet}, + str::FromStr, + string::ToString, +}; + +use async_trait::async_trait; +use mas_data_model::{Client, JwksOrJwksUri, User}; +use mas_iana::{ + jose::JsonWebSignatureAlg, + oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, +}; +use mas_jose::jwk::PublicJsonWebKeySet; +use mas_storage::{oauth2::OAuth2ClientRepository, Clock}; +use oauth2_types::{ + requests::GrantType, + scope::{Scope, ScopeToken}, +}; +use rand::{Rng, RngCore}; +use sqlx::PgConnection; +use tracing::{info_span, Instrument}; +use ulid::Ulid; +use url::Url; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +pub struct PgOAuth2ClientRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2ClientRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +// XXX: response_types & contacts +#[derive(Debug)] +struct OAuth2ClientLookup { + oauth2_client_id: Uuid, + encrypted_client_secret: Option, + redirect_uris: Vec, + // response_types: Vec, + grant_type_authorization_code: bool, + grant_type_refresh_token: bool, + // contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, +} + +impl TryInto for OAuth2ClientLookup { + type Error = DatabaseInconsistencyError; + + #[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing + fn try_into(self) -> Result { + let id = Ulid::from(self.oauth2_client_id); + + let redirect_uris: Result, _> = + self.redirect_uris.iter().map(|s| s.parse()).collect(); + let redirect_uris = redirect_uris.map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("redirect_uris") + .row(id) + .source(e) + })?; + + let response_types = vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ]; + /* XXX + let response_types: Result, _> = + self.response_types.iter().map(|s| s.parse()).collect(); + let response_types = response_types.map_err(|source| ClientFetchError::ParseField { + field: "response_types", + source, + })?; + */ + + let mut grant_types = Vec::new(); + if self.grant_type_authorization_code { + grant_types.push(GrantType::AuthorizationCode); + } + if self.grant_type_refresh_token { + grant_types.push(GrantType::RefreshToken); + } + + let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("logo_uri") + .row(id) + .source(e) + })?; + + let client_uri = self + .client_uri + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("client_uri") + .row(id) + .source(e) + })?; + + let policy_uri = self + .policy_uri + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("policy_uri") + .row(id) + .source(e) + })?; + + let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("tos_uri") + .row(id) + .source(e) + })?; + + let id_token_signed_response_alg = self + .id_token_signed_response_alg + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("id_token_signed_response_alg") + .row(id) + .source(e) + })?; + + let userinfo_signed_response_alg = self + .userinfo_signed_response_alg + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("userinfo_signed_response_alg") + .row(id) + .source(e) + })?; + + let token_endpoint_auth_method = self + .token_endpoint_auth_method + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("token_endpoint_auth_method") + .row(id) + .source(e) + })?; + + let token_endpoint_auth_signing_alg = self + .token_endpoint_auth_signing_alg + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("token_endpoint_auth_signing_alg") + .row(id) + .source(e) + })?; + + let initiate_login_uri = self + .initiate_login_uri + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("initiate_login_uri") + .row(id) + .source(e) + })?; + + let jwks = match (self.jwks, self.jwks_uri) { + (None, None) => None, + (Some(jwks), None) => { + let jwks = serde_json::from_value(jwks).map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("jwks") + .row(id) + .source(e) + })?; + Some(JwksOrJwksUri::Jwks(jwks)) + } + (None, Some(jwks_uri)) => { + let jwks_uri = jwks_uri.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("jwks_uri") + .row(id) + .source(e) + })?; + + Some(JwksOrJwksUri::JwksUri(jwks_uri)) + } + _ => { + return Err(DatabaseInconsistencyError::on("oauth2_clients") + .column("jwks(_uri)") + .row(id)) + } + }; + + Ok(Client { + id, + client_id: id.to_string(), + encrypted_client_secret: self.encrypted_client_secret, + redirect_uris, + response_types, + grant_types, + // contacts: self.contacts, + contacts: vec![], + client_name: self.client_name, + logo_uri, + client_uri, + policy_uri, + tos_uri, + jwks, + id_token_signed_response_alg, + userinfo_signed_response_alg, + token_endpoint_auth_method, + token_endpoint_auth_signing_alg, + initiate_login_uri, + }) + } +} + +#[async_trait] +impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_client.lookup", + skip_all, + fields( + db.statement, + oauth2_client.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2ClientLookup, + r#" + SELECT oauth2_client_id + , encrypted_client_secret + , ARRAY( + SELECT redirect_uri + FROM oauth2_client_redirect_uris r + WHERE r.oauth2_client_id = c.oauth2_client_id + ) AS "redirect_uris!" + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + FROM oauth2_clients c + + WHERE oauth2_client_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_client.load_batch", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn load_batch( + &mut self, + ids: BTreeSet, + ) -> Result, Self::Error> { + let ids: Vec = ids.into_iter().map(Uuid::from).collect(); + let res = sqlx::query_as!( + OAuth2ClientLookup, + r#" + SELECT oauth2_client_id + , encrypted_client_secret + , ARRAY( + SELECT redirect_uri + FROM oauth2_client_redirect_uris r + WHERE r.oauth2_client_id = c.oauth2_client_id + ) AS "redirect_uris!" + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + FROM oauth2_clients c + + WHERE oauth2_client_id = ANY($1::uuid[]) + "#, + &ids, + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + res.into_iter() + .map(|r| { + r.try_into() + .map(|c: Client| (c.id, c)) + .map_err(DatabaseError::from) + }) + .collect() + } + + #[tracing::instrument( + name = "db.oauth2_client.add", + skip_all, + fields( + db.statement, + client.id, + client.name = client_name + ), + err, + )] + #[allow(clippy::too_many_lines)] + async fn add( + &mut self, + mut rng: &mut (dyn RngCore + Send), + clock: &Clock, + redirect_uris: Vec, + encrypted_client_secret: Option, + grant_types: Vec, + contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, + ) -> Result { + let now = clock.now(); + let id = Ulid::from_datetime_with_source(now.into(), rng); + tracing::Span::current().record("client.id", tracing::field::display(id)); + + let jwks_json = jwks + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(DatabaseError::to_invalid_operation)?; + + sqlx::query!( + r#" + INSERT INTO oauth2_clients + ( oauth2_client_id + , encrypted_client_secret + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + "#, + Uuid::from(id), + encrypted_client_secret, + grant_types.contains(&GrantType::AuthorizationCode), + grant_types.contains(&GrantType::RefreshToken), + client_name, + logo_uri.as_ref().map(Url::as_str), + client_uri.as_ref().map(Url::as_str), + policy_uri.as_ref().map(Url::as_str), + tos_uri.as_ref().map(Url::as_str), + jwks_uri.as_ref().map(Url::as_str), + jwks_json, + id_token_signed_response_alg + .as_ref() + .map(ToString::to_string), + userinfo_signed_response_alg + .as_ref() + .map(ToString::to_string), + token_endpoint_auth_method.as_ref().map(ToString::to_string), + token_endpoint_auth_signing_alg + .as_ref() + .map(ToString::to_string), + initiate_login_uri.as_ref().map(Url::as_str), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + { + let span = info_span!( + "db.oauth2_client.add.redirect_uris", + db.statement = tracing::field::Empty, + client.id = %id, + ); + + let (uri_ids, redirect_uris): (Vec, Vec) = redirect_uris + .iter() + .map(|uri| { + ( + Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), + uri.as_str().to_owned(), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_client_redirect_uris + (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) + SELECT id, $2, redirect_uri + FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) + "#, + &uri_ids, + Uuid::from(id), + &redirect_uris, + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + let jwks = match (jwks, jwks_uri) { + (None, None) => None, + (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), + (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), + _ => return Err(DatabaseError::invalid_operation()), + }; + + Ok(Client { + id, + client_id: id.to_string(), + encrypted_client_secret, + redirect_uris, + response_types: vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ], + grant_types, + contacts, + client_name, + logo_uri, + client_uri, + policy_uri, + tos_uri, + jwks, + id_token_signed_response_alg, + userinfo_signed_response_alg, + token_endpoint_auth_method, + token_endpoint_auth_signing_alg, + initiate_login_uri, + }) + } + + #[tracing::instrument( + name = "db.oauth2_client.add_from_config", + skip_all, + fields( + db.statement, + client.id = %client_id, + ), + err, + )] + async fn add_from_config( + &mut self, + mut rng: impl Rng + Send, + clock: &Clock, + client_id: Ulid, + client_auth_method: OAuthClientAuthenticationMethod, + encrypted_client_secret: Option, + jwks: Option, + jwks_uri: Option, + redirect_uris: Vec, + ) -> Result { + let jwks_json = jwks + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(DatabaseError::to_invalid_operation)?; + + let client_auth_method = client_auth_method.to_string(); + + sqlx::query!( + r#" + INSERT INTO oauth2_clients + ( oauth2_client_id + , encrypted_client_secret + , grant_type_authorization_code + , grant_type_refresh_token + , token_endpoint_auth_method + , jwks + , jwks_uri + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (oauth2_client_id) + DO + UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret + , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code + , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token + , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method + , jwks = EXCLUDED.jwks + , jwks_uri = EXCLUDED.jwks_uri + "#, + Uuid::from(client_id), + encrypted_client_secret, + true, + true, + client_auth_method, + jwks_json, + jwks_uri.as_ref().map(Url::as_str), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + { + let span = info_span!( + "db.oauth2_client.add_from_config.redirect_uris", + client.id = %client_id, + db.statement = tracing::field::Empty, + ); + + let now = clock.now(); + let (ids, redirect_uris): (Vec, Vec) = redirect_uris + .iter() + .map(|uri| { + ( + Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), + uri.as_str().to_owned(), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_client_redirect_uris + (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) + SELECT id, $2, redirect_uri + FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) + "#, + &ids, + Uuid::from(client_id), + &redirect_uris, + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + let jwks = match (jwks, jwks_uri) { + (None, None) => None, + (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), + (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), + _ => return Err(DatabaseError::invalid_operation()), + }; + + Ok(Client { + id: client_id, + client_id: client_id.to_string(), + encrypted_client_secret, + redirect_uris, + response_types: vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ], + grant_types: Vec::new(), + contacts: Vec::new(), + client_name: None, + logo_uri: None, + client_uri: None, + policy_uri: None, + tos_uri: None, + jwks, + id_token_signed_response_alg: None, + userinfo_signed_response_alg: None, + token_endpoint_auth_method: None, + token_endpoint_auth_signing_alg: None, + initiate_login_uri: None, + }) + } + + #[tracing::instrument( + name = "db.oauth2_client.get_consent_for_user", + skip_all, + fields( + db.statement, + %user.id, + %client.id, + ), + err, + )] + async fn get_consent_for_user( + &mut self, + client: &Client, + user: &User, + ) -> Result { + let scope_tokens: Vec = sqlx::query_scalar!( + r#" + SELECT scope_token + FROM oauth2_consents + WHERE user_id = $1 AND oauth2_client_id = $2 + "#, + Uuid::from(user.id), + Uuid::from(client.id), + ) + .fetch_all(&mut *self.conn) + .await?; + + let scope: Result = scope_tokens + .into_iter() + .map(|s| ScopeToken::from_str(&s)) + .collect(); + + let scope = scope.map_err(|e| { + DatabaseInconsistencyError::on("oauth2_consents") + .column("scope_token") + .source(e) + })?; + + Ok(scope) + } + + #[tracing::instrument( + skip_all, + fields( + db.statement, + %user.id, + %client.id, + %scope, + ), + err, + )] + async fn give_consent_for_user( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + client: &Client, + user: &User, + scope: &Scope, + ) -> Result<(), Self::Error> { + let now = clock.now(); + let (tokens, ids): (Vec, Vec) = scope + .iter() + .map(|token| { + ( + token.to_string(), + Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_consents + (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at) + SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token) + ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5 + "#, + &ids, + Uuid::from(user.id), + Uuid::from(client.id), + &tokens, + now, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } +} diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs new file mode 100644 index 000000000..edad2beb0 --- /dev/null +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -0,0 +1,25 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod access_token; +pub mod authorization_grant; +mod client; +mod refresh_token; +mod session; + +pub use self::{ + access_token::PgOAuth2AccessTokenRepository, + authorization_grant::PgOAuth2AuthorizationGrantRepository, client::PgOAuth2ClientRepository, + refresh_token::PgOAuth2RefreshTokenRepository, session::PgOAuth2SessionRepository, +}; diff --git a/crates/storage-pg/src/oauth2/refresh_token.rs b/crates/storage-pg/src/oauth2/refresh_token.rs new file mode 100644 index 000000000..47281d934 --- /dev/null +++ b/crates/storage-pg/src/oauth2/refresh_token.rs @@ -0,0 +1,224 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session}; +use mas_storage::{oauth2::OAuth2RefreshTokenRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +pub struct PgOAuth2RefreshTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2RefreshTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct OAuth2RefreshTokenLookup { + oauth2_refresh_token_id: Uuid, + refresh_token: String, + created_at: DateTime, + consumed_at: Option>, + oauth2_access_token_id: Option, + oauth2_session_id: Uuid, +} + +impl From for RefreshToken { + fn from(value: OAuth2RefreshTokenLookup) -> Self { + let state = match value.consumed_at { + None => RefreshTokenState::Valid, + Some(consumed_at) => RefreshTokenState::Consumed { consumed_at }, + }; + + RefreshToken { + id: value.oauth2_refresh_token_id.into(), + state, + session_id: value.oauth2_session_id.into(), + refresh_token: value.refresh_token, + created_at: value.created_at, + access_token_id: value.oauth2_access_token_id.map(Ulid::from), + } + } +} + +#[async_trait] +impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_refresh_token.lookup", + skip_all, + fields( + db.statement, + refresh_token.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2RefreshTokenLookup, + r#" + SELECT oauth2_refresh_token_id + , refresh_token + , created_at + , consumed_at + , oauth2_access_token_id + , oauth2_session_id + FROM oauth2_refresh_tokens + + WHERE oauth2_refresh_token_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2RefreshTokenLookup, + r#" + SELECT oauth2_refresh_token_id + , refresh_token + , created_at + , consumed_at + , oauth2_access_token_id + , oauth2_session_id + FROM oauth2_refresh_tokens + + WHERE refresh_token = $1 + "#, + refresh_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.add", + skip_all, + fields( + db.statement, + %session.id, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + refresh_token.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + session: &Session, + access_token: &AccessToken, + refresh_token: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("refresh_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_refresh_tokens + (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id, + refresh_token, created_at) + VALUES + ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(session.id), + Uuid::from(access_token.id), + refresh_token, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(RefreshToken { + id, + state: RefreshTokenState::default(), + session_id: session.id, + refresh_token, + access_token_id: Some(access_token.id), + created_at, + }) + } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.consume", + skip_all, + fields( + db.statement, + %refresh_token.id, + session.id = %refresh_token.session_id, + ), + err, + )] + async fn consume( + &mut self, + clock: &Clock, + refresh_token: RefreshToken, + ) -> Result { + let consumed_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_refresh_tokens + SET consumed_at = $2 + WHERE oauth2_refresh_token_id = $1 + "#, + Uuid::from(refresh_token.id), + consumed_at, + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + refresh_token + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation) + } +} diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs new file mode 100644 index 000000000..96f798e65 --- /dev/null +++ b/crates/storage-pg/src/oauth2/session.rs @@ -0,0 +1,248 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User}; +use mas_storage::{oauth2::OAuth2SessionRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +pub struct PgOAuth2SessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2SessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct OAuthSessionLookup { + oauth2_session_id: Uuid, + user_session_id: Uuid, + oauth2_client_id: Uuid, + scope: String, + #[allow(dead_code)] + created_at: DateTime, + finished_at: Option>, +} + +impl TryFrom for Session { + type Error = DatabaseInconsistencyError; + + fn try_from(value: OAuthSessionLookup) -> Result { + let id = Ulid::from(value.oauth2_session_id); + let scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_sessions") + .column("scope") + .row(id) + .source(e) + })?; + + let state = match value.finished_at { + None => SessionState::Valid, + Some(finished_at) => SessionState::Finished { finished_at }, + }; + + Ok(Session { + id, + state, + created_at: value.created_at, + client_id: value.oauth2_client_id.into(), + user_session_id: value.user_session_id.into(), + scope, + }) + } +} + +#[async_trait] +impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_session.lookup", + skip_all, + fields( + db.statement, + session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuthSessionLookup, + r#" + SELECT oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + , finished_at + FROM oauth2_sessions + + WHERE oauth2_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(session) = res else { return Ok(None) }; + + Ok(Some(session.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_session.create_from_grant", + skip_all, + fields( + db.statement, + %user_session.id, + user.id = %user_session.user.id, + %grant.id, + client.id = %grant.client_id, + session.id, + session.scope = %grant.scope, + ), + err, + )] + async fn create_from_grant( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + grant: &AuthorizationGrant, + user_session: &BrowserSession, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_sessions + ( oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + ) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + Uuid::from(grant.client_id), + grant.scope.to_string(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(Session { + id, + state: SessionState::Valid, + created_at, + user_session_id: user_session.id, + client_id: grant.client_id, + scope: grant.scope.clone(), + }) + } + + #[tracing::instrument( + name = "db.oauth2_session.finish", + skip_all, + fields( + db.statement, + %session.id, + %session.scope, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + ), + err, + )] + async fn finish(&mut self, clock: &Clock, session: Session) -> Result { + let finished_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_sessions + SET finished_at = $2 + WHERE oauth2_session_id = $1 + "#, + Uuid::from(session.id), + finished_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + session + .finish(finished_at) + .map_err(DatabaseError::to_invalid_operation) + } + + #[tracing::instrument( + name = "db.oauth2_session.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err, + )] + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + , finished_at + FROM oauth2_sessions os + "#, + ); + + query + .push(" WHERE us.user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("oauth2_session_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).try_map(Session::try_from)?; + Ok(page) + } +} diff --git a/crates/storage-pg/src/pagination.rs b/crates/storage-pg/src/pagination.rs new file mode 100644 index 000000000..97e5220f9 --- /dev/null +++ b/crates/storage-pg/src/pagination.rs @@ -0,0 +1,78 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Utilities to manage paginated queries. + +use mas_storage::{pagination::PaginationDirection, Pagination}; +use sqlx::{Database, QueryBuilder}; +use uuid::Uuid; + +/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination +/// to a query +pub trait QueryBuilderExt { + /// Add cursor-based pagination to a query, as used in paginated GraphQL + /// connections + fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self; +} + +impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB> +where + DB: Database, + Uuid: sqlx::Type + sqlx::Encode<'a, DB>, + i64: sqlx::Type + sqlx::Encode<'a, DB>, +{ + fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self { + // ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564 + // 1. Start from the greedy query: SELECT * FROM table + + // 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE` + // clause + if let Some(after) = pagination.after { + self.push(" AND ") + .push(id_field) + .push(" > ") + .push_bind(Uuid::from(after)); + } + + // 3. If the before argument is provided, add `id < parsed_cursor` to the + // `WHERE` clause + if let Some(before) = pagination.before { + self.push(" AND ") + .push(id_field) + .push(" < ") + .push_bind(Uuid::from(before)); + } + + match pagination.direction { + // 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the + // query + PaginationDirection::Forward => { + self.push(" ORDER BY ") + .push(id_field) + .push(" ASC LIMIT ") + .push_bind((pagination.count + 1) as i64); + } + // 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the + // query + PaginationDirection::Backward => { + self.push(" ORDER BY ") + .push(id_field) + .push(" DESC LIMIT ") + .push_bind((pagination.count + 1) as i64); + } + }; + + self + } +} diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs new file mode 100644 index 000000000..288181a67 --- /dev/null +++ b/crates/storage-pg/src/repository.rs @@ -0,0 +1,142 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use mas_storage::Repository; +use sqlx::{PgPool, Postgres, Transaction}; + +use crate::{ + compat::{ + PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository, + PgCompatSsoLoginRepository, + }, + oauth2::{ + PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository, + PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, + }, + upstream_oauth2::{ + PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, + PgUpstreamOAuthSessionRepository, + }, + user::{ + PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository, + PgUserRepository, + }, + DatabaseError, +}; + +pub struct PgRepository { + txn: Transaction<'static, Postgres>, +} + +impl PgRepository { + pub async fn from_pool(pool: &PgPool) -> Result { + let txn = pool.begin().await?; + Ok(PgRepository { txn }) + } + + pub async fn save(self) -> Result<(), DatabaseError> { + self.txn.commit().await?; + Ok(()) + } + + pub async fn cancel(self) -> Result<(), DatabaseError> { + self.txn.rollback().await?; + Ok(()) + } +} + +impl Repository for PgRepository { + type Error = DatabaseError; + + type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; + type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; + type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; + type UserRepository<'c> = PgUserRepository<'c> where Self: 'c; + type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; + type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; + type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; + type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; + type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c; + type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; + type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c; + type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c; + type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; + type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; + type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; + type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; + + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { + PgUpstreamOAuthLinkRepository::new(&mut self.txn) + } + + fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { + PgUpstreamOAuthProviderRepository::new(&mut self.txn) + } + + fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { + PgUpstreamOAuthSessionRepository::new(&mut self.txn) + } + + fn user(&mut self) -> Self::UserRepository<'_> { + PgUserRepository::new(&mut self.txn) + } + + fn user_email(&mut self) -> Self::UserEmailRepository<'_> { + PgUserEmailRepository::new(&mut self.txn) + } + + fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { + PgUserPasswordRepository::new(&mut self.txn) + } + + fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { + PgBrowserSessionRepository::new(&mut self.txn) + } + + fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { + PgOAuth2ClientRepository::new(&mut self.txn) + } + + fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { + PgOAuth2AuthorizationGrantRepository::new(&mut self.txn) + } + + fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { + PgOAuth2SessionRepository::new(&mut self.txn) + } + + fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { + PgOAuth2AccessTokenRepository::new(&mut self.txn) + } + + fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { + PgOAuth2RefreshTokenRepository::new(&mut self.txn) + } + + fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { + PgCompatSessionRepository::new(&mut self.txn) + } + + fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { + PgCompatSsoLoginRepository::new(&mut self.txn) + } + + fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { + PgCompatAccessTokenRepository::new(&mut self.txn) + } + + fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { + PgCompatRefreshTokenRepository::new(&mut self.txn) + } +} diff --git a/crates/storage/src/tracing.rs b/crates/storage-pg/src/tracing.rs similarity index 95% rename from crates/storage/src/tracing.rs rename to crates/storage-pg/src/tracing.rs index 08c62e465..1210816c5 100644 --- a/crates/storage/src/tracing.rs +++ b/crates/storage-pg/src/tracing.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/crates/storage-pg/src/upstream_oauth2/link.rs b/crates/storage-pg/src/upstream_oauth2/link.rs new file mode 100644 index 000000000..4087e2c78 --- /dev/null +++ b/crates/storage-pg/src/upstream_oauth2/link.rs @@ -0,0 +1,262 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; +use mas_storage::{upstream_oauth2::UpstreamOAuthLinkRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +pub struct PgUpstreamOAuthLinkRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthLinkRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct LinkLookup { + upstream_oauth_link_id: Uuid, + upstream_oauth_provider_id: Uuid, + user_id: Option, + subject: String, + created_at: DateTime, +} + +impl From for UpstreamOAuthLink { + fn from(value: LinkLookup) -> Self { + UpstreamOAuthLink { + id: Ulid::from(value.upstream_oauth_link_id), + provider_id: Ulid::from(value.upstream_oauth_provider_id), + user_id: value.user_id.map(Ulid::from), + subject: value.subject, + created_at: value.created_at, + } + } +} + +#[async_trait] +impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.upstream_oauth_link.lookup", + skip_all, + fields( + db.statement, + upstream_oauth_link.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + LinkLookup, + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + WHERE upstream_oauth_link_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()? + .map(Into::into); + + Ok(res) + } + + #[tracing::instrument( + name = "db.upstream_oauth_link.find_by_subject", + skip_all, + fields( + db.statement, + upstream_oauth_link.subject = subject, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + ), + err, + )] + async fn find_by_subject( + &mut self, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + LinkLookup, + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + WHERE upstream_oauth_provider_id = $1 + AND subject = $2 + "#, + Uuid::from(upstream_oauth_provider.id), + subject, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()? + .map(Into::into); + + Ok(res) + } + + #[tracing::instrument( + name = "db.upstream_oauth_link.add", + skip_all, + fields( + db.statement, + upstream_oauth_link.id, + upstream_oauth_link.subject = subject, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_links ( + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + ) VALUES ($1, $2, NULL, $3, $4) + "#, + Uuid::from(id), + Uuid::from(upstream_oauth_provider.id), + &subject, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UpstreamOAuthLink { + id, + provider_id: upstream_oauth_provider.id, + user_id: None, + subject, + created_at, + }) + } + + #[tracing::instrument( + name = "db.upstream_oauth_link.associate_to_user", + skip_all, + fields( + db.statement, + %upstream_oauth_link.id, + %upstream_oauth_link.subject, + %user.id, + %user.username, + ), + err, + )] + async fn associate_to_user( + &mut self, + upstream_oauth_link: &UpstreamOAuthLink, + user: &User, + ) -> Result<(), Self::Error> { + sqlx::query!( + r#" + UPDATE upstream_oauth_links + SET user_id = $1 + WHERE upstream_oauth_link_id = $2 + "#, + Uuid::from(user.id), + Uuid::from(upstream_oauth_link.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + #[tracing::instrument( + name = "db.upstream_oauth_link.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err + )] + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("upstream_oauth_link_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).map(UpstreamOAuthLink::from); + Ok(page) + } +} diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs new file mode 100644 index 000000000..e77daba2e --- /dev/null +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -0,0 +1,271 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod link; +mod provider; +mod session; + +pub use self::{ + link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository, + session::PgUpstreamOAuthSessionRepository, +}; + +#[cfg(test)] +mod tests { + use chrono::Duration; + use mas_storage::{ + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + user::UserRepository, + Clock, Pagination, Repository, + }; + use oauth2_types::scope::{Scope, OPENID}; + use rand::SeedableRng; + use sqlx::PgPool; + + use crate::PgRepository; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_repository(pool: PgPool) { + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // The provider list should be empty at the start + let all_providers = repo.upstream_oauth_provider().all().await.unwrap(); + assert!(all_providers.is_empty()); + + // Let's add a provider + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + "https://example.com/".to_owned(), + Scope::from_iter([OPENID]), + mas_iana::oauth::OAuthClientAuthenticationMethod::None, + None, + "client-id".to_owned(), + None, + ) + .await + .unwrap(); + + // Look it up in the database + let provider = repo + .upstream_oauth_provider() + .lookup(provider.id) + .await + .unwrap() + .expect("provider to be found in the database"); + assert_eq!(provider.issuer, "https://example.com/"); + assert_eq!(provider.client_id, "client-id"); + + // Start a session + let session = repo + .upstream_oauth_session() + .add( + &mut rng, + &clock, + &provider, + "some-state".to_owned(), + None, + "some-nonce".to_owned(), + ) + .await + .unwrap(); + + // Look it up in the database + let session = repo + .upstream_oauth_session() + .lookup(session.id) + .await + .unwrap() + .expect("session to be found in the database"); + assert_eq!(session.provider_id, provider.id); + assert_eq!(session.link_id(), None); + assert!(session.is_pending()); + assert!(!session.is_completed()); + assert!(!session.is_consumed()); + + // Create a link + let link = repo + .upstream_oauth_link() + .add(&mut rng, &clock, &provider, "a-subject".to_owned()) + .await + .unwrap(); + + // We can look it up by its ID + repo.upstream_oauth_link() + .lookup(link.id) + .await + .unwrap() + .expect("link to be found in database"); + + // or by its subject + let link = repo + .upstream_oauth_link() + .find_by_subject(&provider, "a-subject") + .await + .unwrap() + .expect("link to be found in database"); + assert_eq!(link.subject, "a-subject"); + assert_eq!(link.provider_id, provider.id); + + let session = repo + .upstream_oauth_session() + .complete_with_link(&clock, session, &link, None) + .await + .unwrap(); + // Reload the session + let session = repo + .upstream_oauth_session() + .lookup(session.id) + .await + .unwrap() + .expect("session to be found in the database"); + assert!(session.is_completed()); + assert!(!session.is_consumed()); + assert_eq!(session.link_id(), Some(link.id)); + + let session = repo + .upstream_oauth_session() + .consume(&clock, session) + .await + .unwrap(); + // Reload the session + let session = repo + .upstream_oauth_session() + .lookup(session.id) + .await + .unwrap() + .expect("session to be found in the database"); + assert!(session.is_consumed()); + + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + repo.upstream_oauth_link() + .associate_to_user(&link, &user) + .await + .unwrap(); + + let links = repo + .upstream_oauth_link() + .list_paginated(&user, Pagination::first(10)) + .await + .unwrap(); + assert!(!links.has_previous_page); + assert!(!links.has_next_page); + assert_eq!(links.edges.len(), 1); + assert_eq!(links.edges[0].id, link.id); + assert_eq!(links.edges[0].user_id, Some(user.id)); + } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_provider_repository_pagination(pool: PgPool) { + const ISSUER: &str = "https://example.com/"; + let scope = Scope::from_iter([OPENID]); + + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + let mut ids = Vec::with_capacity(20); + // Create 20 providers + for idx in 0..20 { + let client_id = format!("client-{idx}"); + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + ISSUER.to_owned(), + scope.clone(), + mas_iana::oauth::OAuthClientAuthenticationMethod::None, + None, + client_id, + None, + ) + .await + .unwrap(); + ids.push(provider.id); + clock.advance(Duration::seconds(10)); + } + + // Lookup the first 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::first(10)) + .await + .unwrap(); + + // It returned the first 10 items + assert!(page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[..10]); + + // Lookup the next 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::first(10).after(ids[9])) + .await + .unwrap(); + + // It returned the next 10 items + assert!(!page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[10..]); + + // Lookup the last 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::last(10)) + .await + .unwrap(); + + // It returned the last 10 items + assert!(page.has_previous_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[10..]); + + // Lookup the previous 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::last(10).before(ids[10])) + .await + .unwrap(); + + // It returned the previous 10 items + assert!(!page.has_previous_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[..10]); + + // Lookup 10 items between two IDs + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::first(10).after(ids[5]).before(ids[8])) + .await + .unwrap(); + + // It returned the items in between + assert!(!page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[6..8]); + } +} diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs new file mode 100644 index 000000000..480249eee --- /dev/null +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -0,0 +1,273 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::UpstreamOAuthProvider; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Clock, Page, Pagination}; +use oauth2_types::scope::Scope; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +pub struct PgUpstreamOAuthProviderRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthProviderRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct ProviderLookup { + upstream_oauth_provider_id: Uuid, + issuer: String, + scope: String, + client_id: String, + encrypted_client_secret: Option, + token_endpoint_signing_alg: Option, + token_endpoint_auth_method: String, + created_at: DateTime, +} + +impl TryFrom for UpstreamOAuthProvider { + type Error = DatabaseInconsistencyError; + fn try_from(value: ProviderLookup) -> Result { + let id = value.upstream_oauth_provider_id.into(); + let scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("upstream_oauth_providers") + .column("scope") + .row(id) + .source(e) + })?; + let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| { + DatabaseInconsistencyError::on("upstream_oauth_providers") + .column("token_endpoint_auth_method") + .row(id) + .source(e) + })?; + let token_endpoint_signing_alg = value + .token_endpoint_signing_alg + .map(|x| x.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("upstream_oauth_providers") + .column("token_endpoint_signing_alg") + .row(id) + .source(e) + })?; + + Ok(UpstreamOAuthProvider { + id, + issuer: value.issuer, + scope, + client_id: value.client_id, + encrypted_client_secret: value.encrypted_client_secret, + token_endpoint_auth_method, + token_endpoint_signing_alg, + created_at: value.created_at, + }) + } +} + +#[async_trait] +impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.upstream_oauth_provider.lookup", + skip_all, + fields( + db.statement, + upstream_oauth_provider.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + ProviderLookup, + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let res = res + .map(UpstreamOAuthProvider::try_from) + .transpose() + .map_err(DatabaseError::from)?; + + Ok(res) + } + + #[tracing::instrument( + name = "db.upstream_oauth_provider.add", + skip_all, + fields( + db.statement, + upstream_oauth_provider.id, + upstream_oauth_provider.issuer = %issuer, + upstream_oauth_provider.client_id = %client_id, + ), + err, + )] + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_providers ( + upstream_oauth_provider_id, + issuer, + scope, + token_endpoint_auth_method, + token_endpoint_signing_alg, + client_id, + encrypted_client_secret, + created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + "#, + Uuid::from(id), + &issuer, + scope.to_string(), + token_endpoint_auth_method.to_string(), + token_endpoint_signing_alg.as_ref().map(ToString::to_string), + &client_id, + encrypted_client_secret.as_deref(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UpstreamOAuthProvider { + id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at, + }) + } + + #[tracing::instrument( + name = "db.upstream_oauth_provider.list_paginated", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn list_paginated( + &mut self, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + WHERE 1 = 1 + "#, + ); + + query.generate_pagination("upstream_oauth_provider_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).try_map(TryInto::try_into)?; + Ok(page) + } + + #[tracing::instrument( + name = "db.upstream_oauth_provider.all", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn all(&mut self) -> Result, Self::Error> { + let res = sqlx::query_as!( + ProviderLookup, + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + "#, + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let res: Result, _> = res.into_iter().map(TryInto::try_into).collect(); + Ok(res?) + } +} diff --git a/crates/storage-pg/src/upstream_oauth2/session.rs b/crates/storage-pg/src/upstream_oauth2/session.rs new file mode 100644 index 000000000..699a463f0 --- /dev/null +++ b/crates/storage-pg/src/upstream_oauth2/session.rs @@ -0,0 +1,286 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, + UpstreamOAuthProvider, +}; +use mas_storage::{upstream_oauth2::UpstreamOAuthSessionRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +pub struct PgUpstreamOAuthSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthSessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct SessionLookup { + upstream_oauth_authorization_session_id: Uuid, + upstream_oauth_provider_id: Uuid, + upstream_oauth_link_id: Option, + state: String, + code_challenge_verifier: Option, + nonce: String, + id_token: Option, + created_at: DateTime, + completed_at: Option>, + consumed_at: Option>, +} + +impl TryFrom for UpstreamOAuthAuthorizationSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: SessionLookup) -> Result { + let id = value.upstream_oauth_authorization_session_id.into(); + let state = match ( + value.upstream_oauth_link_id, + value.id_token, + value.completed_at, + value.consumed_at, + ) { + (None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, + (Some(link_id), id_token, Some(completed_at), None) => { + UpstreamOAuthAuthorizationSessionState::Completed { + completed_at, + link_id: link_id.into(), + id_token, + } + } + (Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => { + UpstreamOAuthAuthorizationSessionState::Consumed { + completed_at, + link_id: link_id.into(), + id_token, + consumed_at, + } + } + _ => { + return Err( + DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id), + ) + } + }; + + Ok(Self { + id, + provider_id: value.upstream_oauth_provider_id.into(), + state_str: value.state, + nonce: value.nonce, + code_challenge_verifier: value.code_challenge_verifier, + created_at: value.created_at, + state, + }) + } +} + +#[async_trait] +impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.lookup", + skip_all, + fields( + db.statement, + upstream_oauth_provider.id = %id, + ), + err, + )] + async fn lookup( + &mut self, + id: Ulid, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + SessionLookup, + r#" + SELECT + upstream_oauth_authorization_session_id, + upstream_oauth_provider_id, + upstream_oauth_link_id, + state, + code_challenge_verifier, + nonce, + id_token, + created_at, + completed_at, + consumed_at + FROM upstream_oauth_authorization_sessions + WHERE upstream_oauth_authorization_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.add", + skip_all, + fields( + db.statement, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + upstream_oauth_authorization_session.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + state_str: String, + code_challenge_verifier: Option, + nonce: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "upstream_oauth_authorization_session.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_authorization_sessions ( + upstream_oauth_authorization_session_id, + upstream_oauth_provider_id, + state, + code_challenge_verifier, + nonce, + created_at, + completed_at, + consumed_at, + id_token + ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL) + "#, + Uuid::from(id), + Uuid::from(upstream_oauth_provider.id), + &state_str, + code_challenge_verifier.as_deref(), + nonce, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UpstreamOAuthAuthorizationSession { + id, + state: UpstreamOAuthAuthorizationSessionState::default(), + provider_id: upstream_oauth_provider.id, + state_str, + code_challenge_verifier, + nonce, + created_at, + }) + } + + #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.complete_with_link", + skip_all, + fields( + db.statement, + %upstream_oauth_authorization_session.id, + %upstream_oauth_link.id, + ), + err, + )] + async fn complete_with_link( + &mut self, + clock: &Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result { + let completed_at = clock.now(); + + sqlx::query!( + r#" + UPDATE upstream_oauth_authorization_sessions + SET upstream_oauth_link_id = $1, + completed_at = $2, + id_token = $3 + WHERE upstream_oauth_authorization_session_id = $4 + "#, + Uuid::from(upstream_oauth_link.id), + completed_at, + id_token, + Uuid::from(upstream_oauth_authorization_session.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let upstream_oauth_authorization_session = upstream_oauth_authorization_session + .complete(completed_at, upstream_oauth_link, id_token) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(upstream_oauth_authorization_session) + } + + /// Mark a session as consumed + #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.consume", + skip_all, + fields( + db.statement, + %upstream_oauth_authorization_session.id, + ), + err, + )] + async fn consume( + &mut self, + clock: &Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + ) -> Result { + let consumed_at = clock.now(); + sqlx::query!( + r#" + UPDATE upstream_oauth_authorization_sessions + SET consumed_at = $1 + WHERE upstream_oauth_authorization_session_id = $2 + "#, + consumed_at, + Uuid::from(upstream_oauth_authorization_session.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let upstream_oauth_authorization_session = upstream_oauth_authorization_session + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(upstream_oauth_authorization_session) + } +} diff --git a/crates/storage-pg/src/user/email.rs b/crates/storage-pg/src/user/email.rs new file mode 100644 index 000000000..936b06491 --- /dev/null +++ b/crates/storage-pg/src/user/email.rs @@ -0,0 +1,554 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState}; +use mas_storage::{user::UserEmailRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use tracing::{info_span, Instrument}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +pub struct PgUserEmailRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserEmailRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(Debug, Clone, sqlx::FromRow)] +struct UserEmailLookup { + user_email_id: Uuid, + user_id: Uuid, + email: String, + created_at: DateTime, + confirmed_at: Option>, +} + +impl From for UserEmail { + fn from(e: UserEmailLookup) -> UserEmail { + UserEmail { + id: e.user_email_id.into(), + user_id: e.user_id.into(), + email: e.email, + created_at: e.created_at, + confirmed_at: e.confirmed_at, + } + } +} + +struct UserEmailConfirmationCodeLookup { + user_email_confirmation_code_id: Uuid, + user_email_id: Uuid, + code: String, + created_at: DateTime, + expires_at: DateTime, + consumed_at: Option>, +} + +impl UserEmailConfirmationCodeLookup { + fn into_verification(self, clock: &Clock) -> UserEmailVerification { + let now = clock.now(); + let state = if let Some(when) = self.consumed_at { + UserEmailVerificationState::AlreadyUsed { when } + } else if self.expires_at < now { + UserEmailVerificationState::Expired { + when: self.expires_at, + } + } else { + UserEmailVerificationState::Valid + }; + + UserEmailVerification { + id: self.user_email_confirmation_code_id.into(), + user_email_id: self.user_email_id.into(), + code: self.code, + state, + created_at: self.created_at, + } + } +} + +#[async_trait] +impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.user_email.lookup", + skip_all, + fields( + db.statement, + user_email.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailLookup, + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + + WHERE user_email_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(user_email) = res else { return Ok(None) }; + + Ok(Some(user_email.into())) + } + + #[tracing::instrument( + name = "db.user_email.find", + skip_all, + fields( + db.statement, + %user.id, + user_email.email = email, + ), + err, + )] + async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailLookup, + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + + WHERE user_id = $1 AND email = $2 + "#, + Uuid::from(user.id), + email, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(user_email) = res else { return Ok(None) }; + + Ok(Some(user_email.into())) + } + + #[tracing::instrument( + name = "db.user_email.get_primary", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn get_primary(&mut self, user: &User) -> Result, Self::Error> { + let Some(id) = user.primary_user_email_id else { return Ok(None) }; + + let user_email = self.lookup(id).await?.ok_or_else(|| { + DatabaseInconsistencyError::on("users") + .column("primary_user_email_id") + .row(user.id) + })?; + + Ok(Some(user_email)) + } + + #[tracing::instrument( + name = "db.user_email.all", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn all(&mut self, user: &User) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailLookup, + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + + WHERE user_id = $1 + + ORDER BY email ASC + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + Ok(res.into_iter().map(Into::into).collect()) + } + + #[tracing::instrument( + name = "db.user_email.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, DatabaseError> { + let mut query = QueryBuilder::new( + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("ue.user_email_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).map(UserEmail::from); + Ok(page) + } + + #[tracing::instrument( + name = "db.user_email.count", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn count(&mut self, user: &User) -> Result { + let res = sqlx::query_scalar!( + r#" + SELECT COUNT(*) + FROM user_emails + WHERE user_id = $1 + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + let res = res.unwrap_or_default(); + + Ok(res + .try_into() + .map_err(DatabaseError::to_invalid_operation)?) + } + + #[tracing::instrument( + name = "db.user_email.add", + skip_all, + fields( + db.statement, + %user.id, + user_email.id, + user_email.email = email, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + email: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_email.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO user_emails (user_email_id, user_id, email, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + Uuid::from(user.id), + &email, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UserEmail { + id, + user_id: user.id, + email, + created_at, + confirmed_at: None, + }) + } + + #[tracing::instrument( + name = "db.user_email.remove", + skip_all, + fields( + db.statement, + user.id = %user_email.user_id, + %user_email.id, + %user_email.email, + ), + err, + )] + async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> { + let span = info_span!( + "db.user_email.remove.codes", + db.statement = tracing::field::Empty + ); + sqlx::query!( + r#" + DELETE FROM user_email_confirmation_codes + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + + let res = sqlx::query!( + r#" + DELETE FROM user_emails + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(()) + } + + async fn mark_as_verified( + &mut self, + clock: &Clock, + mut user_email: UserEmail, + ) -> Result { + let confirmed_at = clock.now(); + sqlx::query!( + r#" + UPDATE user_emails + SET confirmed_at = $2 + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + confirmed_at, + ) + .execute(&mut *self.conn) + .await?; + + user_email.confirmed_at = Some(confirmed_at); + Ok(user_email) + } + + async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error> { + sqlx::query!( + r#" + UPDATE users + SET primary_user_email_id = user_emails.user_email_id + FROM user_emails + WHERE user_emails.user_email_id = $1 + AND users.user_id = user_emails.user_id + "#, + Uuid::from(user_email.id), + ) + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + #[tracing::instrument( + name = "db.user_email.add_verification_code", + skip_all, + fields( + db.statement, + %user_email.id, + %user_email.email, + user_email_verification.id, + user_email_verification.code = code, + ), + err, + )] + async fn add_verification_code( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user_email: &UserEmail, + max_age: chrono::Duration, + code: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id)); + let expires_at = created_at + max_age; + + sqlx::query!( + r#" + INSERT INTO user_email_confirmation_codes + (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(user_email.id), + code, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let verification = UserEmailVerification { + id, + user_email_id: user_email.id, + code, + created_at, + state: UserEmailVerificationState::Valid, + }; + + Ok(verification) + } + + #[tracing::instrument( + name = "db.user_email.find_verification_code", + skip_all, + fields( + db.statement, + %user_email.id, + user.id = %user_email.user_id, + ), + err, + )] + async fn find_verification_code( + &mut self, + clock: &Clock, + user_email: &UserEmail, + code: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailConfirmationCodeLookup, + r#" + SELECT user_email_confirmation_code_id + , user_email_id + , code + , created_at + , expires_at + , consumed_at + FROM user_email_confirmation_codes + WHERE code = $1 + AND user_email_id = $2 + "#, + code, + Uuid::from(user_email.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into_verification(clock))) + } + + #[tracing::instrument( + name = "db.user_email.consume_verification_code", + skip_all, + fields( + db.statement, + %user_email_verification.id, + user_email.id = %user_email_verification.user_email_id, + ), + err, + )] + async fn consume_verification_code( + &mut self, + clock: &Clock, + mut user_email_verification: UserEmailVerification, + ) -> Result { + if !matches!( + user_email_verification.state, + UserEmailVerificationState::Valid + ) { + return Err(DatabaseError::invalid_operation()); + } + + let consumed_at = clock.now(); + + sqlx::query!( + r#" + UPDATE user_email_confirmation_codes + SET consumed_at = $2 + WHERE user_email_confirmation_code_id = $1 + "#, + Uuid::from(user_email_verification.id), + consumed_at + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_email_verification.state = + UserEmailVerificationState::AlreadyUsed { when: consumed_at }; + + Ok(user_email_verification) + } +} diff --git a/crates/storage-pg/src/user/mod.rs b/crates/storage-pg/src/user/mod.rs new file mode 100644 index 000000000..d73202613 --- /dev/null +++ b/crates/storage-pg/src/user/mod.rs @@ -0,0 +1,203 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::User; +use mas_storage::{user::UserRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +mod email; +mod password; +mod session; + +#[cfg(test)] +mod tests; + +pub use self::{ + email::PgUserEmailRepository, password::PgUserPasswordRepository, + session::PgBrowserSessionRepository, +}; + +pub struct PgUserRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(Debug, Clone)] +struct UserLookup { + user_id: Uuid, + username: String, + primary_user_email_id: Option, + + #[allow(dead_code)] + created_at: DateTime, +} + +impl From for User { + fn from(value: UserLookup) -> Self { + let id = value.user_id.into(); + Self { + id, + username: value.username, + sub: id.to_string(), + primary_user_email_id: value.primary_user_email_id.map(Into::into), + } + } +} + +#[async_trait] +impl<'c> UserRepository for PgUserRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.user.lookup", + skip_all, + fields( + db.statement, + user.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserLookup, + r#" + SELECT user_id + , username + , primary_user_email_id + , created_at + FROM users + WHERE user_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.user.find_by_username", + skip_all, + fields( + db.statement, + user.username = username, + ), + err, + )] + async fn find_by_username(&mut self, username: &str) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserLookup, + r#" + SELECT user_id + , username + , primary_user_email_id + , created_at + FROM users + WHERE username = $1 + "#, + username, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.user.add", + skip_all, + fields( + db.statement, + user.username = username, + user.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + username: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO users (user_id, username, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + username, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(User { + id, + username, + sub: id.to_string(), + primary_user_email_id: None, + }) + } + + #[tracing::instrument( + name = "db.user.exists", + skip_all, + fields( + db.statement, + user.username = username, + ), + err, + )] + async fn exists(&mut self, username: &str) -> Result { + let exists = sqlx::query_scalar!( + r#" + SELECT EXISTS( + SELECT 1 FROM users WHERE username = $1 + ) AS "exists!" + "#, + username + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + Ok(exists) + } +} diff --git a/crates/storage-pg/src/user/password.rs b/crates/storage-pg/src/user/password.rs new file mode 100644 index 000000000..997b12272 --- /dev/null +++ b/crates/storage-pg/src/user/password.rs @@ -0,0 +1,155 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{Password, User}; +use mas_storage::{user::UserPasswordRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +pub struct PgUserPasswordRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserPasswordRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct UserPasswordLookup { + user_password_id: Uuid, + hashed_password: String, + version: i32, + upgraded_from_id: Option, + created_at: DateTime, +} + +#[async_trait] +impl<'c> UserPasswordRepository for PgUserPasswordRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.user_password.active", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err, + )] + async fn active(&mut self, user: &User) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserPasswordLookup, + r#" + SELECT up.user_password_id + , up.hashed_password + , up.version + , up.upgraded_from_id + , up.created_at + FROM user_passwords up + WHERE up.user_id = $1 + ORDER BY up.created_at DESC + LIMIT 1 + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + let id = Ulid::from(res.user_password_id); + + let version = res.version.try_into().map_err(|e| { + DatabaseInconsistencyError::on("user_passwords") + .column("version") + .row(id) + .source(e) + })?; + + let upgraded_from_id = res.upgraded_from_id.map(Ulid::from); + let created_at = res.created_at; + let hashed_password = res.hashed_password; + + Ok(Some(Password { + id, + hashed_password, + version, + upgraded_from_id, + created_at, + })) + } + + #[tracing::instrument( + name = "db.user_password.add", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + user_password.id, + user_password.version = version, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + version: u16, + hashed_password: String, + upgraded_from: Option<&Password>, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_password.id", tracing::field::display(id)); + + let upgraded_from_id = upgraded_from.map(|p| p.id); + + sqlx::query!( + r#" + INSERT INTO user_passwords + (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + Uuid::from(id), + Uuid::from(user.id), + hashed_password, + i32::from(version), + upgraded_from_id.map(Uuid::from), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(Password { + id, + hashed_password, + version, + upgraded_from_id, + created_at, + }) + } +} diff --git a/crates/storage-pg/src/user/session.rs b/crates/storage-pg/src/user/session.rs new file mode 100644 index 000000000..d216c0679 --- /dev/null +++ b/crates/storage-pg/src/user/session.rs @@ -0,0 +1,375 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User}; +use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +pub struct PgBrowserSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgBrowserSessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct SessionLookup { + user_session_id: Uuid, + user_session_created_at: DateTime, + user_session_finished_at: Option>, + user_id: Uuid, + user_username: String, + user_primary_user_email_id: Option, + last_authentication_id: Option, + last_authd_at: Option>, +} + +impl TryFrom for BrowserSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: SessionLookup) -> Result { + let id = Ulid::from(value.user_id); + let user = User { + id, + username: value.user_username, + sub: id.to_string(), + primary_user_email_id: value.user_primary_user_email_id.map(Into::into), + }; + + let last_authentication = match (value.last_authentication_id, value.last_authd_at) { + (Some(id), Some(created_at)) => Some(Authentication { + id: id.into(), + created_at, + }), + (None, None) => None, + _ => { + return Err(DatabaseInconsistencyError::on( + "user_session_authentications", + )) + } + }; + + Ok(BrowserSession { + id: value.user_session_id.into(), + user, + created_at: value.user_session_created_at, + finished_at: value.user_session_finished_at, + last_authentication, + }) + } +} + +#[async_trait] +impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.browser_session.lookup", + skip_all, + fields( + db.statement, + user_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + SessionLookup, + r#" + SELECT s.user_session_id + , s.created_at AS "user_session_created_at" + , s.finished_at AS "user_session_finished_at" + , u.user_id + , u.username AS "user_username" + , u.primary_user_email_id AS "user_primary_user_email_id" + , a.user_session_authentication_id AS "last_authentication_id?" + , a.created_at AS "last_authd_at?" + FROM user_sessions s + INNER JOIN users u + USING (user_id) + LEFT JOIN user_session_authentications a + USING (user_session_id) + WHERE s.user_session_id = $1 + ORDER BY a.created_at DESC + LIMIT 1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.browser_session.add", + skip_all, + fields( + db.statement, + %user.id, + user_session.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO user_sessions (user_session_id, user_id, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + Uuid::from(user.id), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let session = BrowserSession { + id, + // XXX + user: user.clone(), + created_at, + finished_at: None, + last_authentication: None, + }; + + Ok(session) + } + + #[tracing::instrument( + name = "db.browser_session.finish", + skip_all, + fields( + db.statement, + %user_session.id, + ), + err, + )] + async fn finish( + &mut self, + clock: &Clock, + mut user_session: BrowserSession, + ) -> Result { + let finished_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE user_sessions + SET finished_at = $1 + WHERE user_session_id = $2 + "#, + finished_at, + Uuid::from(user_session.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_session.finished_at = Some(finished_at); + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(user_session) + } + + #[tracing::instrument( + name = "db.browser_session.list_active_paginated", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn list_active_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error> { + // TODO: ordering of last authentication is wrong + let mut query = QueryBuilder::new( + r#" + SELECT DISTINCT ON (s.user_session_id) + s.user_session_id, + u.user_id, + u.username, + s.created_at, + a.user_session_authentication_id AS "last_authentication_id", + a.created_at AS "last_authd_at", + FROM user_sessions s + INNER JOIN users u + USING (user_id) + LEFT JOIN user_session_authentications a + USING (user_session_id) + "#, + ); + + query + .push(" WHERE s.finished_at IS NULL AND s.user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("s.user_session_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination + .process(edges) + .try_map(BrowserSession::try_from)?; + Ok(page) + } + + #[tracing::instrument( + name = "db.browser_session.count_active", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn count_active(&mut self, user: &User) -> Result { + let res = sqlx::query_scalar!( + r#" + SELECT COUNT(*) as "count!" + FROM user_sessions s + WHERE s.user_id = $1 AND s.finished_at IS NULL + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + res.try_into().map_err(DatabaseError::to_invalid_operation) + } + + #[tracing::instrument( + name = "db.browser_session.authenticate_with_password", + skip_all, + fields( + db.statement, + %user_session.id, + %user_password.id, + user_session_authentication.id, + ), + err, + )] + async fn authenticate_with_password( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + mut user_session: BrowserSession, + user_password: &Password, + ) -> Result { + let _user_password = user_password; + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "user_session_authentication.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO user_session_authentications + (user_session_authentication_id, user_session_id, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_session.last_authentication = Some(Authentication { id, created_at }); + + Ok(user_session) + } + + #[tracing::instrument( + name = "db.browser_session.authenticate_with_upstream", + skip_all, + fields( + db.statement, + %user_session.id, + %upstream_oauth_link.id, + user_session_authentication.id, + ), + err, + )] + async fn authenticate_with_upstream( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + mut user_session: BrowserSession, + upstream_oauth_link: &UpstreamOAuthLink, + ) -> Result { + let _upstream_oauth_link = upstream_oauth_link; + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "user_session_authentication.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO user_session_authentications + (user_session_authentication_id, user_session_id, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_session.last_authentication = Some(Authentication { id, created_at }); + + Ok(user_session) + } +} diff --git a/crates/storage/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs similarity index 98% rename from crates/storage/src/user/tests.rs rename to crates/storage-pg/src/user/tests.rs index fca35ce05..f0a071b01 100644 --- a/crates/storage/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -13,14 +13,15 @@ // limitations under the License. use chrono::Duration; +use mas_storage::{ + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + Clock, Repository, +}; use rand::SeedableRng; use rand_chacha::ChaChaRng; use sqlx::PgPool; -use crate::{ - user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - Clock, PgRepository, Repository, -}; +use crate::PgRepository; /// Test the user repository, by adding and looking up a user #[sqlx::test(migrator = "crate::MIGRATOR")] @@ -88,7 +89,7 @@ async fn test_user_email_repo(pool: PgPool) { // The user email should not exist yet assert!(repo .user_email() - .find(&user, &EMAIL) + .find(&user, EMAIL) .await .unwrap() .is_none()); @@ -109,7 +110,7 @@ async fn test_user_email_repo(pool: PgPool) { assert!(repo .user_email() - .find(&user, &EMAIL) + .find(&user, EMAIL) .await .unwrap() .is_some()); @@ -179,7 +180,7 @@ async fn test_user_email_repo(pool: PgPool) { // Reload the user_email let user_email = repo .user_email() - .find(&user, &EMAIL) + .find(&user, EMAIL) .await .unwrap() .expect("user email was not found"); diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index fb6c0fdce..97089e956 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -7,18 +7,12 @@ license = "Apache-2.0" [dependencies] async-trait = "0.1.60" -sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] } -chrono = { version = "0.4.23", features = ["serde"] } -serde = { version = "1.0.152", features = ["derive"] } -serde_json = "1.0.91" +chrono = "0.4.23" thiserror = "1.0.38" -tracing = "0.1.37" rand = "0.8.5" -rand_chacha = "0.3.1" -url = { version = "2.3.1", features = ["serde"] } -uuid = "1.2.2" -ulid = { version = "1.0.0", features = ["uuid", "serde"] } +url = "2.3.1" +ulid = "1.0.0" oauth2-types = { path = "../oauth2-types" } mas-data-model = { path = "../data-model" } diff --git a/crates/storage/src/compat/access_token.rs b/crates/storage/src/compat/access_token.rs index 86d2dd198..46ff6e3f8 100644 --- a/crates/storage/src/compat/access_token.rs +++ b/crates/storage/src/compat/access_token.rs @@ -13,14 +13,12 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Duration, Utc}; +use chrono::Duration; use mas_data_model::{CompatAccessToken, CompatSession}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; +use crate::Clock; #[async_trait] pub trait CompatAccessTokenRepository: Send + Sync { @@ -52,195 +50,3 @@ pub trait CompatAccessTokenRepository: Send + Sync { compat_access_token: CompatAccessToken, ) -> Result; } - -pub struct PgCompatAccessTokenRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgCompatAccessTokenRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct CompatAccessTokenLookup { - compat_access_token_id: Uuid, - access_token: String, - created_at: DateTime, - expires_at: Option>, - compat_session_id: Uuid, -} - -impl From for CompatAccessToken { - fn from(value: CompatAccessTokenLookup) -> Self { - Self { - id: value.compat_access_token_id.into(), - session_id: value.compat_session_id.into(), - token: value.access_token, - created_at: value.created_at, - expires_at: value.expires_at, - } - } -} - -#[async_trait] -impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.compat_access_token.lookup", - skip_all, - fields( - db.statement, - compat_session.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatAccessTokenLookup, - r#" - SELECT compat_access_token_id - , access_token - , created_at - , expires_at - , compat_session_id - - FROM compat_access_tokens - - WHERE compat_access_token_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.compat_access_token.find_by_token", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn find_by_token( - &mut self, - access_token: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatAccessTokenLookup, - r#" - SELECT compat_access_token_id - , access_token - , created_at - , expires_at - , compat_session_id - - FROM compat_access_tokens - - WHERE access_token = $1 - "#, - access_token, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.compat_access_token.add", - skip_all, - fields( - db.statement, - compat_access_token.id, - %compat_session.id, - user.id = %compat_session.user_id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - compat_session: &CompatSession, - token: String, - expires_after: Option, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); - - let expires_at = expires_after.map(|expires_after| created_at + expires_after); - - sqlx::query!( - r#" - INSERT INTO compat_access_tokens - (compat_access_token_id, compat_session_id, access_token, created_at, expires_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(compat_session.id), - token, - created_at, - expires_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(CompatAccessToken { - id, - session_id: compat_session.id, - token, - created_at, - expires_at, - }) - } - - #[tracing::instrument( - name = "db.compat_access_token.expire", - skip_all, - fields( - db.statement, - %compat_access_token.id, - compat_session.id = %compat_access_token.session_id, - ), - err, - )] - async fn expire( - &mut self, - clock: &Clock, - mut compat_access_token: CompatAccessToken, - ) -> Result { - let expires_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE compat_access_tokens - SET expires_at = $2 - WHERE compat_access_token_id = $1 - "#, - Uuid::from(compat_access_token.id), - expires_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - compat_access_token.expires_at = Some(expires_at); - Ok(compat_access_token) - } -} diff --git a/crates/storage/src/compat/mod.rs b/crates/storage/src/compat/mod.rs index c37081b8c..634c04a7f 100644 --- a/crates/storage/src/compat/mod.rs +++ b/crates/storage/src/compat/mod.rs @@ -18,301 +18,6 @@ mod session; mod sso_login; pub use self::{ - access_token::{CompatAccessTokenRepository, PgCompatAccessTokenRepository}, - refresh_token::{CompatRefreshTokenRepository, PgCompatRefreshTokenRepository}, - session::{CompatSessionRepository, PgCompatSessionRepository}, - sso_login::{CompatSsoLoginRepository, PgCompatSsoLoginRepository}, + access_token::CompatAccessTokenRepository, refresh_token::CompatRefreshTokenRepository, + session::CompatSessionRepository, sso_login::CompatSsoLoginRepository, }; - -#[cfg(test)] -mod tests { - use chrono::Duration; - use mas_data_model::Device; - use rand::SeedableRng; - use rand_chacha::ChaChaRng; - use sqlx::PgPool; - - use super::*; - use crate::{user::UserRepository, Clock, PgRepository, Repository}; - - #[sqlx::test(migrator = "crate::MIGRATOR")] - async fn test_session_repository(pool: PgPool) { - const FIRST_TOKEN: &str = "first_access_token"; - const SECOND_TOKEN: &str = "second_access_token"; - let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - - // Create a user - let user = repo - .user() - .add(&mut rng, &clock, "john".to_owned()) - .await - .unwrap(); - - // Start a compat session for that user - let device = Device::generate(&mut rng); - let device_str = device.as_str().to_owned(); - let session = repo - .compat_session() - .add(&mut rng, &clock, &user, device) - .await - .unwrap(); - assert_eq!(session.user_id, user.id); - assert_eq!(session.device.as_str(), device_str); - assert!(session.is_valid()); - assert!(!session.is_finished()); - - // Lookup the session and check it didn't change - let session_lookup = repo - .compat_session() - .lookup(session.id) - .await - .unwrap() - .expect("compat session not found"); - assert_eq!(session_lookup.id, session.id); - assert_eq!(session_lookup.user_id, user.id); - assert_eq!(session_lookup.device.as_str(), device_str); - assert!(session_lookup.is_valid()); - assert!(!session_lookup.is_finished()); - - // Finish the session - let session = repo.compat_session().finish(&clock, session).await.unwrap(); - assert!(!session.is_valid()); - assert!(session.is_finished()); - - // Reload the session and check again - let session_lookup = repo - .compat_session() - .lookup(session.id) - .await - .unwrap() - .expect("compat session not found"); - assert!(!session_lookup.is_valid()); - assert!(session_lookup.is_finished()); - } - - #[sqlx::test(migrator = "crate::MIGRATOR")] - async fn test_access_token_repository(pool: PgPool) { - const FIRST_TOKEN: &str = "first_access_token"; - const SECOND_TOKEN: &str = "second_access_token"; - let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - - // Create a user - let user = repo - .user() - .add(&mut rng, &clock, "john".to_owned()) - .await - .unwrap(); - - // Start a compat session for that user - let device = Device::generate(&mut rng); - let session = repo - .compat_session() - .add(&mut rng, &clock, &user, device) - .await - .unwrap(); - - // Add an access token to that session - let token = repo - .compat_access_token() - .add( - &mut rng, - &clock, - &session, - FIRST_TOKEN.to_owned(), - Some(Duration::minutes(1)), - ) - .await - .unwrap(); - assert_eq!(token.session_id, session.id); - assert_eq!(token.token, FIRST_TOKEN); - - // Commit the txn and grab a new transaction, to test a conflict - repo.save().await.unwrap(); - - { - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - // Adding the same token a second time should conflict - assert!(repo - .compat_access_token() - .add( - &mut rng, - &clock, - &session, - FIRST_TOKEN.to_owned(), - Some(Duration::minutes(1)), - ) - .await - .is_err()); - repo.cancel().await.unwrap(); - } - - // Grab a new repo - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - - // Looking up via ID works - let token_lookup = repo - .compat_access_token() - .lookup(token.id) - .await - .unwrap() - .expect("compat access token not found"); - assert_eq!(token.id, token_lookup.id); - assert_eq!(token_lookup.session_id, session.id); - - // Looking up via the token value works - let token_lookup = repo - .compat_access_token() - .find_by_token(FIRST_TOKEN) - .await - .unwrap() - .expect("compat access token not found"); - assert_eq!(token.id, token_lookup.id); - assert_eq!(token_lookup.session_id, session.id); - - // Token is currently valid - assert!(token.is_valid(clock.now())); - - clock.advance(Duration::minutes(1)); - // Token should have expired - assert!(!token.is_valid(clock.now())); - - // Add a second access token, this time without expiration - let token = repo - .compat_access_token() - .add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None) - .await - .unwrap(); - assert_eq!(token.session_id, session.id); - assert_eq!(token.token, SECOND_TOKEN); - - // Token is currently valid - assert!(token.is_valid(clock.now())); - - // Make it expire - repo.compat_access_token() - .expire(&clock, token) - .await - .unwrap(); - - // Reload it - let token = repo - .compat_access_token() - .find_by_token(SECOND_TOKEN) - .await - .unwrap() - .expect("compat access token not found"); - - // Token is not valid anymore - assert!(!token.is_valid(clock.now())); - - repo.save().await.unwrap(); - } - - #[sqlx::test(migrator = "crate::MIGRATOR")] - async fn test_refresh_token_repository(pool: PgPool) { - const ACCESS_TOKEN: &str = "access_token"; - const REFRESH_TOKEN: &str = "refresh_token"; - let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - - // Create a user - let user = repo - .user() - .add(&mut rng, &clock, "john".to_owned()) - .await - .unwrap(); - - // Start a compat session for that user - let device = Device::generate(&mut rng); - let session = repo - .compat_session() - .add(&mut rng, &clock, &user, device) - .await - .unwrap(); - - // Add an access token to that session - let access_token = repo - .compat_access_token() - .add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None) - .await - .unwrap(); - - let refresh_token = repo - .compat_refresh_token() - .add( - &mut rng, - &clock, - &session, - &access_token, - REFRESH_TOKEN.to_owned(), - ) - .await - .unwrap(); - assert_eq!(refresh_token.session_id, session.id); - assert_eq!(refresh_token.access_token_id, access_token.id); - assert_eq!(refresh_token.token, REFRESH_TOKEN); - assert!(refresh_token.is_valid()); - assert!(!refresh_token.is_consumed()); - - // Look it up by ID and check everything matches - let refresh_token_lookup = repo - .compat_refresh_token() - .lookup(refresh_token.id) - .await - .unwrap() - .expect("refresh token not found"); - assert_eq!(refresh_token_lookup.id, refresh_token.id); - assert_eq!(refresh_token_lookup.session_id, session.id); - assert_eq!(refresh_token_lookup.access_token_id, access_token.id); - assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN); - assert!(refresh_token_lookup.is_valid()); - assert!(!refresh_token_lookup.is_consumed()); - - // Look it up by token and check everything matches - let refresh_token_lookup = repo - .compat_refresh_token() - .find_by_token(REFRESH_TOKEN) - .await - .unwrap() - .expect("refresh token not found"); - assert_eq!(refresh_token_lookup.id, refresh_token.id); - assert_eq!(refresh_token_lookup.session_id, session.id); - assert_eq!(refresh_token_lookup.access_token_id, access_token.id); - assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN); - assert!(refresh_token_lookup.is_valid()); - assert!(!refresh_token_lookup.is_consumed()); - - // Consume it - let refresh_token = repo - .compat_refresh_token() - .consume(&clock, refresh_token) - .await - .unwrap(); - assert!(!refresh_token.is_valid()); - assert!(refresh_token.is_consumed()); - - // Reload it and check again - let refresh_token_lookup = repo - .compat_refresh_token() - .find_by_token(REFRESH_TOKEN) - .await - .unwrap() - .expect("refresh token not found"); - assert!(!refresh_token_lookup.is_valid()); - assert!(refresh_token_lookup.is_consumed()); - - // Consuming it again should not work - assert!(repo - .compat_refresh_token() - .consume(&clock, refresh_token) - .await - .is_err()); - - repo.save().await.unwrap(); - } -} diff --git a/crates/storage/src/compat/refresh_token.rs b/crates/storage/src/compat/refresh_token.rs index 300546226..7a1057ff8 100644 --- a/crates/storage/src/compat/refresh_token.rs +++ b/crates/storage/src/compat/refresh_token.rs @@ -13,16 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{ - CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, -}; +use mas_data_model::{CompatAccessToken, CompatRefreshToken, CompatSession}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; +use crate::Clock; #[async_trait] pub trait CompatRefreshTokenRepository: Send + Sync { @@ -54,207 +49,3 @@ pub trait CompatRefreshTokenRepository: Send + Sync { compat_refresh_token: CompatRefreshToken, ) -> Result; } - -pub struct PgCompatRefreshTokenRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgCompatRefreshTokenRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct CompatRefreshTokenLookup { - compat_refresh_token_id: Uuid, - refresh_token: String, - created_at: DateTime, - consumed_at: Option>, - compat_access_token_id: Uuid, - compat_session_id: Uuid, -} - -impl From for CompatRefreshToken { - fn from(value: CompatRefreshTokenLookup) -> Self { - let state = match value.consumed_at { - Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at }, - None => CompatRefreshTokenState::Valid, - }; - - Self { - id: value.compat_refresh_token_id.into(), - state, - session_id: value.compat_session_id.into(), - token: value.refresh_token, - created_at: value.created_at, - access_token_id: value.compat_access_token_id.into(), - } - } -} - -#[async_trait] -impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.compat_refresh_token.lookup", - skip_all, - fields( - db.statement, - compat_refresh_token.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatRefreshTokenLookup, - r#" - SELECT compat_refresh_token_id - , refresh_token - , created_at - , consumed_at - , compat_session_id - , compat_access_token_id - - FROM compat_refresh_tokens - - WHERE compat_refresh_token_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.compat_refresh_token.find_by_token", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn find_by_token( - &mut self, - refresh_token: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatRefreshTokenLookup, - r#" - SELECT compat_refresh_token_id - , refresh_token - , created_at - , consumed_at - , compat_session_id - , compat_access_token_id - - FROM compat_refresh_tokens - - WHERE refresh_token = $1 - "#, - refresh_token, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.compat_refresh_token.add", - skip_all, - fields( - db.statement, - compat_refresh_token.id, - %compat_session.id, - user.id = %compat_session.user_id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - compat_session: &CompatSession, - compat_access_token: &CompatAccessToken, - token: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_refresh_tokens - (compat_refresh_token_id, compat_session_id, - compat_access_token_id, refresh_token, created_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(compat_session.id), - Uuid::from(compat_access_token.id), - token, - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(CompatRefreshToken { - id, - state: CompatRefreshTokenState::default(), - session_id: compat_session.id, - access_token_id: compat_access_token.id, - token, - created_at, - }) - } - - #[tracing::instrument( - name = "db.compat_refresh_token.consume", - skip_all, - fields( - db.statement, - %compat_refresh_token.id, - compat_session.id = %compat_refresh_token.session_id, - ), - err, - )] - async fn consume( - &mut self, - clock: &Clock, - compat_refresh_token: CompatRefreshToken, - ) -> Result { - let consumed_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE compat_refresh_tokens - SET consumed_at = $2 - WHERE compat_refresh_token_id = $1 - "#, - Uuid::from(compat_refresh_token.id), - consumed_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - let compat_refresh_token = compat_refresh_token - .consume(consumed_at) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(compat_refresh_token) - } -} diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index 3068be731..34bc68381 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -13,16 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{CompatSession, CompatSessionState, Device, User}; +use mas_data_model::{CompatSession, Device, User}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::Clock; #[async_trait] pub trait CompatSessionRepository: Send + Sync { @@ -47,174 +42,3 @@ pub trait CompatSessionRepository: Send + Sync { compat_session: CompatSession, ) -> Result; } - -pub struct PgCompatSessionRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgCompatSessionRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct CompatSessionLookup { - compat_session_id: Uuid, - device_id: String, - user_id: Uuid, - created_at: DateTime, - finished_at: Option>, -} - -impl TryFrom for CompatSession { - type Error = DatabaseInconsistencyError; - - fn try_from(value: CompatSessionLookup) -> Result { - let id = value.compat_session_id.into(); - let device = Device::try_from(value.device_id).map_err(|e| { - DatabaseInconsistencyError::on("compat_sessions") - .column("device_id") - .row(id) - .source(e) - })?; - - let state = match value.finished_at { - None => CompatSessionState::Valid, - Some(finished_at) => CompatSessionState::Finished { finished_at }, - }; - - let session = CompatSession { - id, - state, - user_id: value.user_id.into(), - device, - created_at: value.created_at, - }; - - Ok(session) - } -} - -#[async_trait] -impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.compat_session.lookup", - skip_all, - fields( - db.statement, - compat_session.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatSessionLookup, - r#" - SELECT compat_session_id - , device_id - , user_id - , created_at - , finished_at - FROM compat_sessions - WHERE compat_session_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.compat_session.add", - skip_all, - fields( - db.statement, - compat_session.id, - %user.id, - %user.username, - compat_session.device.id = device.as_str(), - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - user: &User, - device: Device, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("compat_session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - Uuid::from(user.id), - device.as_str(), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(CompatSession { - id, - state: CompatSessionState::default(), - user_id: user.id, - device, - created_at, - }) - } - - #[tracing::instrument( - name = "db.compat_session.finish", - skip_all, - fields( - db.statement, - %compat_session.id, - user.id = %compat_session.user_id, - compat_session.device.id = compat_session.device.as_str(), - ), - err, - )] - async fn finish( - &mut self, - clock: &Clock, - compat_session: CompatSession, - ) -> Result { - let finished_at = clock.now(); - - let res = sqlx::query!( - r#" - UPDATE compat_sessions cs - SET finished_at = $2 - WHERE compat_session_id = $1 - "#, - Uuid::from(compat_session.id), - finished_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - let compat_session = compat_session - .finish(finished_at) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(compat_session) - } -} diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs index 76cf1ede4..348e0ac5f 100644 --- a/crates/storage/src/compat/sso_login.rs +++ b/crates/storage/src/compat/sso_login.rs @@ -13,19 +13,12 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User}; +use mas_data_model::{CompatSession, CompatSsoLogin, User}; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; use url::Url; -use uuid::Uuid; -use crate::{ - pagination::{Page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, -}; +use crate::{pagination::Page, Clock, Pagination}; #[async_trait] pub trait CompatSsoLoginRepository: Send + Sync { @@ -71,317 +64,3 @@ pub trait CompatSsoLoginRepository: Send + Sync { pagination: Pagination, ) -> Result, Self::Error>; } - -pub struct PgCompatSsoLoginRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgCompatSsoLoginRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(sqlx::FromRow)] -struct CompatSsoLoginLookup { - compat_sso_login_id: Uuid, - login_token: String, - redirect_uri: String, - created_at: DateTime, - fulfilled_at: Option>, - exchanged_at: Option>, - compat_session_id: Option, -} - -impl TryFrom for CompatSsoLogin { - type Error = DatabaseInconsistencyError; - - fn try_from(res: CompatSsoLoginLookup) -> Result { - let id = res.compat_sso_login_id.into(); - let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| { - DatabaseInconsistencyError::on("compat_sso_logins") - .column("redirect_uri") - .row(id) - .source(e) - })?; - - let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) { - (None, None, None) => CompatSsoLoginState::Pending, - (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled { - fulfilled_at, - session_id: session_id.into(), - }, - (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => { - CompatSsoLoginState::Exchanged { - fulfilled_at, - exchanged_at, - session_id: session_id.into(), - } - } - _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), - }; - - Ok(CompatSsoLogin { - id, - login_token: res.login_token, - redirect_uri, - created_at: res.created_at, - state, - }) - } -} - -#[async_trait] -impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.compat_sso_login.lookup", - skip_all, - fields( - db.statement, - compat_sso_login.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatSsoLoginLookup, - r#" - SELECT compat_sso_login_id - , login_token - , redirect_uri - , created_at - , fulfilled_at - , exchanged_at - , compat_session_id - - FROM compat_sso_logins - WHERE compat_sso_login_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.compat_sso_login.find_by_token", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn find_by_token( - &mut self, - login_token: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatSsoLoginLookup, - r#" - SELECT compat_sso_login_id - , login_token - , redirect_uri - , created_at - , fulfilled_at - , exchanged_at - , compat_session_id - - FROM compat_sso_logins - WHERE login_token = $1 - "#, - login_token, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.compat_sso_login.add", - skip_all, - fields( - db.statement, - compat_sso_login.id, - compat_sso_login.redirect_uri = %redirect_uri, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - login_token: String, - redirect_uri: Url, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_sso_logins - (compat_sso_login_id, login_token, redirect_uri, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - &login_token, - redirect_uri.as_str(), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(CompatSsoLogin { - id, - login_token, - redirect_uri, - created_at, - state: CompatSsoLoginState::default(), - }) - } - - #[tracing::instrument( - name = "db.compat_sso_login.fulfill", - skip_all, - fields( - db.statement, - %compat_sso_login.id, - %compat_session.id, - compat_session.device.id = compat_session.device.as_str(), - user.id = %compat_session.user_id, - ), - err, - )] - async fn fulfill( - &mut self, - clock: &Clock, - compat_sso_login: CompatSsoLogin, - compat_session: &CompatSession, - ) -> Result { - let fulfilled_at = clock.now(); - let compat_sso_login = compat_sso_login - .fulfill(fulfilled_at, compat_session) - .map_err(DatabaseError::to_invalid_operation)?; - - let res = sqlx::query!( - r#" - UPDATE compat_sso_logins - SET - compat_session_id = $2, - fulfilled_at = $3 - WHERE - compat_sso_login_id = $1 - "#, - Uuid::from(compat_sso_login.id), - Uuid::from(compat_session.id), - fulfilled_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - Ok(compat_sso_login) - } - - #[tracing::instrument( - name = "db.compat_sso_login.exchange", - skip_all, - fields( - db.statement, - %compat_sso_login.id, - ), - err, - )] - async fn exchange( - &mut self, - clock: &Clock, - compat_sso_login: CompatSsoLogin, - ) -> Result { - let exchanged_at = clock.now(); - let compat_sso_login = compat_sso_login - .exchange(exchanged_at) - .map_err(DatabaseError::to_invalid_operation)?; - - let res = sqlx::query!( - r#" - UPDATE compat_sso_logins - SET - exchanged_at = $2 - WHERE - compat_sso_login_id = $1 - "#, - Uuid::from(compat_sso_login.id), - exchanged_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - Ok(compat_sso_login) - } - - #[tracing::instrument( - name = "db.compat_sso_login.list_paginated", - skip_all, - fields( - db.statement, - %user.id, - %user.username, - ), - err - )] - async fn list_paginated( - &mut self, - user: &User, - pagination: Pagination, - ) -> Result, Self::Error> { - let mut query = QueryBuilder::new( - r#" - SELECT cl.compat_sso_login_id - , cl.login_token - , cl.redirect_uri - , cl.created_at - , cl.fulfilled_at - , cl.exchanged_at - , cl.compat_session_id - - FROM compat_sso_logins cl - INNER JOIN compat_sessions ON compat_session_id - "#, - ); - - query - .push(" WHERE user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("cl.compat_sso_login_id", pagination); - - let edges: Vec = query - .build_query_as() - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let page = pagination - .process(edges) - .try_map(CompatSsoLogin::try_from)?; - Ok(page) - } -} diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index e92d37fe7..a65c806cf 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -29,150 +29,19 @@ )] use chrono::{DateTime, Utc}; -use pagination::InvalidPagination; -use sqlx::{migrate::Migrator, postgres::PgQueryResult}; -use thiserror::Error; -use ulid::Ulid; - -trait LookupResultExt { - type Output; - - /// Transform a [`Result`] from a sqlx query to transform "not found" errors - /// into [`None`] - fn to_option(self) -> Result, sqlx::Error>; -} - -impl LookupResultExt for Result { - type Output = T; - - fn to_option(self) -> Result, sqlx::Error> { - match self { - Ok(v) => Ok(Some(v)), - Err(sqlx::Error::RowNotFound) => Ok(None), - Err(e) => Err(e), - } - } -} - -/// Generic error when interacting with the database -#[derive(Debug, Error)] -#[error(transparent)] -pub enum DatabaseError { - /// An error which came from the database itself - Driver(#[from] sqlx::Error), - - /// An error which occured while converting the data from the database - Inconsistency(#[from] DatabaseInconsistencyError), - - /// An error which occured while generating the paginated query - Pagination(#[from] InvalidPagination), - - /// An error which happened because the requested database operation is - /// invalid - #[error("Invalid database operation")] - InvalidOperation { - #[source] - source: Option>, - }, - - /// An error which happens when an operation affects not enough or too many - /// rows - #[error("Expected {expected} rows to be affected, but {actual} rows were affected")] - RowsAffected { expected: u64, actual: u64 }, -} - -impl DatabaseError { - pub(crate) fn ensure_affected_rows( - result: &PgQueryResult, - expected: u64, - ) -> Result<(), DatabaseError> { - let actual = result.rows_affected(); - if actual == expected { - Ok(()) - } else { - Err(DatabaseError::RowsAffected { expected, actual }) - } - } - - pub(crate) fn to_invalid_operation(e: E) -> Self { - Self::InvalidOperation { - source: Some(Box::new(e)), - } - } - - pub(crate) const fn invalid_operation() -> Self { - Self::InvalidOperation { source: None } - } -} - -#[derive(Debug, Error)] -pub struct DatabaseInconsistencyError { - table: &'static str, - column: Option<&'static str>, - row: Option, - - #[source] - source: Option>, -} - -impl std::fmt::Display for DatabaseInconsistencyError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Database inconsistency on table {}", self.table)?; - if let Some(column) = self.column { - write!(f, " column {column}")?; - } - if let Some(row) = self.row { - write!(f, " row {row}")?; - } - - Ok(()) - } -} - -impl DatabaseInconsistencyError { - #[must_use] - pub(crate) const fn on(table: &'static str) -> Self { - Self { - table, - column: None, - row: None, - source: None, - } - } - - #[must_use] - pub(crate) const fn column(mut self, column: &'static str) -> Self { - self.column = Some(column); - self - } - - #[must_use] - pub(crate) const fn row(mut self, row: Ulid) -> Self { - self.row = Some(row); - self - } - - pub(crate) fn source( - mut self, - source: E, - ) -> Self { - self.source = Some(Box::new(source)); - self - } -} #[derive(Debug, Clone, Default)] pub struct Clock { _private: (), - #[cfg(test)] + // #[cfg(test)] mock: Option>, } impl Clock { #[must_use] pub fn now(&self) -> DateTime { - #[cfg(test)] + // #[cfg(test)] if let Some(timestamp) = &self.mock { let timestamp = timestamp.load(std::sync::atomic::Ordering::Relaxed); return chrono::TimeZone::timestamp_opt(&Utc, timestamp, 0).unwrap(); @@ -183,13 +52,14 @@ impl Clock { Utc::now() } - #[cfg(test)] + // #[cfg(test)] + #[must_use] pub fn mock() -> Self { use std::sync::{atomic::AtomicI64, Arc}; use chrono::TimeZone; - let datetime = Utc.with_ymd_and_hms(2022, 01, 16, 14, 40, 0).unwrap(); + let datetime = Utc.with_ymd_and_hms(2022, 1, 16, 14, 40, 0).unwrap(); let timestamp = datetime.timestamp(); Self { @@ -198,7 +68,7 @@ impl Clock { } } - #[cfg(test)] + // #[cfg(test)] pub fn advance(&self, duration: chrono::Duration) { let timestamp = self .mock @@ -247,16 +117,12 @@ mod tests { pub mod compat; pub mod oauth2; -pub(crate) mod pagination; +pub mod pagination; pub(crate) mod repository; -pub(crate) mod tracing; pub mod upstream_oauth2; pub mod user; pub use self::{ - pagination::Pagination, - repository::{PgRepository, Repository}, + pagination::{Page, Pagination}, + repository::Repository, }; - -/// Embedded migrations, allowing them to run on startup -pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index db10ed72e..a0406e44c 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -13,14 +13,12 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{AccessToken, AccessTokenState, Session}; +use chrono::Duration; +use mas_data_model::{AccessToken, Session}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; +use crate::Clock; #[async_trait] pub trait OAuth2AccessTokenRepository: Send + Sync { @@ -55,202 +53,3 @@ pub trait OAuth2AccessTokenRepository: Send + Sync { /// Cleanup expired access tokens async fn cleanup_expired(&mut self, clock: &Clock) -> Result; } - -pub struct PgOAuth2AccessTokenRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgOAuth2AccessTokenRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct OAuth2AccessTokenLookup { - oauth2_access_token_id: Uuid, - oauth2_session_id: Uuid, - access_token: String, - created_at: DateTime, - expires_at: DateTime, - revoked_at: Option>, -} - -impl From for AccessToken { - fn from(value: OAuth2AccessTokenLookup) -> Self { - let state = match value.revoked_at { - None => AccessTokenState::Valid, - Some(revoked_at) => AccessTokenState::Revoked { revoked_at }, - }; - - Self { - id: value.oauth2_access_token_id.into(), - state, - session_id: value.oauth2_session_id.into(), - access_token: value.access_token, - created_at: value.created_at, - expires_at: value.expires_at, - } - } -} - -#[async_trait] -impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> { - type Error = DatabaseError; - - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - OAuth2AccessTokenLookup, - r#" - SELECT oauth2_access_token_id - , access_token - , created_at - , expires_at - , revoked_at - , oauth2_session_id - - FROM oauth2_access_tokens - - WHERE oauth2_access_token_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.oauth2_access_token.find_by_token", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn find_by_token( - &mut self, - access_token: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - OAuth2AccessTokenLookup, - r#" - SELECT oauth2_access_token_id - , access_token - , created_at - , expires_at - , revoked_at - , oauth2_session_id - - FROM oauth2_access_tokens - - WHERE access_token = $1 - "#, - access_token, - ) - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.oauth2_access_token.add", - skip_all, - fields( - db.statement, - %session.id, - user_session.id = %session.user_session_id, - client.id = %session.client_id, - access_token.id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - session: &Session, - access_token: String, - expires_after: Duration, - ) -> Result { - let created_at = clock.now(); - let expires_at = created_at + expires_after; - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - - tracing::Span::current().record("access_token.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO oauth2_access_tokens - (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at) - VALUES - ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - &access_token, - created_at, - expires_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(AccessToken { - id, - state: AccessTokenState::default(), - access_token, - session_id: session.id, - created_at, - expires_at, - }) - } - - async fn revoke( - &mut self, - clock: &Clock, - access_token: AccessToken, - ) -> Result { - let revoked_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_access_tokens - SET revoked_at = $2 - WHERE oauth2_access_token_id = $1 - "#, - Uuid::from(access_token.id), - revoked_at, - ) - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - access_token - .revoke(revoked_at) - .map_err(DatabaseError::to_invalid_operation) - } - - async fn cleanup_expired(&mut self, clock: &Clock) -> Result { - // Cleanup token which expired more than 15 minutes ago - let threshold = clock.now() - Duration::minutes(15); - let res = sqlx::query!( - r#" - DELETE FROM oauth2_access_tokens - WHERE expires_at < $1 - "#, - threshold, - ) - .execute(&mut *self.conn) - .await?; - - Ok(res.rows_affected().try_into().unwrap_or(usize::MAX)) - } -} diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index c57c5dcd3..ce1a716ff 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,21 +15,13 @@ use std::num::NonZeroU32; use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{ - AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session, -}; -use mas_iana::oauth::PkceCodeChallengeMethod; +use mas_data_model::{AuthorizationCode, AuthorizationGrant, Client, Session}; use oauth2_types::{requests::ResponseMode, scope::Scope}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; use url::Url; -use uuid::Uuid; -use crate::{ - tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::Clock; #[async_trait] pub trait OAuth2AuthorizationGrantRepository: Send + Sync { @@ -75,482 +67,3 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync { authorization_grant: AuthorizationGrant, ) -> Result; } - -pub struct PgOAuth2AuthorizationGrantRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgOAuth2AuthorizationGrantRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[allow(clippy::struct_excessive_bools)] -struct GrantLookup { - oauth2_authorization_grant_id: Uuid, - created_at: DateTime, - cancelled_at: Option>, - fulfilled_at: Option>, - exchanged_at: Option>, - scope: String, - state: Option, - nonce: Option, - redirect_uri: String, - response_mode: String, - max_age: Option, - response_type_code: bool, - response_type_id_token: bool, - authorization_code: Option, - code_challenge: Option, - code_challenge_method: Option, - requires_consent: bool, - oauth2_client_id: Uuid, - oauth2_session_id: Option, -} - -impl TryFrom for AuthorizationGrant { - type Error = DatabaseInconsistencyError; - - #[allow(clippy::too_many_lines)] - fn try_from(value: GrantLookup) -> Result { - let id = value.oauth2_authorization_grant_id.into(); - let scope: Scope = value.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("scope") - .row(id) - .source(e) - })?; - - let stage = match ( - value.fulfilled_at, - value.exchanged_at, - value.cancelled_at, - value.oauth2_session_id, - ) { - (None, None, None, None) => AuthorizationGrantStage::Pending, - (Some(fulfilled_at), None, None, Some(session_id)) => { - AuthorizationGrantStage::Fulfilled { - session_id: session_id.into(), - fulfilled_at, - } - } - (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => { - AuthorizationGrantStage::Exchanged { - session_id: session_id.into(), - fulfilled_at, - exchanged_at, - } - } - (None, None, Some(cancelled_at), None) => { - AuthorizationGrantStage::Cancelled { cancelled_at } - } - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("stage") - .row(id), - ); - } - }; - - let pkce = match (value.code_challenge, value.code_challenge_method) { - (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { - Some(Pkce { - challenge_method: PkceCodeChallengeMethod::Plain, - challenge, - }) - } - (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce { - challenge_method: PkceCodeChallengeMethod::S256, - challenge, - }), - (None, None) => None, - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("code_challenge_method") - .row(id), - ); - } - }; - - let code: Option = - match (value.response_type_code, value.authorization_code, pkce) { - (false, None, None) => None, - (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }), - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("authorization_code") - .row(id), - ); - } - }; - - let redirect_uri = value.redirect_uri.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("redirect_uri") - .row(id) - .source(e) - })?; - - let response_mode = value.response_mode.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("response_mode") - .row(id) - .source(e) - })?; - - let max_age = value - .max_age - .map(u32::try_from) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("max_age") - .row(id) - .source(e) - })? - .map(NonZeroU32::try_from) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("max_age") - .row(id) - .source(e) - })?; - - Ok(AuthorizationGrant { - id, - stage, - client_id: value.oauth2_client_id.into(), - code, - scope, - state: value.state, - nonce: value.nonce, - max_age, - response_mode, - redirect_uri, - created_at: value.created_at, - response_type_id_token: value.response_type_id_token, - requires_consent: value.requires_consent, - }) - } -} - -#[async_trait] -impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.oauth2_authorization_grant.add", - skip_all, - fields( - db.statement, - grant.id, - grant.scope = %scope, - %client.id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - client: &Client, - redirect_uri: Url, - scope: Scope, - code: Option, - state: Option, - nonce: Option, - max_age: Option, - response_mode: ResponseMode, - response_type_id_token: bool, - requires_consent: bool, - ) -> Result { - let code_challenge = code - .as_ref() - .and_then(|c| c.pkce.as_ref()) - .map(|p| &p.challenge); - let code_challenge_method = code - .as_ref() - .and_then(|c| c.pkce.as_ref()) - .map(|p| p.challenge_method.to_string()); - // TODO: this conversion is a bit ugly - let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)); - let code_str = code.as_ref().map(|c| &c.code); - - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("grant.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO oauth2_authorization_grants ( - oauth2_authorization_grant_id, - oauth2_client_id, - redirect_uri, - scope, - state, - nonce, - max_age, - response_mode, - code_challenge, - code_challenge_method, - response_type_code, - response_type_id_token, - authorization_code, - requires_consent, - created_at - ) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) - "#, - Uuid::from(id), - Uuid::from(client.id), - redirect_uri.to_string(), - scope.to_string(), - state, - nonce, - max_age_i32, - response_mode.to_string(), - code_challenge, - code_challenge_method, - code.is_some(), - response_type_id_token, - code_str, - requires_consent, - created_at, - ) - .execute(&mut *self.conn) - .await?; - - Ok(AuthorizationGrant { - id, - stage: AuthorizationGrantStage::Pending, - code, - redirect_uri, - client_id: client.id, - scope, - state, - nonce, - max_age, - response_mode, - created_at, - response_type_id_token, - requires_consent, - }) - } - - #[tracing::instrument( - name = "db.oauth2_authorization_grant.lookup", - skip_all, - fields( - db.statement, - grant.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - GrantLookup, - r#" - SELECT oauth2_authorization_grant_id - , created_at - , cancelled_at - , fulfilled_at - , exchanged_at - , scope - , state - , redirect_uri - , response_mode - , nonce - , max_age - , oauth2_client_id - , authorization_code - , response_type_code - , response_type_id_token - , code_challenge - , code_challenge_method - , requires_consent - , oauth2_session_id - FROM - oauth2_authorization_grants - - WHERE oauth2_authorization_grant_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.oauth2_authorization_grant.find_by_code", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn find_by_code( - &mut self, - code: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - GrantLookup, - r#" - SELECT oauth2_authorization_grant_id - , created_at - , cancelled_at - , fulfilled_at - , exchanged_at - , scope - , state - , redirect_uri - , response_mode - , nonce - , max_age - , oauth2_client_id - , authorization_code - , response_type_code - , response_type_id_token - , code_challenge - , code_challenge_method - , requires_consent - , oauth2_session_id - FROM - oauth2_authorization_grants - - WHERE authorization_code = $1 - "#, - code, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.oauth2_authorization_grant.fulfill", - skip_all, - fields( - db.statement, - %grant.id, - client.id = %grant.client_id, - %session.id, - user_session.id = %session.user_session_id, - ), - err, - )] - async fn fulfill( - &mut self, - clock: &Clock, - session: &Session, - grant: AuthorizationGrant, - ) -> Result { - let fulfilled_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_authorization_grants - SET fulfilled_at = $2 - , oauth2_session_id = $3 - WHERE oauth2_authorization_grant_id = $1 - "#, - Uuid::from(grant.id), - fulfilled_at, - Uuid::from(session.id), - ) - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - // XXX: check affected rows & new methods - let grant = grant - .fulfill(fulfilled_at, session) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(grant) - } - - #[tracing::instrument( - name = "db.oauth2_authorization_grant.exchange", - skip_all, - fields( - db.statement, - %grant.id, - client.id = %grant.client_id, - ), - err, - )] - async fn exchange( - &mut self, - clock: &Clock, - grant: AuthorizationGrant, - ) -> Result { - let exchanged_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_authorization_grants - SET exchanged_at = $2 - WHERE oauth2_authorization_grant_id = $1 - "#, - Uuid::from(grant.id), - exchanged_at, - ) - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - let grant = grant - .exchange(exchanged_at) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(grant) - } - - #[tracing::instrument( - name = "db.oauth2_authorization_grant.give_consent", - skip_all, - fields( - db.statement, - %grant.id, - client.id = %grant.client_id, - ), - err, - )] - async fn give_consent( - &mut self, - mut grant: AuthorizationGrant, - ) -> Result { - sqlx::query!( - r#" - UPDATE oauth2_authorization_grants AS og - SET - requires_consent = 'f' - WHERE - og.oauth2_authorization_grant_id = $1 - "#, - Uuid::from(grant.id), - ) - .execute(&mut *self.conn) - .await?; - - grant.requires_consent = false; - - Ok(grant) - } -} diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 756017b8e..093369a40 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,33 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - collections::{BTreeMap, BTreeSet}, - str::FromStr, - string::ToString, -}; +use std::collections::{BTreeMap, BTreeSet}; use async_trait::async_trait; -use mas_data_model::{Client, JwksOrJwksUri, User}; -use mas_iana::{ - jose::JsonWebSignatureAlg, - oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, -}; +use mas_data_model::{Client, User}; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_jose::jwk::PublicJsonWebKeySet; -use oauth2_types::{ - requests::GrantType, - scope::{Scope, ScopeToken}, -}; +use oauth2_types::{requests::GrantType, scope::Scope}; use rand::{Rng, RngCore}; -use sqlx::PgConnection; -use tracing::{info_span, Instrument}; use ulid::Ulid; use url::Url; -use uuid::Uuid; -use crate::{ - tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::Clock; #[async_trait] pub trait OAuth2ClientRepository: Send + Sync { @@ -107,708 +92,3 @@ pub trait OAuth2ClientRepository: Send + Sync { scope: &Scope, ) -> Result<(), Self::Error>; } - -pub struct PgOAuth2ClientRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgOAuth2ClientRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -// XXX: response_types & contacts -#[derive(Debug)] -struct OAuth2ClientLookup { - oauth2_client_id: Uuid, - encrypted_client_secret: Option, - redirect_uris: Vec, - // response_types: Vec, - grant_type_authorization_code: bool, - grant_type_refresh_token: bool, - // contacts: Vec, - client_name: Option, - logo_uri: Option, - client_uri: Option, - policy_uri: Option, - tos_uri: Option, - jwks_uri: Option, - jwks: Option, - id_token_signed_response_alg: Option, - userinfo_signed_response_alg: Option, - token_endpoint_auth_method: Option, - token_endpoint_auth_signing_alg: Option, - initiate_login_uri: Option, -} - -impl TryInto for OAuth2ClientLookup { - type Error = DatabaseInconsistencyError; - - #[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing - fn try_into(self) -> Result { - let id = Ulid::from(self.oauth2_client_id); - - let redirect_uris: Result, _> = - self.redirect_uris.iter().map(|s| s.parse()).collect(); - let redirect_uris = redirect_uris.map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("redirect_uris") - .row(id) - .source(e) - })?; - - let response_types = vec![ - OAuthAuthorizationEndpointResponseType::Code, - OAuthAuthorizationEndpointResponseType::IdToken, - OAuthAuthorizationEndpointResponseType::None, - ]; - /* XXX - let response_types: Result, _> = - self.response_types.iter().map(|s| s.parse()).collect(); - let response_types = response_types.map_err(|source| ClientFetchError::ParseField { - field: "response_types", - source, - })?; - */ - - let mut grant_types = Vec::new(); - if self.grant_type_authorization_code { - grant_types.push(GrantType::AuthorizationCode); - } - if self.grant_type_refresh_token { - grant_types.push(GrantType::RefreshToken); - } - - let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("logo_uri") - .row(id) - .source(e) - })?; - - let client_uri = self - .client_uri - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("client_uri") - .row(id) - .source(e) - })?; - - let policy_uri = self - .policy_uri - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("policy_uri") - .row(id) - .source(e) - })?; - - let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("tos_uri") - .row(id) - .source(e) - })?; - - let id_token_signed_response_alg = self - .id_token_signed_response_alg - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("id_token_signed_response_alg") - .row(id) - .source(e) - })?; - - let userinfo_signed_response_alg = self - .userinfo_signed_response_alg - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("userinfo_signed_response_alg") - .row(id) - .source(e) - })?; - - let token_endpoint_auth_method = self - .token_endpoint_auth_method - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("token_endpoint_auth_method") - .row(id) - .source(e) - })?; - - let token_endpoint_auth_signing_alg = self - .token_endpoint_auth_signing_alg - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("token_endpoint_auth_signing_alg") - .row(id) - .source(e) - })?; - - let initiate_login_uri = self - .initiate_login_uri - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("initiate_login_uri") - .row(id) - .source(e) - })?; - - let jwks = match (self.jwks, self.jwks_uri) { - (None, None) => None, - (Some(jwks), None) => { - let jwks = serde_json::from_value(jwks).map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("jwks") - .row(id) - .source(e) - })?; - Some(JwksOrJwksUri::Jwks(jwks)) - } - (None, Some(jwks_uri)) => { - let jwks_uri = jwks_uri.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("jwks_uri") - .row(id) - .source(e) - })?; - - Some(JwksOrJwksUri::JwksUri(jwks_uri)) - } - _ => { - return Err(DatabaseInconsistencyError::on("oauth2_clients") - .column("jwks(_uri)") - .row(id)) - } - }; - - Ok(Client { - id, - client_id: id.to_string(), - encrypted_client_secret: self.encrypted_client_secret, - redirect_uris, - response_types, - grant_types, - // contacts: self.contacts, - contacts: vec![], - client_name: self.client_name, - logo_uri, - client_uri, - policy_uri, - tos_uri, - jwks, - id_token_signed_response_alg, - userinfo_signed_response_alg, - token_endpoint_auth_method, - token_endpoint_auth_signing_alg, - initiate_login_uri, - }) - } -} - -#[async_trait] -impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.oauth2_client.lookup", - skip_all, - fields( - db.statement, - oauth2_client.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - OAuth2ClientLookup, - r#" - SELECT oauth2_client_id - , encrypted_client_secret - , ARRAY( - SELECT redirect_uri - FROM oauth2_client_redirect_uris r - WHERE r.oauth2_client_id = c.oauth2_client_id - ) AS "redirect_uris!" - , grant_type_authorization_code - , grant_type_refresh_token - , client_name - , logo_uri - , client_uri - , policy_uri - , tos_uri - , jwks_uri - , jwks - , id_token_signed_response_alg - , userinfo_signed_response_alg - , token_endpoint_auth_method - , token_endpoint_auth_signing_alg - , initiate_login_uri - FROM oauth2_clients c - - WHERE oauth2_client_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.oauth2_client.load_batch", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn load_batch( - &mut self, - ids: BTreeSet, - ) -> Result, Self::Error> { - let ids: Vec = ids.into_iter().map(Uuid::from).collect(); - let res = sqlx::query_as!( - OAuth2ClientLookup, - r#" - SELECT oauth2_client_id - , encrypted_client_secret - , ARRAY( - SELECT redirect_uri - FROM oauth2_client_redirect_uris r - WHERE r.oauth2_client_id = c.oauth2_client_id - ) AS "redirect_uris!" - , grant_type_authorization_code - , grant_type_refresh_token - , client_name - , logo_uri - , client_uri - , policy_uri - , tos_uri - , jwks_uri - , jwks - , id_token_signed_response_alg - , userinfo_signed_response_alg - , token_endpoint_auth_method - , token_endpoint_auth_signing_alg - , initiate_login_uri - FROM oauth2_clients c - - WHERE oauth2_client_id = ANY($1::uuid[]) - "#, - &ids, - ) - .traced() - .fetch_all(&mut *self.conn) - .await?; - - res.into_iter() - .map(|r| { - r.try_into() - .map(|c: Client| (c.id, c)) - .map_err(DatabaseError::from) - }) - .collect() - } - - #[tracing::instrument( - name = "db.oauth2_client.add", - skip_all, - fields( - db.statement, - client.id, - client.name = client_name - ), - err, - )] - #[allow(clippy::too_many_lines)] - async fn add( - &mut self, - mut rng: &mut (dyn RngCore + Send), - clock: &Clock, - redirect_uris: Vec, - encrypted_client_secret: Option, - grant_types: Vec, - contacts: Vec, - client_name: Option, - logo_uri: Option, - client_uri: Option, - policy_uri: Option, - tos_uri: Option, - jwks_uri: Option, - jwks: Option, - id_token_signed_response_alg: Option, - userinfo_signed_response_alg: Option, - token_endpoint_auth_method: Option, - token_endpoint_auth_signing_alg: Option, - initiate_login_uri: Option, - ) -> Result { - let now = clock.now(); - let id = Ulid::from_datetime_with_source(now.into(), rng); - tracing::Span::current().record("client.id", tracing::field::display(id)); - - let jwks_json = jwks - .as_ref() - .map(serde_json::to_value) - .transpose() - .map_err(DatabaseError::to_invalid_operation)?; - - sqlx::query!( - r#" - INSERT INTO oauth2_clients - ( oauth2_client_id - , encrypted_client_secret - , grant_type_authorization_code - , grant_type_refresh_token - , client_name - , logo_uri - , client_uri - , policy_uri - , tos_uri - , jwks_uri - , jwks - , id_token_signed_response_alg - , userinfo_signed_response_alg - , token_endpoint_auth_method - , token_endpoint_auth_signing_alg - , initiate_login_uri - ) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) - "#, - Uuid::from(id), - encrypted_client_secret, - grant_types.contains(&GrantType::AuthorizationCode), - grant_types.contains(&GrantType::RefreshToken), - client_name, - logo_uri.as_ref().map(Url::as_str), - client_uri.as_ref().map(Url::as_str), - policy_uri.as_ref().map(Url::as_str), - tos_uri.as_ref().map(Url::as_str), - jwks_uri.as_ref().map(Url::as_str), - jwks_json, - id_token_signed_response_alg - .as_ref() - .map(ToString::to_string), - userinfo_signed_response_alg - .as_ref() - .map(ToString::to_string), - token_endpoint_auth_method.as_ref().map(ToString::to_string), - token_endpoint_auth_signing_alg - .as_ref() - .map(ToString::to_string), - initiate_login_uri.as_ref().map(Url::as_str), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - { - let span = info_span!( - "db.oauth2_client.add.redirect_uris", - db.statement = tracing::field::Empty, - client.id = %id, - ); - - let (uri_ids, redirect_uris): (Vec, Vec) = redirect_uris - .iter() - .map(|uri| { - ( - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), - uri.as_str().to_owned(), - ) - }) - .unzip(); - - sqlx::query!( - r#" - INSERT INTO oauth2_client_redirect_uris - (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) - SELECT id, $2, redirect_uri - FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) - "#, - &uri_ids, - Uuid::from(id), - &redirect_uris, - ) - .record(&span) - .execute(&mut *self.conn) - .instrument(span) - .await?; - } - - let jwks = match (jwks, jwks_uri) { - (None, None) => None, - (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), - (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), - _ => return Err(DatabaseError::invalid_operation()), - }; - - Ok(Client { - id, - client_id: id.to_string(), - encrypted_client_secret, - redirect_uris, - response_types: vec![ - OAuthAuthorizationEndpointResponseType::Code, - OAuthAuthorizationEndpointResponseType::IdToken, - OAuthAuthorizationEndpointResponseType::None, - ], - grant_types, - contacts, - client_name, - logo_uri, - client_uri, - policy_uri, - tos_uri, - jwks, - id_token_signed_response_alg, - userinfo_signed_response_alg, - token_endpoint_auth_method, - token_endpoint_auth_signing_alg, - initiate_login_uri, - }) - } - - #[tracing::instrument( - name = "db.oauth2_client.add_from_config", - skip_all, - fields( - db.statement, - client.id = %client_id, - ), - err, - )] - async fn add_from_config( - &mut self, - mut rng: impl Rng + Send, - clock: &Clock, - client_id: Ulid, - client_auth_method: OAuthClientAuthenticationMethod, - encrypted_client_secret: Option, - jwks: Option, - jwks_uri: Option, - redirect_uris: Vec, - ) -> Result { - let jwks_json = jwks - .as_ref() - .map(serde_json::to_value) - .transpose() - .map_err(DatabaseError::to_invalid_operation)?; - - let client_auth_method = client_auth_method.to_string(); - - sqlx::query!( - r#" - INSERT INTO oauth2_clients - ( oauth2_client_id - , encrypted_client_secret - , grant_type_authorization_code - , grant_type_refresh_token - , token_endpoint_auth_method - , jwks - , jwks_uri - ) - VALUES - ($1, $2, $3, $4, $5, $6, $7) - ON CONFLICT (oauth2_client_id) - DO - UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret - , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code - , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token - , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method - , jwks = EXCLUDED.jwks - , jwks_uri = EXCLUDED.jwks_uri - "#, - Uuid::from(client_id), - encrypted_client_secret, - true, - true, - client_auth_method, - jwks_json, - jwks_uri.as_ref().map(Url::as_str), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - { - let span = info_span!( - "db.oauth2_client.add_from_config.redirect_uris", - client.id = %client_id, - db.statement = tracing::field::Empty, - ); - - let now = clock.now(); - let (ids, redirect_uris): (Vec, Vec) = redirect_uris - .iter() - .map(|uri| { - ( - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), - uri.as_str().to_owned(), - ) - }) - .unzip(); - - sqlx::query!( - r#" - INSERT INTO oauth2_client_redirect_uris - (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) - SELECT id, $2, redirect_uri - FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) - "#, - &ids, - Uuid::from(client_id), - &redirect_uris, - ) - .record(&span) - .execute(&mut *self.conn) - .instrument(span) - .await?; - } - - let jwks = match (jwks, jwks_uri) { - (None, None) => None, - (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), - (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), - _ => return Err(DatabaseError::invalid_operation()), - }; - - Ok(Client { - id: client_id, - client_id: client_id.to_string(), - encrypted_client_secret, - redirect_uris, - response_types: vec![ - OAuthAuthorizationEndpointResponseType::Code, - OAuthAuthorizationEndpointResponseType::IdToken, - OAuthAuthorizationEndpointResponseType::None, - ], - grant_types: Vec::new(), - contacts: Vec::new(), - client_name: None, - logo_uri: None, - client_uri: None, - policy_uri: None, - tos_uri: None, - jwks, - id_token_signed_response_alg: None, - userinfo_signed_response_alg: None, - token_endpoint_auth_method: None, - token_endpoint_auth_signing_alg: None, - initiate_login_uri: None, - }) - } - - #[tracing::instrument( - name = "db.oauth2_client.get_consent_for_user", - skip_all, - fields( - db.statement, - %user.id, - %client.id, - ), - err, - )] - async fn get_consent_for_user( - &mut self, - client: &Client, - user: &User, - ) -> Result { - let scope_tokens: Vec = sqlx::query_scalar!( - r#" - SELECT scope_token - FROM oauth2_consents - WHERE user_id = $1 AND oauth2_client_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(client.id), - ) - .fetch_all(&mut *self.conn) - .await?; - - let scope: Result = scope_tokens - .into_iter() - .map(|s| ScopeToken::from_str(&s)) - .collect(); - - let scope = scope.map_err(|e| { - DatabaseInconsistencyError::on("oauth2_consents") - .column("scope_token") - .source(e) - })?; - - Ok(scope) - } - - #[tracing::instrument( - skip_all, - fields( - db.statement, - %user.id, - %client.id, - %scope, - ), - err, - )] - async fn give_consent_for_user( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - client: &Client, - user: &User, - scope: &Scope, - ) -> Result<(), Self::Error> { - let now = clock.now(); - let (tokens, ids): (Vec, Vec) = scope - .iter() - .map(|token| { - ( - token.to_string(), - Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)), - ) - }) - .unzip(); - - sqlx::query!( - r#" - INSERT INTO oauth2_consents - (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at) - SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token) - ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5 - "#, - &ids, - Uuid::from(user.id), - Uuid::from(client.id), - &tokens, - now, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(()) - } -} diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index 480c45155..eaa5e3172 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,11 +19,7 @@ mod refresh_token; mod session; pub use self::{ - access_token::{OAuth2AccessTokenRepository, PgOAuth2AccessTokenRepository}, - authorization_grant::{ - OAuth2AuthorizationGrantRepository, PgOAuth2AuthorizationGrantRepository, - }, - client::{OAuth2ClientRepository, PgOAuth2ClientRepository}, - refresh_token::{OAuth2RefreshTokenRepository, PgOAuth2RefreshTokenRepository}, - session::{OAuth2SessionRepository, PgOAuth2SessionRepository}, + access_token::OAuth2AccessTokenRepository, + authorization_grant::OAuth2AuthorizationGrantRepository, client::OAuth2ClientRepository, + refresh_token::OAuth2RefreshTokenRepository, session::OAuth2SessionRepository, }; diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 5d3bb0133..1e23634aa 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,14 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session}; +use mas_data_model::{AccessToken, RefreshToken, Session}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; +use crate::Clock; #[async_trait] pub trait OAuth2RefreshTokenRepository: Send + Sync { @@ -52,203 +49,3 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync { refresh_token: RefreshToken, ) -> Result; } - -pub struct PgOAuth2RefreshTokenRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgOAuth2RefreshTokenRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct OAuth2RefreshTokenLookup { - oauth2_refresh_token_id: Uuid, - refresh_token: String, - created_at: DateTime, - consumed_at: Option>, - oauth2_access_token_id: Option, - oauth2_session_id: Uuid, -} - -impl From for RefreshToken { - fn from(value: OAuth2RefreshTokenLookup) -> Self { - let state = match value.consumed_at { - None => RefreshTokenState::Valid, - Some(consumed_at) => RefreshTokenState::Consumed { consumed_at }, - }; - - RefreshToken { - id: value.oauth2_refresh_token_id.into(), - state, - session_id: value.oauth2_session_id.into(), - refresh_token: value.refresh_token, - created_at: value.created_at, - access_token_id: value.oauth2_access_token_id.map(Ulid::from), - } - } -} - -#[async_trait] -impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.oauth2_refresh_token.lookup", - skip_all, - fields( - db.statement, - refresh_token.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - OAuth2RefreshTokenLookup, - r#" - SELECT oauth2_refresh_token_id - , refresh_token - , created_at - , consumed_at - , oauth2_access_token_id - , oauth2_session_id - FROM oauth2_refresh_tokens - - WHERE oauth2_refresh_token_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.oauth2_refresh_token.find_by_token", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn find_by_token( - &mut self, - refresh_token: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - OAuth2RefreshTokenLookup, - r#" - SELECT oauth2_refresh_token_id - , refresh_token - , created_at - , consumed_at - , oauth2_access_token_id - , oauth2_session_id - FROM oauth2_refresh_tokens - - WHERE refresh_token = $1 - "#, - refresh_token, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.oauth2_refresh_token.add", - skip_all, - fields( - db.statement, - %session.id, - user_session.id = %session.user_session_id, - client.id = %session.client_id, - refresh_token.id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - session: &Session, - access_token: &AccessToken, - refresh_token: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("refresh_token.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO oauth2_refresh_tokens - (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id, - refresh_token, created_at) - VALUES - ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - Uuid::from(access_token.id), - refresh_token, - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(RefreshToken { - id, - state: RefreshTokenState::default(), - session_id: session.id, - refresh_token, - access_token_id: Some(access_token.id), - created_at, - }) - } - - #[tracing::instrument( - name = "db.oauth2_refresh_token.consume", - skip_all, - fields( - db.statement, - %refresh_token.id, - session.id = %refresh_token.session_id, - ), - err, - )] - async fn consume( - &mut self, - clock: &Clock, - refresh_token: RefreshToken, - ) -> Result { - let consumed_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_refresh_tokens - SET consumed_at = $2 - WHERE oauth2_refresh_token_id = $1 - "#, - Uuid::from(refresh_token.id), - consumed_at, - ) - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - refresh_token - .consume(consumed_at) - .map_err(DatabaseError::to_invalid_operation) - } -} diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index dc21fbcb3..5e6498d8b 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,18 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User}; +use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User}; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{Page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, -}; +use crate::{pagination::Page, Clock, Pagination}; #[async_trait] pub trait OAuth2SessionRepository: Send + Sync { @@ -48,224 +41,3 @@ pub trait OAuth2SessionRepository: Send + Sync { pagination: Pagination, ) -> Result, Self::Error>; } - -pub struct PgOAuth2SessionRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgOAuth2SessionRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(sqlx::FromRow)] -struct OAuthSessionLookup { - oauth2_session_id: Uuid, - user_session_id: Uuid, - oauth2_client_id: Uuid, - scope: String, - #[allow(dead_code)] - created_at: DateTime, - finished_at: Option>, -} - -impl TryFrom for Session { - type Error = DatabaseInconsistencyError; - - fn try_from(value: OAuthSessionLookup) -> Result { - let id = Ulid::from(value.oauth2_session_id); - let scope = value.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("scope") - .row(id) - .source(e) - })?; - - let state = match value.finished_at { - None => SessionState::Valid, - Some(finished_at) => SessionState::Finished { finished_at }, - }; - - Ok(Session { - id, - state, - created_at: value.created_at, - client_id: value.oauth2_client_id.into(), - user_session_id: value.user_session_id.into(), - scope, - }) - } -} - -#[async_trait] -impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.oauth2_session.lookup", - skip_all, - fields( - db.statement, - session.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - OAuthSessionLookup, - r#" - SELECT oauth2_session_id - , user_session_id - , oauth2_client_id - , scope - , created_at - , finished_at - FROM oauth2_sessions - - WHERE oauth2_session_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(session) = res else { return Ok(None) }; - - Ok(Some(session.try_into()?)) - } - - #[tracing::instrument( - name = "db.oauth2_session.create_from_grant", - skip_all, - fields( - db.statement, - %user_session.id, - user.id = %user_session.user.id, - %grant.id, - client.id = %grant.client_id, - session.id, - session.scope = %grant.scope, - ), - err, - )] - async fn create_from_grant( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - grant: &AuthorizationGrant, - user_session: &BrowserSession, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO oauth2_sessions - ( oauth2_session_id - , user_session_id - , oauth2_client_id - , scope - , created_at - ) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(user_session.id), - Uuid::from(grant.client_id), - grant.scope.to_string(), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(Session { - id, - state: SessionState::Valid, - created_at, - user_session_id: user_session.id, - client_id: grant.client_id, - scope: grant.scope.clone(), - }) - } - - #[tracing::instrument( - name = "db.oauth2_session.finish", - skip_all, - fields( - db.statement, - %session.id, - %session.scope, - user_session.id = %session.user_session_id, - client.id = %session.client_id, - ), - err, - )] - async fn finish(&mut self, clock: &Clock, session: Session) -> Result { - let finished_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_sessions - SET finished_at = $2 - WHERE oauth2_session_id = $1 - "#, - Uuid::from(session.id), - finished_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - session - .finish(finished_at) - .map_err(DatabaseError::to_invalid_operation) - } - - #[tracing::instrument( - name = "db.oauth2_session.list_paginated", - skip_all, - fields( - db.statement, - %user.id, - %user.username, - ), - err, - )] - async fn list_paginated( - &mut self, - user: &User, - pagination: Pagination, - ) -> Result, Self::Error> { - let mut query = QueryBuilder::new( - r#" - SELECT oauth2_session_id - , user_session_id - , oauth2_client_id - , scope - , created_at - , finished_at - FROM oauth2_sessions os - "#, - ); - - query - .push(" WHERE us.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("oauth2_session_id", pagination); - - let edges: Vec = query - .build_query_as() - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let page = pagination.process(edges).try_map(Session::try_from)?; - Ok(page) - } -} diff --git a/crates/storage/src/pagination.rs b/crates/storage/src/pagination.rs index 1fa74ddac..6af456418 100644 --- a/crates/storage/src/pagination.rs +++ b/crates/storage/src/pagination.rs @@ -14,10 +14,8 @@ //! Utilities to manage paginated queries. -use sqlx::{Database, QueryBuilder}; use thiserror::Error; use ulid::Ulid; -use uuid::Uuid; /// An error returned when invalid pagination parameters are provided #[derive(Debug, Error)] @@ -26,14 +24,14 @@ pub struct InvalidPagination; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Pagination { - before: Option, - after: Option, - count: usize, - direction: PaginationDirection, + pub before: Option, + pub after: Option, + pub count: usize, + pub direction: PaginationDirection, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum PaginationDirection { +pub enum PaginationDirection { Forward, Backward, } @@ -101,60 +99,8 @@ impl Pagination { self } - /// Add cursor-based pagination to a query, as used in paginated GraphQL - /// connections - fn generate_pagination<'a, DB>(&self, query: &mut QueryBuilder<'a, DB>, id_field: &'static str) - where - DB: Database, - Uuid: sqlx::Type + sqlx::Encode<'a, DB>, - i64: sqlx::Type + sqlx::Encode<'a, DB>, - { - // ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564 - // 1. Start from the greedy query: SELECT * FROM table - - // 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE` - // clause - if let Some(after) = self.after { - query - .push(" AND ") - .push(id_field) - .push(" > ") - .push_bind(Uuid::from(after)); - } - - // 3. If the before argument is provided, add `id < parsed_cursor` to the - // `WHERE` clause - if let Some(before) = self.before { - query - .push(" AND ") - .push(id_field) - .push(" < ") - .push_bind(Uuid::from(before)); - } - - match self.direction { - // 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the - // query - PaginationDirection::Forward => { - query - .push(" ORDER BY ") - .push(id_field) - .push(" ASC LIMIT ") - .push_bind((self.count + 1) as i64); - } - // 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the - // query - PaginationDirection::Backward => { - query - .push(" ORDER BY ") - .push(id_field) - .push(" DESC LIMIT ") - .push_bind((self.count + 1) as i64); - } - }; - } - /// Process a page returned by a paginated query + #[must_use] pub fn process(&self, mut edges: Vec) -> Page { let is_full = edges.len() == (self.count + 1); if is_full { @@ -198,7 +144,6 @@ impl Page { } } - #[must_use] pub fn try_map(self, f: F) -> Result, E> where F: FnMut(T) -> Result, @@ -211,23 +156,3 @@ impl Page { }) } } - -/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination -/// to a query -pub trait QueryBuilderExt { - /// Add cursor-based pagination to a query, as used in paginated GraphQL - /// connections - fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self; -} - -impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB> -where - DB: Database, - Uuid: sqlx::Type + sqlx::Encode<'a, DB>, - i64: sqlx::Type + sqlx::Encode<'a, DB>, -{ - fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self { - pagination.generate_pagination(self, id_field); - self - } -} diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 1fde4b417..55afe41b4 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,31 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -use sqlx::{PgPool, Postgres, Transaction}; - use crate::{ compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, - CompatSsoLoginRepository, PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, - PgCompatSessionRepository, PgCompatSsoLoginRepository, + CompatSsoLoginRepository, }, oauth2::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, - OAuth2RefreshTokenRepository, OAuth2SessionRepository, PgOAuth2AccessTokenRepository, - PgOAuth2AuthorizationGrantRepository, PgOAuth2ClientRepository, - PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, + OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, upstream_oauth2::{ - PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, - PgUpstreamOAuthSessionRepository, UpstreamOAuthLinkRepository, - UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, }, - user::{ - BrowserSessionRepository, PgBrowserSessionRepository, PgUserEmailRepository, - PgUserPasswordRepository, PgUserRepository, UserEmailRepository, UserPasswordRepository, - UserRepository, - }, - DatabaseError, + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, }; pub trait Repository: Send { @@ -126,109 +115,3 @@ pub trait Repository: Send { fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>; fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>; } - -pub struct PgRepository { - txn: Transaction<'static, Postgres>, -} - -impl PgRepository { - pub async fn from_pool(pool: &PgPool) -> Result { - let txn = pool.begin().await?; - Ok(PgRepository { txn }) - } - - pub async fn save(self) -> Result<(), DatabaseError> { - self.txn.commit().await?; - Ok(()) - } - - pub async fn cancel(self) -> Result<(), DatabaseError> { - self.txn.rollback().await?; - Ok(()) - } -} - -impl Repository for PgRepository { - type Error = DatabaseError; - - type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; - type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; - type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; - type UserRepository<'c> = PgUserRepository<'c> where Self: 'c; - type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; - type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; - type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; - type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; - type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c; - type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; - type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c; - type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c; - type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; - type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; - type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; - type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; - - fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { - PgUpstreamOAuthLinkRepository::new(&mut self.txn) - } - - fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { - PgUpstreamOAuthProviderRepository::new(&mut self.txn) - } - - fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { - PgUpstreamOAuthSessionRepository::new(&mut self.txn) - } - - fn user(&mut self) -> Self::UserRepository<'_> { - PgUserRepository::new(&mut self.txn) - } - - fn user_email(&mut self) -> Self::UserEmailRepository<'_> { - PgUserEmailRepository::new(&mut self.txn) - } - - fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { - PgUserPasswordRepository::new(&mut self.txn) - } - - fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { - PgBrowserSessionRepository::new(&mut self.txn) - } - - fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { - PgOAuth2ClientRepository::new(&mut self.txn) - } - - fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { - PgOAuth2AuthorizationGrantRepository::new(&mut self.txn) - } - - fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { - PgOAuth2SessionRepository::new(&mut self.txn) - } - - fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { - PgOAuth2AccessTokenRepository::new(&mut self.txn) - } - - fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { - PgOAuth2RefreshTokenRepository::new(&mut self.txn) - } - - fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { - PgCompatSessionRepository::new(&mut self.txn) - } - - fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { - PgCompatSsoLoginRepository::new(&mut self.txn) - } - - fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { - PgCompatAccessTokenRepository::new(&mut self.txn) - } - - fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { - PgCompatRefreshTokenRepository::new(&mut self.txn) - } -} diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 76364afe6..bc20c6eaf 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,18 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{Page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, LookupResultExt, Pagination, -}; +use crate::{pagination::Page, Clock, Pagination}; #[async_trait] pub trait UpstreamOAuthLinkRepository: Send + Sync { @@ -63,241 +56,3 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { pagination: Pagination, ) -> Result, Self::Error>; } - -pub struct PgUpstreamOAuthLinkRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgUpstreamOAuthLinkRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(sqlx::FromRow)] -struct LinkLookup { - upstream_oauth_link_id: Uuid, - upstream_oauth_provider_id: Uuid, - user_id: Option, - subject: String, - created_at: DateTime, -} - -impl From for UpstreamOAuthLink { - fn from(value: LinkLookup) -> Self { - UpstreamOAuthLink { - id: Ulid::from(value.upstream_oauth_link_id), - provider_id: Ulid::from(value.upstream_oauth_provider_id), - user_id: value.user_id.map(Ulid::from), - subject: value.subject, - created_at: value.created_at, - } - } -} - -#[async_trait] -impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.upstream_oauth_link.lookup", - skip_all, - fields( - db.statement, - upstream_oauth_link.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - LinkLookup, - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - WHERE upstream_oauth_link_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()? - .map(Into::into); - - Ok(res) - } - - #[tracing::instrument( - name = "db.upstream_oauth_link.find_by_subject", - skip_all, - fields( - db.statement, - upstream_oauth_link.subject = subject, - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - ), - err, - )] - async fn find_by_subject( - &mut self, - upstream_oauth_provider: &UpstreamOAuthProvider, - subject: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - LinkLookup, - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - WHERE upstream_oauth_provider_id = $1 - AND subject = $2 - "#, - Uuid::from(upstream_oauth_provider.id), - subject, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()? - .map(Into::into); - - Ok(res) - } - - #[tracing::instrument( - name = "db.upstream_oauth_link.add", - skip_all, - fields( - db.statement, - upstream_oauth_link.id, - upstream_oauth_link.subject = subject, - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - upstream_oauth_provider: &UpstreamOAuthProvider, - subject: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO upstream_oauth_links ( - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - ) VALUES ($1, $2, NULL, $3, $4) - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_provider.id), - &subject, - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(UpstreamOAuthLink { - id, - provider_id: upstream_oauth_provider.id, - user_id: None, - subject, - created_at, - }) - } - - #[tracing::instrument( - name = "db.upstream_oauth_link.associate_to_user", - skip_all, - fields( - db.statement, - %upstream_oauth_link.id, - %upstream_oauth_link.subject, - %user.id, - %user.username, - ), - err, - )] - async fn associate_to_user( - &mut self, - upstream_oauth_link: &UpstreamOAuthLink, - user: &User, - ) -> Result<(), Self::Error> { - sqlx::query!( - r#" - UPDATE upstream_oauth_links - SET user_id = $1 - WHERE upstream_oauth_link_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(upstream_oauth_link.id), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(()) - } - - #[tracing::instrument( - name = "db.upstream_oauth_link.list_paginated", - skip_all, - fields( - db.statement, - %user.id, - %user.username, - ), - err - )] - async fn list_paginated( - &mut self, - user: &User, - pagination: Pagination, - ) -> Result, Self::Error> { - let mut query = QueryBuilder::new( - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - "#, - ); - - query - .push(" WHERE user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("upstream_oauth_link_id", pagination); - - let edges: Vec = query - .build_query_as() - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let page = pagination.process(edges).map(UpstreamOAuthLink::from); - Ok(page) - } -} diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index d1b6809f8..1648a6448 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,249 +17,6 @@ mod provider; mod session; pub use self::{ - link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository}, - provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository}, - session::{PgUpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository}, + link::UpstreamOAuthLinkRepository, provider::UpstreamOAuthProviderRepository, + session::UpstreamOAuthSessionRepository, }; - -#[cfg(test)] -mod tests { - use chrono::Duration; - use oauth2_types::scope::{Scope, OPENID}; - use rand::SeedableRng; - use sqlx::PgPool; - - use super::*; - use crate::{user::UserRepository, Clock, Pagination, PgRepository, Repository}; - - #[sqlx::test(migrator = "crate::MIGRATOR")] - async fn test_repository(pool: PgPool) { - let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - - // The provider list should be empty at the start - let all_providers = repo.upstream_oauth_provider().all().await.unwrap(); - assert!(all_providers.is_empty()); - - // Let's add a provider - let provider = repo - .upstream_oauth_provider() - .add( - &mut rng, - &clock, - "https://example.com/".to_owned(), - Scope::from_iter([OPENID]), - mas_iana::oauth::OAuthClientAuthenticationMethod::None, - None, - "client-id".to_owned(), - None, - ) - .await - .unwrap(); - - // Look it up in the database - let provider = repo - .upstream_oauth_provider() - .lookup(provider.id) - .await - .unwrap() - .expect("provider to be found in the database"); - assert_eq!(provider.issuer, "https://example.com/"); - assert_eq!(provider.client_id, "client-id"); - - // Start a session - let session = repo - .upstream_oauth_session() - .add( - &mut rng, - &clock, - &provider, - "some-state".to_owned(), - None, - "some-nonce".to_owned(), - ) - .await - .unwrap(); - - // Look it up in the database - let session = repo - .upstream_oauth_session() - .lookup(session.id) - .await - .unwrap() - .expect("session to be found in the database"); - assert_eq!(session.provider_id, provider.id); - assert_eq!(session.link_id(), None); - assert!(session.is_pending()); - assert!(!session.is_completed()); - assert!(!session.is_consumed()); - - // Create a link - let link = repo - .upstream_oauth_link() - .add(&mut rng, &clock, &provider, "a-subject".to_owned()) - .await - .unwrap(); - - // We can look it up by its ID - repo.upstream_oauth_link() - .lookup(link.id) - .await - .unwrap() - .expect("link to be found in database"); - - // or by its subject - let link = repo - .upstream_oauth_link() - .find_by_subject(&provider, "a-subject") - .await - .unwrap() - .expect("link to be found in database"); - assert_eq!(link.subject, "a-subject"); - assert_eq!(link.provider_id, provider.id); - - let session = repo - .upstream_oauth_session() - .complete_with_link(&clock, session, &link, None) - .await - .unwrap(); - // Reload the session - let session = repo - .upstream_oauth_session() - .lookup(session.id) - .await - .unwrap() - .expect("session to be found in the database"); - assert!(session.is_completed()); - assert!(!session.is_consumed()); - assert_eq!(session.link_id(), Some(link.id)); - - let session = repo - .upstream_oauth_session() - .consume(&clock, session) - .await - .unwrap(); - // Reload the session - let session = repo - .upstream_oauth_session() - .lookup(session.id) - .await - .unwrap() - .expect("session to be found in the database"); - assert!(session.is_consumed()); - - let user = repo - .user() - .add(&mut rng, &clock, "john".to_owned()) - .await - .unwrap(); - repo.upstream_oauth_link() - .associate_to_user(&link, &user) - .await - .unwrap(); - - let links = repo - .upstream_oauth_link() - .list_paginated(&user, Pagination::first(10)) - .await - .unwrap(); - assert!(!links.has_previous_page); - assert!(!links.has_next_page); - assert_eq!(links.edges.len(), 1); - assert_eq!(links.edges[0].id, link.id); - assert_eq!(links.edges[0].user_id, Some(user.id)); - } - - #[sqlx::test(migrator = "crate::MIGRATOR")] - async fn test_provider_repository_pagination(pool: PgPool) { - const ISSUER: &str = "https://example.com/"; - let scope = Scope::from_iter([OPENID]); - - let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - - let mut ids = Vec::with_capacity(20); - // Create 20 providers - for idx in 0..20 { - let client_id = format!("client-{idx}"); - let provider = repo - .upstream_oauth_provider() - .add( - &mut rng, - &clock, - ISSUER.to_owned(), - scope.clone(), - mas_iana::oauth::OAuthClientAuthenticationMethod::None, - None, - client_id, - None, - ) - .await - .unwrap(); - ids.push(provider.id); - clock.advance(Duration::seconds(10)); - } - - // Lookup the first 10 items - let page = repo - .upstream_oauth_provider() - .list_paginated(Pagination::first(10)) - .await - .unwrap(); - - // It returned the first 10 items - assert!(page.has_next_page); - let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); - assert_eq!(&edge_ids, &ids[..10]); - - // Lookup the next 10 items - let page = repo - .upstream_oauth_provider() - .list_paginated(Pagination::first(10).after(ids[9])) - .await - .unwrap(); - - // It returned the next 10 items - assert!(!page.has_next_page); - let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); - assert_eq!(&edge_ids, &ids[10..]); - - // Lookup the last 10 items - let page = repo - .upstream_oauth_provider() - .list_paginated(Pagination::last(10)) - .await - .unwrap(); - - // It returned the last 10 items - assert!(page.has_previous_page); - let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); - assert_eq!(&edge_ids, &ids[10..]); - - // Lookup the previous 10 items - let page = repo - .upstream_oauth_provider() - .list_paginated(Pagination::last(10).before(ids[10])) - .await - .unwrap(); - - // It returned the previous 10 items - assert!(!page.has_previous_page); - let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); - assert_eq!(&edge_ids, &ids[..10]); - - // Lookup 10 items between two IDs - let page = repo - .upstream_oauth_provider() - .list_paginated(Pagination::first(10).after(ids[5]).before(ids[8])) - .await - .unwrap(); - - // It returned the items in between - assert!(!page.has_next_page); - let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); - assert_eq!(&edge_ids, &ids[6..8]); - } -} diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 14bd65471..4be8f1271 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,20 +13,13 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; use mas_data_model::UpstreamOAuthProvider; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use oauth2_types::scope::Scope; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{Page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, -}; +use crate::{pagination::Page, Clock, Pagination}; #[async_trait] pub trait UpstreamOAuthProviderRepository: Send + Sync { @@ -58,247 +51,3 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { /// Get all upstream OAuth providers async fn all(&mut self) -> Result, Self::Error>; } - -pub struct PgUpstreamOAuthProviderRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgUpstreamOAuthProviderRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(sqlx::FromRow)] -struct ProviderLookup { - upstream_oauth_provider_id: Uuid, - issuer: String, - scope: String, - client_id: String, - encrypted_client_secret: Option, - token_endpoint_signing_alg: Option, - token_endpoint_auth_method: String, - created_at: DateTime, -} - -impl TryFrom for UpstreamOAuthProvider { - type Error = DatabaseInconsistencyError; - fn try_from(value: ProviderLookup) -> Result { - let id = value.upstream_oauth_provider_id.into(); - let scope = value.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("scope") - .row(id) - .source(e) - })?; - let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("token_endpoint_auth_method") - .row(id) - .source(e) - })?; - let token_endpoint_signing_alg = value - .token_endpoint_signing_alg - .map(|x| x.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("token_endpoint_signing_alg") - .row(id) - .source(e) - })?; - - Ok(UpstreamOAuthProvider { - id, - issuer: value.issuer, - scope, - client_id: value.client_id, - encrypted_client_secret: value.encrypted_client_secret, - token_endpoint_auth_method, - token_endpoint_signing_alg, - created_at: value.created_at, - }) - } -} - -#[async_trait] -impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.upstream_oauth_provider.lookup", - skip_all, - fields( - db.statement, - upstream_oauth_provider.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - ProviderLookup, - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - WHERE upstream_oauth_provider_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let res = res - .map(UpstreamOAuthProvider::try_from) - .transpose() - .map_err(DatabaseError::from)?; - - Ok(res) - } - - #[tracing::instrument( - name = "db.upstream_oauth_provider.add", - skip_all, - fields( - db.statement, - upstream_oauth_provider.id, - upstream_oauth_provider.issuer = %issuer, - upstream_oauth_provider.client_id = %client_id, - ), - err, - )] - #[allow(clippy::too_many_arguments)] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - issuer: String, - scope: Scope, - token_endpoint_auth_method: OAuthClientAuthenticationMethod, - token_endpoint_signing_alg: Option, - client_id: String, - encrypted_client_secret: Option, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO upstream_oauth_providers ( - upstream_oauth_provider_id, - issuer, - scope, - token_endpoint_auth_method, - token_endpoint_signing_alg, - client_id, - encrypted_client_secret, - created_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - "#, - Uuid::from(id), - &issuer, - scope.to_string(), - token_endpoint_auth_method.to_string(), - token_endpoint_signing_alg.as_ref().map(ToString::to_string), - &client_id, - encrypted_client_secret.as_deref(), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(UpstreamOAuthProvider { - id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at, - }) - } - - #[tracing::instrument( - name = "db.upstream_oauth_provider.list_paginated", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn list_paginated( - &mut self, - pagination: Pagination, - ) -> Result, Self::Error> { - let mut query = QueryBuilder::new( - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - WHERE 1 = 1 - "#, - ); - - query.generate_pagination("upstream_oauth_provider_id", pagination); - - let edges: Vec = query - .build_query_as() - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let page = pagination.process(edges).try_map(TryInto::try_into)?; - Ok(page) - } - - #[tracing::instrument( - name = "db.upstream_oauth_provider.all", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn all(&mut self) -> Result, Self::Error> { - let res = sqlx::query_as!( - ProviderLookup, - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - "#, - ) - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let res: Result, _> = res.into_iter().map(TryInto::try_into).collect(); - Ok(res?) - } -} diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index d5da6ef8b..4d41a8eca 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,19 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{ - UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, - UpstreamOAuthProvider, -}; +use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::Clock; #[async_trait] pub trait UpstreamOAuthSessionRepository: Send + Sync { @@ -64,262 +56,3 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, ) -> Result; } - -pub struct PgUpstreamOAuthSessionRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgUpstreamOAuthSessionRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct SessionLookup { - upstream_oauth_authorization_session_id: Uuid, - upstream_oauth_provider_id: Uuid, - upstream_oauth_link_id: Option, - state: String, - code_challenge_verifier: Option, - nonce: String, - id_token: Option, - created_at: DateTime, - completed_at: Option>, - consumed_at: Option>, -} - -impl TryFrom for UpstreamOAuthAuthorizationSession { - type Error = DatabaseInconsistencyError; - - fn try_from(value: SessionLookup) -> Result { - let id = value.upstream_oauth_authorization_session_id.into(); - let state = match ( - value.upstream_oauth_link_id, - value.id_token, - value.completed_at, - value.consumed_at, - ) { - (None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, - (Some(link_id), id_token, Some(completed_at), None) => { - UpstreamOAuthAuthorizationSessionState::Completed { - completed_at, - link_id: link_id.into(), - id_token, - } - } - (Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => { - UpstreamOAuthAuthorizationSessionState::Consumed { - completed_at, - link_id: link_id.into(), - id_token, - consumed_at, - } - } - _ => { - return Err( - DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id), - ) - } - }; - - Ok(Self { - id, - provider_id: value.upstream_oauth_provider_id.into(), - state_str: value.state, - nonce: value.nonce, - code_challenge_verifier: value.code_challenge_verifier, - created_at: value.created_at, - state, - }) - } -} - -#[async_trait] -impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.upstream_oauth_authorization_session.lookup", - skip_all, - fields( - db.statement, - upstream_oauth_provider.id = %id, - ), - err, - )] - async fn lookup( - &mut self, - id: Ulid, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - SessionLookup, - r#" - SELECT - upstream_oauth_authorization_session_id, - upstream_oauth_provider_id, - upstream_oauth_link_id, - state, - code_challenge_verifier, - nonce, - id_token, - created_at, - completed_at, - consumed_at - FROM upstream_oauth_authorization_sessions - WHERE upstream_oauth_authorization_session_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.upstream_oauth_authorization_session.add", - skip_all, - fields( - db.statement, - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - upstream_oauth_authorization_session.id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - upstream_oauth_provider: &UpstreamOAuthProvider, - state_str: String, - code_challenge_verifier: Option, - nonce: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record( - "upstream_oauth_authorization_session.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO upstream_oauth_authorization_sessions ( - upstream_oauth_authorization_session_id, - upstream_oauth_provider_id, - state, - code_challenge_verifier, - nonce, - created_at, - completed_at, - consumed_at, - id_token - ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL) - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_provider.id), - &state_str, - code_challenge_verifier.as_deref(), - nonce, - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(UpstreamOAuthAuthorizationSession { - id, - state: UpstreamOAuthAuthorizationSessionState::default(), - provider_id: upstream_oauth_provider.id, - state_str, - code_challenge_verifier, - nonce, - created_at, - }) - } - - #[tracing::instrument( - name = "db.upstream_oauth_authorization_session.complete_with_link", - skip_all, - fields( - db.statement, - %upstream_oauth_authorization_session.id, - %upstream_oauth_link.id, - ), - err, - )] - async fn complete_with_link( - &mut self, - clock: &Clock, - upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, - upstream_oauth_link: &UpstreamOAuthLink, - id_token: Option, - ) -> Result { - let completed_at = clock.now(); - - sqlx::query!( - r#" - UPDATE upstream_oauth_authorization_sessions - SET upstream_oauth_link_id = $1, - completed_at = $2, - id_token = $3 - WHERE upstream_oauth_authorization_session_id = $4 - "#, - Uuid::from(upstream_oauth_link.id), - completed_at, - id_token, - Uuid::from(upstream_oauth_authorization_session.id), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - let upstream_oauth_authorization_session = upstream_oauth_authorization_session - .complete(completed_at, upstream_oauth_link, id_token) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(upstream_oauth_authorization_session) - } - - /// Mark a session as consumed - #[tracing::instrument( - name = "db.upstream_oauth_authorization_session.consume", - skip_all, - fields( - db.statement, - %upstream_oauth_authorization_session.id, - ), - err, - )] - async fn consume( - &mut self, - clock: &Clock, - upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, - ) -> Result { - let consumed_at = clock.now(); - sqlx::query!( - r#" - UPDATE upstream_oauth_authorization_sessions - SET consumed_at = $1 - WHERE upstream_oauth_authorization_session_id = $2 - "#, - consumed_at, - Uuid::from(upstream_oauth_authorization_session.id), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - let upstream_oauth_authorization_session = upstream_oauth_authorization_session - .consume(consumed_at) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(upstream_oauth_authorization_session) - } -} diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 8c8efe1b5..41a7d2935 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,19 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState}; +use mas_data_model::{User, UserEmail, UserEmailVerification}; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; -use tracing::{info_span, Instrument}; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{Page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, -}; +use crate::{pagination::Page, Clock, Pagination}; #[async_trait] pub trait UserEmailRepository: Send + Sync { @@ -82,529 +74,3 @@ pub trait UserEmailRepository: Send + Sync { verification: UserEmailVerification, ) -> Result; } - -pub struct PgUserEmailRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgUserEmailRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(Debug, Clone, sqlx::FromRow)] -struct UserEmailLookup { - user_email_id: Uuid, - user_id: Uuid, - email: String, - created_at: DateTime, - confirmed_at: Option>, -} - -impl From for UserEmail { - fn from(e: UserEmailLookup) -> UserEmail { - UserEmail { - id: e.user_email_id.into(), - user_id: e.user_id.into(), - email: e.email, - created_at: e.created_at, - confirmed_at: e.confirmed_at, - } - } -} - -struct UserEmailConfirmationCodeLookup { - user_email_confirmation_code_id: Uuid, - user_email_id: Uuid, - code: String, - created_at: DateTime, - expires_at: DateTime, - consumed_at: Option>, -} - -impl UserEmailConfirmationCodeLookup { - fn into_verification(self, clock: &Clock) -> UserEmailVerification { - let now = clock.now(); - let state = if let Some(when) = self.consumed_at { - UserEmailVerificationState::AlreadyUsed { when } - } else if self.expires_at < now { - UserEmailVerificationState::Expired { - when: self.expires_at, - } - } else { - UserEmailVerificationState::Valid - }; - - UserEmailVerification { - id: self.user_email_confirmation_code_id.into(), - user_email_id: self.user_email_id.into(), - code: self.code, - state, - created_at: self.created_at, - } - } -} - -#[async_trait] -impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.user_email.lookup", - skip_all, - fields( - db.statement, - user_email.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT user_email_id - , user_id - , email - , created_at - , confirmed_at - FROM user_emails - - WHERE user_email_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(user_email) = res else { return Ok(None) }; - - Ok(Some(user_email.into())) - } - - #[tracing::instrument( - name = "db.user_email.find", - skip_all, - fields( - db.statement, - %user.id, - user_email.email = email, - ), - err, - )] - async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT user_email_id - , user_id - , email - , created_at - , confirmed_at - FROM user_emails - - WHERE user_id = $1 AND email = $2 - "#, - Uuid::from(user.id), - email, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(user_email) = res else { return Ok(None) }; - - Ok(Some(user_email.into())) - } - - #[tracing::instrument( - name = "db.user_email.get_primary", - skip_all, - fields( - db.statement, - %user.id, - ), - err, - )] - async fn get_primary(&mut self, user: &User) -> Result, Self::Error> { - let Some(id) = user.primary_user_email_id else { return Ok(None) }; - - let user_email = self.lookup(id).await?.ok_or_else(|| { - DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(user.id) - })?; - - Ok(Some(user_email)) - } - - #[tracing::instrument( - name = "db.user_email.all", - skip_all, - fields( - db.statement, - %user.id, - ), - err, - )] - async fn all(&mut self, user: &User) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT user_email_id - , user_id - , email - , created_at - , confirmed_at - FROM user_emails - - WHERE user_id = $1 - - ORDER BY email ASC - "#, - Uuid::from(user.id), - ) - .traced() - .fetch_all(&mut *self.conn) - .await?; - - Ok(res.into_iter().map(Into::into).collect()) - } - - #[tracing::instrument( - name = "db.user_email.list_paginated", - skip_all, - fields( - db.statement, - %user.id, - ), - err, - )] - async fn list_paginated( - &mut self, - user: &User, - pagination: Pagination, - ) -> Result, DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT user_email_id - , user_id - , email - , created_at - , confirmed_at - FROM user_emails - "#, - ); - - query - .push(" WHERE user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("ue.user_email_id", pagination); - - let edges: Vec = query - .build_query_as() - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let page = pagination.process(edges).map(UserEmail::from); - Ok(page) - } - - #[tracing::instrument( - name = "db.user_email.count", - skip_all, - fields( - db.statement, - %user.id, - ), - err, - )] - async fn count(&mut self, user: &User) -> Result { - let res = sqlx::query_scalar!( - r#" - SELECT COUNT(*) - FROM user_emails - WHERE user_id = $1 - "#, - Uuid::from(user.id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await?; - - let res = res.unwrap_or_default(); - - Ok(res - .try_into() - .map_err(DatabaseError::to_invalid_operation)?) - } - - #[tracing::instrument( - name = "db.user_email.add", - skip_all, - fields( - db.statement, - %user.id, - user_email.id, - user_email.email = email, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - user: &User, - email: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("user_email.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO user_emails (user_email_id, user_id, email, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - Uuid::from(user.id), - &email, - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(UserEmail { - id, - user_id: user.id, - email, - created_at, - confirmed_at: None, - }) - } - - #[tracing::instrument( - name = "db.user_email.remove", - skip_all, - fields( - db.statement, - user.id = %user_email.user_id, - %user_email.id, - %user_email.email, - ), - err, - )] - async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> { - let span = info_span!( - "db.user_email.remove.codes", - db.statement = tracing::field::Empty - ); - sqlx::query!( - r#" - DELETE FROM user_email_confirmation_codes - WHERE user_email_id = $1 - "#, - Uuid::from(user_email.id), - ) - .record(&span) - .execute(&mut *self.conn) - .instrument(span) - .await?; - - let res = sqlx::query!( - r#" - DELETE FROM user_emails - WHERE user_email_id = $1 - "#, - Uuid::from(user_email.id), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - Ok(()) - } - - async fn mark_as_verified( - &mut self, - clock: &Clock, - mut user_email: UserEmail, - ) -> Result { - let confirmed_at = clock.now(); - sqlx::query!( - r#" - UPDATE user_emails - SET confirmed_at = $2 - WHERE user_email_id = $1 - "#, - Uuid::from(user_email.id), - confirmed_at, - ) - .execute(&mut *self.conn) - .await?; - - user_email.confirmed_at = Some(confirmed_at); - Ok(user_email) - } - - async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error> { - sqlx::query!( - r#" - UPDATE users - SET primary_user_email_id = user_emails.user_email_id - FROM user_emails - WHERE user_emails.user_email_id = $1 - AND users.user_id = user_emails.user_id - "#, - Uuid::from(user_email.id), - ) - .execute(&mut *self.conn) - .await?; - - Ok(()) - } - - #[tracing::instrument( - name = "db.user_email.add_verification_code", - skip_all, - fields( - db.statement, - %user_email.id, - %user_email.email, - user_email_verification.id, - user_email_verification.code = code, - ), - err, - )] - async fn add_verification_code( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - user_email: &UserEmail, - max_age: chrono::Duration, - code: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id)); - let expires_at = created_at + max_age; - - sqlx::query!( - r#" - INSERT INTO user_email_confirmation_codes - (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(user_email.id), - code, - created_at, - expires_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - let verification = UserEmailVerification { - id, - user_email_id: user_email.id, - code, - created_at, - state: UserEmailVerificationState::Valid, - }; - - Ok(verification) - } - - #[tracing::instrument( - name = "db.user_email.find_verification_code", - skip_all, - fields( - db.statement, - %user_email.id, - user.id = %user_email.user_id, - ), - err, - )] - async fn find_verification_code( - &mut self, - clock: &Clock, - user_email: &UserEmail, - code: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserEmailConfirmationCodeLookup, - r#" - SELECT user_email_confirmation_code_id - , user_email_id - , code - , created_at - , expires_at - , consumed_at - FROM user_email_confirmation_codes - WHERE code = $1 - AND user_email_id = $2 - "#, - code, - Uuid::from(user_email.id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into_verification(clock))) - } - - #[tracing::instrument( - name = "db.user_email.consume_verification_code", - skip_all, - fields( - db.statement, - %user_email_verification.id, - user_email.id = %user_email_verification.user_email_id, - ), - err, - )] - async fn consume_verification_code( - &mut self, - clock: &Clock, - mut user_email_verification: UserEmailVerification, - ) -> Result { - if !matches!( - user_email_verification.state, - UserEmailVerificationState::Valid - ) { - return Err(DatabaseError::invalid_operation()); - } - - let consumed_at = clock.now(); - - sqlx::query!( - r#" - UPDATE user_email_confirmation_codes - SET consumed_at = $2 - WHERE user_email_confirmation_code_id = $1 - "#, - Uuid::from(user_email_verification.id), - consumed_at - ) - .traced() - .execute(&mut *self.conn) - .await?; - - user_email_verification.state = - UserEmailVerificationState::AlreadyUsed { when: consumed_at }; - - Ok(user_email_verification) - } -} diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 9dd3d2ca9..23c2f6d1e 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -13,26 +13,18 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; use mas_data_model::User; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; +use crate::Clock; mod email; mod password; mod session; -#[cfg(test)] -mod tests; - pub use self::{ - email::{PgUserEmailRepository, UserEmailRepository}, - password::{PgUserPasswordRepository, UserPasswordRepository}, - session::{BrowserSessionRepository, PgBrowserSessionRepository}, + email::UserEmailRepository, password::UserPasswordRepository, session::BrowserSessionRepository, }; #[async_trait] @@ -49,170 +41,3 @@ pub trait UserRepository: Send + Sync { ) -> Result; async fn exists(&mut self, username: &str) -> Result; } - -pub struct PgUserRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgUserRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(Debug, Clone)] -struct UserLookup { - user_id: Uuid, - username: String, - primary_user_email_id: Option, - - #[allow(dead_code)] - created_at: DateTime, -} - -impl From for User { - fn from(value: UserLookup) -> Self { - let id = value.user_id.into(); - Self { - id, - username: value.username, - sub: id.to_string(), - primary_user_email_id: value.primary_user_email_id.map(Into::into), - } - } -} - -#[async_trait] -impl<'c> UserRepository for PgUserRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.user.lookup", - skip_all, - fields( - db.statement, - user.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserLookup, - r#" - SELECT user_id - , username - , primary_user_email_id - , created_at - FROM users - WHERE user_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.user.find_by_username", - skip_all, - fields( - db.statement, - user.username = username, - ), - err, - )] - async fn find_by_username(&mut self, username: &str) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserLookup, - r#" - SELECT user_id - , username - , primary_user_email_id - , created_at - FROM users - WHERE username = $1 - "#, - username, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.user.add", - skip_all, - fields( - db.statement, - user.username = username, - user.id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - username: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("user.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO users (user_id, username, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - username, - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(User { - id, - username, - sub: id.to_string(), - primary_user_email_id: None, - }) - } - - #[tracing::instrument( - name = "db.user.exists", - skip_all, - fields( - db.statement, - user.username = username, - ), - err, - )] - async fn exists(&mut self, username: &str) -> Result { - let exists = sqlx::query_scalar!( - r#" - SELECT EXISTS( - SELECT 1 FROM users WHERE username = $1 - ) AS "exists!" - "#, - username - ) - .traced() - .fetch_one(&mut *self.conn) - .await?; - - Ok(exists) - } -} diff --git a/crates/storage/src/user/password.rs b/crates/storage/src/user/password.rs index 56c8a439c..2d2d25344 100644 --- a/crates/storage/src/user/password.rs +++ b/crates/storage/src/user/password.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,16 +13,10 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; use mas_data_model::{Password, User}; use rand::RngCore; -use sqlx::PgConnection; -use ulid::Ulid; -use uuid::Uuid; -use crate::{ - tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::Clock; #[async_trait] pub trait UserPasswordRepository: Send + Sync { @@ -39,134 +33,3 @@ pub trait UserPasswordRepository: Send + Sync { upgraded_from: Option<&Password>, ) -> Result; } - -pub struct PgUserPasswordRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgUserPasswordRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct UserPasswordLookup { - user_password_id: Uuid, - hashed_password: String, - version: i32, - upgraded_from_id: Option, - created_at: DateTime, -} - -#[async_trait] -impl<'c> UserPasswordRepository for PgUserPasswordRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.user_password.active", - skip_all, - fields( - db.statement, - %user.id, - %user.username, - ), - err, - )] - async fn active(&mut self, user: &User) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserPasswordLookup, - r#" - SELECT up.user_password_id - , up.hashed_password - , up.version - , up.upgraded_from_id - , up.created_at - FROM user_passwords up - WHERE up.user_id = $1 - ORDER BY up.created_at DESC - LIMIT 1 - "#, - Uuid::from(user.id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let id = Ulid::from(res.user_password_id); - - let version = res.version.try_into().map_err(|e| { - DatabaseInconsistencyError::on("user_passwords") - .column("version") - .row(id) - .source(e) - })?; - - let upgraded_from_id = res.upgraded_from_id.map(Ulid::from); - let created_at = res.created_at; - let hashed_password = res.hashed_password; - - Ok(Some(Password { - id, - hashed_password, - version, - upgraded_from_id, - created_at, - })) - } - - #[tracing::instrument( - name = "db.user_password.add", - skip_all, - fields( - db.statement, - %user.id, - %user.username, - user_password.id, - user_password.version = version, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - user: &User, - version: u16, - hashed_password: String, - upgraded_from: Option<&Password>, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("user_password.id", tracing::field::display(id)); - - let upgraded_from_id = upgraded_from.map(|p| p.id); - - sqlx::query!( - r#" - INSERT INTO user_passwords - (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at) - VALUES ($1, $2, $3, $4, $5, $6) - "#, - Uuid::from(id), - Uuid::from(user.id), - hashed_password, - i32::from(version), - upgraded_from_id.map(Uuid::from), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(Password { - id, - hashed_password, - version, - upgraded_from_id, - created_at, - }) - } -} diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index 10b96da77..2e55f40c1 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,18 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User}; +use mas_data_model::{BrowserSession, Password, UpstreamOAuthLink, User}; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{Page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, -}; +use crate::{pagination::Page, Clock, Pagination}; #[async_trait] pub trait BrowserSessionRepository: Send + Sync { @@ -65,351 +58,3 @@ pub trait BrowserSessionRepository: Send + Sync { upstream_oauth_link: &UpstreamOAuthLink, ) -> Result; } - -pub struct PgBrowserSessionRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgBrowserSessionRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(sqlx::FromRow)] -struct SessionLookup { - user_session_id: Uuid, - user_session_created_at: DateTime, - user_session_finished_at: Option>, - user_id: Uuid, - user_username: String, - user_primary_user_email_id: Option, - last_authentication_id: Option, - last_authd_at: Option>, -} - -impl TryFrom for BrowserSession { - type Error = DatabaseInconsistencyError; - - fn try_from(value: SessionLookup) -> Result { - let id = Ulid::from(value.user_id); - let user = User { - id, - username: value.user_username, - sub: id.to_string(), - primary_user_email_id: value.user_primary_user_email_id.map(Into::into), - }; - - let last_authentication = match (value.last_authentication_id, value.last_authd_at) { - (Some(id), Some(created_at)) => Some(Authentication { - id: id.into(), - created_at, - }), - (None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on( - "user_session_authentications", - )) - } - }; - - Ok(BrowserSession { - id: value.user_session_id.into(), - user, - created_at: value.user_session_created_at, - finished_at: value.user_session_finished_at, - last_authentication, - }) - } -} - -#[async_trait] -impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.browser_session.lookup", - skip_all, - fields( - db.statement, - user_session.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - SessionLookup, - r#" - SELECT s.user_session_id - , s.created_at AS "user_session_created_at" - , s.finished_at AS "user_session_finished_at" - , u.user_id - , u.username AS "user_username" - , u.primary_user_email_id AS "user_primary_user_email_id" - , a.user_session_authentication_id AS "last_authentication_id?" - , a.created_at AS "last_authd_at?" - FROM user_sessions s - INNER JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications a - USING (user_session_id) - WHERE s.user_session_id = $1 - ORDER BY a.created_at DESC - LIMIT 1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.browser_session.add", - skip_all, - fields( - db.statement, - %user.id, - user_session.id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - user: &User, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("user_session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO user_sessions (user_session_id, user_id, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - Uuid::from(user.id), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - let session = BrowserSession { - id, - // XXX - user: user.clone(), - created_at, - finished_at: None, - last_authentication: None, - }; - - Ok(session) - } - - #[tracing::instrument( - name = "db.browser_session.finish", - skip_all, - fields( - db.statement, - %user_session.id, - ), - err, - )] - async fn finish( - &mut self, - clock: &Clock, - mut user_session: BrowserSession, - ) -> Result { - let finished_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE user_sessions - SET finished_at = $1 - WHERE user_session_id = $2 - "#, - finished_at, - Uuid::from(user_session.id), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - user_session.finished_at = Some(finished_at); - - DatabaseError::ensure_affected_rows(&res, 1)?; - - Ok(user_session) - } - - #[tracing::instrument( - name = "db.browser_session.list_active_paginated", - skip_all, - fields( - db.statement, - %user.id, - ), - err, - )] - async fn list_active_paginated( - &mut self, - user: &User, - pagination: Pagination, - ) -> Result, Self::Error> { - // TODO: ordering of last authentication is wrong - let mut query = QueryBuilder::new( - r#" - SELECT DISTINCT ON (s.user_session_id) - s.user_session_id, - u.user_id, - u.username, - s.created_at, - a.user_session_authentication_id AS "last_authentication_id", - a.created_at AS "last_authd_at", - FROM user_sessions s - INNER JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications a - USING (user_session_id) - "#, - ); - - query - .push(" WHERE s.finished_at IS NULL AND s.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("s.user_session_id", pagination); - - let edges: Vec = query - .build_query_as() - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let page = pagination - .process(edges) - .try_map(BrowserSession::try_from)?; - Ok(page) - } - - #[tracing::instrument( - name = "db.browser_session.count_active", - skip_all, - fields( - db.statement, - %user.id, - ), - err, - )] - async fn count_active(&mut self, user: &User) -> Result { - let res = sqlx::query_scalar!( - r#" - SELECT COUNT(*) as "count!" - FROM user_sessions s - WHERE s.user_id = $1 AND s.finished_at IS NULL - "#, - Uuid::from(user.id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await?; - - res.try_into().map_err(DatabaseError::to_invalid_operation) - } - - #[tracing::instrument( - name = "db.browser_session.authenticate_with_password", - skip_all, - fields( - db.statement, - %user_session.id, - %user_password.id, - user_session_authentication.id, - ), - err, - )] - async fn authenticate_with_password( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - mut user_session: BrowserSession, - user_password: &Password, - ) -> Result { - let _user_password = user_password; - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record( - "user_session_authentication.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO user_session_authentications - (user_session_authentication_id, user_session_id, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - Uuid::from(user_session.id), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - user_session.last_authentication = Some(Authentication { id, created_at }); - - Ok(user_session) - } - - #[tracing::instrument( - name = "db.browser_session.authenticate_with_upstream", - skip_all, - fields( - db.statement, - %user_session.id, - %upstream_oauth_link.id, - user_session_authentication.id, - ), - err, - )] - async fn authenticate_with_upstream( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - mut user_session: BrowserSession, - upstream_oauth_link: &UpstreamOAuthLink, - ) -> Result { - let _upstream_oauth_link = upstream_oauth_link; - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record( - "user_session_authentication.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO user_session_authentications - (user_session_authentication_id, user_session_id, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - Uuid::from(user_session.id), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - user_session.last_authentication = Some(Authentication { id, created_at }); - - Ok(user_session) - } -} diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index b82e16c2a..99270a72b 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -14,3 +14,4 @@ tracing = "0.1.37" sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } mas-storage = { path = "../storage" } +mas-storage-pg = { path = "../storage-pg" } diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index 39a33b8dd..660688608 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -14,7 +14,8 @@ //! Database-related tasks -use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, PgRepository, Repository}; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, Repository}; +use mas_storage_pg::PgRepository; use sqlx::{Pool, Postgres}; use tracing::{debug, error, info}; From 2c26ddb24974c6e0b844c3d85529106aaf6fdc20 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 18 Jan 2023 12:20:30 +0100 Subject: [PATCH 28/45] storage: make the Clock a trait --- crates/cli/src/commands/manage.rs | 4 +- crates/cli/src/commands/templates.rs | 4 +- crates/handlers/src/compat/login.rs | 4 +- .../handlers/src/compat/login_sso_complete.rs | 2 +- crates/handlers/src/compat/logout.rs | 4 +- crates/handlers/src/compat/refresh.rs | 2 +- crates/handlers/src/lib.rs | 4 +- crates/handlers/src/oauth2/consent.rs | 2 +- crates/handlers/src/oauth2/introspection.rs | 4 +- crates/handlers/src/oauth2/token.rs | 2 +- crates/handlers/src/oauth2/userinfo.rs | 2 +- .../handlers/src/upstream_oauth2/authorize.rs | 2 +- .../handlers/src/upstream_oauth2/callback.rs | 2 +- crates/handlers/src/upstream_oauth2/link.rs | 2 +- .../handlers/src/views/account/emails/add.rs | 2 +- .../handlers/src/views/account/emails/mod.rs | 4 +- .../src/views/account/emails/verify.rs | 4 +- crates/handlers/src/views/account/mod.rs | 2 +- crates/handlers/src/views/account/password.rs | 2 +- crates/handlers/src/views/index.rs | 1 + crates/handlers/src/views/login.rs | 2 +- crates/handlers/src/views/logout.rs | 4 +- crates/handlers/src/views/reauth.rs | 2 +- crates/handlers/src/views/register.rs | 2 +- crates/storage-pg/src/compat/access_token.rs | 4 +- crates/storage-pg/src/compat/mod.rs | 7 +- crates/storage-pg/src/compat/refresh_token.rs | 4 +- crates/storage-pg/src/compat/session.rs | 4 +- crates/storage-pg/src/compat/sso_login.rs | 6 +- crates/storage-pg/src/oauth2/access_token.rs | 6 +- .../src/oauth2/authorization_grant.rs | 6 +- crates/storage-pg/src/oauth2/client.rs | 6 +- crates/storage-pg/src/oauth2/refresh_token.rs | 4 +- crates/storage-pg/src/oauth2/session.rs | 4 +- crates/storage-pg/src/upstream_oauth2/link.rs | 2 +- crates/storage-pg/src/upstream_oauth2/mod.rs | 7 +- .../src/upstream_oauth2/provider.rs | 2 +- .../storage-pg/src/upstream_oauth2/session.rs | 6 +- crates/storage-pg/src/user/email.rs | 12 +- crates/storage-pg/src/user/mod.rs | 2 +- crates/storage-pg/src/user/password.rs | 2 +- crates/storage-pg/src/user/session.rs | 8 +- crates/storage-pg/src/user/tests.rs | 17 +-- crates/storage/src/clock.rs | 129 ++++++++++++++++++ crates/storage/src/compat/access_token.rs | 4 +- crates/storage/src/compat/refresh_token.rs | 4 +- crates/storage/src/compat/session.rs | 4 +- crates/storage/src/compat/sso_login.rs | 6 +- crates/storage/src/lib.rs | 88 +----------- crates/storage/src/oauth2/access_token.rs | 6 +- .../storage/src/oauth2/authorization_grant.rs | 6 +- crates/storage/src/oauth2/client.rs | 6 +- crates/storage/src/oauth2/refresh_token.rs | 4 +- crates/storage/src/oauth2/session.rs | 4 +- crates/storage/src/upstream_oauth2/link.rs | 2 +- .../storage/src/upstream_oauth2/provider.rs | 2 +- crates/storage/src/upstream_oauth2/session.rs | 6 +- crates/storage/src/user/email.rs | 10 +- crates/storage/src/user/mod.rs | 2 +- crates/storage/src/user/password.rs | 2 +- crates/storage/src/user/session.rs | 8 +- crates/tasks/src/database.rs | 6 +- 62 files changed, 261 insertions(+), 212 deletions(-) create mode 100644 crates/storage/src/clock.rs diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index c608db836..2f3e88528 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -21,7 +21,7 @@ use mas_storage::{ oauth2::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, - Clock, Repository, + Repository, SystemClock, }; use mas_storage_pg::PgRepository; use oauth2_types::scope::Scope; @@ -188,7 +188,7 @@ impl Options { #[allow(clippy::too_many_lines)] pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { use Subcommand as SC; - let clock = Clock::default(); + let clock = SystemClock::default(); // XXX: we should disallow SeedableRng::from_entropy let mut rng = rand_chacha::ChaChaRng::from_entropy(); diff --git a/crates/cli/src/commands/templates.rs b/crates/cli/src/commands/templates.rs index a3a1bf9cc..6f09b7519 100644 --- a/crates/cli/src/commands/templates.rs +++ b/crates/cli/src/commands/templates.rs @@ -14,7 +14,7 @@ use camino::Utf8PathBuf; use clap::Parser; -use mas_storage::Clock; +use mas_storage::{Clock, SystemClock}; use mas_templates::Templates; use rand::SeedableRng; use tracing::info_span; @@ -41,7 +41,7 @@ impl Options { SC::Check { path } => { let _span = info_span!("cli.templates.check").entered(); - let clock = Clock::default(); + let clock = SystemClock::default(); // XXX: we should disallow SeedableRng::from_entropy let mut rng = rand_chacha::ChaChaRng::from_entropy(); let url_builder = mas_router::UrlBuilder::new("https://example.com/".parse()?); diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index bfd36d8ae..9b0a0792f 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -22,7 +22,7 @@ use mas_storage::{ CompatSsoLoginRepository, }, user::{UserPasswordRepository, UserRepository}, - Clock, Repository, + Clock, Repository, SystemClock, }; use mas_storage_pg::PgRepository; use serde::{Deserialize, Serialize}; @@ -254,7 +254,7 @@ pub(crate) async fn post( async fn token_login( repo: &mut PgRepository, - clock: &Clock, + clock: &SystemClock, token: &str, ) -> Result<(CompatSession, User), RouteError> { let login = repo diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 1fea922e8..631d1126d 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -31,7 +31,7 @@ use mas_keystore::Encrypter; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; use mas_storage::{ compat::{CompatSessionRepository, CompatSsoLoginRepository}, - Repository, + Clock, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 21229fe76..310fef2e7 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -18,7 +18,7 @@ use hyper::StatusCode; use mas_data_model::TokenType; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - Clock, Repository, + Clock, Repository, SystemClock, }; use mas_storage_pg::PgRepository; use sqlx::PgPool; @@ -72,7 +72,7 @@ pub(crate) async fn post( State(pool): State, maybe_authorization: Option>>, ) -> Result { - let clock = Clock::default(); + let clock = SystemClock::default(); let mut repo = PgRepository::from_pool(&pool).await?; let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index e16013950..8b47a81f3 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -18,7 +18,7 @@ use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, - Repository, + Clock, Repository, }; use mas_storage_pg::PgRepository; use serde::{Deserialize, Serialize}; diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 501f3176f..0360e7418 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -409,8 +409,8 @@ async fn test_state(pool: PgPool) -> Result { } // XXX: that should be moved somewhere else -fn clock_and_rng() -> (mas_storage::Clock, rand_chacha::ChaChaRng) { - let clock = mas_storage::Clock::default(); +fn clock_and_rng() -> (mas_storage::SystemClock, rand_chacha::ChaChaRng) { + let clock = mas_storage::SystemClock::default(); // This rng is used to source the local rng #[allow(clippy::disallowed_methods)] diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index b0f752f79..8fe4d2acf 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -30,7 +30,7 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - Repository, + Clock, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index e8f9941f8..245f2125c 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -25,7 +25,7 @@ use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository}, user::{BrowserSessionRepository, UserRepository}, - Clock, Repository, + Clock, Repository, SystemClock, }; use mas_storage_pg::PgRepository; use oauth2_types::{ @@ -130,7 +130,7 @@ pub(crate) async fn post( State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { - let clock = Clock::default(); + let clock = SystemClock::default(); let mut repo = PgRepository::from_pool(&pool).await?; let client = client_authorization diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 67ecb4985..de0118f4b 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -37,7 +37,7 @@ use mas_storage::{ OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, user::BrowserSessionRepository, - Repository, + Clock, Repository, }; use mas_storage_pg::PgRepository; use oauth2_types::{ diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 2f5600379..39bc7587b 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -31,7 +31,7 @@ use mas_router::UrlBuilder; use mas_storage::{ oauth2::OAuth2ClientRepository, user::{BrowserSessionRepository, UserEmailRepository}, - Repository, + Clock, Repository, }; use mas_storage_pg::PgRepository; use oauth2_types::scope; diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index fcf5a7d14..565acea89 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -24,7 +24,7 @@ use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; use mas_storage::{ upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, - Repository, + Clock, Repository, }; use mas_storage_pg::PgRepository; use sqlx::PgPool; diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index d243666d4..ffde4ef32 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -30,7 +30,7 @@ use mas_storage::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, - Repository, + Clock, Repository, }; use mas_storage_pg::PgRepository; use oauth2_types::errors::ClientErrorCode; diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 8709ff219..c0408770b 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -27,7 +27,7 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, user::{BrowserSessionRepository, UserRepository}, - Repository, + Clock, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{ diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index e99c8e4d7..4c11a485a 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -24,7 +24,7 @@ use mas_axum_utils::{ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Repository}; +use mas_storage::{user::UserEmailRepository, Clock, Repository}; use mas_storage_pg::PgRepository; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 4d70ab332..f1b7c733a 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -71,7 +71,7 @@ pub(crate) async fn get( async fn render( rng: impl Rng + Send, - clock: &Clock, + clock: &impl Clock, templates: Templates, session: BrowserSession, cookie_jar: PrivateCookieJar, @@ -94,7 +94,7 @@ async fn start_email_verification( mailer: &Mailer, repo: &mut impl Repository, mut rng: impl Rng + Send, - clock: &Clock, + clock: &impl Clock, user: &User, user_email: UserEmail, ) -> anyhow::Result<()> { diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 2b398b424..f37b15e46 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -24,7 +24,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Clock, Repository}; +use mas_storage::{user::UserEmailRepository, Clock, Repository, SystemClock}; use mas_storage_pg::PgRepository; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; @@ -89,7 +89,7 @@ pub(crate) async fn post( Path(id): Path, Form(form): Form>, ) -> Result { - let clock = Clock::default(); + let clock = SystemClock::default(); let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 8d2eb3e2d..29aaeda32 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -25,7 +25,7 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository}, - Repository, + Clock, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{AccountContext, TemplateContext, Templates}; diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 089093f63..1624f6f60 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -67,7 +67,7 @@ pub(crate) async fn get( async fn render( rng: impl Rng + Send, - clock: &Clock, + clock: &impl Clock, templates: Templates, session: BrowserSession, cookie_jar: PrivateCookieJar, diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index cab7c743e..2298e83e7 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -20,6 +20,7 @@ use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_keystore::Encrypter; use mas_router::UrlBuilder; +use mas_storage::Clock; use mas_storage_pg::PgRepository; use mas_templates::{IndexContext, TemplateContext, Templates}; use sqlx::PgPool; diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 87ba9e848..10295dbc3 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -167,7 +167,7 @@ async fn login( password_manager: PasswordManager, repo: &mut impl Repository, mut rng: impl Rng + CryptoRng + Send, - clock: &Clock, + clock: &impl Clock, username: &str, password: &str, ) -> Result { diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 373264d06..781780ca4 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -23,7 +23,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; -use mas_storage::{user::BrowserSessionRepository, Clock, Repository}; +use mas_storage::{user::BrowserSessionRepository, Clock, Repository, SystemClock}; use mas_storage_pg::PgRepository; use sqlx::PgPool; @@ -32,7 +32,7 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>>, ) -> Result { - let clock = Clock::default(); + let clock = SystemClock::default(); let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 49249f3cf..571d2e179 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -26,7 +26,7 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - Repository, + Clock, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{ReauthContext, TemplateContext, Templates}; diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 58db6ec16..3e2d87a97 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -33,7 +33,7 @@ use mas_policy::PolicyFactory; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - Repository, + Clock, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{ diff --git a/crates/storage-pg/src/compat/access_token.rs b/crates/storage-pg/src/compat/access_token.rs index 5f73ed9e4..822c3a8af 100644 --- a/crates/storage-pg/src/compat/access_token.rs +++ b/crates/storage-pg/src/compat/access_token.rs @@ -143,7 +143,7 @@ impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, compat_session: &CompatSession, token: String, expires_after: Option, @@ -191,7 +191,7 @@ impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> { )] async fn expire( &mut self, - clock: &Clock, + clock: &dyn Clock, mut compat_access_token: CompatAccessToken, ) -> Result { let expires_at = clock.now(); diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs index 732ce3aaa..dd68e4d5f 100644 --- a/crates/storage-pg/src/compat/mod.rs +++ b/crates/storage-pg/src/compat/mod.rs @@ -27,6 +27,7 @@ mod tests { use chrono::Duration; use mas_data_model::Device; use mas_storage::{ + clock::MockClock, compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, }, @@ -44,7 +45,7 @@ mod tests { const FIRST_TOKEN: &str = "first_access_token"; const SECOND_TOKEN: &str = "second_access_token"; let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); + let clock = MockClock::default(); let mut repo = PgRepository::from_pool(&pool).await.unwrap(); // Create a user @@ -101,7 +102,7 @@ mod tests { const FIRST_TOKEN: &str = "first_access_token"; const SECOND_TOKEN: &str = "second_access_token"; let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); + let clock = MockClock::default(); let mut repo = PgRepository::from_pool(&pool).await.unwrap(); // Create a user @@ -221,7 +222,7 @@ mod tests { const ACCESS_TOKEN: &str = "access_token"; const REFRESH_TOKEN: &str = "refresh_token"; let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); + let clock = MockClock::default(); let mut repo = PgRepository::from_pool(&pool).await.unwrap(); // Create a user diff --git a/crates/storage-pg/src/compat/refresh_token.rs b/crates/storage-pg/src/compat/refresh_token.rs index 314e81479..991e14381 100644 --- a/crates/storage-pg/src/compat/refresh_token.rs +++ b/crates/storage-pg/src/compat/refresh_token.rs @@ -154,7 +154,7 @@ impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, compat_session: &CompatSession, compat_access_token: &CompatAccessToken, token: String, @@ -202,7 +202,7 @@ impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> { )] async fn consume( &mut self, - clock: &Clock, + clock: &dyn Clock, compat_refresh_token: CompatRefreshToken, ) -> Result { let consumed_at = clock.now(); diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs index a6e65f9a4..16208b4f0 100644 --- a/crates/storage-pg/src/compat/session.rs +++ b/crates/storage-pg/src/compat/session.rs @@ -122,7 +122,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, user: &User, device: Device, ) -> Result { @@ -166,7 +166,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { )] async fn finish( &mut self, - clock: &Clock, + clock: &dyn Clock, compat_session: CompatSession, ) -> Result { let finished_at = clock.now(); diff --git a/crates/storage-pg/src/compat/sso_login.rs b/crates/storage-pg/src/compat/sso_login.rs index a2eeb9263..1b8e0225f 100644 --- a/crates/storage-pg/src/compat/sso_login.rs +++ b/crates/storage-pg/src/compat/sso_login.rs @@ -177,7 +177,7 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, login_token: String, redirect_uri: Url, ) -> Result { @@ -223,7 +223,7 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { )] async fn fulfill( &mut self, - clock: &Clock, + clock: &dyn Clock, compat_sso_login: CompatSsoLogin, compat_session: &CompatSession, ) -> Result { @@ -265,7 +265,7 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { )] async fn exchange( &mut self, - clock: &Clock, + clock: &dyn Clock, compat_sso_login: CompatSsoLogin, ) -> Result { let exchanged_at = clock.now(); diff --git a/crates/storage-pg/src/oauth2/access_token.rs b/crates/storage-pg/src/oauth2/access_token.rs index 33d95242f..ecd5798b0 100644 --- a/crates/storage-pg/src/oauth2/access_token.rs +++ b/crates/storage-pg/src/oauth2/access_token.rs @@ -142,7 +142,7 @@ impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, session: &Session, access_token: String, expires_after: Duration, @@ -182,7 +182,7 @@ impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> { async fn revoke( &mut self, - clock: &Clock, + clock: &dyn Clock, access_token: AccessToken, ) -> Result { let revoked_at = clock.now(); @@ -205,7 +205,7 @@ impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> { .map_err(DatabaseError::to_invalid_operation) } - async fn cleanup_expired(&mut self, clock: &Clock) -> Result { + async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result { // Cleanup token which expired more than 15 minutes ago let threshold = clock.now() - Duration::minutes(15); let res = sqlx::query!( diff --git a/crates/storage-pg/src/oauth2/authorization_grant.rs b/crates/storage-pg/src/oauth2/authorization_grant.rs index 027111a76..92116a62d 100644 --- a/crates/storage-pg/src/oauth2/authorization_grant.rs +++ b/crates/storage-pg/src/oauth2/authorization_grant.rs @@ -211,7 +211,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, client: &Client, redirect_uri: Url, scope: Scope, @@ -410,7 +410,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi )] async fn fulfill( &mut self, - clock: &Clock, + clock: &dyn Clock, session: &Session, grant: AuthorizationGrant, ) -> Result { @@ -451,7 +451,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi )] async fn exchange( &mut self, - clock: &Clock, + clock: &dyn Clock, grant: AuthorizationGrant, ) -> Result { let exchanged_at = clock.now(); diff --git a/crates/storage-pg/src/oauth2/client.rs b/crates/storage-pg/src/oauth2/client.rs index 4430c669c..1b8a99f93 100644 --- a/crates/storage-pg/src/oauth2/client.rs +++ b/crates/storage-pg/src/oauth2/client.rs @@ -378,7 +378,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { async fn add( &mut self, mut rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, redirect_uris: Vec, encrypted_client_secret: Option, grant_types: Vec, @@ -535,7 +535,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { async fn add_from_config( &mut self, mut rng: impl Rng + Send, - clock: &Clock, + clock: &dyn Clock, client_id: Ulid, client_auth_method: OAuthClientAuthenticationMethod, encrypted_client_secret: Option, @@ -707,7 +707,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { async fn give_consent_for_user( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, client: &Client, user: &User, scope: &Scope, diff --git a/crates/storage-pg/src/oauth2/refresh_token.rs b/crates/storage-pg/src/oauth2/refresh_token.rs index 47281d934..ba2fa5334 100644 --- a/crates/storage-pg/src/oauth2/refresh_token.rs +++ b/crates/storage-pg/src/oauth2/refresh_token.rs @@ -150,7 +150,7 @@ impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, session: &Session, access_token: &AccessToken, refresh_token: String, @@ -199,7 +199,7 @@ impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> { )] async fn consume( &mut self, - clock: &Clock, + clock: &dyn Clock, refresh_token: RefreshToken, ) -> Result { let consumed_at = clock.now(); diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index 96f798e65..f1c412790 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -131,7 +131,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { async fn create_from_grant( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, grant: &AuthorizationGrant, user_session: &BrowserSession, ) -> Result { @@ -182,7 +182,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { ), err, )] - async fn finish(&mut self, clock: &Clock, session: Session) -> Result { + async fn finish(&mut self, clock: &dyn Clock, session: Session) -> Result { let finished_at = clock.now(); let res = sqlx::query!( r#" diff --git a/crates/storage-pg/src/upstream_oauth2/link.rs b/crates/storage-pg/src/upstream_oauth2/link.rs index 4087e2c78..c38b344b2 100644 --- a/crates/storage-pg/src/upstream_oauth2/link.rs +++ b/crates/storage-pg/src/upstream_oauth2/link.rs @@ -149,7 +149,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, upstream_oauth_provider: &UpstreamOAuthProvider, subject: String, ) -> Result { diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index e77daba2e..af631f15c 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -25,12 +25,13 @@ pub use self::{ mod tests { use chrono::Duration; use mas_storage::{ + clock::MockClock, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, user::UserRepository, - Clock, Pagination, Repository, + Pagination, Repository, }; use oauth2_types::scope::{Scope, OPENID}; use rand::SeedableRng; @@ -41,7 +42,7 @@ mod tests { #[sqlx::test(migrator = "crate::MIGRATOR")] async fn test_repository(pool: PgPool) { let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); + let clock = MockClock::default(); let mut repo = PgRepository::from_pool(&pool).await.unwrap(); // The provider list should be empty at the start @@ -183,7 +184,7 @@ mod tests { let scope = Scope::from_iter([OPENID]); let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); + let clock = MockClock::default(); let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut ids = Vec::with_capacity(20); diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs index 480249eee..dc1b0c85b 100644 --- a/crates/storage-pg/src/upstream_oauth2/provider.rs +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -149,7 +149,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, issuer: String, scope: Scope, token_endpoint_auth_method: OAuthClientAuthenticationMethod, diff --git a/crates/storage-pg/src/upstream_oauth2/session.rs b/crates/storage-pg/src/upstream_oauth2/session.rs index 699a463f0..3cdef0c73 100644 --- a/crates/storage-pg/src/upstream_oauth2/session.rs +++ b/crates/storage-pg/src/upstream_oauth2/session.rs @@ -156,7 +156,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, upstream_oauth_provider: &UpstreamOAuthProvider, state_str: String, code_challenge_verifier: Option, @@ -217,7 +217,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> )] async fn complete_with_link( &mut self, - clock: &Clock, + clock: &dyn Clock, upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, @@ -260,7 +260,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> )] async fn consume( &mut self, - clock: &Clock, + clock: &dyn Clock, upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, ) -> Result { let consumed_at = clock.now(); diff --git a/crates/storage-pg/src/user/email.rs b/crates/storage-pg/src/user/email.rs index 936b06491..b9d732dec 100644 --- a/crates/storage-pg/src/user/email.rs +++ b/crates/storage-pg/src/user/email.rs @@ -68,7 +68,7 @@ struct UserEmailConfirmationCodeLookup { } impl UserEmailConfirmationCodeLookup { - fn into_verification(self, clock: &Clock) -> UserEmailVerification { + fn into_verification(self, clock: &dyn Clock) -> UserEmailVerification { let now = clock.now(); let state = if let Some(when) = self.consumed_at { UserEmailVerificationState::AlreadyUsed { when } @@ -301,7 +301,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, user: &User, email: String, ) -> Result { @@ -378,7 +378,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { async fn mark_as_verified( &mut self, - clock: &Clock, + clock: &dyn Clock, mut user_email: UserEmail, ) -> Result { let confirmed_at = clock.now(); @@ -430,7 +430,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { async fn add_verification_code( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, user_email: &UserEmail, max_age: chrono::Duration, code: String, @@ -479,7 +479,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { )] async fn find_verification_code( &mut self, - clock: &Clock, + clock: &dyn Clock, user_email: &UserEmail, code: &str, ) -> Result, Self::Error> { @@ -521,7 +521,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { )] async fn consume_verification_code( &mut self, - clock: &Clock, + clock: &dyn Clock, mut user_email_verification: UserEmailVerification, ) -> Result { if !matches!( diff --git a/crates/storage-pg/src/user/mod.rs b/crates/storage-pg/src/user/mod.rs index d73202613..8ec6170f5 100644 --- a/crates/storage-pg/src/user/mod.rs +++ b/crates/storage-pg/src/user/mod.rs @@ -148,7 +148,7 @@ impl<'c> UserRepository for PgUserRepository<'c> { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, username: String, ) -> Result { let created_at = clock.now(); diff --git a/crates/storage-pg/src/user/password.rs b/crates/storage-pg/src/user/password.rs index 997b12272..696b30e77 100644 --- a/crates/storage-pg/src/user/password.rs +++ b/crates/storage-pg/src/user/password.rs @@ -115,7 +115,7 @@ impl<'c> UserPasswordRepository for PgUserPasswordRepository<'c> { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, user: &User, version: u16, hashed_password: String, diff --git a/crates/storage-pg/src/user/session.rs b/crates/storage-pg/src/user/session.rs index d216c0679..e5616fffa 100644 --- a/crates/storage-pg/src/user/session.rs +++ b/crates/storage-pg/src/user/session.rs @@ -142,7 +142,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, user: &User, ) -> Result { let created_at = clock.now(); @@ -185,7 +185,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { )] async fn finish( &mut self, - clock: &Clock, + clock: &dyn Clock, mut user_session: BrowserSession, ) -> Result { let finished_at = clock.now(); @@ -297,7 +297,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { async fn authenticate_with_password( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, mut user_session: BrowserSession, user_password: &Password, ) -> Result { @@ -342,7 +342,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { async fn authenticate_with_upstream( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, mut user_session: BrowserSession, upstream_oauth_link: &UpstreamOAuthLink, ) -> Result { diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index f0a071b01..097bca74f 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -14,8 +14,9 @@ use chrono::Duration; use mas_storage::{ + clock::MockClock, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - Clock, Repository, + Repository, }; use rand::SeedableRng; use rand_chacha::ChaChaRng; @@ -30,7 +31,7 @@ async fn test_user_repo(pool: PgPool) { let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); + let clock = MockClock::default(); // Initially, the user shouldn't exist assert!(!repo.user().exists(USERNAME).await.unwrap()); @@ -78,7 +79,7 @@ async fn test_user_email_repo(pool: PgPool) { let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); + let clock = MockClock::default(); let user = repo .user() @@ -89,7 +90,7 @@ async fn test_user_email_repo(pool: PgPool) { // The user email should not exist yet assert!(repo .user_email() - .find(&user, EMAIL) + .find(&user, &EMAIL) .await .unwrap() .is_none()); @@ -110,7 +111,7 @@ async fn test_user_email_repo(pool: PgPool) { assert!(repo .user_email() - .find(&user, EMAIL) + .find(&user, &EMAIL) .await .unwrap() .is_some()); @@ -180,7 +181,7 @@ async fn test_user_email_repo(pool: PgPool) { // Reload the user_email let user_email = repo .user_email() - .find(&user, EMAIL) + .find(&user, &EMAIL) .await .unwrap() .expect("user email was not found"); @@ -260,7 +261,7 @@ async fn test_user_password_repo(pool: PgPool) { let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); + let clock = MockClock::default(); let user = repo .user() @@ -340,7 +341,7 @@ async fn test_user_session(pool: PgPool) { let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); + let clock = MockClock::default(); let user = repo .user() diff --git a/crates/storage/src/clock.rs b/crates/storage/src/clock.rs new file mode 100644 index 000000000..54ca3b4fc --- /dev/null +++ b/crates/storage/src/clock.rs @@ -0,0 +1,129 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! A [`Clock`] is a way to get the current date and time. +//! +//! This module defines two implemetation of the [`Clock`] trait: +//! [`SystemClock`] which uses the system time, and a [`MockClock`], which can +//! be used and freely manipulated in tests. + +use std::sync::atomic::AtomicI64; + +use chrono::{DateTime, TimeZone, Utc}; + +/// Represents a clock which can give the current date and time +pub trait Clock: Sync { + /// Get the current date and time + fn now(&self) -> DateTime; +} + +/// A clock which uses the system time +#[derive(Clone, Default)] +pub struct SystemClock { + _private: (), +} + +impl Clock for SystemClock { + fn now(&self) -> DateTime { + // This is the clock used elsewhere, it's fine to call Utc::now here + #[allow(clippy::disallowed_methods)] + Utc::now() + } +} + +/// A fake clock, which uses a fixed timestamp, and can be advanced with the +/// [`MockClock::advance`] method. +/// +/// ```rust +/// use mas_storage::clock::{Clock, MockClock}; +/// use chrono::Duration; +/// +/// let clock = MockClock::default(); +/// let t1 = clock.now(); +/// let t2 = clock.now(); +/// assert_eq!(t1, t2); +/// +/// clock.advance(Duration::seconds(10)); +/// let t3 = clock.now(); +/// assert_eq!(t2 + Duration::seconds(10), t3); +/// ``` +pub struct MockClock { + timestamp: AtomicI64, +} + +impl Default for MockClock { + fn default() -> Self { + let datetime = Utc.with_ymd_and_hms(2022, 1, 16, 14, 40, 0).unwrap(); + Self::new(datetime) + } +} + +impl MockClock { + /// Create a new clock which starts at the given datetime + #[must_use] + pub fn new(datetime: DateTime) -> Self { + let timestamp = AtomicI64::new(datetime.timestamp()); + Self { timestamp } + } + + /// Move the clock forward by the given amount of time + pub fn advance(&self, duration: chrono::Duration) { + self.timestamp + .fetch_add(duration.num_seconds(), std::sync::atomic::Ordering::Relaxed); + } +} + +impl Clock for MockClock { + fn now(&self) -> DateTime { + let timestamp = self.timestamp.load(std::sync::atomic::Ordering::Relaxed); + chrono::TimeZone::timestamp_opt(&Utc, timestamp, 0).unwrap() + } +} + +#[cfg(test)] +mod tests { + use chrono::Duration; + + use super::*; + + #[test] + fn test_mocked_clock() { + let clock = MockClock::default(); + + // Time should be frozen, and give out the same timestamp on each call + let first = clock.now(); + std::thread::sleep(std::time::Duration::from_millis(10)); + let second = clock.now(); + + assert_eq!(first, second); + + // Clock can be advanced by a fixed duration + clock.advance(Duration::seconds(10)); + let third = clock.now(); + assert_eq!(first + Duration::seconds(10), third); + } + + #[test] + fn test_real_clock() { + let clock = SystemClock::default(); + + // Time should not be frozen + let first = clock.now(); + std::thread::sleep(std::time::Duration::from_millis(10)); + let second = clock.now(); + + assert_ne!(first, second); + assert!(first < second); + } +} diff --git a/crates/storage/src/compat/access_token.rs b/crates/storage/src/compat/access_token.rs index 46ff6e3f8..e135be398 100644 --- a/crates/storage/src/compat/access_token.rs +++ b/crates/storage/src/compat/access_token.rs @@ -37,7 +37,7 @@ pub trait CompatAccessTokenRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, compat_session: &CompatSession, token: String, expires_after: Option, @@ -46,7 +46,7 @@ pub trait CompatAccessTokenRepository: Send + Sync { /// Set the expiration time of the compat access token to now async fn expire( &mut self, - clock: &Clock, + clock: &dyn Clock, compat_access_token: CompatAccessToken, ) -> Result; } diff --git a/crates/storage/src/compat/refresh_token.rs b/crates/storage/src/compat/refresh_token.rs index 7a1057ff8..f8e8dfaa8 100644 --- a/crates/storage/src/compat/refresh_token.rs +++ b/crates/storage/src/compat/refresh_token.rs @@ -36,7 +36,7 @@ pub trait CompatRefreshTokenRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, compat_session: &CompatSession, compat_access_token: &CompatAccessToken, token: String, @@ -45,7 +45,7 @@ pub trait CompatRefreshTokenRepository: Send + Sync { /// Consume a compat refresh token async fn consume( &mut self, - clock: &Clock, + clock: &dyn Clock, compat_refresh_token: CompatRefreshToken, ) -> Result; } diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index 34bc68381..fa5cbd6e6 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -30,7 +30,7 @@ pub trait CompatSessionRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, user: &User, device: Device, ) -> Result; @@ -38,7 +38,7 @@ pub trait CompatSessionRepository: Send + Sync { /// End a compat session async fn finish( &mut self, - clock: &Clock, + clock: &dyn Clock, compat_session: CompatSession, ) -> Result; } diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs index 348e0ac5f..6ee3270a6 100644 --- a/crates/storage/src/compat/sso_login.rs +++ b/crates/storage/src/compat/sso_login.rs @@ -37,7 +37,7 @@ pub trait CompatSsoLoginRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, login_token: String, redirect_uri: Url, ) -> Result; @@ -45,7 +45,7 @@ pub trait CompatSsoLoginRepository: Send + Sync { /// Fulfill a compat SSO login by providing a compat session async fn fulfill( &mut self, - clock: &Clock, + clock: &dyn Clock, compat_sso_login: CompatSsoLogin, compat_session: &CompatSession, ) -> Result; @@ -53,7 +53,7 @@ pub trait CompatSsoLoginRepository: Send + Sync { /// Mark a compat SSO login as exchanged async fn exchange( &mut self, - clock: &Clock, + clock: &dyn Clock, compat_sso_login: CompatSsoLogin, ) -> Result; diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index a65c806cf..2a83b55e7 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -28,92 +28,7 @@ clippy::module_name_repetitions )] -use chrono::{DateTime, Utc}; - -#[derive(Debug, Clone, Default)] -pub struct Clock { - _private: (), - - // #[cfg(test)] - mock: Option>, -} - -impl Clock { - #[must_use] - pub fn now(&self) -> DateTime { - // #[cfg(test)] - if let Some(timestamp) = &self.mock { - let timestamp = timestamp.load(std::sync::atomic::Ordering::Relaxed); - return chrono::TimeZone::timestamp_opt(&Utc, timestamp, 0).unwrap(); - } - - // This is the clock used elsewhere, it's fine to call Utc::now here - #[allow(clippy::disallowed_methods)] - Utc::now() - } - - // #[cfg(test)] - #[must_use] - pub fn mock() -> Self { - use std::sync::{atomic::AtomicI64, Arc}; - - use chrono::TimeZone; - - let datetime = Utc.with_ymd_and_hms(2022, 1, 16, 14, 40, 0).unwrap(); - let timestamp = datetime.timestamp(); - - Self { - mock: Some(Arc::new(AtomicI64::new(timestamp))), - _private: (), - } - } - - // #[cfg(test)] - pub fn advance(&self, duration: chrono::Duration) { - let timestamp = self - .mock - .as_ref() - .expect("Clock::advance should only be called on mocked clocks in tests"); - timestamp.fetch_add(duration.num_seconds(), std::sync::atomic::Ordering::Relaxed); - } -} - -#[cfg(test)] -mod tests { - use chrono::Duration; - - use super::*; - - #[test] - fn test_mocked_clock() { - let clock = Clock::mock(); - - // Time should be frozen, and give out the same timestamp on each call - let first = clock.now(); - std::thread::sleep(std::time::Duration::from_millis(10)); - let second = clock.now(); - - assert_eq!(first, second); - - // Clock can be advanced by a fixed duration - clock.advance(Duration::seconds(10)); - let third = clock.now(); - assert_eq!(first + Duration::seconds(10), third); - } - - #[test] - fn test_real_clock() { - let clock = Clock::default(); - - // Time should not be frozen - let first = clock.now(); - std::thread::sleep(std::time::Duration::from_millis(10)); - let second = clock.now(); - - assert_ne!(first, second); - assert!(first < second); - } -} +pub mod clock; pub mod compat; pub mod oauth2; @@ -123,6 +38,7 @@ pub mod upstream_oauth2; pub mod user; pub use self::{ + clock::{Clock, SystemClock}, pagination::{Page, Pagination}, repository::Repository, }; diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index a0406e44c..4bbcf8857 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -37,7 +37,7 @@ pub trait OAuth2AccessTokenRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, session: &Session, access_token: String, expires_after: Duration, @@ -46,10 +46,10 @@ pub trait OAuth2AccessTokenRepository: Send + Sync { /// Revoke an access token async fn revoke( &mut self, - clock: &Clock, + clock: &dyn Clock, access_token: AccessToken, ) -> Result; /// Cleanup expired access tokens - async fn cleanup_expired(&mut self, clock: &Clock) -> Result; + async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result; } diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index ce1a716ff..96403f122 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -31,7 +31,7 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, client: &Client, redirect_uri: Url, scope: Scope, @@ -51,14 +51,14 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync { async fn fulfill( &mut self, - clock: &Clock, + clock: &dyn Clock, session: &Session, authorization_grant: AuthorizationGrant, ) -> Result; async fn exchange( &mut self, - clock: &Clock, + clock: &dyn Clock, authorization_grant: AuthorizationGrant, ) -> Result; diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 093369a40..30f37bc63 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -45,7 +45,7 @@ pub trait OAuth2ClientRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, redirect_uris: Vec, encrypted_client_secret: Option, grant_types: Vec, @@ -68,7 +68,7 @@ pub trait OAuth2ClientRepository: Send + Sync { async fn add_from_config( &mut self, mut rng: impl Rng + Send, - clock: &Clock, + clock: &dyn Clock, client_id: Ulid, client_auth_method: OAuthClientAuthenticationMethod, encrypted_client_secret: Option, @@ -86,7 +86,7 @@ pub trait OAuth2ClientRepository: Send + Sync { async fn give_consent_for_user( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, client: &Client, user: &User, scope: &Scope, diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 1e23634aa..88bb26fbb 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -36,7 +36,7 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, session: &Session, access_token: &AccessToken, refresh_token: String, @@ -45,7 +45,7 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync { /// Consume a refresh token async fn consume( &mut self, - clock: &Clock, + clock: &dyn Clock, refresh_token: RefreshToken, ) -> Result; } diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 5e6498d8b..944c77845 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -28,12 +28,12 @@ pub trait OAuth2SessionRepository: Send + Sync { async fn create_from_grant( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, grant: &AuthorizationGrant, user_session: &BrowserSession, ) -> Result; - async fn finish(&mut self, clock: &Clock, session: Session) -> Result; + async fn finish(&mut self, clock: &dyn Clock, session: Session) -> Result; async fn list_paginated( &mut self, diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index bc20c6eaf..474c21d9d 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -37,7 +37,7 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, upstream_oauth_provider: &UpstreamOAuthProvider, subject: String, ) -> Result; diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 4be8f1271..10a03e2bc 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -33,7 +33,7 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, issuer: String, scope: Scope, token_endpoint_auth_method: OAuthClientAuthenticationMethod, diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index 4d41a8eca..e1b6abc16 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -33,7 +33,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, upstream_oauth_provider: &UpstreamOAuthProvider, state: String, code_challenge_verifier: Option, @@ -43,7 +43,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { /// Mark a session as completed and associate the given link async fn complete_with_link( &mut self, - clock: &Clock, + clock: &dyn Clock, upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, @@ -52,7 +52,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { /// Mark a session as consumed async fn consume( &mut self, - clock: &Clock, + clock: &dyn Clock, upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, ) -> Result; } diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 41a7d2935..4b8f846d0 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -38,7 +38,7 @@ pub trait UserEmailRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, user: &User, email: String, ) -> Result; @@ -46,7 +46,7 @@ pub trait UserEmailRepository: Send + Sync { async fn mark_as_verified( &mut self, - clock: &Clock, + clock: &dyn Clock, user_email: UserEmail, ) -> Result; @@ -55,7 +55,7 @@ pub trait UserEmailRepository: Send + Sync { async fn add_verification_code( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, user_email: &UserEmail, max_age: chrono::Duration, code: String, @@ -63,14 +63,14 @@ pub trait UserEmailRepository: Send + Sync { async fn find_verification_code( &mut self, - clock: &Clock, + clock: &dyn Clock, user_email: &UserEmail, code: &str, ) -> Result, Self::Error>; async fn consume_verification_code( &mut self, - clock: &Clock, + clock: &dyn Clock, verification: UserEmailVerification, ) -> Result; } diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 23c2f6d1e..8e046a9dc 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -36,7 +36,7 @@ pub trait UserRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, username: String, ) -> Result; async fn exists(&mut self, username: &str) -> Result; diff --git a/crates/storage/src/user/password.rs b/crates/storage/src/user/password.rs index 2d2d25344..306f7d936 100644 --- a/crates/storage/src/user/password.rs +++ b/crates/storage/src/user/password.rs @@ -26,7 +26,7 @@ pub trait UserPasswordRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, user: &User, version: u16, hashed_password: String, diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index 2e55f40c1..5499f45ab 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -27,12 +27,12 @@ pub trait BrowserSessionRepository: Send + Sync { async fn add( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, user: &User, ) -> Result; async fn finish( &mut self, - clock: &Clock, + clock: &dyn Clock, user_session: BrowserSession, ) -> Result; async fn list_active_paginated( @@ -45,7 +45,7 @@ pub trait BrowserSessionRepository: Send + Sync { async fn authenticate_with_password( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, user_session: BrowserSession, user_password: &Password, ) -> Result; @@ -53,7 +53,7 @@ pub trait BrowserSessionRepository: Send + Sync { async fn authenticate_with_upstream( &mut self, rng: &mut (dyn RngCore + Send), - clock: &Clock, + clock: &dyn Clock, user_session: BrowserSession, upstream_oauth_link: &UpstreamOAuthLink, ) -> Result; diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index 660688608..9e31880c8 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -14,7 +14,7 @@ //! Database-related tasks -use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, Repository}; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, Repository, SystemClock}; use mas_storage_pg::PgRepository; use sqlx::{Pool, Postgres}; use tracing::{debug, error, info}; @@ -22,7 +22,7 @@ use tracing::{debug, error, info}; use super::Task; #[derive(Clone)] -struct CleanupExpired(Pool, Clock); +struct CleanupExpired(Pool, SystemClock); impl std::fmt::Debug for CleanupExpired { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -57,5 +57,5 @@ impl Task for CleanupExpired { #[must_use] pub fn cleanup_expired(pool: &Pool) -> impl Task + Clone { // XXX: the clock should come from somewhere else - CleanupExpired(pool.clone(), Clock::default()) + CleanupExpired(pool.clone(), SystemClock::default()) } From 7099a8df203d207d9b20a3da5522a955a2427761 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 18 Jan 2023 12:25:49 +0100 Subject: [PATCH 29/45] Fix rustdoc lints --- crates/axum-utils/src/http_client_factory.rs | 2 +- crates/data-model/src/compat/session.rs | 4 ++-- crates/data-model/src/tokens.rs | 4 ++-- crates/keystore/src/lib.rs | 7 +------ crates/listener/src/lib.rs | 3 +++ crates/listener/src/server.rs | 2 +- crates/spa/src/lib.rs | 2 ++ 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/crates/axum-utils/src/http_client_factory.rs b/crates/axum-utils/src/http_client_factory.rs index 25f5b1558..6eb5f9406 100644 --- a/crates/axum-utils/src/http_client_factory.rs +++ b/crates/axum-utils/src/http_client_factory.rs @@ -56,7 +56,7 @@ impl HttpClientFactory { Ok(layer.layer(client)) } - /// Constructs a new [`HttpService`], suitable for [`mas_oidc_client`] + /// Constructs a new [`HttpService`], suitable for `mas-oidc-client` /// /// # Errors /// diff --git a/crates/data-model/src/compat/session.rs b/crates/data-model/src/compat/session.rs index a5c2c17de..1dbd07228 100644 --- a/crates/data-model/src/compat/session.rs +++ b/crates/data-model/src/compat/session.rs @@ -31,7 +31,7 @@ pub enum CompatSessionState { impl CompatSessionState { /// Returns `true` if the compta session state is [`Valid`]. /// - /// [`Valid`]: ComptaSessionState::Valid + /// [`Valid`]: CompatSessionState::Valid #[must_use] pub fn is_valid(&self) -> bool { matches!(self, Self::Valid) @@ -39,7 +39,7 @@ impl CompatSessionState { /// Returns `true` if the compta session state is [`Finished`]. /// - /// [`Finished`]: ComptaSessionState::Finished + /// [`Finished`]: CompatSessionState::Finished #[must_use] pub fn is_finished(&self) -> bool { matches!(self, Self::Finished { .. }) diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index 120f293e7..2d57a663b 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -40,7 +40,7 @@ impl AccessTokenState { /// Returns `true` if the refresh token state is [`Valid`]. /// - /// [`Valid`]: RefreshTokenState::Valid + /// [`Valid`]: AccessTokenState::Valid #[must_use] pub fn is_valid(&self) -> bool { matches!(self, Self::Valid) @@ -48,7 +48,7 @@ impl AccessTokenState { /// Returns `true` if the refresh token state is [`Revoked`]. /// - /// [`Revoked`]: RefreshTokenState::Revoked + /// [`Revoked`]: AccessTokenState::Revoked #[must_use] pub fn is_revoked(&self) -> bool { matches!(self, Self::Revoked { .. }) diff --git a/crates/keystore/src/lib.rs b/crates/keystore/src/lib.rs index f1a15ee65..0899bcdc8 100644 --- a/crates/keystore/src/lib.rs +++ b/crates/keystore/src/lib.rs @@ -15,12 +15,7 @@ //! A crate to store keys which can then be used to sign and verify JWTs. #![forbid(unsafe_code)] -#![deny( - clippy::all, - clippy::str_to_string, - rustdoc::broken_intra_doc_links, - rustdoc::all -)] +#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)] #![warn(clippy::pedantic)] use std::{ops::Deref, sync::Arc}; diff --git a/crates/listener/src/lib.rs b/crates/listener/src/lib.rs index 091d5809a..5aaabee4a 100644 --- a/crates/listener/src/lib.rs +++ b/crates/listener/src/lib.rs @@ -22,6 +22,9 @@ #![warn(clippy::pedantic)] #![allow(clippy::module_name_repetitions)] +//! An utility crate to build flexible [`hyper`] listeners, with optional TLS +//! and proxy protocol support. + use self::{maybe_tls::TlsStreamInfo, proxy_protocol::ProxyProtocolV1Info}; pub mod maybe_tls; diff --git a/crates/listener/src/server.rs b/crates/listener/src/server.rs index c919017fc..2527302f1 100644 --- a/crates/listener/src/server.rs +++ b/crates/listener/src/server.rs @@ -39,7 +39,7 @@ pub struct Server { impl Server { /// # Errors /// - /// Returns an error if the listener couldn't be converted via [`TryInfo`] + /// Returns an error if the listener couldn't be converted via [`TryInto`] pub fn try_new(listener: L, service: S) -> Result where L: TryInto, diff --git a/crates/spa/src/lib.rs b/crates/spa/src/lib.rs index 58838f28a..fbef53863 100644 --- a/crates/spa/src/lib.rs +++ b/crates/spa/src/lib.rs @@ -21,6 +21,8 @@ )] #![warn(clippy::pedantic)] +//! A crate to help serve single-page apps built by Vite. + mod vite; use std::{future::Future, pin::Pin}; From 5f2a8b6cb137cd93ebd66878ab1ba8c2ca7eedce Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 18 Jan 2023 12:27:44 +0100 Subject: [PATCH 30/45] Fix rustfmt --- crates/storage-pg/src/oauth2/session.rs | 6 +++++- crates/storage/src/oauth2/session.rs | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index f1c412790..891c2278f 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -182,7 +182,11 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { ), err, )] - async fn finish(&mut self, clock: &dyn Clock, session: Session) -> Result { + async fn finish( + &mut self, + clock: &dyn Clock, + session: Session, + ) -> Result { let finished_at = clock.now(); let res = sqlx::query!( r#" diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 944c77845..2b6c63b71 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -33,7 +33,8 @@ pub trait OAuth2SessionRepository: Send + Sync { user_session: &BrowserSession, ) -> Result; - async fn finish(&mut self, clock: &dyn Clock, session: Session) -> Result; + async fn finish(&mut self, clock: &dyn Clock, session: Session) + -> Result; async fn list_paginated( &mut self, From 2265327bacdae47be3ee9ce7cca154faced7c77f Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 18 Jan 2023 17:32:54 +0100 Subject: [PATCH 31/45] handlers: box the rng and clock, and extract it from the state --- Cargo.lock | 2 +- crates/axum-utils/src/csrf.rs | 21 +++++++--- crates/axum-utils/src/user_authorization.rs | 15 ++++--- crates/data-model/src/tokens.rs | 4 +- crates/handlers/src/app_state.rs | 39 ++++++++++++++++- crates/handlers/src/compat/login.rs | 29 +++++++++---- .../handlers/src/compat/login_sso_complete.rs | 12 +++--- .../handlers/src/compat/login_sso_redirect.rs | 8 ++-- crates/handlers/src/compat/logout.rs | 5 +-- crates/handlers/src/compat/refresh.rs | 6 +-- crates/handlers/src/lib.rs | 25 +++++------ .../src/oauth2/authorization/complete.rs | 12 +++--- .../handlers/src/oauth2/authorization/mod.rs | 26 +++++++++--- crates/handlers/src/oauth2/consent.rs | 13 +++--- crates/handlers/src/oauth2/introspection.rs | 5 +-- crates/handlers/src/oauth2/registration.rs | 6 +-- crates/handlers/src/oauth2/token.rs | 42 ++++++++++++------- crates/handlers/src/oauth2/userinfo.rs | 8 ++-- .../handlers/src/upstream_oauth2/authorize.rs | 9 ++-- .../handlers/src/upstream_oauth2/callback.rs | 9 ++-- crates/handlers/src/upstream_oauth2/cookie.rs | 11 ++--- crates/handlers/src/upstream_oauth2/link.rs | 16 +++---- .../handlers/src/views/account/emails/add.rs | 12 +++--- .../handlers/src/views/account/emails/mod.rs | 13 +++--- .../src/views/account/emails/verify.rs | 11 ++--- crates/handlers/src/views/account/mod.rs | 7 ++-- crates/handlers/src/views/account/password.rs | 12 +++--- crates/handlers/src/views/index.rs | 7 ++-- crates/handlers/src/views/login.rs | 14 ++++--- crates/handlers/src/views/logout.rs | 6 +-- crates/handlers/src/views/reauth.rs | 12 +++--- crates/handlers/src/views/register.rs | 14 ++++--- crates/storage-pg/src/oauth2/client.rs | 6 +-- crates/storage/Cargo.toml | 2 +- crates/storage/src/clock.rs | 6 +++ crates/storage/src/compat/access_token.rs | 2 +- crates/storage/src/compat/refresh_token.rs | 2 +- crates/storage/src/compat/session.rs | 2 +- crates/storage/src/compat/sso_login.rs | 2 +- crates/storage/src/lib.rs | 5 +++ crates/storage/src/oauth2/access_token.rs | 2 +- .../storage/src/oauth2/authorization_grant.rs | 2 +- crates/storage/src/oauth2/client.rs | 4 +- crates/storage/src/oauth2/refresh_token.rs | 2 +- crates/storage/src/oauth2/session.rs | 2 +- crates/storage/src/upstream_oauth2/link.rs | 2 +- .../storage/src/upstream_oauth2/provider.rs | 2 +- crates/storage/src/upstream_oauth2/session.rs | 2 +- crates/storage/src/user/email.rs | 2 +- crates/storage/src/user/mod.rs | 2 +- crates/storage/src/user/password.rs | 2 +- crates/storage/src/user/session.rs | 2 +- 52 files changed, 291 insertions(+), 193 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6ecc512d5..253809cfc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3105,7 +3105,7 @@ dependencies = [ "mas-iana", "mas-jose", "oauth2-types", - "rand 0.8.5", + "rand_core 0.6.4", "thiserror", "ulid", "url", diff --git a/crates/axum-utils/src/csrf.rs b/crates/axum-utils/src/csrf.rs index e7037d55a..9886fedb1 100644 --- a/crates/axum-utils/src/csrf.rs +++ b/crates/axum-utils/src/csrf.rs @@ -15,6 +15,7 @@ use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use chrono::{DateTime, Duration, Utc}; use data_encoding::{DecodeError, BASE64URL_NOPAD}; +use mas_storage::Clock; use rand::{Rng, RngCore}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, TimestampSeconds}; @@ -108,22 +109,27 @@ pub struct ProtectedForm { } pub trait CsrfExt { - fn csrf_token(self, now: DateTime, rng: R) -> (CsrfToken, Self) + fn csrf_token(self, clock: &C, rng: R) -> (CsrfToken, Self) where - R: RngCore; - fn verify_form(&self, now: DateTime, form: ProtectedForm) -> Result; + R: RngCore, + C: Clock; + fn verify_form(&self, clock: &C, form: ProtectedForm) -> Result + where + C: Clock; } impl CsrfExt for PrivateCookieJar { - fn csrf_token(self, now: DateTime, rng: R) -> (CsrfToken, Self) + fn csrf_token(self, clock: &C, rng: R) -> (CsrfToken, Self) where R: RngCore, + C: Clock, { let jar = self; let mut cookie = jar.get("csrf").unwrap_or_else(|| Cookie::new("csrf", "")); cookie.set_path("/"); cookie.set_http_only(true); + let now = clock.now(); let new_token = cookie .decode() .ok() @@ -136,10 +142,13 @@ impl CsrfExt for PrivateCookieJar { (new_token, jar) } - fn verify_form(&self, now: DateTime, form: ProtectedForm) -> Result { + fn verify_form(&self, clock: &C, form: ProtectedForm) -> Result + where + C: Clock, + { let cookie = self.get("csrf").ok_or(CsrfError::Missing)?; let token: CsrfToken = cookie.decode()?; - let token = token.verify_expiration(now)?; + let token = token.verify_expiration(clock.now())?; token.verify_form_value(&form.csrf)?; Ok(form.inner) } diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index 11d793122..9a5956c9e 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -24,13 +24,12 @@ use axum::{ response::{IntoResponse, Response}, BoxError, }; -use chrono::{DateTime, Utc}; use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName}; use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode}; use mas_data_model::Session; use mas_storage::{ oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository}, - Repository, + Clock, Repository, }; use serde::{de::DeserializeOwned, Deserialize}; use thiserror::Error; @@ -86,10 +85,10 @@ pub struct UserAuthorization { impl UserAuthorization { // TODO: take scopes to validate as parameter - pub async fn protected_form( + pub async fn protected_form( self, repo: &mut R, - now: DateTime, + clock: &C, ) -> Result<(Session, F), AuthorizationVerificationError> { let form = match self.form { Some(f) => f, @@ -98,7 +97,7 @@ impl UserAuthorization { let (token, session) = self.access_token.fetch(repo).await?; - if !token.is_valid(now) || !session.is_valid() { + if !token.is_valid(clock.now()) || !session.is_valid() { return Err(AuthorizationVerificationError::InvalidToken); } @@ -106,14 +105,14 @@ impl UserAuthorization { } // TODO: take scopes to validate as parameter - pub async fn protected( + pub async fn protected( self, repo: &mut R, - now: DateTime, + clock: &C, ) -> Result> { let (token, session) = self.access_token.fetch(repo).await?; - if !token.is_valid(now) || !session.is_valid() { + if !token.is_valid(clock.now()) || !session.is_valid() { return Err(AuthorizationVerificationError::InvalidToken); } diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index 2d57a663b..ad8c407e4 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -15,7 +15,7 @@ use chrono::{DateTime, Utc}; use crc::{Crc, CRC_32_ISO_HDLC}; use mas_iana::oauth::OAuthTokenTypeHint; -use rand::{distributions::Alphanumeric, Rng}; +use rand::{distributions::Alphanumeric, Rng, RngCore}; use thiserror::Error; use ulid::Ulid; @@ -193,7 +193,7 @@ impl TokenType { /// AccessToken.generate(thread_rng()); /// RefreshToken.generate(thread_rng()); /// ``` - pub fn generate(self, rng: impl Rng) -> String { + pub fn generate(self, rng: &mut (impl RngCore + ?Sized)) -> String { let random_part: String = rng .sample_iter(&Alphanumeric) .take(30) diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index c9c650904..45446e107 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -12,15 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; +use std::{convert::Infallible, sync::Arc}; -use axum::extract::FromRef; +use axum::{ + async_trait, + extract::{FromRef, FromRequestParts}, +}; use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_email::Mailer; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; +use mas_storage::{BoxClock, BoxRng, SystemClock}; use mas_templates::Templates; +use rand::SeedableRng; use sqlx::PgPool; use crate::{passwords::PasswordManager, MatrixHomeserver}; @@ -105,3 +110,33 @@ impl FromRef for PasswordManager { input.password_manager.clone() } } + +#[async_trait] +impl FromRequestParts for BoxClock { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + _state: &AppState, + ) -> Result { + let clock = SystemClock::default(); + Ok(Box::new(clock)) + } +} + +#[async_trait] +impl FromRequestParts for BoxRng { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + _state: &AppState, + ) -> Result { + // This rng is used to source the local rng + #[allow(clippy::disallowed_methods)] + let rng = rand::thread_rng(); + + let rng = rand_chacha::ChaChaRng::from_rng(rng).expect("Failed to seed RNG"); + Ok(Box::new(rng)) + } +} diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 9b0a0792f..279b575d3 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -22,9 +22,10 @@ use mas_storage::{ CompatSsoLoginRepository, }, user::{UserPasswordRepository, UserRepository}, - Clock, Repository, SystemClock, + BoxClock, BoxRng, Clock, Repository, }; use mas_storage_pg::PgRepository; +use rand::{CryptoRng, RngCore}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; use sqlx::PgPool; @@ -154,7 +155,6 @@ pub enum RouteError { InvalidLoginToken, } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { @@ -194,18 +194,29 @@ impl IntoResponse for RouteError { #[tracing::instrument(skip_all, err)] pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(password_manager): State, State(pool): State, State(homeserver): State, Json(input): Json, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; let (session, user) = match input.credentials { Credentials::Password { identifier: Identifier::User { user }, password, - } => user_password_login(&password_manager, &mut repo, user, password).await?, + } => { + user_password_login( + &mut rng, + &clock, + &password_manager, + &mut repo, + user, + password, + ) + .await? + } Credentials::Token { token } => token_login(&mut repo, &clock, &token).await?, @@ -254,7 +265,7 @@ pub(crate) async fn post( async fn token_login( repo: &mut PgRepository, - clock: &SystemClock, + clock: &dyn Clock, token: &str, ) -> Result<(CompatSession, User), RouteError> { let login = repo @@ -319,13 +330,13 @@ async fn token_login( } async fn user_password_login( + mut rng: &mut (impl RngCore + CryptoRng + Send), + clock: &impl Clock, password_manager: &PasswordManager, repo: &mut PgRepository, username: String, password: String, ) -> Result<(CompatSession, User), RouteError> { - let (clock, mut rng) = crate::clock_and_rng(); - // Find the user let user = repo .user() @@ -358,7 +369,7 @@ async fn user_password_login( repo.user_password() .add( &mut rng, - &clock, + clock, &user, version, hashed_password, @@ -371,7 +382,7 @@ async fn user_password_login( let device = Device::generate(&mut rng); let session = repo .compat_session() - .add(&mut rng, &clock, &user, device) + .add(&mut rng, clock, &user, device) .await?; Ok((session, user)) diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 631d1126d..6201b0c64 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -31,7 +31,7 @@ use mas_keystore::Encrypter; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; use mas_storage::{ compat::{CompatSessionRepository, CompatSsoLoginRepository}, - Clock, Repository, + BoxClock, BoxRng, Clock, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; @@ -54,17 +54,18 @@ pub struct Params { } pub async fn get( + mut rng: BoxRng, + clock: BoxClock, State(pool): State, State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, Query(params): Query, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let maybe_session = session_info.load_session(&mut repo).await?; @@ -117,6 +118,8 @@ pub async fn get( } pub async fn post( + mut rng: BoxRng, + clock: BoxClock, State(pool): State, State(templates): State, cookie_jar: PrivateCookieJar, @@ -124,11 +127,10 @@ pub async fn post( Query(params): Query, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); - cookie_jar.verify_form(clock.now(), form)?; + cookie_jar.verify_form(&clock, form)?; let maybe_session = session_info.load_session(&mut repo).await?; diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index 38aa08943..a8063141a 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -19,7 +19,7 @@ use axum::{ }; use hyper::StatusCode; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; -use mas_storage::{compat::CompatSsoLoginRepository, Repository}; +use mas_storage::{compat::CompatSsoLoginRepository, BoxClock, BoxRng, Repository}; use mas_storage_pg::PgRepository; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; @@ -49,7 +49,6 @@ pub enum RouteError { InvalidRedirectUrl, } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { @@ -58,14 +57,13 @@ impl IntoResponse for RouteError { } } -#[tracing::instrument(skip(pool, url_builder), err)] pub async fn get( + mut rng: BoxRng, + clock: BoxClock, State(pool): State, State(url_builder): State, Query(params): Query, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - // Check the redirectUrl parameter let redirect_url = params.redirect_url.ok_or(RouteError::MissingRedirectUrl)?; let redirect_url = Url::parse(&redirect_url).map_err(|_| RouteError::InvalidRedirectUrl)?; diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 310fef2e7..bfc767faf 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -18,7 +18,7 @@ use hyper::StatusCode; use mas_data_model::TokenType; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - Clock, Repository, SystemClock, + BoxClock, Clock, Repository, }; use mas_storage_pg::PgRepository; use sqlx::PgPool; @@ -42,7 +42,6 @@ pub enum RouteError { InvalidAuthorization, } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { @@ -69,10 +68,10 @@ impl IntoResponse for RouteError { } pub(crate) async fn post( + clock: BoxClock, State(pool): State, maybe_authorization: Option>>, ) -> Result { - let clock = SystemClock::default(); let mut repo = PgRepository::from_pool(&pool).await?; let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 8b47a81f3..868be9db1 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -18,7 +18,7 @@ use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, - Clock, Repository, + BoxClock, BoxRng, Clock, Repository, }; use mas_storage_pg::PgRepository; use serde::{Deserialize, Serialize}; @@ -70,7 +70,6 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl From for RouteError { @@ -89,10 +88,11 @@ pub struct ResponseBody { } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(pool): State, Json(input): Json, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; let token_type = TokenType::check(&input.refresh_token)?; diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 0360e7418..4d9dcbcdc 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -28,7 +28,7 @@ use std::{convert::Infallible, sync::Arc, time::Duration}; use axum::{ body::{Bytes, HttpBody}, - extract::FromRef, + extract::{FromRef, FromRequestParts}, response::{Html, IntoResponse}, routing::{get, on, post, MethodFilter}, Router, @@ -40,9 +40,9 @@ use mas_http::CorsLayerExt; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::{Route, UrlBuilder}; +use mas_storage::{BoxClock, BoxRng}; use mas_templates::{ErrorContext, Templates}; use passwords::PasswordManager; -use rand::SeedableRng; use sqlx::PgPool; use tower::util::AndThenLayer; use tower_http::cors::{Any, CorsLayer}; @@ -116,6 +116,8 @@ where S: Clone + Send + Sync + 'static, Keystore: FromRef, UrlBuilder: FromRef, + BoxClock: FromRequestParts, + BoxRng: FromRequestParts, { Router::new() .route( @@ -155,6 +157,8 @@ where PgPool: FromRef, Encrypter: FromRef, HttpClientFactory: FromRef, + BoxClock: FromRequestParts, + BoxRng: FromRequestParts, { // All those routes are API-like, with a common CORS layer Router::new() @@ -208,6 +212,8 @@ where PgPool: FromRef, MatrixHomeserver: FromRef, PasswordManager: FromRef, + BoxClock: FromRequestParts, + BoxRng: FromRequestParts, { Router::new() .route( @@ -255,6 +261,8 @@ where Keystore: FromRef, HttpClientFactory: FromRef, PasswordManager: FromRef, + BoxClock: FromRequestParts, + BoxRng: FromRequestParts, { Router::new() .route( @@ -407,16 +415,3 @@ async fn test_state(pool: PgPool) -> Result { password_manager, }) } - -// XXX: that should be moved somewhere else -fn clock_and_rng() -> (mas_storage::SystemClock, rand_chacha::ChaChaRng) { - let clock = mas_storage::SystemClock::default(); - - // This rng is used to source the local rng - #[allow(clippy::disallowed_methods)] - let rng = rand::thread_rng(); - - let rng = rand_chacha::ChaChaRng::from_rng(rng).expect("Failed to seed RNG"); - - (clock, rng) -} diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 05554e128..934ba088f 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -27,7 +27,7 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, - Repository, + BoxClock, BoxRng, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::Templates; @@ -70,7 +70,6 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); @@ -79,6 +78,8 @@ impl_from_error_for_route!(super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(super::callback::CallbackDestinationError); pub(crate) async fn get( + rng: BoxRng, + clock: BoxClock, State(policy_factory): State>, State(templates): State, State(pool): State, @@ -108,7 +109,7 @@ pub(crate) async fn get( return Ok((cookie_jar, mas_router::Login::and_then(continue_grant).go()).into_response()); }; - match complete(grant, session, &policy_factory, repo).await { + match complete(rng, clock, grant, session, &policy_factory, repo).await { Ok(params) => { let res = callback_destination.go(&templates, params).await?; Ok((cookie_jar, res).into_response()) @@ -149,7 +150,6 @@ pub enum GrantCompletionError { NoSuchClient, } -impl_from_error_for_route!(GrantCompletionError: sqlx::Error); impl_from_error_for_route!(GrantCompletionError: mas_storage_pg::DatabaseError); impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); @@ -157,13 +157,13 @@ impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError); impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError); pub(crate) async fn complete( + mut rng: BoxRng, + clock: BoxClock, grant: AuthorizationGrant, browser_session: BrowserSession, policy_factory: &PolicyFactory, mut repo: PgRepository, ) -> Result>, GrantCompletionError> { - let (clock, mut rng) = crate::clock_and_rng(); - // Verify that the grant is in a pending stage if !grant.stage.is_pending() { return Err(GrantCompletionError::NotPending); diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 43bda928b..607534823 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -27,7 +27,7 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - Repository, + BoxClock, BoxRng, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::Templates; @@ -91,7 +91,6 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(mas_policy::LoadError); @@ -133,13 +132,14 @@ fn resolve_response_mode( #[allow(clippy::too_many_lines)] pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(policy_factory): State>, State(templates): State, State(pool): State, cookie_jar: PrivateCookieJar, Form(params): Form, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; // First, figure out what client it is @@ -334,7 +334,15 @@ pub(crate) async fn get( // Else, we immediately try to complete the authorization grant Some(user_session) if prompt.contains(&Prompt::None) => { // With prompt=none, we should get back to the client immediately - match self::complete::complete(grant, user_session, &policy_factory, repo).await + match self::complete::complete( + rng, + clock, + grant, + user_session, + &policy_factory, + repo, + ) + .await { Ok(params) => callback_destination.go(&templates, params).await?, Err(GrantCompletionError::RequiresConsent) => { @@ -373,7 +381,15 @@ pub(crate) async fn get( Some(user_session) => { let grant_id = grant.id; // Else, we show the relevant reauth/consent page if necessary - match self::complete::complete(grant, user_session, &policy_factory, repo).await + match self::complete::complete( + rng, + clock, + grant, + user_session, + &policy_factory, + repo, + ) + .await { Ok(params) => callback_destination.go(&templates, params).await?, Err( diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 8fe4d2acf..86c832fb2 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -30,7 +30,7 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - Clock, Repository, + BoxClock, BoxRng, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; @@ -61,7 +61,6 @@ pub enum RouteError { NoSuchClient, } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_policy::LoadError); @@ -75,13 +74,14 @@ impl IntoResponse for RouteError { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(policy_factory): State>, State(templates): State, State(pool): State, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -99,7 +99,7 @@ pub(crate) async fn get( } if let Some(session) = maybe_session { - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let mut policy = policy_factory.instantiate().await?; let res = policy @@ -130,16 +130,17 @@ pub(crate) async fn get( } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(policy_factory): State>, State(pool): State, cookie_jar: PrivateCookieJar, Path(grant_id): Path, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - cookie_jar.verify_form(clock.now(), form)?; + cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 245f2125c..d8e64fa06 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -25,7 +25,7 @@ use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository}, user::{BrowserSessionRepository, UserRepository}, - Clock, Repository, SystemClock, + BoxClock, Clock, Repository, }; use mas_storage_pg::PgRepository; use oauth2_types::{ @@ -97,7 +97,6 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl From for RouteError { @@ -125,12 +124,12 @@ const API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:org.matrix.msc #[allow(clippy::too_many_lines)] pub(crate) async fn post( + clock: BoxClock, State(http_client_factory): State, State(pool): State, State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { - let clock = SystemClock::default(); let mut repo = PgRepository::from_pool(&pool).await?; let client = client_authorization diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index 8e9489e81..da043b8bd 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -19,7 +19,7 @@ use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; use mas_policy::{PolicyFactory, Violation}; -use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRng, Repository}; use mas_storage_pg::PgRepository; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -49,7 +49,6 @@ pub(crate) enum RouteError { PolicyDenied(Vec), } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); @@ -108,12 +107,13 @@ impl IntoResponse for RouteError { #[tracing::instrument(skip_all, err)] pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(pool): State, State(policy_factory): State>, State(encrypter): State, Json(body): Json, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); info!(?body, "Client registration"); // Validate the body diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index de0118f4b..3fe916d81 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -37,7 +37,7 @@ use mas_storage::{ OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, user::BrowserSessionRepository, - Clock, Repository, + BoxClock, BoxRng, Clock, Repository, }; use mas_storage_pg::PgRepository; use oauth2_types::{ @@ -151,7 +151,6 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::claims::ClaimError); @@ -160,6 +159,8 @@ impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); #[tracing::instrument(skip_all, err)] pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(http_client_factory): State, State(key_store): State, State(url_builder): State, @@ -189,10 +190,19 @@ pub(crate) async fn post( let reply = match form { AccessTokenRequest::AuthorizationCode(grant) => { - authorization_code_grant(&grant, &client, &key_store, &url_builder, repo).await? + authorization_code_grant( + &mut rng, + &clock, + &grant, + &client, + &key_store, + &url_builder, + repo, + ) + .await? } AccessTokenRequest::RefreshToken(grant) => { - refresh_token_grant(&grant, &client, repo).await? + refresh_token_grant(&mut rng, &clock, &grant, &client, repo).await? } _ => { return Err(RouteError::InvalidGrant); @@ -208,14 +218,14 @@ pub(crate) async fn post( #[allow(clippy::too_many_lines)] async fn authorization_code_grant( + mut rng: &mut BoxRng, + clock: &impl Clock, grant: &AuthorizationCodeGrant, client: &Client, key_store: &Keystore, url_builder: &UrlBuilder, mut repo: PgRepository, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let authz_grant = repo .oauth2_authorization_grant() .find_by_code(&grant.code) @@ -244,7 +254,7 @@ async fn authorization_code_grant( .lookup(session_id) .await? .ok_or(RouteError::NoSuchOAuthSession)?; - repo.oauth2_session().finish(&clock, session).await?; + repo.oauth2_session().finish(clock, session).await?; repo.save().await?; } @@ -302,12 +312,12 @@ async fn authorization_code_grant( let access_token = repo .oauth2_access_token() - .add(&mut rng, &clock, &session, access_token_str, ttl) + .add(&mut rng, clock, &session, access_token_str, ttl) .await?; let refresh_token = repo .oauth2_refresh_token() - .add(&mut rng, &clock, &session, &access_token, refresh_token_str) + .add(&mut rng, clock, &session, &access_token, refresh_token_str) .await?; let id_token = if session.scope.contains(&scope::OPENID) { @@ -357,7 +367,7 @@ async fn authorization_code_grant( } repo.oauth2_authorization_grant() - .exchange(&clock, authz_grant) + .exchange(clock, authz_grant) .await?; repo.save().await?; @@ -366,12 +376,12 @@ async fn authorization_code_grant( } async fn refresh_token_grant( + mut rng: &mut BoxRng, + clock: &impl Clock, grant: &RefreshTokenGrant, client: &Client, mut repo: PgRepository, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let refresh_token = repo .oauth2_refresh_token() .find_by_token(&grant.refresh_token) @@ -399,14 +409,14 @@ async fn refresh_token_grant( let new_access_token = repo .oauth2_access_token() - .add(&mut rng, &clock, &session, access_token_str.clone(), ttl) + .add(&mut rng, clock, &session, access_token_str.clone(), ttl) .await?; let new_refresh_token = repo .oauth2_refresh_token() .add( &mut rng, - &clock, + clock, &session, &new_access_token, refresh_token_str, @@ -415,13 +425,13 @@ async fn refresh_token_grant( let refresh_token = repo .oauth2_refresh_token() - .consume(&clock, refresh_token) + .consume(clock, refresh_token) .await?; if let Some(access_token_id) = refresh_token.access_token_id { if let Some(access_token) = repo.oauth2_access_token().lookup(access_token_id).await? { repo.oauth2_access_token() - .revoke(&clock, access_token) + .revoke(clock, access_token) .await?; } } diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 39bc7587b..9d60ac1fa 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -31,7 +31,7 @@ use mas_router::UrlBuilder; use mas_storage::{ oauth2::OAuth2ClientRepository, user::{BrowserSessionRepository, UserEmailRepository}, - Clock, Repository, + BoxClock, BoxRng, Repository, }; use mas_storage_pg::PgRepository; use oauth2_types::scope; @@ -79,7 +79,6 @@ pub enum RouteError { NoSuchBrowserSession, } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); @@ -99,15 +98,16 @@ impl IntoResponse for RouteError { } pub async fn get( + mut rng: BoxRng, + clock: BoxClock, State(url_builder): State, State(pool): State, State(key_store): State, user_authorization: UserAuthorization, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let session = user_authorization.protected(&mut repo, clock.now()).await?; + let session = user_authorization.protected(&mut repo, &clock).await?; let browser_session = repo .browser_session() diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 565acea89..d66493170 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -24,7 +24,7 @@ use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; use mas_storage::{ upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, - Clock, Repository, + BoxClock, BoxRng, Repository, }; use mas_storage_pg::PgRepository; use sqlx::PgPool; @@ -43,7 +43,6 @@ pub(crate) enum RouteError { Internal(Box), } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError); @@ -59,6 +58,8 @@ impl IntoResponse for RouteError { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(http_client_factory): State, State(pool): State, State(url_builder): State, @@ -66,8 +67,6 @@ pub(crate) async fn get( Path(provider_id): Path, Query(query): Query, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut repo = PgRepository::from_pool(&pool).await?; let provider = repo @@ -115,7 +114,7 @@ pub(crate) async fn get( let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar) .add(session.id, provider.id, data.state, query.post_auth_action) - .save(cookie_jar, clock.now()); + .save(cookie_jar, &clock); repo.save().await?; diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index ffde4ef32..fd66af094 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -30,7 +30,7 @@ use mas_storage::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, - Clock, Repository, + BoxClock, BoxRng, Clock, Repository, }; use mas_storage_pg::PgRepository; use oauth2_types::errors::ClientErrorCode; @@ -102,7 +102,6 @@ pub(crate) enum RouteError { impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_http::ClientInitError); -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::JwksError); impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError); @@ -122,6 +121,8 @@ impl IntoResponse for RouteError { #[allow(clippy::too_many_lines, clippy::too_many_arguments)] pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(http_client_factory): State, State(pool): State, State(url_builder): State, @@ -131,8 +132,6 @@ pub(crate) async fn get( Path(provider_id): Path, Query(params): Query, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut repo = PgRepository::from_pool(&pool).await?; let provider = repo @@ -268,7 +267,7 @@ pub(crate) async fn get( let cookie_jar = sessions_cookie .add_link_to_session(session.id, link.id)? - .save(cookie_jar, clock.now()); + .save(cookie_jar, &clock); repo.save().await?; diff --git a/crates/handlers/src/upstream_oauth2/cookie.rs b/crates/handlers/src/upstream_oauth2/cookie.rs index be1d3edf0..92cfa5650 100644 --- a/crates/handlers/src/upstream_oauth2/cookie.rs +++ b/crates/handlers/src/upstream_oauth2/cookie.rs @@ -18,6 +18,7 @@ use axum_extra::extract::{cookie::Cookie, PrivateCookieJar}; use chrono::{DateTime, Duration, NaiveDateTime, Utc}; use mas_axum_utils::CookieExt; use mas_router::PostAuthAction; +use mas_storage::Clock; use serde::{Deserialize, Serialize}; use thiserror::Error; use time::OffsetDateTime; @@ -65,11 +66,11 @@ impl UpstreamSessions { } /// Save the upstreams sessions to the cookie jar - pub fn save( - self, - cookie_jar: PrivateCookieJar, - now: DateTime, - ) -> PrivateCookieJar { + pub fn save(self, cookie_jar: PrivateCookieJar, clock: &C) -> PrivateCookieJar + where + C: Clock, + { + let now = clock.now(); let this = self.expire(now); let mut cookie = Cookie::named(COOKIE_NAME).encode(&this); cookie.set_path("/"); diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index c0408770b..d318fc3e3 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -27,7 +27,7 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, user::{BrowserSessionRepository, UserRepository}, - Clock, Repository, + BoxClock, BoxRng, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{ @@ -70,7 +70,6 @@ pub(crate) enum RouteError { Internal(Box), } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); @@ -95,14 +94,14 @@ pub(crate) enum FormData { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(pool): State, State(templates): State, cookie_jar: PrivateCookieJar, Path(link_id): Path, ) -> Result { let mut repo = PgRepository::from_pool(&pool).await?; - let (clock, mut rng) = crate::clock_and_rng(); - let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let (session_id, _post_auth_action) = sessions_cookie .lookup_link(link_id) @@ -131,7 +130,7 @@ pub(crate) async fn get( } let (user_session_info, cookie_jar) = cookie_jar.session_info(); - let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let maybe_user_session = user_session_info.load_session(&mut repo).await?; let render = match (maybe_user_session, link.user_id) { @@ -212,14 +211,15 @@ pub(crate) async fn get( } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(pool): State, cookie_jar: PrivateCookieJar, Path(link_id): Path, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let (session_id, post_auth_action) = sessions_cookie @@ -297,7 +297,7 @@ pub(crate) async fn post( let cookie_jar = sessions_cookie .consume_link(link_id)? - .save(cookie_jar, clock.now()); + .save(cookie_jar, &clock); let cookie_jar = cookie_jar.set_session(&session); repo.save().await?; diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 4c11a485a..1c8c6665e 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -24,7 +24,7 @@ use mas_axum_utils::{ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Clock, Repository}; +use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository}; use mas_storage_pg::PgRepository; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; @@ -39,14 +39,15 @@ pub struct EmailForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, State(pool): State, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; @@ -68,16 +69,17 @@ pub(crate) async fn get( } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(pool): State, State(mailer): State, cookie_jar: PrivateCookieJar, Query(query): Query, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index f1b7c733a..10772b871 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -28,7 +28,7 @@ use mas_data_model::{BrowserSession, User, UserEmail}; use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Clock, Repository}; +use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Clock, Repository}; use mas_storage_pg::PgRepository; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; @@ -49,12 +49,12 @@ pub enum ManagementForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, State(pool): State, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -77,7 +77,7 @@ async fn render( cookie_jar: PrivateCookieJar, repo: &mut impl Repository, ) -> Result { - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); let emails = repo.user_email().all(&session.user).await?; @@ -124,13 +124,14 @@ async fn start_email_verification( } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, State(pool): State, State(mailer): State, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -144,7 +145,7 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; match form { ManagementForm::Add { email } => { diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index f37b15e46..644810e54 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -24,7 +24,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Clock, Repository, SystemClock}; +use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository}; use mas_storage_pg::PgRepository; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; @@ -39,16 +39,17 @@ pub struct CodeForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, State(pool): State, Query(query): Query, Path(id): Path, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; @@ -83,16 +84,16 @@ pub(crate) async fn get( } pub(crate) async fn post( + clock: BoxClock, State(pool): State, cookie_jar: PrivateCookieJar, Query(query): Query, Path(id): Path, Form(form): Form>, ) -> Result { - let clock = SystemClock::default(); let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 29aaeda32..660c14162 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -25,21 +25,22 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository}, - Clock, Repository, + BoxClock, BoxRng, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{AccountContext, TemplateContext, Templates}; use sqlx::PgPool; pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, State(pool): State, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 1624f6f60..a9f17123b 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -27,7 +27,7 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - Clock, Repository, + BoxClock, BoxRng, Clock, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{EmptyContext, TemplateContext, Templates}; @@ -46,11 +46,12 @@ pub struct ChangeForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, State(pool): State, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -72,7 +73,7 @@ async fn render( session: BrowserSession, cookie_jar: PrivateCookieJar, ) -> Result { - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); let ctx = EmptyContext .with_session(session) @@ -84,16 +85,17 @@ async fn render( } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(password_manager): State, State(templates): State, State(pool): State, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index 2298e83e7..ffe500e72 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -20,21 +20,22 @@ use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_keystore::Encrypter; use mas_router::UrlBuilder; -use mas_storage::Clock; +use mas_storage::{BoxClock, BoxRng}; use mas_storage_pg::PgRepository; use mas_templates::{IndexContext, TemplateContext, Templates}; use sqlx::PgPool; pub async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, State(url_builder): State, State(pool): State, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); let session = session_info.load_session(&mut repo).await?; diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 10295dbc3..d8abcb467 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -26,7 +26,7 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, - Clock, Repository, + BoxClock, BoxRng, Clock, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{ @@ -51,15 +51,16 @@ impl ToFormState for LoginForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, State(pool): State, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; @@ -83,6 +84,8 @@ pub(crate) async fn get( } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(password_manager): State, State(templates): State, State(pool): State, @@ -90,12 +93,11 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); // Validate the form let state = { diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 781780ca4..f8491cb95 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -23,19 +23,19 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; -use mas_storage::{user::BrowserSessionRepository, Clock, Repository, SystemClock}; +use mas_storage::{user::BrowserSessionRepository, BoxClock, Repository}; use mas_storage_pg::PgRepository; use sqlx::PgPool; pub(crate) async fn post( + clock: BoxClock, State(pool): State, cookie_jar: PrivateCookieJar, Form(form): Form>>, ) -> Result { - let clock = SystemClock::default(); let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; let (session_info, mut cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 571d2e179..9c2330a30 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -26,7 +26,7 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - Clock, Repository, + BoxClock, BoxRng, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{ReauthContext, TemplateContext, Templates}; @@ -43,15 +43,16 @@ pub(crate) struct ReauthForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, State(pool): State, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; @@ -80,16 +81,17 @@ pub(crate) async fn get( } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(password_manager): State, State(pool): State, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 3e2d87a97..a8fc7baee 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -33,7 +33,7 @@ use mas_policy::PolicyFactory; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - Clock, Repository, + BoxClock, BoxRng, Repository, }; use mas_storage_pg::PgRepository; use mas_templates::{ @@ -61,15 +61,16 @@ impl ToFormState for RegisterForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, State(pool): State, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; @@ -93,6 +94,8 @@ pub(crate) async fn get( #[allow(clippy::too_many_lines, clippy::too_many_arguments)] pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(password_manager): State, State(mailer): State, State(policy_factory): State>, @@ -102,12 +105,11 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); // Validate the form let state = { diff --git a/crates/storage-pg/src/oauth2/client.rs b/crates/storage-pg/src/oauth2/client.rs index 1b8a99f93..e17245aad 100644 --- a/crates/storage-pg/src/oauth2/client.rs +++ b/crates/storage-pg/src/oauth2/client.rs @@ -30,7 +30,7 @@ use oauth2_types::{ requests::GrantType, scope::{Scope, ScopeToken}, }; -use rand::{Rng, RngCore}; +use rand::RngCore; use sqlx::PgConnection; use tracing::{info_span, Instrument}; use ulid::Ulid; @@ -534,7 +534,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { )] async fn add_from_config( &mut self, - mut rng: impl Rng + Send, + rng: &mut (dyn RngCore + Send), clock: &dyn Clock, client_id: Ulid, client_auth_method: OAuthClientAuthenticationMethod, @@ -597,7 +597,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { .iter() .map(|uri| { ( - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), + Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut *rng)), uri.as_str().to_owned(), ) }) diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index 97089e956..86ca9f078 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -10,7 +10,7 @@ async-trait = "0.1.60" chrono = "0.4.23" thiserror = "1.0.38" -rand = "0.8.5" +rand_core = "0.6.4" url = "2.3.1" ulid = "1.0.0" diff --git a/crates/storage/src/clock.rs b/crates/storage/src/clock.rs index 54ca3b4fc..04c69f25a 100644 --- a/crates/storage/src/clock.rs +++ b/crates/storage/src/clock.rs @@ -28,6 +28,12 @@ pub trait Clock: Sync { fn now(&self) -> DateTime; } +impl Clock for Box { + fn now(&self) -> DateTime { + (**self).now() + } +} + /// A clock which uses the system time #[derive(Clone, Default)] pub struct SystemClock { diff --git a/crates/storage/src/compat/access_token.rs b/crates/storage/src/compat/access_token.rs index e135be398..32ba1f735 100644 --- a/crates/storage/src/compat/access_token.rs +++ b/crates/storage/src/compat/access_token.rs @@ -15,7 +15,7 @@ use async_trait::async_trait; use chrono::Duration; use mas_data_model::{CompatAccessToken, CompatSession}; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use crate::Clock; diff --git a/crates/storage/src/compat/refresh_token.rs b/crates/storage/src/compat/refresh_token.rs index f8e8dfaa8..627b59a12 100644 --- a/crates/storage/src/compat/refresh_token.rs +++ b/crates/storage/src/compat/refresh_token.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use mas_data_model::{CompatAccessToken, CompatRefreshToken, CompatSession}; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use crate::Clock; diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index fa5cbd6e6..0c5bc125c 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use mas_data_model::{CompatSession, Device, User}; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use crate::Clock; diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs index 6ee3270a6..1ed3e5d80 100644 --- a/crates/storage/src/compat/sso_login.rs +++ b/crates/storage/src/compat/sso_login.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use mas_data_model::{CompatSession, CompatSsoLogin, User}; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use url::Url; diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 2a83b55e7..d5a453726 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -37,8 +37,13 @@ pub(crate) mod repository; pub mod upstream_oauth2; pub mod user; +use rand_core::CryptoRngCore; + pub use self::{ clock::{Clock, SystemClock}, pagination::{Page, Pagination}, repository::Repository, }; + +pub type BoxClock = Box; +pub type BoxRng = Box; diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index 4bbcf8857..1148136f0 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -15,7 +15,7 @@ use async_trait::async_trait; use chrono::Duration; use mas_data_model::{AccessToken, Session}; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use crate::Clock; diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 96403f122..1130e6a8a 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -17,7 +17,7 @@ use std::num::NonZeroU32; use async_trait::async_trait; use mas_data_model::{AuthorizationCode, AuthorizationGrant, Client, Session}; use oauth2_types::{requests::ResponseMode, scope::Scope}; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use url::Url; diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 30f37bc63..3c7d7dbb3 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -19,7 +19,7 @@ use mas_data_model::{Client, User}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_jose::jwk::PublicJsonWebKeySet; use oauth2_types::{requests::GrantType, scope::Scope}; -use rand::{Rng, RngCore}; +use rand_core::RngCore; use ulid::Ulid; use url::Url; @@ -67,7 +67,7 @@ pub trait OAuth2ClientRepository: Send + Sync { #[allow(clippy::too_many_arguments)] async fn add_from_config( &mut self, - mut rng: impl Rng + Send, + rng: &mut (dyn RngCore + Send), clock: &dyn Clock, client_id: Ulid, client_auth_method: OAuthClientAuthenticationMethod, diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 88bb26fbb..66ec2c328 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use mas_data_model::{AccessToken, RefreshToken, Session}; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use crate::Clock; diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 2b6c63b71..3813810b8 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User}; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use crate::{pagination::Page, Clock, Pagination}; diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 474c21d9d..bf9e0aadd 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use crate::{pagination::Page, Clock, Pagination}; diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 10a03e2bc..521a7e7a0 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -16,7 +16,7 @@ use async_trait::async_trait; use mas_data_model::UpstreamOAuthProvider; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use oauth2_types::scope::Scope; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use crate::{pagination::Page, Clock, Pagination}; diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index e1b6abc16..f4441b2ab 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider}; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use crate::Clock; diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 4b8f846d0..65ee465b9 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use mas_data_model::{User, UserEmail, UserEmailVerification}; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use crate::{pagination::Page, Clock, Pagination}; diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 8e046a9dc..b3bd0bc25 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use mas_data_model::User; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use crate::Clock; diff --git a/crates/storage/src/user/password.rs b/crates/storage/src/user/password.rs index 306f7d936..609198b22 100644 --- a/crates/storage/src/user/password.rs +++ b/crates/storage/src/user/password.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use mas_data_model::{Password, User}; -use rand::RngCore; +use rand_core::RngCore; use crate::Clock; diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index 5499f45ab..5556547c0 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use mas_data_model::{BrowserSession, Password, UpstreamOAuthLink, User}; -use rand::RngCore; +use rand_core::RngCore; use ulid::Ulid; use crate::{pagination::Page, Clock, Pagination}; From 34136a2a978f07b99b774a42b3d069e3aa2e9af2 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 18 Jan 2023 18:21:45 +0100 Subject: [PATCH 32/45] handlers: extract the PgRepository from the request Also fix a bunch of clippy errors & doctests --- crates/data-model/src/tokens.rs | 4 +-- crates/handlers/src/app_state.rs | 27 +++++++++++++++++++ crates/handlers/src/compat/login.rs | 4 +-- .../handlers/src/compat/login_sso_complete.rs | 9 ++----- .../handlers/src/compat/login_sso_redirect.rs | 4 +-- crates/handlers/src/compat/logout.rs | 7 ++--- crates/handlers/src/compat/refresh.rs | 7 ++--- crates/handlers/src/lib.rs | 14 ++++++---- .../src/oauth2/authorization/complete.rs | 5 +--- .../handlers/src/oauth2/authorization/mod.rs | 5 +--- crates/handlers/src/oauth2/consent.rs | 9 ++----- crates/handlers/src/oauth2/introspection.rs | 5 +--- crates/handlers/src/oauth2/registration.rs | 5 +--- crates/handlers/src/oauth2/token.rs | 5 +--- crates/handlers/src/oauth2/userinfo.rs | 5 +--- .../handlers/src/upstream_oauth2/authorize.rs | 5 +--- .../handlers/src/upstream_oauth2/callback.rs | 5 +--- crates/handlers/src/upstream_oauth2/link.rs | 7 ++--- .../handlers/src/views/account/emails/add.rs | 9 ++----- .../handlers/src/views/account/emails/mod.rs | 9 ++----- .../src/views/account/emails/verify.rs | 9 ++----- crates/handlers/src/views/account/mod.rs | 5 +--- crates/handlers/src/views/account/password.rs | 9 ++----- crates/handlers/src/views/index.rs | 5 +--- crates/handlers/src/views/login.rs | 9 ++----- crates/handlers/src/views/logout.rs | 10 ++----- crates/handlers/src/views/reauth.rs | 9 ++----- crates/handlers/src/views/register.rs | 9 ++----- crates/storage-pg/src/user/tests.rs | 6 ++--- 29 files changed, 79 insertions(+), 142 deletions(-) diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index ad8c407e4..5c7acf34c 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -190,8 +190,8 @@ impl TokenType { /// use rand::thread_rng; /// use mas_data_model::TokenType::{AccessToken, RefreshToken}; /// - /// AccessToken.generate(thread_rng()); - /// RefreshToken.generate(thread_rng()); + /// AccessToken.generate(&mut thread_rng()); + /// RefreshToken.generate(&mut thread_rng()); /// ``` pub fn generate(self, rng: &mut (impl RngCore + ?Sized)) -> String { let random_part: String = rng diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index 45446e107..4271f8965 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -17,16 +17,20 @@ use std::{convert::Infallible, sync::Arc}; use axum::{ async_trait, extract::{FromRef, FromRequestParts}, + response::IntoResponse, }; +use hyper::StatusCode; use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_email::Mailer; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; use mas_storage::{BoxClock, BoxRng, SystemClock}; +use mas_storage_pg::PgRepository; use mas_templates::Templates; use rand::SeedableRng; use sqlx::PgPool; +use thiserror::Error; use crate::{passwords::PasswordManager, MatrixHomeserver}; @@ -140,3 +144,26 @@ impl FromRequestParts for BoxRng { Ok(Box::new(rng)) } } + +#[derive(Debug, Error)] +#[error(transparent)] +pub struct RepositoryError(#[from] mas_storage_pg::DatabaseError); + +impl IntoResponse for RepositoryError { + fn into_response(self) -> axum::response::Response { + (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response() + } +} + +#[async_trait] +impl FromRequestParts for PgRepository { + type Rejection = RepositoryError; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &AppState, + ) -> Result { + let repo = PgRepository::from_pool(&state.pool).await?; + Ok(repo) + } +} diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 279b575d3..f76cda717 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -28,7 +28,6 @@ use mas_storage_pg::PgRepository; use rand::{CryptoRng, RngCore}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; -use sqlx::PgPool; use thiserror::Error; use zeroize::Zeroizing; @@ -197,11 +196,10 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(password_manager): State, - State(pool): State, + mut repo: PgRepository, State(homeserver): State, Json(input): Json, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; let (session, user) = match input.credentials { Credentials::Password { identifier: Identifier::User { user }, diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 6201b0c64..602b4d80d 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -36,7 +36,6 @@ use mas_storage::{ use mas_storage_pg::PgRepository; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use serde::{Deserialize, Serialize}; -use sqlx::PgPool; use ulid::Ulid; #[derive(Serialize)] @@ -56,14 +55,12 @@ pub struct Params { pub async fn get( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, Query(params): Query, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); @@ -120,15 +117,13 @@ pub async fn get( pub async fn post( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, Query(params): Query, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); cookie_jar.verify_form(&clock, form)?; diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index a8063141a..d8ef0fb27 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -24,7 +24,6 @@ use mas_storage_pg::PgRepository; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use serde_with::serde; -use sqlx::PgPool; use thiserror::Error; use url::Url; @@ -60,7 +59,7 @@ impl IntoResponse for RouteError { pub async fn get( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, State(url_builder): State, Query(params): Query, ) -> Result { @@ -79,7 +78,6 @@ pub async fn get( } let token = Alphanumeric.sample_string(&mut rng, 32); - let mut repo = PgRepository::from_pool(&pool).await?; let login = repo .compat_sso_login() .add(&mut rng, &clock, token, redirect_url) diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index bfc767faf..e1ef02be5 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{extract::State, response::IntoResponse, Json, TypedHeader}; +use axum::{response::IntoResponse, Json, TypedHeader}; use headers::{authorization::Bearer, Authorization}; use hyper::StatusCode; use mas_data_model::TokenType; @@ -21,7 +21,6 @@ use mas_storage::{ BoxClock, Clock, Repository, }; use mas_storage_pg::PgRepository; -use sqlx::PgPool; use thiserror::Error; use super::MatrixError; @@ -69,11 +68,9 @@ impl IntoResponse for RouteError { pub(crate) async fn post( clock: BoxClock, - State(pool): State, + mut repo: PgRepository, maybe_authorization: Option>>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; let token = authorization.token(); diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 868be9db1..6b90464ea 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{extract::State, response::IntoResponse, Json}; +use axum::{response::IntoResponse, Json}; use chrono::Duration; use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; @@ -23,7 +23,6 @@ use mas_storage::{ use mas_storage_pg::PgRepository; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; -use sqlx::PgPool; use thiserror::Error; use super::MatrixError; @@ -90,11 +89,9 @@ pub struct ResponseBody { pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, Json(input): Json, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let token_type = TokenType::check(&input.refresh_token)?; if token_type != TokenType::CompatRefreshToken { diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 4d9dcbcdc..48ca55608 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -21,7 +21,10 @@ )] #![warn(clippy::pedantic)] #![allow( - clippy::unused_async // Some axum handlers need that + // Some axum handlers need that + clippy::unused_async, + // Because of how axum handlers work, we sometime have take many arguments + clippy::too_many_arguments, )] use std::{convert::Infallible, sync::Arc, time::Duration}; @@ -41,6 +44,7 @@ use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::{Route, UrlBuilder}; use mas_storage::{BoxClock, BoxRng}; +use mas_storage_pg::PgRepository; use mas_templates::{ErrorContext, Templates}; use passwords::PasswordManager; use sqlx::PgPool; @@ -154,7 +158,7 @@ where Keystore: FromRef, UrlBuilder: FromRef, Arc: FromRef, - PgPool: FromRef, + PgRepository: FromRequestParts, Encrypter: FromRef, HttpClientFactory: FromRef, BoxClock: FromRequestParts, @@ -209,7 +213,7 @@ where ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, - PgPool: FromRef, + PgRepository: FromRequestParts, MatrixHomeserver: FromRef, PasswordManager: FromRef, BoxClock: FromRequestParts, @@ -254,7 +258,7 @@ where S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, Arc: FromRef, - PgPool: FromRef, + PgRepository: FromRequestParts, Encrypter: FromRef, Templates: FromRef, Mailer: FromRef, @@ -358,7 +362,7 @@ where } #[cfg(test)] -async fn test_state(pool: PgPool) -> Result { +async fn test_state(pool: sqlx::PgPool) -> Result { use mas_email::MailTransport; use crate::passwords::Hasher; diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 934ba088f..c17fb9f1b 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -32,7 +32,6 @@ use mas_storage::{ use mas_storage_pg::PgRepository; use mas_templates::Templates; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -82,12 +81,10 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 607534823..30efcaa36 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -39,7 +39,6 @@ use oauth2_types::{ }; use rand::{distributions::Alphanumeric, Rng}; use serde::Deserialize; -use sqlx::PgPool; use thiserror::Error; use self::{callback::CallbackDestination, complete::GrantCompletionError}; @@ -136,12 +135,10 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Form(params): Form, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - // First, figure out what client it is let client = repo .oauth2_client() diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 86c832fb2..c83dca03c 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -34,7 +34,6 @@ use mas_storage::{ }; use mas_storage_pg::PgRepository; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -78,12 +77,10 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; @@ -133,13 +130,11 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(policy_factory): State>, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index d8e64fa06..65e48e064 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -33,7 +33,6 @@ use oauth2_types::{ requests::{IntrospectionRequest, IntrospectionResponse}, scope::ScopeToken, }; -use sqlx::PgPool; use thiserror::Error; use crate::impl_from_error_for_route; @@ -126,12 +125,10 @@ const API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:org.matrix.msc pub(crate) async fn post( clock: BoxClock, State(http_client_factory): State, - State(pool): State, + mut repo: PgRepository, State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let client = client_authorization .credentials .fetch(&mut repo) diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index da043b8bd..129f636f6 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -28,7 +28,6 @@ use oauth2_types::{ }, }; use rand::distributions::{Alphanumeric, DistString}; -use sqlx::PgPool; use thiserror::Error; use tracing::info; @@ -109,7 +108,7 @@ impl IntoResponse for RouteError { pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, State(policy_factory): State>, State(encrypter): State, Json(body): Json, @@ -125,8 +124,6 @@ pub(crate) async fn post( return Err(RouteError::PolicyDenied(res.violations)); } - let mut repo = PgRepository::from_pool(&pool).await?; - let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method { Some( OAuthClientAuthenticationMethod::ClientSecretJwt diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 3fe916d81..ed566261e 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -50,7 +50,6 @@ use oauth2_types::{ }; use serde::Serialize; use serde_with::{serde_as, skip_serializing_none}; -use sqlx::PgPool; use thiserror::Error; use tracing::debug; use url::Url; @@ -164,12 +163,10 @@ pub(crate) async fn post( State(http_client_factory): State, State(key_store): State, State(url_builder): State, - State(pool): State, + mut repo: PgRepository, State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let client = client_authorization .credentials .fetch(&mut repo) diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 9d60ac1fa..eb9e1cc2f 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -37,7 +37,6 @@ use mas_storage_pg::PgRepository; use oauth2_types::scope; use serde::Serialize; use serde_with::skip_serializing_none; -use sqlx::PgPool; use thiserror::Error; use crate::impl_from_error_for_route; @@ -101,12 +100,10 @@ pub async fn get( mut rng: BoxRng, clock: BoxClock, State(url_builder): State, - State(pool): State, + mut repo: PgRepository, State(key_store): State, user_authorization: UserAuthorization, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let session = user_authorization.protected(&mut repo, &clock).await?; let browser_session = repo diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index d66493170..ff47084b5 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -27,7 +27,6 @@ use mas_storage::{ BoxClock, BoxRng, Repository, }; use mas_storage_pg::PgRepository; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -61,14 +60,12 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(http_client_factory): State, - State(pool): State, + mut repo: PgRepository, State(url_builder): State, cookie_jar: PrivateCookieJar, Path(provider_id): Path, Query(query): Query, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let provider = repo .upstream_oauth_provider() .lookup(provider_id) diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index fd66af094..b324cfb24 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -35,7 +35,6 @@ use mas_storage::{ use mas_storage_pg::PgRepository; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -124,7 +123,7 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(http_client_factory): State, - State(pool): State, + mut repo: PgRepository, State(url_builder): State, State(encrypter): State, State(keystore): State, @@ -132,8 +131,6 @@ pub(crate) async fn get( Path(provider_id): Path, Query(params): Query, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let provider = repo .upstream_oauth_provider() .lookup(provider_id) diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index d318fc3e3..bdd5df1ff 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -35,7 +35,6 @@ use mas_templates::{ UpstreamSuggestLink, }; use serde::Deserialize; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -96,12 +95,11 @@ pub(crate) enum FormData { pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(link_id): Path, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let (session_id, _post_auth_action) = sessions_cookie .lookup_link(link_id) @@ -213,12 +211,11 @@ pub(crate) async fn get( pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Path(link_id): Path, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(&clock, form)?; let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 1c8c6665e..64218e3a8 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -28,7 +28,6 @@ use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository}; use mas_storage_pg::PgRepository; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; -use sqlx::PgPool; use super::start_email_verification; use crate::views::shared::OptionalPostAuthAction; @@ -42,11 +41,9 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -71,14 +68,12 @@ pub(crate) async fn get( pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, State(mailer): State, cookie_jar: PrivateCookieJar, Query(query): Query, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 10772b871..fd2f2981f 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -33,7 +33,6 @@ use mas_storage_pg::PgRepository; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; use serde::Deserialize; -use sqlx::PgPool; use tracing::info; pub mod add; @@ -52,11 +51,9 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; @@ -127,13 +124,11 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, State(mailer): State, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 644810e54..e330c944f 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -28,7 +28,6 @@ use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository}; use mas_storage_pg::PgRepository; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; -use sqlx::PgPool; use ulid::Ulid; use crate::views::shared::OptionalPostAuthAction; @@ -42,13 +41,11 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, Path(id): Path, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -85,14 +82,12 @@ pub(crate) async fn get( pub(crate) async fn post( clock: BoxClock, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Query(query): Query, Path(id): Path, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 660c14162..76ea5667d 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -29,17 +29,14 @@ use mas_storage::{ }; use mas_storage_pg::PgRepository; use mas_templates::{AccountContext, TemplateContext, Templates}; -use sqlx::PgPool; pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index a9f17123b..4fa86eae2 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -33,7 +33,6 @@ use mas_storage_pg::PgRepository; use mas_templates::{EmptyContext, TemplateContext, Templates}; use rand::Rng; use serde::Deserialize; -use sqlx::PgPool; use zeroize::Zeroizing; use crate::passwords::PasswordManager; @@ -49,11 +48,9 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; @@ -89,12 +86,10 @@ pub(crate) async fn post( clock: BoxClock, State(password_manager): State, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index ffe500e72..d4322eefd 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -23,18 +23,15 @@ use mas_router::UrlBuilder; use mas_storage::{BoxClock, BoxRng}; use mas_storage_pg::PgRepository; use mas_templates::{IndexContext, TemplateContext, Templates}; -use sqlx::PgPool; pub async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, State(url_builder): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); let session = session_info.load_session(&mut repo).await?; diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index d8abcb467..b245b5977 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -34,7 +34,6 @@ use mas_templates::{ }; use rand::{CryptoRng, Rng}; use serde::{Deserialize, Serialize}; -use sqlx::PgPool; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -54,12 +53,10 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -88,13 +85,11 @@ pub(crate) async fn post( clock: BoxClock, State(password_manager): State, State(templates): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index f8491cb95..9cdc93f03 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{ - extract::{Form, State}, - response::IntoResponse, -}; +use axum::{extract::Form, response::IntoResponse}; use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, @@ -25,16 +22,13 @@ use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; use mas_storage::{user::BrowserSessionRepository, BoxClock, Repository}; use mas_storage_pg::PgRepository; -use sqlx::PgPool; pub(crate) async fn post( clock: BoxClock, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Form(form): Form>>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (session_info, mut cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 9c2330a30..ced979020 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -31,7 +31,6 @@ use mas_storage::{ use mas_storage_pg::PgRepository; use mas_templates::{ReauthContext, TemplateContext, Templates}; use serde::Deserialize; -use sqlx::PgPool; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -46,12 +45,10 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -84,13 +81,11 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(password_manager): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index a8fc7baee..68cf5c493 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -42,7 +42,6 @@ use mas_templates::{ }; use rand::{distributions::Uniform, Rng}; use serde::{Deserialize, Serialize}; -use sqlx::PgPool; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -64,12 +63,10 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -100,13 +97,11 @@ pub(crate) async fn post( State(mailer): State, State(policy_factory): State>, State(templates): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index 097bca74f..b3b882321 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -90,7 +90,7 @@ async fn test_user_email_repo(pool: PgPool) { // The user email should not exist yet assert!(repo .user_email() - .find(&user, &EMAIL) + .find(&user, EMAIL) .await .unwrap() .is_none()); @@ -111,7 +111,7 @@ async fn test_user_email_repo(pool: PgPool) { assert!(repo .user_email() - .find(&user, &EMAIL) + .find(&user, EMAIL) .await .unwrap() .is_some()); @@ -181,7 +181,7 @@ async fn test_user_email_repo(pool: PgPool) { // Reload the user_email let user_email = repo .user_email() - .find(&user, &EMAIL) + .find(&user, EMAIL) .await .unwrap() .expect("user email was not found"); From aa830db9f9b2973d7c00287ad1eace21520b93f4 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 19 Jan 2023 19:10:35 +0100 Subject: [PATCH 33/45] storage: ensure the repository trait can be boxed and define some wrappers to map the errors --- crates/graphql/src/model/upstream_oauth.rs | 12 +- crates/graphql/src/model/users.rs | 4 +- crates/handlers/src/oauth2/token.rs | 3 +- crates/storage-pg/src/repository.rs | 126 ++++++---- crates/storage/src/compat/access_token.rs | 26 +- crates/storage/src/compat/refresh_token.rs | 26 +- crates/storage/src/compat/session.rs | 20 +- crates/storage/src/compat/sso_login.rs | 38 ++- crates/storage/src/lib.rs | 54 ++++ crates/storage/src/oauth2/access_token.rs | 28 ++- .../storage/src/oauth2/authorization_grant.rs | 43 +++- crates/storage/src/oauth2/client.rs | 60 ++++- crates/storage/src/oauth2/refresh_token.rs | 26 +- crates/storage/src/oauth2/session.rs | 23 +- crates/storage/src/repository.rs | 236 +++++++++++++----- crates/storage/src/upstream_oauth2/link.rs | 34 ++- .../storage/src/upstream_oauth2/provider.rs | 25 +- crates/storage/src/upstream_oauth2/session.rs | 33 ++- crates/storage/src/user/email.rs | 55 +++- crates/storage/src/user/mod.rs | 14 +- crates/storage/src/user/password.rs | 15 +- crates/storage/src/user/session.rs | 39 ++- crates/tasks/src/database.rs | 3 +- 23 files changed, 801 insertions(+), 142 deletions(-) diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index 5767f8d4b..d65158c91 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -104,10 +104,12 @@ impl UpstreamOAuth2Link { } else { // Fetch on-the-fly let mut repo = PgRepository::from_pool(ctx.data::()?).await?; - repo.upstream_oauth_provider() + let provider = repo + .upstream_oauth_provider() .lookup(self.link.provider_id) .await? - .context("Upstream OAuth 2.0 provider not found")? + .context("Upstream OAuth 2.0 provider not found")?; + provider }; Ok(UpstreamOAuth2Provider::new(provider)) @@ -121,10 +123,12 @@ impl UpstreamOAuth2Link { } else if let Some(user_id) = &self.link.user_id { // Fetch on-the-fly let mut repo = PgRepository::from_pool(ctx.data::()?).await?; - repo.user() + let user = repo + .user() .lookup(*user_id) .await? - .context("User not found")? + .context("User not found")?; + user } else { return Ok(None); }; diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 3f587eb06..a8036dc89 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -67,7 +67,9 @@ impl User { ) -> Result, async_graphql::Error> { let mut repo = PgRepository::from_pool(ctx.data::()?).await?; - Ok(repo.user_email().get_primary(&self.0).await?.map(UserEmail)) + let mut user_email_repo = repo.user_email(); + + Ok(user_email_repo.get_primary(&self.0).await?.map(UserEmail)) } /// Get the list of compatibility SSO logins, chronologically sorted diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index ed566261e..5b6b7565b 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -426,7 +426,8 @@ async fn refresh_token_grant( .await?; if let Some(access_token_id) = refresh_token.access_token_id { - if let Some(access_token) = repo.oauth2_access_token().lookup(access_token_id).await? { + let access_token = repo.oauth2_access_token().lookup(access_token_id).await?; + if let Some(access_token) = access_token { repo.oauth2_access_token() .revoke(clock, access_token) .await?; diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 288181a67..540027551 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -12,7 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -use mas_storage::Repository; +use mas_storage::{ + compat::{ + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + CompatSsoLoginRepository, + }, + oauth2::{ + OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, + OAuth2RefreshTokenRepository, OAuth2SessionRepository, + }, + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + Repository, +}; use sqlx::{PgPool, Postgres, Transaction}; use crate::{ @@ -59,84 +74,95 @@ impl PgRepository { impl Repository for PgRepository { type Error = DatabaseError; - type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; - type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; - type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; - type UserRepository<'c> = PgUserRepository<'c> where Self: 'c; - type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; - type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; - type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; - type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; - type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c; - type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; - type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c; - type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c; - type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; - type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; - type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; - type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; - - fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { - PgUpstreamOAuthLinkRepository::new(&mut self.txn) + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUpstreamOAuthLinkRepository::new(&mut self.txn)) } - fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { - PgUpstreamOAuthProviderRepository::new(&mut self.txn) + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUpstreamOAuthProviderRepository::new(&mut self.txn)) } - fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { - PgUpstreamOAuthSessionRepository::new(&mut self.txn) + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUpstreamOAuthSessionRepository::new(&mut self.txn)) } - fn user(&mut self) -> Self::UserRepository<'_> { - PgUserRepository::new(&mut self.txn) + fn user<'c>(&'c mut self) -> Box + 'c> { + Box::new(PgUserRepository::new(&mut self.txn)) } - fn user_email(&mut self) -> Self::UserEmailRepository<'_> { - PgUserEmailRepository::new(&mut self.txn) + fn user_email<'c>(&'c mut self) -> Box + 'c> { + Box::new(PgUserEmailRepository::new(&mut self.txn)) } - fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { - PgUserPasswordRepository::new(&mut self.txn) + fn user_password<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUserPasswordRepository::new(&mut self.txn)) } - fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { - PgBrowserSessionRepository::new(&mut self.txn) + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgBrowserSessionRepository::new(&mut self.txn)) } - fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { - PgOAuth2ClientRepository::new(&mut self.txn) + fn oauth2_client<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2ClientRepository::new(&mut self.txn)) } - fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { - PgOAuth2AuthorizationGrantRepository::new(&mut self.txn) + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2AuthorizationGrantRepository::new(&mut self.txn)) } - fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { - PgOAuth2SessionRepository::new(&mut self.txn) + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2SessionRepository::new(&mut self.txn)) } - fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { - PgOAuth2AccessTokenRepository::new(&mut self.txn) + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2AccessTokenRepository::new(&mut self.txn)) } - fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { - PgOAuth2RefreshTokenRepository::new(&mut self.txn) + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2RefreshTokenRepository::new(&mut self.txn)) } - fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { - PgCompatSessionRepository::new(&mut self.txn) + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgCompatSessionRepository::new(&mut self.txn)) } - fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { - PgCompatSsoLoginRepository::new(&mut self.txn) + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgCompatSsoLoginRepository::new(&mut self.txn)) } - fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { - PgCompatAccessTokenRepository::new(&mut self.txn) + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgCompatAccessTokenRepository::new(&mut self.txn)) } - fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { - PgCompatRefreshTokenRepository::new(&mut self.txn) + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgCompatRefreshTokenRepository::new(&mut self.txn)) } } diff --git a/crates/storage/src/compat/access_token.rs b/crates/storage/src/compat/access_token.rs index 32ba1f735..c6d4eb7fe 100644 --- a/crates/storage/src/compat/access_token.rs +++ b/crates/storage/src/compat/access_token.rs @@ -18,7 +18,7 @@ use mas_data_model::{CompatAccessToken, CompatSession}; use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait CompatAccessTokenRepository: Send + Sync { @@ -50,3 +50,27 @@ pub trait CompatAccessTokenRepository: Send + Sync { compat_access_token: CompatAccessToken, ) -> Result; } + +repository_impl!(CompatAccessTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + compat_session: &CompatSession, + token: String, + expires_after: Option, + ) -> Result; + + async fn expire( + &mut self, + clock: &dyn Clock, + compat_access_token: CompatAccessToken, + ) -> Result; +); diff --git a/crates/storage/src/compat/refresh_token.rs b/crates/storage/src/compat/refresh_token.rs index 627b59a12..3fd916da9 100644 --- a/crates/storage/src/compat/refresh_token.rs +++ b/crates/storage/src/compat/refresh_token.rs @@ -17,7 +17,7 @@ use mas_data_model::{CompatAccessToken, CompatRefreshToken, CompatSession}; use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait CompatRefreshTokenRepository: Send + Sync { @@ -49,3 +49,27 @@ pub trait CompatRefreshTokenRepository: Send + Sync { compat_refresh_token: CompatRefreshToken, ) -> Result; } + +repository_impl!(CompatRefreshTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + compat_session: &CompatSession, + compat_access_token: &CompatAccessToken, + token: String, + ) -> Result; + + async fn consume( + &mut self, + clock: &dyn Clock, + compat_refresh_token: CompatRefreshToken, + ) -> Result; +); diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index 0c5bc125c..f867a332b 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -17,7 +17,7 @@ use mas_data_model::{CompatSession, Device, User}; use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait CompatSessionRepository: Send + Sync { @@ -42,3 +42,21 @@ pub trait CompatSessionRepository: Send + Sync { compat_session: CompatSession, ) -> Result; } + +repository_impl!(CompatSessionRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + device: Device, + ) -> Result; + + async fn finish( + &mut self, + clock: &dyn Clock, + compat_session: CompatSession, + ) -> Result; +); diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs index 1ed3e5d80..a6fa07357 100644 --- a/crates/storage/src/compat/sso_login.rs +++ b/crates/storage/src/compat/sso_login.rs @@ -18,7 +18,7 @@ use rand_core::RngCore; use ulid::Ulid; use url::Url; -use crate::{pagination::Page, Clock, Pagination}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait CompatSsoLoginRepository: Send + Sync { @@ -64,3 +64,39 @@ pub trait CompatSsoLoginRepository: Send + Sync { pagination: Pagination, ) -> Result, Self::Error>; } + +repository_impl!(CompatSsoLoginRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + login_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + login_token: String, + redirect_uri: Url, + ) -> Result; + + async fn fulfill( + &mut self, + clock: &dyn Clock, + compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, + ) -> Result; + + async fn exchange( + &mut self, + clock: &dyn Clock, + compat_sso_login: CompatSsoLogin, + ) -> Result; + + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; +); diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index d5a453726..0cdc4e39b 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -45,5 +45,59 @@ pub use self::{ repository::Repository, }; +pub struct MapErr { + inner: Repository, + mapper: Mapper, +} + +impl MapErr { + fn new(inner: Repository, mapper: Mapper) -> Self { + Self { inner, mapper } + } +} + +#[macro_export] +macro_rules! repository_impl { + ($repo_trait:ident: + $( + async fn $method:ident ( + &mut self + $(, $arg:ident: $arg_ty:ty )* + $(,)? + ) -> Result<$ret_ty:ty, Self::Error>; + )* + ) => { + #[::async_trait::async_trait] + impl $repo_trait for ::std::boxed::Box + where + R: $repo_trait, + { + type Error = ::Error; + + $( + async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> { + (**self).$method ( $($arg),* ).await + } + )* + } + + #[::async_trait::async_trait] + impl $repo_trait for $crate::MapErr + where + R: $repo_trait, + F: FnMut(::Error) -> E + ::std::marker::Send + ::std::marker::Sync, + E: ::std::error::Error + ::std::marker::Send + ::std::marker::Sync, + { + type Error = E; + + $( + async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> { + self.inner.$method ( $($arg),* ).await.map_err(&mut self.mapper) + } + )* + } + }; +} + pub type BoxClock = Box; pub type BoxRng = Box; diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index 1148136f0..8a5362431 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -18,7 +18,7 @@ use mas_data_model::{AccessToken, Session}; use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait OAuth2AccessTokenRepository: Send + Sync { @@ -53,3 +53,29 @@ pub trait OAuth2AccessTokenRepository: Send + Sync { /// Cleanup expired access tokens async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result; } + +repository_impl!(OAuth2AccessTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + session: &Session, + access_token: String, + expires_after: Duration, + ) -> Result; + + async fn revoke( + &mut self, + clock: &dyn Clock, + access_token: AccessToken, + ) -> Result; + + async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result; +); diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 1130e6a8a..8852f796b 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -21,7 +21,7 @@ use rand_core::RngCore; use ulid::Ulid; use url::Url; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait OAuth2AuthorizationGrantRepository: Send + Sync { @@ -67,3 +67,44 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync { authorization_grant: AuthorizationGrant, ) -> Result; } + +repository_impl!(OAuth2AuthorizationGrantRepository: + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client: &Client, + redirect_uri: Url, + scope: Scope, + code: Option, + state: Option, + nonce: Option, + max_age: Option, + response_mode: ResponseMode, + response_type_id_token: bool, + requires_consent: bool, + ) -> Result; + + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_code(&mut self, code: &str) + -> Result, Self::Error>; + + async fn fulfill( + &mut self, + clock: &dyn Clock, + session: &Session, + authorization_grant: AuthorizationGrant, + ) -> Result; + + async fn exchange( + &mut self, + clock: &dyn Clock, + authorization_grant: AuthorizationGrant, + ) -> Result; + + async fn give_consent( + &mut self, + authorization_grant: AuthorizationGrant, + ) -> Result; +); diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 3c7d7dbb3..98acaaf7e 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -23,7 +23,7 @@ use rand_core::RngCore; use ulid::Ulid; use url::Url; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait OAuth2ClientRepository: Send + Sync { @@ -92,3 +92,61 @@ pub trait OAuth2ClientRepository: Send + Sync { scope: &Scope, ) -> Result<(), Self::Error>; } + +repository_impl!(OAuth2ClientRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn load_batch( + &mut self, + ids: BTreeSet, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + redirect_uris: Vec, + encrypted_client_secret: Option, + grant_types: Vec, + contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, + ) -> Result; + + async fn add_from_config( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client_id: Ulid, + client_auth_method: OAuthClientAuthenticationMethod, + encrypted_client_secret: Option, + jwks: Option, + jwks_uri: Option, + redirect_uris: Vec, + ) -> Result; + + async fn get_consent_for_user( + &mut self, + client: &Client, + user: &User, + ) -> Result; + + async fn give_consent_for_user( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client: &Client, + user: &User, + scope: &Scope, + ) -> Result<(), Self::Error>; +); diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 66ec2c328..e8ac63ce6 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -17,7 +17,7 @@ use mas_data_model::{AccessToken, RefreshToken, Session}; use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait OAuth2RefreshTokenRepository: Send + Sync { @@ -49,3 +49,27 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync { refresh_token: RefreshToken, ) -> Result; } + +repository_impl!(OAuth2RefreshTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + session: &Session, + access_token: &AccessToken, + refresh_token: String, + ) -> Result; + + async fn consume( + &mut self, + clock: &dyn Clock, + refresh_token: RefreshToken, + ) -> Result; +); diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 3813810b8..f348d9e68 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -17,7 +17,7 @@ use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User}; use rand_core::RngCore; use ulid::Ulid; -use crate::{pagination::Page, Clock, Pagination}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait OAuth2SessionRepository: Send + Sync { @@ -42,3 +42,24 @@ pub trait OAuth2SessionRepository: Send + Sync { pagination: Pagination, ) -> Result, Self::Error>; } + +repository_impl!(OAuth2SessionRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn create_from_grant( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + grant: &AuthorizationGrant, + user_session: &BrowserSession, + ) -> Result; + + async fn finish(&mut self, clock: &dyn Clock, session: Session) + -> Result; + + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; +); diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 55afe41b4..085c06aba 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -26,92 +26,192 @@ use crate::{ UpstreamOAuthSessionRepository, }, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + MapErr, }; pub trait Repository: Send { type Error: std::error::Error + Send + Sync + 'static; - type UpstreamOAuthLinkRepository<'c>: UpstreamOAuthLinkRepository + 'c - where - Self: 'c; + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c>; - type UpstreamOAuthProviderRepository<'c>: UpstreamOAuthProviderRepository - + 'c - where - Self: 'c; + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c>; - type UpstreamOAuthSessionRepository<'c>: UpstreamOAuthSessionRepository - + 'c - where - Self: 'c; + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c>; - type UserRepository<'c>: UserRepository + 'c - where - Self: 'c; + fn user<'c>(&'c mut self) -> Box + 'c>; - type UserEmailRepository<'c>: UserEmailRepository + 'c - where - Self: 'c; + fn user_email<'c>(&'c mut self) -> Box + 'c>; - type UserPasswordRepository<'c>: UserPasswordRepository + 'c - where - Self: 'c; + fn user_password<'c>(&'c mut self) + -> Box + 'c>; - type BrowserSessionRepository<'c>: BrowserSessionRepository + 'c - where - Self: 'c; + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c>; - type OAuth2ClientRepository<'c>: OAuth2ClientRepository + 'c - where - Self: 'c; + fn oauth2_client<'c>(&'c mut self) + -> Box + 'c>; - type OAuth2AuthorizationGrantRepository<'c>: OAuth2AuthorizationGrantRepository - + 'c - where - Self: 'c; + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c>; - type OAuth2SessionRepository<'c>: OAuth2SessionRepository + 'c - where - Self: 'c; + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c>; - type OAuth2AccessTokenRepository<'c>: OAuth2AccessTokenRepository + 'c - where - Self: 'c; + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c>; - type OAuth2RefreshTokenRepository<'c>: OAuth2RefreshTokenRepository + 'c - where - Self: 'c; + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c>; - type CompatSessionRepository<'c>: CompatSessionRepository + 'c - where - Self: 'c; + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c>; - type CompatSsoLoginRepository<'c>: CompatSsoLoginRepository + 'c - where - Self: 'c; + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c>; - type CompatAccessTokenRepository<'c>: CompatAccessTokenRepository + 'c - where - Self: 'c; + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c>; - type CompatRefreshTokenRepository<'c>: CompatRefreshTokenRepository + 'c - where - Self: 'c; - - fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; - fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; - fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>; - fn user(&mut self) -> Self::UserRepository<'_>; - fn user_email(&mut self) -> Self::UserEmailRepository<'_>; - fn user_password(&mut self) -> Self::UserPasswordRepository<'_>; - fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>; - fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>; - fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_>; - fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>; - fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_>; - fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_>; - fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>; - fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>; - fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>; - fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>; + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c>; +} + +impl Repository for crate::MapErr +where + R: Repository, + F: FnMut(R::Error) -> E + Send + Sync, + E: std::error::Error + Send + Sync + 'static, +{ + type Error = E; + + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_link(), + &mut self.mapper, + )) + } + + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_provider(), + &mut self.mapper, + )) + } + + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_session(), + &mut self.mapper, + )) + } + + fn user<'c>(&'c mut self) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user(), &mut self.mapper)) + } + + fn user_email<'c>(&'c mut self) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user_email(), &mut self.mapper)) + } + + fn user_password<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper)) + } + + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper)) + } + + fn oauth2_client<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.oauth2_client(), &mut self.mapper)) + } + + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_authorization_grant(), + &mut self.mapper, + )) + } + + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.oauth2_session(), &mut self.mapper)) + } + + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_access_token(), + &mut self.mapper, + )) + } + + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_refresh_token(), + &mut self.mapper, + )) + } + + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.compat_session(), &mut self.mapper)) + } + + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.compat_sso_login(), &mut self.mapper)) + } + + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.compat_access_token(), + &mut self.mapper, + )) + } + + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.compat_refresh_token(), + &mut self.mapper, + )) + } } diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index bf9e0aadd..c5e024af7 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -17,11 +17,11 @@ use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; use rand_core::RngCore; use ulid::Ulid; -use crate::{pagination::Page, Clock, Pagination}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait UpstreamOAuthLinkRepository: Send + Sync { - type Error; + type Error: std::error::Error + Send + Sync; /// Lookup an upstream OAuth link by its ID async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; @@ -56,3 +56,33 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { pagination: Pagination, ) -> Result, Self::Error>; } + +repository_impl!(UpstreamOAuthLinkRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_subject( + &mut self, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: String, + ) -> Result; + + async fn associate_to_user( + &mut self, + upstream_oauth_link: &UpstreamOAuthLink, + user: &User, + ) -> Result<(), Self::Error>; + + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; +); diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 521a7e7a0..8aaca0dac 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -19,7 +19,7 @@ use oauth2_types::scope::Scope; use rand_core::RngCore; use ulid::Ulid; -use crate::{pagination::Page, Clock, Pagination}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait UpstreamOAuthProviderRepository: Send + Sync { @@ -51,3 +51,26 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { /// Get all upstream OAuth providers async fn all(&mut self) -> Result, Self::Error>; } + +repository_impl!(UpstreamOAuthProviderRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option + ) -> Result; + + async fn list_paginated( + &mut self, + pagination: Pagination + ) -> Result, Self::Error>; + + async fn all(&mut self) -> Result, Self::Error>; +); diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index f4441b2ab..e878444bd 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -17,7 +17,7 @@ use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, Upstr use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait UpstreamOAuthSessionRepository: Send + Sync { @@ -56,3 +56,34 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, ) -> Result; } + +repository_impl!(UpstreamOAuthSessionRepository: + async fn lookup( + &mut self, + id: Ulid, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + state: String, + code_challenge_verifier: Option, + nonce: String, + ) -> Result; + + async fn complete_with_link( + &mut self, + clock: &dyn Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result; + + async fn consume( + &mut self, + clock: &dyn Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + ) -> Result; +); diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 65ee465b9..4c8601c2b 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -17,7 +17,7 @@ use mas_data_model::{User, UserEmail, UserEmailVerification}; use rand_core::RngCore; use ulid::Ulid; -use crate::{pagination::Page, Clock, Pagination}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait UserEmailRepository: Send + Sync { @@ -74,3 +74,56 @@ pub trait UserEmailRepository: Send + Sync { verification: UserEmailVerification, ) -> Result; } + +repository_impl!(UserEmailRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error>; + async fn get_primary(&mut self, user: &User) -> Result, Self::Error>; + + async fn all(&mut self, user: &User) -> Result, Self::Error>; + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; + async fn count(&mut self, user: &User) -> Result; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + email: String, + ) -> Result; + async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error>; + + async fn mark_as_verified( + &mut self, + clock: &dyn Clock, + user_email: UserEmail, + ) -> Result; + + async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error>; + + async fn add_verification_code( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_email: &UserEmail, + max_age: chrono::Duration, + code: String, + ) -> Result; + + async fn find_verification_code( + &mut self, + clock: &dyn Clock, + user_email: &UserEmail, + code: &str, + ) -> Result, Self::Error>; + + async fn consume_verification_code( + &mut self, + clock: &dyn Clock, + verification: UserEmailVerification, + ) -> Result; +); diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index b3bd0bc25..49003335d 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -17,7 +17,7 @@ use mas_data_model::User; use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; mod email; mod password; @@ -41,3 +41,15 @@ pub trait UserRepository: Send + Sync { ) -> Result; async fn exists(&mut self, username: &str) -> Result; } + +repository_impl!(UserRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn find_by_username(&mut self, username: &str) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + username: String, + ) -> Result; + async fn exists(&mut self, username: &str) -> Result; +); diff --git a/crates/storage/src/user/password.rs b/crates/storage/src/user/password.rs index 609198b22..06f03f551 100644 --- a/crates/storage/src/user/password.rs +++ b/crates/storage/src/user/password.rs @@ -16,7 +16,7 @@ use async_trait::async_trait; use mas_data_model::{Password, User}; use rand_core::RngCore; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait UserPasswordRepository: Send + Sync { @@ -33,3 +33,16 @@ pub trait UserPasswordRepository: Send + Sync { upgraded_from: Option<&Password>, ) -> Result; } + +repository_impl!(UserPasswordRepository: + async fn active(&mut self, user: &User) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + version: u16, + hashed_password: String, + upgraded_from: Option<&Password>, + ) -> Result; +); diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index 5556547c0..0dfc581cc 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -17,7 +17,7 @@ use mas_data_model::{BrowserSession, Password, UpstreamOAuthLink, User}; use rand_core::RngCore; use ulid::Ulid; -use crate::{pagination::Page, Clock, Pagination}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait BrowserSessionRepository: Send + Sync { @@ -58,3 +58,40 @@ pub trait BrowserSessionRepository: Send + Sync { upstream_oauth_link: &UpstreamOAuthLink, ) -> Result; } + +repository_impl!(BrowserSessionRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + ) -> Result; + async fn finish( + &mut self, + clock: &dyn Clock, + user_session: BrowserSession, + ) -> Result; + async fn list_active_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; + async fn count_active(&mut self, user: &User) -> Result; + + async fn authenticate_with_password( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_session: BrowserSession, + user_password: &Password, + ) -> Result; + + async fn authenticate_with_upstream( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_session: BrowserSession, + upstream_oauth_link: &UpstreamOAuthLink, + ) -> Result; +); diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index 9e31880c8..e7947ce25 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -35,7 +35,8 @@ impl Task for CleanupExpired { async fn run(&self) { let res = async move { let mut repo = PgRepository::from_pool(&self.0).await?; - repo.oauth2_access_token().cleanup_expired(&self.1).await + let res = repo.oauth2_access_token().cleanup_expired(&self.1).await; + res } .await; From 50825ce660bded4707f682ce9abd090f21593527 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 20 Jan 2023 17:49:16 +0100 Subject: [PATCH 34/45] Box the repository everywhere --- Cargo.lock | 5 +- crates/axum-utils/src/client_authorization.rs | 8 +-- crates/axum-utils/src/session.rs | 6 +-- crates/axum-utils/src/user_authorization.rs | 23 ++++---- crates/cli/src/commands/manage.rs | 6 +-- crates/cli/src/commands/server.rs | 2 +- crates/graphql/Cargo.toml | 3 +- crates/graphql/src/lib.rs | 17 +++--- crates/graphql/src/model/compat_sessions.rs | 9 ++-- crates/graphql/src/model/oauth.rs | 13 +++-- crates/graphql/src/model/upstream_oauth.rs | 9 ++-- crates/graphql/src/model/users.rs | 20 ++++--- crates/handlers/src/app_state.rs | 8 +-- crates/handlers/src/compat/login.rs | 11 ++-- .../handlers/src/compat/login_sso_complete.rs | 11 ++-- .../handlers/src/compat/login_sso_redirect.rs | 7 ++- crates/handlers/src/compat/logout.rs | 7 ++- crates/handlers/src/compat/refresh.rs | 7 ++- crates/handlers/src/graphql.rs | 51 ++++++------------ crates/handlers/src/lib.rs | 13 +++-- .../src/oauth2/authorization/complete.rs | 13 +++-- .../handlers/src/oauth2/authorization/mod.rs | 9 ++-- crates/handlers/src/oauth2/consent.rs | 13 +++-- crates/handlers/src/oauth2/introspection.rs | 9 ++-- crates/handlers/src/oauth2/registration.rs | 7 ++- crates/handlers/src/oauth2/token.rs | 29 +++++------ crates/handlers/src/oauth2/userinfo.rs | 11 ++-- .../handlers/src/upstream_oauth2/authorize.rs | 7 ++- .../handlers/src/upstream_oauth2/callback.rs | 7 ++- crates/handlers/src/upstream_oauth2/link.rs | 13 +++-- .../handlers/src/views/account/emails/add.rs | 13 +++-- .../handlers/src/views/account/emails/mod.rs | 27 +++++----- .../src/views/account/emails/verify.rs | 11 ++-- crates/handlers/src/views/account/mod.rs | 7 ++- crates/handlers/src/views/account/password.rs | 11 ++-- crates/handlers/src/views/index.rs | 7 ++- crates/handlers/src/views/login.rs | 21 ++++---- crates/handlers/src/views/logout.rs | 7 ++- crates/handlers/src/views/reauth.rs | 13 +++-- crates/handlers/src/views/register.rs | 15 +++--- crates/handlers/src/views/shared.rs | 6 +-- crates/storage-pg/Cargo.toml | 1 + crates/storage-pg/src/compat/mod.rs | 8 +-- crates/storage-pg/src/repository.rs | 19 ++++--- crates/storage-pg/src/user/tests.rs | 6 +-- crates/storage/Cargo.toml | 1 + crates/storage/src/lib.rs | 11 ++-- crates/storage/src/repository.rs | 52 ++++++++++++++++++- crates/storage/src/upstream_oauth2/link.rs | 2 +- 49 files changed, 296 insertions(+), 296 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 253809cfc..87dd60dea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2804,11 +2804,10 @@ dependencies = [ "chrono", "mas-data-model", "mas-storage", - "mas-storage-pg", "oauth2-types", "serde", - "sqlx", "thiserror", + "tokio", "tracing", "ulid", "url", @@ -3101,6 +3100,7 @@ version = "0.1.0" dependencies = [ "async-trait", "chrono", + "futures-util", "mas-data-model", "mas-iana", "mas-jose", @@ -3117,6 +3117,7 @@ version = "0.1.0" dependencies = [ "async-trait", "chrono", + "futures-util", "mas-data-model", "mas-iana", "mas-jose", diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 09090230c..67930c3fa 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -72,10 +72,10 @@ pub enum Credentials { } impl Credentials { - pub async fn fetch<'r, R>(&self, repo: &'r mut R) -> Result, R::Error> - where - R: Repository, - { + pub async fn fetch( + &self, + repo: &mut (impl Repository + ?Sized), + ) -> Result, E> { let client_id = match self { Credentials::None { client_id } | Credentials::ClientSecretBasic { client_id, .. } diff --git a/crates/axum-utils/src/session.rs b/crates/axum-utils/src/session.rs index 719613675..5e9661525 100644 --- a/crates/axum-utils/src/session.rs +++ b/crates/axum-utils/src/session.rs @@ -43,10 +43,10 @@ impl SessionInfo { } /// Load the [`BrowserSession`] from database - pub async fn load_session( + pub async fn load_session( &self, - repo: &mut R, - ) -> Result, R::Error> { + repo: &mut (impl Repository + ?Sized), + ) -> Result, E> { let session_id = if let Some(id) = self.current { id } else { diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index 9a5956c9e..2d37c40c5 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -51,11 +51,10 @@ enum AccessToken { } impl AccessToken { - async fn fetch( + async fn fetch( &self, - repo: &mut R, - ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> - { + repo: &mut (impl Repository + ?Sized), + ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> { let token = match self { AccessToken::Form(t) | AccessToken::Header(t) => t, AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), @@ -85,11 +84,11 @@ pub struct UserAuthorization { impl UserAuthorization { // TODO: take scopes to validate as parameter - pub async fn protected_form( + pub async fn protected_form( self, - repo: &mut R, - clock: &C, - ) -> Result<(Session, F), AuthorizationVerificationError> { + repo: &mut (impl Repository + ?Sized), + clock: &impl Clock, + ) -> Result<(Session, F), AuthorizationVerificationError> { let form = match self.form { Some(f) => f, None => return Err(AuthorizationVerificationError::MissingForm), @@ -105,11 +104,11 @@ impl UserAuthorization { } // TODO: take scopes to validate as parameter - pub async fn protected( + pub async fn protected( self, - repo: &mut R, - clock: &C, - ) -> Result> { + repo: &mut (impl Repository + ?Sized), + clock: &impl Clock, + ) -> Result> { let (token, session) = self.access_token.fetch(repo).await?; if !token.is_valid(clock.now()) || !session.is_valid() { diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 2f3e88528..4e74569a1 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -203,7 +203,7 @@ impl Options { let pool = database_from_config(&database_config).await?; let password_manager = password_manager_from_config(&passwords_config).await?; - let mut repo = PgRepository::from_pool(&pool).await?; + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); let user = repo .user() .find_by_username(username) @@ -234,7 +234,7 @@ impl Options { let config: DatabaseConfig = root.load_config()?; let pool = database_from_config(&config).await?; - let mut repo = PgRepository::from_pool(&pool).await?; + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); let user = repo .user() @@ -262,7 +262,7 @@ impl Options { let pool = database_from_config(&config.database).await?; let encrypter = config.secrets.encrypter(); - let mut repo = PgRepository::from_pool(&pool).await?; + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); for client in config.clients.iter() { let client_id = client.client_id; diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 1a7e39e69..002309536 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -102,7 +102,7 @@ impl Options { watch_templates(&templates).await?; } - let graphql_schema = mas_handlers::graphql_schema(&pool); + let graphql_schema = mas_handlers::graphql_schema(); // Maximum 50 outgoing HTTP requests at a time let http_client_factory = HttpClientFactory::new(50); diff --git a/crates/graphql/Cargo.toml b/crates/graphql/Cargo.toml index 16f7e5b51..1f8bda4ec 100644 --- a/crates/graphql/Cargo.toml +++ b/crates/graphql/Cargo.toml @@ -10,7 +10,7 @@ anyhow = "1.0.68" async-graphql = { version = "5.0.4", features = ["chrono", "url"] } chrono = "0.4.23" serde = { version = "1.0.152", features = ["derive"] } -sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } +tokio = { version = "1.23.0", features = ["sync"] } thiserror = "1.0.38" tracing = "0.1.37" ulid = "1.0.0" @@ -19,7 +19,6 @@ url = "2.3.1" oauth2-types = { path = "../oauth2-types" } mas-data-model = { path = "../data-model" } mas-storage = { path = "../storage" } -mas-storage-pg = { path = "../storage-pg" } [[bin]] name = "schema" diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 159387ae3..ca3745653 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -34,11 +34,10 @@ use mas_storage::{ oauth2::OAuth2ClientRepository, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, user::{BrowserSessionRepository, UserEmailRepository}, - Pagination, Repository, + BoxRepository, Pagination, }; -use mas_storage_pg::PgRepository; use model::CreationEvent; -use sqlx::PgPool; +use tokio::sync::Mutex; use self::model::{ BrowserSession, Cursor, Node, NodeCursor, NodeType, OAuth2Client, UpstreamOAuth2Link, @@ -94,7 +93,7 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::OAuth2Client.extract_ulid(&id)?; - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let client = repo.oauth2_client().lookup(id).await?; @@ -124,7 +123,7 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::BrowserSession.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let Some(session) = session else { return Ok(None) }; let current_user = session.user; @@ -150,7 +149,7 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::UserEmail.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let Some(session) = session else { return Ok(None) }; let current_user = session.user; @@ -172,7 +171,7 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let Some(session) = session else { return Ok(None) }; let current_user = session.user; @@ -192,7 +191,7 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?; - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let provider = repo.upstream_oauth_provider().lookup(id).await?; @@ -211,7 +210,7 @@ impl RootQuery { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; query( after, diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index e5cd66bce..38fdd4baa 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -15,9 +15,8 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository}; -use mas_storage_pg::PgRepository; -use sqlx::PgPool; +use mas_storage::{compat::CompatSessionRepository, user::UserRepository, BoxRepository}; +use tokio::sync::Mutex; use url::Url; use super::{NodeType, User}; @@ -36,7 +35,7 @@ impl CompatSession { /// The user authorized for this session. async fn user(&self, ctx: &Context<'_>) -> Result { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let user = repo .user() .lookup(self.0.user_id) @@ -101,7 +100,7 @@ impl CompatSsoLogin { ) -> Result, async_graphql::Error> { let Some(session_id) = self.0.session_id() else { return Ok(None) }; - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let session = repo .compat_session() .lookup(session_id) diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 90a0c6b7f..19612f6d7 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,10 +14,9 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, BoxRepository}; use oauth2_types::scope::Scope; -use sqlx::PgPool; +use tokio::sync::Mutex; use ulid::Ulid; use url::Url; @@ -37,7 +36,7 @@ impl OAuth2Session { /// OAuth 2.0 client used by this session. pub async fn client(&self, ctx: &Context<'_>) -> Result { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let client = repo .oauth2_client() .lookup(self.0.client_id) @@ -57,7 +56,7 @@ impl OAuth2Session { &self, ctx: &Context<'_>, ) -> Result { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let browser_session = repo .browser_session() .lookup(self.0.user_session_id) @@ -69,7 +68,7 @@ impl OAuth2Session { /// User authorized for this session. pub async fn user(&self, ctx: &Context<'_>) -> Result { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let browser_session = repo .browser_session() .lookup(self.0.user_session_id) @@ -139,7 +138,7 @@ impl OAuth2Consent { /// OAuth 2.0 client for which the user granted access. pub async fn client(&self, ctx: &Context<'_>) -> Result { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let client = repo .oauth2_client() .lookup(self.client_id) diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index d65158c91..76b3a44a3 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -16,10 +16,9 @@ use anyhow::Context as _; use async_graphql::{Context, Object, ID}; use chrono::{DateTime, Utc}; use mas_storage::{ - upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, Repository, + upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, BoxRepository, }; -use mas_storage_pg::PgRepository; -use sqlx::PgPool; +use tokio::sync::Mutex; use super::{NodeType, User}; @@ -103,7 +102,7 @@ impl UpstreamOAuth2Link { provider.clone() } else { // Fetch on-the-fly - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let provider = repo .upstream_oauth_provider() .lookup(self.link.provider_id) @@ -122,7 +121,7 @@ impl UpstreamOAuth2Link { user.clone() } else if let Some(user_id) = &self.link.user_id { // Fetch on-the-fly - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let user = repo .user() .lookup(*user_id) diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index a8036dc89..35c2cae4f 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -22,10 +22,9 @@ use mas_storage::{ oauth2::OAuth2SessionRepository, upstream_oauth2::UpstreamOAuthLinkRepository, user::{BrowserSessionRepository, UserEmailRepository}, - Pagination, Repository, + BoxRepository, Pagination, }; -use mas_storage_pg::PgRepository; -use sqlx::PgPool; +use tokio::sync::Mutex; use super::{ compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session, @@ -65,10 +64,9 @@ impl User { &self, ctx: &Context<'_>, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let mut user_email_repo = repo.user_email(); - Ok(user_email_repo.get_primary(&self.0).await?.map(UserEmail)) } @@ -84,7 +82,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -131,7 +129,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -178,7 +176,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -229,7 +227,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -276,7 +274,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -350,7 +348,7 @@ pub struct UserEmailsPagination(mas_data_model::User); impl UserEmailsPagination { /// Identifies the total count of items in the connection. async fn total_count(&self, ctx: &Context<'_>) -> Result { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let count = repo.user_email().count(&self.0).await?; Ok(count) } diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index 4271f8965..2e826badc 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -25,7 +25,7 @@ use mas_email::Mailer; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; -use mas_storage::{BoxClock, BoxRng, SystemClock}; +use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock}; use mas_storage_pg::PgRepository; use mas_templates::Templates; use rand::SeedableRng; @@ -156,7 +156,7 @@ impl IntoResponse for RepositoryError { } #[async_trait] -impl FromRequestParts for PgRepository { +impl FromRequestParts for BoxRepository { type Rejection = RepositoryError; async fn from_request_parts( @@ -164,6 +164,8 @@ impl FromRequestParts for PgRepository { state: &AppState, ) -> Result { let repo = PgRepository::from_pool(&state.pool).await?; - Ok(repo) + Ok(repo + .map_err(mas_storage::RepositoryError::from_error) + .boxed()) } } diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index f76cda717..07077ea13 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -22,9 +22,8 @@ use mas_storage::{ CompatSsoLoginRepository, }, user::{UserPasswordRepository, UserRepository}, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, }; -use mas_storage_pg::PgRepository; use rand::{CryptoRng, RngCore}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; @@ -154,7 +153,7 @@ pub enum RouteError { InvalidLoginToken, } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -196,7 +195,7 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(password_manager): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(homeserver): State, Json(input): Json, ) -> Result { @@ -262,7 +261,7 @@ pub(crate) async fn post( } async fn token_login( - repo: &mut PgRepository, + repo: &mut BoxRepository, clock: &dyn Clock, token: &str, ) -> Result<(CompatSession, User), RouteError> { @@ -331,7 +330,7 @@ async fn user_password_login( mut rng: &mut (impl RngCore + CryptoRng + Send), clock: &impl Clock, password_manager: &PasswordManager, - repo: &mut PgRepository, + repo: &mut BoxRepository, username: String, password: String, ) -> Result<(CompatSession, User), RouteError> { diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 602b4d80d..ba3dee136 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -31,9 +31,8 @@ use mas_keystore::Encrypter; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; use mas_storage::{ compat::{CompatSessionRepository, CompatSsoLoginRepository}, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, }; -use mas_storage_pg::PgRepository; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use serde::{Deserialize, Serialize}; use ulid::Ulid; @@ -55,7 +54,7 @@ pub struct Params { pub async fn get( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, @@ -64,7 +63,7 @@ pub async fn get( let (session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session @@ -117,7 +116,7 @@ pub async fn get( pub async fn post( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, @@ -127,7 +126,7 @@ pub async fn post( let (session_info, cookie_jar) = cookie_jar.session_info(); cookie_jar.verify_form(&clock, form)?; - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index d8ef0fb27..da013cf7c 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -19,8 +19,7 @@ use axum::{ }; use hyper::StatusCode; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; -use mas_storage::{compat::CompatSsoLoginRepository, BoxClock, BoxRng, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{compat::CompatSsoLoginRepository, BoxClock, BoxRepository, BoxRng}; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use serde_with::serde; @@ -48,7 +47,7 @@ pub enum RouteError { InvalidRedirectUrl, } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -59,7 +58,7 @@ impl IntoResponse for RouteError { pub async fn get( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, State(url_builder): State, Query(params): Query, ) -> Result { diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index e1ef02be5..096b22de5 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -18,9 +18,8 @@ use hyper::StatusCode; use mas_data_model::TokenType; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - BoxClock, Clock, Repository, + BoxClock, BoxRepository, Clock, }; -use mas_storage_pg::PgRepository; use thiserror::Error; use super::MatrixError; @@ -41,7 +40,7 @@ pub enum RouteError { InvalidAuthorization, } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -68,7 +67,7 @@ impl IntoResponse for RouteError { pub(crate) async fn post( clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, maybe_authorization: Option>>, ) -> Result { let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 6b90464ea..eb970c570 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -18,9 +18,8 @@ use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, }; -use mas_storage_pg::PgRepository; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; use thiserror::Error; @@ -69,7 +68,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl From for RouteError { fn from(_e: TokenFormatError) -> Self { @@ -89,7 +88,7 @@ pub struct ResponseBody { pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, Json(input): Json, ) -> Result { let token_type = TokenType::check(&input.refresh_token)?; diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql.rs index fcc6aa3c3..2d1f7fcc0 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql.rs @@ -22,20 +22,19 @@ use axum::{ Json, TypedHeader, }; use axum_extra::extract::PrivateCookieJar; -use futures_util::{StreamExt, TryStreamExt}; +use futures_util::TryStreamExt; use headers::{ContentType, HeaderValue}; use hyper::header::CACHE_CONTROL; use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_graphql::Schema; use mas_keystore::Encrypter; -use mas_storage_pg::PgRepository; -use sqlx::PgPool; +use mas_storage::BoxRepository; +use tokio::sync::Mutex; use tracing::{info_span, Instrument}; #[must_use] -pub fn schema(pool: &PgPool) -> Schema { +pub fn schema() -> Schema { mas_graphql::schema_builder() - .data(pool.clone()) .extension(Tracing) .extension(ApolloTracing) .finish() @@ -59,8 +58,8 @@ fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span { } pub async fn post( - State(pool): State, State(schema): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, content_type: Option>, body: BodyStream, @@ -68,62 +67,46 @@ pub async fn post( let content_type = content_type.map(|TypedHeader(h)| h.to_string()); let (session_info, _cookie_jar) = cookie_jar.session_info(); - let mut repo = PgRepository::from_pool(&pool).await?; - let maybe_session = session_info.load_session(&mut repo).await?; - repo.cancel().await?; + let maybe_session = session_info.load_session(&mut *repo).await?; - let mut request = async_graphql::http::receive_batch_body( + let mut request = async_graphql::http::receive_body( content_type, body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) .into_async_read(), MultipartOptions::default(), ) - .await?; // XXX: this should probably return another error response? + .await? // XXX: this should probably return another error response? + .data(Mutex::new(repo)); if let Some(session) = maybe_session { request = request.data(session); } - let response = match request { - async_graphql::BatchRequest::Single(request) => { - let span = span_for_graphql_request(&request); - let response = schema.execute(request).instrument(span).await; - async_graphql::BatchResponse::Single(response) - } - async_graphql::BatchRequest::Batch(requests) => async_graphql::BatchResponse::Batch( - futures_util::stream::iter(requests.into_iter()) - .then(|request| { - let span = span_for_graphql_request(&request); - schema.execute(request).instrument(span) - }) - .collect() - .await, - ), - }; + let span = span_for_graphql_request(&request); + let response = schema.execute(request).instrument(span).await; let cache_control = response - .cache_control() + .cache_control .value() .and_then(|v| HeaderValue::from_str(&v).ok()) .map(|h| [(CACHE_CONTROL, h)]); - let headers = response.http_headers(); + let headers = response.http_headers.clone(); Ok((headers, cache_control, Json(response))) } pub async fn get( - State(pool): State, State(schema): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, RawQuery(query): RawQuery, ) -> Result { let (session_info, _cookie_jar) = cookie_jar.session_info(); - let mut repo = PgRepository::from_pool(&pool).await?; - let maybe_session = session_info.load_session(&mut repo).await?; - repo.cancel().await?; + let maybe_session = session_info.load_session(&mut *repo).await?; - let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?; + let mut request = + async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(Mutex::new(repo)); if let Some(session) = maybe_session { request = request.data(session); diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 48ca55608..30519f425 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -43,8 +43,7 @@ use mas_http::CorsLayerExt; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::{Route, UrlBuilder}; -use mas_storage::{BoxClock, BoxRng}; -use mas_storage_pg::PgRepository; +use mas_storage::{BoxClock, BoxRepository, BoxRng}; use mas_templates::{ErrorContext, Templates}; use passwords::PasswordManager; use sqlx::PgPool; @@ -98,7 +97,7 @@ where ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, mas_graphql::Schema: FromRef, - PgPool: FromRef, + BoxRepository: FromRequestParts, Encrypter: FromRef, { let mut router = Router::new().route( @@ -158,7 +157,7 @@ where Keystore: FromRef, UrlBuilder: FromRef, Arc: FromRef, - PgRepository: FromRequestParts, + BoxRepository: FromRequestParts, Encrypter: FromRef, HttpClientFactory: FromRef, BoxClock: FromRequestParts, @@ -213,7 +212,7 @@ where ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, - PgRepository: FromRequestParts, + BoxRepository: FromRequestParts, MatrixHomeserver: FromRef, PasswordManager: FromRef, BoxClock: FromRequestParts, @@ -258,7 +257,7 @@ where S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, Arc: FromRef, - PgRepository: FromRequestParts, + BoxRepository: FromRequestParts, Encrypter: FromRef, Templates: FromRef, Mailer: FromRef, @@ -401,7 +400,7 @@ async fn test_state(pool: sqlx::PgPool) -> Result { let policy_factory = Arc::new(policy_factory); - let graphql_schema = graphql_schema(&pool); + let graphql_schema = graphql_schema(); let http_client_factory = HttpClientFactory::new(10); diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index c17fb9f1b..91121df98 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -27,9 +27,8 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use mas_templates::Templates; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; use thiserror::Error; @@ -69,7 +68,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -81,13 +80,13 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let grant = repo .oauth2_authorization_grant() @@ -147,7 +146,7 @@ pub enum GrantCompletionError { NoSuchClient, } -impl_from_error_for_route!(GrantCompletionError: mas_storage_pg::DatabaseError); +impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError); impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError); @@ -159,7 +158,7 @@ pub(crate) async fn complete( grant: AuthorizationGrant, browser_session: BrowserSession, policy_factory: &PolicyFactory, - mut repo: PgRepository, + mut repo: BoxRepository, ) -> Result>, GrantCompletionError> { // Verify that the grant is in a pending stage if !grant.stage.is_pending() { diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 30efcaa36..4ce10baa1 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -27,9 +27,8 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use mas_templates::Templates; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -90,7 +89,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); @@ -135,7 +134,7 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Form(params): Form, ) -> Result { @@ -168,7 +167,7 @@ pub(crate) async fn get( let templates = templates.clone(); let callback_destination = callback_destination.clone(); async move { - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let prompt = params.auth.prompt.as_deref().unwrap_or_default(); // Check if the request/request_uri/registration params are used. If so, reply diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index c83dca03c..f365a9f39 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -30,9 +30,8 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; use thiserror::Error; use ulid::Ulid; @@ -61,7 +60,7 @@ pub enum RouteError { } impl_from_error_for_route!(mas_templates::TemplateError); -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -77,13 +76,13 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let grant = repo .oauth2_authorization_grant() @@ -130,7 +129,7 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(policy_factory): State>, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, Form(form): Form>, @@ -139,7 +138,7 @@ pub(crate) async fn post( let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let grant = repo .oauth2_authorization_grant() diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 65e48e064..d0dcd26c7 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -25,9 +25,8 @@ use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository}, user::{BrowserSessionRepository, UserRepository}, - BoxClock, Clock, Repository, + BoxClock, BoxRepository, Clock, }; -use mas_storage_pg::PgRepository; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, requests::{IntrospectionRequest, IntrospectionResponse}, @@ -96,7 +95,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl From for RouteError { fn from(_e: TokenFormatError) -> Self { @@ -125,13 +124,13 @@ const API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:org.matrix.msc pub(crate) async fn post( clock: BoxClock, State(http_client_factory): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { let client = client_authorization .credentials - .fetch(&mut repo) + .fetch(&mut *repo) .await .unwrap() .ok_or(RouteError::ClientNotFound)?; diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index 129f636f6..650a19ab7 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -19,8 +19,7 @@ use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; use mas_policy::{PolicyFactory, Violation}; -use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRng, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng}; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, registration::{ @@ -48,7 +47,7 @@ pub(crate) enum RouteError { PolicyDenied(Vec), } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -108,7 +107,7 @@ impl IntoResponse for RouteError { pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, State(policy_factory): State>, State(encrypter): State, Json(body): Json, diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 5b6b7565b..76943e7e2 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -37,9 +37,8 @@ use mas_storage::{ OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, user::BrowserSessionRepository, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, }; -use mas_storage_pg::PgRepository; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, pkce::CodeChallengeError, @@ -150,7 +149,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::claims::ClaimError); impl_from_error_for_route!(mas_jose::claims::TokenHashError); @@ -163,13 +162,13 @@ pub(crate) async fn post( State(http_client_factory): State, State(key_store): State, State(url_builder): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { let client = client_authorization .credentials - .fetch(&mut repo) + .fetch(&mut *repo) .await? .ok_or(RouteError::ClientNotFound)?; @@ -185,7 +184,7 @@ pub(crate) async fn post( let form = client_authorization.form.ok_or(RouteError::BadRequest)?; - let reply = match form { + let (reply, repo) = match form { AccessTokenRequest::AuthorizationCode(grant) => { authorization_code_grant( &mut rng, @@ -206,6 +205,8 @@ pub(crate) async fn post( } }; + repo.save().await?; + let mut headers = HeaderMap::new(); headers.typed_insert(CacheControl::new().with_no_store()); headers.typed_insert(Pragma::no_cache()); @@ -221,8 +222,8 @@ async fn authorization_code_grant( client: &Client, key_store: &Keystore, url_builder: &UrlBuilder, - mut repo: PgRepository, -) -> Result { + mut repo: BoxRepository, +) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { let authz_grant = repo .oauth2_authorization_grant() .find_by_code(&grant.code) @@ -367,9 +368,7 @@ async fn authorization_code_grant( .exchange(clock, authz_grant) .await?; - repo.save().await?; - - Ok(params) + Ok((params, repo)) } async fn refresh_token_grant( @@ -377,8 +376,8 @@ async fn refresh_token_grant( clock: &impl Clock, grant: &RefreshTokenGrant, client: &Client, - mut repo: PgRepository, -) -> Result { + mut repo: BoxRepository, +) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { let refresh_token = repo .oauth2_refresh_token() .find_by_token(&grant.refresh_token) @@ -439,7 +438,5 @@ async fn refresh_token_grant( .with_refresh_token(new_refresh_token.refresh_token) .with_scope(session.scope); - repo.save().await?; - - Ok(params) + Ok((params, repo)) } diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index eb9e1cc2f..e56dafbca 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -31,9 +31,8 @@ use mas_router::UrlBuilder; use mas_storage::{ oauth2::OAuth2ClientRepository, user::{BrowserSessionRepository, UserEmailRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use oauth2_types::scope; use serde::Serialize; use serde_with::skip_serializing_none; @@ -65,7 +64,7 @@ pub enum RouteError { #[error("failed to authenticate")] AuthorizationVerificationError( - #[from] AuthorizationVerificationError, + #[from] AuthorizationVerificationError, ), #[error("no suitable key found for signing")] @@ -78,7 +77,7 @@ pub enum RouteError { NoSuchBrowserSession, } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); @@ -100,11 +99,11 @@ pub async fn get( mut rng: BoxRng, clock: BoxClock, State(url_builder): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(key_store): State, user_authorization: UserAuthorization, ) -> Result { - let session = user_authorization.protected(&mut repo, &clock).await?; + let session = user_authorization.protected(&mut *repo, &clock).await?; let browser_session = repo .browser_session() diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index ff47084b5..8da6231af 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -24,9 +24,8 @@ use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; use mas_storage::{ upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use thiserror::Error; use ulid::Ulid; @@ -45,7 +44,7 @@ pub(crate) enum RouteError { impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError); -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -60,7 +59,7 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(http_client_factory): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(url_builder): State, cookie_jar: PrivateCookieJar, Path(provider_id): Path, diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index b324cfb24..bc24c399a 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -30,9 +30,8 @@ use mas_storage::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, }; -use mas_storage_pg::PgRepository; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; use thiserror::Error; @@ -99,7 +98,7 @@ pub(crate) enum RouteError { Internal(Box), } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::JwksError); @@ -123,7 +122,7 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(http_client_factory): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(url_builder): State, State(encrypter): State, State(keystore): State, diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index bdd5df1ff..89614dcb1 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -27,9 +27,8 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, user::{BrowserSessionRepository, UserRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use mas_templates::{ EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink, @@ -72,7 +71,7 @@ pub(crate) enum RouteError { impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -95,7 +94,7 @@ pub(crate) enum FormData { pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(link_id): Path, @@ -129,7 +128,7 @@ pub(crate) async fn get( let (user_session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); - let maybe_user_session = user_session_info.load_session(&mut repo).await?; + let maybe_user_session = user_session_info.load_session(&mut *repo).await?; let render = match (maybe_user_session, link.user_id) { (Some(session), Some(user_id)) if session.user.id == user_id => { @@ -211,7 +210,7 @@ pub(crate) async fn get( pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(link_id): Path, Form(form): Form>, @@ -250,7 +249,7 @@ pub(crate) async fn post( } let (user_session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_user_session = user_session_info.load_session(&mut repo).await?; + let maybe_user_session = user_session_info.load_session(&mut *repo).await?; let session = match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 64218e3a8..7b89b2d81 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -24,8 +24,7 @@ use mas_axum_utils::{ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng}; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; @@ -41,13 +40,13 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session @@ -68,7 +67,7 @@ pub(crate) async fn get( pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, State(mailer): State, cookie_jar: PrivateCookieJar, Query(query): Query, @@ -77,7 +76,7 @@ pub(crate) async fn post( let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session @@ -99,7 +98,7 @@ pub(crate) async fn post( }; start_email_verification( &mailer, - &mut repo, + &mut *repo, &mut rng, &clock, &session.user, diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index fd2f2981f..251e5adb3 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -28,8 +28,7 @@ use mas_data_model::{BrowserSession, User, UserEmail}; use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Clock, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, Repository}; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; use serde::Deserialize; @@ -51,28 +50,28 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; if let Some(session) = maybe_session { - render(&mut rng, &clock, templates, session, cookie_jar, &mut repo).await + render(&mut rng, &clock, templates, session, cookie_jar, &mut *repo).await } else { let login = mas_router::Login::default(); Ok((cookie_jar, login.go()).into_response()) } } -async fn render( +async fn render( rng: impl Rng + Send, clock: &impl Clock, templates: Templates, session: BrowserSession, cookie_jar: PrivateCookieJar, - repo: &mut impl Repository, + repo: &mut (impl Repository + ?Sized), ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); @@ -87,9 +86,9 @@ async fn render( Ok((cookie_jar, Html(content)).into_response()) } -async fn start_email_verification( +async fn start_email_verification( mailer: &Mailer, - repo: &mut impl Repository, + repo: &mut (impl Repository + ?Sized), mut rng: impl Rng + Send, clock: &impl Clock, user: &User, @@ -124,14 +123,14 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(mailer): State, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let mut session = if let Some(session) = maybe_session { session @@ -150,7 +149,7 @@ pub(crate) async fn post( .await?; let next = mas_router::AccountVerifyEmail::new(email.id); - start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) + start_email_verification(&mailer, &mut *repo, &mut rng, &clock, &session.user, email) .await?; repo.save().await?; return Ok((cookie_jar, next.go()).into_response()); @@ -169,7 +168,7 @@ pub(crate) async fn post( } let next = mas_router::AccountVerifyEmail::new(email.id); - start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) + start_email_verification(&mailer, &mut *repo, &mut rng, &clock, &session.user, email) .await?; repo.save().await?; return Ok((cookie_jar, next.go()).into_response()); @@ -212,7 +211,7 @@ pub(crate) async fn post( templates.clone(), session, cookie_jar, - &mut repo, + &mut *repo, ) .await?; diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index e330c944f..6a701b50a 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -24,8 +24,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng}; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; use ulid::Ulid; @@ -41,7 +40,7 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, Path(id): Path, cookie_jar: PrivateCookieJar, @@ -49,7 +48,7 @@ pub(crate) async fn get( let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session @@ -82,7 +81,7 @@ pub(crate) async fn get( pub(crate) async fn post( clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Query(query): Query, Path(id): Path, @@ -91,7 +90,7 @@ pub(crate) async fn post( let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 76ea5667d..8860c43c1 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -25,22 +25,21 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use mas_templates::{AccountContext, TemplateContext, Templates}; pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 4fa86eae2..ddc477798 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -27,9 +27,8 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, }; -use mas_storage_pg::PgRepository; use mas_templates::{EmptyContext, TemplateContext, Templates}; use rand::Rng; use serde::Deserialize; @@ -48,12 +47,12 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; if let Some(session) = maybe_session { render(&mut rng, &clock, templates, session, cookie_jar).await @@ -86,7 +85,7 @@ pub(crate) async fn post( clock: BoxClock, State(password_manager): State, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { @@ -94,7 +93,7 @@ pub(crate) async fn post( let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index d4322eefd..0cfe0d05d 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -20,8 +20,7 @@ use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_keystore::Encrypter; use mas_router::UrlBuilder; -use mas_storage::{BoxClock, BoxRng}; -use mas_storage_pg::PgRepository; +use mas_storage::{BoxClock, BoxRepository, BoxRng}; use mas_templates::{IndexContext, TemplateContext, Templates}; pub async fn get( @@ -29,12 +28,12 @@ pub async fn get( clock: BoxClock, State(templates): State, State(url_builder): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let session = session_info.load_session(&mut repo).await?; + let session = session_info.load_session(&mut *repo).await?; let ctx = IndexContext::new(url_builder.oidc_discovery()) .maybe_with_session(session) diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index b245b5977..4f9ecfd80 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -26,9 +26,8 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, Repository, }; -use mas_storage_pg::PgRepository; use mas_templates::{ FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, }; @@ -53,14 +52,14 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; if maybe_session.is_some() { let reply = query.go_next(); @@ -71,7 +70,7 @@ pub(crate) async fn get( LoginContext::default().with_upstrem_providers(providers), query, csrf_token, - &mut repo, + &mut *repo, &templates, ) .await?; @@ -85,7 +84,7 @@ pub(crate) async fn post( clock: BoxClock, State(password_manager): State, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, @@ -117,7 +116,7 @@ pub(crate) async fn post( .with_upstrem_providers(providers), query, csrf_token, - &mut repo, + &mut *repo, &templates, ) .await?; @@ -127,7 +126,7 @@ pub(crate) async fn post( match login( password_manager, - &mut repo, + &mut *repo, rng, &clock, &form.username, @@ -149,7 +148,7 @@ pub(crate) async fn post( LoginContext::default().with_form_state(state), query, csrf_token, - &mut repo, + &mut *repo, &templates, ) .await?; @@ -162,7 +161,7 @@ pub(crate) async fn post( // TODO: move that logic elsewhere? async fn login( password_manager: PasswordManager, - repo: &mut impl Repository, + repo: &mut (impl Repository + ?Sized), mut rng: impl Rng + CryptoRng + Send, clock: &impl Clock, username: &str, @@ -236,7 +235,7 @@ async fn render( ctx: LoginContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - repo: &mut impl Repository, + repo: &mut (impl Repository + ?Sized), templates: &Templates, ) -> Result { let next = action.load_context(repo).await?; diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 9cdc93f03..189331fde 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -20,12 +20,11 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; -use mas_storage::{user::BrowserSessionRepository, BoxClock, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{user::BrowserSessionRepository, BoxClock, BoxRepository}; pub(crate) async fn post( clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Form(form): Form>>, ) -> Result { @@ -33,7 +32,7 @@ pub(crate) async fn post( let (session_info, mut cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; if let Some(session) = maybe_session { repo.browser_session().finish(&clock, session).await?; diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index ced979020..2750711c1 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -26,9 +26,8 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use mas_templates::{ReauthContext, TemplateContext, Templates}; use serde::Deserialize; use zeroize::Zeroizing; @@ -45,14 +44,14 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session @@ -64,7 +63,7 @@ pub(crate) async fn get( }; let ctx = ReauthContext::default(); - let next = query.load_context(&mut repo).await?; + let next = query.load_context(&mut *repo).await?; let ctx = if let Some(next) = next { ctx.with_post_action(next) } else { @@ -81,7 +80,7 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(password_manager): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, @@ -90,7 +89,7 @@ pub(crate) async fn post( let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 68cf5c493..467352af2 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -33,9 +33,8 @@ use mas_policy::PolicyFactory; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, Repository, }; -use mas_storage_pg::PgRepository; use mas_templates::{ EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, TemplateContext, Templates, ToFormState, @@ -63,14 +62,14 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; if maybe_session.is_some() { let reply = query.go_next(); @@ -80,7 +79,7 @@ pub(crate) async fn get( RegisterContext::default(), query, csrf_token, - &mut repo, + &mut *repo, &templates, ) .await?; @@ -97,7 +96,7 @@ pub(crate) async fn post( State(mailer): State, State(policy_factory): State>, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, @@ -175,7 +174,7 @@ pub(crate) async fn post( RegisterContext::default().with_form_state(state), query, csrf_token, - &mut repo, + &mut *repo, &templates, ) .await?; @@ -234,7 +233,7 @@ async fn render( ctx: RegisterContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - repo: &mut impl Repository, + repo: &mut (impl Repository + ?Sized), templates: &Templates, ) -> Result { let next = action.load_context(repo).await?; diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index db3c33920..b29460842 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -40,9 +40,9 @@ impl OptionalPostAuthAction { self.go_next_or_default(&mas_router::Index) } - pub async fn load_context( - &self, - repo: &mut R, + pub async fn load_context<'a>( + &'a self, + repo: &'a mut (impl Repository + ?Sized), ) -> anyhow::Result> { let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let ctx = match action { diff --git a/crates/storage-pg/Cargo.toml b/crates/storage-pg/Cargo.toml index fad6e30ee..3373a21fa 100644 --- a/crates/storage-pg/Cargo.toml +++ b/crates/storage-pg/Cargo.toml @@ -13,6 +13,7 @@ serde = { version = "1.0.152", features = ["derive"] } serde_json = "1.0.91" thiserror = "1.0.38" tracing = "0.1.37" +futures-util = "0.3.25" rand = "0.8.5" rand_chacha = "0.3.1" diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs index dd68e4d5f..9b3407563 100644 --- a/crates/storage-pg/src/compat/mod.rs +++ b/crates/storage-pg/src/compat/mod.rs @@ -103,7 +103,7 @@ mod tests { const SECOND_TOKEN: &str = "second_access_token"; let mut rng = ChaChaRng::seed_from_u64(42); let clock = MockClock::default(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); // Create a user let user = repo @@ -139,7 +139,7 @@ mod tests { repo.save().await.unwrap(); { - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); // Adding the same token a second time should conflict assert!(repo .compat_access_token() @@ -156,7 +156,7 @@ mod tests { } // Grab a new repo - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); // Looking up via ID works let token_lookup = repo @@ -223,7 +223,7 @@ mod tests { const REFRESH_TOKEN: &str = "refresh_token"; let mut rng = ChaChaRng::seed_from_u64(42); let clock = MockClock::default(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); // Create a user let user = repo diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 540027551..6448b61a3 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; use mas_storage::{ compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, @@ -59,21 +60,19 @@ impl PgRepository { let txn = pool.begin().await?; Ok(PgRepository { txn }) } - - pub async fn save(self) -> Result<(), DatabaseError> { - self.txn.commit().await?; - Ok(()) - } - - pub async fn cancel(self) -> Result<(), DatabaseError> { - self.txn.rollback().await?; - Ok(()) - } } impl Repository for PgRepository { type Error = DatabaseError; + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + self.txn.commit().map_err(DatabaseError::from).boxed() + } + + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + self.txn.rollback().map_err(DatabaseError::from).boxed() + } + fn upstream_oauth_link<'c>( &'c mut self, ) -> Box + 'c> { diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index b3b882321..7c3eab376 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -29,7 +29,7 @@ use crate::PgRepository; async fn test_user_repo(pool: PgPool) { const USERNAME: &str = "john"; - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); let mut rng = ChaChaRng::seed_from_u64(42); let clock = MockClock::default(); @@ -77,7 +77,7 @@ async fn test_user_email_repo(pool: PgPool) { const CODE2: &str = "543210"; const EMAIL: &str = "john@example.com"; - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); let mut rng = ChaChaRng::seed_from_u64(42); let clock = MockClock::default(); @@ -259,7 +259,7 @@ async fn test_user_password_repo(pool: PgPool) { const FIRST_PASSWORD_HASH: &str = "doesntmatter"; const SECOND_PASSWORD_HASH: &str = "alsodoesntmatter"; - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); let mut rng = ChaChaRng::seed_from_u64(42); let clock = MockClock::default(); diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index 86ca9f078..cea7b03b0 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -9,6 +9,7 @@ license = "Apache-2.0" async-trait = "0.1.60" chrono = "0.4.23" thiserror = "1.0.38" +futures-util = "0.3.25" rand_core = "0.6.4" url = "2.3.1" diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 0cdc4e39b..aa1db0af7 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -28,21 +28,21 @@ clippy::module_name_repetitions )] +use rand_core::CryptoRngCore; + pub mod clock; +pub mod pagination; +pub(crate) mod repository; pub mod compat; pub mod oauth2; -pub mod pagination; -pub(crate) mod repository; pub mod upstream_oauth2; pub mod user; -use rand_core::CryptoRngCore; - pub use self::{ clock::{Clock, SystemClock}, pagination::{Page, Pagination}, - repository::Repository, + repository::{BoxRepository, Repository, RepositoryError}, }; pub struct MapErr { @@ -86,7 +86,6 @@ macro_rules! repository_impl { where R: $repo_trait, F: FnMut(::Error) -> E + ::std::marker::Send + ::std::marker::Sync, - E: ::std::error::Error + ::std::marker::Send + ::std::marker::Sync, { type Error = E; diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 085c06aba..3da64a8c6 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; +use thiserror::Error; + use crate::{ compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, @@ -32,6 +35,23 @@ use crate::{ pub trait Repository: Send { type Error: std::error::Error + Send + Sync + 'static; + fn map_err(self, mapper: Mapper) -> MapErr + where + Self: Sized, + { + MapErr::new(self, mapper) + } + + fn boxed(self) -> BoxRepository + where + Self: Sized + Sync + 'static, + { + Box::new(self) + } + + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; + fn upstream_oauth_link<'c>( &'c mut self, ) -> Box + 'c>; @@ -91,14 +111,44 @@ pub trait Repository: Send { ) -> Box + 'c>; } +/// An opaque, type-erased error +#[derive(Debug, Error)] +#[error(transparent)] +pub struct RepositoryError { + source: Box, +} + +impl RepositoryError { + pub fn from_error(value: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self { + source: Box::new(value), + } + } +} + +pub type BoxRepository = + Box + Send + Sync + 'static>; + impl Repository for crate::MapErr where R: Repository, - F: FnMut(R::Error) -> E + Send + Sync, + R::Error: 'static, + F: FnMut(R::Error) -> E + Send + Sync + 'static, E: std::error::Error + Send + Sync + 'static, { type Error = E; + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::new(self.inner).save().map_err(self.mapper).boxed() + } + + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::new(self.inner).cancel().map_err(self.mapper).boxed() + } + fn upstream_oauth_link<'c>( &'c mut self, ) -> Box + 'c> { diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index c5e024af7..0057f2d6c 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -21,7 +21,7 @@ use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait UpstreamOAuthLinkRepository: Send + Sync { - type Error: std::error::Error + Send + Sync; + type Error; /// Lookup an upstream OAuth link by its ID async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; From aaa6944815696be2a44c8d9b9b1453adc72d4657 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 23 Jan 2023 18:12:50 +0100 Subject: [PATCH 35/45] storage: impl Repository for Box --- .../handlers/src/compat/login_sso_complete.rs | 4 +- crates/handlers/src/graphql.rs | 4 +- .../src/oauth2/authorization/complete.rs | 2 +- .../handlers/src/oauth2/authorization/mod.rs | 2 +- crates/handlers/src/oauth2/consent.rs | 4 +- crates/handlers/src/oauth2/introspection.rs | 2 +- crates/handlers/src/oauth2/token.rs | 2 +- crates/handlers/src/oauth2/userinfo.rs | 2 +- crates/handlers/src/upstream_oauth2/link.rs | 4 +- .../handlers/src/views/account/emails/add.rs | 6 +- .../handlers/src/views/account/emails/mod.rs | 12 +- .../src/views/account/emails/verify.rs | 4 +- crates/handlers/src/views/account/mod.rs | 2 +- crates/handlers/src/views/account/password.rs | 4 +- crates/handlers/src/views/index.rs | 2 +- crates/handlers/src/views/login.rs | 10 +- crates/handlers/src/views/logout.rs | 2 +- crates/handlers/src/views/reauth.rs | 6 +- crates/handlers/src/views/register.rs | 6 +- crates/storage/src/repository.rs | 112 ++++++++++++++++++ 20 files changed, 152 insertions(+), 40 deletions(-) diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index ba3dee136..540287467 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -63,7 +63,7 @@ pub async fn get( let (session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -126,7 +126,7 @@ pub async fn post( let (session_info, cookie_jar) = cookie_jar.session_info(); cookie_jar.verify_form(&clock, form)?; - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql.rs index 2d1f7fcc0..233c46906 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql.rs @@ -67,7 +67,7 @@ pub async fn post( let content_type = content_type.map(|TypedHeader(h)| h.to_string()); let (session_info, _cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let mut request = async_graphql::http::receive_body( content_type, @@ -103,7 +103,7 @@ pub async fn get( RawQuery(query): RawQuery, ) -> Result { let (session_info, _cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(Mutex::new(repo)); diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 91121df98..6b6869da9 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -86,7 +86,7 @@ pub(crate) async fn get( ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let grant = repo .oauth2_authorization_grant() diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 4ce10baa1..1aa8ce64d 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -167,7 +167,7 @@ pub(crate) async fn get( let templates = templates.clone(); let callback_destination = callback_destination.clone(); async move { - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let prompt = params.auth.prompt.as_deref().unwrap_or_default(); // Check if the request/request_uri/registration params are used. If so, reply diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index f365a9f39..ffaa6106e 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -82,7 +82,7 @@ pub(crate) async fn get( ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let grant = repo .oauth2_authorization_grant() @@ -138,7 +138,7 @@ pub(crate) async fn post( let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let grant = repo .oauth2_authorization_grant() diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index d0dcd26c7..3b44c511d 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -130,7 +130,7 @@ pub(crate) async fn post( ) -> Result { let client = client_authorization .credentials - .fetch(&mut *repo) + .fetch(&mut repo) .await .unwrap() .ok_or(RouteError::ClientNotFound)?; diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 76943e7e2..682813bf4 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -168,7 +168,7 @@ pub(crate) async fn post( ) -> Result { let client = client_authorization .credentials - .fetch(&mut *repo) + .fetch(&mut repo) .await? .ok_or(RouteError::ClientNotFound)?; diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index e56dafbca..d2d27cd0b 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -103,7 +103,7 @@ pub async fn get( State(key_store): State, user_authorization: UserAuthorization, ) -> Result { - let session = user_authorization.protected(&mut *repo, &clock).await?; + let session = user_authorization.protected(&mut repo, &clock).await?; let browser_session = repo .browser_session() diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 89614dcb1..30d678cf2 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -128,7 +128,7 @@ pub(crate) async fn get( let (user_session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); - let maybe_user_session = user_session_info.load_session(&mut *repo).await?; + let maybe_user_session = user_session_info.load_session(&mut repo).await?; let render = match (maybe_user_session, link.user_id) { (Some(session), Some(user_id)) if session.user.id == user_id => { @@ -249,7 +249,7 @@ pub(crate) async fn post( } let (user_session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_user_session = user_session_info.load_session(&mut *repo).await?; + let maybe_user_session = user_session_info.load_session(&mut repo).await?; let session = match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 7b89b2d81..e26c9cc1a 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -46,7 +46,7 @@ pub(crate) async fn get( let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -76,7 +76,7 @@ pub(crate) async fn post( let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -98,7 +98,7 @@ pub(crate) async fn post( }; start_email_verification( &mailer, - &mut *repo, + &mut repo, &mut rng, &clock, &session.user, diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 251e5adb3..830383376 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -55,10 +55,10 @@ pub(crate) async fn get( ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if let Some(session) = maybe_session { - render(&mut rng, &clock, templates, session, cookie_jar, &mut *repo).await + render(&mut rng, &clock, templates, session, cookie_jar, &mut repo).await } else { let login = mas_router::Login::default(); Ok((cookie_jar, login.go()).into_response()) @@ -130,7 +130,7 @@ pub(crate) async fn post( ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let mut session = if let Some(session) = maybe_session { session @@ -149,7 +149,7 @@ pub(crate) async fn post( .await?; let next = mas_router::AccountVerifyEmail::new(email.id); - start_email_verification(&mailer, &mut *repo, &mut rng, &clock, &session.user, email) + start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) .await?; repo.save().await?; return Ok((cookie_jar, next.go()).into_response()); @@ -168,7 +168,7 @@ pub(crate) async fn post( } let next = mas_router::AccountVerifyEmail::new(email.id); - start_email_verification(&mailer, &mut *repo, &mut rng, &clock, &session.user, email) + start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) .await?; repo.save().await?; return Ok((cookie_jar, next.go()).into_response()); @@ -211,7 +211,7 @@ pub(crate) async fn post( templates.clone(), session, cookie_jar, - &mut *repo, + &mut repo, ) .await?; diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 6a701b50a..d7f074b8c 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -48,7 +48,7 @@ pub(crate) async fn get( let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -90,7 +90,7 @@ pub(crate) async fn post( let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 8860c43c1..162b78993 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -39,7 +39,7 @@ pub(crate) async fn get( let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index ddc477798..d9e026105 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -52,7 +52,7 @@ pub(crate) async fn get( ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if let Some(session) = maybe_session { render(&mut rng, &clock, templates, session, cookie_jar).await @@ -93,7 +93,7 @@ pub(crate) async fn post( let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index 0cfe0d05d..7b4be7df0 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -33,7 +33,7 @@ pub async fn get( ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let session = session_info.load_session(&mut *repo).await?; + let session = session_info.load_session(&mut repo).await?; let ctx = IndexContext::new(url_builder.oidc_discovery()) .maybe_with_session(session) diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 4f9ecfd80..2836eee35 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -59,7 +59,7 @@ pub(crate) async fn get( let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if maybe_session.is_some() { let reply = query.go_next(); @@ -70,7 +70,7 @@ pub(crate) async fn get( LoginContext::default().with_upstrem_providers(providers), query, csrf_token, - &mut *repo, + &mut repo, &templates, ) .await?; @@ -116,7 +116,7 @@ pub(crate) async fn post( .with_upstrem_providers(providers), query, csrf_token, - &mut *repo, + &mut repo, &templates, ) .await?; @@ -126,7 +126,7 @@ pub(crate) async fn post( match login( password_manager, - &mut *repo, + &mut repo, rng, &clock, &form.username, @@ -148,7 +148,7 @@ pub(crate) async fn post( LoginContext::default().with_form_state(state), query, csrf_token, - &mut *repo, + &mut repo, &templates, ) .await?; diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 189331fde..9b0f3602e 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -32,7 +32,7 @@ pub(crate) async fn post( let (session_info, mut cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if let Some(session) = maybe_session { repo.browser_session().finish(&clock, session).await?; diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 2750711c1..12f205d6a 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -51,7 +51,7 @@ pub(crate) async fn get( let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -63,7 +63,7 @@ pub(crate) async fn get( }; let ctx = ReauthContext::default(); - let next = query.load_context(&mut *repo).await?; + let next = query.load_context(&mut repo).await?; let ctx = if let Some(next) = next { ctx.with_post_action(next) } else { @@ -89,7 +89,7 @@ pub(crate) async fn post( let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 467352af2..ad1ff3784 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -69,7 +69,7 @@ pub(crate) async fn get( let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut *repo).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if maybe_session.is_some() { let reply = query.go_next(); @@ -79,7 +79,7 @@ pub(crate) async fn get( RegisterContext::default(), query, csrf_token, - &mut *repo, + &mut repo, &templates, ) .await?; @@ -174,7 +174,7 @@ pub(crate) async fn post( RegisterContext::default().with_form_state(state), query, csrf_token, - &mut *repo, + &mut repo, &templates, ) .await?; diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 3da64a8c6..d6772b9ad 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -265,3 +265,115 @@ where )) } } + +impl Repository for Box { + type Error = R::Error; + + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> + where + Self: Sized, + { + // This shouldn't be callable? + unimplemented!() + } + + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> + where + Self: Sized, + { + // This shouldn't be callable? + unimplemented!() + } + + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).upstream_oauth_link() + } + + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).upstream_oauth_provider() + } + + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).upstream_oauth_session() + } + + fn user<'c>(&'c mut self) -> Box + 'c> { + (**self).user() + } + + fn user_email<'c>(&'c mut self) -> Box + 'c> { + (**self).user_email() + } + + fn user_password<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).user_password() + } + + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).browser_session() + } + + fn oauth2_client<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_client() + } + + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_authorization_grant() + } + + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_session() + } + + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_access_token() + } + + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_refresh_token() + } + + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_session() + } + + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_sso_login() + } + + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_access_token() + } + + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_refresh_token() + } +} From 59ce524586edf3ead0dfac75723c97be0e9c2fe5 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 24 Jan 2023 16:04:18 +0100 Subject: [PATCH 36/45] storage: split the repository trait --- clippy.toml | 2 +- crates/axum-utils/src/client_authorization.rs | 4 +- crates/axum-utils/src/session.rs | 4 +- crates/axum-utils/src/user_authorization.rs | 8 +- crates/cli/src/commands/manage.rs | 2 +- crates/data-model/src/compat/device.rs | 4 +- .../handlers/src/views/account/emails/mod.rs | 8 +- crates/handlers/src/views/login.rs | 6 +- crates/handlers/src/views/register.rs | 4 +- crates/handlers/src/views/shared.rs | 4 +- crates/storage-pg/src/compat/mod.rs | 2 +- crates/storage-pg/src/lib.rs | 2 +- crates/storage-pg/src/repository.rs | 10 +- crates/storage-pg/src/upstream_oauth2/mod.rs | 2 +- crates/storage-pg/src/user/tests.rs | 2 +- crates/storage/src/lib.rs | 19 +- crates/storage/src/repository.rs | 624 ++++++++++-------- crates/tasks/src/database.rs | 2 +- 18 files changed, 401 insertions(+), 308 deletions(-) diff --git a/clippy.toml b/clippy.toml index 61c5c04f2..a300ae05b 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1,4 +1,4 @@ -doc-valid-idents = ["OpenID", "OAuth", ".."] +doc-valid-idents = ["OpenID", "OAuth", "..", "PostgreSQL"] disallowed-methods = [ { path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" }, diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 67930c3fa..6f5b3e270 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -31,7 +31,7 @@ use mas_http::HttpServiceExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_keystore::Encrypter; -use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, RepositoryAccess}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; use thiserror::Error; @@ -74,7 +74,7 @@ pub enum Credentials { impl Credentials { pub async fn fetch( &self, - repo: &mut (impl Repository + ?Sized), + repo: &mut (impl RepositoryAccess + ?Sized), ) -> Result, E> { let client_id = match self { Credentials::None { client_id } diff --git a/crates/axum-utils/src/session.rs b/crates/axum-utils/src/session.rs index 5e9661525..c4fece7b0 100644 --- a/crates/axum-utils/src/session.rs +++ b/crates/axum-utils/src/session.rs @@ -14,7 +14,7 @@ use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use mas_data_model::BrowserSession; -use mas_storage::{user::BrowserSessionRepository, Repository}; +use mas_storage::{user::BrowserSessionRepository, RepositoryAccess}; use serde::{Deserialize, Serialize}; use ulid::Ulid; @@ -45,7 +45,7 @@ impl SessionInfo { /// Load the [`BrowserSession`] from database pub async fn load_session( &self, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, ) -> Result, E> { let session_id = if let Some(id) = self.current { id diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index 2d37c40c5..c9bc537c1 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -29,7 +29,7 @@ use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode use mas_data_model::Session; use mas_storage::{ oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository}, - Clock, Repository, + Clock, RepositoryAccess, }; use serde::{de::DeserializeOwned, Deserialize}; use thiserror::Error; @@ -53,7 +53,7 @@ enum AccessToken { impl AccessToken { async fn fetch( &self, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> { let token = match self { AccessToken::Form(t) | AccessToken::Header(t) => t, @@ -86,7 +86,7 @@ impl UserAuthorization { // TODO: take scopes to validate as parameter pub async fn protected_form( self, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, clock: &impl Clock, ) -> Result<(Session, F), AuthorizationVerificationError> { let form = match self.form { @@ -106,7 +106,7 @@ impl UserAuthorization { // TODO: take scopes to validate as parameter pub async fn protected( self, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, clock: &impl Clock, ) -> Result> { let (token, session) = self.access_token.fetch(repo).await?; diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 4e74569a1..b685a167d 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -21,7 +21,7 @@ use mas_storage::{ oauth2::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, - Repository, SystemClock, + Repository, RepositoryAccess, SystemClock, }; use mas_storage_pg::PgRepository; use oauth2_types::scope::Scope; diff --git a/crates/data-model/src/compat/device.rs b/crates/data-model/src/compat/device.rs index 84bdd067e..eebfd9eda 100644 --- a/crates/data-model/src/compat/device.rs +++ b/crates/data-model/src/compat/device.rs @@ -15,7 +15,7 @@ use oauth2_types::scope::ScopeToken; use rand::{ distributions::{Alphanumeric, DistString}, - Rng, + RngCore, }; use serde::Serialize; use thiserror::Error; @@ -48,7 +48,7 @@ impl Device { } /// Generate a random device ID - pub fn generate(rng: &mut R) -> Self { + pub fn generate(rng: &mut R) -> Self { let id: String = Alphanumeric.sample_string(rng, DEVICE_ID_LENGTH); Self { id } } diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 830383376..ad997e0dc 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -28,7 +28,9 @@ use mas_data_model::{BrowserSession, User, UserEmail}; use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, Repository}; +use mas_storage::{ + user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, +}; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; use serde::Deserialize; @@ -71,7 +73,7 @@ async fn render( templates: Templates, session: BrowserSession, cookie_jar: PrivateCookieJar, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); @@ -88,7 +90,7 @@ async fn render( async fn start_email_verification( mailer: &Mailer, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, mut rng: impl Rng + Send, clock: &impl Clock, user: &User, diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 2836eee35..3083eae00 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -26,7 +26,7 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, - BoxClock, BoxRepository, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, }; use mas_templates::{ FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, @@ -161,7 +161,7 @@ pub(crate) async fn post( // TODO: move that logic elsewhere? async fn login( password_manager: PasswordManager, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, mut rng: impl Rng + CryptoRng + Send, clock: &impl Clock, username: &str, @@ -235,7 +235,7 @@ async fn render( ctx: LoginContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, templates: &Templates, ) -> Result { let next = action.load_context(repo).await?; diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index ad1ff3784..64e30af72 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -33,7 +33,7 @@ use mas_policy::PolicyFactory; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - BoxClock, BoxRepository, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; use mas_templates::{ EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, @@ -233,7 +233,7 @@ async fn render( ctx: RegisterContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, templates: &Templates, ) -> Result { let next = action.load_context(repo).await?; diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index b29460842..69fdf901f 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -18,7 +18,7 @@ use mas_storage::{ compat::CompatSsoLoginRepository, oauth2::OAuth2AuthorizationGrantRepository, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, - Repository, + RepositoryAccess, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; use serde::{Deserialize, Serialize}; @@ -42,7 +42,7 @@ impl OptionalPostAuthAction { pub async fn load_context<'a>( &'a self, - repo: &'a mut (impl Repository + ?Sized), + repo: &'a mut impl RepositoryAccess, ) -> anyhow::Result> { let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let ctx = match action { diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs index 9b3407563..5d99f3329 100644 --- a/crates/storage-pg/src/compat/mod.rs +++ b/crates/storage-pg/src/compat/mod.rs @@ -32,7 +32,7 @@ mod tests { CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, }, user::UserRepository, - Clock, Repository, + Clock, Repository, RepositoryAccess, }; use rand::SeedableRng; use rand_chacha::ChaChaRng; diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs index 459c8c3bf..08c89db2e 100644 --- a/crates/storage-pg/src/lib.rs +++ b/crates/storage-pg/src/lib.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Interactions with the database +//! An implementation of the storage traits for a PostgreSQL database #![forbid(unsafe_code)] #![deny( diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 6448b61a3..0f1fdfb4f 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -27,7 +27,7 @@ use mas_storage::{ UpstreamOAuthSessionRepository, }, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - Repository, + Repository, RepositoryAccess, RepositoryTransaction, }; use sqlx::{PgPool, Postgres, Transaction}; @@ -62,7 +62,9 @@ impl PgRepository { } } -impl Repository for PgRepository { +impl Repository for PgRepository {} + +impl RepositoryTransaction for PgRepository { type Error = DatabaseError; fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { @@ -72,6 +74,10 @@ impl Repository for PgRepository { fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { self.txn.rollback().map_err(DatabaseError::from).boxed() } +} + +impl RepositoryAccess for PgRepository { + type Error = DatabaseError; fn upstream_oauth_link<'c>( &'c mut self, diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index af631f15c..9ff7699eb 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -31,7 +31,7 @@ mod tests { UpstreamOAuthSessionRepository, }, user::UserRepository, - Pagination, Repository, + Pagination, RepositoryAccess, }; use oauth2_types::scope::{Scope, OPENID}; use rand::SeedableRng; diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index 7c3eab376..29f828aba 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -16,7 +16,7 @@ use chrono::Duration; use mas_storage::{ clock::MockClock, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - Repository, + Repository, RepositoryAccess, }; use rand::SeedableRng; use rand_chacha::ChaChaRng; diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index aa1db0af7..69bc2881b 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Interactions with the database +//! Interactions with the storage backend #![forbid(unsafe_code)] #![deny( @@ -42,20 +42,25 @@ pub mod user; pub use self::{ clock::{Clock, SystemClock}, pagination::{Page, Pagination}, - repository::{BoxRepository, Repository, RepositoryError}, + repository::{ + BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction, + }, }; -pub struct MapErr { - inner: Repository, - mapper: Mapper, +/// A wrapper which is used to map the error type of a repository to another +pub struct MapErr { + inner: R, + mapper: F, } -impl MapErr { - fn new(inner: Repository, mapper: Mapper) -> Self { +impl MapErr { + fn new(inner: R, mapper: F) -> Self { Self { inner, mapper } } } +/// A macro to implement a repository trait for the [`MapErr`] wrapper and for +/// [`Box`] #[macro_export] macro_rules! repository_impl { ($repo_trait:ident: diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index d6772b9ad..f023e469b 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; +use futures_util::future::BoxFuture; use thiserror::Error; use crate::{ @@ -32,83 +32,27 @@ use crate::{ MapErr, }; -pub trait Repository: Send { - type Error: std::error::Error + Send + Sync + 'static; +/// A [`Repository`] helps interacting with the underlying storage backend. +pub trait Repository: + RepositoryAccess + RepositoryTransaction + Send +where + E: std::error::Error + Send + Sync + 'static, +{ + /// Construct a (boxed) typed-erased repository + fn boxed(self) -> BoxRepository + where + Self: Sync + Sized + 'static, + { + Box::new(self) + } + /// Map the error type of all the methods of a [`Repository`] fn map_err(self, mapper: Mapper) -> MapErr where Self: Sized, { MapErr::new(self, mapper) } - - fn boxed(self) -> BoxRepository - where - Self: Sized + Sync + 'static, - { - Box::new(self) - } - - fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; - fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; - - fn upstream_oauth_link<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn upstream_oauth_provider<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn upstream_oauth_session<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn user<'c>(&'c mut self) -> Box + 'c>; - - fn user_email<'c>(&'c mut self) -> Box + 'c>; - - fn user_password<'c>(&'c mut self) - -> Box + 'c>; - - fn browser_session<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn oauth2_client<'c>(&'c mut self) - -> Box + 'c>; - - fn oauth2_authorization_grant<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn oauth2_session<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn oauth2_access_token<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn oauth2_refresh_token<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn compat_session<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn compat_sso_login<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn compat_access_token<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn compat_refresh_token<'c>( - &'c mut self, - ) -> Box + 'c>; } /// An opaque, type-erased error @@ -119,6 +63,7 @@ pub struct RepositoryError { } impl RepositoryError { + /// Construct a [`RepositoryError`] from any error kind pub fn from_error(value: E) -> Self where E: std::error::Error + Send + Sync + 'static, @@ -129,251 +74,386 @@ impl RepositoryError { } } -pub type BoxRepository = - Box + Send + Sync + 'static>; +/// A type-erased [`Repository`] +pub type BoxRepository = Box + Send + Sync + 'static>; -impl Repository for crate::MapErr -where - R: Repository, - R::Error: 'static, - F: FnMut(R::Error) -> E + Send + Sync + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - type Error = E; +/// A [`RepositoryTransaction`] can be saved or cancelled, after a series +/// of operations. +pub trait RepositoryTransaction { + /// The error type used by the [`Self::save`] and [`Self::cancel`] functions + type Error; - fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { - Box::new(self.inner).save().map_err(self.mapper).boxed() - } + /// Commit the transaction + /// + /// # Errors + /// + /// Returns an error if the underlying storage backend failed to commit the + /// transaction. + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; - fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { - Box::new(self.inner).cancel().map_err(self.mapper).boxed() - } - - fn upstream_oauth_link<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.upstream_oauth_link(), - &mut self.mapper, - )) - } - - fn upstream_oauth_provider<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.upstream_oauth_provider(), - &mut self.mapper, - )) - } - - fn upstream_oauth_session<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.upstream_oauth_session(), - &mut self.mapper, - )) - } - - fn user<'c>(&'c mut self) -> Box + 'c> { - Box::new(MapErr::new(self.inner.user(), &mut self.mapper)) - } - - fn user_email<'c>(&'c mut self) -> Box + 'c> { - Box::new(MapErr::new(self.inner.user_email(), &mut self.mapper)) - } - - fn user_password<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper)) - } - - fn browser_session<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper)) - } - - fn oauth2_client<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new(self.inner.oauth2_client(), &mut self.mapper)) - } - - fn oauth2_authorization_grant<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.oauth2_authorization_grant(), - &mut self.mapper, - )) - } - - fn oauth2_session<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new(self.inner.oauth2_session(), &mut self.mapper)) - } - - fn oauth2_access_token<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.oauth2_access_token(), - &mut self.mapper, - )) - } - - fn oauth2_refresh_token<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.oauth2_refresh_token(), - &mut self.mapper, - )) - } - - fn compat_session<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new(self.inner.compat_session(), &mut self.mapper)) - } - - fn compat_sso_login<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new(self.inner.compat_sso_login(), &mut self.mapper)) - } - - fn compat_access_token<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.compat_access_token(), - &mut self.mapper, - )) - } - - fn compat_refresh_token<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.compat_refresh_token(), - &mut self.mapper, - )) - } + /// Rollback the transaction + /// + /// # Errors + /// + /// Returns an error if the underlying storage backend failed to rollback + /// the transaction. + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; } -impl Repository for Box { - type Error = R::Error; - - fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> - where - Self: Sized, - { - // This shouldn't be callable? - unimplemented!() - } - - fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> - where - Self: Sized, - { - // This shouldn't be callable? - unimplemented!() - } +/// Access the various repositories the backend implements. +pub trait RepositoryAccess: Send { + /// The backend-specific error type used by each repository. + type Error: std::error::Error + Send + Sync + 'static; + /// Get an [`UpstreamOAuthLinkRepository`] fn upstream_oauth_link<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).upstream_oauth_link() - } + ) -> Box + 'c>; + /// Get an [`UpstreamOAuthProviderRepository`] fn upstream_oauth_provider<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).upstream_oauth_provider() - } + ) -> Box + 'c>; + /// Get an [`UpstreamOAuthSessionRepository`] fn upstream_oauth_session<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).upstream_oauth_session() - } + ) -> Box + 'c>; - fn user<'c>(&'c mut self) -> Box + 'c> { - (**self).user() - } + /// Get an [`UserRepository`] + fn user<'c>(&'c mut self) -> Box + 'c>; - fn user_email<'c>(&'c mut self) -> Box + 'c> { - (**self).user_email() - } + /// Get an [`UserEmailRepository`] + fn user_email<'c>(&'c mut self) -> Box + 'c>; - fn user_password<'c>( - &'c mut self, - ) -> Box + 'c> { - (**self).user_password() - } + /// Get an [`UserPasswordRepository`] + fn user_password<'c>(&'c mut self) + -> Box + 'c>; + /// Get a [`BrowserSessionRepository`] fn browser_session<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).browser_session() - } + ) -> Box + 'c>; - fn oauth2_client<'c>( - &'c mut self, - ) -> Box + 'c> { - (**self).oauth2_client() - } + /// Get an [`OAuth2ClientRepository`] + fn oauth2_client<'c>(&'c mut self) + -> Box + 'c>; + /// Get an [`OAuth2AuthorizationGrantRepository`] fn oauth2_authorization_grant<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).oauth2_authorization_grant() - } + ) -> Box + 'c>; + /// Get an [`OAuth2SessionRepository`] fn oauth2_session<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).oauth2_session() - } + ) -> Box + 'c>; + /// Get an [`OAuth2AccessTokenRepository`] fn oauth2_access_token<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).oauth2_access_token() - } + ) -> Box + 'c>; + /// Get an [`OAuth2RefreshTokenRepository`] fn oauth2_refresh_token<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).oauth2_refresh_token() - } + ) -> Box + 'c>; + /// Get a [`CompatSessionRepository`] fn compat_session<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).compat_session() - } + ) -> Box + 'c>; + /// Get a [`CompatSsoLoginRepository`] fn compat_sso_login<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).compat_sso_login() - } + ) -> Box + 'c>; + /// Get a [`CompatAccessTokenRepository`] fn compat_access_token<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).compat_access_token() - } + ) -> Box + 'c>; + /// Get a [`CompatRefreshTokenRepository`] fn compat_refresh_token<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).compat_refresh_token() + ) -> Box + 'c>; +} + +/// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and +/// [`Repository`] for the [`MapErr`] wrapper and [`Box`] +mod impls { + use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; + + use super::RepositoryAccess; + use crate::{ + compat::{ + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + CompatSsoLoginRepository, + }, + oauth2::{ + OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, + OAuth2ClientRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, + }, + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + user::{ + BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository, + }, + MapErr, Repository, RepositoryTransaction, + }; + + // --- Repository --- + impl Repository for MapErr + where + R: Repository + RepositoryAccess + RepositoryTransaction, + F: FnMut(E1) -> E2 + Send + Sync + 'static, + E1: std::error::Error + Send + Sync + 'static, + E2: std::error::Error + Send + Sync + 'static, + { + } + + // --- RepositoryTransaction -- + impl RepositoryTransaction for MapErr + where + R: RepositoryTransaction, + R::Error: 'static, + F: FnMut(R::Error) -> E + Send + Sync + 'static, + E: std::error::Error, + { + type Error = E; + + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::new(self.inner).save().map_err(self.mapper).boxed() + } + + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::new(self.inner).cancel().map_err(self.mapper).boxed() + } + } + + // --- RepositoryAccess -- + impl RepositoryAccess for MapErr + where + R: RepositoryAccess, + R::Error: 'static, + F: FnMut(R::Error) -> E + Send + Sync + 'static, + E: std::error::Error + Send + Sync + 'static, + { + type Error = E; + + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_link(), + &mut self.mapper, + )) + } + + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_provider(), + &mut self.mapper, + )) + } + + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_session(), + &mut self.mapper, + )) + } + + fn user<'c>(&'c mut self) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user(), &mut self.mapper)) + } + + fn user_email<'c>(&'c mut self) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user_email(), &mut self.mapper)) + } + + fn user_password<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper)) + } + + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper)) + } + + fn oauth2_client<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.oauth2_client(), &mut self.mapper)) + } + + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_authorization_grant(), + &mut self.mapper, + )) + } + + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.oauth2_session(), &mut self.mapper)) + } + + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_access_token(), + &mut self.mapper, + )) + } + + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_refresh_token(), + &mut self.mapper, + )) + } + + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.compat_session(), &mut self.mapper)) + } + + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.compat_sso_login(), &mut self.mapper)) + } + + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.compat_access_token(), + &mut self.mapper, + )) + } + + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.compat_refresh_token(), + &mut self.mapper, + )) + } + } + + impl RepositoryAccess for Box { + type Error = R::Error; + + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).upstream_oauth_link() + } + + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).upstream_oauth_provider() + } + + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).upstream_oauth_session() + } + + fn user<'c>(&'c mut self) -> Box + 'c> { + (**self).user() + } + + fn user_email<'c>(&'c mut self) -> Box + 'c> { + (**self).user_email() + } + + fn user_password<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).user_password() + } + + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).browser_session() + } + + fn oauth2_client<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_client() + } + + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_authorization_grant() + } + + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_session() + } + + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_access_token() + } + + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_refresh_token() + } + + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_session() + } + + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_sso_login() + } + + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_access_token() + } + + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_refresh_token() + } } } diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index e7947ce25..ebade53af 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -14,7 +14,7 @@ //! Database-related tasks -use mas_storage::{oauth2::OAuth2AccessTokenRepository, Repository, SystemClock}; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, RepositoryAccess, SystemClock}; use mas_storage_pg::PgRepository; use sqlx::{Pool, Postgres}; use tracing::{debug, error, info}; From 2e4b1c5492f37a65215f986a2715339d581a5aad Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 25 Jan 2023 16:09:36 +0100 Subject: [PATCH 37/45] storage: document all the repository traits and methods --- crates/storage-pg/src/compat/access_token.rs | 4 + crates/storage-pg/src/compat/mod.rs | 3 + crates/storage-pg/src/compat/refresh_token.rs | 4 + crates/storage-pg/src/compat/session.rs | 3 + crates/storage-pg/src/compat/sso_login.rs | 4 + crates/storage-pg/src/lib.rs | 48 ++++-- crates/storage-pg/src/oauth2/access_token.rs | 4 + .../src/oauth2/authorization_grant.rs | 4 + crates/storage-pg/src/oauth2/client.rs | 3 + crates/storage-pg/src/oauth2/mod.rs | 5 +- crates/storage-pg/src/oauth2/refresh_token.rs | 4 + crates/storage-pg/src/oauth2/session.rs | 3 + crates/storage-pg/src/repository.rs | 8 + crates/storage-pg/src/tracing.rs | 2 + crates/storage-pg/src/upstream_oauth2/link.rs | 4 + crates/storage-pg/src/upstream_oauth2/mod.rs | 5 + .../src/upstream_oauth2/provider.rs | 4 + .../storage-pg/src/upstream_oauth2/session.rs | 4 + crates/storage-pg/src/user/email.rs | 3 + crates/storage-pg/src/user/mod.rs | 5 + crates/storage-pg/src/user/password.rs | 3 + crates/storage-pg/src/user/session.rs | 4 + crates/storage-pg/src/user/tests.rs | 1 + crates/storage/src/compat/access_token.rs | 45 +++++ crates/storage/src/compat/mod.rs | 2 + crates/storage/src/compat/refresh_token.rs | 46 ++++++ crates/storage/src/compat/session.rs | 37 +++++ crates/storage/src/compat/sso_login.rs | 69 ++++++++ crates/storage/src/lib.rs | 12 +- crates/storage/src/oauth2/access_token.rs | 58 +++++++ .../storage/src/oauth2/authorization_grant.rs | 87 ++++++++++ crates/storage/src/oauth2/client.rs | 102 ++++++++++++ crates/storage/src/oauth2/mod.rs | 4 +- crates/storage/src/oauth2/refresh_token.rs | 50 ++++++ crates/storage/src/oauth2/session.rs | 51 ++++++ crates/storage/src/pagination.rs | 32 ++++ crates/storage/src/upstream_oauth2/link.rs | 59 +++++++ crates/storage/src/upstream_oauth2/mod.rs | 3 + .../storage/src/upstream_oauth2/provider.rs | 45 +++++ crates/storage/src/upstream_oauth2/session.rs | 56 +++++++ crates/storage/src/user/email.rs | 154 ++++++++++++++++++ crates/storage/src/user/mod.rs | 54 ++++++ crates/storage/src/user/password.rs | 31 ++++ crates/storage/src/user/session.rs | 91 +++++++++++ 44 files changed, 1202 insertions(+), 18 deletions(-) diff --git a/crates/storage-pg/src/compat/access_token.rs b/crates/storage-pg/src/compat/access_token.rs index 822c3a8af..70fabac79 100644 --- a/crates/storage-pg/src/compat/access_token.rs +++ b/crates/storage-pg/src/compat/access_token.rs @@ -23,11 +23,15 @@ use uuid::Uuid; use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; +/// An implementation of [`CompatAccessTokenRepository`] for a PostgreSQL +/// connection pub struct PgCompatAccessTokenRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgCompatAccessTokenRepository<'c> { + /// Create a new [`PgCompatAccessTokenRepository`] from an active PostgreSQL + /// connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs index 5d99f3329..ae2a40b25 100644 --- a/crates/storage-pg/src/compat/mod.rs +++ b/crates/storage-pg/src/compat/mod.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! A module containing PostgreSQL implementation of repositories for the +//! compatibility layer + mod access_token; mod refresh_token; mod session; diff --git a/crates/storage-pg/src/compat/refresh_token.rs b/crates/storage-pg/src/compat/refresh_token.rs index 991e14381..0811119a1 100644 --- a/crates/storage-pg/src/compat/refresh_token.rs +++ b/crates/storage-pg/src/compat/refresh_token.rs @@ -25,11 +25,15 @@ use uuid::Uuid; use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; +/// An implementation of [`CompatRefreshTokenRepository`] for a PostgreSQL +/// connection pub struct PgCompatRefreshTokenRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgCompatRefreshTokenRepository<'c> { + /// Create a new [`PgCompatRefreshTokenRepository`] from an active + /// PostgreSQL connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs index 16208b4f0..283a9a598 100644 --- a/crates/storage-pg/src/compat/session.rs +++ b/crates/storage-pg/src/compat/session.rs @@ -23,11 +23,14 @@ use uuid::Uuid; use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +/// An implementation of [`CompatSessionRepository`] for a PostgreSQL connection pub struct PgCompatSessionRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgCompatSessionRepository<'c> { + /// Create a new [`PgCompatSessionRepository`] from an active PostgreSQL + /// connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/compat/sso_login.rs b/crates/storage-pg/src/compat/sso_login.rs index 1b8e0225f..ae9ca083c 100644 --- a/crates/storage-pg/src/compat/sso_login.rs +++ b/crates/storage-pg/src/compat/sso_login.rs @@ -27,11 +27,15 @@ use crate::{ LookupResultExt, }; +/// An implementation of [`CompatSsoLoginRepository`] for a PostgreSQL +/// connection pub struct PgCompatSsoLoginRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgCompatSsoLoginRepository<'c> { + /// Create a new [`PgCompatSsoLoginRepository`] from an active PostgreSQL + /// connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs index 08c89db2e..046cc4b8a 100644 --- a/crates/storage-pg/src/lib.rs +++ b/crates/storage-pg/src/lib.rs @@ -13,25 +13,29 @@ // limitations under the License. //! An implementation of the storage traits for a PostgreSQL database +//! +//! This backend uses [`sqlx`] to interact with the database. Most queries are +//! type-checked, using introspection data recorded in the `sqlx-data.json` +//! file. This file is generated by the `sqlx` CLI tool, and should be updated +//! whenever the database schema changes, or new queries are added. #![forbid(unsafe_code)] #![deny( clippy::all, clippy::str_to_string, clippy::future_not_send, - rustdoc::broken_intra_doc_links + rustdoc::broken_intra_doc_links, + missing_docs )] #![warn(clippy::pedantic)] -#![allow( - clippy::missing_errors_doc, - clippy::missing_panics_doc, - clippy::module_name_repetitions -)] +#![allow(clippy::module_name_repetitions)] use sqlx::{migrate::Migrator, postgres::PgQueryResult}; use thiserror::Error; use ulid::Ulid; +/// An extension trait for [`Result`] which adds a [`to_option`] method, useful +/// for handling "not found" errors from [`sqlx`] trait LookupResultExt { type Output; @@ -57,7 +61,11 @@ impl LookupResultExt for Result { #[error(transparent)] pub enum DatabaseError { /// An error which came from the database itself - Driver(#[from] sqlx::Error), + Driver { + /// The underlying error from the database driver + #[from] + source: sqlx::Error, + }, /// An error which occured while converting the data from the database Inconsistency(#[from] DatabaseInconsistencyError), @@ -66,6 +74,7 @@ pub enum DatabaseError { /// invalid #[error("Invalid database operation")] InvalidOperation { + /// The source of the error, if any #[source] source: Option>, }, @@ -73,7 +82,13 @@ pub enum DatabaseError { /// An error which happens when an operation affects not enough or too many /// rows #[error("Expected {expected} rows to be affected, but {actual} rows were affected")] - RowsAffected { expected: u64, actual: u64 }, + RowsAffected { + /// How many rows were expected to be affected + expected: u64, + + /// How many rows were actually affected + actual: u64, + }, } impl DatabaseError { @@ -100,12 +115,19 @@ impl DatabaseError { } } +/// An error which occured while converting the data from the database #[derive(Debug, Error)] pub struct DatabaseInconsistencyError { + /// The table which was being queried table: &'static str, + + /// The column which was being queried column: Option<&'static str>, + + /// The row which was being queried row: Option, + /// The source of the error #[source] source: Option>, } @@ -125,6 +147,7 @@ impl std::fmt::Display for DatabaseInconsistencyError { } impl DatabaseInconsistencyError { + /// Create a new [`DatabaseInconsistencyError`] for the given table #[must_use] pub(crate) const fn on(table: &'static str) -> Self { Self { @@ -135,18 +158,22 @@ impl DatabaseInconsistencyError { } } + /// Set the column which was being queried #[must_use] pub(crate) const fn column(mut self, column: &'static str) -> Self { self.column = Some(column); self } + /// Set the row which was being queried #[must_use] pub(crate) const fn row(mut self, row: Ulid) -> Self { self.row = Some(row); self } + /// Give the source of the error + #[must_use] pub(crate) fn source( mut self, source: E, @@ -158,11 +185,12 @@ impl DatabaseInconsistencyError { pub mod compat; pub mod oauth2; +pub mod upstream_oauth2; +pub mod user; + pub(crate) mod pagination; pub(crate) mod repository; pub(crate) mod tracing; -pub mod upstream_oauth2; -pub mod user; pub use self::repository::PgRepository; diff --git a/crates/storage-pg/src/oauth2/access_token.rs b/crates/storage-pg/src/oauth2/access_token.rs index ecd5798b0..e809fa53c 100644 --- a/crates/storage-pg/src/oauth2/access_token.rs +++ b/crates/storage-pg/src/oauth2/access_token.rs @@ -23,11 +23,15 @@ use uuid::Uuid; use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; +/// An implementation of [`OAuth2AccessTokenRepository`] for a PostgreSQL +/// connection pub struct PgOAuth2AccessTokenRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgOAuth2AccessTokenRepository<'c> { + /// Create a new [`PgOAuth2AccessTokenRepository`] from an active PostgreSQL + /// connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/oauth2/authorization_grant.rs b/crates/storage-pg/src/oauth2/authorization_grant.rs index 92116a62d..f62edae30 100644 --- a/crates/storage-pg/src/oauth2/authorization_grant.rs +++ b/crates/storage-pg/src/oauth2/authorization_grant.rs @@ -30,11 +30,15 @@ use uuid::Uuid; use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +/// An implementation of [`OAuth2AuthorizationGrantRepository`] for a PostgreSQL +/// connection pub struct PgOAuth2AuthorizationGrantRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgOAuth2AuthorizationGrantRepository<'c> { + /// Create a new [`PgOAuth2AuthorizationGrantRepository`] from an active + /// PostgreSQL connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/oauth2/client.rs b/crates/storage-pg/src/oauth2/client.rs index e17245aad..cc2ed8b86 100644 --- a/crates/storage-pg/src/oauth2/client.rs +++ b/crates/storage-pg/src/oauth2/client.rs @@ -39,11 +39,14 @@ use uuid::Uuid; use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +/// An implementation of [`OAuth2ClientRepository`] for a PostgreSQL connection pub struct PgOAuth2ClientRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgOAuth2ClientRepository<'c> { + /// Create a new [`PgOAuth2ClientRepository`] from an active PostgreSQL + /// connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs index edad2beb0..3e4961417 100644 --- a/crates/storage-pg/src/oauth2/mod.rs +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -12,8 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! A module containing the PostgreSQL implementations of the OAuth2-related +//! repositories + mod access_token; -pub mod authorization_grant; +mod authorization_grant; mod client; mod refresh_token; mod session; diff --git a/crates/storage-pg/src/oauth2/refresh_token.rs b/crates/storage-pg/src/oauth2/refresh_token.rs index ba2fa5334..ae723f7c3 100644 --- a/crates/storage-pg/src/oauth2/refresh_token.rs +++ b/crates/storage-pg/src/oauth2/refresh_token.rs @@ -23,11 +23,15 @@ use uuid::Uuid; use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; +/// An implementation of [`OAuth2RefreshTokenRepository`] for a PostgreSQL +/// connection pub struct PgOAuth2RefreshTokenRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgOAuth2RefreshTokenRepository<'c> { + /// Create a new [`PgOAuth2RefreshTokenRepository`] from an active + /// PostgreSQL connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index 891c2278f..aa667f252 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -26,11 +26,14 @@ use crate::{ LookupResultExt, }; +/// An implementation of [`OAuth2SessionRepository`] for a PostgreSQL connection pub struct PgOAuth2SessionRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgOAuth2SessionRepository<'c> { + /// Create a new [`PgOAuth2SessionRepository`] from an active PostgreSQL + /// connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 0f1fdfb4f..da81d3af4 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -51,11 +51,19 @@ use crate::{ DatabaseError, }; +/// An implementation of the [`Repository`] trait backed by a PostgreSQL +/// transaction. pub struct PgRepository { txn: Transaction<'static, Postgres>, } impl PgRepository { + /// Create a new [`PgRepository`] from a PostgreSQL connection pool, + /// starting a transaction. + /// + /// # Errors + /// + /// Returns a [`DatabaseError`] if the transaction could not be started. pub async fn from_pool(pool: &PgPool) -> Result { let txn = pool.begin().await?; Ok(PgRepository { txn }) diff --git a/crates/storage-pg/src/tracing.rs b/crates/storage-pg/src/tracing.rs index 1210816c5..b0bc0b7ff 100644 --- a/crates/storage-pg/src/tracing.rs +++ b/crates/storage-pg/src/tracing.rs @@ -14,6 +14,8 @@ use tracing::Span; +/// An extension trait for [`sqlx::Execute`] that records the SQL statement as +/// `db.statement` in a tracing span pub trait ExecuteExt<'q, DB>: Sized { /// Records the statement as `db.statement` in the current span fn traced(self) -> Self { diff --git a/crates/storage-pg/src/upstream_oauth2/link.rs b/crates/storage-pg/src/upstream_oauth2/link.rs index c38b344b2..0e14f3fd5 100644 --- a/crates/storage-pg/src/upstream_oauth2/link.rs +++ b/crates/storage-pg/src/upstream_oauth2/link.rs @@ -23,11 +23,15 @@ use uuid::Uuid; use crate::{pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, LookupResultExt}; +/// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL +/// connection pub struct PgUpstreamOAuthLinkRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgUpstreamOAuthLinkRepository<'c> { + /// Create a new [`PgUpstreamOAuthLinkRepository`] from an active PostgreSQL + /// connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index 9ff7699eb..5bf97514f 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! A module containing the PostgreSQL implementation of the repositories +//! related to the upstream OAuth 2.0 providers + mod link; mod provider; mod session; @@ -178,6 +181,8 @@ mod tests { assert_eq!(links.edges[0].user_id, Some(user.id)); } + /// Test that the pagination works as expected in the upstream OAuth + /// provider repository #[sqlx::test(migrator = "crate::MIGRATOR")] async fn test_provider_repository_pagination(pool: PgPool) { const ISSUER: &str = "https://example.com/"; diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs index dc1b0c85b..d4ecbe473 100644 --- a/crates/storage-pg/src/upstream_oauth2/provider.rs +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -28,11 +28,15 @@ use crate::{ LookupResultExt, }; +/// An implementation of [`UpstreamOAuthProviderRepository`] for a PostgreSQL +/// connection pub struct PgUpstreamOAuthProviderRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgUpstreamOAuthProviderRepository<'c> { + /// Create a new [`PgUpstreamOAuthProviderRepository`] from an active + /// PostgreSQL connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/upstream_oauth2/session.rs b/crates/storage-pg/src/upstream_oauth2/session.rs index 3cdef0c73..5780ab8d3 100644 --- a/crates/storage-pg/src/upstream_oauth2/session.rs +++ b/crates/storage-pg/src/upstream_oauth2/session.rs @@ -26,11 +26,15 @@ use uuid::Uuid; use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +/// An implementation of [`UpstreamOAuthSessionRepository`] for a PostgreSQL +/// connection pub struct PgUpstreamOAuthSessionRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgUpstreamOAuthSessionRepository<'c> { + /// Create a new [`PgUpstreamOAuthSessionRepository`] from an active + /// PostgreSQL connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/user/email.rs b/crates/storage-pg/src/user/email.rs index b9d732dec..28b9d3951 100644 --- a/crates/storage-pg/src/user/email.rs +++ b/crates/storage-pg/src/user/email.rs @@ -27,11 +27,14 @@ use crate::{ LookupResultExt, }; +/// An implementation of [`UserEmailRepository`] for a PostgreSQL connection pub struct PgUserEmailRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgUserEmailRepository<'c> { + /// Create a new [`PgUserEmailRepository`] from an active PostgreSQL + /// connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/user/mod.rs b/crates/storage-pg/src/user/mod.rs index 8ec6170f5..0554c8b25 100644 --- a/crates/storage-pg/src/user/mod.rs +++ b/crates/storage-pg/src/user/mod.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! A module containing the PostgreSQL implementation of the user-related +//! repositories + use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::User; @@ -35,11 +38,13 @@ pub use self::{ session::PgBrowserSessionRepository, }; +/// An implementation of [`UserRepository`] for a PostgreSQL connection pub struct PgUserRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgUserRepository<'c> { + /// Create a new [`PgUserRepository`] from an active PostgreSQL connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/user/password.rs b/crates/storage-pg/src/user/password.rs index 696b30e77..1dfd90d1c 100644 --- a/crates/storage-pg/src/user/password.rs +++ b/crates/storage-pg/src/user/password.rs @@ -23,11 +23,14 @@ use uuid::Uuid; use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +/// An implementation of [`UserPasswordRepository`] for a PostgreSQL connection pub struct PgUserPasswordRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgUserPasswordRepository<'c> { + /// Create a new [`PgUserPasswordRepository`] from an active PostgreSQL + /// connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/user/session.rs b/crates/storage-pg/src/user/session.rs index e5616fffa..ff91726ea 100644 --- a/crates/storage-pg/src/user/session.rs +++ b/crates/storage-pg/src/user/session.rs @@ -26,11 +26,15 @@ use crate::{ LookupResultExt, }; +/// An implementation of [`BrowserSessionRepository`] for a PostgreSQL +/// connection pub struct PgBrowserSessionRepository<'c> { conn: &'c mut PgConnection, } impl<'c> PgBrowserSessionRepository<'c> { + /// Create a new [`PgBrowserSessionRepository`] from an active PostgreSQL + /// connection pub fn new(conn: &'c mut PgConnection) -> Self { Self { conn } } diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index 29f828aba..29ebe19f6 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -253,6 +253,7 @@ async fn test_user_email_repo(pool: PgPool) { repo.save().await.unwrap(); } +/// Test the user password repository implementation. #[sqlx::test(migrator = "crate::MIGRATOR")] async fn test_user_password_repo(pool: PgPool) { const USERNAME: &str = "john"; diff --git a/crates/storage/src/compat/access_token.rs b/crates/storage/src/compat/access_token.rs index c6d4eb7fe..c6d3979ee 100644 --- a/crates/storage/src/compat/access_token.rs +++ b/crates/storage/src/compat/access_token.rs @@ -20,20 +20,58 @@ use ulid::Ulid; use crate::{repository_impl, Clock}; +/// A [`CompatAccessTokenRepository`] helps interacting with +/// [`CompatAccessToken`] saved in the storage backend #[async_trait] pub trait CompatAccessTokenRepository: Send + Sync { + /// The error type returned by the repository type Error; /// Lookup a compat access token by its ID + /// + /// Returns the compat access token if it exists, `None` otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the compat access token to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; /// Find a compat access token by its token + /// + /// Returns the compat access token if found, `None` otherwise + /// + /// # Parameters + /// + /// * `access_token`: The token of the compat access token to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn find_by_token( &mut self, access_token: &str, ) -> Result, Self::Error>; /// Add a new compat access token to the database + /// + /// Returns the newly created compat access token + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `compat_session`: The compat session associated with the access token + /// * `token`: The token of the access token + /// * `expires_after`: The duration after which the access token expires, if + /// specified + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn add( &mut self, rng: &mut (dyn RngCore + Send), @@ -44,6 +82,13 @@ pub trait CompatAccessTokenRepository: Send + Sync { ) -> Result; /// Set the expiration time of the compat access token to now + /// + /// Returns the expired compat access token + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `compat_access_token`: The compat access token to expire async fn expire( &mut self, clock: &dyn Clock, diff --git a/crates/storage/src/compat/mod.rs b/crates/storage/src/compat/mod.rs index 634c04a7f..eb971edd1 100644 --- a/crates/storage/src/compat/mod.rs +++ b/crates/storage/src/compat/mod.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Repositories to interact with entities of the compatibility layer + mod access_token; mod refresh_token; mod session; diff --git a/crates/storage/src/compat/refresh_token.rs b/crates/storage/src/compat/refresh_token.rs index 3fd916da9..c9b3aabe4 100644 --- a/crates/storage/src/compat/refresh_token.rs +++ b/crates/storage/src/compat/refresh_token.rs @@ -19,20 +19,55 @@ use ulid::Ulid; use crate::{repository_impl, Clock}; +/// A [`CompatRefreshTokenRepository`] helps interacting with +/// [`CompatRefreshToken`] saved in the storage backend #[async_trait] pub trait CompatRefreshTokenRepository: Send + Sync { + /// The error type returned by the repository type Error; /// Lookup a compat refresh token by its ID + /// + /// Returns the compat refresh token if it exists, `None` otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the compat refresh token to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; /// Find a compat refresh token by its token + /// + /// Returns the compat refresh token if found, `None` otherwise + /// + /// # Parameters + /// + /// * `refresh_token`: The token of the compat refresh token to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn find_by_token( &mut self, refresh_token: &str, ) -> Result, Self::Error>; /// Add a new compat refresh token to the database + /// + /// Returns the newly created compat refresh token + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `compat_session`: The compat session associated with this refresh + /// token + /// * `compat_access_token`: The compat access token created alongside this + /// refresh token + /// * `token`: The token of the refresh token async fn add( &mut self, rng: &mut (dyn RngCore + Send), @@ -43,6 +78,17 @@ pub trait CompatRefreshTokenRepository: Send + Sync { ) -> Result; /// Consume a compat refresh token + /// + /// Returns the consumed compat refresh token + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `compat_refresh_token`: The compat refresh token to consume + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn consume( &mut self, clock: &dyn Clock, diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index f867a332b..fb9dea73c 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -19,14 +19,40 @@ use ulid::Ulid; use crate::{repository_impl, Clock}; +/// A [`CompatSessionRepository`] helps interacting with +/// [`CompatSessionRepository`] saved in the storage backend #[async_trait] pub trait CompatSessionRepository: Send + Sync { + /// The error type returned by the repository type Error; /// Lookup a compat session by its ID + /// + /// Returns the compat session if it exists, `None` otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the compat session to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; /// Start a new compat session + /// + /// Returns the newly created compat session + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `user`: The user to create the compat session for + /// * `device`: The device ID of this session + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn add( &mut self, rng: &mut (dyn RngCore + Send), @@ -36,6 +62,17 @@ pub trait CompatSessionRepository: Send + Sync { ) -> Result; /// End a compat session + /// + /// Returns the ended compat session + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `compat_session`: The compat session to end + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn finish( &mut self, clock: &dyn Clock, diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs index a6fa07357..7c823d620 100644 --- a/crates/storage/src/compat/sso_login.rs +++ b/crates/storage/src/compat/sso_login.rs @@ -20,20 +20,56 @@ use url::Url; use crate::{pagination::Page, repository_impl, Clock, Pagination}; +/// A [`CompatSsoLoginRepository`] helps interacting with +/// [`CompatSsoLoginRepository`] saved in the storage backend #[async_trait] pub trait CompatSsoLoginRepository: Send + Sync { + /// The error type returned by the repository type Error; /// Lookup a compat SSO login by its ID + /// + /// Returns the compat SSO login if it exists, `None` otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the compat SSO login to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; /// Find a compat SSO login by its login token + /// + /// Returns the compat SSO login if found, `None` otherwise + /// + /// # Parameters + /// + /// * `login_token`: The login token of the compat SSO login to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn find_by_token( &mut self, login_token: &str, ) -> Result, Self::Error>; /// Start a new compat SSO login token + /// + /// Returns the newly created compat SSO login + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate the timestamps + /// * `login_token`: The login token given to the client + /// * `redirect_uri`: The redirect URI given by the client + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn add( &mut self, rng: &mut (dyn RngCore + Send), @@ -43,6 +79,19 @@ pub trait CompatSsoLoginRepository: Send + Sync { ) -> Result; /// Fulfill a compat SSO login by providing a compat session + /// + /// Returns the fulfilled compat SSO login + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate the timestamps + /// * `compat_sso_login`: The compat SSO login to fulfill + /// * `compat_session`: The compat session to associate with the compat SSO + /// login + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn fulfill( &mut self, clock: &dyn Clock, @@ -51,6 +100,17 @@ pub trait CompatSsoLoginRepository: Send + Sync { ) -> Result; /// Mark a compat SSO login as exchanged + /// + /// Returns the exchanged compat SSO login + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate the timestamps + /// * `compat_sso_login`: The compat SSO login to mark as exchanged + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn exchange( &mut self, clock: &dyn Clock, @@ -58,6 +118,15 @@ pub trait CompatSsoLoginRepository: Send + Sync { ) -> Result; /// Get a paginated list of compat SSO logins for a user + /// + /// # Parameters + /// + /// * `user`: The user to get the compat SSO logins for + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn list_paginated( &mut self, user: &User, diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 69bc2881b..cffec045a 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -19,14 +19,11 @@ clippy::all, clippy::str_to_string, clippy::future_not_send, - rustdoc::broken_intra_doc_links + rustdoc::broken_intra_doc_links, + missing_docs )] #![warn(clippy::pedantic)] -#![allow( - clippy::missing_errors_doc, - clippy::missing_panics_doc, - clippy::module_name_repetitions -)] +#![allow(clippy::module_name_repetitions)] use rand_core::CryptoRngCore; @@ -103,5 +100,8 @@ macro_rules! repository_impl { }; } +/// A boxed [`Clock`] pub type BoxClock = Box; + +/// A boxed random number generator pub type BoxRng = Box; diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index 8a5362431..3fba2399d 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -20,20 +20,57 @@ use ulid::Ulid; use crate::{repository_impl, Clock}; +/// An [`OAuth2AccessTokenRepository`] helps interacting with [`AccessToken`] +/// saved in the storage backend #[async_trait] pub trait OAuth2AccessTokenRepository: Send + Sync { + /// The error type returned by the repository type Error; /// Lookup an access token by its ID + /// + /// Returns the access token if it exists, `None` otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the access token to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; /// Find an access token by its token + /// + /// Returns the access token if it exists, `None` otherwise + /// + /// # Parameters + /// + /// * `access_token`: The token of the access token to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn find_by_token( &mut self, access_token: &str, ) -> Result, Self::Error>; /// Add a new access token to the database + /// + /// Returns the newly created access token + /// + /// # Parameters + /// + /// * `rng`: A random number generator + /// * `clock`: The clock used to generate timestamps + /// * `session`: The session the access token is associated with + /// * `access_token`: The access token to add + /// * `expires_after`: The duration after which the access token expires + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn add( &mut self, rng: &mut (dyn RngCore + Send), @@ -44,6 +81,17 @@ pub trait OAuth2AccessTokenRepository: Send + Sync { ) -> Result; /// Revoke an access token + /// + /// Returns the revoked access token + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `access_token`: The access token to revoke + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn revoke( &mut self, clock: &dyn Clock, @@ -51,6 +99,16 @@ pub trait OAuth2AccessTokenRepository: Send + Sync { ) -> Result; /// Cleanup expired access tokens + /// + /// Returns the number of access tokens that were cleaned up + /// + /// # Parameters + /// + /// * `clock`: The clock used to get the current time + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result; } diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 8852f796b..623ea596d 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -23,10 +23,38 @@ use url::Url; use crate::{repository_impl, Clock}; +/// An [`OAuth2AuthorizationGrantRepository`] helps interacting with +/// [`AuthorizationGrant`] saved in the storage backend #[async_trait] pub trait OAuth2AuthorizationGrantRepository: Send + Sync { + /// The error type returned by the repository type Error; + /// Create a new authorization grant + /// + /// Returns the newly created authorization grant + /// + /// # Parameters + /// + /// * `rng`: A random number generator + /// * `clock`: The clock used to generate timestamps + /// * `client`: The client that requested the authorization grant + /// * `redirect_uri`: The redirect URI the client requested + /// * `scope`: The scope the client requested + /// * `code`: The authorization code used by this grant, if the `code` + /// `response_type` was requested + /// * `state`: The state the client sent, if set + /// * `nonce`: The nonce the client sent, if set + /// * `max_age`: The maximum age since the user last authenticated, if asked + /// by the client + /// * `response_mode`: The response mode the client requested + /// * `response_type_id_token`: Whether the `id_token` `response_type` was + /// requested + /// * `requires_consent`: Whether the client explicitly requested consent + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails #[allow(clippy::too_many_arguments)] async fn add( &mut self, @@ -44,11 +72,47 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync { requires_consent: bool, ) -> Result; + /// Lookup an authorization grant by its ID + /// + /// Returns the authorization grant if found, `None` otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the authorization grant to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + /// Find an authorization grant by its code + /// + /// Returns the authorization grant if found, `None` otherwise + /// + /// # Parameters + /// + /// * `code`: The code of the authorization grant to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn find_by_code(&mut self, code: &str) -> Result, Self::Error>; + /// Fulfill an authorization grant, by giving the [`Session`] that it + /// created + /// + /// Returns the updated authorization grant + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `session`: The session that was created using this authorization grant + /// * `authorization_grant`: The authorization grant to fulfill + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn fulfill( &mut self, clock: &dyn Clock, @@ -56,12 +120,35 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync { authorization_grant: AuthorizationGrant, ) -> Result; + /// Mark an authorization grant as exchanged + /// + /// Returns the updated authorization grant + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `authorization_grant`: The authorization grant to mark as exchanged + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn exchange( &mut self, clock: &dyn Clock, authorization_grant: AuthorizationGrant, ) -> Result; + /// Unset the `requires_consent` flag on an authorization grant + /// + /// Returns the updated authorization grant + /// + /// # Parameters + /// + /// * `authorization_grant`: The authorization grant to update + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn give_consent( &mut self, authorization_grant: AuthorizationGrant, diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 98acaaf7e..18f0108b7 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -25,22 +25,82 @@ use url::Url; use crate::{repository_impl, Clock}; +/// An [`OAuth2ClientRepository`] helps interacting with [`Client`] saved in the +/// storage backend #[async_trait] pub trait OAuth2ClientRepository: Send + Sync { + /// The error type returned by the repository type Error; + /// Lookup an OAuth2 client by its ID + /// + /// Returns `None` if the client does not exist + /// + /// # Parameters + /// + /// * `id`: The ID of the client to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + /// Find an OAuth2 client by its client ID async fn find_by_client_id(&mut self, client_id: &str) -> Result, Self::Error> { let Ok(id) = client_id.parse() else { return Ok(None) }; self.lookup(id).await } + /// Load a batch of OAuth2 clients by their IDs + /// + /// Returns a map of client IDs to clients. If a client does not exist, it + /// is not present in the map. + /// + /// # Parameters + /// + /// * `ids`: The IDs of the clients to load + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn load_batch( &mut self, ids: BTreeSet, ) -> Result, Self::Error>; + /// Add a new OAuth2 client + /// + /// Returns the client that was added + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `redirect_uris`: The list of redirect URIs used by this client + /// * `encrypted_client_secret`: The encrypted client secret, if any + /// * `grant_types`: The list of grant types this client can use + /// * `contacts`: The list of contacts for this client + /// * `client_name`: The human-readable name of this client, if given + /// * `logo_uri`: The URI of the logo of this client, if given + /// * `client_uri`: The URI of a website of this client, if given + /// * `policy_uri`: The URI of the privacy policy of this client, if given + /// * `tos_uri`: The URI of the terms of service of this client, if given + /// * `jwks_uri`: The URI of the JWKS of this client, if given + /// * `jwks`: The JWKS of this client, if given + /// * `id_token_signed_response_alg`: The algorithm used to sign the ID + /// token + /// * `userinfo_signed_response_alg`: The algorithm used to sign the user + /// info. If none, the user info endpoint will not sign the response + /// * `token_endpoint_auth_method`: The authentication method used by this + /// client when calling the token endpoint + /// * `token_endpoint_auth_signing_alg`: The algorithm used to sign the JWT + /// when using the `client_secret_jwt` or `private_key_jwt` authentication + /// methods + /// * `initiate_login_uri`: The URI used to initiate a login, if given + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails #[allow(clippy::too_many_arguments)] async fn add( &mut self, @@ -64,6 +124,24 @@ pub trait OAuth2ClientRepository: Send + Sync { initiate_login_uri: Option, ) -> Result; + /// Add or replace a client from the configuration + /// + /// Returns the client that was added or replaced + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `client_id`: The client ID + /// * `client_auth_method`: The authentication method this client uses + /// * `encrypted_client_secret`: The encrypted client secret, if any + /// * `jwks`: The client JWKS, if any + /// * `jwks_uri`: The client JWKS URI, if any + /// * `redirect_uris`: The list of redirect URIs used by this client + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails #[allow(clippy::too_many_arguments)] async fn add_from_config( &mut self, @@ -77,12 +155,36 @@ pub trait OAuth2ClientRepository: Send + Sync { redirect_uris: Vec, ) -> Result; + /// Get the list of scopes that the user has given consent for the given + /// client + /// + /// # Parameters + /// + /// * `client`: The client to get the consent for + /// * `user`: The user to get the consent for + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn get_consent_for_user( &mut self, client: &Client, user: &User, ) -> Result; + /// Give consent for a set of scopes for the given client and user + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `client`: The client to give the consent for + /// * `user`: The user to give the consent for + /// * `scope`: The scope to give consent for + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn give_consent_for_user( &mut self, rng: &mut (dyn RngCore + Send), diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index eaa5e3172..75823c277 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Repositories to interact with entities related to the OAuth 2.0 protocol + mod access_token; -pub mod authorization_grant; +mod authorization_grant; mod client; mod refresh_token; mod session; diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index e8ac63ce6..a0e2c44a0 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -19,20 +19,58 @@ use ulid::Ulid; use crate::{repository_impl, Clock}; +/// An [`OAuth2RefreshTokenRepository`] helps interacting with [`RefreshToken`] +/// saved in the storage backend #[async_trait] pub trait OAuth2RefreshTokenRepository: Send + Sync { + /// The error type returned by the repository type Error; /// Lookup a refresh token by its ID + /// + /// Returns `None` if no [`RefreshToken`] was found + /// + /// # Parameters + /// + /// * `id`: The ID of the [`RefreshToken`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; /// Find a refresh token by its token + /// + /// Returns `None` if no [`RefreshToken`] was found + /// + /// # Parameters + /// + /// * `token`: The token of the [`RefreshToken`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn find_by_token( &mut self, refresh_token: &str, ) -> Result, Self::Error>; /// Add a new refresh token to the database + /// + /// Returns the newly created [`RefreshToken`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `session`: The [`Session`] in which to create the [`RefreshToken`] + /// * `access_token`: The [`AccessToken`] created alongside this + /// [`RefreshToken`] + /// * `refresh_token`: The refresh token to store + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn add( &mut self, rng: &mut (dyn RngCore + Send), @@ -43,6 +81,18 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync { ) -> Result; /// Consume a refresh token + /// + /// Returns the updated [`RefreshToken`] + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `refresh_token`: The [`RefreshToken`] to consume + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails, or if the + /// token was already consumed async fn consume( &mut self, clock: &dyn Clock, diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index f348d9e68..880992a67 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -19,12 +19,41 @@ use ulid::Ulid; use crate::{pagination::Page, repository_impl, Clock, Pagination}; +/// An [`OAuth2SessionRepository`] helps interacting with [`Session`] +/// saved in the storage backend #[async_trait] pub trait OAuth2SessionRepository: Send + Sync { + /// The error type returned by the repository type Error; + /// Lookup an [`Session`] by its ID + /// + /// Returns `None` if no [`Session`] was found + /// + /// # Parameters + /// + /// * `id`: The ID of the [`Session`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + /// Create a new [`Session`] from an [`AuthorizationGrant`] + /// + /// Returns the newly created [`Session`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `grant`: The [`AuthorizationGrant`] to create the [`Session`] from + /// * `user_session`: The [`BrowserSession`] of the user which completed the + /// authorization + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn create_from_grant( &mut self, rng: &mut (dyn RngCore + Send), @@ -33,9 +62,31 @@ pub trait OAuth2SessionRepository: Send + Sync { user_session: &BrowserSession, ) -> Result; + /// Mark a [`Session`] as finished + /// + /// Returns the updated [`Session`] + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `session`: The [`Session`] to mark as finished + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn finish(&mut self, clock: &dyn Clock, session: Session) -> Result; + /// Get a paginated list of [`Session`]s for a [`User`] + /// + /// # Parameters + /// + /// * `user`: The [`User`] to get the [`Session`]s for + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn list_paginated( &mut self, user: &User, diff --git a/crates/storage/src/pagination.rs b/crates/storage/src/pagination.rs index 6af456418..d8d8bc1c4 100644 --- a/crates/storage/src/pagination.rs +++ b/crates/storage/src/pagination.rs @@ -22,17 +22,29 @@ use ulid::Ulid; #[error("Either 'first' or 'last' must be specified")] pub struct InvalidPagination; +/// Pagination parameters #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Pagination { + /// The cursor to start from pub before: Option, + + /// The cursor to end at pub after: Option, + + /// The maximum number of items to return pub count: usize, + + /// In which direction to paginate pub direction: PaginationDirection, } +/// The direction to paginate #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PaginationDirection { + /// Paginate forward Forward, + + /// Paginate backward Backward, } @@ -124,13 +136,24 @@ impl Pagination { } } +/// A page of results returned by a paginated query pub struct Page { + /// When paginating forwards, this is true if there are more items after pub has_next_page: bool, + + /// When paginating backwards, this is true if there are more items before pub has_previous_page: bool, + + /// The items in the page pub edges: Vec, } impl Page { + /// Map the items in this page with the given function + /// + /// # Parameters + /// + /// * `f`: The function to map the items with #[must_use] pub fn map(self, f: F) -> Page where @@ -144,6 +167,15 @@ impl Page { } } + /// Try to map the items in this page with the given fallible function + /// + /// # Parameters + /// + /// * `f`: The fallible function to map the items with + /// + /// # Errors + /// + /// Returns the first error encountered while mapping the items pub fn try_map(self, f: F) -> Result, E> where F: FnMut(T) -> Result, diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 0057f2d6c..9b8a4f1cd 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -19,14 +19,39 @@ use ulid::Ulid; use crate::{pagination::Page, repository_impl, Clock, Pagination}; +/// An [`UpstreamOAuthLinkRepository`] helps interacting with +/// [`UpstreamOAuthLink`] with the storage backend #[async_trait] pub trait UpstreamOAuthLinkRepository: Send + Sync { + /// The error type returned by the repository type Error; /// Lookup an upstream OAuth link by its ID + /// + /// Returns `None` if the link does not exist + /// + /// # Parameters + /// + /// * `id`: The ID of the upstream OAuth link to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; /// Find an upstream OAuth link for a provider by its subject + /// + /// Returns `None` if no matching upstream OAuth link was found + /// + /// # Parameters + /// + /// * `upstream_oauth_provider`: The upstream OAuth provider on which to + /// find the link + /// * `subject`: The subject of the upstream OAuth link to find + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn find_by_subject( &mut self, upstream_oauth_provider: &UpstreamOAuthProvider, @@ -34,6 +59,20 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { ) -> Result, Self::Error>; /// Add a new upstream OAuth link + /// + /// Returns the newly created upstream OAuth link + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `upsream_oauth_provider`: The upstream OAuth provider for which to + /// create the link + /// * `subject`: The subject of the upstream OAuth link to create + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn add( &mut self, rng: &mut (dyn RngCore + Send), @@ -43,6 +82,17 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { ) -> Result; /// Associate an upstream OAuth link to a user + /// + /// Returns the updated upstream OAuth link + /// + /// # Parameters + /// + /// * `upstream_oauth_link`: The upstream OAuth link to update + /// * `user`: The user to associate to the upstream OAuth link + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn associate_to_user( &mut self, upstream_oauth_link: &UpstreamOAuthLink, @@ -50,6 +100,15 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { ) -> Result<(), Self::Error>; /// Get a paginated list of upstream OAuth links on a user + /// + /// # Parameters + /// + /// * `user`: The user for which to get the upstream OAuth links + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn list_paginated( &mut self, user: &User, diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 1648a6448..252217527 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Repositories to interact with entities related to the upstream OAuth 2.0 +//! providers + mod link; mod provider; mod session; diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 8aaca0dac..663af2c92 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -21,14 +21,47 @@ use ulid::Ulid; use crate::{pagination::Page, repository_impl, Clock, Pagination}; +/// An [`UpstreamOAuthProviderRepository`] helps interacting with +/// [`UpstreamOAuthProvider`] saved in the storage backend #[async_trait] pub trait UpstreamOAuthProviderRepository: Send + Sync { + /// The error type returned by the repository type Error; /// Lookup an upstream OAuth provider by its ID + /// + /// Returns `None` if the provider was not found + /// + /// # Parameters + /// + /// * `id`: The ID of the provider to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; /// Add a new upstream OAuth provider + /// + /// Returns the newly created provider + /// + /// # Parameters + /// + /// * `rng`: A random number generator + /// * `clock`: The clock used to generate timestamps + /// * `issuer`: The OIDC issuer of the provider + /// * `scope`: The scope to request during the authorization flow + /// * `token_endpoint_auth_method`: The token endpoint authentication method + /// * `token_endpoint_auth_signing_alg`: The JWT signing algorithm to use + /// when then `client_secret_jwt` or `private_key_jwt` authentication + /// methods are used + /// * `client_id`: The client ID to use when authenticating to the upstream + /// * `encrypted_client_secret`: The encrypted client secret to use when + /// authenticating to the upstream + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails #[allow(clippy::too_many_arguments)] async fn add( &mut self, @@ -43,12 +76,24 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { ) -> Result; /// Get a paginated list of upstream OAuth providers + /// + /// # Parameters + /// + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn list_paginated( &mut self, pagination: Pagination, ) -> Result, Self::Error>; /// Get all upstream OAuth providers + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn all(&mut self) -> Result, Self::Error>; } diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index e878444bd..2d8f14be7 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -19,17 +19,48 @@ use ulid::Ulid; use crate::{repository_impl, Clock}; +/// An [`UpstreamOAuthSessionRepository`] helps interacting with +/// [`UpstreamOAuthAuthorizationSession`] saved in the storage backend #[async_trait] pub trait UpstreamOAuthSessionRepository: Send + Sync { + /// The error type returned by the repository type Error; /// Lookup a session by its ID + /// + /// Returns `None` if the session does not exist + /// + /// # Parameters + /// + /// * `id`: the ID of the session to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup( &mut self, id: Ulid, ) -> Result, Self::Error>; /// Add a session to the database + /// + /// Returns the newly created session + /// + /// # Parameters + /// + /// * `rng`: the random number generator to use + /// * `clock`: the clock source + /// * `upstream_oauth_provider`: the upstream OAuth provider for which to + /// create the session + /// * `state`: the authorization grant `state` parameter sent to the + /// upstream OAuth provider + /// * `code_challenge_verifier`: the code challenge verifier used in this + /// session, if PKCE is being used + /// * `nonce`: the `nonce` used in this session + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn add( &mut self, rng: &mut (dyn RngCore + Send), @@ -41,6 +72,20 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { ) -> Result; /// Mark a session as completed and associate the given link + /// + /// Returns the updated session + /// + /// # Parameters + /// + /// * `clock`: the clock source + /// * `upstream_oauth_authorization_session`: the session to update + /// * `upstream_oauth_link`: the link to associate with the session + /// * `id_token`: the ID token returned by the upstream OAuth provider, if + /// present + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn complete_with_link( &mut self, clock: &dyn Clock, @@ -50,6 +95,17 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { ) -> Result; /// Mark a session as consumed + /// + /// Returns the updated session + /// + /// # Parameters + /// + /// * `clock`: the clock source + /// * `upstream_oauth_authorization_session`: the session to consume + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn consume( &mut self, clock: &dyn Clock, diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 4c8601c2b..9ae815348 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -19,22 +19,105 @@ use ulid::Ulid; use crate::{pagination::Page, repository_impl, Clock, Pagination}; +/// A [`UserEmailRepository`] helps interacting with [`UserEmail`] saved in the +/// storage backend #[async_trait] pub trait UserEmailRepository: Send + Sync { + /// The error type returned by the repository type Error; + /// Lookup an [`UserEmail`] by its ID + /// + /// Returns `None` if no [`UserEmail`] was found + /// + /// # Parameters + /// + /// * `id`: The ID of the [`UserEmail`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Lookup an [`UserEmail`] by its email address for a [`User`] + /// + /// Returns `None` if no matching [`UserEmail`] was found + /// + /// # Parameters + /// + /// * `user`: The [`User`] for whom to lookup the [`UserEmail`] + /// * `email`: The email address to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error>; + + /// Get the primary [`UserEmail`] of a [`User`] + /// + /// Returns `None` if no the user has no primary [`UserEmail`] + /// + /// # Parameters + /// + /// * `user`: The [`User`] for whom to lookup the primary [`UserEmail`] + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn get_primary(&mut self, user: &User) -> Result, Self::Error>; + /// Get all [`UserEmail`] of a [`User`] + /// + /// # Parameters + /// + /// * `user`: The [`User`] for whom to lookup the [`UserEmail`] + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn all(&mut self, user: &User) -> Result, Self::Error>; + + /// List [`UserEmail`] of a [`User`] with the given pagination + /// + /// # Parameters + /// + /// * `user`: The [`User`] for whom to lookup the [`UserEmail`] + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn list_paginated( &mut self, user: &User, pagination: Pagination, ) -> Result, Self::Error>; + + /// Count the [`UserEmail`] of a [`User`] + /// + /// # Parameters + /// + /// * `user`: The [`User`] for whom to count the [`UserEmail`] + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn count(&mut self, user: &User) -> Result; + /// Create a new [`UserEmail`] for a [`User`] + /// + /// Returns the newly created [`UserEmail`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock to use + /// * `user`: The [`User`] for whom to create the [`UserEmail`] + /// * `email`: The email address of the [`UserEmail`] + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn add( &mut self, rng: &mut (dyn RngCore + Send), @@ -42,16 +125,61 @@ pub trait UserEmailRepository: Send + Sync { user: &User, email: String, ) -> Result; + + /// Delete a [`UserEmail`] + /// + /// # Parameters + /// + /// * `user_email`: The [`UserEmail`] to delete + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error>; + /// Mark a [`UserEmail`] as verified + /// + /// Returns the updated [`UserEmail`] + /// + /// # Parameters + /// + /// * `clock`: The clock to use + /// * `user_email`: The [`UserEmail`] to mark as verified + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn mark_as_verified( &mut self, clock: &dyn Clock, user_email: UserEmail, ) -> Result; + /// Mark a [`UserEmail`] as primary + /// + /// # Parameters + /// + /// * `user_email`: The [`UserEmail`] to mark as primary + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error>; + /// Add a [`UserEmailVerification`] for a [`UserEmail`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock to use + /// * `user_email`: The [`UserEmail`] for which to add the + /// [`UserEmailVerification`] + /// * `max_age`: The duration for which the [`UserEmailVerification`] is + /// valid + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn add_verification_code( &mut self, rng: &mut (dyn RngCore + Send), @@ -61,6 +189,20 @@ pub trait UserEmailRepository: Send + Sync { code: String, ) -> Result; + /// Find a [`UserEmailVerification`] for a [`UserEmail`] by its code + /// + /// Returns `None` if no matching [`UserEmailVerification`] was found + /// + /// # Parameters + /// + /// * `clock`: The clock to use + /// * `user_email`: The [`UserEmail`] for which to lookup the + /// [`UserEmailVerification`] + /// * `code`: The code used to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn find_verification_code( &mut self, clock: &dyn Clock, @@ -68,6 +210,18 @@ pub trait UserEmailRepository: Send + Sync { code: &str, ) -> Result, Self::Error>; + /// Consume a [`UserEmailVerification`] + /// + /// Returns the consumed [`UserEmailVerification`] + /// + /// # Parameters + /// + /// * `clock`: The clock to use + /// * `verification`: The [`UserEmailVerification`] to consume + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn consume_verification_code( &mut self, clock: &dyn Clock, diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 49003335d..a611b459c 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Repositories to interact with entities related to user accounts + use async_trait::async_trait; use mas_data_model::User; use rand_core::RngCore; @@ -27,18 +29,70 @@ pub use self::{ email::UserEmailRepository, password::UserPasswordRepository, session::BrowserSessionRepository, }; +/// A [`UserRepository`] helps interacting with [`User`] saved in the storage +/// backend #[async_trait] pub trait UserRepository: Send + Sync { + /// The error type returned by the repository type Error; + /// Lookup a [`User`] by its ID + /// + /// Returns `None` if no [`User`] was found + /// + /// # Parameters + /// + /// * `id`: The ID of the [`User`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find a [`User`] by its username + /// + /// Returns `None` if no [`User`] was found + /// + /// # Parameters + /// + /// * `username`: The username of the [`User`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn find_by_username(&mut self, username: &str) -> Result, Self::Error>; + + /// Create a new [`User`] + /// + /// Returns the newly created [`User`] + /// + /// # Parameters + /// + /// * `rng`: A random number generator to generate the [`User`] ID + /// * `clock`: The clock used to generate timestamps + /// * `username`: The username of the [`User`] + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn add( &mut self, rng: &mut (dyn RngCore + Send), clock: &dyn Clock, username: String, ) -> Result; + + /// Check if a [`User`] exists + /// + /// Returns `true` if the [`User`] exists, `false` otherwise + /// + /// # Parameters + /// + /// * `username`: The username of the [`User`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn exists(&mut self, username: &str) -> Result; } diff --git a/crates/storage/src/user/password.rs b/crates/storage/src/user/password.rs index 06f03f551..7ef5c7ad8 100644 --- a/crates/storage/src/user/password.rs +++ b/crates/storage/src/user/password.rs @@ -18,11 +18,42 @@ use rand_core::RngCore; use crate::{repository_impl, Clock}; +/// A [`UserPasswordRepository`] helps interacting with [`Password`] saved in +/// the storage backend #[async_trait] pub trait UserPasswordRepository: Send + Sync { + /// The error type returned by the repository type Error; + /// Get the active password for a user + /// + /// Returns `None` if the user has no password set + /// + /// # Parameters + /// + /// * `user`: The user to get the password for + /// + /// # Errors + /// + /// Returns [`Self::Error`] if underlying repository fails async fn active(&mut self, user: &User) -> Result, Self::Error>; + + /// Set a new password for a user + /// + /// Returns the newly created [`Password`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `user`: The user to set the password for + /// * `version`: The version of the hashing scheme used + /// * `hashed_password`: The hashed password + /// * `upgraded_from`: The password this password was upgraded from, if any + /// + /// # Errors + /// + /// Returns [`Self::Error`] if underlying repository fails async fn add( &mut self, rng: &mut (dyn RngCore + Send), diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index 0dfc581cc..5e9defbec 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -19,29 +19,105 @@ use ulid::Ulid; use crate::{pagination::Page, repository_impl, Clock, Pagination}; +/// A [`BrowserSessionRepository`] helps interacting with [`BrowserSession`] +/// saved in the storage backend #[async_trait] pub trait BrowserSessionRepository: Send + Sync { + /// The error type returned by the repository type Error; + /// Lookup a [`BrowserSession`] by its ID + /// + /// Returns `None` if the session is not found + /// + /// # Parameters + /// + /// * `id`: The ID of the session to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Create a new [`BrowserSession`] for a [`User`] + /// + /// Returns the newly created [`BrowserSession`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `user`: The user to create the session for + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn add( &mut self, rng: &mut (dyn RngCore + Send), clock: &dyn Clock, user: &User, ) -> Result; + + /// Finish a [`BrowserSession`] + /// + /// Returns the finished session + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `user_session`: The session to finish + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn finish( &mut self, clock: &dyn Clock, user_session: BrowserSession, ) -> Result; + + /// List active [`BrowserSession`] for a [`User`] with the given pagination + /// + /// # Parameters + /// + /// * `user`: The user to list the sessions for + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn list_active_paginated( &mut self, user: &User, pagination: Pagination, ) -> Result, Self::Error>; + + /// Count active [`BrowserSession`] for a [`User`] + /// + /// # Parameters + /// + /// * `user`: The user to count the sessions for + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn count_active(&mut self, user: &User) -> Result; + /// Authenticate a [`BrowserSession`] with the given [`Password`] + /// + /// Returns the updated [`BrowserSession`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `user_session`: The session to authenticate + /// * `user_password`: The password which was used to authenticate + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn authenticate_with_password( &mut self, rng: &mut (dyn RngCore + Send), @@ -50,6 +126,21 @@ pub trait BrowserSessionRepository: Send + Sync { user_password: &Password, ) -> Result; + /// Authenticate a [`BrowserSession`] with the given [`UpstreamOAuthLink`] + /// + /// Returns the updated [`BrowserSession`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `user_session`: The session to authenticate + /// * `upstream_oauth_link`: The upstream OAuth link which was used to + /// authenticate + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails async fn authenticate_with_upstream( &mut self, rng: &mut (dyn RngCore + Send), From 6f6572dddab182546d2d4366a0674a1bf98278d2 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 25 Jan 2023 17:24:34 +0100 Subject: [PATCH 38/45] storage-pg: write tests for the OAuth2 repositories --- .../src/oauth2/authorization_grant.rs | 24 ++ crates/storage-pg/src/oauth2/mod.rs | 320 ++++++++++++++++++ crates/storage-pg/src/oauth2/session.rs | 7 +- 3 files changed, 348 insertions(+), 3 deletions(-) diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index 76572f489..5638ca10a 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -120,6 +120,22 @@ impl AuthorizationGrantStage { pub fn is_pending(&self) -> bool { matches!(self, Self::Pending) } + + /// Returns `true` if the authorization grant stage is [`Fulfilled`]. + /// + /// [`Fulfilled`]: AuthorizationGrantStage::Fulfilled + #[must_use] + pub fn is_fulfilled(&self) -> bool { + matches!(self, Self::Fulfilled { .. }) + } + + /// Returns `true` if the authorization grant stage is [`Exchanged`]. + /// + /// [`Exchanged`]: AuthorizationGrantStage::Exchanged + #[must_use] + pub fn is_exchanged(&self) -> bool { + matches!(self, Self::Exchanged { .. }) + } } #[derive(Debug, Clone, PartialEq, Eq, Serialize)] @@ -140,6 +156,14 @@ pub struct AuthorizationGrant { pub requires_consent: bool, } +impl std::ops::Deref for AuthorizationGrant { + type Target = AuthorizationGrantStage; + + fn deref(&self) -> &Self::Target { + &self.stage + } +} + impl AuthorizationGrant { #[must_use] pub fn max_auth_time(&self) -> DateTime { diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs index 3e4961417..c0659aa4c 100644 --- a/crates/storage-pg/src/oauth2/mod.rs +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -26,3 +26,323 @@ pub use self::{ authorization_grant::PgOAuth2AuthorizationGrantRepository, client::PgOAuth2ClientRepository, refresh_token::PgOAuth2RefreshTokenRepository, session::PgOAuth2SessionRepository, }; + +#[cfg(test)] +mod tests { + use chrono::Duration; + use mas_data_model::AuthorizationCode; + use mas_storage::{clock::MockClock, Clock, Pagination, Repository}; + use oauth2_types::{ + requests::{GrantType, ResponseMode}, + scope::{Scope, OPENID}, + }; + use rand::SeedableRng; + use rand_chacha::ChaChaRng; + use sqlx::PgPool; + use ulid::Ulid; + + use crate::PgRepository; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_repositories(pool: PgPool) { + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + + // Lookup a non-existing client + let client = repo.oauth2_client().lookup(Ulid::nil()).await.unwrap(); + assert_eq!(client, None); + + // Find a non-existing client by client id + let client = repo + .oauth2_client() + .find_by_client_id("some-client-id") + .await + .unwrap(); + assert_eq!(client, None); + + // Create a client + let client = repo + .oauth2_client() + .add( + &mut rng, + &clock, + vec!["https://example.com/redirect".parse().unwrap()], + None, + vec![GrantType::AuthorizationCode], + Vec::new(), // TODO: contacts are not yet saved + // vec!["contact@example.com".to_owned()], + Some("Test client".to_owned()), + Some("https://example.com/logo.png".parse().unwrap()), + Some("https://example.com/".parse().unwrap()), + Some("https://example.com/policy".parse().unwrap()), + Some("https://example.com/tos".parse().unwrap()), + Some("https://example.com/jwks.json".parse().unwrap()), + None, + None, + None, + None, + None, + Some("https://example.com/login".parse().unwrap()), + ) + .await + .unwrap(); + + // Lookup the same client by id + let client_lookup = repo + .oauth2_client() + .lookup(client.id) + .await + .unwrap() + .expect("client not found"); + assert_eq!(client, client_lookup); + + // Find the same client by client id + let client_lookup = repo + .oauth2_client() + .find_by_client_id(&client.client_id) + .await + .unwrap() + .expect("client not found"); + assert_eq!(client, client_lookup); + + // Lookup a non-existing grant + let grant = repo + .oauth2_authorization_grant() + .lookup(Ulid::nil()) + .await + .unwrap(); + assert_eq!(grant, None); + + // Find a non-existing grant by code + let grant = repo + .oauth2_authorization_grant() + .find_by_code("code") + .await + .unwrap(); + assert_eq!(grant, None); + + // Create an authorization grant + let grant = repo + .oauth2_authorization_grant() + .add( + &mut rng, + &clock, + &client, + "https://example.com/redirect".parse().unwrap(), + Scope::from_iter([OPENID]), + Some(AuthorizationCode { + code: "code".to_owned(), + pkce: None, + }), + Some("state".to_owned()), + Some("nonce".to_owned()), + None, + ResponseMode::Query, + true, + false, + ) + .await + .unwrap(); + assert!(grant.is_pending()); + + // Lookup the same grant by id + let grant_lookup = repo + .oauth2_authorization_grant() + .lookup(grant.id) + .await + .unwrap() + .expect("grant not found"); + assert_eq!(grant, grant_lookup); + + // Find the same grant by code + let grant_lookup = repo + .oauth2_authorization_grant() + .find_by_code("code") + .await + .unwrap() + .expect("grant not found"); + assert_eq!(grant, grant_lookup); + + // Create a user and a start a user session + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + let user_session = repo + .browser_session() + .add(&mut rng, &clock, &user) + .await + .unwrap(); + + // Lookup a non-existing session + let session = repo.oauth2_session().lookup(Ulid::nil()).await.unwrap(); + assert_eq!(session, None); + + // Create a session out of the grant + let session = repo + .oauth2_session() + .create_from_grant(&mut rng, &clock, &grant, &user_session) + .await + .unwrap(); + + // Mark the grant as fulfilled + let grant = repo + .oauth2_authorization_grant() + .fulfill(&clock, &session, grant) + .await + .unwrap(); + assert!(grant.is_fulfilled()); + + // Lookup the same session by id + let session_lookup = repo + .oauth2_session() + .lookup(session.id) + .await + .unwrap() + .expect("session not found"); + assert_eq!(session, session_lookup); + + // Mark the grant as exchanged + let grant = repo + .oauth2_authorization_grant() + .exchange(&clock, grant) + .await + .unwrap(); + assert!(grant.is_exchanged()); + + // Lookup a non-existing token + let token = repo + .oauth2_access_token() + .lookup(Ulid::nil()) + .await + .unwrap(); + assert_eq!(token, None); + + // Find a non-existing token + let token = repo + .oauth2_access_token() + .find_by_token("aabbcc") + .await + .unwrap(); + assert_eq!(token, None); + + // Create an access token + let access_token = repo + .oauth2_access_token() + .add( + &mut rng, + &clock, + &session, + "aabbcc".to_owned(), + Duration::minutes(5), + ) + .await + .unwrap(); + + // Lookup the same token by id + let access_token_lookup = repo + .oauth2_access_token() + .lookup(access_token.id) + .await + .unwrap() + .expect("token not found"); + assert_eq!(access_token, access_token_lookup); + + // Find the same token by token + let access_token_lookup = repo + .oauth2_access_token() + .find_by_token("aabbcc") + .await + .unwrap() + .expect("token not found"); + assert_eq!(access_token, access_token_lookup); + + // Lookup a non-existing refresh token + let refresh_token = repo + .oauth2_refresh_token() + .lookup(Ulid::nil()) + .await + .unwrap(); + assert_eq!(refresh_token, None); + + // Find a non-existing refresh token + let refresh_token = repo + .oauth2_refresh_token() + .find_by_token("aabbcc") + .await + .unwrap(); + assert_eq!(refresh_token, None); + + // Create a refresh token + let refresh_token = repo + .oauth2_refresh_token() + .add( + &mut rng, + &clock, + &session, + &access_token, + "aabbcc".to_owned(), + ) + .await + .unwrap(); + + // Lookup the same refresh token by id + let refresh_token_lookup = repo + .oauth2_refresh_token() + .lookup(refresh_token.id) + .await + .unwrap() + .expect("refresh token not found"); + assert_eq!(refresh_token, refresh_token_lookup); + + // Find the same refresh token by token + let refresh_token_lookup = repo + .oauth2_refresh_token() + .find_by_token("aabbcc") + .await + .unwrap() + .expect("refresh token not found"); + assert_eq!(refresh_token, refresh_token_lookup); + + assert!(access_token.is_valid(clock.now())); + clock.advance(Duration::minutes(6)); + assert!(!access_token.is_valid(clock.now())); + + // XXX: we might want to create a new access token + clock.advance(Duration::minutes(-6)); // Go back in time + assert!(access_token.is_valid(clock.now())); + + // Mark the access token as revoked + let access_token = repo + .oauth2_access_token() + .revoke(&clock, access_token) + .await + .unwrap(); + assert!(!access_token.is_valid(clock.now())); + + // Mark the refresh token as consumed + assert!(refresh_token.is_valid()); + let refresh_token = repo + .oauth2_refresh_token() + .consume(&clock, refresh_token) + .await + .unwrap(); + assert!(!refresh_token.is_valid()); + + // Mark the session as finished + assert!(session.is_valid()); + let session = repo.oauth2_session().finish(&clock, session).await.unwrap(); + assert!(!session.is_valid()); + + // The session should appear in the paginated list of sessions for the user + let sessions = repo + .oauth2_session() + .list_paginated(&user, Pagination::first(10)) + .await + .unwrap(); + assert!(!sessions.has_next_page); + assert_eq!(sessions.edges, vec![session]); + } +} diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index aa667f252..e6168310f 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -232,14 +232,15 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { , user_session_id , oauth2_client_id , scope - , created_at - , finished_at + , os.created_at + , os.finished_at FROM oauth2_sessions os + INNER JOIN user_sessions USING (user_session_id) "#, ); query - .push(" WHERE us.user_id = ") + .push(" WHERE user_id = ") .push_bind(Uuid::from(user.id)) .generate_pagination("oauth2_session_id", pagination); From 45f56748017133544bc19cc93d58082f00f165f1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 26 Jan 2023 15:51:53 +0100 Subject: [PATCH 39/45] storage-pg: add tests for most remaining repositories Also fixes all the list_paginated() repository methods --- crates/storage-pg/src/compat/mod.rs | 125 +++++++++++++++++++++- crates/storage-pg/src/compat/sso_login.rs | 4 +- crates/storage-pg/src/oauth2/mod.rs | 23 ++++ crates/storage-pg/src/user/email.rs | 2 +- crates/storage-pg/src/user/tests.rs | 12 ++- 5 files changed, 161 insertions(+), 5 deletions(-) diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs index ae2a40b25..ab4ed33e6 100644 --- a/crates/storage-pg/src/compat/mod.rs +++ b/crates/storage-pg/src/compat/mod.rs @@ -35,11 +35,12 @@ mod tests { CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, }, user::UserRepository, - Clock, Repository, RepositoryAccess, + Clock, Pagination, Repository, RepositoryAccess, }; use rand::SeedableRng; use rand_chacha::ChaChaRng; use sqlx::PgPool; + use ulid::Ulid; use crate::PgRepository; @@ -323,4 +324,126 @@ mod tests { repo.save().await.unwrap(); } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_compat_sso_login_repository(pool: PgPool) { + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Lookup an unknown SSO login + let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap(); + assert_eq!(login, None); + + // Lookup an unknown login token + let login = repo + .compat_sso_login() + .find_by_token("login-token") + .await + .unwrap(); + assert_eq!(login, None); + + // Start a new SSO login + let login = repo + .compat_sso_login() + .add( + &mut rng, + &clock, + "login-token".to_owned(), + "https://example.com/callback".parse().unwrap(), + ) + .await + .unwrap(); + assert!(login.is_pending()); + + // Lookup the login by ID + let login_lookup = repo + .compat_sso_login() + .lookup(login.id) + .await + .unwrap() + .expect("login not found"); + assert_eq!(login_lookup, login); + + // Find the login by token + let login_lookup = repo + .compat_sso_login() + .find_by_token("login-token") + .await + .unwrap() + .expect("login not found"); + assert_eq!(login_lookup, login); + + // Exchanging before fulfilling should not work + // Note: It should also not poison the SQL transaction + let res = repo + .compat_sso_login() + .exchange(&clock, login.clone()) + .await; + assert!(res.is_err()); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + + // Associate the login with the session + let login = repo + .compat_sso_login() + .fulfill(&clock, login, &session) + .await + .unwrap(); + assert!(login.is_fulfilled()); + + // Fulfilling again should not work + // Note: It should also not poison the SQL transaction + let res = repo + .compat_sso_login() + .fulfill(&clock, login.clone(), &session) + .await; + assert!(res.is_err()); + + // Exchange that login + let login = repo + .compat_sso_login() + .exchange(&clock, login) + .await + .unwrap(); + assert!(login.is_exchanged()); + + // Exchange again should not work + // Note: It should also not poison the SQL transaction + let res = repo + .compat_sso_login() + .exchange(&clock, login.clone()) + .await; + assert!(res.is_err()); + + // Fulfilling after exchanging should not work + // Note: It should also not poison the SQL transaction + let res = repo + .compat_sso_login() + .fulfill(&clock, login.clone(), &session) + .await; + assert!(res.is_err()); + + // List the logins for the user + let logins = repo + .compat_sso_login() + .list_paginated(&user, Pagination::first(10)) + .await + .unwrap(); + assert!(!logins.has_next_page); + assert_eq!(logins.edges, vec![login]); + } } diff --git a/crates/storage-pg/src/compat/sso_login.rs b/crates/storage-pg/src/compat/sso_login.rs index ae9ca083c..328bd7890 100644 --- a/crates/storage-pg/src/compat/sso_login.rs +++ b/crates/storage-pg/src/compat/sso_login.rs @@ -323,12 +323,12 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { , cl.compat_session_id FROM compat_sso_logins cl - INNER JOIN compat_sessions ON compat_session_id + INNER JOIN compat_sessions cs USING (compat_session_id) "#, ); query - .push(" WHERE user_id = ") + .push(" WHERE cs.user_id = ") .push_bind(Uuid::from(user.id)) .generate_pagination("cl.compat_sso_login_id", pagination); diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs index c0659aa4c..120fca6cf 100644 --- a/crates/storage-pg/src/oauth2/mod.rs +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -176,6 +176,29 @@ mod tests { .await .unwrap(); + // Lookup the consent the user gave to the client + let consent = repo + .oauth2_client() + .get_consent_for_user(&client, &user) + .await + .unwrap(); + assert!(consent.is_empty()); + + // Give consent to the client + let scope = Scope::from_iter([OPENID]); + repo.oauth2_client() + .give_consent_for_user(&mut rng, &clock, &client, &user, &scope) + .await + .unwrap(); + + // Lookup the consent the user gave to the client + let consent = repo + .oauth2_client() + .get_consent_for_user(&client, &user) + .await + .unwrap(); + assert_eq!(scope, consent); + // Lookup a non-existing session let session = repo.oauth2_session().lookup(Ulid::nil()).await.unwrap(); assert_eq!(session, None); diff --git a/crates/storage-pg/src/user/email.rs b/crates/storage-pg/src/user/email.rs index 28b9d3951..147542d9d 100644 --- a/crates/storage-pg/src/user/email.rs +++ b/crates/storage-pg/src/user/email.rs @@ -249,7 +249,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { query .push(" WHERE user_id = ") .push_bind(Uuid::from(user.id)) - .generate_pagination("ue.user_email_id", pagination); + .generate_pagination("user_email_id", pagination); let edges: Vec = query .build_query_as() diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index 29ebe19f6..9aec949de 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -16,7 +16,7 @@ use chrono::Duration; use mas_storage::{ clock::MockClock, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - Repository, RepositoryAccess, + Pagination, Repository, RepositoryAccess, }; use rand::SeedableRng; use rand_chacha::ChaChaRng; @@ -230,6 +230,16 @@ async fn test_user_email_repo(pool: PgPool) { .unwrap() .is_some()); + // Listing the user emails should work + let emails = repo + .user_email() + .list_paginated(&user, Pagination::first(10)) + .await + .unwrap(); + assert!(!emails.has_next_page); + assert_eq!(emails.edges.len(), 1); + assert_eq!(emails.edges[0], user_email); + // Deleting the user email should work repo.user_email().remove(user_email).await.unwrap(); assert_eq!(repo.user_email().count(&user).await.unwrap(), 0); From 3081140f34255b3a9ba21c5575a123da2d510d5b Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 26 Jan 2023 17:58:03 +0100 Subject: [PATCH 40/45] storage{,-pg}: better documentation of both crates --- crates/storage-pg/src/errors.rs | 144 +++++++++++++++ crates/storage-pg/src/lib.rs | 291 +++++++++++++++++-------------- crates/storage-pg/src/tracing.rs | 2 + crates/storage/src/lib.rs | 185 +++++++++++++------- crates/storage/src/repository.rs | 14 ++ crates/storage/src/utils.rs | 86 +++++++++ docs/development/database.md | 86 ++------- 7 files changed, 544 insertions(+), 264 deletions(-) create mode 100644 crates/storage-pg/src/errors.rs create mode 100644 crates/storage/src/utils.rs diff --git a/crates/storage-pg/src/errors.rs b/crates/storage-pg/src/errors.rs new file mode 100644 index 000000000..ef7afad0f --- /dev/null +++ b/crates/storage-pg/src/errors.rs @@ -0,0 +1,144 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use sqlx::postgres::PgQueryResult; +use thiserror::Error; +use ulid::Ulid; + +/// Generic error when interacting with the database +#[derive(Debug, Error)] +#[error(transparent)] +pub enum DatabaseError { + /// An error which came from the database itself + Driver { + /// The underlying error from the database driver + #[from] + source: sqlx::Error, + }, + + /// An error which occured while converting the data from the database + Inconsistency(#[from] DatabaseInconsistencyError), + + /// An error which happened because the requested database operation is + /// invalid + #[error("Invalid database operation")] + InvalidOperation { + /// The source of the error, if any + #[source] + source: Option>, + }, + + /// An error which happens when an operation affects not enough or too many + /// rows + #[error("Expected {expected} rows to be affected, but {actual} rows were affected")] + RowsAffected { + /// How many rows were expected to be affected + expected: u64, + + /// How many rows were actually affected + actual: u64, + }, +} + +impl DatabaseError { + pub(crate) fn ensure_affected_rows( + result: &PgQueryResult, + expected: u64, + ) -> Result<(), DatabaseError> { + let actual = result.rows_affected(); + if actual == expected { + Ok(()) + } else { + Err(DatabaseError::RowsAffected { expected, actual }) + } + } + + pub(crate) fn to_invalid_operation(e: E) -> Self { + Self::InvalidOperation { + source: Some(Box::new(e)), + } + } + + pub(crate) const fn invalid_operation() -> Self { + Self::InvalidOperation { source: None } + } +} + +/// An error which occured while converting the data from the database +#[derive(Debug, Error)] +pub struct DatabaseInconsistencyError { + /// The table which was being queried + table: &'static str, + + /// The column which was being queried + column: Option<&'static str>, + + /// The row which was being queried + row: Option, + + /// The source of the error + #[source] + source: Option>, +} + +impl std::fmt::Display for DatabaseInconsistencyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Database inconsistency on table {}", self.table)?; + if let Some(column) = self.column { + write!(f, " column {column}")?; + } + if let Some(row) = self.row { + write!(f, " row {row}")?; + } + + Ok(()) + } +} + +impl DatabaseInconsistencyError { + /// Create a new [`DatabaseInconsistencyError`] for the given table + #[must_use] + pub(crate) const fn on(table: &'static str) -> Self { + Self { + table, + column: None, + row: None, + source: None, + } + } + + /// Set the column which was being queried + #[must_use] + pub(crate) const fn column(mut self, column: &'static str) -> Self { + self.column = Some(column); + self + } + + /// Set the row which was being queried + #[must_use] + pub(crate) const fn row(mut self, row: Ulid) -> Self { + self.row = Some(row); + self + } + + /// Give the source of the error + #[must_use] + pub(crate) fn source( + mut self, + source: E, + ) -> Self { + self.source = Some(Box::new(source)); + self + } +} diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs index 046cc4b8a..6615a3175 100644 --- a/crates/storage-pg/src/lib.rs +++ b/crates/storage-pg/src/lib.rs @@ -18,6 +18,152 @@ //! type-checked, using introspection data recorded in the `sqlx-data.json` //! file. This file is generated by the `sqlx` CLI tool, and should be updated //! whenever the database schema changes, or new queries are added. +//! +//! # Implementing a new repository +//! +//! When a new repository is defined in [`mas_storage`], it should be +//! implemented here, with the PostgreSQL backend. +//! +//! A typical implementation will look like this: +//! +//! ```rust +//! # use async_trait::async_trait; +//! # use ulid::Ulid; +//! # use rand::RngCore; +//! # use mas_storage::Clock; +//! # use mas_storage_pg::{DatabaseError, ExecuteExt, LookupResultExt}; +//! # use sqlx::PgConnection; +//! # use uuid::Uuid; +//! # +//! # // A fake data structure, usually defined in mas-data-model +//! # #[derive(sqlx::FromRow)] +//! # struct FakeData { +//! # id: Ulid, +//! # } +//! # +//! # // A fake repository trait, usually defined in mas-storage +//! # #[async_trait] +//! # pub trait FakeDataRepository: Send + Sync { +//! # type Error; +//! # async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; +//! # async fn add( +//! # &mut self, +//! # rng: &mut (dyn RngCore + Send), +//! # clock: &dyn Clock, +//! # ) -> Result; +//! # } +//! # +//! /// An implementation of [`FakeDataRepository`] for a PostgreSQL connection +//! pub struct PgFakeDataRepository<'c> { +//! conn: &'c mut PgConnection, +//! } +//! +//! impl<'c> PgFakeDataRepository<'c> { +//! /// Create a new [`FakeDataRepository`] from an active PostgreSQL connection +//! pub fn new(conn: &'c mut PgConnection) -> Self { +//! Self { conn } +//! } +//! } +//! +//! #[derive(sqlx::FromRow)] +//! struct FakeDataLookup { +//! fake_data_id: Uuid, +//! } +//! +//! impl From for FakeData { +//! fn from(value: FakeDataLookup) -> Self { +//! Self { +//! id: value.fake_data_id.into(), +//! } +//! } +//! } +//! +//! #[async_trait] +//! impl<'c> FakeDataRepository for PgFakeDataRepository<'c> { +//! type Error = DatabaseError; +//! +//! #[tracing::instrument( +//! name = "db.fake_data.lookup", +//! skip_all, +//! fields( +//! db.statement, +//! fake_data.id = %id, +//! ), +//! err, +//! )] +//! async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { +//! // Note: here we would use the macro version instead, but it's not possible here in +//! // this documentation example +//! let res: Option = sqlx::query_as( +//! r#" +//! SELECT fake_data_id +//! FROM fake_data +//! WHERE fake_data_id = $1 +//! "#, +//! ) +//! .bind(Uuid::from(id)) +//! .traced() +//! .fetch_one(&mut *self.conn) +//! .await +//! .to_option()?; +//! +//! let Some(res) = res else { return Ok(None) }; +//! +//! Ok(Some(res.into())) +//! } +//! +//! #[tracing::instrument( +//! name = "db.fake_data.add", +//! skip_all, +//! fields( +//! db.statement, +//! fake_data.id, +//! ), +//! err, +//! )] +//! async fn add( +//! &mut self, +//! rng: &mut (dyn RngCore + Send), +//! clock: &dyn Clock, +//! ) -> Result { +//! let created_at = clock.now(); +//! let id = Ulid::from_datetime_with_source(created_at.into(), rng); +//! tracing::Span::current().record("fake_data.id", tracing::field::display(id)); +//! +//! // Note: here we would use the macro version instead, but it's not possible here in +//! // this documentation example +//! sqlx::query( +//! r#" +//! INSERT INTO fake_data (id) +//! VALUES ($1) +//! "#, +//! ) +//! .bind(Uuid::from(id)) +//! .traced() +//! .execute(&mut *self.conn) +//! .await?; +//! +//! Ok(FakeData { +//! id, +//! }) +//! } +//! } +//! ``` +//! +//! A few things to note with the implementation: +//! +//! - All methods are traced, with an explicit, somewhat consistent name. +//! - The SQL statement is included as attribute, by declaring a `db.statement` +//! attribute on the tracing span, and then calling [`ExecuteExt::traced`]. +//! - The IDs are all [`Ulid`], and generated from the clock and the random +//! number generated passed as parameters. The generated IDs are recorded in +//! the span. +//! - The IDs are stored as [`Uuid`] in PostgreSQL, so conversions are required +//! - "Not found" errors are handled by returning `Ok(None)` instead of an +//! error. The [`LookupResultExt::to_option`] method helps to do that. +//! +//! [`Ulid`]: ulid::Ulid +//! [`Uuid`]: uuid::Uuid #![forbid(unsafe_code)] #![deny( @@ -30,17 +176,23 @@ #![warn(clippy::pedantic)] #![allow(clippy::module_name_repetitions)] -use sqlx::{migrate::Migrator, postgres::PgQueryResult}; -use thiserror::Error; -use ulid::Ulid; +use sqlx::migrate::Migrator; /// An extension trait for [`Result`] which adds a [`to_option`] method, useful /// for handling "not found" errors from [`sqlx`] -trait LookupResultExt { +/// +/// [`to_option`]: LookupResultExt::to_option +pub trait LookupResultExt { + /// The output type type Output; /// Transform a [`Result`] from a sqlx query to transform "not found" errors /// into [`None`] + /// + /// # Errors + /// + /// Returns the original error if the error was not a + /// [`sqlx::Error::RowNotFound`] error fn to_option(self) -> Result, sqlx::Error>; } @@ -56,143 +208,18 @@ impl LookupResultExt for Result { } } -/// Generic error when interacting with the database -#[derive(Debug, Error)] -#[error(transparent)] -pub enum DatabaseError { - /// An error which came from the database itself - Driver { - /// The underlying error from the database driver - #[from] - source: sqlx::Error, - }, - - /// An error which occured while converting the data from the database - Inconsistency(#[from] DatabaseInconsistencyError), - - /// An error which happened because the requested database operation is - /// invalid - #[error("Invalid database operation")] - InvalidOperation { - /// The source of the error, if any - #[source] - source: Option>, - }, - - /// An error which happens when an operation affects not enough or too many - /// rows - #[error("Expected {expected} rows to be affected, but {actual} rows were affected")] - RowsAffected { - /// How many rows were expected to be affected - expected: u64, - - /// How many rows were actually affected - actual: u64, - }, -} - -impl DatabaseError { - pub(crate) fn ensure_affected_rows( - result: &PgQueryResult, - expected: u64, - ) -> Result<(), DatabaseError> { - let actual = result.rows_affected(); - if actual == expected { - Ok(()) - } else { - Err(DatabaseError::RowsAffected { expected, actual }) - } - } - - pub(crate) fn to_invalid_operation(e: E) -> Self { - Self::InvalidOperation { - source: Some(Box::new(e)), - } - } - - pub(crate) const fn invalid_operation() -> Self { - Self::InvalidOperation { source: None } - } -} - -/// An error which occured while converting the data from the database -#[derive(Debug, Error)] -pub struct DatabaseInconsistencyError { - /// The table which was being queried - table: &'static str, - - /// The column which was being queried - column: Option<&'static str>, - - /// The row which was being queried - row: Option, - - /// The source of the error - #[source] - source: Option>, -} - -impl std::fmt::Display for DatabaseInconsistencyError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Database inconsistency on table {}", self.table)?; - if let Some(column) = self.column { - write!(f, " column {column}")?; - } - if let Some(row) = self.row { - write!(f, " row {row}")?; - } - - Ok(()) - } -} - -impl DatabaseInconsistencyError { - /// Create a new [`DatabaseInconsistencyError`] for the given table - #[must_use] - pub(crate) const fn on(table: &'static str) -> Self { - Self { - table, - column: None, - row: None, - source: None, - } - } - - /// Set the column which was being queried - #[must_use] - pub(crate) const fn column(mut self, column: &'static str) -> Self { - self.column = Some(column); - self - } - - /// Set the row which was being queried - #[must_use] - pub(crate) const fn row(mut self, row: Ulid) -> Self { - self.row = Some(row); - self - } - - /// Give the source of the error - #[must_use] - pub(crate) fn source( - mut self, - source: E, - ) -> Self { - self.source = Some(Box::new(source)); - self - } -} - pub mod compat; pub mod oauth2; pub mod upstream_oauth2; pub mod user; +mod errors; pub(crate) mod pagination; pub(crate) mod repository; pub(crate) mod tracing; -pub use self::repository::PgRepository; +pub(crate) use self::errors::DatabaseInconsistencyError; +pub use self::{errors::DatabaseError, repository::PgRepository, tracing::ExecuteExt}; /// Embedded migrations, allowing them to run on startup pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/crates/storage-pg/src/tracing.rs b/crates/storage-pg/src/tracing.rs index b0bc0b7ff..853b5d9d9 100644 --- a/crates/storage-pg/src/tracing.rs +++ b/crates/storage-pg/src/tracing.rs @@ -18,11 +18,13 @@ use tracing::Span; /// `db.statement` in a tracing span pub trait ExecuteExt<'q, DB>: Sized { /// Records the statement as `db.statement` in the current span + #[must_use] fn traced(self) -> Self { self.record(&Span::current()) } /// Records the statement as `db.statement` in the given span + #[must_use] fn record(self, span: &Span) -> Self; } diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index cffec045a..0e8458b7d 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -13,6 +13,125 @@ // limitations under the License. //! Interactions with the storage backend +//! +//! This crate provides a set of traits that can be implemented to interact with +//! the storage backend. Those traits are called repositories and are grouped by +//! the type of data they manage. +//! +//! Each of those reposotories can be accessed via the [`RepositoryAccess`] +//! trait. This trait can be wrapped in a [`BoxRepository`] to allow using it +//! without caring about the underlying storage backend, and without carrying +//! around the generic type parameter. +//! +//! This crate also defines a [`Clock`] trait that can be used to abstract the +//! way the current time is retrieved. It has two implementation: +//! [`SystemClock`] that uses the system time and [`MockClock`] which is useful +//! for testing. +//! +//! [`MockClock`]: crate::clock::MockClock +//! +//! # Defining a new repository +//! +//! To define a new repository, you have to: +//! 1. Define a new (async) repository trait, with the methods you need +//! 2. Write an implementation of this trait for each storage backend you want +//! (currently only for [`mas-storage-pg`]) +//! 3. Make it accessible via the [`RepositoryAccess`] trait +//! +//! The repository trait definition should look like this: +//! +//! ```rust +//! # use async_trait::async_trait; +//! # use ulid::Ulid; +//! # use rand_core::RngCore; +//! # use mas_storage::Clock; +//! # +//! # // A fake data structure, usually defined in mas-data-model +//! # struct FakeData { +//! # id: Ulid, +//! # } +//! # +//! # // A fake empty macro, to replace `mas_storage::repository_impl` +//! # macro_rules! repository_impl { ($($tok:tt)*) => {} } +//! +//! #[async_trait] +//! pub trait FakeDataRepository: Send + Sync { +//! /// The error type returned by the repository +//! type Error; +//! +//! /// Lookup a [`FakeData`] by its ID +//! /// +//! /// Returns `None` if no [`FakeData`] was found +//! /// +//! /// # Parameters +//! /// +//! /// * `id`: The ID of the [`FakeData`] to lookup +//! /// +//! /// # Errors +//! /// +//! /// Returns [`Self::Error`] if the underlying repository fails +//! async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; +//! +//! /// Create a new [`FakeData`] +//! /// +//! /// Returns the newly-created [`FakeData`]. +//! /// +//! /// # Parameters +//! /// +//! /// * `rng`: The random number generator to use +//! /// * `clock`: The clock used to generate timestamps +//! /// +//! /// # Errors +//! /// +//! /// Returns [`Self::Error`] if the underlying repository fails +//! async fn add( +//! &mut self, +//! rng: &mut (dyn RngCore + Send), +//! clock: &dyn Clock, +//! ) -> Result; +//! } +//! +//! repository_impl!(FakeDataRepository: +//! async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; +//! async fn add( +//! &mut self, +//! rng: &mut (dyn RngCore + Send), +//! clock: &dyn Clock, +//! ) -> Result; +//! ); +//! ``` +//! +//! Four things to note with the implementation: +//! +//! 1. It defined an assocated error type, and all functions are faillible, +//! and use that error type +//! 2. Lookups return an `Result, Self::Error>`, because 'not found' +//! errors are usually cases that are handled differently +//! 3. Operations that need to record the current type use a [`Clock`] +//! parameter. Operations that need to generate new IDs also use a random +//! number generator. +//! 4. All the methods use an `&mut self`. This is ensures only one operation +//! is done at a time on a single repository instance. +//! +//! Then update the [`RepositoryAccess`] trait to make the new repository +//! available: +//! +//! ```rust +//! # trait FakeDataRepository { +//! # type Error; +//! # } +//! +//! /// Access the various repositories the backend implements. +//! pub trait RepositoryAccess: Send { +//! /// The backend-specific error type used by each repository. +//! type Error: std::error::Error + Send + Sync + 'static; +//! +//! // ...other repositories... +//! +//! /// Get a [`FakeDataRepository`] +//! fn fake_data<'c>(&'c mut self) -> Box + 'c>; +//! } +//! ``` #![forbid(unsafe_code)] #![deny( @@ -25,11 +144,10 @@ #![warn(clippy::pedantic)] #![allow(clippy::module_name_repetitions)] -use rand_core::CryptoRngCore; - pub mod clock; pub mod pagination; pub(crate) mod repository; +mod utils; pub mod compat; pub mod oauth2; @@ -42,66 +160,5 @@ pub use self::{ repository::{ BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction, }, + utils::{BoxClock, BoxRng, MapErr}, }; - -/// A wrapper which is used to map the error type of a repository to another -pub struct MapErr { - inner: R, - mapper: F, -} - -impl MapErr { - fn new(inner: R, mapper: F) -> Self { - Self { inner, mapper } - } -} - -/// A macro to implement a repository trait for the [`MapErr`] wrapper and for -/// [`Box`] -#[macro_export] -macro_rules! repository_impl { - ($repo_trait:ident: - $( - async fn $method:ident ( - &mut self - $(, $arg:ident: $arg_ty:ty )* - $(,)? - ) -> Result<$ret_ty:ty, Self::Error>; - )* - ) => { - #[::async_trait::async_trait] - impl $repo_trait for ::std::boxed::Box - where - R: $repo_trait, - { - type Error = ::Error; - - $( - async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> { - (**self).$method ( $($arg),* ).await - } - )* - } - - #[::async_trait::async_trait] - impl $repo_trait for $crate::MapErr - where - R: $repo_trait, - F: FnMut(::Error) -> E + ::std::marker::Send + ::std::marker::Sync, - { - type Error = E; - - $( - async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> { - self.inner.$method ( $($arg),* ).await.map_err(&mut self.mapper) - } - )* - } - }; -} - -/// A boxed [`Clock`] -pub type BoxClock = Box; - -/// A boxed random number generator -pub type BoxRng = Box; diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index f023e469b..c76e98665 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -101,6 +101,20 @@ pub trait RepositoryTransaction { } /// Access the various repositories the backend implements. +/// +/// All the methods return a boxed trait object, which can be used to access a +/// particular repository. The lifetime of the returned object is bound to the +/// lifetime of the whole repository, so that only one mutable reference to the +/// repository is used at a time. +/// +/// When adding a new repository, you should add a new method to this trait, and +/// update the implementations for [`MapErr`] and [`Box`] below. +/// +/// Note: this used to have generic associated types to avoid boxing all the +/// repository traits, but that was removed because it made almost impossible to +/// box the trait object. This might be a shortcoming of the initial +/// implementation of generic associated types, and might be fixed in the +/// future. pub trait RepositoryAccess: Send { /// The backend-specific error type used by each repository. type Error: std::error::Error + Send + Sync + 'static; diff --git a/crates/storage/src/utils.rs b/crates/storage/src/utils.rs new file mode 100644 index 000000000..44caa23ea --- /dev/null +++ b/crates/storage/src/utils.rs @@ -0,0 +1,86 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Wrappers and useful type aliases + +use rand_core::CryptoRngCore; + +use crate::Clock; + +/// A wrapper which is used to map the error type of a repository to another +pub struct MapErr { + pub(crate) inner: R, + pub(crate) mapper: F, + _private: (), +} + +impl MapErr { + pub(crate) fn new(inner: R, mapper: F) -> Self { + Self { + inner, + mapper, + _private: (), + } + } +} + +/// A boxed [`Clock`] +pub type BoxClock = Box; + +/// A boxed random number generator +pub type BoxRng = Box; + +/// A macro to implement a repository trait for the [`MapErr`] wrapper and for +/// [`Box`] +#[macro_export] +macro_rules! repository_impl { + ($repo_trait:ident: + $( + async fn $method:ident ( + &mut self + $(, $arg:ident: $arg_ty:ty )* + $(,)? + ) -> Result<$ret_ty:ty, Self::Error>; + )* + ) => { + #[::async_trait::async_trait] + impl $repo_trait for ::std::boxed::Box + where + R: $repo_trait, + { + type Error = ::Error; + + $( + async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> { + (**self).$method ( $($arg),* ).await + } + )* + } + + #[::async_trait::async_trait] + impl $repo_trait for $crate::MapErr + where + R: $repo_trait, + F: FnMut(::Error) -> E + ::std::marker::Send + ::std::marker::Sync, + { + type Error = E; + + $( + async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> { + self.inner.$method ( $($arg),* ).await.map_err(&mut self.mapper) + } + )* + } + }; +} diff --git a/docs/development/database.md b/docs/development/database.md index 513fe7eff..80c6aef21 100644 --- a/docs/development/database.md +++ b/docs/development/database.md @@ -3,6 +3,21 @@ Interactions with the database goes through `sqlx`. It provides async database operations with connection pooling, migrations support and compile-time check of queries through macros. +## Writing database interactions + +All database interactions are done through repositoriy traits. Each repository trait usually manages one type of data, defined in the `mas-data-model` crate. + +Defining a new data type and associated repository looks like this: + + - Define new structs in `mas-data-model` crate + - Define the repository trait in `mas-storage` crate + - Make that repository trait available via the `RepositoryAccess` trait in `mas-storage` crate + - Setup the database schema by writing a migration file in `mas-storage-pg` crate + - Implement the new repository trait in `mas-storage-pg` crate + - Write tests for the PostgreSQL implementation in `mas-storage-pg` crate + +Some of those steps are documented in more details in the `mas-storage` and `mas-storage-pg` crates. + ## Compile-time check of queries To be able to check queries, `sqlx` has to introspect the live database. @@ -14,7 +29,7 @@ Preparing this flat file is done through `sqlx-cli`, and should be done everytim # Install the CLI cargo install sqlx-cli --no-default-features --features postgres -cd crates/storage/ # Must be in the mas-storage crate folder +cd crates/storage-pg/ # Must be in the mas-storage-pg crate folder export DATABASE_URL=postgresql:///matrix_auth cargo sqlx prepare ``` @@ -24,75 +39,10 @@ cargo sqlx prepare Migration files live in the `migrations` folder in the `mas-core` crate. ```sh -cd crates/storage/ # Again, in the mas-storage crate folder +cd crates/storage-pg/ # Again, in the mas-storage-pg crate folder export DATABASE_URL=postgresql:///matrix_auth cargo sqlx migrate run # Run pending migrations -cargo sqlx migrate revert # Revert the last migration -cargo sqlx migrate add -r [description] # Add new migration files +cargo sqlx migrate add [description] # Add new migration files ``` Note that migrations are embedded in the final binary and can be run from the service CLI tool. - -## Writing database interactions - -**TODO**: *This section is outdated.* - -A typical interaction with the database look like this: - -```rust -pub async fn lookup_session( - executor: impl Executor<'_, Database = Postgres>, - id: i64, -) -> anyhow::Result { - sqlx::query_as!( - SessionInfo, // Struct that will be filled with the result - r#" - SELECT - s.id, - u.id as user_id, - u.username, - s.active, - s.created_at, - a.created_at as "last_authd_at?" - FROM user_sessions s - INNER JOIN users u - ON s.user_id = u.id - LEFT JOIN user_session_authentications a - ON a.session_id = s.id - WHERE s.id = $1 - ORDER BY a.created_at DESC - LIMIT 1 - "#, - id, // Query parameter - ) - .fetch_one(executor) - .await - // Providing some context when there is an error - .context("could not fetch session") -} -``` - -Note that we pass an `impl Executor` as parameter here. -This allows us to use this function from either a simple connection or from an active transaction. - -The caveat here is that the `executor` can be used only once, so if an interaction needs to do multiple queries, it should probably take an `impl Acquire` to then acquire a transaction and do multiple interactions. - -```rust -pub async fn login( - conn: impl Acquire<'_, Database = Postgres>, - username: &str, - password: String, -) -> Result { - let mut txn = conn.begin().await.context("could not start transaction")?; - // First interaction - let user = lookup_user_by_username(&mut txn, username)?; - // Second interaction - let mut session = start_session(&mut txn, user).await?; - // Third interaction - session.last_authd_at = - Some(authenticate_session(&mut txn, session.id, password).await?); - // Commit the transaction once everything went fine - txn.commit().await.context("could not commit transaction")?; - Ok(session) -} -``` From fec5d20eee008db39d401a2a58f25eb098192a07 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 27 Jan 2023 10:30:59 +0100 Subject: [PATCH 41/45] axum-utils: remove an unnecessary ?Sized bound --- crates/axum-utils/src/client_authorization.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 6f5b3e270..8dff5cdfe 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -74,7 +74,7 @@ pub enum Credentials { impl Credentials { pub async fn fetch( &self, - repo: &mut (impl RepositoryAccess + ?Sized), + repo: &mut impl RepositoryAccess, ) -> Result, E> { let client_id = match self { Credentials::None { client_id } From f2bc613d5c1967bcbafcf1a8ef8f3074e3e75d61 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 27 Jan 2023 10:47:27 +0100 Subject: [PATCH 42/45] ci: publish docs without pushing to the gh-pages branch --- .github/workflows/docs.yaml | 44 ++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index e78a94ef2..dd4d8e74f 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -1,30 +1,48 @@ -name: Deploy the documentation +name: Build and deploy the documentation on: push: - branches: - - main + branches: [ main ] + pull_request: + branches: [ main ] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: - pages: - name: GitHub Pages + build: + name: Build the documentation runs-on: ubuntu-latest steps: - name: Checkout the code uses: actions/checkout@v3 - name: Setup mdBook - uses: peaceiris/actions-mdbook@adeb05db28a0c0004681db83893d56c0388ea9ea # v1.1.14 + uses: peaceiris/actions-mdbook@v1.2.0 with: - mdbook-version: '0.4.12' + mdbook-version: '0.4.25' - name: Build the documentation run: mdbook build + - name: Upload GitHub Pages artifacts + uses: actions/upload-pages-artifact@v1.0.7 + + deploy: + name: Deploy the documentation on GitHub Pages + runs-on: ubuntu-latest + needs: build + + permissions: + pages: write + id-token: write + + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + + steps: - name: Deploy to GitHub Pages - uses: peaceiris/actions-gh-pages@de7ea6f8efb354206b205ef54722213d99067935 # v3.8.0 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: ./target/book - - + id: deployment + uses: actions/deploy-pages@v1.2.3 From f1536b35e38475bac0c4eb1a12d758412d1acbc8 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 27 Jan 2023 11:18:11 +0100 Subject: [PATCH 43/45] ci: fix the docs build --- .github/workflows/docs.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 58f8c0ac8..483466135 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -28,6 +28,8 @@ jobs: - name: Upload GitHub Pages artifacts uses: actions/upload-pages-artifact@v1.0.7 + with: + path: target/book/ deploy: name: Deploy the documentation on GitHub Pages From 855178a6132515affb1c5acf06227590004241d6 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 27 Jan 2023 11:19:48 +0100 Subject: [PATCH 44/45] ci: deploy docs only on push to master --- .github/workflows/docs.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 483466135..9b9a36e37 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -35,6 +35,7 @@ jobs: name: Deploy the documentation on GitHub Pages runs-on: ubuntu-latest needs: build + if: github.ref == 'refs/heads/main' permissions: pages: write From 8b1f64d7937acde92bfe3df20d54c5b272836690 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 27 Jan 2023 11:40:19 +0100 Subject: [PATCH 45/45] docs: link to rustdoc pages from the mdbook --- .github/workflows/docs.yaml | 18 ++++++++++++++++ docs/development/architecture.md | 36 ++++++++++++++++++++++---------- docs/development/database.md | 20 +++++++++++------- 3 files changed, 55 insertions(+), 19 deletions(-) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 9b9a36e37..125f1f315 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -10,6 +10,10 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true +env: + CARGO_TERM_COLOR: always + CARGO_NET_GIT_FETCH_WITH_CLI: "true" + jobs: build: name: Build the documentation @@ -18,6 +22,14 @@ jobs: - name: Checkout the code uses: actions/checkout@v3 + - name: Install Rust toolchain + run: | + rustup toolchain install nightly + rustup default nightly + + - name: Setup Rust cache + uses: Swatinem/rust-cache@v2 + - name: Setup mdBook uses: peaceiris/actions-mdbook@v1.2.0 with: @@ -26,6 +38,12 @@ jobs: - name: Build the documentation run: mdbook build + - name: Build rustdoc + run: cargo doc -Zrustdoc-map --workspace --lib --no-deps + + - name: Move the Rust documentation within the mdBook + run: mv target/doc target/book/rustdoc + - name: Upload GitHub Pages artifacts uses: actions/upload-pages-artifact@v1.0.7 with: diff --git a/docs/development/architecture.md b/docs/development/architecture.md index f15b31203..691c881f4 100644 --- a/docs/development/architecture.md +++ b/docs/development/architecture.md @@ -10,17 +10,31 @@ The whole repository is a [Cargo Workspace](https://doc.rust-lang.org/book/ch14- This includes: - `mas-cli`: Command line utility, main entry point - - `mas-config`: Configuration parsing and loading - - `mas-data-model`: Models of objects that live in the database, regardless of the storage backend - - `mas-email`: High-level email sending abstraction - - `mas-handlers`: Main HTTP application logic - - `mas-iana`: Auto-generated enums from IANA registries - - `mas-iana-codegen`: Code generator for the `mas-iana` crate - - `mas-jose`: JWT/JWS/JWE/JWK abstraction - - `mas-static-files`: Frontend static files (CSS/JS). Includes some frontend tooling - - `mas-storage`: Interactions with the database - - `mas-tasks`: Asynchronous task runner and scheduler - - `oauth2-types`: Useful structures and types to deal with OAuth 2.0/OpenID Connect endpoints. This might end up published as a standalone library as it can be useful in other contexts. + - [`mas-config`][mas-config]: Configuration parsing and loading + - [`mas-data-model`][mas-data-model]: Models of objects that live in the database, regardless of the storage backend + - [`mas-email`][mas-email]: High-level email sending abstraction + - [`mas-handlers`][mas-handlers]: Main HTTP application logic + - [`mas-iana`][mas-iana]: Auto-generated enums from IANA registries + - [`mas-iana-codegen`][mas-iana-codegen]: Code generator for the `mas-iana` crate + - [`mas-jose`][mas-jose]: JWT/JWS/JWE/JWK abstraction + - [`mas-static-files`][mas-static-files]: Frontend static files (CSS/JS). Includes some frontend tooling + - [`mas-storage`][mas-storage]: Abstraction of the storage backends + - [`mas-storage-pg`][mas-storage-pg]: Storage backend implementation for a PostgreSQL database + - [`mas-tasks`][mas-tasks]: Asynchronous task runner and scheduler + - [`oauth2-types`][oauth2-types]: Useful structures and types to deal with OAuth 2.0/OpenID Connect endpoints. This might end up published as a standalone library as it can be useful in other contexts. + +[mas-config]: ../rustdoc/mas_config/index.html +[mas-data-model]: ../rustdoc/mas_data_model/index.html +[mas-email]: ../rustdoc/mas_email/index.html +[mas-handlers]: ../rustdoc/mas_handlers/index.html +[mas-iana]: ../rustdoc/mas_iana/index.html +[mas-iana-codegen]: ../rustdoc/mas_iana_codegen/index.html +[mas-jose]: ../rustdoc/mas_jose/index.html +[mas-static-files]: ../rustdoc/mas_static_files/index.html +[mas-storage]: ../rustdoc/mas_storage/index.html +[mas-storage-pg]: ../rustdoc/mas_storage/index.html +[mas-tasks]: ../rustdoc/mas_tasks/index.html +[oauth2-types]: ../rustdoc/oauth2_types/index.html ## Important crates diff --git a/docs/development/database.md b/docs/development/database.md index 80c6aef21..5ffe8a5a8 100644 --- a/docs/development/database.md +++ b/docs/development/database.md @@ -5,18 +5,22 @@ It provides async database operations with connection pooling, migrations suppor ## Writing database interactions -All database interactions are done through repositoriy traits. Each repository trait usually manages one type of data, defined in the `mas-data-model` crate. +All database interactions are done through repositoriy traits. Each repository trait usually manages one type of data, defined in the [`mas-data-model`][mas-data-model] crate. Defining a new data type and associated repository looks like this: - - Define new structs in `mas-data-model` crate - - Define the repository trait in `mas-storage` crate - - Make that repository trait available via the `RepositoryAccess` trait in `mas-storage` crate - - Setup the database schema by writing a migration file in `mas-storage-pg` crate - - Implement the new repository trait in `mas-storage-pg` crate - - Write tests for the PostgreSQL implementation in `mas-storage-pg` crate + - Define new structs in [`mas-data-model`][mas-data-model] crate + - Define the repository trait in [`mas-storage`][mas-storage] crate + - Make that repository trait available via the `RepositoryAccess` trait in [`mas-storage`][mas-storage] crate + - Setup the database schema by writing a migration file in [`mas-storage-pg`][mas-storage-pg] crate + - Implement the new repository trait in [`mas-storage-pg`][mas-storage-pg] crate + - Write tests for the PostgreSQL implementation in [`mas-storage-pg`][mas-storage-pg] crate -Some of those steps are documented in more details in the `mas-storage` and `mas-storage-pg` crates. +Some of those steps are documented in more details in the [`mas-storage`][mas-storage] and [`mas-storage-pg`][mas-storage-pg] crates. + +[mas-data-model]: ../rustdoc/mas_data_model/index.html +[mas-storage]: ../rustdoc/mas_storage/index.html +[mas-storage-pg]: ../rustdoc/mas_storage_pg/index.html ## Compile-time check of queries