diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index a5db4ea2f..125f1f315 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -1,30 +1,69 @@ -name: Deploy the documentation +name: Build and deploy the documentation on: push: - branches: - - main + branches: [ main ] + pull_request: + branches: [ main ] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + CARGO_TERM_COLOR: always + CARGO_NET_GIT_FETCH_WITH_CLI: "true" jobs: - pages: - name: GitHub Pages + build: + name: Build the documentation runs-on: ubuntu-latest steps: - name: Checkout the code uses: actions/checkout@v3 + - name: Install Rust toolchain + run: | + rustup toolchain install nightly + rustup default nightly + + - name: Setup Rust cache + uses: Swatinem/rust-cache@v2 + - name: Setup mdBook - uses: peaceiris/actions-mdbook@adeb05db28a0c0004681db83893d56c0388ea9ea # v1.1.14 + uses: peaceiris/actions-mdbook@v1.2.0 with: - mdbook-version: '0.4.12' + mdbook-version: '0.4.25' - name: Build the documentation run: mdbook build - - - name: Deploy to GitHub Pages - uses: peaceiris/actions-gh-pages@64b46b4226a4a12da2239ba3ea5aa73e3163c75b # v3.8.0 + + - name: Build rustdoc + run: cargo doc -Zrustdoc-map --workspace --lib --no-deps + + - name: Move the Rust documentation within the mdBook + run: mv target/doc target/book/rustdoc + + - name: Upload GitHub Pages artifacts + uses: actions/upload-pages-artifact@v1.0.7 with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: ./target/book + path: target/book/ + deploy: + name: Deploy the documentation on GitHub Pages + runs-on: ubuntu-latest + needs: build + if: github.ref == 'refs/heads/main' + permissions: + pages: write + id-token: write + + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v1.2.3 diff --git a/Cargo.lock b/Cargo.lock index 2bf7dc92e..7ef9d0cc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2688,7 +2688,6 @@ dependencies = [ "serde_json", "serde_urlencoded", "serde_with", - "sqlx", "thiserror", "tokio", "tower", @@ -2721,6 +2720,7 @@ dependencies = [ "mas-router", "mas-spa", "mas-storage", + "mas-storage-pg", "mas-tasks", "mas-templates", "oauth2-types", @@ -2821,8 +2821,8 @@ dependencies = [ "mas-storage", "oauth2-types", "serde", - "sqlx", "thiserror", + "tokio", "tracing", "ulid", "url", @@ -2859,6 +2859,7 @@ dependencies = [ "mas-policy", "mas-router", "mas-storage", + "mas-storage-pg", "mas-templates", "mime", "oauth2-types", @@ -3112,12 +3113,33 @@ dependencies = [ name = "mas-storage" version = "0.1.0" dependencies = [ + "async-trait", "chrono", + "futures-util", "mas-data-model", "mas-iana", "mas-jose", "oauth2-types", + "rand_core 0.6.4", + "thiserror", + "ulid", + "url", +] + +[[package]] +name = "mas-storage-pg" +version = "0.1.0" +dependencies = [ + "async-trait", + "chrono", + "futures-util", + "mas-data-model", + "mas-iana", + "mas-jose", + "mas-storage", + "oauth2-types", "rand 0.8.5", + "rand_chacha 0.3.1", "serde", "serde_json", "sqlx", @@ -3135,6 +3157,7 @@ dependencies = [ "async-trait", "futures-util", "mas-storage", + "mas-storage-pg", "sqlx", "tokio", "tokio-stream", @@ -5590,8 +5613,7 @@ checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" [[package]] name = "ulid" version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13a3aaa69b04e5b66cc27309710a569ea23593612387d67daaf102e73aa974fd" +source = "git+https://github.com/dylanhart/ulid-rs.git?rev=0b9295c2db2114cd87aa19abcc1fc00c16b272db#0b9295c2db2114cd87aa19abcc1fc00c16b272db" dependencies = [ "rand 0.8.5", "serde", diff --git a/Cargo.toml b/Cargo.toml index f621be0cc..644ecbf4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,8 @@ opt-level = 3 [profile.dev.package.sqlx-macros] opt-level = 3 + +# Until https://github.com/dylanhart/ulid-rs/pull/56 gets released +[patch.crates-io.ulid] +git = "https://github.com/dylanhart/ulid-rs.git" +rev = "0b9295c2db2114cd87aa19abcc1fc00c16b272db" diff --git a/clippy.toml b/clippy.toml index 61c5c04f2..a300ae05b 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1,4 +1,4 @@ -doc-valid-idents = ["OpenID", "OAuth", ".."] +doc-valid-idents = ["OpenID", "OAuth", "..", "PostgreSQL"] disallowed-methods = [ { path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" }, diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index ac16c1407..24487a019 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -21,7 +21,6 @@ serde = "1.0.152" serde_with = "2.1.0" serde_urlencoded = "0.7.1" serde_json = "1.0.91" -sqlx = "0.6.2" thiserror = "1.0.38" tokio = "1.24.1" tower = { version = "0.4.13", features = ["util"] } diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 00baf4a52..8dff5cdfe 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -31,10 +31,9 @@ use mas_http::HttpServiceExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_keystore::Encrypter; -use mas_storage::{oauth2::client::lookup_client_by_client_id, DatabaseError}; +use mas_storage::{oauth2::OAuth2ClientRepository, RepositoryAccess}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; -use sqlx::PgExecutor; use thiserror::Error; use tower::{Service, ServiceExt}; @@ -73,10 +72,10 @@ pub enum Credentials { } impl Credentials { - pub async fn fetch( + pub async fn fetch( &self, - executor: impl PgExecutor<'_>, - ) -> Result, DatabaseError> { + repo: &mut impl RepositoryAccess, + ) -> Result, E> { let client_id = match self { Credentials::None { client_id } | Credentials::ClientSecretBasic { client_id, .. } @@ -84,7 +83,7 @@ impl Credentials { | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id, }; - lookup_client_by_client_id(executor, client_id).await + repo.oauth2_client().find_by_client_id(client_id).await } #[tracing::instrument(skip_all, err)] diff --git a/crates/axum-utils/src/csrf.rs b/crates/axum-utils/src/csrf.rs index e7037d55a..9886fedb1 100644 --- a/crates/axum-utils/src/csrf.rs +++ b/crates/axum-utils/src/csrf.rs @@ -15,6 +15,7 @@ use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use chrono::{DateTime, Duration, Utc}; use data_encoding::{DecodeError, BASE64URL_NOPAD}; +use mas_storage::Clock; use rand::{Rng, RngCore}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, TimestampSeconds}; @@ -108,22 +109,27 @@ pub struct ProtectedForm { } pub trait CsrfExt { - fn csrf_token(self, now: DateTime, rng: R) -> (CsrfToken, Self) + fn csrf_token(self, clock: &C, rng: R) -> (CsrfToken, Self) where - R: RngCore; - fn verify_form(&self, now: DateTime, form: ProtectedForm) -> Result; + R: RngCore, + C: Clock; + fn verify_form(&self, clock: &C, form: ProtectedForm) -> Result + where + C: Clock; } impl CsrfExt for PrivateCookieJar { - fn csrf_token(self, now: DateTime, rng: R) -> (CsrfToken, Self) + fn csrf_token(self, clock: &C, rng: R) -> (CsrfToken, Self) where R: RngCore, + C: Clock, { let jar = self; let mut cookie = jar.get("csrf").unwrap_or_else(|| Cookie::new("csrf", "")); cookie.set_path("/"); cookie.set_http_only(true); + let now = clock.now(); let new_token = cookie .decode() .ok() @@ -136,10 +142,13 @@ impl CsrfExt for PrivateCookieJar { (new_token, jar) } - fn verify_form(&self, now: DateTime, form: ProtectedForm) -> Result { + fn verify_form(&self, clock: &C, form: ProtectedForm) -> Result + where + C: Clock, + { let cookie = self.get("csrf").ok_or(CsrfError::Missing)?; let token: CsrfToken = cookie.decode()?; - let token = token.verify_expiration(now)?; + let token = token.verify_expiration(clock.now())?; token.verify_form_value(&form.csrf)?; Ok(form.inner) } diff --git a/crates/axum-utils/src/http_client_factory.rs b/crates/axum-utils/src/http_client_factory.rs index 25f5b1558..6eb5f9406 100644 --- a/crates/axum-utils/src/http_client_factory.rs +++ b/crates/axum-utils/src/http_client_factory.rs @@ -56,7 +56,7 @@ impl HttpClientFactory { Ok(layer.layer(client)) } - /// Constructs a new [`HttpService`], suitable for [`mas_oidc_client`] + /// Constructs a new [`HttpService`], suitable for `mas-oidc-client` /// /// # Errors /// diff --git a/crates/axum-utils/src/session.rs b/crates/axum-utils/src/session.rs index a63c22668..c4fece7b0 100644 --- a/crates/axum-utils/src/session.rs +++ b/crates/axum-utils/src/session.rs @@ -14,9 +14,8 @@ use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use mas_data_model::BrowserSession; -use mas_storage::{user::lookup_active_session, DatabaseError}; +use mas_storage::{user::BrowserSessionRepository, RepositoryAccess}; use serde::{Deserialize, Serialize}; -use sqlx::{Executor, Postgres}; use ulid::Ulid; use crate::CookieExt; @@ -44,18 +43,24 @@ impl SessionInfo { } /// Load the [`BrowserSession`] from database - pub async fn load_session( + pub async fn load_session( &self, - executor: impl Executor<'_, Database = Postgres>, - ) -> Result, DatabaseError> { + repo: &mut impl RepositoryAccess, + ) -> Result, E> { let session_id = if let Some(id) = self.current { id } else { return Ok(None); }; - let res = lookup_active_session(executor, session_id).await?; - Ok(res) + let maybe_session = repo + .browser_session() + .lookup(session_id) + .await? + // Ensure that the session is still active + .filter(BrowserSession::active); + + Ok(maybe_session) } } diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index 923ef34d7..c9bc537c1 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -27,9 +27,11 @@ use axum::{ use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName}; use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode}; use mas_data_model::Session; -use mas_storage::{oauth2::access_token::lookup_active_access_token, DatabaseError}; +use mas_storage::{ + oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository}, + Clock, RepositoryAccess, +}; use serde::{de::DeserializeOwned, Deserialize}; -use sqlx::PgConnection; use thiserror::Error; #[derive(Debug, Deserialize)] @@ -49,16 +51,24 @@ enum AccessToken { } impl AccessToken { - pub async fn fetch( + async fn fetch( &self, - conn: &mut PgConnection, - ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> { + repo: &mut impl RepositoryAccess, + ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> { let token = match self { AccessToken::Form(t) | AccessToken::Header(t) => t, AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), }; - let (token, session) = lookup_active_access_token(conn, token.as_str()) + let token = repo + .oauth2_access_token() + .find_by_token(token.as_str()) + .await? + .ok_or(AuthorizationVerificationError::InvalidToken)?; + + let session = repo + .oauth2_session() + .lookup(token.session_id) .await? .ok_or(AuthorizationVerificationError::InvalidToken)?; @@ -74,26 +84,36 @@ pub struct UserAuthorization { impl UserAuthorization { // TODO: take scopes to validate as parameter - pub async fn protected_form( + pub async fn protected_form( self, - conn: &mut PgConnection, - ) -> Result<(Session, F), AuthorizationVerificationError> { + repo: &mut impl RepositoryAccess, + clock: &impl Clock, + ) -> Result<(Session, F), AuthorizationVerificationError> { let form = match self.form { Some(f) => f, None => return Err(AuthorizationVerificationError::MissingForm), }; - let (_token, session) = self.access_token.fetch(conn).await?; + let (token, session) = self.access_token.fetch(repo).await?; + + if !token.is_valid(clock.now()) || !session.is_valid() { + return Err(AuthorizationVerificationError::InvalidToken); + } Ok((session, form)) } // TODO: take scopes to validate as parameter - pub async fn protected( + pub async fn protected( self, - conn: &mut PgConnection, - ) -> Result { - let (_token, session) = self.access_token.fetch(conn).await?; + repo: &mut impl RepositoryAccess, + clock: &impl Clock, + ) -> Result> { + let (token, session) = self.access_token.fetch(repo).await?; + + if !token.is_valid(clock.now()) || !session.is_valid() { + return Err(AuthorizationVerificationError::InvalidToken); + } Ok(session) } @@ -107,7 +127,7 @@ pub enum UserAuthorizationError { } #[derive(Debug, Error)] -pub enum AuthorizationVerificationError { +pub enum AuthorizationVerificationError { #[error("missing token")] MissingToken, @@ -118,7 +138,7 @@ pub enum AuthorizationVerificationError { MissingForm, #[error(transparent)] - Internal(#[from] DatabaseError), + Internal(#[from] E), } enum BearerError { @@ -226,7 +246,10 @@ impl IntoResponse for UserAuthorizationError { } } -impl IntoResponse for AuthorizationVerificationError { +impl IntoResponse for AuthorizationVerificationError +where + E: ToString, +{ fn into_response(self) -> Response { match self { Self::MissingForm | Self::MissingToken => { diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 1c8136cdf..1ae92c18c 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -50,6 +50,7 @@ mas-policy = { path = "../policy" } mas-router = { path = "../router" } mas-spa = { path = "../spa" } mas-storage = { path = "../storage" } +mas-storage-pg = { path = "../storage-pg" } mas-tasks = { path = "../tasks" } mas-templates = { path = "../templates" } oauth2-types = { path = "../oauth2-types" } diff --git a/crates/cli/src/commands/config.rs b/crates/cli/src/commands/config.rs index a1d6376e0..f8e3aaa54 100644 --- a/crates/cli/src/commands/config.rs +++ b/crates/cli/src/commands/config.rs @@ -15,7 +15,7 @@ use clap::Parser; use mas_config::{ConfigurationSection, RootConfig}; use rand::SeedableRng; -use tracing::info; +use tracing::{info, info_span}; #[derive(Parser, Debug)] pub(super) struct Options { @@ -40,6 +40,8 @@ impl Options { use Subcommand as SC; match &self.subcommand { SC::Dump => { + let _span = info_span!("cli.config.dump").entered(); + let config: RootConfig = root.load_config()?; serde_yaml::to_writer(std::io::stdout(), &config)?; @@ -47,11 +49,15 @@ impl Options { Ok(()) } SC::Check => { + let _span = info_span!("cli.config.check").entered(); + let _config: RootConfig = root.load_config()?; info!(path = ?root.config, "Configuration file looks good"); Ok(()) } SC::Generate => { + let _span = info_span!("cli.config.generate").entered(); + // XXX: we should disallow SeedableRng::from_entropy let rng = rand_chacha::ChaChaRng::from_entropy(); let config = RootConfig::load_and_generate(rng).await?; diff --git a/crates/cli/src/commands/database.rs b/crates/cli/src/commands/database.rs index 338fdbf91..0e4d68af6 100644 --- a/crates/cli/src/commands/database.rs +++ b/crates/cli/src/commands/database.rs @@ -15,7 +15,8 @@ use anyhow::Context; use clap::Parser; use mas_config::DatabaseConfig; -use mas_storage::MIGRATOR; +use mas_storage_pg::MIGRATOR; +use tracing::{info_span, Instrument}; use crate::util::database_from_config; @@ -33,12 +34,14 @@ enum Subcommand { impl Options { pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { + let _span = info_span!("cli.database.migrate").entered(); let config: DatabaseConfig = root.load_config()?; let pool = database_from_config(&config).await?; // Run pending migrations MIGRATOR .run(&pool) + .instrument(info_span!("db.migrate")) .await .context("could not run migrations")?; diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 7b2d6cd8f..24f9262ce 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -19,7 +19,7 @@ use mas_handlers::HttpClientFactory; use mas_http::HttpServiceExt; use tokio::io::AsyncWriteExt; use tower::{Service, ServiceExt}; -use tracing::info; +use tracing::{info, info_span}; use crate::util::policy_factory_from_config; @@ -74,6 +74,7 @@ impl Options { json: false, url, } => { + let _span = info_span!("cli.debug.http").entered(); let mut client = http_client_factory.client("cli-debug-http").await?; let request = hyper::Request::builder() .uri(url) @@ -98,6 +99,7 @@ impl Options { json: true, url, } => { + let _span = info_span!("cli.debug.http").entered(); let mut client = http_client_factory .client("cli-debug-http") .await? @@ -122,6 +124,7 @@ impl Options { } SC::Policy => { + let _span = info_span!("cli.debug.policy").entered(); let config: PolicyConfig = root.load_config()?; info!("Loading and compiling the policy module"); let policy_factory = policy_factory_from_config(&config).await?; diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 5472b78c7..b685a167d 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -18,15 +18,15 @@ use mas_config::{DatabaseConfig, PasswordsConfig, RootConfig}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_router::UrlBuilder; use mas_storage::{ - oauth2::client::{insert_client_from_config, lookup_client, truncate_clients}, - user::{ - add_user_password, lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, - }, - Clock, + oauth2::OAuth2ClientRepository, + upstream_oauth2::UpstreamOAuthProviderRepository, + user::{UserEmailRepository, UserPasswordRepository, UserRepository}, + Repository, RepositoryAccess, SystemClock, }; +use mas_storage_pg::PgRepository; use oauth2_types::scope::Scope; use rand::SeedableRng; -use tracing::{info, warn}; +use tracing::{info, info_span, warn}; use crate::util::{database_from_config, password_manager_from_config}; @@ -147,9 +147,9 @@ enum Subcommand { /// Import clients from config ImportClients { - /// Remove all clients before importing + /// Update existing clients #[arg(long)] - truncate: bool, + update: bool, }, /// Set a user password @@ -188,20 +188,25 @@ impl Options { #[allow(clippy::too_many_lines)] pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { use Subcommand as SC; - let clock = Clock::default(); + let clock = SystemClock::default(); // XXX: we should disallow SeedableRng::from_entropy let mut rng = rand_chacha::ChaChaRng::from_entropy(); match &self.subcommand { SC::SetPassword { username, password } => { + let _span = + info_span!("cli.manage.set_password", user.username = %username).entered(); + let database_config: DatabaseConfig = root.load_config()?; let passwords_config: PasswordsConfig = root.load_config()?; let pool = database_from_config(&database_config).await?; let password_manager = password_manager_from_config(&passwords_config).await?; - let mut txn = pool.begin().await?; - let user = lookup_user_by_username(&mut txn, username) + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + let user = repo + .user() + .find_by_username(username) .await? .context("User not found")?; @@ -209,89 +214,96 @@ impl Options { let (version, hashed_password) = password_manager.hash(&mut rng, password).await?; - add_user_password( - &mut txn, - &mut rng, - &clock, - &user, - version, - hashed_password, - None, - ) - .await?; + repo.user_password() + .add(&mut rng, &clock, &user, version, hashed_password, None) + .await?; info!(%user.id, %user.username, "Password changed"); - txn.commit().await?; + repo.save().await?; Ok(()) } SC::VerifyEmail { username, email } => { + let _span = info_span!( + "cli.manage.verify_email", + user.username = username, + user_email.email = email + ) + .entered(); + let config: DatabaseConfig = root.load_config()?; let pool = database_from_config(&config).await?; - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); - let user = lookup_user_by_username(&mut txn, username) + let user = repo + .user() + .find_by_username(username) .await? .context("User not found")?; - let email = lookup_user_email(&mut txn, &user, email) + + let email = repo + .user_email() + .find(&user, email) .await? .context("Email not found")?; - let email = mark_user_email_as_verified(&mut txn, &clock, email).await?; + let email = repo.user_email().mark_as_verified(&clock, email).await?; - txn.commit().await?; + repo.save().await?; info!(?email, "Email marked as verified"); Ok(()) } - SC::ImportClients { truncate } => { + SC::ImportClients { update } => { + let _span = info_span!("cli.manage.import_clients").entered(); + let config: RootConfig = root.load_config()?; let pool = database_from_config(&config.database).await?; let encrypter = config.secrets.encrypter(); - let mut txn = pool.begin().await?; - - if *truncate { - warn!("Removing all clients first"); - truncate_clients(&mut txn).await?; - } + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); for client in config.clients.iter() { let client_id = client.client_id; - let res = lookup_client(&mut txn, client_id).await?; - if res.is_some() { - warn!(%client_id, "Skipping already imported client"); + + let existing = repo.oauth2_client().lookup(client_id).await?.is_some(); + if !update && existing { + warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients."); continue; } - info!(%client_id, "Importing client"); + if existing { + info!(%client_id, "Updating client"); + } else { + info!(%client_id, "Importing client"); + } + let client_secret = client.client_secret(); let client_auth_method = client.client_auth_method(); let jwks = client.jwks(); let jwks_uri = client.jwks_uri(); - let redirect_uris = &client.redirect_uris; // TODO: should be moved somewhere else let encrypted_client_secret = client_secret .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .transpose()?; - insert_client_from_config( - &mut txn, - &mut rng, - &clock, - client_id, - client_auth_method, - encrypted_client_secret.as_deref(), - jwks, - jwks_uri, - redirect_uris, - ) - .await?; + repo.oauth2_client() + .add_from_config( + &mut rng, + &clock, + client_id, + client_auth_method, + encrypted_client_secret, + jwks.cloned(), + jwks_uri.cloned(), + client.redirect_uris.clone(), + ) + .await?; } - txn.commit().await?; + repo.save().await?; Ok(()) } @@ -304,11 +316,18 @@ impl Options { client_secret, signing_alg, } => { + let _span = info_span!( + "cli.manage.add_oauth_upstream", + upstream_oauth_provider.issuer = issuer, + upstream_oauth_provider.client_id = client_id, + ) + .entered(); + let config: RootConfig = root.load_config()?; let encrypter = config.secrets.encrypter(); let pool = database_from_config(&config.database).await?; let url_builder = UrlBuilder::new(config.http.public_base); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let requires_client_secret = token_endpoint_auth_method.requires_client_secret(); @@ -329,18 +348,19 @@ impl Options { .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .transpose()?; - let provider = mas_storage::upstream_oauth2::add_provider( - &mut conn, - &mut rng, - &clock, - issuer.clone(), - scope.clone(), - token_endpoint_auth_method, - token_endpoint_signing_alg, - client_id.clone(), - encrypted_client_secret, - ) - .await?; + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + issuer.clone(), + scope.clone(), + token_endpoint_auth_method, + token_endpoint_signing_alg, + client_id.clone(), + encrypted_client_secret, + ) + .await?; let redirect_uri = url_builder.upstream_oauth_callback(provider.id); let auth_uri = url_builder.upstream_oauth_authorize(provider.id); diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 606547827..002309536 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,10 +21,10 @@ use mas_config::RootConfig; use mas_handlers::{AppState, HttpClientFactory, MatrixHomeserver}; use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_router::UrlBuilder; -use mas_storage::MIGRATOR; +use mas_storage_pg::MIGRATOR; use mas_tasks::TaskQueue; use tokio::signal::unix::SignalKind; -use tracing::{info, warn}; +use tracing::{info, info_span, warn, Instrument}; use crate::util::{ database_from_config, mailer_from_config, password_manager_from_config, @@ -45,6 +45,7 @@ pub(super) struct Options { impl Options { #[allow(clippy::too_many_lines)] pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { + let span = info_span!("cli.run.init").entered(); let config: RootConfig = root.load_config()?; // Connect to the database @@ -55,6 +56,7 @@ impl Options { info!("Running pending migrations"); MIGRATOR .run(&pool) + .instrument(info_span!("db.migrate")) .await .context("could not run migrations")?; } @@ -100,7 +102,7 @@ impl Options { watch_templates(&templates).await?; } - let graphql_schema = mas_handlers::graphql_schema(&pool); + let graphql_schema = mas_handlers::graphql_schema(); // Maximum 50 outgoing HTTP requests at a time let http_client_factory = HttpClientFactory::new(50); @@ -186,6 +188,8 @@ impl Options { .with_signal(SignalKind::terminate())? .with_signal(SignalKind::interrupt())?; + span.exit(); + mas_listener::server::run_servers(servers, shutdown).await; Ok(()) diff --git a/crates/cli/src/commands/templates.rs b/crates/cli/src/commands/templates.rs index 186a097e9..6f09b7519 100644 --- a/crates/cli/src/commands/templates.rs +++ b/crates/cli/src/commands/templates.rs @@ -14,9 +14,10 @@ use camino::Utf8PathBuf; use clap::Parser; -use mas_storage::Clock; +use mas_storage::{Clock, SystemClock}; use mas_templates::Templates; use rand::SeedableRng; +use tracing::info_span; #[derive(Parser, Debug)] pub(super) struct Options { @@ -38,7 +39,9 @@ impl Options { use Subcommand as SC; match &self.subcommand { SC::Check { path } => { - let clock = Clock::default(); + let _span = info_span!("cli.templates.check").entered(); + + let clock = SystemClock::default(); // XXX: we should disallow SeedableRng::from_entropy let mut rng = rand_chacha::ChaChaRng::from_entropy(); let url_builder = mas_router::UrlBuilder::new("https://example.com/".parse()?); diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index ec525478f..b6485e951 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -110,6 +110,7 @@ pub async fn templates_from_config( Templates::load(config.path.clone(), url_builder.clone()).await } +#[tracing::instrument(name = "db.connect", skip_all, err(Debug))] pub async fn database_from_config(config: &DatabaseConfig) -> Result { let mut options = match &config.options { DatabaseConnectConfig::Uri { uri } => uri diff --git a/crates/config/src/sections/secrets.rs b/crates/config/src/sections/secrets.rs index 3a0f79a41..afe624f46 100644 --- a/crates/config/src/sections/secrets.rs +++ b/crates/config/src/sections/secrets.rs @@ -86,6 +86,7 @@ impl SecretsConfig { /// # Errors /// /// Returns an error when a key could not be imported + #[tracing::instrument(name = "secrets.load", skip_all, err(Debug))] pub async fn key_store(&self) -> anyhow::Result { let mut keys = Vec::with_capacity(self.keys.len()); for item in &self.keys { diff --git a/crates/data-model/Cargo.toml b/crates/data-model/Cargo.toml index 661e9bbcc..b96d7b83a 100644 --- a/crates/data-model/Cargo.toml +++ b/crates/data-model/Cargo.toml @@ -11,8 +11,8 @@ thiserror = "1.0.38" serde = "1.0.152" url = { version = "2.3.1", features = ["serde"] } crc = "3.0.0" +ulid = { version = "1.0.0", features = ["serde"] } rand = "0.8.5" -ulid = "1.0.0" rand_chacha = "0.3.1" mas-iana = { path = "../iana" } diff --git a/crates/data-model/src/compat.rs b/crates/data-model/src/compat/device.rs similarity index 62% rename from crates/data-model/src/compat.rs rename to crates/data-model/src/compat/device.rs index d6f772db3..eebfd9eda 100644 --- a/crates/data-model/src/compat.rs +++ b/crates/data-model/src/compat/device.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,18 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Utc}; use oauth2_types::scope::ScopeToken; use rand::{ distributions::{Alphanumeric, DistString}, - Rng, + RngCore, }; use serde::Serialize; use thiserror::Error; -use ulid::Ulid; -use url::Url; - -use crate::User; static DEVICE_ID_LENGTH: usize = 10; @@ -53,7 +48,7 @@ impl Device { } /// Generate a random device ID - pub fn generate(rng: &mut R) -> Self { + pub fn generate(rng: &mut R) -> Self { let id: String = Alphanumeric.sample_string(rng, DEVICE_ID_LENGTH); Self { id } } @@ -81,50 +76,3 @@ impl TryFrom for Device { Ok(Self { id }) } } - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct CompatSession { - pub id: Ulid, - pub user: User, - pub device: Device, - pub created_at: DateTime, - pub finished_at: Option>, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CompatAccessToken { - pub id: Ulid, - pub token: String, - pub created_at: DateTime, - pub expires_at: Option>, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CompatRefreshToken { - pub id: Ulid, - pub token: String, - pub created_at: DateTime, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub enum CompatSsoLoginState { - Pending, - Fulfilled { - fulfilled_at: DateTime, - session: CompatSession, - }, - Exchanged { - fulfilled_at: DateTime, - exchanged_at: DateTime, - session: CompatSession, - }, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct CompatSsoLogin { - pub id: Ulid, - pub redirect_uri: Url, - pub login_token: String, - pub created_at: DateTime, - pub state: CompatSsoLoginState, -} diff --git a/crates/data-model/src/compat/mod.rs b/crates/data-model/src/compat/mod.rs new file mode 100644 index 000000000..d0e560c70 --- /dev/null +++ b/crates/data-model/src/compat/mod.rs @@ -0,0 +1,106 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use ulid::Ulid; + +mod device; +mod session; +mod sso_login; + +pub use self::{ + device::Device, + session::{CompatSession, CompatSessionState}, + sso_login::{CompatSsoLogin, CompatSsoLoginState}, +}; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CompatAccessToken { + pub id: Ulid, + pub session_id: Ulid, + pub token: String, + pub created_at: DateTime, + pub expires_at: Option>, +} + +impl CompatAccessToken { + #[must_use] + pub fn is_valid(&self, now: DateTime) -> bool { + if let Some(expires_at) = self.expires_at { + expires_at > now + } else { + true + } + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum CompatRefreshTokenState { + #[default] + Valid, + Consumed { + consumed_at: DateTime, + }, +} + +impl CompatRefreshTokenState { + /// Returns `true` if the compat refresh token state is [`Valid`]. + /// + /// [`Valid`]: CompatRefreshTokenState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the compat refresh token state is [`Consumed`]. + /// + /// [`Consumed`]: CompatRefreshTokenState::Consumed + #[must_use] + pub fn is_consumed(&self) -> bool { + matches!(self, Self::Consumed { .. }) + } + + pub fn consume(self, consumed_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Consumed { consumed_at }), + Self::Consumed { .. } => Err(InvalidTransitionError), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CompatRefreshToken { + pub id: Ulid, + pub state: CompatRefreshTokenState, + pub session_id: Ulid, + pub access_token_id: Ulid, + pub token: String, + pub created_at: DateTime, +} + +impl std::ops::Deref for CompatRefreshToken { + type Target = CompatRefreshTokenState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl CompatRefreshToken { + pub fn consume(mut self, consumed_at: DateTime) -> Result { + self.state = self.state.consume(consumed_at)?; + Ok(self) + } +} diff --git a/crates/data-model/src/compat/session.rs b/crates/data-model/src/compat/session.rs new file mode 100644 index 000000000..1dbd07228 --- /dev/null +++ b/crates/data-model/src/compat/session.rs @@ -0,0 +1,86 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; + +use super::Device; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] +pub enum CompatSessionState { + #[default] + Valid, + Finished { + finished_at: DateTime, + }, +} + +impl CompatSessionState { + /// Returns `true` if the compta session state is [`Valid`]. + /// + /// [`Valid`]: CompatSessionState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the compta session state is [`Finished`]. + /// + /// [`Finished`]: CompatSessionState::Finished + #[must_use] + pub fn is_finished(&self) -> bool { + matches!(self, Self::Finished { .. }) + } + + pub fn finish(self, finished_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Finished { finished_at }), + Self::Finished { .. } => Err(InvalidTransitionError), + } + } + + #[must_use] + pub fn finished_at(&self) -> Option> { + match self { + CompatSessionState::Valid => None, + CompatSessionState::Finished { finished_at } => Some(*finished_at), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct CompatSession { + pub id: Ulid, + pub state: CompatSessionState, + pub user_id: Ulid, + pub device: Device, + pub created_at: DateTime, +} + +impl std::ops::Deref for CompatSession { + type Target = CompatSessionState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl CompatSession { + pub fn finish(mut self, finished_at: DateTime) -> Result { + self.state = self.state.finish(finished_at)?; + Ok(self) + } +} diff --git a/crates/data-model/src/compat/sso_login.rs b/crates/data-model/src/compat/sso_login.rs new file mode 100644 index 000000000..ccc7bb370 --- /dev/null +++ b/crates/data-model/src/compat/sso_login.rs @@ -0,0 +1,151 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; +use url::Url; + +use super::CompatSession; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] +pub enum CompatSsoLoginState { + #[default] + Pending, + Fulfilled { + fulfilled_at: DateTime, + session_id: Ulid, + }, + Exchanged { + fulfilled_at: DateTime, + exchanged_at: DateTime, + session_id: Ulid, + }, +} + +impl CompatSsoLoginState { + /// Returns `true` if the compat sso login state is [`Pending`]. + /// + /// [`Pending`]: CompatSsoLoginState::Pending + #[must_use] + pub fn is_pending(&self) -> bool { + matches!(self, Self::Pending) + } + + /// Returns `true` if the compat sso login state is [`Fulfilled`]. + /// + /// [`Fulfilled`]: CompatSsoLoginState::Fulfilled + #[must_use] + pub fn is_fulfilled(&self) -> bool { + matches!(self, Self::Fulfilled { .. }) + } + + /// Returns `true` if the compat sso login state is [`Exchanged`]. + /// + /// [`Exchanged`]: CompatSsoLoginState::Exchanged + #[must_use] + pub fn is_exchanged(&self) -> bool { + matches!(self, Self::Exchanged { .. }) + } + + #[must_use] + pub fn fulfilled_at(&self) -> Option> { + match self { + Self::Pending => None, + Self::Fulfilled { fulfilled_at, .. } | Self::Exchanged { fulfilled_at, .. } => { + Some(*fulfilled_at) + } + } + } + + #[must_use] + pub fn exchanged_at(&self) -> Option> { + match self { + Self::Pending | Self::Fulfilled { .. } => None, + Self::Exchanged { exchanged_at, .. } => Some(*exchanged_at), + } + } + + #[must_use] + pub fn session_id(&self) -> Option { + match self { + Self::Pending => None, + Self::Fulfilled { session_id, .. } | Self::Exchanged { session_id, .. } => { + Some(*session_id) + } + } + } + + pub fn fulfill( + self, + fulfilled_at: DateTime, + session: &CompatSession, + ) -> Result { + match self { + Self::Pending => Ok(Self::Fulfilled { + fulfilled_at, + session_id: session.id, + }), + Self::Fulfilled { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError), + } + } + + pub fn exchange(self, exchanged_at: DateTime) -> Result { + match self { + Self::Fulfilled { + fulfilled_at, + session_id, + } => Ok(Self::Exchanged { + fulfilled_at, + exchanged_at, + session_id, + }), + Self::Pending { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct CompatSsoLogin { + pub id: Ulid, + pub redirect_uri: Url, + pub login_token: String, + pub created_at: DateTime, + pub state: CompatSsoLoginState, +} + +impl std::ops::Deref for CompatSsoLogin { + type Target = CompatSsoLoginState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl CompatSsoLogin { + pub fn fulfill( + mut self, + fulfilled_at: DateTime, + session: &CompatSession, + ) -> Result { + self.state = self.state.fulfill(fulfilled_at, session)?; + Ok(self) + } + + pub fn exchange(mut self, exchanged_at: DateTime) -> Result { + self.state = self.state.exchange(exchanged_at)?; + Ok(self) + } +} diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index d104642ea..bde11fbed 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,24 +23,33 @@ clippy::type_repetition_in_bounds )] +use thiserror::Error; + pub(crate) mod compat; pub(crate) mod oauth2; pub(crate) mod tokens; pub(crate) mod upstream_oauth2; pub(crate) mod users; +#[derive(Debug, Error)] +#[error("invalid state transition")] +pub struct InvalidTransitionError; + pub use self::{ compat::{ - CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState, - Device, + CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, + CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, }, oauth2::{ AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, - InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, + InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState, + }, + tokens::{ + AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType, }, - tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, upstream_oauth2::{ - UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider, + UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, + UpstreamOAuthLink, UpstreamOAuthProvider, }, users::{ Authentication, BrowserSession, Password, User, UserEmail, UserEmailVerification, diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index cb85a2654..5638ca10a 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -21,11 +21,11 @@ use oauth2_types::{ requests::ResponseMode, }; use serde::Serialize; -use thiserror::Error; use ulid::Ulid; use url::Url; -use super::{client::Client, session::Session}; +use super::session::Session; +use crate::InvalidTransitionError; #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct Pkce { @@ -53,21 +53,17 @@ pub struct AuthorizationCode { pub pkce: Option, } -#[derive(Debug, Error)] -#[error("invalid state transition")] -pub struct InvalidTransitionError; - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)] #[serde(tag = "stage", rename_all = "lowercase")] pub enum AuthorizationGrantStage { #[default] Pending, Fulfilled { - session: Session, + session_id: Ulid, fulfilled_at: DateTime, }, Exchanged { - session: Session, + session_id: Ulid, fulfilled_at: DateTime, exchanged_at: DateTime, }, @@ -82,35 +78,35 @@ impl AuthorizationGrantStage { Self::Pending } - pub fn fulfill( + fn fulfill( self, fulfilled_at: DateTime, - session: Session, + session: &Session, ) -> Result { match self { Self::Pending => Ok(Self::Fulfilled { fulfilled_at, - session, + session_id: session.id, }), _ => Err(InvalidTransitionError), } } - pub fn exchange(self, exchanged_at: DateTime) -> Result { + fn exchange(self, exchanged_at: DateTime) -> Result { match self { Self::Fulfilled { fulfilled_at, - session, + session_id, } => Ok(Self::Exchanged { fulfilled_at, exchanged_at, - session, + session_id, }), _ => Err(InvalidTransitionError), } } - pub fn cancel(self, cancelled_at: DateTime) -> Result { + fn cancel(self, cancelled_at: DateTime) -> Result { match self { Self::Pending => Ok(Self::Cancelled { cancelled_at }), _ => Err(InvalidTransitionError), @@ -124,6 +120,22 @@ impl AuthorizationGrantStage { pub fn is_pending(&self) -> bool { matches!(self, Self::Pending) } + + /// Returns `true` if the authorization grant stage is [`Fulfilled`]. + /// + /// [`Fulfilled`]: AuthorizationGrantStage::Fulfilled + #[must_use] + pub fn is_fulfilled(&self) -> bool { + matches!(self, Self::Fulfilled { .. }) + } + + /// Returns `true` if the authorization grant stage is [`Exchanged`]. + /// + /// [`Exchanged`]: AuthorizationGrantStage::Exchanged + #[must_use] + pub fn is_exchanged(&self) -> bool { + matches!(self, Self::Exchanged { .. }) + } } #[derive(Debug, Clone, PartialEq, Eq, Serialize)] @@ -132,7 +144,7 @@ pub struct AuthorizationGrant { #[serde(flatten)] pub stage: AuthorizationGrantStage, pub code: Option, - pub client: Client, + pub client_id: Ulid, pub redirect_uri: Url, pub scope: oauth2_types::scope::Scope, pub state: Option, @@ -144,10 +156,38 @@ pub struct AuthorizationGrant { pub requires_consent: bool, } +impl std::ops::Deref for AuthorizationGrant { + type Target = AuthorizationGrantStage; + + fn deref(&self) -> &Self::Target { + &self.stage + } +} + impl AuthorizationGrant { #[must_use] pub fn max_auth_time(&self) -> DateTime { let max_age: Option = self.max_age.map(|x| x.get().into()); self.created_at - Duration::seconds(max_age.unwrap_or(3600 * 24 * 365)) } + + pub fn exchange(mut self, exchanged_at: DateTime) -> Result { + self.stage = self.stage.exchange(exchanged_at)?; + Ok(self) + } + + pub fn fulfill( + mut self, + fulfilled_at: DateTime, + session: &Session, + ) -> Result { + self.stage = self.stage.fulfill(fulfilled_at, session)?; + Ok(self) + } + + // TODO: this is not used? + pub fn cancel(mut self, canceld_at: DateTime) -> Result { + self.stage = self.stage.cancel(canceld_at)?; + Ok(self) + } } diff --git a/crates/data-model/src/oauth2/mod.rs b/crates/data-model/src/oauth2/mod.rs index ef512260a..bc76b091d 100644 --- a/crates/data-model/src/oauth2/mod.rs +++ b/crates/data-model/src/oauth2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,5 +19,5 @@ pub(self) mod session; pub use self::{ authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce}, client::{Client, InvalidRedirectUriError, JwksOrJwksUri}, - session::Session, + session::{Session, SessionState}, }; diff --git a/crates/data-model/src/oauth2/session.rs b/crates/data-model/src/oauth2/session.rs index ff222ca83..68ac821cd 100644 --- a/crates/data-model/src/oauth2/session.rs +++ b/crates/data-model/src/oauth2/session.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,17 +12,76 @@ // See the License for the specific language governing permissions and // limitations under the License. +use chrono::{DateTime, Utc}; use oauth2_types::scope::Scope; use serde::Serialize; use ulid::Ulid; -use super::client::Client; -use crate::users::BrowserSession; +use crate::InvalidTransitionError; + +trait T { + type State; +} + +impl T for Session { + type State = SessionState; +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] +pub enum SessionState { + #[default] + Valid, + Finished { + finished_at: DateTime, + }, +} + +impl SessionState { + /// Returns `true` if the session state is [`Valid`]. + /// + /// [`Valid`]: SessionState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the session state is [`Finished`]. + /// + /// [`Finished`]: SessionState::Finished + #[must_use] + pub fn is_finished(&self) -> bool { + matches!(self, Self::Finished { .. }) + } + + pub fn finish(self, finished_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Finished { finished_at }), + Self::Finished { .. } => Err(InvalidTransitionError), + } + } +} #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct Session { pub id: Ulid, - pub browser_session: BrowserSession, - pub client: Client, + pub state: SessionState, + pub created_at: DateTime, + pub user_session_id: Ulid, + pub client_id: Ulid, pub scope: Scope, } + +impl std::ops::Deref for Session { + type Target = SessionState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl Session { + pub fn finish(mut self, finished_at: DateTime) -> Result { + self.state = self.state.finish(finished_at)?; + Ok(self) + } +} diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index 93b29f6dd..5c7acf34c 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -15,25 +15,135 @@ use chrono::{DateTime, Utc}; use crc::{Crc, CRC_32_ISO_HDLC}; use mas_iana::oauth::OAuthTokenTypeHint; -use rand::{distributions::Alphanumeric, Rng}; +use rand::{distributions::Alphanumeric, Rng, RngCore}; use thiserror::Error; use ulid::Ulid; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum AccessTokenState { + #[default] + Valid, + Revoked { + revoked_at: DateTime, + }, +} + +impl AccessTokenState { + fn revoke(self, revoked_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Revoked { revoked_at }), + Self::Revoked { .. } => Err(InvalidTransitionError), + } + } + + /// Returns `true` if the refresh token state is [`Valid`]. + /// + /// [`Valid`]: AccessTokenState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the refresh token state is [`Revoked`]. + /// + /// [`Revoked`]: AccessTokenState::Revoked + #[must_use] + pub fn is_revoked(&self) -> bool { + matches!(self, Self::Revoked { .. }) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct AccessToken { pub id: Ulid, - pub jti: String, + pub state: AccessTokenState, + pub session_id: Ulid, pub access_token: String, pub created_at: DateTime, pub expires_at: DateTime, } +impl AccessToken { + #[must_use] + pub fn jti(&self) -> String { + self.id.to_string() + } + + #[must_use] + pub fn is_valid(&self, now: DateTime) -> bool { + self.state.is_valid() && self.expires_at > now + } + + pub fn revoke(mut self, revoked_at: DateTime) -> Result { + self.state = self.state.revoke(revoked_at)?; + Ok(self) + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum RefreshTokenState { + #[default] + Valid, + Consumed { + consumed_at: DateTime, + }, +} + +impl RefreshTokenState { + fn consume(self, consumed_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Consumed { consumed_at }), + Self::Consumed { .. } => Err(InvalidTransitionError), + } + } + + /// Returns `true` if the refresh token state is [`Valid`]. + /// + /// [`Valid`]: RefreshTokenState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the refresh token state is [`Consumed`]. + /// + /// [`Consumed`]: RefreshTokenState::Consumed + #[must_use] + pub fn is_consumed(&self) -> bool { + matches!(self, Self::Consumed { .. }) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct RefreshToken { pub id: Ulid, + pub state: RefreshTokenState, pub refresh_token: String, + pub session_id: Ulid, pub created_at: DateTime, - pub access_token: Option, + pub access_token_id: Option, +} + +impl std::ops::Deref for RefreshToken { + type Target = RefreshTokenState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl RefreshToken { + #[must_use] + pub fn jti(&self) -> String { + self.id.to_string() + } + + pub fn consume(mut self, consumed_at: DateTime) -> Result { + self.state = self.state.consume(consumed_at)?; + Ok(self) + } } /// Type of token to generate or validate @@ -80,10 +190,10 @@ impl TokenType { /// use rand::thread_rng; /// use mas_data_model::TokenType::{AccessToken, RefreshToken}; /// - /// AccessToken.generate(thread_rng()); - /// RefreshToken.generate(thread_rng()); + /// AccessToken.generate(&mut thread_rng()); + /// RefreshToken.generate(&mut thread_rng()); /// ``` - pub fn generate(self, rng: impl Rng) -> String { + pub fn generate(self, rng: &mut (impl RngCore + ?Sized)) -> String { let random_part: String = rng .sample_iter(&Alphanumeric) .take(30) diff --git a/crates/data-model/src/upstream_oauth2/link.rs b/crates/data-model/src/upstream_oauth2/link.rs new file mode 100644 index 000000000..c0699173c --- /dev/null +++ b/crates/data-model/src/upstream_oauth2/link.rs @@ -0,0 +1,26 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UpstreamOAuthLink { + pub id: Ulid, + pub provider_id: Ulid, + pub user_id: Option, + pub subject: String, + pub created_at: DateTime, +} diff --git a/crates/data-model/src/upstream_oauth2/mod.rs b/crates/data-model/src/upstream_oauth2/mod.rs index 08fbf6c0b..90780a8bf 100644 --- a/crates/data-model/src/upstream_oauth2/mod.rs +++ b/crates/data-model/src/upstream_oauth2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,55 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Utc}; -use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; -use oauth2_types::scope::Scope; -use serde::Serialize; -use ulid::Ulid; +mod link; +mod provider; +mod session; -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct UpstreamOAuthProvider { - pub id: Ulid, - pub issuer: String, - pub scope: Scope, - pub client_id: String, - pub encrypted_client_secret: Option, - pub token_endpoint_signing_alg: Option, - pub token_endpoint_auth_method: OAuthClientAuthenticationMethod, - pub created_at: DateTime, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct UpstreamOAuthLink { - pub id: Ulid, - pub provider_id: Ulid, - pub user_id: Option, - pub subject: String, - pub created_at: DateTime, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct UpstreamOAuthAuthorizationSession { - pub id: Ulid, - pub provider_id: Ulid, - pub link_id: Option, - pub state: String, - pub code_challenge_verifier: Option, - pub nonce: String, - pub created_at: DateTime, - pub completed_at: Option>, - pub consumed_at: Option>, - pub id_token: Option, -} - -impl UpstreamOAuthAuthorizationSession { - #[must_use] - pub const fn completed(&self) -> bool { - self.completed_at.is_some() - } - - #[must_use] - pub const fn consumed(&self) -> bool { - self.consumed_at.is_some() - } -} +pub use self::{ + link::UpstreamOAuthLink, + provider::UpstreamOAuthProvider, + session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState}, +}; diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs new file mode 100644 index 000000000..919b22217 --- /dev/null +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -0,0 +1,31 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use oauth2_types::scope::Scope; +use serde::Serialize; +use ulid::Ulid; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UpstreamOAuthProvider { + pub id: Ulid, + pub issuer: String, + pub scope: Scope, + pub client_id: String, + pub encrypted_client_secret: Option, + pub token_endpoint_signing_alg: Option, + pub token_endpoint_auth_method: OAuthClientAuthenticationMethod, + pub created_at: DateTime, +} diff --git a/crates/data-model/src/upstream_oauth2/session.rs b/crates/data-model/src/upstream_oauth2/session.rs new file mode 100644 index 000000000..9ce612668 --- /dev/null +++ b/crates/data-model/src/upstream_oauth2/session.rs @@ -0,0 +1,170 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; + +use super::UpstreamOAuthLink; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] +pub enum UpstreamOAuthAuthorizationSessionState { + #[default] + Pending, + Completed { + completed_at: DateTime, + link_id: Ulid, + id_token: Option, + }, + Consumed { + completed_at: DateTime, + consumed_at: DateTime, + link_id: Ulid, + id_token: Option, + }, +} + +impl UpstreamOAuthAuthorizationSessionState { + pub fn complete( + self, + completed_at: DateTime, + link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result { + match self { + Self::Pending => Ok(Self::Completed { + completed_at, + link_id: link.id, + id_token, + }), + Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError), + } + } + + pub fn consume(self, consumed_at: DateTime) -> Result { + match self { + Self::Completed { + completed_at, + link_id, + id_token, + } => Ok(Self::Consumed { + completed_at, + link_id, + consumed_at, + id_token, + }), + Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError), + } + } + + #[must_use] + pub fn link_id(&self) -> Option { + match self { + Self::Pending => None, + Self::Completed { link_id, .. } | Self::Consumed { link_id, .. } => Some(*link_id), + } + } + + #[must_use] + pub fn completed_at(&self) -> Option> { + match self { + Self::Pending => None, + Self::Completed { completed_at, .. } | Self::Consumed { completed_at, .. } => { + Some(*completed_at) + } + } + } + + #[must_use] + pub fn id_token(&self) -> Option<&str> { + match self { + Self::Pending => None, + Self::Completed { id_token, .. } | Self::Consumed { id_token, .. } => { + id_token.as_deref() + } + } + } + + #[must_use] + pub fn consumed_at(&self) -> Option> { + match self { + Self::Pending | Self::Completed { .. } => None, + Self::Consumed { consumed_at, .. } => Some(*consumed_at), + } + } + + /// Returns `true` if the upstream oauth authorization session state is + /// [`Pending`]. + /// + /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending + #[must_use] + pub fn is_pending(&self) -> bool { + matches!(self, Self::Pending) + } + + /// Returns `true` if the upstream oauth authorization session state is + /// [`Completed`]. + /// + /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed + #[must_use] + pub fn is_completed(&self) -> bool { + matches!(self, Self::Completed { .. }) + } + + /// Returns `true` if the upstream oauth authorization session state is + /// [`Consumed`]. + /// + /// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed + #[must_use] + pub fn is_consumed(&self) -> bool { + matches!(self, Self::Consumed { .. }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UpstreamOAuthAuthorizationSession { + pub id: Ulid, + pub state: UpstreamOAuthAuthorizationSessionState, + pub provider_id: Ulid, + pub state_str: String, + pub code_challenge_verifier: Option, + pub nonce: String, + pub created_at: DateTime, +} + +impl std::ops::Deref for UpstreamOAuthAuthorizationSession { + type Target = UpstreamOAuthAuthorizationSessionState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl UpstreamOAuthAuthorizationSession { + pub fn complete( + mut self, + completed_at: DateTime, + link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result { + self.state = self.state.complete(completed_at, link, id_token)?; + Ok(self) + } + + pub fn consume(mut self, consumed_at: DateTime) -> Result { + self.state = self.state.consume(consumed_at)?; + Ok(self) + } +} diff --git a/crates/data-model/src/users.rs b/crates/data-model/src/users.rs index 4d9c884a0..638ed77e3 100644 --- a/crates/data-model/src/users.rs +++ b/crates/data-model/src/users.rs @@ -22,7 +22,7 @@ pub struct User { pub id: Ulid, pub username: String, pub sub: String, - pub primary_email: Option, + pub primary_user_email_id: Option, } impl User { @@ -32,7 +32,7 @@ impl User { id: Ulid::from_datetime_with_source(now.into(), rng), username: "john".to_owned(), sub: "123-456".to_owned(), - primary_email: None, + primary_user_email_id: None, }] } } @@ -57,10 +57,16 @@ pub struct BrowserSession { pub id: Ulid, pub user: User, pub created_at: DateTime, + pub finished_at: Option>, pub last_authentication: Option, } impl BrowserSession { + #[must_use] + pub fn active(&self) -> bool { + self.finished_at.is_none() + } + #[must_use] pub fn was_authenticated_after(&self, after: DateTime) -> bool { if let Some(auth) = &self.last_authentication { @@ -80,6 +86,7 @@ impl BrowserSession { id: Ulid::from_datetime_with_source(now.into(), rng), user, created_at: now, + finished_at: None, last_authentication: None, }) .collect() @@ -89,6 +96,7 @@ impl BrowserSession { #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct UserEmail { pub id: Ulid, + pub user_id: Ulid, pub email: String, pub created_at: DateTime, pub confirmed_at: Option>, @@ -100,12 +108,14 @@ impl UserEmail { vec![ Self { id: Ulid::from_datetime_with_source(now.into(), rng), + user_id: Ulid::from_datetime_with_source(now.into(), rng), email: "alice@example.com".to_owned(), created_at: now, confirmed_at: Some(now), }, Self { id: Ulid::from_datetime_with_source(now.into(), rng), + user_id: Ulid::from_datetime_with_source(now.into(), rng), email: "bob@example.com".to_owned(), created_at: now, confirmed_at: None, @@ -124,7 +134,7 @@ pub enum UserEmailVerificationState { #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct UserEmailVerification { pub id: Ulid, - pub email: UserEmail, + pub user_email_id: Ulid, pub code: String, pub created_at: DateTime, pub state: UserEmailVerificationState, @@ -152,8 +162,8 @@ impl UserEmailVerification { .into_iter() .map(move |email| Self { id: Ulid::from_datetime_with_source(now.into(), &mut rng), + user_email_id: email.id, code: "123456".to_owned(), - email, created_at: now - Duration::minutes(10), state: state.clone(), }) diff --git a/crates/email/src/mailer.rs b/crates/email/src/mailer.rs index 177cc8d10..979a9f7bb 100644 --- a/crates/email/src/mailer.rs +++ b/crates/email/src/mailer.rs @@ -100,10 +100,10 @@ impl Mailer { /// /// Will return `Err` if the email failed rendering or failed sending #[tracing::instrument( + name = "email.verification.send", skip_all, fields( email.to = %to, - email.from = %self.from, user.id = %context.user().id, user_email_verification.id = %context.verification().id, user_email_verification.code = context.verification().code, @@ -125,6 +125,7 @@ impl Mailer { /// # Errors /// /// Returns an error if the connection failed + #[tracing::instrument(name = "email.test_connection", skip_all, err)] pub async fn test_connection(&self) -> Result<(), crate::transport::Error> { self.transport.test_connection().await } diff --git a/crates/graphql/Cargo.toml b/crates/graphql/Cargo.toml index 3cb5fd69c..f55b781e0 100644 --- a/crates/graphql/Cargo.toml +++ b/crates/graphql/Cargo.toml @@ -10,7 +10,7 @@ anyhow = "1.0.68" async-graphql = { version = "5.0.5", features = ["chrono", "url"] } chrono = "0.4.23" serde = { version = "1.0.152", features = ["derive"] } -sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } +tokio = { version = "1.23.0", features = ["sync"] } thiserror = "1.0.38" tracing = "0.1.37" ulid = "1.0.0" diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 1e691a96e..ca3745653 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -30,8 +30,14 @@ use async_graphql::{ connection::{query, Connection, Edge, OpaqueCursor}, Context, Description, EmptyMutation, EmptySubscription, ID, }; +use mas_storage::{ + oauth2::OAuth2ClientRepository, + upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, + user::{BrowserSessionRepository, UserEmailRepository}, + BoxRepository, Pagination, +}; use model::CreationEvent; -use sqlx::PgPool; +use tokio::sync::Mutex; use self::model::{ BrowserSession, Cursor, Node, NodeCursor, NodeType, OAuth2Client, UpstreamOAuth2Link, @@ -87,10 +93,9 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::OAuth2Client.extract_ulid(&id)?; - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = ctx.data::>()?.lock().await; - let client = mas_storage::oauth2::client::lookup_client(&mut conn, id).await?; + let client = repo.oauth2_client().lookup(id).await?; Ok(client.map(OAuth2Client)) } @@ -118,13 +123,12 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::BrowserSession.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = ctx.data::>()?.lock().await; let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let browser_session = mas_storage::user::lookup_active_session(&mut conn, id).await?; + let browser_session = repo.browser_session().lookup(id).await?; let ret = browser_session.and_then(|browser_session| { if browser_session.user.id == current_user.id { @@ -145,14 +149,16 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::UserEmail.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = ctx.data::>()?.lock().await; let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let user_email = - mas_storage::user::lookup_user_email_by_id(&mut conn, ¤t_user, id).await?; + let user_email = repo + .user_email() + .lookup(id) + .await? + .filter(|e| e.user_id == current_user.id); Ok(user_email.map(UserEmail)) } @@ -165,13 +171,12 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = ctx.data::>()?.lock().await; let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let link = mas_storage::upstream_oauth2::lookup_link(&mut conn, id).await?; + let link = repo.upstream_oauth_link().lookup(id).await?; // Ensure that the link belongs to the current user let link = link.filter(|link| link.user_id == Some(current_user.id)); @@ -186,10 +191,9 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?; - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = ctx.data::>()?.lock().await; - let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id).await?; + let provider = repo.upstream_oauth_provider().lookup(id).await?; Ok(provider.map(UpstreamOAuth2Provider::new)) } @@ -206,7 +210,7 @@ impl RootQuery { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -214,7 +218,6 @@ impl RootQuery { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| { x.extract_for_type(NodeType::UpstreamOAuth2Provider) @@ -225,15 +228,15 @@ impl RootQuery { x.extract_for_type(NodeType::UpstreamOAuth2Provider) }) .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; - let (has_previous_page, has_next_page, edges) = - mas_storage::upstream_oauth2::get_paginated_providers( - &mut conn, before_id, after_id, first, last, - ) + let page = repo + .upstream_oauth_provider() + .list_paginated(pagination) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|p| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|p| { Edge::new( OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)), UpstreamOAuth2Provider::new(p), diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index 5b272b184..38fdd4baa 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use async_graphql::{Description, Object, ID}; +use anyhow::Context as _; +use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_data_model::CompatSsoLoginState; +use mas_storage::{compat::CompatSessionRepository, user::UserRepository, BoxRepository}; +use tokio::sync::Mutex; use url::Url; use super::{NodeType, User}; @@ -32,8 +34,14 @@ impl CompatSession { } /// The user authorized for this session. - async fn user(&self) -> User { - User(self.0.user.clone()) + async fn user(&self, ctx: &Context<'_>) -> Result { + let mut repo = ctx.data::>()?.lock().await; + let user = repo + .user() + .lookup(self.0.user_id) + .await? + .context("Could not load user")?; + Ok(User(user)) } /// The Matrix Device ID of this session. @@ -48,7 +56,7 @@ impl CompatSession { /// When the session ended. pub async fn finished_at(&self) -> Option> { - self.0.finished_at + self.0.finished_at() } } @@ -77,29 +85,28 @@ impl CompatSsoLogin { /// When the login was fulfilled, and the user was redirected back to the /// client. async fn fulfilled_at(&self) -> Option> { - match &self.0.state { - CompatSsoLoginState::Pending => None, - CompatSsoLoginState::Fulfilled { fulfilled_at, .. } - | CompatSsoLoginState::Exchanged { fulfilled_at, .. } => Some(*fulfilled_at), - } + self.0.fulfilled_at() } /// When the client exchanged the login token sent during the redirection. async fn exchanged_at(&self) -> Option> { - match &self.0.state { - CompatSsoLoginState::Pending | CompatSsoLoginState::Fulfilled { .. } => None, - CompatSsoLoginState::Exchanged { exchanged_at, .. } => Some(*exchanged_at), - } + self.0.exchanged_at() } /// The compat session which was started by this login. - async fn session(&self) -> Option { - match &self.0.state { - CompatSsoLoginState::Pending => None, - CompatSsoLoginState::Fulfilled { session, .. } - | CompatSsoLoginState::Exchanged { session, .. } => { - Some(CompatSession(session.clone())) - } - } + async fn session( + &self, + ctx: &Context<'_>, + ) -> Result, async_graphql::Error> { + let Some(session_id) = self.0.session_id() else { return Ok(None) }; + + let mut repo = ctx.data::>()?.lock().await; + let session = repo + .compat_session() + .lookup(session_id) + .await? + .context("Could not load compat session")?; + + Ok(Some(CompatSession(session))) } } diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 89598ffa5..19612f6d7 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,9 +14,9 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::oauth2::client::lookup_client; +use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, BoxRepository}; use oauth2_types::scope::Scope; -use sqlx::PgPool; +use tokio::sync::Mutex; use ulid::Ulid; use url::Url; @@ -35,8 +35,15 @@ impl OAuth2Session { } /// OAuth 2.0 client used by this session. - pub async fn client(&self) -> OAuth2Client { - OAuth2Client(self.0.client.clone()) + pub async fn client(&self, ctx: &Context<'_>) -> Result { + let mut repo = ctx.data::>()?.lock().await; + let client = repo + .oauth2_client() + .lookup(self.0.client_id) + .await? + .context("Could not load client")?; + + Ok(OAuth2Client(client)) } /// Scope granted for this session. @@ -45,13 +52,30 @@ impl OAuth2Session { } /// The browser session which started this OAuth 2.0 session. - pub async fn browser_session(&self) -> BrowserSession { - BrowserSession(self.0.browser_session.clone()) + pub async fn browser_session( + &self, + ctx: &Context<'_>, + ) -> Result { + let mut repo = ctx.data::>()?.lock().await; + let browser_session = repo + .browser_session() + .lookup(self.0.user_session_id) + .await? + .context("Could not load browser session")?; + + Ok(BrowserSession(browser_session)) } /// User authorized for this session. - pub async fn user(&self) -> User { - User(self.0.browser_session.user.clone()) + pub async fn user(&self, ctx: &Context<'_>) -> Result { + let mut repo = ctx.data::>()?.lock().await; + let browser_session = repo + .browser_session() + .lookup(self.0.user_session_id) + .await? + .context("Could not load browser session")?; + + Ok(User(browser_session.user)) } } @@ -114,8 +138,10 @@ impl OAuth2Consent { /// OAuth 2.0 client for which the user granted access. pub async fn client(&self, ctx: &Context<'_>) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let client = lookup_client(&mut conn, self.client_id) + let mut repo = ctx.data::>()?.lock().await; + let client = repo + .oauth2_client() + .lookup(self.client_id) .await? .context("Could not load client")?; Ok(OAuth2Client(client)) diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index 87164dd4e..76b3a44a3 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -15,7 +15,10 @@ use anyhow::Context as _; use async_graphql::{Context, Object, ID}; use chrono::{DateTime, Utc}; -use sqlx::PgPool; +use mas_storage::{ + upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, BoxRepository, +}; +use tokio::sync::Mutex; use super::{NodeType, User}; @@ -99,11 +102,13 @@ impl UpstreamOAuth2Link { provider.clone() } else { // Fetch on-the-fly - let database = ctx.data::()?; - let mut conn = database.acquire().await?; - mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id) + let mut repo = ctx.data::>()?.lock().await; + let provider = repo + .upstream_oauth_provider() + .lookup(self.link.provider_id) .await? - .context("Upstream OAuth 2.0 provider not found")? + .context("Upstream OAuth 2.0 provider not found")?; + provider }; Ok(UpstreamOAuth2Provider::new(provider)) @@ -116,9 +121,13 @@ impl UpstreamOAuth2Link { user.clone() } else if let Some(user_id) = &self.link.user_id { // Fetch on-the-fly - let database = ctx.data::()?; - let mut conn = database.acquire().await?; - mas_storage::user::lookup_user(&mut conn, *user_id).await? + let mut repo = ctx.data::>()?.lock().await; + let user = repo + .user() + .lookup(*user_id) + .await? + .context("User not found")?; + user } else { return Ok(None); }; diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index ad8bfa435..35c2cae4f 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -17,7 +17,14 @@ use async_graphql::{ Context, Description, Object, ID, }; use chrono::{DateTime, Utc}; -use sqlx::PgPool; +use mas_storage::{ + compat::CompatSsoLoginRepository, + oauth2::OAuth2SessionRepository, + upstream_oauth2::UpstreamOAuthLinkRepository, + user::{BrowserSessionRepository, UserEmailRepository}, + BoxRepository, Pagination, +}; +use tokio::sync::Mutex; use super::{ compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session, @@ -53,8 +60,14 @@ impl User { } /// Primary email address of the user. - async fn primary_email(&self) -> Option { - self.0.primary_email.clone().map(UserEmail) + async fn primary_email( + &self, + ctx: &Context<'_>, + ) -> Result, async_graphql::Error> { + let mut repo = ctx.data::>()?.lock().await; + + let mut user_email_repo = repo.user_email(); + Ok(user_email_repo.get_primary(&self.0).await?.map(UserEmail)) } /// Get the list of compatibility SSO logins, chronologically sorted @@ -69,7 +82,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -77,22 +90,21 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| x.extract_for_type(NodeType::CompatSsoLogin)) .transpose()?; let before_id = before .map(|x: OpaqueCursor| x.extract_for_type(NodeType::CompatSsoLogin)) .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; - let (has_previous_page, has_next_page, edges) = - mas_storage::compat::get_paginated_user_compat_sso_logins( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = repo + .compat_sso_login() + .list_paginated(&self.0, pagination) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|u| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|u| { Edge::new( OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.id)), CompatSsoLogin(u), @@ -117,7 +129,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -125,22 +137,21 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| x.extract_for_type(NodeType::BrowserSession)) .transpose()?; let before_id = before .map(|x: OpaqueCursor| x.extract_for_type(NodeType::BrowserSession)) .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; - let (has_previous_page, has_next_page, edges) = - mas_storage::user::get_paginated_user_sessions( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = repo + .browser_session() + .list_active_paginated(&self.0, pagination) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|u| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|u| { Edge::new( OpaqueCursor(NodeCursor(NodeType::BrowserSession, u.id)), BrowserSession(u), @@ -165,7 +176,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -173,26 +184,25 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| x.extract_for_type(NodeType::UserEmail)) .transpose()?; let before_id = before .map(|x: OpaqueCursor| x.extract_for_type(NodeType::UserEmail)) .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; - let (has_previous_page, has_next_page, edges) = - mas_storage::user::get_paginated_user_emails( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = repo + .user_email() + .list_paginated(&self.0, pagination) .await?; let mut connection = Connection::with_additional_fields( - has_previous_page, - has_next_page, + page.has_previous_page, + page.has_next_page, UserEmailsPagination(self.0.clone()), ); - connection.edges.extend(edges.into_iter().map(|u| { + connection.edges.extend(page.edges.into_iter().map(|u| { Edge::new( OpaqueCursor(NodeCursor(NodeType::UserEmail, u.id)), UserEmail(u), @@ -217,7 +227,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -225,22 +235,21 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| x.extract_for_type(NodeType::OAuth2Session)) .transpose()?; let before_id = before .map(|x: OpaqueCursor| x.extract_for_type(NodeType::OAuth2Session)) .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; - let (has_previous_page, has_next_page, edges) = - mas_storage::oauth2::get_paginated_user_oauth_sessions( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = repo + .oauth2_session() + .list_paginated(&self.0, pagination) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|s| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|s| { Edge::new( OpaqueCursor(NodeCursor(NodeType::OAuth2Session, s.id)), OAuth2Session(s), @@ -265,7 +274,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -273,7 +282,6 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| { x.extract_for_type(NodeType::UpstreamOAuth2Link) @@ -284,15 +292,15 @@ impl User { x.extract_for_type(NodeType::UpstreamOAuth2Link) }) .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; - let (has_previous_page, has_next_page, edges) = - mas_storage::upstream_oauth2::get_paginated_user_links( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = repo + .upstream_oauth_link() + .list_paginated(&self.0, pagination) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|s| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|s| { Edge::new( OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Link, s.id)), UpstreamOAuth2Link::new(s), @@ -339,9 +347,9 @@ pub struct UserEmailsPagination(mas_data_model::User); #[Object] impl UserEmailsPagination { /// Identifies the total count of items in the connection. - async fn total_count(&self, ctx: &Context<'_>) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let count = mas_storage::user::count_user_emails(&mut conn, &self.0).await?; + async fn total_count(&self, ctx: &Context<'_>) -> Result { + let mut repo = ctx.data::>()?.lock().await; + let count = repo.user_email().count(&self.0).await?; Ok(count) } } diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index 23424c0ab..e2d5a507a 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -68,6 +68,7 @@ mas-oidc-client = { path = "../oidc-client" } mas-policy = { path = "../policy" } mas-router = { path = "../router" } mas-storage = { path = "../storage" } +mas-storage-pg = { path = "../storage-pg" } mas-templates = { path = "../templates" } oauth2-types = { path = "../oauth2-types" } diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index c9c650904..2e826badc 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -12,16 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; +use std::{convert::Infallible, sync::Arc}; -use axum::extract::FromRef; +use axum::{ + async_trait, + extract::{FromRef, FromRequestParts}, + response::IntoResponse, +}; +use hyper::StatusCode; use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_email::Mailer; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; +use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock}; +use mas_storage_pg::PgRepository; use mas_templates::Templates; +use rand::SeedableRng; use sqlx::PgPool; +use thiserror::Error; use crate::{passwords::PasswordManager, MatrixHomeserver}; @@ -105,3 +114,58 @@ impl FromRef for PasswordManager { input.password_manager.clone() } } + +#[async_trait] +impl FromRequestParts for BoxClock { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + _state: &AppState, + ) -> Result { + let clock = SystemClock::default(); + Ok(Box::new(clock)) + } +} + +#[async_trait] +impl FromRequestParts for BoxRng { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + _state: &AppState, + ) -> Result { + // This rng is used to source the local rng + #[allow(clippy::disallowed_methods)] + let rng = rand::thread_rng(); + + let rng = rand_chacha::ChaChaRng::from_rng(rng).expect("Failed to seed RNG"); + Ok(Box::new(rng)) + } +} + +#[derive(Debug, Error)] +#[error(transparent)] +pub struct RepositoryError(#[from] mas_storage_pg::DatabaseError); + +impl IntoResponse for RepositoryError { + fn into_response(self) -> axum::response::Response { + (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response() + } +} + +#[async_trait] +impl FromRequestParts for BoxRepository { + type Rejection = RepositoryError; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &AppState, + ) -> Result { + let repo = PgRepository::from_pool(&state.pool).await?; + Ok(repo + .map_err(mas_storage::RepositoryError::from_error) + .boxed()) + } +} diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 6d5ca8255..07077ea13 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -15,18 +15,18 @@ use axum::{extract::State, response::IntoResponse, Json}; use chrono::Duration; use hyper::StatusCode; -use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType}; +use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User}; use mas_storage::{ compat::{ - add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token, - mark_compat_sso_login_as_exchanged, start_compat_session, + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + CompatSsoLoginRepository, }, - user::{add_user_password, lookup_user_by_username, lookup_user_password}, - Clock, + user::{UserPasswordRepository, UserRepository}, + BoxClock, BoxRepository, BoxRng, Clock, }; +use rand::{CryptoRng, RngCore}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; -use sqlx::{PgPool, Postgres, Transaction}; use thiserror::Error; use zeroize::Zeroizing; @@ -137,6 +137,9 @@ pub enum RouteError { #[error("user not found")] UserNotFound, + #[error("session not found")] + SessionNotFound, + #[error("user has no password")] NoPassword, @@ -150,13 +153,12 @@ pub enum RouteError { InvalidLoginToken, } -impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { - Self::Internal(_) => MatrixError { + Self::Internal(_) | Self::SessionNotFound => MatrixError { errcode: "M_UNKNOWN", error: "Internal server error", status: StatusCode::INTERNAL_SERVER_ERROR, @@ -190,27 +192,37 @@ impl IntoResponse for RouteError { #[tracing::instrument(skip_all, err)] pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(password_manager): State, - State(pool): State, + mut repo: BoxRepository, State(homeserver): State, Json(input): Json, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; - let session = match input.credentials { + let (session, user) = match input.credentials { Credentials::Password { identifier: Identifier::User { user }, password, - } => user_password_login(&password_manager, &mut txn, user, password).await?, + } => { + user_password_login( + &mut rng, + &clock, + &password_manager, + &mut repo, + user, + password, + ) + .await? + } - Credentials::Token { token } => token_login(&mut txn, &clock, &token).await?, + Credentials::Token { token } => token_login(&mut repo, &clock, &token).await?, _ => { return Err(RouteError::Unsupported); } }; - let user_id = format!("@{username}:{homeserver}", username = session.user.username); + let user_id = format!("@{username}:{homeserver}", username = user.username); // If the client asked for a refreshable token, make it expire let expires_in = if input.refresh_token { @@ -221,33 +233,23 @@ pub(crate) async fn post( }; let access_token = TokenType::CompatAccessToken.generate(&mut rng); - let access_token = add_compat_access_token( - &mut txn, - &mut rng, - &clock, - &session, - access_token, - expires_in, - ) - .await?; + let access_token = repo + .compat_access_token() + .add(&mut rng, &clock, &session, access_token, expires_in) + .await?; let refresh_token = if input.refresh_token { let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng); - let refresh_token = add_compat_refresh_token( - &mut txn, - &mut rng, - &clock, - &session, - &access_token, - refresh_token, - ) - .await?; + let refresh_token = repo + .compat_refresh_token() + .add(&mut rng, &clock, &session, &access_token, refresh_token) + .await?; Some(refresh_token.token) } else { None }; - txn.commit().await?; + repo.save().await?; Ok(Json(ResponseBody { access_token: access_token.token, @@ -259,16 +261,18 @@ pub(crate) async fn post( } async fn token_login( - txn: &mut Transaction<'_, Postgres>, - clock: &Clock, + repo: &mut BoxRepository, + clock: &dyn Clock, token: &str, -) -> Result { - let login = get_compat_sso_login_by_token(&mut *txn, token) +) -> Result<(CompatSession, User), RouteError> { + let login = repo + .compat_sso_login() + .find_by_token(token) .await? .ok_or(RouteError::InvalidLoginToken)?; let now = clock.now(); - match login.state { + let session_id = match login.state { CompatSsoLoginState::Pending => { tracing::error!( compat_sso_login.id = %login.id, @@ -277,49 +281,70 @@ async fn token_login( return Err(RouteError::InvalidLoginToken); } CompatSsoLoginState::Fulfilled { - fulfilled_at: fullfilled_at, + fulfilled_at, + session_id, .. } => { - if now > fullfilled_at + Duration::seconds(30) { + if now > fulfilled_at + Duration::seconds(30) { return Err(RouteError::LoginTookTooLong); } + + session_id } - CompatSsoLoginState::Exchanged { exchanged_at, .. } => { + CompatSsoLoginState::Exchanged { + exchanged_at, + session_id, + .. + } => { if now > exchanged_at + Duration::seconds(30) { // TODO: log that session out tracing::error!( compat_sso_login.id = %login.id, + compat_session.id = %session_id, "Login token exchanged a second time more than 30s after" ); } return Err(RouteError::InvalidLoginToken); } - } + }; - let login = mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?; + let session = repo + .compat_session() + .lookup(session_id) + .await? + .ok_or(RouteError::SessionNotFound)?; - match login.state { - CompatSsoLoginState::Exchanged { session, .. } => Ok(session), - _ => unreachable!(), - } + let user = repo + .user() + .lookup(session.user_id) + .await? + .ok_or(RouteError::UserNotFound)?; + + repo.compat_sso_login().exchange(clock, login).await?; + + Ok((session, user)) } async fn user_password_login( + mut rng: &mut (impl RngCore + CryptoRng + Send), + clock: &impl Clock, password_manager: &PasswordManager, - txn: &mut Transaction<'_, Postgres>, + repo: &mut BoxRepository, username: String, password: String, -) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - +) -> Result<(CompatSession, User), RouteError> { // Find the user - let user = lookup_user_by_username(&mut *txn, &username) + let user = repo + .user() + .find_by_username(&username) .await? .ok_or(RouteError::UserNotFound)?; // Lookup its password - let user_password = lookup_user_password(&mut *txn, &user) + let user_password = repo + .user_password() + .active(&user) .await? .ok_or(RouteError::NoPassword)?; @@ -338,21 +363,24 @@ async fn user_password_login( if let Some((version, hashed_password)) = new_password_hash { // Save the upgraded password if needed - add_user_password( - &mut *txn, - &mut rng, - &clock, - &user, - version, - hashed_password, - Some(user_password), - ) - .await?; + repo.user_password() + .add( + &mut rng, + clock, + &user, + version, + hashed_password, + Some(&user_password), + ) + .await?; } // Now that the user credentials have been verified, start a new compat session let device = Device::generate(&mut rng); - let session = start_compat_session(&mut *txn, &mut rng, &clock, user, device).await?; + let session = repo + .compat_session() + .add(&mut rng, clock, &user, device) + .await?; - Ok(session) + Ok((session, user)) } diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index e0416cd17..540287467 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -29,10 +29,12 @@ use mas_axum_utils::{ use mas_data_model::Device; use mas_keystore::Encrypter; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; -use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id}; +use mas_storage::{ + compat::{CompatSessionRepository, CompatSsoLoginRepository}, + BoxClock, BoxRepository, BoxRng, Clock, +}; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use serde::{Deserialize, Serialize}; -use sqlx::PgPool; use ulid::Ulid; #[derive(Serialize)] @@ -50,19 +52,18 @@ pub struct Params { } pub async fn get( - State(pool): State, + mut rng: BoxRng, + clock: BoxClock, + mut repo: BoxRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, Query(params): Query, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -80,20 +81,16 @@ pub async fn get( return Ok((cookie_jar, url).into_response()); }; - // TODO: make that more generic - if session - .user - .primary_email - .as_ref() - .and_then(|e| e.confirmed_at) - .is_none() - { + // TODO: make that more generic, check that the email has been confirmed + if session.user.primary_user_email_id.is_none() { let destination = mas_router::AccountAddEmail::default() .and_then(PostAuthAction::continue_compat_sso_login(id)); return Ok((cookie_jar, destination.go()).into_response()); } - let login = get_compat_sso_login_by_id(&mut conn, id) + let login = repo + .compat_sso_login() + .lookup(id) .await? .context("Could not find compat SSO login")?; @@ -117,20 +114,19 @@ pub async fn get( } pub async fn post( - State(pool): State, + mut rng: BoxRng, + clock: BoxClock, + mut repo: BoxRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, Query(params): Query, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); - cookie_jar.verify_form(clock.now(), form)?; + cookie_jar.verify_form(&clock, form)?; - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -149,19 +145,15 @@ pub async fn post( }; // TODO: make that more generic - if session - .user - .primary_email - .as_ref() - .and_then(|e| e.confirmed_at) - .is_none() - { + if session.user.primary_user_email_id.is_none() { let destination = mas_router::AccountAddEmail::default() .and_then(PostAuthAction::continue_compat_sso_login(id)); return Ok((cookie_jar, destination.go()).into_response()); } - let login = get_compat_sso_login_by_id(&mut txn, id) + let login = repo + .compat_sso_login() + .lookup(id) .await? .context("Could not find compat SSO login")?; @@ -193,10 +185,16 @@ pub async fn post( }; let device = Device::generate(&mut rng); - let _login = - fullfill_compat_sso_login(&mut txn, &mut rng, &clock, session.user, login, device).await?; + let compat_session = repo + .compat_session() + .add(&mut rng, &clock, &session.user, device) + .await?; - txn.commit().await?; + repo.compat_sso_login() + .fulfill(&clock, login, &compat_session) + .await?; + + repo.save().await?; Ok((cookie_jar, Redirect::to(redirect_uri.as_str())).into_response()) } diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index f90862c72..da013cf7c 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -19,11 +19,10 @@ use axum::{ }; use hyper::StatusCode; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; -use mas_storage::compat::insert_compat_sso_login; +use mas_storage::{compat::CompatSsoLoginRepository, BoxClock, BoxRepository, BoxRng}; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use serde_with::serde; -use sqlx::PgPool; use thiserror::Error; use url::Url; @@ -48,7 +47,7 @@ pub enum RouteError { InvalidRedirectUrl, } -impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -56,14 +55,13 @@ impl IntoResponse for RouteError { } } -#[tracing::instrument(skip(pool, url_builder), err)] pub async fn get( - State(pool): State, + mut rng: BoxRng, + clock: BoxClock, + mut repo: BoxRepository, State(url_builder): State, Query(params): Query, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - // Check the redirectUrl parameter let redirect_url = params.redirect_url.ok_or(RouteError::MissingRedirectUrl)?; let redirect_url = Url::parse(&redirect_url).map_err(|_| RouteError::InvalidRedirectUrl)?; @@ -79,8 +77,10 @@ pub async fn get( } let token = Alphanumeric.sample_string(&mut rng, 32); - let mut conn = pool.acquire().await?; - let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?; + let login = repo + .compat_sso_login() + .add(&mut rng, &clock, token, redirect_url) + .await?; Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.id, params.action))) } diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 4dca7797f..096b22de5 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{extract::State, response::IntoResponse, Json, TypedHeader}; +use axum::{response::IntoResponse, Json, TypedHeader}; use headers::{authorization::Bearer, Authorization}; use hyper::StatusCode; use mas_data_model::TokenType; -use mas_storage::{compat::compat_logout, Clock}; -use sqlx::PgPool; +use mas_storage::{ + compat::{CompatAccessTokenRepository, CompatSessionRepository}, + BoxClock, BoxRepository, Clock, +}; use thiserror::Error; use super::MatrixError; @@ -36,12 +38,9 @@ pub enum RouteError { #[error("Invalid access token")] InvalidAuthorization, - - #[error("Logout failed")] - LogoutFailed, } -impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -56,7 +55,7 @@ impl IntoResponse for RouteError { error: "Missing access token", status: StatusCode::UNAUTHORIZED, }, - Self::InvalidAuthorization | Self::LogoutFailed | Self::TokenFormat(_) => MatrixError { + Self::InvalidAuthorization | Self::TokenFormat(_) => MatrixError { errcode: "M_UNKNOWN_TOKEN", error: "Invalid access token", status: StatusCode::UNAUTHORIZED, @@ -67,12 +66,10 @@ impl IntoResponse for RouteError { } pub(crate) async fn post( - State(pool): State, + clock: BoxClock, + mut repo: BoxRepository, maybe_authorization: Option>>, ) -> Result { - let clock = Clock::default(); - let mut conn = pool.acquire().await?; - let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; let token = authorization.token(); @@ -82,9 +79,23 @@ pub(crate) async fn post( return Err(RouteError::InvalidAuthorization); } - if !compat_logout(&mut conn, &clock, token).await? { - return Err(RouteError::LogoutFailed); - } + let token = repo + .compat_access_token() + .find_by_token(token) + .await? + .filter(|t| t.is_valid(clock.now())) + .ok_or(RouteError::InvalidAuthorization)?; + + let session = repo + .compat_session() + .lookup(token.session_id) + .await? + .filter(|s| s.is_valid()) + .ok_or(RouteError::InvalidAuthorization)?; + + repo.compat_session().finish(&clock, session).await?; + + repo.save().await?; Ok(Json(serde_json::json!({}))) } diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 912a0f1a5..eb970c570 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -12,17 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{extract::State, response::IntoResponse, Json}; +use axum::{response::IntoResponse, Json}; use chrono::Duration; use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; -use mas_storage::compat::{ - add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token, - expire_compat_access_token, lookup_active_compat_refresh_token, +use mas_storage::{ + compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, + BoxClock, BoxRepository, BoxRng, Clock, }; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; -use sqlx::PgPool; use thiserror::Error; use super::MatrixError; @@ -40,17 +39,26 @@ pub enum RouteError { #[error("invalid token")] InvalidToken, + + #[error("refresh token already consumed")] + RefreshTokenConsumed, + + #[error("invalid session")] + InvalidSession, + + #[error("unknown session")] + UnknownSession, } impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { - Self::Internal(_) => MatrixError { + Self::Internal(_) | Self::UnknownSession => MatrixError { errcode: "M_UNKNOWN", error: "Internal error", status: StatusCode::INTERNAL_SERVER_ERROR, }, - Self::InvalidToken => MatrixError { + Self::InvalidToken | Self::InvalidSession | Self::RefreshTokenConsumed => MatrixError { errcode: "M_UNKNOWN_TOKEN", error: "Invalid refresh token", status: StatusCode::UNAUTHORIZED, @@ -60,8 +68,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl From for RouteError { fn from(_e: TokenFormatError) -> Self { @@ -79,50 +86,79 @@ pub struct ResponseBody { } pub(crate) async fn post( - State(pool): State, + mut rng: BoxRng, + clock: BoxClock, + mut repo: BoxRepository, Json(input): Json, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; - let token_type = TokenType::check(&input.refresh_token)?; if token_type != TokenType::CompatRefreshToken { return Err(RouteError::InvalidToken); } - let (refresh_token, access_token, session) = - lookup_active_compat_refresh_token(&mut txn, &input.refresh_token) - .await? - .ok_or(RouteError::InvalidToken)?; + let refresh_token = repo + .compat_refresh_token() + .find_by_token(&input.refresh_token) + .await? + .ok_or(RouteError::InvalidToken)?; + + if !refresh_token.is_valid() { + return Err(RouteError::RefreshTokenConsumed); + } + + let session = repo + .compat_session() + .lookup(refresh_token.session_id) + .await? + .ok_or(RouteError::UnknownSession)?; + + if !session.is_valid() { + return Err(RouteError::InvalidSession); + } + + let access_token = repo + .compat_access_token() + .lookup(refresh_token.access_token_id) + .await? + .filter(|t| t.is_valid(clock.now())); let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng); let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng); let expires_in = Duration::minutes(5); - let new_access_token = add_compat_access_token( - &mut txn, - &mut rng, - &clock, - &session, - new_access_token_str, - Some(expires_in), - ) - .await?; - let new_refresh_token = add_compat_refresh_token( - &mut txn, - &mut rng, - &clock, - &session, - &new_access_token, - new_refresh_token_str, - ) - .await?; + let new_access_token = repo + .compat_access_token() + .add( + &mut rng, + &clock, + &session, + new_access_token_str, + Some(expires_in), + ) + .await?; + let new_refresh_token = repo + .compat_refresh_token() + .add( + &mut rng, + &clock, + &session, + &new_access_token, + new_refresh_token_str, + ) + .await?; - consume_compat_refresh_token(&mut txn, &clock, refresh_token).await?; - expire_compat_access_token(&mut txn, &clock, access_token).await?; + repo.compat_refresh_token() + .consume(&clock, refresh_token) + .await?; - txn.commit().await?; + if let Some(access_token) = access_token { + repo.compat_access_token() + .expire(&clock, access_token) + .await?; + } + + repo.save().await?; Ok(Json(ResponseBody { access_token: new_access_token.token, diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql.rs index ba6919940..233c46906 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql.rs @@ -22,19 +22,19 @@ use axum::{ Json, TypedHeader, }; use axum_extra::extract::PrivateCookieJar; -use futures_util::{StreamExt, TryStreamExt}; +use futures_util::TryStreamExt; use headers::{ContentType, HeaderValue}; use hyper::header::CACHE_CONTROL; use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_graphql::Schema; use mas_keystore::Encrypter; -use sqlx::PgPool; +use mas_storage::BoxRepository; +use tokio::sync::Mutex; use tracing::{info_span, Instrument}; #[must_use] -pub fn schema(pool: &PgPool) -> Schema { +pub fn schema() -> Schema { mas_graphql::schema_builder() - .data(pool.clone()) .extension(Tracing) .extension(ApolloTracing) .finish() @@ -58,8 +58,8 @@ fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span { } pub async fn post( - State(pool): State, State(schema): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, content_type: Option>, body: BodyStream, @@ -67,58 +67,46 @@ pub async fn post( let content_type = content_type.map(|TypedHeader(h)| h.to_string()); let (session_info, _cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&pool).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let mut request = async_graphql::http::receive_batch_body( + let mut request = async_graphql::http::receive_body( content_type, body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) .into_async_read(), MultipartOptions::default(), ) - .await?; // XXX: this should probably return another error response? + .await? // XXX: this should probably return another error response? + .data(Mutex::new(repo)); if let Some(session) = maybe_session { request = request.data(session); } - let response = match request { - async_graphql::BatchRequest::Single(request) => { - let span = span_for_graphql_request(&request); - let response = schema.execute(request).instrument(span).await; - async_graphql::BatchResponse::Single(response) - } - async_graphql::BatchRequest::Batch(requests) => async_graphql::BatchResponse::Batch( - futures_util::stream::iter(requests.into_iter()) - .then(|request| { - let span = span_for_graphql_request(&request); - schema.execute(request).instrument(span) - }) - .collect() - .await, - ), - }; + let span = span_for_graphql_request(&request); + let response = schema.execute(request).instrument(span).await; let cache_control = response - .cache_control() + .cache_control .value() .and_then(|v| HeaderValue::from_str(&v).ok()) .map(|h| [(CACHE_CONTROL, h)]); - let headers = response.http_headers(); + let headers = response.http_headers.clone(); Ok((headers, cache_control, Json(response))) } pub async fn get( - State(pool): State, State(schema): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, RawQuery(query): RawQuery, ) -> Result { let (session_info, _cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&pool).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?; + let mut request = + async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(Mutex::new(repo)); if let Some(session) = maybe_session { request = request.data(session); diff --git a/crates/handlers/src/health.rs b/crates/handlers/src/health.rs index 6322dffde..10638497b 100644 --- a/crates/handlers/src/health.rs +++ b/crates/handlers/src/health.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ mod tests { use super::*; - #[sqlx::test(migrator = "mas_storage::MIGRATOR")] + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_get_health(pool: PgPool) -> Result<(), anyhow::Error> { let state = crate::test_state(pool).await?; let app = crate::healthcheck_router().with_state(state); diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 501f3176f..30519f425 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -21,14 +21,17 @@ )] #![warn(clippy::pedantic)] #![allow( - clippy::unused_async // Some axum handlers need that + // Some axum handlers need that + clippy::unused_async, + // Because of how axum handlers work, we sometime have take many arguments + clippy::too_many_arguments, )] use std::{convert::Infallible, sync::Arc, time::Duration}; use axum::{ body::{Bytes, HttpBody}, - extract::FromRef, + extract::{FromRef, FromRequestParts}, response::{Html, IntoResponse}, routing::{get, on, post, MethodFilter}, Router, @@ -40,9 +43,9 @@ use mas_http::CorsLayerExt; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::{Route, UrlBuilder}; +use mas_storage::{BoxClock, BoxRepository, BoxRng}; use mas_templates::{ErrorContext, Templates}; use passwords::PasswordManager; -use rand::SeedableRng; use sqlx::PgPool; use tower::util::AndThenLayer; use tower_http::cors::{Any, CorsLayer}; @@ -94,7 +97,7 @@ where ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, mas_graphql::Schema: FromRef, - PgPool: FromRef, + BoxRepository: FromRequestParts, Encrypter: FromRef, { let mut router = Router::new().route( @@ -116,6 +119,8 @@ where S: Clone + Send + Sync + 'static, Keystore: FromRef, UrlBuilder: FromRef, + BoxClock: FromRequestParts, + BoxRng: FromRequestParts, { Router::new() .route( @@ -152,9 +157,11 @@ where Keystore: FromRef, UrlBuilder: FromRef, Arc: FromRef, - PgPool: FromRef, + BoxRepository: FromRequestParts, Encrypter: FromRef, HttpClientFactory: FromRef, + BoxClock: FromRequestParts, + BoxRng: FromRequestParts, { // All those routes are API-like, with a common CORS layer Router::new() @@ -205,9 +212,11 @@ where ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, - PgPool: FromRef, + BoxRepository: FromRequestParts, MatrixHomeserver: FromRef, PasswordManager: FromRef, + BoxClock: FromRequestParts, + BoxRng: FromRequestParts, { Router::new() .route( @@ -248,13 +257,15 @@ where S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, Arc: FromRef, - PgPool: FromRef, + BoxRepository: FromRequestParts, Encrypter: FromRef, Templates: FromRef, Mailer: FromRef, Keystore: FromRef, HttpClientFactory: FromRef, PasswordManager: FromRef, + BoxClock: FromRequestParts, + BoxRng: FromRequestParts, { Router::new() .route( @@ -350,7 +361,7 @@ where } #[cfg(test)] -async fn test_state(pool: PgPool) -> Result { +async fn test_state(pool: sqlx::PgPool) -> Result { use mas_email::MailTransport; use crate::passwords::Hasher; @@ -389,7 +400,7 @@ async fn test_state(pool: PgPool) -> Result { let policy_factory = Arc::new(policy_factory); - let graphql_schema = graphql_schema(&pool); + let graphql_schema = graphql_schema(); let http_client_factory = HttpClientFactory::new(10); @@ -407,16 +418,3 @@ async fn test_state(pool: PgPool) -> Result { password_manager, }) } - -// XXX: that should be moved somewhere else -fn clock_and_rng() -> (mas_storage::Clock, rand_chacha::ChaChaRng) { - let clock = mas_storage::Clock::default(); - - // This rng is used to source the local rng - #[allow(clippy::disallowed_methods)] - let rng = rand::thread_rng(); - - let rng = rand_chacha::ChaChaRng::from_rng(rng).expect("Failed to seed RNG"); - - (clock, rng) -} diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index eb4b8889d..6b6869da9 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -25,13 +25,12 @@ use mas_data_model::{AuthorizationGrant, BrowserSession}; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; -use mas_storage::oauth2::{ - authorization_grant::{derive_session, fulfill_grant, get_grant_by_id}, - consent::fetch_client_consent, +use mas_storage::{ + oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, + BoxClock, BoxRepository, BoxRng, }; use mas_templates::Templates; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; -use sqlx::{PgPool, Postgres, Transaction}; use thiserror::Error; use ulid::Ulid; @@ -69,8 +68,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -78,19 +76,21 @@ impl_from_error_for_route!(super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(super::callback::CallbackDestinationError); pub(crate) async fn get( + rng: BoxRng, + clock: BoxClock, State(policy_factory): State>, State(templates): State, - State(pool): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { - let mut txn = pool.begin().await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let grant = get_grant_by_id(&mut txn, grant_id) + let grant = repo + .oauth2_authorization_grant() + .lookup(grant_id) .await? .ok_or(RouteError::NotFound)?; @@ -105,7 +105,7 @@ pub(crate) async fn get( return Ok((cookie_jar, mas_router::Login::and_then(continue_grant).go()).into_response()); }; - match complete(grant, session, &policy_factory, txn).await { + match complete(rng, clock, grant, session, &policy_factory, repo).await { Ok(params) => { let res = callback_destination.go(&templates, params).await?; Ok((cookie_jar, res).into_response()) @@ -121,6 +121,7 @@ pub(crate) async fn get( } Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending), Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)), + Err(e) => Err(RouteError::Internal(e.into())), } } @@ -140,23 +141,25 @@ pub enum GrantCompletionError { #[error("denied by the policy")] PolicyViolation, + + #[error("failed to load client")] + NoSuchClient, } -impl_from_error_for_route!(GrantCompletionError: sqlx::Error); -impl_from_error_for_route!(GrantCompletionError: mas_storage::DatabaseError); +impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError); impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError); impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError); pub(crate) async fn complete( + mut rng: BoxRng, + clock: BoxClock, grant: AuthorizationGrant, browser_session: BrowserSession, policy_factory: &PolicyFactory, - mut txn: Transaction<'_, Postgres>, + mut repo: BoxRepository, ) -> Result>, GrantCompletionError> { - let (clock, mut rng) = crate::clock_and_rng(); - // Verify that the grant is in a pending stage if !grant.stage.is_pending() { return Err(GrantCompletionError::NotPending); @@ -164,7 +167,7 @@ pub(crate) async fn complete( // Check if the authentication is fresh enough if !browser_session.was_authenticated_after(grant.max_auth_time()) { - txn.commit().await?; + repo.save().await?; return Err(GrantCompletionError::RequiresReauth); } @@ -178,8 +181,16 @@ pub(crate) async fn complete( return Err(GrantCompletionError::PolicyViolation); } - let current_consent = - fetch_client_consent(&mut txn, &browser_session.user, &grant.client).await?; + let client = repo + .oauth2_client() + .lookup(grant.client_id) + .await? + .ok_or(GrantCompletionError::NoSuchClient)?; + + let current_consent = repo + .oauth2_client() + .get_consent_for_user(&client, &browser_session.user) + .await?; let lacks_consent = grant .scope @@ -188,14 +199,20 @@ pub(crate) async fn complete( // Check if the client lacks consent *or* if consent was explicitely asked if lacks_consent || grant.requires_consent { - txn.commit().await?; + repo.save().await?; return Err(GrantCompletionError::RequiresConsent); } // All good, let's start the session - let session = derive_session(&mut txn, &mut rng, &clock, &grant, browser_session).await?; + let session = repo + .oauth2_session() + .create_from_grant(&mut rng, &clock, &grant, &browser_session) + .await?; - let grant = fulfill_grant(&mut txn, grant, session.clone()).await?; + let grant = repo + .oauth2_authorization_grant() + .fulfill(&clock, &session, grant) + .await?; // Yep! Let's complete the auth now let mut params = AuthorizationResponse::default(); @@ -213,6 +230,6 @@ pub(crate) async fn complete( )); } - txn.commit().await?; + repo.save().await?; Ok(params) } diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 1b999ffc6..1aa8ce64d 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -25,8 +25,9 @@ use mas_data_model::{AuthorizationCode, Pkce}; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; -use mas_storage::oauth2::{ - authorization_grant::new_authorization_grant, client::lookup_client_by_client_id, +use mas_storage::{ + oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, + BoxClock, BoxRepository, BoxRng, }; use mas_templates::Templates; use oauth2_types::{ @@ -37,7 +38,6 @@ use oauth2_types::{ }; use rand::{distributions::Alphanumeric, Rng}; use serde::Deserialize; -use sqlx::PgPool; use thiserror::Error; use self::{callback::CallbackDestination, complete::GrantCompletionError}; @@ -89,8 +89,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); @@ -131,17 +130,18 @@ fn resolve_response_mode( #[allow(clippy::too_many_lines)] pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(policy_factory): State>, State(templates): State, - State(pool): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Form(params): Form, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; - // First, figure out what client it is - let client = lookup_client_by_client_id(&mut txn, ¶ms.auth.client_id) + let client = repo + .oauth2_client() + .find_by_client_id(¶ms.auth.client_id) .await? .ok_or(RouteError::ClientNotFound)?; @@ -167,7 +167,7 @@ pub(crate) async fn get( let templates = templates.clone(); let callback_destination = callback_destination.clone(); async move { - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let prompt = params.auth.prompt.as_deref().unwrap_or_default(); // Check if the request/request_uri/registration params are used. If so, reply @@ -272,23 +272,23 @@ pub(crate) async fn get( let requires_consent = prompt.contains(&Prompt::Consent); - let grant = new_authorization_grant( - &mut txn, - &mut rng, - &clock, - client, - redirect_uri.clone(), - params.auth.scope, - code, - params.auth.state.clone(), - params.auth.nonce, - params.auth.max_age, - None, - response_mode, - response_type.has_id_token(), - requires_consent, - ) - .await?; + let grant = repo + .oauth2_authorization_grant() + .add( + &mut rng, + &clock, + &client, + redirect_uri.clone(), + params.auth.scope, + code, + params.auth.state.clone(), + params.auth.nonce, + params.auth.max_age, + response_mode, + response_type.has_id_token(), + requires_consent, + ) + .await?; let continue_grant = PostAuthAction::continue_grant(grant.id); let res = match maybe_session { @@ -299,7 +299,7 @@ pub(crate) async fn get( } None if prompt.contains(&Prompt::Create) => { // Client asked for a registration, show the registration prompt - txn.commit().await?; + repo.save().await?; mas_router::Register::and_then(continue_grant) .go() @@ -307,7 +307,7 @@ pub(crate) async fn get( } None => { // Other cases where we don't have a session, ask for a login - txn.commit().await?; + repo.save().await?; mas_router::Login::and_then(continue_grant) .go() @@ -320,7 +320,7 @@ pub(crate) async fn get( || prompt.contains(&Prompt::SelectAccount) => { // TODO: better pages here - txn.commit().await?; + repo.save().await?; mas_router::Reauth::and_then(continue_grant) .go() @@ -330,7 +330,15 @@ pub(crate) async fn get( // Else, we immediately try to complete the authorization grant Some(user_session) if prompt.contains(&Prompt::None) => { // With prompt=none, we should get back to the client immediately - match self::complete::complete(grant, user_session, &policy_factory, txn).await + match self::complete::complete( + rng, + clock, + grant, + user_session, + &policy_factory, + repo, + ) + .await { Ok(params) => callback_destination.go(&templates, params).await?, Err(GrantCompletionError::RequiresConsent) => { @@ -357,7 +365,10 @@ pub(crate) async fn get( Err(GrantCompletionError::Internal(e)) => { return Err(RouteError::Internal(e)) } - Err(e @ GrantCompletionError::NotPending) => { + Err( + e @ (GrantCompletionError::NotPending + | GrantCompletionError::NoSuchClient), + ) => { // This should never happen return Err(RouteError::Internal(Box::new(e))); } @@ -366,7 +377,15 @@ pub(crate) async fn get( Some(user_session) => { let grant_id = grant.id; // Else, we show the relevant reauth/consent page if necessary - match self::complete::complete(grant, user_session, &policy_factory, txn).await + match self::complete::complete( + rng, + clock, + grant, + user_session, + &policy_factory, + repo, + ) + .await { Ok(params) => callback_destination.go(&templates, params).await?, Err( @@ -387,7 +406,10 @@ pub(crate) async fn get( Err(GrantCompletionError::Internal(e)) => { return Err(RouteError::Internal(e)) } - Err(e @ GrantCompletionError::NotPending) => { + Err( + e @ (GrantCompletionError::NotPending + | GrantCompletionError::NoSuchClient), + ) => { // This should never happen return Err(RouteError::Internal(Box::new(e))); } diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 451077837..ffaa6106e 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -28,12 +28,11 @@ use mas_data_model::AuthorizationGrantStage; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; -use mas_storage::oauth2::{ - authorization_grant::{get_grant_by_id, give_consent_to_grant}, - consent::insert_client_consent, +use mas_storage::{ + oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, + BoxClock, BoxRepository, BoxRng, }; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -55,11 +54,13 @@ pub enum RouteError { #[error("Policy violation")] PolicyViolation, + + #[error("Failed to load client")] + NoSuchClient, } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_templates::TemplateError); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -71,20 +72,21 @@ impl IntoResponse for RouteError { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(policy_factory): State>, State(templates): State, - State(pool): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let grant = get_grant_by_id(&mut conn, grant_id) + let grant = repo + .oauth2_authorization_grant() + .lookup(grant_id) .await? .ok_or(RouteError::GrantNotFound)?; @@ -93,7 +95,7 @@ pub(crate) async fn get( } if let Some(session) = maybe_session { - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let mut policy = policy_factory.instantiate().await?; let res = policy @@ -124,22 +126,23 @@ pub(crate) async fn get( } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(policy_factory): State>, - State(pool): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; - - cookie_jar.verify_form(clock.now(), form)?; + cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let grant = get_grant_by_id(&mut txn, grant_id) + let grant = repo + .oauth2_authorization_grant() + .lookup(grant_id) .await? .ok_or(RouteError::GrantNotFound)?; let next = PostAuthAction::continue_grant(grant_id); @@ -160,6 +163,12 @@ pub(crate) async fn post( return Err(RouteError::PolicyViolation); } + let client = repo + .oauth2_client() + .lookup(grant.client_id) + .await? + .ok_or(RouteError::NoSuchClient)?; + // Do not consent for the "urn:matrix:org.matrix.msc2967.client:device:*" scope let scope_without_device = grant .scope @@ -167,19 +176,21 @@ pub(crate) async fn post( .filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:")) .cloned() .collect(); - insert_client_consent( - &mut txn, - &mut rng, - &clock, - &session.user, - &grant.client, - &scope_without_device, - ) - .await?; + repo.oauth2_client() + .give_consent_for_user( + &mut rng, + &clock, + &client, + &session.user, + &scope_without_device, + ) + .await?; - let _grant = give_consent_to_grant(&mut txn, grant).await?; + repo.oauth2_authorization_grant() + .give_consent(grant) + .await?; - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, next.go_next()).into_response()) } diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index c2e68261b..3b44c511d 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -22,18 +22,16 @@ use mas_data_model::{TokenFormatError, TokenType}; use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_keystore::Encrypter; use mas_storage::{ - compat::{lookup_active_compat_access_token, lookup_active_compat_refresh_token}, - oauth2::{ - access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, - }, - Clock, + compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, + oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository}, + user::{BrowserSessionRepository, UserRepository}, + BoxClock, BoxRepository, Clock, }; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, requests::{IntrospectionRequest, IntrospectionResponse}, scope::ScopeToken, }; -use sqlx::PgPool; use thiserror::Error; use crate::impl_from_error_for_route; @@ -97,8 +95,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl From for RouteError { fn from(_e: TokenFormatError) -> Self { @@ -125,18 +122,17 @@ const API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:org.matrix.msc #[allow(clippy::too_many_lines)] pub(crate) async fn post( + clock: BoxClock, State(http_client_factory): State, - State(pool): State, + mut repo: BoxRepository, State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { - let clock = Clock::default(); - let mut conn = pool.acquire().await?; - let client = client_authorization .credentials - .fetch(&mut conn) - .await? + .fetch(&mut repo) + .await + .unwrap() .ok_or(RouteError::ClientNotFound)?; let method = match &client.token_endpoint_auth_method { @@ -167,48 +163,103 @@ pub(crate) async fn post( let reply = match token_type { TokenType::AccessToken => { - let (token, session) = lookup_active_access_token(&mut conn, token) + let token = repo + .oauth2_access_token() + .find_by_token(token) .await? + .filter(|t| t.is_valid(clock.now())) + .ok_or(RouteError::UnknownToken)?; + + let session = repo + .oauth2_session() + .lookup(token.session_id) + .await? + .filter(|s| s.is_valid()) + // XXX: is that the right error to bubble up? + .ok_or(RouteError::UnknownToken)?; + + let browser_session = repo + .browser_session() + .lookup(session.user_session_id) + .await? + // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; IntrospectionResponse { active: true, scope: Some(session.scope), - client_id: Some(session.client.client_id), - username: Some(session.browser_session.user.username), + client_id: Some(session.client_id.to_string()), + username: Some(browser_session.user.username), token_type: Some(OAuthTokenTypeHint::AccessToken), exp: Some(token.expires_at), iat: Some(token.created_at), nbf: Some(token.created_at), - sub: Some(session.browser_session.user.sub), + sub: Some(browser_session.user.sub), aud: None, iss: None, - jti: None, + jti: Some(token.jti()), } } + TokenType::RefreshToken => { - let (token, session) = lookup_active_refresh_token(&mut conn, token) + let token = repo + .oauth2_refresh_token() + .find_by_token(token) .await? + .filter(|t| t.is_valid()) + .ok_or(RouteError::UnknownToken)?; + + let session = repo + .oauth2_session() + .lookup(token.session_id) + .await? + .filter(|s| s.is_valid()) + // XXX: is that the right error to bubble up? + .ok_or(RouteError::UnknownToken)?; + + let browser_session = repo + .browser_session() + .lookup(session.user_session_id) + .await? + // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; IntrospectionResponse { active: true, scope: Some(session.scope), - client_id: Some(session.client.client_id), - username: Some(session.browser_session.user.username), + client_id: Some(session.client_id.to_string()), + username: Some(browser_session.user.username), token_type: Some(OAuthTokenTypeHint::RefreshToken), exp: None, iat: Some(token.created_at), nbf: Some(token.created_at), - sub: Some(session.browser_session.user.sub), + sub: Some(browser_session.user.sub), aud: None, iss: None, - jti: None, + jti: Some(token.jti()), } } + TokenType::CompatAccessToken => { - let (token, session) = lookup_active_compat_access_token(&mut conn, &clock, token) + let access_token = repo + .compat_access_token() + .find_by_token(token) .await? + .filter(|t| t.is_valid(clock.now())) + .ok_or(RouteError::UnknownToken)?; + + let session = repo + .compat_session() + .lookup(access_token.session_id) + .await? + .filter(|s| s.is_valid()) + .ok_or(RouteError::UnknownToken)?; + + let user = repo + .user() + .lookup(session.user_id) + .await? + // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; let device_scope = session.device.to_scope_token(); @@ -218,22 +269,39 @@ pub(crate) async fn post( active: true, scope: Some(scope), client_id: Some("legacy".into()), - username: Some(session.user.username), + username: Some(user.username), token_type: Some(OAuthTokenTypeHint::AccessToken), - exp: token.expires_at, - iat: Some(token.created_at), - nbf: Some(token.created_at), - sub: Some(session.user.sub), + exp: access_token.expires_at, + iat: Some(access_token.created_at), + nbf: Some(access_token.created_at), + sub: Some(user.sub), aud: None, iss: None, jti: None, } } + TokenType::CompatRefreshToken => { - let (refresh_token, _access_token, session) = - lookup_active_compat_refresh_token(&mut conn, token) - .await? - .ok_or(RouteError::UnknownToken)?; + let refresh_token = repo + .compat_refresh_token() + .find_by_token(token) + .await? + .filter(|t| t.is_valid()) + .ok_or(RouteError::UnknownToken)?; + + let session = repo + .compat_session() + .lookup(refresh_token.session_id) + .await? + .filter(|s| s.is_valid()) + .ok_or(RouteError::UnknownToken)?; + + let user = repo + .user() + .lookup(session.user_id) + .await? + // XXX: is that the right error to bubble up? + .ok_or(RouteError::UnknownToken)?; let device_scope = session.device.to_scope_token(); let scope = [API_SCOPE, device_scope].into_iter().collect(); @@ -242,12 +310,12 @@ pub(crate) async fn post( active: true, scope: Some(scope), client_id: Some("legacy".into()), - username: Some(session.user.username), + username: Some(user.username), token_type: Some(OAuthTokenTypeHint::RefreshToken), exp: None, iat: Some(refresh_token.created_at), nbf: Some(refresh_token.created_at), - sub: Some(session.user.sub), + sub: Some(user.sub), aud: None, iss: None, jti: None, diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index 25b734cfd..650a19ab7 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -19,7 +19,7 @@ use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; use mas_policy::{PolicyFactory, Violation}; -use mas_storage::oauth2::client::insert_client; +use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng}; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, registration::{ @@ -27,10 +27,8 @@ use oauth2_types::{ }, }; use rand::distributions::{Alphanumeric, DistString}; -use sqlx::PgPool; use thiserror::Error; use tracing::info; -use ulid::Ulid; use crate::impl_from_error_for_route; @@ -49,7 +47,7 @@ pub(crate) enum RouteError { PolicyDenied(Vec), } -impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -107,12 +105,13 @@ impl IntoResponse for RouteError { #[tracing::instrument(skip_all, err)] pub(crate) async fn post( - State(pool): State, + mut rng: BoxRng, + clock: BoxClock, + mut repo: BoxRepository, State(policy_factory): State>, State(encrypter): State, Json(body): Json, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); info!(?body, "Client registration"); // Validate the body @@ -124,16 +123,6 @@ pub(crate) async fn post( return Err(RouteError::PolicyDenied(res.violations)); } - // Contacts was checked by the policy - let contacts = metadata.contacts.as_deref().unwrap_or_default(); - - // Grab a txn - let mut txn = pool.begin().await?; - - let now = clock.now(); - // Let's generate a random client ID - let client_id = Ulid::from_datetime_with_source(now.into(), &mut rng); - let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method { Some( OAuthClientAuthenticationMethod::ClientSecretJwt @@ -148,41 +137,42 @@ pub(crate) async fn post( _ => (None, None), }; - insert_client( - &mut txn, - &mut rng, - &clock, - client_id, - metadata.redirect_uris(), - encrypted_client_secret.as_deref(), - //&metadata.response_types(), - metadata.grant_types(), - contacts, - metadata - .client_name - .as_ref() - .map(|l| l.non_localized().as_ref()), - metadata.logo_uri.as_ref().map(Localized::non_localized), - metadata.client_uri.as_ref().map(Localized::non_localized), - metadata.policy_uri.as_ref().map(Localized::non_localized), - metadata.tos_uri.as_ref().map(Localized::non_localized), - metadata.jwks_uri.as_ref(), - metadata.jwks.as_ref(), - // XXX: those might not be right, should be function calls - metadata.id_token_signed_response_alg.as_ref(), - metadata.userinfo_signed_response_alg.as_ref(), - metadata.token_endpoint_auth_method.as_ref(), - metadata.token_endpoint_auth_signing_alg.as_ref(), - metadata.initiate_login_uri.as_ref(), - ) - .await?; + let client = repo + .oauth2_client() + .add( + &mut rng, + &clock, + metadata.redirect_uris().to_vec(), + encrypted_client_secret, + //&metadata.response_types(), + metadata.grant_types().to_vec(), + metadata.contacts.clone().unwrap_or_default(), + metadata + .client_name + .clone() + .map(Localized::to_non_localized), + metadata.logo_uri.clone().map(Localized::to_non_localized), + metadata.client_uri.clone().map(Localized::to_non_localized), + metadata.policy_uri.clone().map(Localized::to_non_localized), + metadata.tos_uri.clone().map(Localized::to_non_localized), + metadata.jwks_uri.clone(), + metadata.jwks.clone(), + // XXX: those might not be right, should be function calls + metadata.id_token_signed_response_alg.clone(), + metadata.userinfo_signed_response_alg.clone(), + metadata.token_endpoint_auth_method.clone(), + metadata.token_endpoint_auth_signing_alg.clone(), + metadata.initiate_login_uri.clone(), + ) + .await?; - txn.commit().await?; + repo.save().await?; let response = ClientRegistrationResponse { - client_id: client_id.to_string(), + client_id: client.client_id, client_secret, - client_id_issued_at: Some(now), + // XXX: we should have a `created_at` field on the clients + client_id_issued_at: Some(client.id.datetime().into()), client_secret_expires_at: None, }; diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index a6d899f4f..682813bf4 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -31,11 +31,13 @@ use mas_jose::{ }; use mas_keystore::{Encrypter, Keystore}; use mas_router::UrlBuilder; -use mas_storage::oauth2::{ - access_token::{add_access_token, revoke_access_token}, - authorization_grant::{exchange_grant, lookup_grant_by_code}, - end_oauth_session, - refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token}, +use mas_storage::{ + oauth2::{ + OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, + OAuth2RefreshTokenRepository, OAuth2SessionRepository, + }, + user::BrowserSessionRepository, + BoxClock, BoxRepository, BoxRng, Clock, }; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -47,7 +49,6 @@ use oauth2_types::{ }; use serde::Serialize; use serde_with::{serde_as, skip_serializing_none}; -use sqlx::{PgPool, Postgres, Transaction}; use thiserror::Error; use tracing::debug; use url::Url; @@ -102,12 +103,21 @@ pub(crate) enum RouteError { #[error("no suitable key found for signing")] InvalidSigningKey, + + #[error("failed to load browser session")] + NoSuchBrowserSession, + + #[error("failed to load oauth session")] + NoSuchOAuthSession, } impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { - Self::Internal(_) | Self::InvalidSigningKey => ( + Self::Internal(_) + | Self::InvalidSigningKey + | Self::NoSuchBrowserSession + | Self::NoSuchOAuthSession => ( StatusCode::INTERNAL_SERVER_ERROR, Json(ClientError::from(ClientErrorCode::ServerError)), ), @@ -139,8 +149,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::claims::ClaimError); impl_from_error_for_route!(mas_jose::claims::TokenHashError); @@ -148,18 +157,18 @@ impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); #[tracing::instrument(skip_all, err)] pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(http_client_factory): State, State(key_store): State, State(url_builder): State, - State(pool): State, + mut repo: BoxRepository, State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { - let mut txn = pool.begin().await?; - let client = client_authorization .credentials - .fetch(&mut txn) + .fetch(&mut repo) .await? .ok_or(RouteError::ClientNotFound)?; @@ -175,18 +184,29 @@ pub(crate) async fn post( let form = client_authorization.form.ok_or(RouteError::BadRequest)?; - let reply = match form { + let (reply, repo) = match form { AccessTokenRequest::AuthorizationCode(grant) => { - authorization_code_grant(&grant, &client, &key_store, &url_builder, txn).await? + authorization_code_grant( + &mut rng, + &clock, + &grant, + &client, + &key_store, + &url_builder, + repo, + ) + .await? } AccessTokenRequest::RefreshToken(grant) => { - refresh_token_grant(&grant, &client, txn).await? + refresh_token_grant(&mut rng, &clock, &grant, &client, repo).await? } _ => { return Err(RouteError::InvalidGrant); } }; + repo.save().await?; + let mut headers = HeaderMap::new(); headers.typed_insert(CacheControl::new().with_no_store()); headers.typed_insert(Pragma::no_cache()); @@ -196,23 +216,23 @@ pub(crate) async fn post( #[allow(clippy::too_many_lines)] async fn authorization_code_grant( + mut rng: &mut BoxRng, + clock: &impl Clock, grant: &AuthorizationCodeGrant, client: &Client, key_store: &Keystore, url_builder: &UrlBuilder, - mut txn: Transaction<'_, Postgres>, -) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - - // TODO: there is a bunch of unnecessary cloning here - // TODO: handle "not found" cases - let authz_grant = lookup_grant_by_code(&mut txn, &grant.code) + mut repo: BoxRepository, +) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { + let authz_grant = repo + .oauth2_authorization_grant() + .find_by_code(&grant.code) .await? .ok_or(RouteError::GrantNotFound)?; let now = clock.now(); - let session = match authz_grant.stage { + let session_id = match authz_grant.stage { AuthorizationGrantStage::Cancelled { cancelled_at } => { debug!(%cancelled_at, "Authorization grant was cancelled"); return Err(RouteError::InvalidGrant); @@ -220,15 +240,20 @@ async fn authorization_code_grant( AuthorizationGrantStage::Exchanged { exchanged_at, fulfilled_at, - session, + session_id, } => { debug!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged"); // Ending the session if the token was already exchanged more than 20s ago if now - exchanged_at > Duration::seconds(20) { debug!("Ending potentially compromised session"); - end_oauth_session(&mut txn, &clock, session).await?; - txn.commit().await?; + let session = repo + .oauth2_session() + .lookup(session_id) + .await? + .ok_or(RouteError::NoSuchOAuthSession)?; + repo.oauth2_session().finish(clock, session).await?; + repo.save().await?; } return Err(RouteError::InvalidGrant); @@ -238,7 +263,7 @@ async fn authorization_code_grant( return Err(RouteError::InvalidGrant); } AuthorizationGrantStage::Fulfilled { - ref session, + session_id, fulfilled_at, } => { if now - fulfilled_at > Duration::minutes(10) { @@ -246,14 +271,20 @@ async fn authorization_code_grant( return Err(RouteError::InvalidGrant); } - session + session_id } }; + let session = repo + .oauth2_session() + .lookup(session_id) + .await? + .ok_or(RouteError::NoSuchOAuthSession)?; + // This should never happen, since we looked up in the database using the code let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?; - if client.client_id != session.client.client_id { + if client.id != session.client_id { return Err(RouteError::UnauthorizedClient); } @@ -267,31 +298,25 @@ async fn authorization_code_grant( } }; - let browser_session = &session.browser_session; + let browser_session = repo + .browser_session() + .lookup(session.user_session_id) + .await? + .ok_or(RouteError::NoSuchBrowserSession)?; let ttl = Duration::minutes(5); let access_token_str = TokenType::AccessToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - let access_token = add_access_token( - &mut txn, - &mut rng, - &clock, - session, - access_token_str.clone(), - ttl, - ) - .await?; + let access_token = repo + .oauth2_access_token() + .add(&mut rng, clock, &session, access_token_str, ttl) + .await?; - let _refresh_token = add_refresh_token( - &mut txn, - &mut rng, - &clock, - session, - access_token, - refresh_token_str.clone(), - ) - .await?; + let refresh_token = repo + .oauth2_refresh_token() + .add(&mut rng, clock, &session, &access_token, refresh_token_str) + .await?; let id_token = if session.scope.contains(&scope::OPENID) { let mut claims = HashMap::new(); @@ -317,7 +342,7 @@ async fn authorization_code_grant( .signing_key_for_algorithm(&alg) .ok_or(RouteError::InvalidSigningKey)?; - claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token_str)?)?; + claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token.access_token)?)?; claims::C_HASH.insert(&mut claims, hash_token(&alg, &grant.code)?)?; let signer = key.params().signing_key_for_alg(&alg)?; @@ -330,34 +355,46 @@ async fn authorization_code_grant( None }; - let mut params = AccessTokenResponse::new(access_token_str) + let mut params = AccessTokenResponse::new(access_token.access_token) .with_expires_in(ttl) - .with_refresh_token(refresh_token_str) + .with_refresh_token(refresh_token.refresh_token) .with_scope(session.scope.clone()); if let Some(id_token) = id_token { params = params.with_id_token(id_token); } - exchange_grant(&mut txn, &clock, authz_grant).await?; + repo.oauth2_authorization_grant() + .exchange(clock, authz_grant) + .await?; - txn.commit().await?; - - Ok(params) + Ok((params, repo)) } async fn refresh_token_grant( + mut rng: &mut BoxRng, + clock: &impl Clock, grant: &RefreshTokenGrant, client: &Client, - mut txn: Transaction<'_, Postgres>, -) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - - let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token) + mut repo: BoxRepository, +) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { + let refresh_token = repo + .oauth2_refresh_token() + .find_by_token(&grant.refresh_token) .await? .ok_or(RouteError::InvalidGrant)?; - if client.client_id != session.client.client_id { + let session = repo + .oauth2_session() + .lookup(refresh_token.session_id) + .await? + .ok_or(RouteError::NoSuchOAuthSession)?; + + if !refresh_token.is_valid() || !session.is_valid() { + return Err(RouteError::InvalidGrant); + } + + if client.id != session.client_id { // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 return Err(RouteError::InvalidGrant); } @@ -366,30 +403,34 @@ async fn refresh_token_grant( let access_token_str = TokenType::AccessToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - let new_access_token = add_access_token( - &mut txn, - &mut rng, - &clock, - &session, - access_token_str.clone(), - ttl, - ) - .await?; + let new_access_token = repo + .oauth2_access_token() + .add(&mut rng, clock, &session, access_token_str.clone(), ttl) + .await?; - let new_refresh_token = add_refresh_token( - &mut txn, - &mut rng, - &clock, - &session, - new_access_token, - refresh_token_str, - ) - .await?; + let new_refresh_token = repo + .oauth2_refresh_token() + .add( + &mut rng, + clock, + &session, + &new_access_token, + refresh_token_str, + ) + .await?; - consume_refresh_token(&mut txn, &clock, &refresh_token).await?; + let refresh_token = repo + .oauth2_refresh_token() + .consume(clock, refresh_token) + .await?; - if let Some(access_token) = refresh_token.access_token { - revoke_access_token(&mut txn, &clock, access_token).await?; + if let Some(access_token_id) = refresh_token.access_token_id { + let access_token = repo.oauth2_access_token().lookup(access_token_id).await?; + if let Some(access_token) = access_token { + repo.oauth2_access_token() + .revoke(clock, access_token) + .await?; + } } let params = AccessTokenResponse::new(access_token_str) @@ -397,7 +438,5 @@ async fn refresh_token_grant( .with_refresh_token(new_refresh_token.refresh_token) .with_scope(session.scope); - txn.commit().await?; - - Ok(params) + Ok((params, repo)) } diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 0369739d2..d2d27cd0b 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -28,10 +28,14 @@ use mas_jose::{ }; use mas_keystore::Keystore; use mas_router::UrlBuilder; +use mas_storage::{ + oauth2::OAuth2ClientRepository, + user::{BrowserSessionRepository, UserEmailRepository}, + BoxClock, BoxRepository, BoxRng, +}; use oauth2_types::scope; use serde::Serialize; use serde_with::skip_serializing_none; -use sqlx::PgPool; use thiserror::Error; use crate::impl_from_error_for_route; @@ -59,20 +63,31 @@ pub enum RouteError { Internal(Box), #[error("failed to authenticate")] - AuthorizationVerificationError(#[from] AuthorizationVerificationError), + AuthorizationVerificationError( + #[from] AuthorizationVerificationError, + ), #[error("no suitable key found for signing")] InvalidSigningKey, + + #[error("failed to load client")] + NoSuchClient, + + #[error("failed to load browser session")] + NoSuchBrowserSession, } -impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { - Self::Internal(_) | Self::InvalidSigningKey => { + Self::Internal(_) + | Self::InvalidSigningKey + | Self::NoSuchClient + | Self::NoSuchBrowserSession => { (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response() } Self::AuthorizationVerificationError(_e) => StatusCode::UNAUTHORIZED.into_response(), @@ -81,32 +96,43 @@ impl IntoResponse for RouteError { } pub async fn get( + mut rng: BoxRng, + clock: BoxClock, State(url_builder): State, - State(pool): State, + mut repo: BoxRepository, State(key_store): State, user_authorization: UserAuthorization, ) -> Result { - let (_clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let session = user_authorization.protected(&mut repo, &clock).await?; - let session = user_authorization.protected(&mut conn).await?; + let browser_session = repo + .browser_session() + .lookup(session.user_session_id) + .await? + .ok_or(RouteError::NoSuchBrowserSession)?; - let user = session.browser_session.user; - let mut user_info = UserInfo { - sub: user.sub, - username: user.username, - email: None, - email_verified: None, + let user = browser_session.user; + + let user_email = if session.scope.contains(&scope::EMAIL) { + repo.user_email().get_primary(&user).await? + } else { + None }; - if session.scope.contains(&scope::EMAIL) { - if let Some(email) = user.primary_email { - user_info.email_verified = Some(email.confirmed_at.is_some()); - user_info.email = Some(email.email); - } - } + let user_info = UserInfo { + sub: user.sub.clone(), + username: user.username.clone(), + email_verified: user_email.as_ref().map(|u| u.confirmed_at.is_some()), + email: user_email.map(|u| u.email), + }; - if let Some(alg) = session.client.userinfo_signed_response_alg { + let client = repo + .oauth2_client() + .lookup(session.client_id) + .await? + .ok_or(RouteError::NoSuchClient)?; + + if let Some(alg) = client.userinfo_signed_response_alg { let key = key_store .signing_key_for_algorithm(&alg) .ok_or(RouteError::InvalidSigningKey)?; @@ -117,7 +143,7 @@ pub async fn get( let user_info = SignedUserInfo { iss: url_builder.oidc_issuer().to_string(), - aud: session.client.client_id, + aud: client.client_id, user_info, }; diff --git a/crates/handlers/src/passwords.rs b/crates/handlers/src/passwords.rs index 89326c9a3..4be71a45a 100644 --- a/crates/handlers/src/passwords.rs +++ b/crates/handlers/src/passwords.rs @@ -71,7 +71,7 @@ impl PasswordManager { /// # Errors /// /// Returns an error if the hashing failed - #[tracing::instrument(skip_all)] + #[tracing::instrument(name = "passwords.hash", skip_all)] pub async fn hash( &self, rng: R, @@ -82,13 +82,16 @@ impl PasswordManager { let rng = rand_chacha::ChaChaRng::from_rng(rng)?; let hashers = self.hashers.clone(); let default_hasher_version = self.default_hasher; + let span = tracing::Span::current(); let hashed = tokio::task::spawn_blocking(move || { - let default_hasher = hashers - .get(&default_hasher_version) - .context("Default hasher not found")?; + span.in_scope(move || { + let default_hasher = hashers + .get(&default_hasher_version) + .context("Default hasher not found")?; - default_hasher.hash_blocking(rng, &password) + default_hasher.hash_blocking(rng, &password) + }) }) .await??; @@ -100,7 +103,7 @@ impl PasswordManager { /// # Errors /// /// Returns an error if the password hash verification failed - #[tracing::instrument(skip_all, fields(%scheme))] + #[tracing::instrument(name = "passwords.verify", skip_all, fields(%scheme))] pub async fn verify( &self, scheme: SchemeVersion, @@ -108,10 +111,13 @@ impl PasswordManager { hashed_password: String, ) -> Result<(), anyhow::Error> { let hashers = self.hashers.clone(); + let span = tracing::Span::current(); tokio::task::spawn_blocking(move || { - let hasher = hashers.get(&scheme).context("Hashing scheme not found")?; - hasher.verify_blocking(&hashed_password, &password) + span.in_scope(move || { + let hasher = hashers.get(&scheme).context("Hashing scheme not found")?; + hasher.verify_blocking(&hashed_password, &password) + }) }) .await??; @@ -124,7 +130,7 @@ impl PasswordManager { /// # Errors /// /// Returns an error if the password hash verification failed - #[tracing::instrument(skip_all, fields(%scheme))] + #[tracing::instrument(name = "passwords.verify_and_upgrade", skip_all, fields(%scheme))] pub async fn verify_and_upgrade( &self, rng: R, diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 787124512..8da6231af 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -22,8 +22,10 @@ use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_keystore::Encrypter; use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; -use mas_storage::upstream_oauth2::lookup_provider; -use sqlx::PgPool; +use mas_storage::{ + upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, + BoxClock, BoxRepository, BoxRng, +}; use thiserror::Error; use ulid::Ulid; @@ -39,11 +41,10 @@ pub(crate) enum RouteError { Internal(Box), } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -55,18 +56,18 @@ impl IntoResponse for RouteError { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(http_client_factory): State, - State(pool): State, + mut repo: BoxRepository, State(url_builder): State, cookie_jar: PrivateCookieJar, Path(provider_id): Path, Query(query): Query, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - - let mut txn = pool.begin().await?; - - let provider = lookup_provider(&mut txn, provider_id) + let provider = repo + .upstream_oauth_provider() + .lookup(provider_id) .await? .ok_or(RouteError::ProviderNotFound)?; @@ -95,22 +96,23 @@ pub(crate) async fn get( &mut rng, )?; - let session = mas_storage::upstream_oauth2::add_session( - &mut txn, - &mut rng, - &clock, - &provider, - data.state.clone(), - data.code_challenge_verifier, - data.nonce, - ) - .await?; + let session = repo + .upstream_oauth_session() + .add( + &mut rng, + &clock, + &provider, + data.state.clone(), + data.code_challenge_verifier, + data.nonce, + ) + .await?; let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar) .add(session.id, provider.id, data.state, query.post_auth_action) - .save(cookie_jar, clock.now()); + .save(cookie_jar, &clock); - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, Redirect::temporary(url.as_str()))) } diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index ab31641c8..bc24c399a 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -25,12 +25,15 @@ use mas_oidc_client::requests::{ authorization_code::AuthorizationValidationData, jose::JwtVerificationData, }; use mas_router::{Route, UrlBuilder}; -use mas_storage::upstream_oauth2::{ - add_link, complete_session, lookup_link_by_subject, lookup_session, +use mas_storage::{ + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + BoxClock, BoxRepository, BoxRng, Clock, }; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -64,6 +67,9 @@ pub(crate) enum RouteError { #[error("Session not found")] SessionNotFound, + #[error("Provider not found")] + ProviderNotFound, + #[error("Provider mismatch")] ProviderMismatch, @@ -92,9 +98,8 @@ pub(crate) enum RouteError { Internal(Box), } -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_http::ClientInitError); -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::JwksError); impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError); @@ -104,6 +109,7 @@ impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { + Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(), Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(), Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(), @@ -113,8 +119,10 @@ impl IntoResponse for RouteError { #[allow(clippy::too_many_lines, clippy::too_many_arguments)] pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(http_client_factory): State, - State(pool): State, + mut repo: BoxRepository, State(url_builder): State, State(encrypter): State, State(keystore): State, @@ -122,30 +130,34 @@ pub(crate) async fn get( Path(provider_id): Path, Query(params): Query, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - - let mut txn = pool.begin().await?; + let provider = repo + .upstream_oauth_provider() + .lookup(provider_id) + .await? + .ok_or(RouteError::ProviderNotFound)?; let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let (session_id, _post_auth_action) = sessions_cookie .find_session(provider_id, ¶ms.state) .map_err(|_| RouteError::MissingCookie)?; - let (provider, session) = lookup_session(&mut txn, session_id) + let session = repo + .upstream_oauth_session() + .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; - if provider.id != provider_id { + if provider.id != session.provider_id { // The provider in the session cookie should match the one from the URL return Err(RouteError::ProviderMismatch); } - if params.state != session.state { + if params.state != session.state_str { // The state in the session cookie should match the one from the params return Err(RouteError::StateMismatch); } - if session.completed() { + if !session.is_pending() { // The session was already completed return Err(RouteError::AlreadyCompleted); } @@ -194,7 +206,7 @@ pub(crate) async fn get( // TODO: all that should be borrowed let validation_data = AuthorizationValidationData { - state: session.state.clone(), + state: session.state_str.clone(), nonce: session.nonce.clone(), code_challenge_verifier: session.code_challenge_verifier.clone(), redirect_uri, @@ -231,20 +243,29 @@ pub(crate) async fn get( let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?; // Look for an existing link - let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject).await?; + let maybe_link = repo + .upstream_oauth_link() + .find_by_subject(&provider, &subject) + .await?; let link = if let Some(link) = maybe_link { link } else { - add_link(&mut txn, &mut rng, &clock, &provider, subject).await? + repo.upstream_oauth_link() + .add(&mut rng, &clock, &provider, subject) + .await? }; - let session = complete_session(&mut txn, &clock, session, &link, response.id_token).await?; + let session = repo + .upstream_oauth_session() + .complete_with_link(&clock, session, &link, response.id_token) + .await?; + let cookie_jar = sessions_cookie .add_link_to_session(session.id, link.id)? - .save(cookie_jar, clock.now()); + .save(cookie_jar, &clock); - txn.commit().await?; + repo.save().await?; Ok(( cookie_jar, diff --git a/crates/handlers/src/upstream_oauth2/cookie.rs b/crates/handlers/src/upstream_oauth2/cookie.rs index be1d3edf0..92cfa5650 100644 --- a/crates/handlers/src/upstream_oauth2/cookie.rs +++ b/crates/handlers/src/upstream_oauth2/cookie.rs @@ -18,6 +18,7 @@ use axum_extra::extract::{cookie::Cookie, PrivateCookieJar}; use chrono::{DateTime, Duration, NaiveDateTime, Utc}; use mas_axum_utils::CookieExt; use mas_router::PostAuthAction; +use mas_storage::Clock; use serde::{Deserialize, Serialize}; use thiserror::Error; use time::OffsetDateTime; @@ -65,11 +66,11 @@ impl UpstreamSessions { } /// Save the upstreams sessions to the cookie jar - pub fn save( - self, - cookie_jar: PrivateCookieJar, - now: DateTime, - ) -> PrivateCookieJar { + pub fn save(self, cookie_jar: PrivateCookieJar, clock: &C) -> PrivateCookieJar + where + C: Clock, + { + let now = clock.now(); let this = self.expire(now); let mut cookie = Cookie::named(COOKIE_NAME).encode(&this); cookie.set_path("/"); diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 15c5ac93d..30d678cf2 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -25,17 +25,15 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_storage::{ - upstream_oauth2::{ - associate_link_to_user, consume_session, lookup_link, lookup_session_on_link, - }, - user::{add_user, authenticate_session_with_upstream, lookup_user, start_session}, + upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, + user::{BrowserSessionRepository, UserRepository}, + BoxClock, BoxRepository, BoxRng, }; use mas_templates::{ EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink, }; use serde::Deserialize; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -52,6 +50,10 @@ pub(crate) enum RouteError { #[error("Session not found")] SessionNotFound, + /// Couldn't find the user + #[error("User not found")] + UserNotFound, + /// Session was already consumed #[error("Session already consumed")] SessionConsumed, @@ -66,11 +68,10 @@ pub(crate) enum RouteError { Internal(Box), } -impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -91,48 +92,60 @@ pub(crate) enum FormData { } pub(crate) async fn get( - State(pool): State, + mut rng: BoxRng, + clock: BoxClock, + mut repo: BoxRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(link_id): Path, ) -> Result { - let mut txn = pool.begin().await?; - let (clock, mut rng) = crate::clock_and_rng(); - let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let (session_id, _post_auth_action) = sessions_cookie .lookup_link(link_id) .map_err(|_| RouteError::MissingCookie)?; - let link = lookup_link(&mut txn, link_id) + let link = repo + .upstream_oauth_link() + .lookup(link_id) .await? .ok_or(RouteError::LinkNotFound)?; - // This checks that we're in a browser session which is allowed to consume this - // link: the upstream auth session should have been started in this browser. - let upstream_session = lookup_session_on_link(&mut txn, &link, session_id) + let upstream_session = repo + .upstream_oauth_session() + .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; - if upstream_session.consumed() { + // This checks that we're in a browser session which is allowed to consume this + // link: the upstream auth session should have been started in this browser. + if upstream_session.link_id() != Some(link.id) { + return Err(RouteError::SessionNotFound); + } + + if upstream_session.is_consumed() { return Err(RouteError::SessionConsumed); } let (user_session_info, cookie_jar) = cookie_jar.session_info(); - let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); - let maybe_user_session = user_session_info.load_session(&mut txn).await?; + let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + let maybe_user_session = user_session_info.load_session(&mut repo).await?; let render = match (maybe_user_session, link.user_id) { - (Some(mut session), Some(user_id)) if session.user.id == user_id => { + (Some(session), Some(user_id)) if session.user.id == user_id => { // Session already linked, and link matches the currently logged // user. Mark the session as consumed and renew the authentication. - consume_session(&mut txn, &clock, upstream_session).await?; - authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link) + repo.upstream_oauth_session() + .consume(&clock, upstream_session) + .await?; + + let session = repo + .browser_session() + .authenticate_with_upstream(&mut rng, &clock, session, &link) .await?; cookie_jar = cookie_jar.set_session(&session); - txn.commit().await?; + repo.save().await?; let ctx = EmptyContext .with_session(session) @@ -147,7 +160,11 @@ pub(crate) async fn get( // Session already linked, but link doesn't match the currently // logged user. Suggest logging out of the current user // and logging in with the new one - let user = lookup_user(&mut txn, user_id).await?; + let user = repo + .user() + .lookup(user_id) + .await? + .ok_or(RouteError::UserNotFound)?; let ctx = UpstreamExistingLinkContext::new(user) .with_session(user_session) @@ -167,7 +184,11 @@ pub(crate) async fn get( (None, Some(user_id)) => { // Session linked, but user not logged in: do the login - let user = lookup_user(&mut txn, user_id).await?; + let user = repo + .user() + .lookup(user_id) + .await? + .ok_or(RouteError::UserNotFound)?; let ctx = UpstreamExistingLinkContext::new(user).with_csrf(csrf_token.form_value()); @@ -187,14 +208,14 @@ pub(crate) async fn get( } pub(crate) async fn post( - State(pool): State, + mut rng: BoxRng, + clock: BoxClock, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(link_id): Path, Form(form): Form>, ) -> Result { - let mut txn = pool.begin().await?; - let (clock, mut rng) = crate::clock_and_rng(); - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let (session_id, post_auth_action) = sessions_cookie @@ -205,53 +226,77 @@ pub(crate) async fn post( post_auth_action: post_auth_action.cloned(), }; - let link = lookup_link(&mut txn, link_id) + let link = repo + .upstream_oauth_link() + .lookup(link_id) .await? .ok_or(RouteError::LinkNotFound)?; - // This checks that we're in a browser session which is allowed to consume this - // link: the upstream auth session should have been started in this browser. - let upstream_session = lookup_session_on_link(&mut txn, &link, session_id) + let upstream_session = repo + .upstream_oauth_session() + .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; - if upstream_session.consumed() { + // This checks that we're in a browser session which is allowed to consume this + // link: the upstream auth session should have been started in this browser. + if upstream_session.link_id() != Some(link.id) { + return Err(RouteError::SessionNotFound); + } + + if upstream_session.is_consumed() { return Err(RouteError::SessionConsumed); } let (user_session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_user_session = user_session_info.load_session(&mut txn).await?; + let maybe_user_session = user_session_info.load_session(&mut repo).await?; - let mut session = match (maybe_user_session, link.user_id, form) { + let session = match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { - associate_link_to_user(&mut txn, &link, &session.user).await?; + repo.upstream_oauth_link() + .associate_to_user(&link, &session.user) + .await?; + session } (None, Some(user_id), FormData::Login) => { - let user = lookup_user(&mut txn, user_id).await?; - start_session(&mut txn, &mut rng, &clock, user).await? + let user = repo + .user() + .lookup(user_id) + .await? + .ok_or(RouteError::UserNotFound)?; + + repo.browser_session().add(&mut rng, &clock, &user).await? } (None, None, FormData::Register { username }) => { - let user = add_user(&mut txn, &mut rng, &clock, &username).await?; - associate_link_to_user(&mut txn, &link, &user).await?; + let user = repo.user().add(&mut rng, &clock, username).await?; + repo.upstream_oauth_link() + .associate_to_user(&link, &user) + .await?; - start_session(&mut txn, &mut rng, &clock, user).await? + repo.browser_session().add(&mut rng, &clock, &user).await? } _ => return Err(RouteError::InvalidFormAction), }; - consume_session(&mut txn, &clock, upstream_session).await?; - authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link).await?; + repo.upstream_oauth_session() + .consume(&clock, upstream_session) + .await?; + + let session = repo + .browser_session() + .authenticate_with_upstream(&mut rng, &clock, session, &link) + .await?; let cookie_jar = sessions_cookie .consume_link(link_id)? - .save(cookie_jar, clock.now()); + .save(cookie_jar, &clock); let cookie_jar = cookie_jar.set_session(&session); - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, post_auth_action.go_next())) } diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 06fe7e067..e26c9cc1a 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -24,10 +24,9 @@ use mas_axum_utils::{ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::user::add_user_email; +use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng}; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; -use sqlx::PgPool; use super::start_email_verification; use crate::views::shared::OptionalPostAuthAction; @@ -38,17 +37,16 @@ pub struct EmailForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.begin().await?; - - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -67,19 +65,18 @@ pub(crate) async fn get( } pub(crate) async fn post( - State(pool): State, + mut rng: BoxRng, + clock: BoxClock, + mut repo: BoxRepository, State(mailer): State, cookie_jar: PrivateCookieJar, Query(query): Query, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; - - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -88,7 +85,11 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = add_user_email(&mut txn, &mut rng, &clock, &session.user, form.email).await?; + let user_email = repo + .user_email() + .add(&mut rng, &clock, &session.user, form.email) + .await?; + let next = mas_router::AccountVerifyEmail::new(user_email.id); let next = if let Some(action) = query.post_auth_action { next.and_then(action) @@ -97,7 +98,7 @@ pub(crate) async fn post( }; start_email_verification( &mailer, - &mut txn, + &mut repo, &mut rng, &clock, &session.user, @@ -105,7 +106,7 @@ pub(crate) async fn post( ) .await?; - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, next.go()).into_response()) } diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 061e360c0..ad997e0dc 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::{anyhow, Context}; use axum::{ extract::{Form, State}, response::{Html, IntoResponse, Response}, @@ -28,16 +29,11 @@ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ - user::{ - add_user_email, add_user_email_verification_code, get_user_email, get_user_emails, - remove_user_email, set_user_email_as_primary, - }, - Clock, + user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, }; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; use serde::Deserialize; -use sqlx::{PgExecutor, PgPool}; use tracing::info; pub mod add; @@ -53,37 +49,35 @@ pub enum ManagementForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - - let mut conn = pool.acquire().await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if let Some(session) = maybe_session { - render(&mut rng, &clock, templates, session, cookie_jar, &mut conn).await + render(&mut rng, &clock, templates, session, cookie_jar, &mut repo).await } else { let login = mas_router::Login::default(); Ok((cookie_jar, login.go()).into_response()) } } -async fn render( +async fn render( rng: impl Rng + Send, - clock: &Clock, + clock: &impl Clock, templates: Templates, session: BrowserSession, cookie_jar: PrivateCookieJar, - executor: impl PgExecutor<'_>, + repo: &mut impl RepositoryAccess, ) -> Result { - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); - let emails = get_user_emails(executor, &session.user).await?; + let emails = repo.user_email().all(&session.user).await?; let ctx = AccountEmailsContext::new(emails) .with_session(session) @@ -94,11 +88,11 @@ async fn render( Ok((cookie_jar, Html(content)).into_response()) } -async fn start_email_verification( +async fn start_email_verification( mailer: &Mailer, - executor: impl PgExecutor<'_>, + repo: &mut impl RepositoryAccess, mut rng: impl Rng + Send, - clock: &Clock, + clock: &impl Clock, user: &User, user_email: UserEmail, ) -> anyhow::Result<()> { @@ -108,15 +102,10 @@ async fn start_email_verification( let address: Address = user_email.email.parse()?; - let verification = add_user_email_verification_code( - executor, - &mut rng, - clock, - user_email, - Duration::hours(8), - code, - ) - .await?; + let verification = repo + .user_email() + .add_verification_code(&mut rng, clock, &user_email, Duration::hours(8), code) + .await?; // And send the verification email let mailbox = Mailbox::new(Some(user.username.clone()), address); @@ -126,25 +115,24 @@ async fn start_email_verification( mailer.send_verification_email(mailbox, &context).await?; info!( - email.id = %verification.email.id, + email.id = %user_email.id, "Verification email sent" ); Ok(()) } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: BoxRepository, State(mailer): State, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let mut session = if let Some(session) = maybe_session { session @@ -153,53 +141,69 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; match form { ManagementForm::Add { email } => { - let user_email = - add_user_email(&mut txn, &mut rng, &clock, &session.user, email).await?; - let next = mas_router::AccountVerifyEmail::new(user_email.id); - start_email_verification( - &mailer, - &mut txn, - &mut rng, - &clock, - &session.user, - user_email, - ) - .await?; - txn.commit().await?; + let email = repo + .user_email() + .add(&mut rng, &clock, &session.user, email) + .await?; + + let next = mas_router::AccountVerifyEmail::new(email.id); + start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) + .await?; + repo.save().await?; return Ok((cookie_jar, next.go()).into_response()); } ManagementForm::ResendConfirmation { id } => { let id = id.parse()?; - let user_email = get_user_email(&mut txn, &session.user, id).await?; - let next = mas_router::AccountVerifyEmail::new(user_email.id); - start_email_verification( - &mailer, - &mut txn, - &mut rng, - &clock, - &session.user, - user_email, - ) - .await?; - txn.commit().await?; + let email = repo + .user_email() + .lookup(id) + .await? + .context("Email not found")?; + + if email.user_id != session.user.id { + return Err(anyhow!("Email not found").into()); + } + + let next = mas_router::AccountVerifyEmail::new(email.id); + start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) + .await?; + repo.save().await?; return Ok((cookie_jar, next.go()).into_response()); } ManagementForm::Remove { id } => { let id = id.parse()?; - let email = get_user_email(&mut txn, &session.user, id).await?; - remove_user_email(&mut txn, email).await?; + let email = repo + .user_email() + .lookup(id) + .await? + .context("Email not found")?; + + if email.user_id != session.user.id { + return Err(anyhow!("Email not found").into()); + } + + repo.user_email().remove(email).await?; } ManagementForm::SetPrimary { id } => { let id = id.parse()?; - let email = get_user_email(&mut txn, &session.user, id).await?; - set_user_email_as_primary(&mut txn, &email).await?; - session.user.primary_email = Some(email); + let email = repo + .user_email() + .lookup(id) + .await? + .context("Email not found")?; + + if email.user_id != session.user.id { + return Err(anyhow!("Email not found").into()); + } + + repo.user_email().set_as_primary(&email).await?; + session.user.primary_user_email_id = Some(email.id); } }; @@ -209,11 +213,11 @@ pub(crate) async fn post( templates.clone(), session, cookie_jar, - &mut txn, + &mut repo, ) .await?; - txn.commit().await?; + repo.save().await?; Ok(reply) } diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 0ce6503a5..d7f074b8c 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -24,16 +24,9 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{ - user::{ - consume_email_verification, lookup_user_email_by_id, lookup_user_email_verification_code, - mark_user_email_as_verified, set_user_email_as_primary, - }, - Clock, -}; +use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng}; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; -use sqlx::PgPool; use ulid::Ulid; use crate::views::shared::OptionalPostAuthAction; @@ -44,19 +37,18 @@ pub struct CodeForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: BoxRepository, Query(query): Query, Path(id): Path, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; - - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -65,8 +57,11 @@ pub(crate) async fn get( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = lookup_user_email_by_id(&mut conn, &session.user, id) + let user_email = repo + .user_email() + .lookup(id) .await? + .filter(|u| u.user_id == session.user.id) .context("Could not find user email")?; if user_email.confirmed_at.is_some() { @@ -85,19 +80,17 @@ pub(crate) async fn get( } pub(crate) async fn post( - State(pool): State, + clock: BoxClock, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Query(query): Query, Path(id): Path, Form(form): Form>, ) -> Result { - let clock = Clock::default(); - let mut txn = pool.begin().await?; - - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -106,25 +99,33 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let email = lookup_user_email_by_id(&mut txn, &session.user, id) + let user_email = repo + .user_email() + .lookup(id) .await? + .filter(|u| u.user_id == session.user.id) .context("Could not find user email")?; - if session.user.primary_email.is_none() { - set_user_email_as_primary(&mut txn, &email).await?; - } - - // TODO: make those 8 hours configurable - let verification = lookup_user_email_verification_code(&mut txn, &clock, email, &form.code) + let verification = repo + .user_email() + .find_verification_code(&clock, &user_email, &form.code) .await? .context("Invalid code")?; // TODO: display nice errors if the code was already consumed or expired - let verification = consume_email_verification(&mut txn, &clock, verification).await?; + repo.user_email() + .consume_verification_code(&clock, verification) + .await?; - let _email = mark_user_email_as_verified(&mut txn, &clock, verification.email).await?; + if session.user.primary_user_email_id.is_none() { + repo.user_email().set_as_primary(&user_email).await?; + } - txn.commit().await?; + repo.user_email() + .mark_as_verified(&clock, user_email) + .await?; + + repo.save().await?; let destination = query.go_next_or_default(&mas_router::AccountEmails); Ok((cookie_jar, destination).into_response()) diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 87eec0961..162b78993 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -23,22 +23,23 @@ use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::user::{count_active_sessions, get_user_emails}; +use mas_storage::{ + user::{BrowserSessionRepository, UserEmailRepository}, + BoxClock, BoxRepository, BoxRng, +}; use mas_templates::{AccountContext, TemplateContext, Templates}; -use sqlx::PgPool; pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; - - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -47,9 +48,9 @@ pub(crate) async fn get( return Ok((cookie_jar, login.go()).into_response()); }; - let active_sessions = count_active_sessions(&mut conn, &session.user).await?; + let active_sessions = repo.browser_session().count_active(&session.user).await?; - let emails = get_user_emails(&mut conn, &session.user).await?; + let emails = repo.user_email().all(&session.user).await?; let ctx = AccountContext::new(active_sessions, emails) .with_session(session) diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 2ba4b3f8a..d9e026105 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -26,13 +26,12 @@ use mas_data_model::BrowserSession; use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ - user::{add_user_password, authenticate_session_with_password, lookup_user_password}, - Clock, + user::{BrowserSessionRepository, UserPasswordRepository}, + BoxClock, BoxRepository, BoxRng, Clock, }; use mas_templates::{EmptyContext, TemplateContext, Templates}; use rand::Rng; use serde::Deserialize; -use sqlx::PgPool; use zeroize::Zeroizing; use crate::passwords::PasswordManager; @@ -45,16 +44,15 @@ pub struct ChangeForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if let Some(session) = maybe_session { render(&mut rng, &clock, templates, session, cookie_jar).await @@ -66,12 +64,12 @@ pub(crate) async fn get( async fn render( rng: impl Rng + Send, - clock: &Clock, + clock: &impl Clock, templates: Templates, session: BrowserSession, cookie_jar: PrivateCookieJar, ) -> Result { - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); let ctx = EmptyContext .with_session(session) @@ -83,29 +81,30 @@ async fn render( } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(password_manager): State, State(templates): State, - State(pool): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; - - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let mut session = if let Some(session) = maybe_session { + let session = if let Some(session) = maybe_session { session } else { let login = mas_router::Login::and_then(mas_router::PostAuthAction::ChangePassword); return Ok((cookie_jar, login.go()).into_response()); }; - let user_password = lookup_user_password(&mut txn, &session.user) + let user_password = repo + .user_password() + .active(&session.user) .await? .context("user has no password")?; @@ -127,23 +126,26 @@ pub(crate) async fn post( } let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?; - let user_password = add_user_password( - &mut txn, - &mut rng, - &clock, - &session.user, - version, - hashed_password, - None, - ) - .await?; + let user_password = repo + .user_password() + .add( + &mut rng, + &clock, + &session.user, + version, + hashed_password, + None, + ) + .await?; - authenticate_session_with_password(&mut txn, &mut rng, &clock, &mut session, &user_password) + let session = repo + .browser_session() + .authenticate_with_password(&mut rng, &clock, session, &user_password) .await?; let reply = render(&mut rng, &clock, templates.clone(), session, cookie_jar).await?; - txn.commit().await?; + repo.save().await?; Ok(reply) } diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index 2471296da..7b4be7df0 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -20,21 +20,20 @@ use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_keystore::Encrypter; use mas_router::UrlBuilder; +use mas_storage::{BoxClock, BoxRepository, BoxRng}; use mas_templates::{IndexContext, TemplateContext, Templates}; -use sqlx::PgPool; pub async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, State(url_builder): State, - State(pool): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; - - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let session = session_info.load_session(&mut conn).await?; + let session = session_info.load_session(&mut repo).await?; let ctx = IndexContext::new(url_builder.oidc_discovery()) .maybe_with_session(session) diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 24fc17b75..3083eae00 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -24,18 +24,15 @@ use mas_axum_utils::{ use mas_data_model::BrowserSession; use mas_keystore::Encrypter; use mas_storage::{ - user::{ - add_user_password, authenticate_session_with_password, lookup_user_by_username, - lookup_user_password, start_session, - }, - Clock, + upstream_oauth2::UpstreamOAuthProviderRepository, + user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, + BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, }; use mas_templates::{ FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, }; use rand::{CryptoRng, Rng}; use serde::{Deserialize, Serialize}; -use sqlx::{PgConnection, PgPool}; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -52,29 +49,28 @@ impl ToFormState for LoginForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; - - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if maybe_session.is_some() { let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) } else { - let providers = mas_storage::upstream_oauth2::get_providers(&mut conn).await?; + let providers = repo.upstream_oauth_provider().all().await?; let content = render( LoginContext::default().with_upstrem_providers(providers), query, csrf_token, - &mut conn, + &mut repo, &templates, ) .await?; @@ -84,19 +80,18 @@ pub(crate) async fn get( } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(password_manager): State, State(templates): State, - State(pool): State, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let form = cookie_jar.verify_form(&clock, form)?; - let form = cookie_jar.verify_form(clock.now(), form)?; - - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); // Validate the form let state = { @@ -114,14 +109,14 @@ pub(crate) async fn post( }; if !state.is_valid() { - let providers = mas_storage::upstream_oauth2::get_providers(&mut conn).await?; + let providers = repo.upstream_oauth_provider().all().await?; let content = render( LoginContext::default() .with_form_state(state) .with_upstrem_providers(providers), query, csrf_token, - &mut conn, + &mut repo, &templates, ) .await?; @@ -129,11 +124,9 @@ pub(crate) async fn post( return Ok((cookie_jar, Html(content)).into_response()); } - lookup_user_by_username(&mut conn, &form.username).await?; - match login( password_manager, - &mut conn, + &mut repo, rng, &clock, &form.username, @@ -142,6 +135,8 @@ pub(crate) async fn post( .await { Ok(session_info) => { + repo.save().await?; + let cookie_jar = cookie_jar.set_session(&session_info); let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) @@ -153,7 +148,7 @@ pub(crate) async fn post( LoginContext::default().with_form_state(state), query, csrf_token, - &mut conn, + &mut repo, &templates, ) .await?; @@ -166,21 +161,25 @@ pub(crate) async fn post( // TODO: move that logic elsewhere? async fn login( password_manager: PasswordManager, - conn: &mut PgConnection, + repo: &mut impl RepositoryAccess, mut rng: impl Rng + CryptoRng + Send, - clock: &Clock, + clock: &impl Clock, username: &str, password: &str, ) -> Result { // XXX: we're loosing the error context here // First, lookup the user - let user = lookup_user_by_username(&mut *conn, username) + let user = repo + .user() + .find_by_username(username) .await .map_err(|_e| FormError::Internal)? .ok_or(FormError::InvalidCredentials)?; // And its password - let user_password = lookup_user_password(&mut *conn, &user) + let user_password = repo + .user_password() + .active(&user) .await .map_err(|_e| FormError::Internal)? .ok_or(FormError::InvalidCredentials)?; @@ -200,28 +199,32 @@ async fn login( let user_password = if let Some((version, new_password_hash)) = new_password_hash { // Save the upgraded password - add_user_password( - &mut *conn, - &mut rng, - clock, - &user, - version, - new_password_hash, - Some(user_password), - ) - .await - .map_err(|_| FormError::Internal)? + repo.user_password() + .add( + &mut rng, + clock, + &user, + version, + new_password_hash, + Some(&user_password), + ) + .await + .map_err(|_| FormError::Internal)? } else { user_password }; // Start a new session - let mut user_session = start_session(&mut *conn, &mut rng, clock, user) + let user_session = repo + .browser_session() + .add(&mut rng, clock, &user) .await .map_err(|_| FormError::Internal)?; // And mark it as authenticated by the password - authenticate_session_with_password(&mut *conn, rng, clock, &mut user_session, &user_password) + let user_session = repo + .browser_session() + .authenticate_with_password(&mut rng, clock, user_session, &user_password) .await .map_err(|_| FormError::Internal)?; @@ -232,10 +235,10 @@ async fn render( ctx: LoginContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - conn: &mut PgConnection, + repo: &mut impl RepositoryAccess, templates: &Templates, ) -> Result { - let next = action.load_context(conn).await?; + let next = action.load_context(repo).await?; let ctx = if let Some(next) = next { ctx.with_post_action(next) } else { diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 07043e64a..9b0f3602e 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{ - extract::{Form, State}, - response::IntoResponse, -}; +use axum::{extract::Form, response::IntoResponse}; use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, @@ -23,29 +20,26 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; -use mas_storage::{user::end_session, Clock}; -use sqlx::PgPool; +use mas_storage::{user::BrowserSessionRepository, BoxClock, BoxRepository}; pub(crate) async fn post( - State(pool): State, + clock: BoxClock, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Form(form): Form>>, ) -> Result { - let clock = Clock::default(); - let mut txn = pool.begin().await?; - - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; let (session_info, mut cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if let Some(session) = maybe_session { - end_session(&mut txn, &clock, &session).await?; + repo.browser_session().finish(&clock, session).await?; cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended()); } - txn.commit().await?; + repo.save().await?; let destination = if let Some(action) = form { action.go_next() diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 875189a76..12f205d6a 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -24,12 +24,12 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::user::{ - add_user_password, authenticate_session_with_password, lookup_user_password, +use mas_storage::{ + user::{BrowserSessionRepository, UserPasswordRepository}, + BoxClock, BoxRepository, BoxRng, }; use mas_templates::{ReauthContext, TemplateContext, Templates}; use serde::Deserialize; -use sqlx::PgPool; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -41,18 +41,17 @@ pub(crate) struct ReauthForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; - - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -64,7 +63,7 @@ pub(crate) async fn get( }; let ctx = ReauthContext::default(); - let next = query.load_context(&mut conn).await?; + let next = query.load_context(&mut repo).await?; let ctx = if let Some(next) = next { ctx.with_post_action(next) } else { @@ -78,22 +77,21 @@ pub(crate) async fn get( } pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(password_manager): State, - State(pool): State, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; - - let form = cookie_jar.verify_form(clock.now(), form)?; + let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let mut session = if let Some(session) = maybe_session { + let session = if let Some(session) = maybe_session { session } else { // If there is no session, redirect to the login screen, keeping the @@ -103,7 +101,9 @@ pub(crate) async fn post( }; // Load the user password - let user_password = lookup_user_password(&mut txn, &session.user) + let user_password = repo + .user_password() + .active(&session.user) .await? .context("User has no password")?; @@ -122,25 +122,28 @@ pub(crate) async fn post( let user_password = if let Some((version, new_password_hash)) = new_password_hash { // Save the upgraded password - add_user_password( - &mut *txn, - &mut rng, - &clock, - &session.user, - version, - new_password_hash, - Some(user_password), - ) - .await? + repo.user_password() + .add( + &mut rng, + &clock, + &session.user, + version, + new_password_hash, + Some(&user_password), + ) + .await? } else { user_password }; // Mark the session as authenticated by the password - authenticate_session_with_password(&mut txn, rng, &clock, &mut session, &user_password).await?; + let session = repo + .browser_session() + .authenticate_with_password(&mut rng, &clock, session, &user_password) + .await?; let cookie_jar = cookie_jar.set_session(&session); - txn.commit().await?; + repo.save().await?; let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 9a12efac2..64e30af72 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -31,9 +31,9 @@ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::Route; -use mas_storage::user::{ - add_user, add_user_email, add_user_email_verification_code, add_user_password, - authenticate_session_with_password, start_session, username_exists, +use mas_storage::{ + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; use mas_templates::{ EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, @@ -41,7 +41,6 @@ use mas_templates::{ }; use rand::{distributions::Uniform, Rng}; use serde::{Deserialize, Serialize}; -use sqlx::{PgConnection, PgPool}; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -60,18 +59,17 @@ impl ToFormState for RegisterForm { } pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; - - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if maybe_session.is_some() { let reply = query.go_next(); @@ -81,7 +79,7 @@ pub(crate) async fn get( RegisterContext::default(), query, csrf_token, - &mut conn, + &mut repo, &templates, ) .await?; @@ -92,21 +90,20 @@ pub(crate) async fn get( #[allow(clippy::too_many_lines, clippy::too_many_arguments)] pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, State(password_manager): State, State(mailer): State, State(policy_factory): State>, State(templates): State, - State(pool): State, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let form = cookie_jar.verify_form(&clock, form)?; - let form = cookie_jar.verify_form(clock.now(), form)?; - - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); // Validate the form let state = { @@ -114,7 +111,7 @@ pub(crate) async fn post( if form.username.is_empty() { state.add_error_on_field(RegisterFormField::Username, FieldError::Required); - } else if username_exists(&mut txn, &form.username).await? { + } else if repo.user().exists(&form.username).await? { state.add_error_on_field(RegisterFormField::Username, FieldError::Exists); } @@ -177,7 +174,7 @@ pub(crate) async fn post( RegisterContext::default().with_form_state(state), query, csrf_token, - &mut txn, + &mut repo, &templates, ) .await?; @@ -185,21 +182,18 @@ pub(crate) async fn post( return Ok((cookie_jar, Html(content)).into_response()); } - let user = add_user(&mut txn, &mut rng, &clock, &form.username).await?; + let user = repo.user().add(&mut rng, &clock, form.username).await?; let password = Zeroizing::new(form.password.into_bytes()); let (version, hashed_password) = password_manager.hash(&mut rng, password).await?; - let user_password = add_user_password( - &mut txn, - &mut rng, - &clock, - &user, - version, - hashed_password, - None, - ) - .await?; + let user_password = repo + .user_password() + .add(&mut rng, &clock, &user, version, hashed_password, None) + .await?; - let user_email = add_user_email(&mut txn, &mut rng, &clock, &user, form.email).await?; + let user_email = repo + .user_email() + .add(&mut rng, &clock, &user, form.email) + .await?; // First, generate a code let range = Uniform::::from(0..1_000_000); @@ -208,15 +202,10 @@ pub(crate) async fn post( let address: Address = user_email.email.parse()?; - let verification = add_user_email_verification_code( - &mut txn, - &mut rng, - &clock, - user_email, - Duration::hours(8), - code, - ) - .await?; + let verification = repo + .user_email() + .add_verification_code(&mut rng, &clock, &user_email, Duration::hours(8), code) + .await?; // And send the verification email let mailbox = Mailbox::new(Some(user.username.clone()), address); @@ -225,14 +214,16 @@ pub(crate) async fn post( mailer.send_verification_email(mailbox, &context).await?; - let next = mas_router::AccountVerifyEmail::new(verification.email.id) - .and_maybe(query.post_auth_action); + let next = mas_router::AccountVerifyEmail::new(user_email.id).and_maybe(query.post_auth_action); - let mut session = start_session(&mut txn, &mut rng, &clock, user).await?; - authenticate_session_with_password(&mut txn, &mut rng, &clock, &mut session, &user_password) + let session = repo.browser_session().add(&mut rng, &clock, &user).await?; + + let session = repo + .browser_session() + .authenticate_with_password(&mut rng, &clock, session, &user_password) .await?; - txn.commit().await?; + repo.save().await?; let cookie_jar = cookie_jar.set_session(&session); Ok((cookie_jar, next.go()).into_response()) @@ -242,10 +233,10 @@ async fn render( ctx: RegisterContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - conn: &mut PgConnection, + repo: &mut impl RepositoryAccess, templates: &Templates, ) -> Result { - let next = action.load_context(conn).await?; + let next = action.load_context(repo).await?; let ctx = if let Some(next) = next { ctx.with_post_action(next) } else { diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index fcdef3b4a..69fdf901f 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -15,11 +15,13 @@ use anyhow::Context; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, + compat::CompatSsoLoginRepository, + oauth2::OAuth2AuthorizationGrantRepository, + upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, + RepositoryAccess, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; use serde::{Deserialize, Serialize}; -use sqlx::PgConnection; #[derive(Serialize, Deserialize, Default, Debug, Clone)] pub(crate) struct OptionalPostAuthAction { @@ -38,14 +40,16 @@ impl OptionalPostAuthAction { self.go_next_or_default(&mas_router::Index) } - pub async fn load_context( - &self, - conn: &mut PgConnection, + pub async fn load_context<'a>( + &'a self, + repo: &'a mut impl RepositoryAccess, ) -> anyhow::Result> { let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let ctx = match action { PostAuthAction::ContinueAuthorizationGrant { id } => { - let grant = get_grant_by_id(conn, id) + let grant = repo + .oauth2_authorization_grant() + .lookup(id) .await? .context("Failed to load authorization grant")?; let grant = Box::new(grant); @@ -53,7 +57,9 @@ impl OptionalPostAuthAction { } PostAuthAction::ContinueCompatSsoLogin { id } => { - let login = get_compat_sso_login_by_id(conn, id) + let login = repo + .compat_sso_login() + .lookup(id) .await? .context("Failed to load compat SSO login")?; let login = Box::new(login); @@ -63,14 +69,17 @@ impl OptionalPostAuthAction { PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword, PostAuthAction::LinkUpstream { id } => { - let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id) + let link = repo + .upstream_oauth_link() + .lookup(id) .await? .context("Failed to load upstream OAuth 2.0 link")?; - let provider = - mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id) - .await? - .context("Failed to load upstream OAuth 2.0 provider")?; + let provider = repo + .upstream_oauth_provider() + .lookup(link.provider_id) + .await? + .context("Failed to load upstream OAuth 2.0 provider")?; let provider = Box::new(provider); let link = Box::new(link); diff --git a/crates/keystore/src/lib.rs b/crates/keystore/src/lib.rs index f1a15ee65..0899bcdc8 100644 --- a/crates/keystore/src/lib.rs +++ b/crates/keystore/src/lib.rs @@ -15,12 +15,7 @@ //! A crate to store keys which can then be used to sign and verify JWTs. #![forbid(unsafe_code)] -#![deny( - clippy::all, - clippy::str_to_string, - rustdoc::broken_intra_doc_links, - rustdoc::all -)] +#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)] #![warn(clippy::pedantic)] use std::{ops::Deref, sync::Arc}; diff --git a/crates/listener/src/lib.rs b/crates/listener/src/lib.rs index 091d5809a..5aaabee4a 100644 --- a/crates/listener/src/lib.rs +++ b/crates/listener/src/lib.rs @@ -22,6 +22,9 @@ #![warn(clippy::pedantic)] #![allow(clippy::module_name_repetitions)] +//! An utility crate to build flexible [`hyper`] listeners, with optional TLS +//! and proxy protocol support. + use self::{maybe_tls::TlsStreamInfo, proxy_protocol::ProxyProtocolV1Info}; pub mod maybe_tls; diff --git a/crates/listener/src/server.rs b/crates/listener/src/server.rs index c919017fc..2527302f1 100644 --- a/crates/listener/src/server.rs +++ b/crates/listener/src/server.rs @@ -39,7 +39,7 @@ pub struct Server { impl Server { /// # Errors /// - /// Returns an error if the listener couldn't be converted via [`TryInfo`] + /// Returns an error if the listener couldn't be converted via [`TryInto`] pub fn try_new(listener: L, service: S) -> Result where L: TryInto, diff --git a/crates/oauth2-types/src/registration/mod.rs b/crates/oauth2-types/src/registration/mod.rs index 18aa24fa0..0d9589966 100644 --- a/crates/oauth2-types/src/registration/mod.rs +++ b/crates/oauth2-types/src/registration/mod.rs @@ -90,6 +90,11 @@ impl Localized { &self.non_localized } + /// Get the non-localized variant. + pub fn to_non_localized(self) -> T { + self.non_localized + } + /// Get the variant corresponding to the given language, if it exists. pub fn get(&self, language: Option<&LanguageTag>) -> Option<&T> { match language { diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 27cfb1b32..d895339a1 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -69,7 +69,7 @@ pub struct PolicyFactory { } impl PolicyFactory { - #[tracing::instrument(skip(source), err)] + #[tracing::instrument(name = "policy.load", skip(source), err)] pub async fn load( mut source: impl AsyncRead + std::marker::Unpin, data: serde_json::Value, @@ -108,7 +108,7 @@ impl PolicyFactory { authorization_grant_endpoint, }; - // Try to instanciate + // Try to instantiate factory .instantiate() .await @@ -117,7 +117,7 @@ impl PolicyFactory { Ok(factory) } - #[tracing::instrument(skip(self), err)] + #[tracing::instrument(name = "policy.instantiate", skip_all, err)] pub async fn instantiate(&self) -> Result { let mut store = Store::new(&self.engine, ()); let runtime = Runtime::new(&mut store, &self.module) @@ -189,7 +189,14 @@ pub enum EvaluationError { } impl Policy { - #[tracing::instrument(skip(self, password))] + #[tracing::instrument( + name = "policy.evaluate.register", + skip_all, + fields( + data.username = username, + ), + err, + )] pub async fn evaluate_register( &mut self, username: &str, @@ -234,7 +241,15 @@ impl Policy { Ok(res) } - #[tracing::instrument(skip(self))] + #[tracing::instrument( + name = "policy.evaluate.authorization_grant", + skip_all, + fields( + data.authorization_grant.id = %authorization_grant.id, + data.user.id = %user.id, + ), + err, + )] pub async fn evaluate_authorization_grant( &mut self, authorization_grant: &AuthorizationGrant, diff --git a/crates/spa/src/lib.rs b/crates/spa/src/lib.rs index 58838f28a..fbef53863 100644 --- a/crates/spa/src/lib.rs +++ b/crates/spa/src/lib.rs @@ -21,6 +21,8 @@ )] #![warn(clippy::pedantic)] +//! A crate to help serve single-page apps built by Vite. + mod vite; use std::{future::Future, pin::Pin}; diff --git a/crates/storage-pg/Cargo.toml b/crates/storage-pg/Cargo.toml new file mode 100644 index 000000000..3373a21fa --- /dev/null +++ b/crates/storage-pg/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "mas-storage-pg" +version = "0.1.0" +authors = ["Quentin Gliech "] +edition = "2021" +license = "Apache-2.0" + +[dependencies] +async-trait = "0.1.60" +sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] } +chrono = { version = "0.4.23", features = ["serde"] } +serde = { version = "1.0.152", features = ["derive"] } +serde_json = "1.0.91" +thiserror = "1.0.38" +tracing = "0.1.37" +futures-util = "0.3.25" + +rand = "0.8.5" +rand_chacha = "0.3.1" +url = { version = "2.3.1", features = ["serde"] } +uuid = "1.2.2" +ulid = { version = "1.0.0", features = ["uuid", "serde"] } + +oauth2-types = { path = "../oauth2-types" } +mas-storage = { path = "../storage" } +mas-data-model = { path = "../data-model" } +mas-iana = { path = "../iana" } +mas-jose = { path = "../jose" } diff --git a/crates/storage/build.rs b/crates/storage-pg/build.rs similarity index 92% rename from crates/storage/build.rs rename to crates/storage-pg/build.rs index dd5b11428..dca71bd6d 100644 --- a/crates/storage/build.rs +++ b/crates/storage-pg/build.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/crates/storage/migrations/20221018142001_init.sql b/crates/storage-pg/migrations/20221018142001_init.sql similarity index 100% rename from crates/storage/migrations/20221018142001_init.sql rename to crates/storage-pg/migrations/20221018142001_init.sql diff --git a/crates/storage/migrations/20221121151402_upstream_oauth.sql b/crates/storage-pg/migrations/20221121151402_upstream_oauth.sql similarity index 100% rename from crates/storage/migrations/20221121151402_upstream_oauth.sql rename to crates/storage-pg/migrations/20221121151402_upstream_oauth.sql diff --git a/crates/storage/migrations/20221213145242_password_schemes.sql b/crates/storage-pg/migrations/20221213145242_password_schemes.sql similarity index 100% rename from crates/storage/migrations/20221213145242_password_schemes.sql rename to crates/storage-pg/migrations/20221213145242_password_schemes.sql diff --git a/crates/storage-pg/sqlx-data.json b/crates/storage-pg/sqlx-data.json new file mode 100644 index 000000000..94527512c --- /dev/null +++ b/crates/storage-pg/sqlx-data.json @@ -0,0 +1,2441 @@ +{ + "db": "PostgreSQL", + "015f7ad7c8d5403ce4dfb71d598fd9af472689d5aef7c1c4b1c594ca57c02237": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz", + "Uuid" + ] + } + }, + "query": "\n UPDATE oauth2_authorization_grants\n SET fulfilled_at = $2\n , oauth2_session_id = $3\n WHERE oauth2_authorization_grant_id = $1\n " + }, + "08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35": { + "describe": { + "columns": [ + { + "name": "user_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "username", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "primary_user_email_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE user_id = $1\n " + }, + "154e2e4488ff87e09163698750b56a43127cee4e1392785416a586d40a4d9b21": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_provider_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "issuer", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "scope", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "client_id", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "encrypted_client_secret", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "token_endpoint_signing_alg", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_method", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 7, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + false, + false + ], + "parameters": { + "Left": [] + } + }, + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n " + }, + "18c3e56c72ef26bd42653c379767ffdd97bb06cb1686dfbf4099f3ad3d7b22c8": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " + }, + "1a8701f5672de052bb766933f60b93249acc7237b996e8b93cd61b9f69c902ff": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Timestamptz" + ] + } + }, + "query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n " + }, + "1d372f36c382ab16264cea54537af3544ea6d6d75d10b432b07dbd0dadd2fa4e": { + "describe": { + "columns": [ + { + "name": "user_email_confirmation_code_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_email_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "code", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 5, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Text", + "Uuid" + ] + } + }, + "query": "\n SELECT user_email_confirmation_code_id\n , user_email_id\n , code\n , created_at\n , expires_at\n , consumed_at\n FROM user_email_confirmation_codes\n WHERE code = $1\n AND user_email_id = $2\n " + }, + "1ee5cecfafd4726a4ebc08da8a34c09178e6e1e072581c8fca9d3d76967792cb": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n " + }, + "1f6297fb323e9f2fbfa1c9e3225c0b3037c8c4714533a6240c62275332aa58dc": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM user_email_confirmation_codes\n WHERE user_email_id = $1\n " + }, + "2564bf6366eb59268c41fb25bb40d0e4e9e1fd1f9ea53b7a359c9025d7304223": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n " + }, + "3d66f3121b11ce923b9c60609b510a8ca899640e78cc8f5b03168622928ffe94": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM user_emails\n WHERE user_email_id = $1\n " + }, + "4187907bfc770b2c76f741671d5e672f5c35eed7c9a9e57ff52888b1768a5ed6": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_link_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_provider_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "subject", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_link_id = $1\n " + }, + "4192c1144c0ea530cf1aa77993a38e94cd5cf8b5c42cb037efb7917c6fc44a1d": { + "describe": { + "columns": [ + { + "name": "user_email_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "email", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "confirmed_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT user_email_id\n , user_id\n , email\n , created_at\n , confirmed_at\n FROM user_emails\n\n WHERE user_email_id = $1\n " + }, + "41c1aafbd338c24476f27d342cf80eef7de2836e85b078232d143d6712fc2be4": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n " + }, + "432e199b0d47fe299d840c91159726c0a4f89f65b4dc3e33ddad58aabf6b148b": { + "describe": { + "columns": [ + { + "name": "compat_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "compat_access_token_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE refresh_token = $1\n " + }, + "446a8d7bd8532a751810401adfab924dc20785c91770ed43d62df2e590e8da71": { + "describe": { + "columns": [ + { + "name": "user_password_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "hashed_password", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "version", + "ordinal": 2, + "type_info": "Int4" + }, + { + "name": "upgraded_from_id", + "ordinal": 3, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT up.user_password_id\n , up.hashed_password\n , up.version\n , up.upgraded_from_id\n , up.created_at\n FROM user_passwords up\n WHERE up.user_id = $1\n ORDER BY up.created_at DESC\n LIMIT 1\n " + }, + "477f79556e5777b38feb85013b4f04dbb8230e4b0b0bcc45f669d7b8d0b91db4": { + "describe": { + "columns": [ + { + "name": "oauth2_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "revoked_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE access_token = $1\n " + }, + "478f0ad710da8bfd803c6cddd982bc504d1b6bd0f5283de53c8c7b1b4b7dafd4": { + "describe": { + "columns": [ + { + "name": "compat_sso_login_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "login_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "redirect_uri", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "fulfilled_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "exchanged_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 6, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + true + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n\n FROM compat_sso_logins\n WHERE login_token = $1\n " + }, + "496813daf6f8486353e7f509a64362626daebb0121c3c9420b96e2d8157f1e07": { + "describe": { + "columns": [ + { + "name": "oauth2_authorization_grant_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "cancelled_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "fulfilled_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "exchanged_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "scope", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "state", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "redirect_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "response_mode", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "nonce", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "max_age", + "ordinal": 10, + "type_info": "Int4" + }, + { + "name": "oauth2_client_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "authorization_code", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "response_type_code", + "ordinal": 13, + "type_info": "Bool" + }, + { + "name": "response_type_id_token", + "ordinal": 14, + "type_info": "Bool" + }, + { + "name": "code_challenge", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "code_challenge_method", + "ordinal": 16, + "type_info": "Text" + }, + { + "name": "requires_consent", + "ordinal": 17, + "type_info": "Bool" + }, + { + "name": "oauth2_session_id", + "ordinal": 18, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_authorization_grant_id\n , created_at\n , cancelled_at\n , fulfilled_at\n , exchanged_at\n , scope\n , state\n , redirect_uri\n , response_mode\n , nonce\n , max_age\n , oauth2_client_id\n , authorization_code\n , response_type_code\n , response_type_id_token\n , code_challenge\n , code_challenge_method\n , requires_consent\n , oauth2_session_id\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n " + }, + "4d79ce892e4595edb8b801e94fb0cbef28facdfd2e45d1c72c57f47418fbe24b": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_sso_logins\n SET\n compat_session_id = $2,\n fulfilled_at = $3\n WHERE\n compat_sso_login_id = $1\n " + }, + "53ad718642644b47a2d49f768d81bd993088526923769a9147281686c2d47591": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE access_token = $1\n " + }, + "583ae9a0db9cd55fa57a179339550f3dab1bfc76f35ad488e1560ea37f7ed029": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO oauth2_sessions\n ( oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n )\n VALUES ($1, $2, $3, $4, $5)\n " + }, + "5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n " + }, + "64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Text", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n " + }, + "67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_authorization_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_provider_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_link_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "state", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "code_challenge_verifier", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "nonce", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "id_token", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 7, + "type_info": "Timestamptz" + }, + { + "name": "completed_at", + "ordinal": 8, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 9, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + true, + false, + true, + false, + true, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n " + }, + "689ffbfc5137ec788e89062ad679bbe6b23a8861c09a7246dc1659c28f12bf8d": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Timestamptz", + "Uuid" + ] + } + }, + "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n " + }, + "6a3b543ec53ce242866d1e84de26728e6dd275cae745f9c646e3824d859c5384": { + "describe": { + "columns": [ + { + "name": "oauth2_authorization_grant_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "cancelled_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "fulfilled_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "exchanged_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "scope", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "state", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "redirect_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "response_mode", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "nonce", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "max_age", + "ordinal": 10, + "type_info": "Int4" + }, + { + "name": "oauth2_client_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "authorization_code", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "response_type_code", + "ordinal": 13, + "type_info": "Bool" + }, + { + "name": "response_type_id_token", + "ordinal": 14, + "type_info": "Bool" + }, + { + "name": "code_challenge", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "code_challenge_method", + "ordinal": 16, + "type_info": "Text" + }, + { + "name": "requires_consent", + "ordinal": 17, + "type_info": "Bool" + }, + { + "name": "oauth2_session_id", + "ordinal": 18, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT oauth2_authorization_grant_id\n , created_at\n , cancelled_at\n , fulfilled_at\n , exchanged_at\n , scope\n , state\n , redirect_uri\n , response_mode\n , nonce\n , max_age\n , oauth2_client_id\n , authorization_code\n , response_type_code\n , response_type_id_token\n , code_challenge\n , code_challenge_method\n , requires_consent\n , oauth2_session_id\n FROM\n oauth2_authorization_grants\n\n WHERE authorization_code = $1\n " + }, + "6e21e7d816f806da9bb5176931bdb550dee05c44c9d93f53df95fe3b4a840347": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n " + }, + "6f97b5f9ad0d4d15387150bea3839fb7f81015f7ceef61ecaadba64521895cff": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Int4", + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO user_passwords\n (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at)\n VALUES ($1, $2, $3, $4, $5, $6)\n " + }, + "751d549073d77ded84aea1aaba36d3b130ec71bc592d722eb75b959b80f0b4ff": { + "describe": { + "columns": [ + { + "name": "count!", + "ordinal": 0, + "type_info": "Int8" + } + ], + "nullable": [ + null + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.finished_at IS NULL\n " + }, + "77dfa9fae1a9c77b70476d7da19d3313a02886994cfff0690451229fb5ae2f77": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE compat_access_token_id = $1\n " + }, + "79295f3d3a75f831e9469aabfa720d381a254d00dbe39fef1e9652029d51b89b": { + "describe": { + "columns": [ + { + "name": "user_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_session_created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "user_session_finished_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "user_id", + "ordinal": 3, + "type_info": "Uuid" + }, + { + "name": "user_username", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "user_primary_user_email_id", + "ordinal": 5, + "type_info": "Uuid" + }, + { + "name": "last_authentication_id?", + "ordinal": 6, + "type_info": "Uuid" + }, + { + "name": "last_authd_at?", + "ordinal": 7, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , a.user_session_authentication_id AS \"last_authentication_id?\"\n , a.created_at AS \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications a\n USING (user_session_id)\n WHERE s.user_session_id = $1\n ORDER BY a.created_at DESC\n LIMIT 1\n " + }, + "7be139553610ace03193a99fe27fcb4e3d50c90accdaf22ca1cfeefdc9734300": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "UuidArray", + "Uuid", + "TextArray" + ] + } + }, + "query": "\n INSERT INTO oauth2_client_redirect_uris\n (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)\n SELECT id, $2, redirect_uri\n FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)\n " + }, + "7ce387b1b0aaf10e72adde667b19521b66eaafa51f73bf2f95e38b8f3b64a229": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + } + }, + "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " + }, + "836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c": { + "describe": { + "columns": [ + { + "name": "user_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "username", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "primary_user_email_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE username = $1\n " + }, + "85499663f1adc7b7439592063f06914089f6243126a177b365bde37db5f6b33d": { + "describe": { + "columns": [ + { + "name": "oauth2_client_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "encrypted_client_secret", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "redirect_uris!", + "ordinal": 2, + "type_info": "TextArray" + }, + { + "name": "grant_type_authorization_code", + "ordinal": 3, + "type_info": "Bool" + }, + { + "name": "grant_type_refresh_token", + "ordinal": 4, + "type_info": "Bool" + }, + { + "name": "client_name", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "logo_uri", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "client_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "policy_uri", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "tos_uri", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "jwks_uri", + "ordinal": 10, + "type_info": "Text" + }, + { + "name": "jwks", + "ordinal": 11, + "type_info": "Jsonb" + }, + { + "name": "id_token_signed_response_alg", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "userinfo_signed_response_alg", + "ordinal": 13, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_method", + "ordinal": 14, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_signing_alg", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "initiate_login_uri", + "ordinal": 16, + "type_info": "Text" + } + ], + "nullable": [ + false, + true, + null, + false, + false, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "parameters": { + "Left": [ + "UuidArray" + ] + } + }, + "query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n\n WHERE oauth2_client_id = ANY($1::uuid[])\n " + }, + "8a79c7c392dd930628caadec80c9b2645501475ab4feacbac59ca1bc52b16c3f": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Bool", + "Bool", + "Text", + "Jsonb", + "Text" + ] + } + }, + "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n " + }, + "8b7297c263336d70c2b647212b16f7ae39bc5cb1572e3a2dcfcd67f196a1fa39": { + "describe": { + "columns": [ + { + "name": "scope_token", + "ordinal": 0, + "type_info": "Text" + } + ], + "nullable": [ + false + ], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + } + }, + "query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n " + }, + "8f7a9fb1f24c24f8dbc3c193df2a742c9ac730ab958587b67297de2d4b843863": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_provider_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "issuer", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "scope", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "client_id", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "encrypted_client_secret", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "token_endpoint_signing_alg", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_method", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 7, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n " + }, + "90b5512c0c9dc3b3eb6500056cc72f9993216d9b553c2e33a7edec26ffb0fc59": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE user_emails\n SET confirmed_at = $2\n WHERE user_email_id = $1\n " + }, + "90fe32cb9c88a262a682c0db700fef7d69d6ce0be1f930d9f16c50b921a8b819": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO user_emails (user_email_id, user_id, email, created_at)\n VALUES ($1, $2, $3, $4)\n " + }, + "921d77c194609615a7e9a6fd806e9cc17a7927e3e5deb58f3917ceeb9ab4dede": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE user_email_confirmation_codes\n SET consumed_at = $2\n WHERE user_email_confirmation_code_id = $1\n " + }, + "9348d87f9e06b614c7e90bdc93bcf38236766aaf4d894bf768debdff2b59fae2": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n " + }, + "94fd96446b237c87bd6bf741f3c42b37ee751b87b7fcc459602bdf8c46962443": { + "describe": { + "columns": [ + { + "name": "exists!", + "ordinal": 0, + "type_info": "Bool" + } + ], + "nullable": [ + null + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT EXISTS(\n SELECT 1 FROM users WHERE username = $1\n ) AS \"exists!\"\n " + }, + "9a6c197ff4ad80217262d48f8792ce7e16bc5df0677c7cd4ecb4fdbc5ee86395": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "UuidArray", + "Uuid", + "Uuid", + "TextArray", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO oauth2_consents\n (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)\n SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)\n ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5\n " + }, + "9f7bdc034c618e47e49c467d0d7f5b8c297d055abe248cc876dbc12c5a7dc920": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO compat_refresh_tokens\n (compat_refresh_token_id, compat_session_id,\n compat_access_token_id, refresh_token, created_at)\n VALUES ($1, $2, $3, $4, $5)\n " + }, + "a2f7433f06fb4f6a7ad5ac6c1db18705276bce41e9b19d5d7e910ad4b767fb5e": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,\n refresh_token, created_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " + }, + "a300fe99c95679c5664646a6a525c0491829e97db45f3234483872ed38436322": { + "describe": { + "columns": [ + { + "name": "user_email_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "email", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "confirmed_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT user_email_id\n , user_id\n , email\n , created_at\n , confirmed_at\n FROM user_emails\n\n WHERE user_id = $1\n\n ORDER BY email ASC\n " + }, + "a6fa7811d0a7c62c7cccff96dc82db5b25462fa7669fde1941ccab4712585b20": { + "describe": { + "columns": [ + { + "name": "oauth2_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE oauth2_refresh_token_id = $1\n " + }, + "a7f780528882a2ae66c45435215763eed0582264861436eab3f862e3eb12cab1": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO compat_access_tokens\n (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " + }, + "ab34912b42a48a8b5c8d63e271b99b7d0b690a2471873c6654b1b6cf2079b95c": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n WHERE compat_session_id = $1\n " + }, + "afa86e79e3de2a83265cb0db8549d378a2f11b2a27bbd86d60558318c87eb698": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO oauth2_access_tokens\n (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " + }, + "aff08a8caabeb62f4929e6e901e7ca7c55e284c18c5c1d1e78821dd9bc961412": { + "describe": { + "columns": [ + { + "name": "user_email_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "email", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "confirmed_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid", + "Text" + ] + } + }, + "query": "\n SELECT user_email_id\n , user_id\n , email\n , created_at\n , confirmed_at\n FROM user_emails\n\n WHERE user_id = $1 AND email = $2\n " + }, + "b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n " + }, + "b515bbfb331e46acd3c0219f09223cc5d8d31cb41287e693dcb82c6e199f7991": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO user_email_confirmation_codes\n (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " + }, + "b6a6f5386dc89e4bc2ce56d578a29341848fce336d339b6bbf425956f5ed5032": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE oauth2_refresh_tokens\n SET consumed_at = $2\n WHERE oauth2_refresh_token_id = $1\n " + }, + "b700dc3f7d0f86f4904725d8357e34b7e457f857ed37c467c314142877fd5367": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE oauth2_sessions\n SET finished_at = $2\n WHERE oauth2_session_id = $1\n " + }, + "b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz", + "Text", + "Uuid" + ] + } + }, + "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n " + }, + "bbf62633c561706a762089bbab2f76a9ba3e2ed3539ef16accb601fb609c2ec9": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n " + }, + "bd1f6daa5fa1b10250c01f8b3fbe451646a9ceeefa6f72b9c4e29b6d05f17641": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n UPDATE users\n SET primary_user_email_id = user_emails.user_email_id\n FROM user_emails\n WHERE user_emails.user_email_id = $1\n AND users.user_id = user_emails.user_id\n " + }, + "c0ed9d70e496433d8686a499055d8a8376459109b6154a2c0c13b28462afa523": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Text", + "Text", + "Text", + "Int4", + "Text", + "Text", + "Text", + "Bool", + "Bool", + "Text", + "Bool", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)\n " + }, + "c1d90a7f2287ec779c81a521fab19e5ede3fa95484033e0312c30d9b6ecc03f0": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " + }, + "c5e7dbb22488aca427b85b3415bd1f1a1766ff865f2e08a5daa095d2a1ccbd56": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE oauth2_authorization_grants\n SET exchanged_at = $2\n WHERE oauth2_authorization_grant_id = $1\n " + }, + "d0b403e9c843ef19fa5ad60bec32ebf14a1ba0d01681c3836366d3f55e7851f4": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_refresh_tokens\n SET consumed_at = $2\n WHERE compat_refresh_token_id = $1\n " + }, + "d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c": { + "describe": { + "columns": [ + { + "name": "count", + "ordinal": 0, + "type_info": "Int8" + } + ], + "nullable": [ + null + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT COUNT(*)\n FROM user_emails\n WHERE user_id = $1\n " + }, + "d83421d4a16f4ad084dd0db5abb56d3688851c36a48a50aa6104e8291e73630d": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n requires_consent = 'f'\n WHERE\n og.oauth2_authorization_grant_id = $1\n " + }, + "db90cbc406a399f5447bd2c1d8018464f83b927dec620353516c0285b76fcf24": { + "describe": { + "columns": [ + { + "name": "oauth2_client_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "encrypted_client_secret", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "redirect_uris!", + "ordinal": 2, + "type_info": "TextArray" + }, + { + "name": "grant_type_authorization_code", + "ordinal": 3, + "type_info": "Bool" + }, + { + "name": "grant_type_refresh_token", + "ordinal": 4, + "type_info": "Bool" + }, + { + "name": "client_name", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "logo_uri", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "client_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "policy_uri", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "tos_uri", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "jwks_uri", + "ordinal": 10, + "type_info": "Text" + }, + { + "name": "jwks", + "ordinal": 11, + "type_info": "Jsonb" + }, + { + "name": "id_token_signed_response_alg", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "userinfo_signed_response_alg", + "ordinal": 13, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_method", + "ordinal": 14, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_signing_alg", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "initiate_login_uri", + "ordinal": 16, + "type_info": "Text" + } + ], + "nullable": [ + false, + true, + null, + false, + false, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n\n WHERE oauth2_client_id = $1\n " + }, + "dbf4be84eeff9ea51b00185faae2d453ab449017ed492bf6711dc7fceb630880": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Timestamptz", + "Uuid" + ] + } + }, + "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " + }, + "dd16942318bf38d9a245b2c86fedd3cbd6b65e7a13465552d79cd3c022122fd4": { + "describe": { + "columns": [ + { + "name": "oauth2_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "revoked_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE oauth2_access_token_id = $1\n " + }, + "ddb22dd9ae9367af65a607e1fdc48b3d9581d67deea0c168f24e02090082bb82": { + "describe": { + "columns": [ + { + "name": "compat_sso_login_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "login_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "redirect_uri", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "fulfilled_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "exchanged_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 6, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n\n FROM compat_sso_logins\n WHERE compat_sso_login_id = $1\n " + }, + "e35d56de7136d43d0803ec825b0612e4185cef838f105d66f18cb24865e45140": { + "describe": { + "columns": [ + { + "name": "compat_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "compat_access_token_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE compat_refresh_token_id = $1\n " + }, + "e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_link_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_provider_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "subject", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid", + "Text" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " + }, + "e709869c062ac50248b1f9f8f808cc2f5e7bef58a6c2e42a7bb0c1cb8f508671": { + "describe": { + "columns": [ + { + "name": "oauth2_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n " + }, + "f0ace1af3775192a555c4ebb59b81183f359771f9f77e5fad759d38d872541d1": { + "describe": { + "columns": [ + { + "name": "oauth2_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_session_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "oauth2_client_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "scope", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "finished_at", + "ordinal": 5, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n , finished_at\n FROM oauth2_sessions\n\n WHERE oauth2_session_id = $1\n " + }, + "f3ee06958d827b152c57328caa0a6030c372cb99cdb60e4b75a28afeb5096f45": { + "describe": { + "columns": [ + { + "name": "compat_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "device_id", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "finished_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_session_id\n , device_id\n , user_id\n , created_at\n , finished_at\n FROM compat_sessions\n WHERE compat_session_id = $1\n " + }, + "f5edcd4c306ca8179cdf9d4aab59fbba971b54611c91345849920954dd8089b3": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Bool", + "Bool", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Jsonb", + "Text", + "Text", + "Text", + "Text", + "Text" + ] + } + }, + "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " + } +} \ No newline at end of file diff --git a/crates/storage-pg/src/compat/access_token.rs b/crates/storage-pg/src/compat/access_token.rs new file mode 100644 index 000000000..70fabac79 --- /dev/null +++ b/crates/storage-pg/src/compat/access_token.rs @@ -0,0 +1,220 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; +use mas_data_model::{CompatAccessToken, CompatSession}; +use mas_storage::{compat::CompatAccessTokenRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +/// An implementation of [`CompatAccessTokenRepository`] for a PostgreSQL +/// connection +pub struct PgCompatAccessTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatAccessTokenRepository<'c> { + /// Create a new [`PgCompatAccessTokenRepository`] from an active PostgreSQL + /// connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatAccessTokenLookup { + compat_access_token_id: Uuid, + access_token: String, + created_at: DateTime, + expires_at: Option>, + compat_session_id: Uuid, +} + +impl From for CompatAccessToken { + fn from(value: CompatAccessTokenLookup) -> Self { + Self { + id: value.compat_access_token_id.into(), + session_id: value.compat_session_id.into(), + token: value.access_token, + created_at: value.created_at, + expires_at: value.expires_at, + } + } +} + +#[async_trait] +impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_access_token.lookup", + skip_all, + fields( + db.statement, + compat_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatAccessTokenLookup, + r#" + SELECT compat_access_token_id + , access_token + , created_at + , expires_at + , compat_session_id + + FROM compat_access_tokens + + WHERE compat_access_token_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_access_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatAccessTokenLookup, + r#" + SELECT compat_access_token_id + , access_token + , created_at + , expires_at + , compat_session_id + + FROM compat_access_tokens + + WHERE access_token = $1 + "#, + access_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_access_token.add", + skip_all, + fields( + db.statement, + compat_access_token.id, + %compat_session.id, + user.id = %compat_session.user_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + compat_session: &CompatSession, + token: String, + expires_after: Option, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); + + let expires_at = expires_after.map(|expires_after| created_at + expires_after); + + sqlx::query!( + r#" + INSERT INTO compat_access_tokens + (compat_access_token_id, compat_session_id, access_token, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(compat_session.id), + token, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatAccessToken { + id, + session_id: compat_session.id, + token, + created_at, + expires_at, + }) + } + + #[tracing::instrument( + name = "db.compat_access_token.expire", + skip_all, + fields( + db.statement, + %compat_access_token.id, + compat_session.id = %compat_access_token.session_id, + ), + err, + )] + async fn expire( + &mut self, + clock: &dyn Clock, + mut compat_access_token: CompatAccessToken, + ) -> Result { + let expires_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE compat_access_tokens + SET expires_at = $2 + WHERE compat_access_token_id = $1 + "#, + Uuid::from(compat_access_token.id), + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + compat_access_token.expires_at = Some(expires_at); + Ok(compat_access_token) + } +} diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs new file mode 100644 index 000000000..ab4ed33e6 --- /dev/null +++ b/crates/storage-pg/src/compat/mod.rs @@ -0,0 +1,449 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! A module containing PostgreSQL implementation of repositories for the +//! compatibility layer + +mod access_token; +mod refresh_token; +mod session; +mod sso_login; + +pub use self::{ + access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository, + session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository, +}; + +#[cfg(test)] +mod tests { + use chrono::Duration; + use mas_data_model::Device; + use mas_storage::{ + clock::MockClock, + compat::{ + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + }, + user::UserRepository, + Clock, Pagination, Repository, RepositoryAccess, + }; + use rand::SeedableRng; + use rand_chacha::ChaChaRng; + use sqlx::PgPool; + use ulid::Ulid; + + use crate::PgRepository; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_session_repository(pool: PgPool) { + const FIRST_TOKEN: &str = "first_access_token"; + const SECOND_TOKEN: &str = "second_access_token"; + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let device_str = device.as_str().to_owned(); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + assert_eq!(session.user_id, user.id); + assert_eq!(session.device.as_str(), device_str); + assert!(session.is_valid()); + assert!(!session.is_finished()); + + // Lookup the session and check it didn't change + let session_lookup = repo + .compat_session() + .lookup(session.id) + .await + .unwrap() + .expect("compat session not found"); + assert_eq!(session_lookup.id, session.id); + assert_eq!(session_lookup.user_id, user.id); + assert_eq!(session_lookup.device.as_str(), device_str); + assert!(session_lookup.is_valid()); + assert!(!session_lookup.is_finished()); + + // Finish the session + let session = repo.compat_session().finish(&clock, session).await.unwrap(); + assert!(!session.is_valid()); + assert!(session.is_finished()); + + // Reload the session and check again + let session_lookup = repo + .compat_session() + .lookup(session.id) + .await + .unwrap() + .expect("compat session not found"); + assert!(!session_lookup.is_valid()); + assert!(session_lookup.is_finished()); + } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_access_token_repository(pool: PgPool) { + const FIRST_TOKEN: &str = "first_access_token"; + const SECOND_TOKEN: &str = "second_access_token"; + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + + // Add an access token to that session + let token = repo + .compat_access_token() + .add( + &mut rng, + &clock, + &session, + FIRST_TOKEN.to_owned(), + Some(Duration::minutes(1)), + ) + .await + .unwrap(); + assert_eq!(token.session_id, session.id); + assert_eq!(token.token, FIRST_TOKEN); + + // Commit the txn and grab a new transaction, to test a conflict + repo.save().await.unwrap(); + + { + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + // Adding the same token a second time should conflict + assert!(repo + .compat_access_token() + .add( + &mut rng, + &clock, + &session, + FIRST_TOKEN.to_owned(), + Some(Duration::minutes(1)), + ) + .await + .is_err()); + repo.cancel().await.unwrap(); + } + + // Grab a new repo + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + + // Looking up via ID works + let token_lookup = repo + .compat_access_token() + .lookup(token.id) + .await + .unwrap() + .expect("compat access token not found"); + assert_eq!(token.id, token_lookup.id); + assert_eq!(token_lookup.session_id, session.id); + + // Looking up via the token value works + let token_lookup = repo + .compat_access_token() + .find_by_token(FIRST_TOKEN) + .await + .unwrap() + .expect("compat access token not found"); + assert_eq!(token.id, token_lookup.id); + assert_eq!(token_lookup.session_id, session.id); + + // Token is currently valid + assert!(token.is_valid(clock.now())); + + clock.advance(Duration::minutes(1)); + // Token should have expired + assert!(!token.is_valid(clock.now())); + + // Add a second access token, this time without expiration + let token = repo + .compat_access_token() + .add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None) + .await + .unwrap(); + assert_eq!(token.session_id, session.id); + assert_eq!(token.token, SECOND_TOKEN); + + // Token is currently valid + assert!(token.is_valid(clock.now())); + + // Make it expire + repo.compat_access_token() + .expire(&clock, token) + .await + .unwrap(); + + // Reload it + let token = repo + .compat_access_token() + .find_by_token(SECOND_TOKEN) + .await + .unwrap() + .expect("compat access token not found"); + + // Token is not valid anymore + assert!(!token.is_valid(clock.now())); + + repo.save().await.unwrap(); + } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_refresh_token_repository(pool: PgPool) { + const ACCESS_TOKEN: &str = "access_token"; + const REFRESH_TOKEN: &str = "refresh_token"; + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + + // Add an access token to that session + let access_token = repo + .compat_access_token() + .add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None) + .await + .unwrap(); + + let refresh_token = repo + .compat_refresh_token() + .add( + &mut rng, + &clock, + &session, + &access_token, + REFRESH_TOKEN.to_owned(), + ) + .await + .unwrap(); + assert_eq!(refresh_token.session_id, session.id); + assert_eq!(refresh_token.access_token_id, access_token.id); + assert_eq!(refresh_token.token, REFRESH_TOKEN); + assert!(refresh_token.is_valid()); + assert!(!refresh_token.is_consumed()); + + // Look it up by ID and check everything matches + let refresh_token_lookup = repo + .compat_refresh_token() + .lookup(refresh_token.id) + .await + .unwrap() + .expect("refresh token not found"); + assert_eq!(refresh_token_lookup.id, refresh_token.id); + assert_eq!(refresh_token_lookup.session_id, session.id); + assert_eq!(refresh_token_lookup.access_token_id, access_token.id); + assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN); + assert!(refresh_token_lookup.is_valid()); + assert!(!refresh_token_lookup.is_consumed()); + + // Look it up by token and check everything matches + let refresh_token_lookup = repo + .compat_refresh_token() + .find_by_token(REFRESH_TOKEN) + .await + .unwrap() + .expect("refresh token not found"); + assert_eq!(refresh_token_lookup.id, refresh_token.id); + assert_eq!(refresh_token_lookup.session_id, session.id); + assert_eq!(refresh_token_lookup.access_token_id, access_token.id); + assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN); + assert!(refresh_token_lookup.is_valid()); + assert!(!refresh_token_lookup.is_consumed()); + + // Consume it + let refresh_token = repo + .compat_refresh_token() + .consume(&clock, refresh_token) + .await + .unwrap(); + assert!(!refresh_token.is_valid()); + assert!(refresh_token.is_consumed()); + + // Reload it and check again + let refresh_token_lookup = repo + .compat_refresh_token() + .find_by_token(REFRESH_TOKEN) + .await + .unwrap() + .expect("refresh token not found"); + assert!(!refresh_token_lookup.is_valid()); + assert!(refresh_token_lookup.is_consumed()); + + // Consuming it again should not work + assert!(repo + .compat_refresh_token() + .consume(&clock, refresh_token) + .await + .is_err()); + + repo.save().await.unwrap(); + } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_compat_sso_login_repository(pool: PgPool) { + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Lookup an unknown SSO login + let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap(); + assert_eq!(login, None); + + // Lookup an unknown login token + let login = repo + .compat_sso_login() + .find_by_token("login-token") + .await + .unwrap(); + assert_eq!(login, None); + + // Start a new SSO login + let login = repo + .compat_sso_login() + .add( + &mut rng, + &clock, + "login-token".to_owned(), + "https://example.com/callback".parse().unwrap(), + ) + .await + .unwrap(); + assert!(login.is_pending()); + + // Lookup the login by ID + let login_lookup = repo + .compat_sso_login() + .lookup(login.id) + .await + .unwrap() + .expect("login not found"); + assert_eq!(login_lookup, login); + + // Find the login by token + let login_lookup = repo + .compat_sso_login() + .find_by_token("login-token") + .await + .unwrap() + .expect("login not found"); + assert_eq!(login_lookup, login); + + // Exchanging before fulfilling should not work + // Note: It should also not poison the SQL transaction + let res = repo + .compat_sso_login() + .exchange(&clock, login.clone()) + .await; + assert!(res.is_err()); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + + // Associate the login with the session + let login = repo + .compat_sso_login() + .fulfill(&clock, login, &session) + .await + .unwrap(); + assert!(login.is_fulfilled()); + + // Fulfilling again should not work + // Note: It should also not poison the SQL transaction + let res = repo + .compat_sso_login() + .fulfill(&clock, login.clone(), &session) + .await; + assert!(res.is_err()); + + // Exchange that login + let login = repo + .compat_sso_login() + .exchange(&clock, login) + .await + .unwrap(); + assert!(login.is_exchanged()); + + // Exchange again should not work + // Note: It should also not poison the SQL transaction + let res = repo + .compat_sso_login() + .exchange(&clock, login.clone()) + .await; + assert!(res.is_err()); + + // Fulfilling after exchanging should not work + // Note: It should also not poison the SQL transaction + let res = repo + .compat_sso_login() + .fulfill(&clock, login.clone(), &session) + .await; + assert!(res.is_err()); + + // List the logins for the user + let logins = repo + .compat_sso_login() + .list_paginated(&user, Pagination::first(10)) + .await + .unwrap(); + assert!(!logins.has_next_page); + assert_eq!(logins.edges, vec![login]); + } +} diff --git a/crates/storage-pg/src/compat/refresh_token.rs b/crates/storage-pg/src/compat/refresh_token.rs new file mode 100644 index 000000000..0811119a1 --- /dev/null +++ b/crates/storage-pg/src/compat/refresh_token.rs @@ -0,0 +1,234 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, +}; +use mas_storage::{compat::CompatRefreshTokenRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +/// An implementation of [`CompatRefreshTokenRepository`] for a PostgreSQL +/// connection +pub struct PgCompatRefreshTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatRefreshTokenRepository<'c> { + /// Create a new [`PgCompatRefreshTokenRepository`] from an active + /// PostgreSQL connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatRefreshTokenLookup { + compat_refresh_token_id: Uuid, + refresh_token: String, + created_at: DateTime, + consumed_at: Option>, + compat_access_token_id: Uuid, + compat_session_id: Uuid, +} + +impl From for CompatRefreshToken { + fn from(value: CompatRefreshTokenLookup) -> Self { + let state = match value.consumed_at { + Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at }, + None => CompatRefreshTokenState::Valid, + }; + + Self { + id: value.compat_refresh_token_id.into(), + state, + session_id: value.compat_session_id.into(), + token: value.refresh_token, + created_at: value.created_at, + access_token_id: value.compat_access_token_id.into(), + } + } +} + +#[async_trait] +impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_refresh_token.lookup", + skip_all, + fields( + db.statement, + compat_refresh_token.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatRefreshTokenLookup, + r#" + SELECT compat_refresh_token_id + , refresh_token + , created_at + , consumed_at + , compat_session_id + , compat_access_token_id + + FROM compat_refresh_tokens + + WHERE compat_refresh_token_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatRefreshTokenLookup, + r#" + SELECT compat_refresh_token_id + , refresh_token + , created_at + , consumed_at + , compat_session_id + , compat_access_token_id + + FROM compat_refresh_tokens + + WHERE refresh_token = $1 + "#, + refresh_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.add", + skip_all, + fields( + db.statement, + compat_refresh_token.id, + %compat_session.id, + user.id = %compat_session.user_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + compat_session: &CompatSession, + compat_access_token: &CompatAccessToken, + token: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_refresh_tokens + (compat_refresh_token_id, compat_session_id, + compat_access_token_id, refresh_token, created_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(compat_session.id), + Uuid::from(compat_access_token.id), + token, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatRefreshToken { + id, + state: CompatRefreshTokenState::default(), + session_id: compat_session.id, + access_token_id: compat_access_token.id, + token, + created_at, + }) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.consume", + skip_all, + fields( + db.statement, + %compat_refresh_token.id, + compat_session.id = %compat_refresh_token.session_id, + ), + err, + )] + async fn consume( + &mut self, + clock: &dyn Clock, + compat_refresh_token: CompatRefreshToken, + ) -> Result { + let consumed_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE compat_refresh_tokens + SET consumed_at = $2 + WHERE compat_refresh_token_id = $1 + "#, + Uuid::from(compat_refresh_token.id), + consumed_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let compat_refresh_token = compat_refresh_token + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(compat_refresh_token) + } +} diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs new file mode 100644 index 000000000..283a9a598 --- /dev/null +++ b/crates/storage-pg/src/compat/session.rs @@ -0,0 +1,198 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{CompatSession, CompatSessionState, Device, User}; +use mas_storage::{compat::CompatSessionRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +/// An implementation of [`CompatSessionRepository`] for a PostgreSQL connection +pub struct PgCompatSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatSessionRepository<'c> { + /// Create a new [`PgCompatSessionRepository`] from an active PostgreSQL + /// connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatSessionLookup { + compat_session_id: Uuid, + device_id: String, + user_id: Uuid, + created_at: DateTime, + finished_at: Option>, +} + +impl TryFrom for CompatSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: CompatSessionLookup) -> Result { + let id = value.compat_session_id.into(); + let device = Device::try_from(value.device_id).map_err(|e| { + DatabaseInconsistencyError::on("compat_sessions") + .column("device_id") + .row(id) + .source(e) + })?; + + let state = match value.finished_at { + None => CompatSessionState::Valid, + Some(finished_at) => CompatSessionState::Finished { finished_at }, + }; + + let session = CompatSession { + id, + state, + user_id: value.user_id.into(), + device, + created_at: value.created_at, + }; + + Ok(session) + } +} + +#[async_trait] +impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_session.lookup", + skip_all, + fields( + db.statement, + compat_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSessionLookup, + r#" + SELECT compat_session_id + , device_id + , user_id + , created_at + , finished_at + FROM compat_sessions + WHERE compat_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_session.add", + skip_all, + fields( + db.statement, + compat_session.id, + %user.id, + %user.username, + compat_session.device.id = device.as_str(), + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + device: Device, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + Uuid::from(user.id), + device.as_str(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatSession { + id, + state: CompatSessionState::default(), + user_id: user.id, + device, + created_at, + }) + } + + #[tracing::instrument( + name = "db.compat_session.finish", + skip_all, + fields( + db.statement, + %compat_session.id, + user.id = %compat_session.user_id, + compat_session.device.id = compat_session.device.as_str(), + ), + err, + )] + async fn finish( + &mut self, + clock: &dyn Clock, + compat_session: CompatSession, + ) -> Result { + let finished_at = clock.now(); + + let res = sqlx::query!( + r#" + UPDATE compat_sessions cs + SET finished_at = $2 + WHERE compat_session_id = $1 + "#, + Uuid::from(compat_session.id), + finished_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let compat_session = compat_session + .finish(finished_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(compat_session) + } +} diff --git a/crates/storage-pg/src/compat/sso_login.rs b/crates/storage-pg/src/compat/sso_login.rs new file mode 100644 index 000000000..328bd7890 --- /dev/null +++ b/crates/storage-pg/src/compat/sso_login.rs @@ -0,0 +1,346 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User}; +use mas_storage::{compat::CompatSsoLoginRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use url::Url; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +/// An implementation of [`CompatSsoLoginRepository`] for a PostgreSQL +/// connection +pub struct PgCompatSsoLoginRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatSsoLoginRepository<'c> { + /// Create a new [`PgCompatSsoLoginRepository`] from an active PostgreSQL + /// connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct CompatSsoLoginLookup { + compat_sso_login_id: Uuid, + login_token: String, + redirect_uri: String, + created_at: DateTime, + fulfilled_at: Option>, + exchanged_at: Option>, + compat_session_id: Option, +} + +impl TryFrom for CompatSsoLogin { + type Error = DatabaseInconsistencyError; + + fn try_from(res: CompatSsoLoginLookup) -> Result { + let id = res.compat_sso_login_id.into(); + let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| { + DatabaseInconsistencyError::on("compat_sso_logins") + .column("redirect_uri") + .row(id) + .source(e) + })?; + + let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) { + (None, None, None) => CompatSsoLoginState::Pending, + (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled { + fulfilled_at, + session_id: session_id.into(), + }, + (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => { + CompatSsoLoginState::Exchanged { + fulfilled_at, + exchanged_at, + session_id: session_id.into(), + } + } + _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), + }; + + Ok(CompatSsoLogin { + id, + login_token: res.login_token, + redirect_uri, + created_at: res.created_at, + state, + }) + } +} + +#[async_trait] +impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_sso_login.lookup", + skip_all, + fields( + db.statement, + compat_sso_login.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSsoLoginLookup, + r#" + SELECT compat_sso_login_id + , login_token + , redirect_uri + , created_at + , fulfilled_at + , exchanged_at + , compat_session_id + + FROM compat_sso_logins + WHERE compat_sso_login_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_sso_login.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + login_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSsoLoginLookup, + r#" + SELECT compat_sso_login_id + , login_token + , redirect_uri + , created_at + , fulfilled_at + , exchanged_at + , compat_session_id + + FROM compat_sso_logins + WHERE login_token = $1 + "#, + login_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_sso_login.add", + skip_all, + fields( + db.statement, + compat_sso_login.id, + compat_sso_login.redirect_uri = %redirect_uri, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + login_token: String, + redirect_uri: Url, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_sso_logins + (compat_sso_login_id, login_token, redirect_uri, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + &login_token, + redirect_uri.as_str(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatSsoLogin { + id, + login_token, + redirect_uri, + created_at, + state: CompatSsoLoginState::default(), + }) + } + + #[tracing::instrument( + name = "db.compat_sso_login.fulfill", + skip_all, + fields( + db.statement, + %compat_sso_login.id, + %compat_session.id, + compat_session.device.id = compat_session.device.as_str(), + user.id = %compat_session.user_id, + ), + err, + )] + async fn fulfill( + &mut self, + clock: &dyn Clock, + compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, + ) -> Result { + let fulfilled_at = clock.now(); + let compat_sso_login = compat_sso_login + .fulfill(fulfilled_at, compat_session) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE compat_sso_logins + SET + compat_session_id = $2, + fulfilled_at = $3 + WHERE + compat_sso_login_id = $1 + "#, + Uuid::from(compat_sso_login.id), + Uuid::from(compat_session.id), + fulfilled_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(compat_sso_login) + } + + #[tracing::instrument( + name = "db.compat_sso_login.exchange", + skip_all, + fields( + db.statement, + %compat_sso_login.id, + ), + err, + )] + async fn exchange( + &mut self, + clock: &dyn Clock, + compat_sso_login: CompatSsoLogin, + ) -> Result { + let exchanged_at = clock.now(); + let compat_sso_login = compat_sso_login + .exchange(exchanged_at) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE compat_sso_logins + SET + exchanged_at = $2 + WHERE + compat_sso_login_id = $1 + "#, + Uuid::from(compat_sso_login.id), + exchanged_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(compat_sso_login) + } + + #[tracing::instrument( + name = "db.compat_sso_login.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err + )] + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT cl.compat_sso_login_id + , cl.login_token + , cl.redirect_uri + , cl.created_at + , cl.fulfilled_at + , cl.exchanged_at + , cl.compat_session_id + + FROM compat_sso_logins cl + INNER JOIN compat_sessions cs USING (compat_session_id) + "#, + ); + + query + .push(" WHERE cs.user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("cl.compat_sso_login_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination + .process(edges) + .try_map(CompatSsoLogin::try_from)?; + Ok(page) + } +} diff --git a/crates/storage-pg/src/errors.rs b/crates/storage-pg/src/errors.rs new file mode 100644 index 000000000..ef7afad0f --- /dev/null +++ b/crates/storage-pg/src/errors.rs @@ -0,0 +1,144 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use sqlx::postgres::PgQueryResult; +use thiserror::Error; +use ulid::Ulid; + +/// Generic error when interacting with the database +#[derive(Debug, Error)] +#[error(transparent)] +pub enum DatabaseError { + /// An error which came from the database itself + Driver { + /// The underlying error from the database driver + #[from] + source: sqlx::Error, + }, + + /// An error which occured while converting the data from the database + Inconsistency(#[from] DatabaseInconsistencyError), + + /// An error which happened because the requested database operation is + /// invalid + #[error("Invalid database operation")] + InvalidOperation { + /// The source of the error, if any + #[source] + source: Option>, + }, + + /// An error which happens when an operation affects not enough or too many + /// rows + #[error("Expected {expected} rows to be affected, but {actual} rows were affected")] + RowsAffected { + /// How many rows were expected to be affected + expected: u64, + + /// How many rows were actually affected + actual: u64, + }, +} + +impl DatabaseError { + pub(crate) fn ensure_affected_rows( + result: &PgQueryResult, + expected: u64, + ) -> Result<(), DatabaseError> { + let actual = result.rows_affected(); + if actual == expected { + Ok(()) + } else { + Err(DatabaseError::RowsAffected { expected, actual }) + } + } + + pub(crate) fn to_invalid_operation(e: E) -> Self { + Self::InvalidOperation { + source: Some(Box::new(e)), + } + } + + pub(crate) const fn invalid_operation() -> Self { + Self::InvalidOperation { source: None } + } +} + +/// An error which occured while converting the data from the database +#[derive(Debug, Error)] +pub struct DatabaseInconsistencyError { + /// The table which was being queried + table: &'static str, + + /// The column which was being queried + column: Option<&'static str>, + + /// The row which was being queried + row: Option, + + /// The source of the error + #[source] + source: Option>, +} + +impl std::fmt::Display for DatabaseInconsistencyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Database inconsistency on table {}", self.table)?; + if let Some(column) = self.column { + write!(f, " column {column}")?; + } + if let Some(row) = self.row { + write!(f, " row {row}")?; + } + + Ok(()) + } +} + +impl DatabaseInconsistencyError { + /// Create a new [`DatabaseInconsistencyError`] for the given table + #[must_use] + pub(crate) const fn on(table: &'static str) -> Self { + Self { + table, + column: None, + row: None, + source: None, + } + } + + /// Set the column which was being queried + #[must_use] + pub(crate) const fn column(mut self, column: &'static str) -> Self { + self.column = Some(column); + self + } + + /// Set the row which was being queried + #[must_use] + pub(crate) const fn row(mut self, row: Ulid) -> Self { + self.row = Some(row); + self + } + + /// Give the source of the error + #[must_use] + pub(crate) fn source( + mut self, + source: E, + ) -> Self { + self.source = Some(Box::new(source)); + self + } +} diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs new file mode 100644 index 000000000..6615a3175 --- /dev/null +++ b/crates/storage-pg/src/lib.rs @@ -0,0 +1,225 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! An implementation of the storage traits for a PostgreSQL database +//! +//! This backend uses [`sqlx`] to interact with the database. Most queries are +//! type-checked, using introspection data recorded in the `sqlx-data.json` +//! file. This file is generated by the `sqlx` CLI tool, and should be updated +//! whenever the database schema changes, or new queries are added. +//! +//! # Implementing a new repository +//! +//! When a new repository is defined in [`mas_storage`], it should be +//! implemented here, with the PostgreSQL backend. +//! +//! A typical implementation will look like this: +//! +//! ```rust +//! # use async_trait::async_trait; +//! # use ulid::Ulid; +//! # use rand::RngCore; +//! # use mas_storage::Clock; +//! # use mas_storage_pg::{DatabaseError, ExecuteExt, LookupResultExt}; +//! # use sqlx::PgConnection; +//! # use uuid::Uuid; +//! # +//! # // A fake data structure, usually defined in mas-data-model +//! # #[derive(sqlx::FromRow)] +//! # struct FakeData { +//! # id: Ulid, +//! # } +//! # +//! # // A fake repository trait, usually defined in mas-storage +//! # #[async_trait] +//! # pub trait FakeDataRepository: Send + Sync { +//! # type Error; +//! # async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; +//! # async fn add( +//! # &mut self, +//! # rng: &mut (dyn RngCore + Send), +//! # clock: &dyn Clock, +//! # ) -> Result; +//! # } +//! # +//! /// An implementation of [`FakeDataRepository`] for a PostgreSQL connection +//! pub struct PgFakeDataRepository<'c> { +//! conn: &'c mut PgConnection, +//! } +//! +//! impl<'c> PgFakeDataRepository<'c> { +//! /// Create a new [`FakeDataRepository`] from an active PostgreSQL connection +//! pub fn new(conn: &'c mut PgConnection) -> Self { +//! Self { conn } +//! } +//! } +//! +//! #[derive(sqlx::FromRow)] +//! struct FakeDataLookup { +//! fake_data_id: Uuid, +//! } +//! +//! impl From for FakeData { +//! fn from(value: FakeDataLookup) -> Self { +//! Self { +//! id: value.fake_data_id.into(), +//! } +//! } +//! } +//! +//! #[async_trait] +//! impl<'c> FakeDataRepository for PgFakeDataRepository<'c> { +//! type Error = DatabaseError; +//! +//! #[tracing::instrument( +//! name = "db.fake_data.lookup", +//! skip_all, +//! fields( +//! db.statement, +//! fake_data.id = %id, +//! ), +//! err, +//! )] +//! async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { +//! // Note: here we would use the macro version instead, but it's not possible here in +//! // this documentation example +//! let res: Option = sqlx::query_as( +//! r#" +//! SELECT fake_data_id +//! FROM fake_data +//! WHERE fake_data_id = $1 +//! "#, +//! ) +//! .bind(Uuid::from(id)) +//! .traced() +//! .fetch_one(&mut *self.conn) +//! .await +//! .to_option()?; +//! +//! let Some(res) = res else { return Ok(None) }; +//! +//! Ok(Some(res.into())) +//! } +//! +//! #[tracing::instrument( +//! name = "db.fake_data.add", +//! skip_all, +//! fields( +//! db.statement, +//! fake_data.id, +//! ), +//! err, +//! )] +//! async fn add( +//! &mut self, +//! rng: &mut (dyn RngCore + Send), +//! clock: &dyn Clock, +//! ) -> Result { +//! let created_at = clock.now(); +//! let id = Ulid::from_datetime_with_source(created_at.into(), rng); +//! tracing::Span::current().record("fake_data.id", tracing::field::display(id)); +//! +//! // Note: here we would use the macro version instead, but it's not possible here in +//! // this documentation example +//! sqlx::query( +//! r#" +//! INSERT INTO fake_data (id) +//! VALUES ($1) +//! "#, +//! ) +//! .bind(Uuid::from(id)) +//! .traced() +//! .execute(&mut *self.conn) +//! .await?; +//! +//! Ok(FakeData { +//! id, +//! }) +//! } +//! } +//! ``` +//! +//! A few things to note with the implementation: +//! +//! - All methods are traced, with an explicit, somewhat consistent name. +//! - The SQL statement is included as attribute, by declaring a `db.statement` +//! attribute on the tracing span, and then calling [`ExecuteExt::traced`]. +//! - The IDs are all [`Ulid`], and generated from the clock and the random +//! number generated passed as parameters. The generated IDs are recorded in +//! the span. +//! - The IDs are stored as [`Uuid`] in PostgreSQL, so conversions are required +//! - "Not found" errors are handled by returning `Ok(None)` instead of an +//! error. The [`LookupResultExt::to_option`] method helps to do that. +//! +//! [`Ulid`]: ulid::Ulid +//! [`Uuid`]: uuid::Uuid + +#![forbid(unsafe_code)] +#![deny( + clippy::all, + clippy::str_to_string, + clippy::future_not_send, + rustdoc::broken_intra_doc_links, + missing_docs +)] +#![warn(clippy::pedantic)] +#![allow(clippy::module_name_repetitions)] + +use sqlx::migrate::Migrator; + +/// An extension trait for [`Result`] which adds a [`to_option`] method, useful +/// for handling "not found" errors from [`sqlx`] +/// +/// [`to_option`]: LookupResultExt::to_option +pub trait LookupResultExt { + /// The output type + type Output; + + /// Transform a [`Result`] from a sqlx query to transform "not found" errors + /// into [`None`] + /// + /// # Errors + /// + /// Returns the original error if the error was not a + /// [`sqlx::Error::RowNotFound`] error + fn to_option(self) -> Result, sqlx::Error>; +} + +impl LookupResultExt for Result { + type Output = T; + + fn to_option(self) -> Result, sqlx::Error> { + match self { + Ok(v) => Ok(Some(v)), + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(e), + } + } +} + +pub mod compat; +pub mod oauth2; +pub mod upstream_oauth2; +pub mod user; + +mod errors; +pub(crate) mod pagination; +pub(crate) mod repository; +pub(crate) mod tracing; + +pub(crate) use self::errors::DatabaseInconsistencyError; +pub use self::{errors::DatabaseError, repository::PgRepository, tracing::ExecuteExt}; + +/// Embedded migrations, allowing them to run on startup +pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/crates/storage-pg/src/oauth2/access_token.rs b/crates/storage-pg/src/oauth2/access_token.rs new file mode 100644 index 000000000..e809fa53c --- /dev/null +++ b/crates/storage-pg/src/oauth2/access_token.rs @@ -0,0 +1,227 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; +use mas_data_model::{AccessToken, AccessTokenState, Session}; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +/// An implementation of [`OAuth2AccessTokenRepository`] for a PostgreSQL +/// connection +pub struct PgOAuth2AccessTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2AccessTokenRepository<'c> { + /// Create a new [`PgOAuth2AccessTokenRepository`] from an active PostgreSQL + /// connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct OAuth2AccessTokenLookup { + oauth2_access_token_id: Uuid, + oauth2_session_id: Uuid, + access_token: String, + created_at: DateTime, + expires_at: DateTime, + revoked_at: Option>, +} + +impl From for AccessToken { + fn from(value: OAuth2AccessTokenLookup) -> Self { + let state = match value.revoked_at { + None => AccessTokenState::Valid, + Some(revoked_at) => AccessTokenState::Revoked { revoked_at }, + }; + + Self { + id: value.oauth2_access_token_id.into(), + state, + session_id: value.oauth2_session_id.into(), + access_token: value.access_token, + created_at: value.created_at, + expires_at: value.expires_at, + } + } +} + +#[async_trait] +impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> { + type Error = DatabaseError; + + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2AccessTokenLookup, + r#" + SELECT oauth2_access_token_id + , access_token + , created_at + , expires_at + , revoked_at + , oauth2_session_id + + FROM oauth2_access_tokens + + WHERE oauth2_access_token_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_access_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2AccessTokenLookup, + r#" + SELECT oauth2_access_token_id + , access_token + , created_at + , expires_at + , revoked_at + , oauth2_session_id + + FROM oauth2_access_tokens + + WHERE access_token = $1 + "#, + access_token, + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_access_token.add", + skip_all, + fields( + db.statement, + %session.id, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + access_token.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + session: &Session, + access_token: String, + expires_after: Duration, + ) -> Result { + let created_at = clock.now(); + let expires_at = created_at + expires_after; + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + + tracing::Span::current().record("access_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_access_tokens + (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at) + VALUES + ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(session.id), + &access_token, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(AccessToken { + id, + state: AccessTokenState::default(), + access_token, + session_id: session.id, + created_at, + expires_at, + }) + } + + async fn revoke( + &mut self, + clock: &dyn Clock, + access_token: AccessToken, + ) -> Result { + let revoked_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_access_tokens + SET revoked_at = $2 + WHERE oauth2_access_token_id = $1 + "#, + Uuid::from(access_token.id), + revoked_at, + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + access_token + .revoke(revoked_at) + .map_err(DatabaseError::to_invalid_operation) + } + + async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result { + // Cleanup token which expired more than 15 minutes ago + let threshold = clock.now() - Duration::minutes(15); + let res = sqlx::query!( + r#" + DELETE FROM oauth2_access_tokens + WHERE expires_at < $1 + "#, + threshold, + ) + .execute(&mut *self.conn) + .await?; + + Ok(res.rows_affected().try_into().unwrap_or(usize::MAX)) + } +} diff --git a/crates/storage-pg/src/oauth2/authorization_grant.rs b/crates/storage-pg/src/oauth2/authorization_grant.rs new file mode 100644 index 000000000..f62edae30 --- /dev/null +++ b/crates/storage-pg/src/oauth2/authorization_grant.rs @@ -0,0 +1,514 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::num::NonZeroU32; + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session, +}; +use mas_iana::oauth::PkceCodeChallengeMethod; +use mas_storage::{oauth2::OAuth2AuthorizationGrantRepository, Clock}; +use oauth2_types::{requests::ResponseMode, scope::Scope}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use url::Url; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +/// An implementation of [`OAuth2AuthorizationGrantRepository`] for a PostgreSQL +/// connection +pub struct PgOAuth2AuthorizationGrantRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2AuthorizationGrantRepository<'c> { + /// Create a new [`PgOAuth2AuthorizationGrantRepository`] from an active + /// PostgreSQL connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[allow(clippy::struct_excessive_bools)] +struct GrantLookup { + oauth2_authorization_grant_id: Uuid, + created_at: DateTime, + cancelled_at: Option>, + fulfilled_at: Option>, + exchanged_at: Option>, + scope: String, + state: Option, + nonce: Option, + redirect_uri: String, + response_mode: String, + max_age: Option, + response_type_code: bool, + response_type_id_token: bool, + authorization_code: Option, + code_challenge: Option, + code_challenge_method: Option, + requires_consent: bool, + oauth2_client_id: Uuid, + oauth2_session_id: Option, +} + +impl TryFrom for AuthorizationGrant { + type Error = DatabaseInconsistencyError; + + #[allow(clippy::too_many_lines)] + fn try_from(value: GrantLookup) -> Result { + let id = value.oauth2_authorization_grant_id.into(); + let scope: Scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("scope") + .row(id) + .source(e) + })?; + + let stage = match ( + value.fulfilled_at, + value.exchanged_at, + value.cancelled_at, + value.oauth2_session_id, + ) { + (None, None, None, None) => AuthorizationGrantStage::Pending, + (Some(fulfilled_at), None, None, Some(session_id)) => { + AuthorizationGrantStage::Fulfilled { + session_id: session_id.into(), + fulfilled_at, + } + } + (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => { + AuthorizationGrantStage::Exchanged { + session_id: session_id.into(), + fulfilled_at, + exchanged_at, + } + } + (None, None, Some(cancelled_at), None) => { + AuthorizationGrantStage::Cancelled { cancelled_at } + } + _ => { + return Err( + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("stage") + .row(id), + ); + } + }; + + let pkce = match (value.code_challenge, value.code_challenge_method) { + (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { + Some(Pkce { + challenge_method: PkceCodeChallengeMethod::Plain, + challenge, + }) + } + (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce { + challenge_method: PkceCodeChallengeMethod::S256, + challenge, + }), + (None, None) => None, + _ => { + return Err( + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("code_challenge_method") + .row(id), + ); + } + }; + + let code: Option = + match (value.response_type_code, value.authorization_code, pkce) { + (false, None, None) => None, + (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }), + _ => { + return Err( + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("authorization_code") + .row(id), + ); + } + }; + + let redirect_uri = value.redirect_uri.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("redirect_uri") + .row(id) + .source(e) + })?; + + let response_mode = value.response_mode.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("response_mode") + .row(id) + .source(e) + })?; + + let max_age = value + .max_age + .map(u32::try_from) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("max_age") + .row(id) + .source(e) + })? + .map(NonZeroU32::try_from) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("max_age") + .row(id) + .source(e) + })?; + + Ok(AuthorizationGrant { + id, + stage, + client_id: value.oauth2_client_id.into(), + code, + scope, + state: value.state, + nonce: value.nonce, + max_age, + response_mode, + redirect_uri, + created_at: value.created_at, + response_type_id_token: value.response_type_id_token, + requires_consent: value.requires_consent, + }) + } +} + +#[async_trait] +impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.add", + skip_all, + fields( + db.statement, + grant.id, + grant.scope = %scope, + %client.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client: &Client, + redirect_uri: Url, + scope: Scope, + code: Option, + state: Option, + nonce: Option, + max_age: Option, + response_mode: ResponseMode, + response_type_id_token: bool, + requires_consent: bool, + ) -> Result { + let code_challenge = code + .as_ref() + .and_then(|c| c.pkce.as_ref()) + .map(|p| &p.challenge); + let code_challenge_method = code + .as_ref() + .and_then(|c| c.pkce.as_ref()) + .map(|p| p.challenge_method.to_string()); + // TODO: this conversion is a bit ugly + let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)); + let code_str = code.as_ref().map(|c| &c.code); + + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("grant.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_authorization_grants ( + oauth2_authorization_grant_id, + oauth2_client_id, + redirect_uri, + scope, + state, + nonce, + max_age, + response_mode, + code_challenge, + code_challenge_method, + response_type_code, + response_type_id_token, + authorization_code, + requires_consent, + created_at + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + "#, + Uuid::from(id), + Uuid::from(client.id), + redirect_uri.to_string(), + scope.to_string(), + state, + nonce, + max_age_i32, + response_mode.to_string(), + code_challenge, + code_challenge_method, + code.is_some(), + response_type_id_token, + code_str, + requires_consent, + created_at, + ) + .execute(&mut *self.conn) + .await?; + + Ok(AuthorizationGrant { + id, + stage: AuthorizationGrantStage::Pending, + code, + redirect_uri, + client_id: client.id, + scope, + state, + nonce, + max_age, + response_mode, + created_at, + response_type_id_token, + requires_consent, + }) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.lookup", + skip_all, + fields( + db.statement, + grant.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + GrantLookup, + r#" + SELECT oauth2_authorization_grant_id + , created_at + , cancelled_at + , fulfilled_at + , exchanged_at + , scope + , state + , redirect_uri + , response_mode + , nonce + , max_age + , oauth2_client_id + , authorization_code + , response_type_code + , response_type_id_token + , code_challenge + , code_challenge_method + , requires_consent + , oauth2_session_id + FROM + oauth2_authorization_grants + + WHERE oauth2_authorization_grant_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.find_by_code", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_code( + &mut self, + code: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + GrantLookup, + r#" + SELECT oauth2_authorization_grant_id + , created_at + , cancelled_at + , fulfilled_at + , exchanged_at + , scope + , state + , redirect_uri + , response_mode + , nonce + , max_age + , oauth2_client_id + , authorization_code + , response_type_code + , response_type_id_token + , code_challenge + , code_challenge_method + , requires_consent + , oauth2_session_id + FROM + oauth2_authorization_grants + + WHERE authorization_code = $1 + "#, + code, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.fulfill", + skip_all, + fields( + db.statement, + %grant.id, + client.id = %grant.client_id, + %session.id, + user_session.id = %session.user_session_id, + ), + err, + )] + async fn fulfill( + &mut self, + clock: &dyn Clock, + session: &Session, + grant: AuthorizationGrant, + ) -> Result { + let fulfilled_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_authorization_grants + SET fulfilled_at = $2 + , oauth2_session_id = $3 + WHERE oauth2_authorization_grant_id = $1 + "#, + Uuid::from(grant.id), + fulfilled_at, + Uuid::from(session.id), + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + // XXX: check affected rows & new methods + let grant = grant + .fulfill(fulfilled_at, session) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(grant) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.exchange", + skip_all, + fields( + db.statement, + %grant.id, + client.id = %grant.client_id, + ), + err, + )] + async fn exchange( + &mut self, + clock: &dyn Clock, + grant: AuthorizationGrant, + ) -> Result { + let exchanged_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_authorization_grants + SET exchanged_at = $2 + WHERE oauth2_authorization_grant_id = $1 + "#, + Uuid::from(grant.id), + exchanged_at, + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let grant = grant + .exchange(exchanged_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(grant) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.give_consent", + skip_all, + fields( + db.statement, + %grant.id, + client.id = %grant.client_id, + ), + err, + )] + async fn give_consent( + &mut self, + mut grant: AuthorizationGrant, + ) -> Result { + sqlx::query!( + r#" + UPDATE oauth2_authorization_grants AS og + SET + requires_consent = 'f' + WHERE + og.oauth2_authorization_grant_id = $1 + "#, + Uuid::from(grant.id), + ) + .execute(&mut *self.conn) + .await?; + + grant.requires_consent = false; + + Ok(grant) + } +} diff --git a/crates/storage-pg/src/oauth2/client.rs b/crates/storage-pg/src/oauth2/client.rs new file mode 100644 index 000000000..cc2ed8b86 --- /dev/null +++ b/crates/storage-pg/src/oauth2/client.rs @@ -0,0 +1,748 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + collections::{BTreeMap, BTreeSet}, + str::FromStr, + string::ToString, +}; + +use async_trait::async_trait; +use mas_data_model::{Client, JwksOrJwksUri, User}; +use mas_iana::{ + jose::JsonWebSignatureAlg, + oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, +}; +use mas_jose::jwk::PublicJsonWebKeySet; +use mas_storage::{oauth2::OAuth2ClientRepository, Clock}; +use oauth2_types::{ + requests::GrantType, + scope::{Scope, ScopeToken}, +}; +use rand::RngCore; +use sqlx::PgConnection; +use tracing::{info_span, Instrument}; +use ulid::Ulid; +use url::Url; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +/// An implementation of [`OAuth2ClientRepository`] for a PostgreSQL connection +pub struct PgOAuth2ClientRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2ClientRepository<'c> { + /// Create a new [`PgOAuth2ClientRepository`] from an active PostgreSQL + /// connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +// XXX: response_types & contacts +#[derive(Debug)] +struct OAuth2ClientLookup { + oauth2_client_id: Uuid, + encrypted_client_secret: Option, + redirect_uris: Vec, + // response_types: Vec, + grant_type_authorization_code: bool, + grant_type_refresh_token: bool, + // contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, +} + +impl TryInto for OAuth2ClientLookup { + type Error = DatabaseInconsistencyError; + + #[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing + fn try_into(self) -> Result { + let id = Ulid::from(self.oauth2_client_id); + + let redirect_uris: Result, _> = + self.redirect_uris.iter().map(|s| s.parse()).collect(); + let redirect_uris = redirect_uris.map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("redirect_uris") + .row(id) + .source(e) + })?; + + let response_types = vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ]; + /* XXX + let response_types: Result, _> = + self.response_types.iter().map(|s| s.parse()).collect(); + let response_types = response_types.map_err(|source| ClientFetchError::ParseField { + field: "response_types", + source, + })?; + */ + + let mut grant_types = Vec::new(); + if self.grant_type_authorization_code { + grant_types.push(GrantType::AuthorizationCode); + } + if self.grant_type_refresh_token { + grant_types.push(GrantType::RefreshToken); + } + + let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("logo_uri") + .row(id) + .source(e) + })?; + + let client_uri = self + .client_uri + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("client_uri") + .row(id) + .source(e) + })?; + + let policy_uri = self + .policy_uri + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("policy_uri") + .row(id) + .source(e) + })?; + + let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("tos_uri") + .row(id) + .source(e) + })?; + + let id_token_signed_response_alg = self + .id_token_signed_response_alg + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("id_token_signed_response_alg") + .row(id) + .source(e) + })?; + + let userinfo_signed_response_alg = self + .userinfo_signed_response_alg + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("userinfo_signed_response_alg") + .row(id) + .source(e) + })?; + + let token_endpoint_auth_method = self + .token_endpoint_auth_method + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("token_endpoint_auth_method") + .row(id) + .source(e) + })?; + + let token_endpoint_auth_signing_alg = self + .token_endpoint_auth_signing_alg + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("token_endpoint_auth_signing_alg") + .row(id) + .source(e) + })?; + + let initiate_login_uri = self + .initiate_login_uri + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("initiate_login_uri") + .row(id) + .source(e) + })?; + + let jwks = match (self.jwks, self.jwks_uri) { + (None, None) => None, + (Some(jwks), None) => { + let jwks = serde_json::from_value(jwks).map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("jwks") + .row(id) + .source(e) + })?; + Some(JwksOrJwksUri::Jwks(jwks)) + } + (None, Some(jwks_uri)) => { + let jwks_uri = jwks_uri.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("jwks_uri") + .row(id) + .source(e) + })?; + + Some(JwksOrJwksUri::JwksUri(jwks_uri)) + } + _ => { + return Err(DatabaseInconsistencyError::on("oauth2_clients") + .column("jwks(_uri)") + .row(id)) + } + }; + + Ok(Client { + id, + client_id: id.to_string(), + encrypted_client_secret: self.encrypted_client_secret, + redirect_uris, + response_types, + grant_types, + // contacts: self.contacts, + contacts: vec![], + client_name: self.client_name, + logo_uri, + client_uri, + policy_uri, + tos_uri, + jwks, + id_token_signed_response_alg, + userinfo_signed_response_alg, + token_endpoint_auth_method, + token_endpoint_auth_signing_alg, + initiate_login_uri, + }) + } +} + +#[async_trait] +impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_client.lookup", + skip_all, + fields( + db.statement, + oauth2_client.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2ClientLookup, + r#" + SELECT oauth2_client_id + , encrypted_client_secret + , ARRAY( + SELECT redirect_uri + FROM oauth2_client_redirect_uris r + WHERE r.oauth2_client_id = c.oauth2_client_id + ) AS "redirect_uris!" + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + FROM oauth2_clients c + + WHERE oauth2_client_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_client.load_batch", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn load_batch( + &mut self, + ids: BTreeSet, + ) -> Result, Self::Error> { + let ids: Vec = ids.into_iter().map(Uuid::from).collect(); + let res = sqlx::query_as!( + OAuth2ClientLookup, + r#" + SELECT oauth2_client_id + , encrypted_client_secret + , ARRAY( + SELECT redirect_uri + FROM oauth2_client_redirect_uris r + WHERE r.oauth2_client_id = c.oauth2_client_id + ) AS "redirect_uris!" + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + FROM oauth2_clients c + + WHERE oauth2_client_id = ANY($1::uuid[]) + "#, + &ids, + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + res.into_iter() + .map(|r| { + r.try_into() + .map(|c: Client| (c.id, c)) + .map_err(DatabaseError::from) + }) + .collect() + } + + #[tracing::instrument( + name = "db.oauth2_client.add", + skip_all, + fields( + db.statement, + client.id, + client.name = client_name + ), + err, + )] + #[allow(clippy::too_many_lines)] + async fn add( + &mut self, + mut rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + redirect_uris: Vec, + encrypted_client_secret: Option, + grant_types: Vec, + contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, + ) -> Result { + let now = clock.now(); + let id = Ulid::from_datetime_with_source(now.into(), rng); + tracing::Span::current().record("client.id", tracing::field::display(id)); + + let jwks_json = jwks + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(DatabaseError::to_invalid_operation)?; + + sqlx::query!( + r#" + INSERT INTO oauth2_clients + ( oauth2_client_id + , encrypted_client_secret + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + "#, + Uuid::from(id), + encrypted_client_secret, + grant_types.contains(&GrantType::AuthorizationCode), + grant_types.contains(&GrantType::RefreshToken), + client_name, + logo_uri.as_ref().map(Url::as_str), + client_uri.as_ref().map(Url::as_str), + policy_uri.as_ref().map(Url::as_str), + tos_uri.as_ref().map(Url::as_str), + jwks_uri.as_ref().map(Url::as_str), + jwks_json, + id_token_signed_response_alg + .as_ref() + .map(ToString::to_string), + userinfo_signed_response_alg + .as_ref() + .map(ToString::to_string), + token_endpoint_auth_method.as_ref().map(ToString::to_string), + token_endpoint_auth_signing_alg + .as_ref() + .map(ToString::to_string), + initiate_login_uri.as_ref().map(Url::as_str), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + { + let span = info_span!( + "db.oauth2_client.add.redirect_uris", + db.statement = tracing::field::Empty, + client.id = %id, + ); + + let (uri_ids, redirect_uris): (Vec, Vec) = redirect_uris + .iter() + .map(|uri| { + ( + Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), + uri.as_str().to_owned(), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_client_redirect_uris + (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) + SELECT id, $2, redirect_uri + FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) + "#, + &uri_ids, + Uuid::from(id), + &redirect_uris, + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + let jwks = match (jwks, jwks_uri) { + (None, None) => None, + (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), + (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), + _ => return Err(DatabaseError::invalid_operation()), + }; + + Ok(Client { + id, + client_id: id.to_string(), + encrypted_client_secret, + redirect_uris, + response_types: vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ], + grant_types, + contacts, + client_name, + logo_uri, + client_uri, + policy_uri, + tos_uri, + jwks, + id_token_signed_response_alg, + userinfo_signed_response_alg, + token_endpoint_auth_method, + token_endpoint_auth_signing_alg, + initiate_login_uri, + }) + } + + #[tracing::instrument( + name = "db.oauth2_client.add_from_config", + skip_all, + fields( + db.statement, + client.id = %client_id, + ), + err, + )] + async fn add_from_config( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client_id: Ulid, + client_auth_method: OAuthClientAuthenticationMethod, + encrypted_client_secret: Option, + jwks: Option, + jwks_uri: Option, + redirect_uris: Vec, + ) -> Result { + let jwks_json = jwks + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(DatabaseError::to_invalid_operation)?; + + let client_auth_method = client_auth_method.to_string(); + + sqlx::query!( + r#" + INSERT INTO oauth2_clients + ( oauth2_client_id + , encrypted_client_secret + , grant_type_authorization_code + , grant_type_refresh_token + , token_endpoint_auth_method + , jwks + , jwks_uri + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (oauth2_client_id) + DO + UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret + , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code + , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token + , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method + , jwks = EXCLUDED.jwks + , jwks_uri = EXCLUDED.jwks_uri + "#, + Uuid::from(client_id), + encrypted_client_secret, + true, + true, + client_auth_method, + jwks_json, + jwks_uri.as_ref().map(Url::as_str), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + { + let span = info_span!( + "db.oauth2_client.add_from_config.redirect_uris", + client.id = %client_id, + db.statement = tracing::field::Empty, + ); + + let now = clock.now(); + let (ids, redirect_uris): (Vec, Vec) = redirect_uris + .iter() + .map(|uri| { + ( + Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut *rng)), + uri.as_str().to_owned(), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_client_redirect_uris + (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) + SELECT id, $2, redirect_uri + FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) + "#, + &ids, + Uuid::from(client_id), + &redirect_uris, + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + let jwks = match (jwks, jwks_uri) { + (None, None) => None, + (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), + (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), + _ => return Err(DatabaseError::invalid_operation()), + }; + + Ok(Client { + id: client_id, + client_id: client_id.to_string(), + encrypted_client_secret, + redirect_uris, + response_types: vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ], + grant_types: Vec::new(), + contacts: Vec::new(), + client_name: None, + logo_uri: None, + client_uri: None, + policy_uri: None, + tos_uri: None, + jwks, + id_token_signed_response_alg: None, + userinfo_signed_response_alg: None, + token_endpoint_auth_method: None, + token_endpoint_auth_signing_alg: None, + initiate_login_uri: None, + }) + } + + #[tracing::instrument( + name = "db.oauth2_client.get_consent_for_user", + skip_all, + fields( + db.statement, + %user.id, + %client.id, + ), + err, + )] + async fn get_consent_for_user( + &mut self, + client: &Client, + user: &User, + ) -> Result { + let scope_tokens: Vec = sqlx::query_scalar!( + r#" + SELECT scope_token + FROM oauth2_consents + WHERE user_id = $1 AND oauth2_client_id = $2 + "#, + Uuid::from(user.id), + Uuid::from(client.id), + ) + .fetch_all(&mut *self.conn) + .await?; + + let scope: Result = scope_tokens + .into_iter() + .map(|s| ScopeToken::from_str(&s)) + .collect(); + + let scope = scope.map_err(|e| { + DatabaseInconsistencyError::on("oauth2_consents") + .column("scope_token") + .source(e) + })?; + + Ok(scope) + } + + #[tracing::instrument( + skip_all, + fields( + db.statement, + %user.id, + %client.id, + %scope, + ), + err, + )] + async fn give_consent_for_user( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client: &Client, + user: &User, + scope: &Scope, + ) -> Result<(), Self::Error> { + let now = clock.now(); + let (tokens, ids): (Vec, Vec) = scope + .iter() + .map(|token| { + ( + token.to_string(), + Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_consents + (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at) + SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token) + ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5 + "#, + &ids, + Uuid::from(user.id), + Uuid::from(client.id), + &tokens, + now, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } +} diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs new file mode 100644 index 000000000..120fca6cf --- /dev/null +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -0,0 +1,371 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! A module containing the PostgreSQL implementations of the OAuth2-related +//! repositories + +mod access_token; +mod authorization_grant; +mod client; +mod refresh_token; +mod session; + +pub use self::{ + access_token::PgOAuth2AccessTokenRepository, + authorization_grant::PgOAuth2AuthorizationGrantRepository, client::PgOAuth2ClientRepository, + refresh_token::PgOAuth2RefreshTokenRepository, session::PgOAuth2SessionRepository, +}; + +#[cfg(test)] +mod tests { + use chrono::Duration; + use mas_data_model::AuthorizationCode; + use mas_storage::{clock::MockClock, Clock, Pagination, Repository}; + use oauth2_types::{ + requests::{GrantType, ResponseMode}, + scope::{Scope, OPENID}, + }; + use rand::SeedableRng; + use rand_chacha::ChaChaRng; + use sqlx::PgPool; + use ulid::Ulid; + + use crate::PgRepository; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_repositories(pool: PgPool) { + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + + // Lookup a non-existing client + let client = repo.oauth2_client().lookup(Ulid::nil()).await.unwrap(); + assert_eq!(client, None); + + // Find a non-existing client by client id + let client = repo + .oauth2_client() + .find_by_client_id("some-client-id") + .await + .unwrap(); + assert_eq!(client, None); + + // Create a client + let client = repo + .oauth2_client() + .add( + &mut rng, + &clock, + vec!["https://example.com/redirect".parse().unwrap()], + None, + vec![GrantType::AuthorizationCode], + Vec::new(), // TODO: contacts are not yet saved + // vec!["contact@example.com".to_owned()], + Some("Test client".to_owned()), + Some("https://example.com/logo.png".parse().unwrap()), + Some("https://example.com/".parse().unwrap()), + Some("https://example.com/policy".parse().unwrap()), + Some("https://example.com/tos".parse().unwrap()), + Some("https://example.com/jwks.json".parse().unwrap()), + None, + None, + None, + None, + None, + Some("https://example.com/login".parse().unwrap()), + ) + .await + .unwrap(); + + // Lookup the same client by id + let client_lookup = repo + .oauth2_client() + .lookup(client.id) + .await + .unwrap() + .expect("client not found"); + assert_eq!(client, client_lookup); + + // Find the same client by client id + let client_lookup = repo + .oauth2_client() + .find_by_client_id(&client.client_id) + .await + .unwrap() + .expect("client not found"); + assert_eq!(client, client_lookup); + + // Lookup a non-existing grant + let grant = repo + .oauth2_authorization_grant() + .lookup(Ulid::nil()) + .await + .unwrap(); + assert_eq!(grant, None); + + // Find a non-existing grant by code + let grant = repo + .oauth2_authorization_grant() + .find_by_code("code") + .await + .unwrap(); + assert_eq!(grant, None); + + // Create an authorization grant + let grant = repo + .oauth2_authorization_grant() + .add( + &mut rng, + &clock, + &client, + "https://example.com/redirect".parse().unwrap(), + Scope::from_iter([OPENID]), + Some(AuthorizationCode { + code: "code".to_owned(), + pkce: None, + }), + Some("state".to_owned()), + Some("nonce".to_owned()), + None, + ResponseMode::Query, + true, + false, + ) + .await + .unwrap(); + assert!(grant.is_pending()); + + // Lookup the same grant by id + let grant_lookup = repo + .oauth2_authorization_grant() + .lookup(grant.id) + .await + .unwrap() + .expect("grant not found"); + assert_eq!(grant, grant_lookup); + + // Find the same grant by code + let grant_lookup = repo + .oauth2_authorization_grant() + .find_by_code("code") + .await + .unwrap() + .expect("grant not found"); + assert_eq!(grant, grant_lookup); + + // Create a user and a start a user session + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + let user_session = repo + .browser_session() + .add(&mut rng, &clock, &user) + .await + .unwrap(); + + // Lookup the consent the user gave to the client + let consent = repo + .oauth2_client() + .get_consent_for_user(&client, &user) + .await + .unwrap(); + assert!(consent.is_empty()); + + // Give consent to the client + let scope = Scope::from_iter([OPENID]); + repo.oauth2_client() + .give_consent_for_user(&mut rng, &clock, &client, &user, &scope) + .await + .unwrap(); + + // Lookup the consent the user gave to the client + let consent = repo + .oauth2_client() + .get_consent_for_user(&client, &user) + .await + .unwrap(); + assert_eq!(scope, consent); + + // Lookup a non-existing session + let session = repo.oauth2_session().lookup(Ulid::nil()).await.unwrap(); + assert_eq!(session, None); + + // Create a session out of the grant + let session = repo + .oauth2_session() + .create_from_grant(&mut rng, &clock, &grant, &user_session) + .await + .unwrap(); + + // Mark the grant as fulfilled + let grant = repo + .oauth2_authorization_grant() + .fulfill(&clock, &session, grant) + .await + .unwrap(); + assert!(grant.is_fulfilled()); + + // Lookup the same session by id + let session_lookup = repo + .oauth2_session() + .lookup(session.id) + .await + .unwrap() + .expect("session not found"); + assert_eq!(session, session_lookup); + + // Mark the grant as exchanged + let grant = repo + .oauth2_authorization_grant() + .exchange(&clock, grant) + .await + .unwrap(); + assert!(grant.is_exchanged()); + + // Lookup a non-existing token + let token = repo + .oauth2_access_token() + .lookup(Ulid::nil()) + .await + .unwrap(); + assert_eq!(token, None); + + // Find a non-existing token + let token = repo + .oauth2_access_token() + .find_by_token("aabbcc") + .await + .unwrap(); + assert_eq!(token, None); + + // Create an access token + let access_token = repo + .oauth2_access_token() + .add( + &mut rng, + &clock, + &session, + "aabbcc".to_owned(), + Duration::minutes(5), + ) + .await + .unwrap(); + + // Lookup the same token by id + let access_token_lookup = repo + .oauth2_access_token() + .lookup(access_token.id) + .await + .unwrap() + .expect("token not found"); + assert_eq!(access_token, access_token_lookup); + + // Find the same token by token + let access_token_lookup = repo + .oauth2_access_token() + .find_by_token("aabbcc") + .await + .unwrap() + .expect("token not found"); + assert_eq!(access_token, access_token_lookup); + + // Lookup a non-existing refresh token + let refresh_token = repo + .oauth2_refresh_token() + .lookup(Ulid::nil()) + .await + .unwrap(); + assert_eq!(refresh_token, None); + + // Find a non-existing refresh token + let refresh_token = repo + .oauth2_refresh_token() + .find_by_token("aabbcc") + .await + .unwrap(); + assert_eq!(refresh_token, None); + + // Create a refresh token + let refresh_token = repo + .oauth2_refresh_token() + .add( + &mut rng, + &clock, + &session, + &access_token, + "aabbcc".to_owned(), + ) + .await + .unwrap(); + + // Lookup the same refresh token by id + let refresh_token_lookup = repo + .oauth2_refresh_token() + .lookup(refresh_token.id) + .await + .unwrap() + .expect("refresh token not found"); + assert_eq!(refresh_token, refresh_token_lookup); + + // Find the same refresh token by token + let refresh_token_lookup = repo + .oauth2_refresh_token() + .find_by_token("aabbcc") + .await + .unwrap() + .expect("refresh token not found"); + assert_eq!(refresh_token, refresh_token_lookup); + + assert!(access_token.is_valid(clock.now())); + clock.advance(Duration::minutes(6)); + assert!(!access_token.is_valid(clock.now())); + + // XXX: we might want to create a new access token + clock.advance(Duration::minutes(-6)); // Go back in time + assert!(access_token.is_valid(clock.now())); + + // Mark the access token as revoked + let access_token = repo + .oauth2_access_token() + .revoke(&clock, access_token) + .await + .unwrap(); + assert!(!access_token.is_valid(clock.now())); + + // Mark the refresh token as consumed + assert!(refresh_token.is_valid()); + let refresh_token = repo + .oauth2_refresh_token() + .consume(&clock, refresh_token) + .await + .unwrap(); + assert!(!refresh_token.is_valid()); + + // Mark the session as finished + assert!(session.is_valid()); + let session = repo.oauth2_session().finish(&clock, session).await.unwrap(); + assert!(!session.is_valid()); + + // The session should appear in the paginated list of sessions for the user + let sessions = repo + .oauth2_session() + .list_paginated(&user, Pagination::first(10)) + .await + .unwrap(); + assert!(!sessions.has_next_page); + assert_eq!(sessions.edges, vec![session]); + } +} diff --git a/crates/storage-pg/src/oauth2/refresh_token.rs b/crates/storage-pg/src/oauth2/refresh_token.rs new file mode 100644 index 000000000..ae723f7c3 --- /dev/null +++ b/crates/storage-pg/src/oauth2/refresh_token.rs @@ -0,0 +1,228 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session}; +use mas_storage::{oauth2::OAuth2RefreshTokenRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +/// An implementation of [`OAuth2RefreshTokenRepository`] for a PostgreSQL +/// connection +pub struct PgOAuth2RefreshTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2RefreshTokenRepository<'c> { + /// Create a new [`PgOAuth2RefreshTokenRepository`] from an active + /// PostgreSQL connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct OAuth2RefreshTokenLookup { + oauth2_refresh_token_id: Uuid, + refresh_token: String, + created_at: DateTime, + consumed_at: Option>, + oauth2_access_token_id: Option, + oauth2_session_id: Uuid, +} + +impl From for RefreshToken { + fn from(value: OAuth2RefreshTokenLookup) -> Self { + let state = match value.consumed_at { + None => RefreshTokenState::Valid, + Some(consumed_at) => RefreshTokenState::Consumed { consumed_at }, + }; + + RefreshToken { + id: value.oauth2_refresh_token_id.into(), + state, + session_id: value.oauth2_session_id.into(), + refresh_token: value.refresh_token, + created_at: value.created_at, + access_token_id: value.oauth2_access_token_id.map(Ulid::from), + } + } +} + +#[async_trait] +impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_refresh_token.lookup", + skip_all, + fields( + db.statement, + refresh_token.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2RefreshTokenLookup, + r#" + SELECT oauth2_refresh_token_id + , refresh_token + , created_at + , consumed_at + , oauth2_access_token_id + , oauth2_session_id + FROM oauth2_refresh_tokens + + WHERE oauth2_refresh_token_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2RefreshTokenLookup, + r#" + SELECT oauth2_refresh_token_id + , refresh_token + , created_at + , consumed_at + , oauth2_access_token_id + , oauth2_session_id + FROM oauth2_refresh_tokens + + WHERE refresh_token = $1 + "#, + refresh_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.add", + skip_all, + fields( + db.statement, + %session.id, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + refresh_token.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + session: &Session, + access_token: &AccessToken, + refresh_token: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("refresh_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_refresh_tokens + (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id, + refresh_token, created_at) + VALUES + ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(session.id), + Uuid::from(access_token.id), + refresh_token, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(RefreshToken { + id, + state: RefreshTokenState::default(), + session_id: session.id, + refresh_token, + access_token_id: Some(access_token.id), + created_at, + }) + } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.consume", + skip_all, + fields( + db.statement, + %refresh_token.id, + session.id = %refresh_token.session_id, + ), + err, + )] + async fn consume( + &mut self, + clock: &dyn Clock, + refresh_token: RefreshToken, + ) -> Result { + let consumed_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_refresh_tokens + SET consumed_at = $2 + WHERE oauth2_refresh_token_id = $1 + "#, + Uuid::from(refresh_token.id), + consumed_at, + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + refresh_token + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation) + } +} diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs new file mode 100644 index 000000000..e6168310f --- /dev/null +++ b/crates/storage-pg/src/oauth2/session.rs @@ -0,0 +1,256 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User}; +use mas_storage::{oauth2::OAuth2SessionRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +/// An implementation of [`OAuth2SessionRepository`] for a PostgreSQL connection +pub struct PgOAuth2SessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2SessionRepository<'c> { + /// Create a new [`PgOAuth2SessionRepository`] from an active PostgreSQL + /// connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct OAuthSessionLookup { + oauth2_session_id: Uuid, + user_session_id: Uuid, + oauth2_client_id: Uuid, + scope: String, + #[allow(dead_code)] + created_at: DateTime, + finished_at: Option>, +} + +impl TryFrom for Session { + type Error = DatabaseInconsistencyError; + + fn try_from(value: OAuthSessionLookup) -> Result { + let id = Ulid::from(value.oauth2_session_id); + let scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_sessions") + .column("scope") + .row(id) + .source(e) + })?; + + let state = match value.finished_at { + None => SessionState::Valid, + Some(finished_at) => SessionState::Finished { finished_at }, + }; + + Ok(Session { + id, + state, + created_at: value.created_at, + client_id: value.oauth2_client_id.into(), + user_session_id: value.user_session_id.into(), + scope, + }) + } +} + +#[async_trait] +impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_session.lookup", + skip_all, + fields( + db.statement, + session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuthSessionLookup, + r#" + SELECT oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + , finished_at + FROM oauth2_sessions + + WHERE oauth2_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(session) = res else { return Ok(None) }; + + Ok(Some(session.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_session.create_from_grant", + skip_all, + fields( + db.statement, + %user_session.id, + user.id = %user_session.user.id, + %grant.id, + client.id = %grant.client_id, + session.id, + session.scope = %grant.scope, + ), + err, + )] + async fn create_from_grant( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + grant: &AuthorizationGrant, + user_session: &BrowserSession, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_sessions + ( oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + ) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + Uuid::from(grant.client_id), + grant.scope.to_string(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(Session { + id, + state: SessionState::Valid, + created_at, + user_session_id: user_session.id, + client_id: grant.client_id, + scope: grant.scope.clone(), + }) + } + + #[tracing::instrument( + name = "db.oauth2_session.finish", + skip_all, + fields( + db.statement, + %session.id, + %session.scope, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + ), + err, + )] + async fn finish( + &mut self, + clock: &dyn Clock, + session: Session, + ) -> Result { + let finished_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_sessions + SET finished_at = $2 + WHERE oauth2_session_id = $1 + "#, + Uuid::from(session.id), + finished_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + session + .finish(finished_at) + .map_err(DatabaseError::to_invalid_operation) + } + + #[tracing::instrument( + name = "db.oauth2_session.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err, + )] + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , os.created_at + , os.finished_at + FROM oauth2_sessions os + INNER JOIN user_sessions USING (user_session_id) + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("oauth2_session_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).try_map(Session::try_from)?; + Ok(page) + } +} diff --git a/crates/storage-pg/src/pagination.rs b/crates/storage-pg/src/pagination.rs new file mode 100644 index 000000000..97e5220f9 --- /dev/null +++ b/crates/storage-pg/src/pagination.rs @@ -0,0 +1,78 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Utilities to manage paginated queries. + +use mas_storage::{pagination::PaginationDirection, Pagination}; +use sqlx::{Database, QueryBuilder}; +use uuid::Uuid; + +/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination +/// to a query +pub trait QueryBuilderExt { + /// Add cursor-based pagination to a query, as used in paginated GraphQL + /// connections + fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self; +} + +impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB> +where + DB: Database, + Uuid: sqlx::Type + sqlx::Encode<'a, DB>, + i64: sqlx::Type + sqlx::Encode<'a, DB>, +{ + fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self { + // ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564 + // 1. Start from the greedy query: SELECT * FROM table + + // 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE` + // clause + if let Some(after) = pagination.after { + self.push(" AND ") + .push(id_field) + .push(" > ") + .push_bind(Uuid::from(after)); + } + + // 3. If the before argument is provided, add `id < parsed_cursor` to the + // `WHERE` clause + if let Some(before) = pagination.before { + self.push(" AND ") + .push(id_field) + .push(" < ") + .push_bind(Uuid::from(before)); + } + + match pagination.direction { + // 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the + // query + PaginationDirection::Forward => { + self.push(" ORDER BY ") + .push(id_field) + .push(" ASC LIMIT ") + .push_bind((pagination.count + 1) as i64); + } + // 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the + // query + PaginationDirection::Backward => { + self.push(" ORDER BY ") + .push(id_field) + .push(" DESC LIMIT ") + .push_bind((pagination.count + 1) as i64); + } + }; + + self + } +} diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs new file mode 100644 index 000000000..da81d3af4 --- /dev/null +++ b/crates/storage-pg/src/repository.rs @@ -0,0 +1,181 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; +use mas_storage::{ + compat::{ + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + CompatSsoLoginRepository, + }, + oauth2::{ + OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, + OAuth2RefreshTokenRepository, OAuth2SessionRepository, + }, + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + Repository, RepositoryAccess, RepositoryTransaction, +}; +use sqlx::{PgPool, Postgres, Transaction}; + +use crate::{ + compat::{ + PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository, + PgCompatSsoLoginRepository, + }, + oauth2::{ + PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository, + PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, + }, + upstream_oauth2::{ + PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, + PgUpstreamOAuthSessionRepository, + }, + user::{ + PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository, + PgUserRepository, + }, + DatabaseError, +}; + +/// An implementation of the [`Repository`] trait backed by a PostgreSQL +/// transaction. +pub struct PgRepository { + txn: Transaction<'static, Postgres>, +} + +impl PgRepository { + /// Create a new [`PgRepository`] from a PostgreSQL connection pool, + /// starting a transaction. + /// + /// # Errors + /// + /// Returns a [`DatabaseError`] if the transaction could not be started. + pub async fn from_pool(pool: &PgPool) -> Result { + let txn = pool.begin().await?; + Ok(PgRepository { txn }) + } +} + +impl Repository for PgRepository {} + +impl RepositoryTransaction for PgRepository { + type Error = DatabaseError; + + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + self.txn.commit().map_err(DatabaseError::from).boxed() + } + + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + self.txn.rollback().map_err(DatabaseError::from).boxed() + } +} + +impl RepositoryAccess for PgRepository { + type Error = DatabaseError; + + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUpstreamOAuthLinkRepository::new(&mut self.txn)) + } + + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUpstreamOAuthProviderRepository::new(&mut self.txn)) + } + + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUpstreamOAuthSessionRepository::new(&mut self.txn)) + } + + fn user<'c>(&'c mut self) -> Box + 'c> { + Box::new(PgUserRepository::new(&mut self.txn)) + } + + fn user_email<'c>(&'c mut self) -> Box + 'c> { + Box::new(PgUserEmailRepository::new(&mut self.txn)) + } + + fn user_password<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUserPasswordRepository::new(&mut self.txn)) + } + + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgBrowserSessionRepository::new(&mut self.txn)) + } + + fn oauth2_client<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2ClientRepository::new(&mut self.txn)) + } + + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2AuthorizationGrantRepository::new(&mut self.txn)) + } + + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2SessionRepository::new(&mut self.txn)) + } + + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2AccessTokenRepository::new(&mut self.txn)) + } + + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2RefreshTokenRepository::new(&mut self.txn)) + } + + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgCompatSessionRepository::new(&mut self.txn)) + } + + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgCompatSsoLoginRepository::new(&mut self.txn)) + } + + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgCompatAccessTokenRepository::new(&mut self.txn)) + } + + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgCompatRefreshTokenRepository::new(&mut self.txn)) + } +} diff --git a/crates/storage-pg/src/tracing.rs b/crates/storage-pg/src/tracing.rs new file mode 100644 index 000000000..853b5d9d9 --- /dev/null +++ b/crates/storage-pg/src/tracing.rs @@ -0,0 +1,40 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use tracing::Span; + +/// An extension trait for [`sqlx::Execute`] that records the SQL statement as +/// `db.statement` in a tracing span +pub trait ExecuteExt<'q, DB>: Sized { + /// Records the statement as `db.statement` in the current span + #[must_use] + fn traced(self) -> Self { + self.record(&Span::current()) + } + + /// Records the statement as `db.statement` in the given span + #[must_use] + fn record(self, span: &Span) -> Self; +} + +impl<'q, DB, T> ExecuteExt<'q, DB> for T +where + T: sqlx::Execute<'q, DB>, + DB: sqlx::Database, +{ + fn record(self, span: &Span) -> Self { + span.record("db.statement", self.sql()); + self + } +} diff --git a/crates/storage-pg/src/upstream_oauth2/link.rs b/crates/storage-pg/src/upstream_oauth2/link.rs new file mode 100644 index 000000000..0e14f3fd5 --- /dev/null +++ b/crates/storage-pg/src/upstream_oauth2/link.rs @@ -0,0 +1,266 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; +use mas_storage::{upstream_oauth2::UpstreamOAuthLinkRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +/// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL +/// connection +pub struct PgUpstreamOAuthLinkRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthLinkRepository<'c> { + /// Create a new [`PgUpstreamOAuthLinkRepository`] from an active PostgreSQL + /// connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct LinkLookup { + upstream_oauth_link_id: Uuid, + upstream_oauth_provider_id: Uuid, + user_id: Option, + subject: String, + created_at: DateTime, +} + +impl From for UpstreamOAuthLink { + fn from(value: LinkLookup) -> Self { + UpstreamOAuthLink { + id: Ulid::from(value.upstream_oauth_link_id), + provider_id: Ulid::from(value.upstream_oauth_provider_id), + user_id: value.user_id.map(Ulid::from), + subject: value.subject, + created_at: value.created_at, + } + } +} + +#[async_trait] +impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.upstream_oauth_link.lookup", + skip_all, + fields( + db.statement, + upstream_oauth_link.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + LinkLookup, + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + WHERE upstream_oauth_link_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()? + .map(Into::into); + + Ok(res) + } + + #[tracing::instrument( + name = "db.upstream_oauth_link.find_by_subject", + skip_all, + fields( + db.statement, + upstream_oauth_link.subject = subject, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + ), + err, + )] + async fn find_by_subject( + &mut self, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + LinkLookup, + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + WHERE upstream_oauth_provider_id = $1 + AND subject = $2 + "#, + Uuid::from(upstream_oauth_provider.id), + subject, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()? + .map(Into::into); + + Ok(res) + } + + #[tracing::instrument( + name = "db.upstream_oauth_link.add", + skip_all, + fields( + db.statement, + upstream_oauth_link.id, + upstream_oauth_link.subject = subject, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_links ( + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + ) VALUES ($1, $2, NULL, $3, $4) + "#, + Uuid::from(id), + Uuid::from(upstream_oauth_provider.id), + &subject, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UpstreamOAuthLink { + id, + provider_id: upstream_oauth_provider.id, + user_id: None, + subject, + created_at, + }) + } + + #[tracing::instrument( + name = "db.upstream_oauth_link.associate_to_user", + skip_all, + fields( + db.statement, + %upstream_oauth_link.id, + %upstream_oauth_link.subject, + %user.id, + %user.username, + ), + err, + )] + async fn associate_to_user( + &mut self, + upstream_oauth_link: &UpstreamOAuthLink, + user: &User, + ) -> Result<(), Self::Error> { + sqlx::query!( + r#" + UPDATE upstream_oauth_links + SET user_id = $1 + WHERE upstream_oauth_link_id = $2 + "#, + Uuid::from(user.id), + Uuid::from(upstream_oauth_link.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + #[tracing::instrument( + name = "db.upstream_oauth_link.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err + )] + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("upstream_oauth_link_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).map(UpstreamOAuthLink::from); + Ok(page) + } +} diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs new file mode 100644 index 000000000..5bf97514f --- /dev/null +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -0,0 +1,277 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! A module containing the PostgreSQL implementation of the repositories +//! related to the upstream OAuth 2.0 providers + +mod link; +mod provider; +mod session; + +pub use self::{ + link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository, + session::PgUpstreamOAuthSessionRepository, +}; + +#[cfg(test)] +mod tests { + use chrono::Duration; + use mas_storage::{ + clock::MockClock, + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + user::UserRepository, + Pagination, RepositoryAccess, + }; + use oauth2_types::scope::{Scope, OPENID}; + use rand::SeedableRng; + use sqlx::PgPool; + + use crate::PgRepository; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_repository(pool: PgPool) { + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // The provider list should be empty at the start + let all_providers = repo.upstream_oauth_provider().all().await.unwrap(); + assert!(all_providers.is_empty()); + + // Let's add a provider + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + "https://example.com/".to_owned(), + Scope::from_iter([OPENID]), + mas_iana::oauth::OAuthClientAuthenticationMethod::None, + None, + "client-id".to_owned(), + None, + ) + .await + .unwrap(); + + // Look it up in the database + let provider = repo + .upstream_oauth_provider() + .lookup(provider.id) + .await + .unwrap() + .expect("provider to be found in the database"); + assert_eq!(provider.issuer, "https://example.com/"); + assert_eq!(provider.client_id, "client-id"); + + // Start a session + let session = repo + .upstream_oauth_session() + .add( + &mut rng, + &clock, + &provider, + "some-state".to_owned(), + None, + "some-nonce".to_owned(), + ) + .await + .unwrap(); + + // Look it up in the database + let session = repo + .upstream_oauth_session() + .lookup(session.id) + .await + .unwrap() + .expect("session to be found in the database"); + assert_eq!(session.provider_id, provider.id); + assert_eq!(session.link_id(), None); + assert!(session.is_pending()); + assert!(!session.is_completed()); + assert!(!session.is_consumed()); + + // Create a link + let link = repo + .upstream_oauth_link() + .add(&mut rng, &clock, &provider, "a-subject".to_owned()) + .await + .unwrap(); + + // We can look it up by its ID + repo.upstream_oauth_link() + .lookup(link.id) + .await + .unwrap() + .expect("link to be found in database"); + + // or by its subject + let link = repo + .upstream_oauth_link() + .find_by_subject(&provider, "a-subject") + .await + .unwrap() + .expect("link to be found in database"); + assert_eq!(link.subject, "a-subject"); + assert_eq!(link.provider_id, provider.id); + + let session = repo + .upstream_oauth_session() + .complete_with_link(&clock, session, &link, None) + .await + .unwrap(); + // Reload the session + let session = repo + .upstream_oauth_session() + .lookup(session.id) + .await + .unwrap() + .expect("session to be found in the database"); + assert!(session.is_completed()); + assert!(!session.is_consumed()); + assert_eq!(session.link_id(), Some(link.id)); + + let session = repo + .upstream_oauth_session() + .consume(&clock, session) + .await + .unwrap(); + // Reload the session + let session = repo + .upstream_oauth_session() + .lookup(session.id) + .await + .unwrap() + .expect("session to be found in the database"); + assert!(session.is_consumed()); + + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + repo.upstream_oauth_link() + .associate_to_user(&link, &user) + .await + .unwrap(); + + let links = repo + .upstream_oauth_link() + .list_paginated(&user, Pagination::first(10)) + .await + .unwrap(); + assert!(!links.has_previous_page); + assert!(!links.has_next_page); + assert_eq!(links.edges.len(), 1); + assert_eq!(links.edges[0].id, link.id); + assert_eq!(links.edges[0].user_id, Some(user.id)); + } + + /// Test that the pagination works as expected in the upstream OAuth + /// provider repository + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_provider_repository_pagination(pool: PgPool) { + const ISSUER: &str = "https://example.com/"; + let scope = Scope::from_iter([OPENID]); + + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + let mut ids = Vec::with_capacity(20); + // Create 20 providers + for idx in 0..20 { + let client_id = format!("client-{idx}"); + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + ISSUER.to_owned(), + scope.clone(), + mas_iana::oauth::OAuthClientAuthenticationMethod::None, + None, + client_id, + None, + ) + .await + .unwrap(); + ids.push(provider.id); + clock.advance(Duration::seconds(10)); + } + + // Lookup the first 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::first(10)) + .await + .unwrap(); + + // It returned the first 10 items + assert!(page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[..10]); + + // Lookup the next 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::first(10).after(ids[9])) + .await + .unwrap(); + + // It returned the next 10 items + assert!(!page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[10..]); + + // Lookup the last 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::last(10)) + .await + .unwrap(); + + // It returned the last 10 items + assert!(page.has_previous_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[10..]); + + // Lookup the previous 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::last(10).before(ids[10])) + .await + .unwrap(); + + // It returned the previous 10 items + assert!(!page.has_previous_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[..10]); + + // Lookup 10 items between two IDs + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::first(10).after(ids[5]).before(ids[8])) + .await + .unwrap(); + + // It returned the items in between + assert!(!page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[6..8]); + } +} diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs new file mode 100644 index 000000000..d4ecbe473 --- /dev/null +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -0,0 +1,277 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::UpstreamOAuthProvider; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Clock, Page, Pagination}; +use oauth2_types::scope::Scope; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +/// An implementation of [`UpstreamOAuthProviderRepository`] for a PostgreSQL +/// connection +pub struct PgUpstreamOAuthProviderRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthProviderRepository<'c> { + /// Create a new [`PgUpstreamOAuthProviderRepository`] from an active + /// PostgreSQL connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct ProviderLookup { + upstream_oauth_provider_id: Uuid, + issuer: String, + scope: String, + client_id: String, + encrypted_client_secret: Option, + token_endpoint_signing_alg: Option, + token_endpoint_auth_method: String, + created_at: DateTime, +} + +impl TryFrom for UpstreamOAuthProvider { + type Error = DatabaseInconsistencyError; + fn try_from(value: ProviderLookup) -> Result { + let id = value.upstream_oauth_provider_id.into(); + let scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("upstream_oauth_providers") + .column("scope") + .row(id) + .source(e) + })?; + let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| { + DatabaseInconsistencyError::on("upstream_oauth_providers") + .column("token_endpoint_auth_method") + .row(id) + .source(e) + })?; + let token_endpoint_signing_alg = value + .token_endpoint_signing_alg + .map(|x| x.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("upstream_oauth_providers") + .column("token_endpoint_signing_alg") + .row(id) + .source(e) + })?; + + Ok(UpstreamOAuthProvider { + id, + issuer: value.issuer, + scope, + client_id: value.client_id, + encrypted_client_secret: value.encrypted_client_secret, + token_endpoint_auth_method, + token_endpoint_signing_alg, + created_at: value.created_at, + }) + } +} + +#[async_trait] +impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.upstream_oauth_provider.lookup", + skip_all, + fields( + db.statement, + upstream_oauth_provider.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + ProviderLookup, + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let res = res + .map(UpstreamOAuthProvider::try_from) + .transpose() + .map_err(DatabaseError::from)?; + + Ok(res) + } + + #[tracing::instrument( + name = "db.upstream_oauth_provider.add", + skip_all, + fields( + db.statement, + upstream_oauth_provider.id, + upstream_oauth_provider.issuer = %issuer, + upstream_oauth_provider.client_id = %client_id, + ), + err, + )] + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_providers ( + upstream_oauth_provider_id, + issuer, + scope, + token_endpoint_auth_method, + token_endpoint_signing_alg, + client_id, + encrypted_client_secret, + created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + "#, + Uuid::from(id), + &issuer, + scope.to_string(), + token_endpoint_auth_method.to_string(), + token_endpoint_signing_alg.as_ref().map(ToString::to_string), + &client_id, + encrypted_client_secret.as_deref(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UpstreamOAuthProvider { + id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at, + }) + } + + #[tracing::instrument( + name = "db.upstream_oauth_provider.list_paginated", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn list_paginated( + &mut self, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + WHERE 1 = 1 + "#, + ); + + query.generate_pagination("upstream_oauth_provider_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).try_map(TryInto::try_into)?; + Ok(page) + } + + #[tracing::instrument( + name = "db.upstream_oauth_provider.all", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn all(&mut self) -> Result, Self::Error> { + let res = sqlx::query_as!( + ProviderLookup, + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + "#, + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let res: Result, _> = res.into_iter().map(TryInto::try_into).collect(); + Ok(res?) + } +} diff --git a/crates/storage-pg/src/upstream_oauth2/session.rs b/crates/storage-pg/src/upstream_oauth2/session.rs new file mode 100644 index 000000000..5780ab8d3 --- /dev/null +++ b/crates/storage-pg/src/upstream_oauth2/session.rs @@ -0,0 +1,290 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, + UpstreamOAuthProvider, +}; +use mas_storage::{upstream_oauth2::UpstreamOAuthSessionRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +/// An implementation of [`UpstreamOAuthSessionRepository`] for a PostgreSQL +/// connection +pub struct PgUpstreamOAuthSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthSessionRepository<'c> { + /// Create a new [`PgUpstreamOAuthSessionRepository`] from an active + /// PostgreSQL connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct SessionLookup { + upstream_oauth_authorization_session_id: Uuid, + upstream_oauth_provider_id: Uuid, + upstream_oauth_link_id: Option, + state: String, + code_challenge_verifier: Option, + nonce: String, + id_token: Option, + created_at: DateTime, + completed_at: Option>, + consumed_at: Option>, +} + +impl TryFrom for UpstreamOAuthAuthorizationSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: SessionLookup) -> Result { + let id = value.upstream_oauth_authorization_session_id.into(); + let state = match ( + value.upstream_oauth_link_id, + value.id_token, + value.completed_at, + value.consumed_at, + ) { + (None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, + (Some(link_id), id_token, Some(completed_at), None) => { + UpstreamOAuthAuthorizationSessionState::Completed { + completed_at, + link_id: link_id.into(), + id_token, + } + } + (Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => { + UpstreamOAuthAuthorizationSessionState::Consumed { + completed_at, + link_id: link_id.into(), + id_token, + consumed_at, + } + } + _ => { + return Err( + DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id), + ) + } + }; + + Ok(Self { + id, + provider_id: value.upstream_oauth_provider_id.into(), + state_str: value.state, + nonce: value.nonce, + code_challenge_verifier: value.code_challenge_verifier, + created_at: value.created_at, + state, + }) + } +} + +#[async_trait] +impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.lookup", + skip_all, + fields( + db.statement, + upstream_oauth_provider.id = %id, + ), + err, + )] + async fn lookup( + &mut self, + id: Ulid, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + SessionLookup, + r#" + SELECT + upstream_oauth_authorization_session_id, + upstream_oauth_provider_id, + upstream_oauth_link_id, + state, + code_challenge_verifier, + nonce, + id_token, + created_at, + completed_at, + consumed_at + FROM upstream_oauth_authorization_sessions + WHERE upstream_oauth_authorization_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.add", + skip_all, + fields( + db.statement, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + upstream_oauth_authorization_session.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + state_str: String, + code_challenge_verifier: Option, + nonce: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "upstream_oauth_authorization_session.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_authorization_sessions ( + upstream_oauth_authorization_session_id, + upstream_oauth_provider_id, + state, + code_challenge_verifier, + nonce, + created_at, + completed_at, + consumed_at, + id_token + ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL) + "#, + Uuid::from(id), + Uuid::from(upstream_oauth_provider.id), + &state_str, + code_challenge_verifier.as_deref(), + nonce, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UpstreamOAuthAuthorizationSession { + id, + state: UpstreamOAuthAuthorizationSessionState::default(), + provider_id: upstream_oauth_provider.id, + state_str, + code_challenge_verifier, + nonce, + created_at, + }) + } + + #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.complete_with_link", + skip_all, + fields( + db.statement, + %upstream_oauth_authorization_session.id, + %upstream_oauth_link.id, + ), + err, + )] + async fn complete_with_link( + &mut self, + clock: &dyn Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result { + let completed_at = clock.now(); + + sqlx::query!( + r#" + UPDATE upstream_oauth_authorization_sessions + SET upstream_oauth_link_id = $1, + completed_at = $2, + id_token = $3 + WHERE upstream_oauth_authorization_session_id = $4 + "#, + Uuid::from(upstream_oauth_link.id), + completed_at, + id_token, + Uuid::from(upstream_oauth_authorization_session.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let upstream_oauth_authorization_session = upstream_oauth_authorization_session + .complete(completed_at, upstream_oauth_link, id_token) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(upstream_oauth_authorization_session) + } + + /// Mark a session as consumed + #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.consume", + skip_all, + fields( + db.statement, + %upstream_oauth_authorization_session.id, + ), + err, + )] + async fn consume( + &mut self, + clock: &dyn Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + ) -> Result { + let consumed_at = clock.now(); + sqlx::query!( + r#" + UPDATE upstream_oauth_authorization_sessions + SET consumed_at = $1 + WHERE upstream_oauth_authorization_session_id = $2 + "#, + consumed_at, + Uuid::from(upstream_oauth_authorization_session.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let upstream_oauth_authorization_session = upstream_oauth_authorization_session + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(upstream_oauth_authorization_session) + } +} diff --git a/crates/storage-pg/src/user/email.rs b/crates/storage-pg/src/user/email.rs new file mode 100644 index 000000000..147542d9d --- /dev/null +++ b/crates/storage-pg/src/user/email.rs @@ -0,0 +1,557 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState}; +use mas_storage::{user::UserEmailRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use tracing::{info_span, Instrument}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +/// An implementation of [`UserEmailRepository`] for a PostgreSQL connection +pub struct PgUserEmailRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserEmailRepository<'c> { + /// Create a new [`PgUserEmailRepository`] from an active PostgreSQL + /// connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(Debug, Clone, sqlx::FromRow)] +struct UserEmailLookup { + user_email_id: Uuid, + user_id: Uuid, + email: String, + created_at: DateTime, + confirmed_at: Option>, +} + +impl From for UserEmail { + fn from(e: UserEmailLookup) -> UserEmail { + UserEmail { + id: e.user_email_id.into(), + user_id: e.user_id.into(), + email: e.email, + created_at: e.created_at, + confirmed_at: e.confirmed_at, + } + } +} + +struct UserEmailConfirmationCodeLookup { + user_email_confirmation_code_id: Uuid, + user_email_id: Uuid, + code: String, + created_at: DateTime, + expires_at: DateTime, + consumed_at: Option>, +} + +impl UserEmailConfirmationCodeLookup { + fn into_verification(self, clock: &dyn Clock) -> UserEmailVerification { + let now = clock.now(); + let state = if let Some(when) = self.consumed_at { + UserEmailVerificationState::AlreadyUsed { when } + } else if self.expires_at < now { + UserEmailVerificationState::Expired { + when: self.expires_at, + } + } else { + UserEmailVerificationState::Valid + }; + + UserEmailVerification { + id: self.user_email_confirmation_code_id.into(), + user_email_id: self.user_email_id.into(), + code: self.code, + state, + created_at: self.created_at, + } + } +} + +#[async_trait] +impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.user_email.lookup", + skip_all, + fields( + db.statement, + user_email.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailLookup, + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + + WHERE user_email_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(user_email) = res else { return Ok(None) }; + + Ok(Some(user_email.into())) + } + + #[tracing::instrument( + name = "db.user_email.find", + skip_all, + fields( + db.statement, + %user.id, + user_email.email = email, + ), + err, + )] + async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailLookup, + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + + WHERE user_id = $1 AND email = $2 + "#, + Uuid::from(user.id), + email, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(user_email) = res else { return Ok(None) }; + + Ok(Some(user_email.into())) + } + + #[tracing::instrument( + name = "db.user_email.get_primary", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn get_primary(&mut self, user: &User) -> Result, Self::Error> { + let Some(id) = user.primary_user_email_id else { return Ok(None) }; + + let user_email = self.lookup(id).await?.ok_or_else(|| { + DatabaseInconsistencyError::on("users") + .column("primary_user_email_id") + .row(user.id) + })?; + + Ok(Some(user_email)) + } + + #[tracing::instrument( + name = "db.user_email.all", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn all(&mut self, user: &User) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailLookup, + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + + WHERE user_id = $1 + + ORDER BY email ASC + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + Ok(res.into_iter().map(Into::into).collect()) + } + + #[tracing::instrument( + name = "db.user_email.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, DatabaseError> { + let mut query = QueryBuilder::new( + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("user_email_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).map(UserEmail::from); + Ok(page) + } + + #[tracing::instrument( + name = "db.user_email.count", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn count(&mut self, user: &User) -> Result { + let res = sqlx::query_scalar!( + r#" + SELECT COUNT(*) + FROM user_emails + WHERE user_id = $1 + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + let res = res.unwrap_or_default(); + + Ok(res + .try_into() + .map_err(DatabaseError::to_invalid_operation)?) + } + + #[tracing::instrument( + name = "db.user_email.add", + skip_all, + fields( + db.statement, + %user.id, + user_email.id, + user_email.email = email, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + email: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_email.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO user_emails (user_email_id, user_id, email, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + Uuid::from(user.id), + &email, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UserEmail { + id, + user_id: user.id, + email, + created_at, + confirmed_at: None, + }) + } + + #[tracing::instrument( + name = "db.user_email.remove", + skip_all, + fields( + db.statement, + user.id = %user_email.user_id, + %user_email.id, + %user_email.email, + ), + err, + )] + async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> { + let span = info_span!( + "db.user_email.remove.codes", + db.statement = tracing::field::Empty + ); + sqlx::query!( + r#" + DELETE FROM user_email_confirmation_codes + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + + let res = sqlx::query!( + r#" + DELETE FROM user_emails + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(()) + } + + async fn mark_as_verified( + &mut self, + clock: &dyn Clock, + mut user_email: UserEmail, + ) -> Result { + let confirmed_at = clock.now(); + sqlx::query!( + r#" + UPDATE user_emails + SET confirmed_at = $2 + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + confirmed_at, + ) + .execute(&mut *self.conn) + .await?; + + user_email.confirmed_at = Some(confirmed_at); + Ok(user_email) + } + + async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error> { + sqlx::query!( + r#" + UPDATE users + SET primary_user_email_id = user_emails.user_email_id + FROM user_emails + WHERE user_emails.user_email_id = $1 + AND users.user_id = user_emails.user_id + "#, + Uuid::from(user_email.id), + ) + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + #[tracing::instrument( + name = "db.user_email.add_verification_code", + skip_all, + fields( + db.statement, + %user_email.id, + %user_email.email, + user_email_verification.id, + user_email_verification.code = code, + ), + err, + )] + async fn add_verification_code( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_email: &UserEmail, + max_age: chrono::Duration, + code: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id)); + let expires_at = created_at + max_age; + + sqlx::query!( + r#" + INSERT INTO user_email_confirmation_codes + (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(user_email.id), + code, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let verification = UserEmailVerification { + id, + user_email_id: user_email.id, + code, + created_at, + state: UserEmailVerificationState::Valid, + }; + + Ok(verification) + } + + #[tracing::instrument( + name = "db.user_email.find_verification_code", + skip_all, + fields( + db.statement, + %user_email.id, + user.id = %user_email.user_id, + ), + err, + )] + async fn find_verification_code( + &mut self, + clock: &dyn Clock, + user_email: &UserEmail, + code: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailConfirmationCodeLookup, + r#" + SELECT user_email_confirmation_code_id + , user_email_id + , code + , created_at + , expires_at + , consumed_at + FROM user_email_confirmation_codes + WHERE code = $1 + AND user_email_id = $2 + "#, + code, + Uuid::from(user_email.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into_verification(clock))) + } + + #[tracing::instrument( + name = "db.user_email.consume_verification_code", + skip_all, + fields( + db.statement, + %user_email_verification.id, + user_email.id = %user_email_verification.user_email_id, + ), + err, + )] + async fn consume_verification_code( + &mut self, + clock: &dyn Clock, + mut user_email_verification: UserEmailVerification, + ) -> Result { + if !matches!( + user_email_verification.state, + UserEmailVerificationState::Valid + ) { + return Err(DatabaseError::invalid_operation()); + } + + let consumed_at = clock.now(); + + sqlx::query!( + r#" + UPDATE user_email_confirmation_codes + SET consumed_at = $2 + WHERE user_email_confirmation_code_id = $1 + "#, + Uuid::from(user_email_verification.id), + consumed_at + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_email_verification.state = + UserEmailVerificationState::AlreadyUsed { when: consumed_at }; + + Ok(user_email_verification) + } +} diff --git a/crates/storage-pg/src/user/mod.rs b/crates/storage-pg/src/user/mod.rs new file mode 100644 index 000000000..0554c8b25 --- /dev/null +++ b/crates/storage-pg/src/user/mod.rs @@ -0,0 +1,208 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! A module containing the PostgreSQL implementation of the user-related +//! repositories + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::User; +use mas_storage::{user::UserRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +mod email; +mod password; +mod session; + +#[cfg(test)] +mod tests; + +pub use self::{ + email::PgUserEmailRepository, password::PgUserPasswordRepository, + session::PgBrowserSessionRepository, +}; + +/// An implementation of [`UserRepository`] for a PostgreSQL connection +pub struct PgUserRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserRepository<'c> { + /// Create a new [`PgUserRepository`] from an active PostgreSQL connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(Debug, Clone)] +struct UserLookup { + user_id: Uuid, + username: String, + primary_user_email_id: Option, + + #[allow(dead_code)] + created_at: DateTime, +} + +impl From for User { + fn from(value: UserLookup) -> Self { + let id = value.user_id.into(); + Self { + id, + username: value.username, + sub: id.to_string(), + primary_user_email_id: value.primary_user_email_id.map(Into::into), + } + } +} + +#[async_trait] +impl<'c> UserRepository for PgUserRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.user.lookup", + skip_all, + fields( + db.statement, + user.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserLookup, + r#" + SELECT user_id + , username + , primary_user_email_id + , created_at + FROM users + WHERE user_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.user.find_by_username", + skip_all, + fields( + db.statement, + user.username = username, + ), + err, + )] + async fn find_by_username(&mut self, username: &str) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserLookup, + r#" + SELECT user_id + , username + , primary_user_email_id + , created_at + FROM users + WHERE username = $1 + "#, + username, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.user.add", + skip_all, + fields( + db.statement, + user.username = username, + user.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + username: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO users (user_id, username, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + username, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(User { + id, + username, + sub: id.to_string(), + primary_user_email_id: None, + }) + } + + #[tracing::instrument( + name = "db.user.exists", + skip_all, + fields( + db.statement, + user.username = username, + ), + err, + )] + async fn exists(&mut self, username: &str) -> Result { + let exists = sqlx::query_scalar!( + r#" + SELECT EXISTS( + SELECT 1 FROM users WHERE username = $1 + ) AS "exists!" + "#, + username + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + Ok(exists) + } +} diff --git a/crates/storage-pg/src/user/password.rs b/crates/storage-pg/src/user/password.rs new file mode 100644 index 000000000..1dfd90d1c --- /dev/null +++ b/crates/storage-pg/src/user/password.rs @@ -0,0 +1,158 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{Password, User}; +use mas_storage::{user::UserPasswordRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +/// An implementation of [`UserPasswordRepository`] for a PostgreSQL connection +pub struct PgUserPasswordRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserPasswordRepository<'c> { + /// Create a new [`PgUserPasswordRepository`] from an active PostgreSQL + /// connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct UserPasswordLookup { + user_password_id: Uuid, + hashed_password: String, + version: i32, + upgraded_from_id: Option, + created_at: DateTime, +} + +#[async_trait] +impl<'c> UserPasswordRepository for PgUserPasswordRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.user_password.active", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err, + )] + async fn active(&mut self, user: &User) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserPasswordLookup, + r#" + SELECT up.user_password_id + , up.hashed_password + , up.version + , up.upgraded_from_id + , up.created_at + FROM user_passwords up + WHERE up.user_id = $1 + ORDER BY up.created_at DESC + LIMIT 1 + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + let id = Ulid::from(res.user_password_id); + + let version = res.version.try_into().map_err(|e| { + DatabaseInconsistencyError::on("user_passwords") + .column("version") + .row(id) + .source(e) + })?; + + let upgraded_from_id = res.upgraded_from_id.map(Ulid::from); + let created_at = res.created_at; + let hashed_password = res.hashed_password; + + Ok(Some(Password { + id, + hashed_password, + version, + upgraded_from_id, + created_at, + })) + } + + #[tracing::instrument( + name = "db.user_password.add", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + user_password.id, + user_password.version = version, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + version: u16, + hashed_password: String, + upgraded_from: Option<&Password>, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_password.id", tracing::field::display(id)); + + let upgraded_from_id = upgraded_from.map(|p| p.id); + + sqlx::query!( + r#" + INSERT INTO user_passwords + (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + Uuid::from(id), + Uuid::from(user.id), + hashed_password, + i32::from(version), + upgraded_from_id.map(Uuid::from), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(Password { + id, + hashed_password, + version, + upgraded_from_id, + created_at, + }) + } +} diff --git a/crates/storage-pg/src/user/session.rs b/crates/storage-pg/src/user/session.rs new file mode 100644 index 000000000..ff91726ea --- /dev/null +++ b/crates/storage-pg/src/user/session.rs @@ -0,0 +1,379 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User}; +use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +/// An implementation of [`BrowserSessionRepository`] for a PostgreSQL +/// connection +pub struct PgBrowserSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgBrowserSessionRepository<'c> { + /// Create a new [`PgBrowserSessionRepository`] from an active PostgreSQL + /// connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct SessionLookup { + user_session_id: Uuid, + user_session_created_at: DateTime, + user_session_finished_at: Option>, + user_id: Uuid, + user_username: String, + user_primary_user_email_id: Option, + last_authentication_id: Option, + last_authd_at: Option>, +} + +impl TryFrom for BrowserSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: SessionLookup) -> Result { + let id = Ulid::from(value.user_id); + let user = User { + id, + username: value.user_username, + sub: id.to_string(), + primary_user_email_id: value.user_primary_user_email_id.map(Into::into), + }; + + let last_authentication = match (value.last_authentication_id, value.last_authd_at) { + (Some(id), Some(created_at)) => Some(Authentication { + id: id.into(), + created_at, + }), + (None, None) => None, + _ => { + return Err(DatabaseInconsistencyError::on( + "user_session_authentications", + )) + } + }; + + Ok(BrowserSession { + id: value.user_session_id.into(), + user, + created_at: value.user_session_created_at, + finished_at: value.user_session_finished_at, + last_authentication, + }) + } +} + +#[async_trait] +impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.browser_session.lookup", + skip_all, + fields( + db.statement, + user_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + SessionLookup, + r#" + SELECT s.user_session_id + , s.created_at AS "user_session_created_at" + , s.finished_at AS "user_session_finished_at" + , u.user_id + , u.username AS "user_username" + , u.primary_user_email_id AS "user_primary_user_email_id" + , a.user_session_authentication_id AS "last_authentication_id?" + , a.created_at AS "last_authd_at?" + FROM user_sessions s + INNER JOIN users u + USING (user_id) + LEFT JOIN user_session_authentications a + USING (user_session_id) + WHERE s.user_session_id = $1 + ORDER BY a.created_at DESC + LIMIT 1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.browser_session.add", + skip_all, + fields( + db.statement, + %user.id, + user_session.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO user_sessions (user_session_id, user_id, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + Uuid::from(user.id), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let session = BrowserSession { + id, + // XXX + user: user.clone(), + created_at, + finished_at: None, + last_authentication: None, + }; + + Ok(session) + } + + #[tracing::instrument( + name = "db.browser_session.finish", + skip_all, + fields( + db.statement, + %user_session.id, + ), + err, + )] + async fn finish( + &mut self, + clock: &dyn Clock, + mut user_session: BrowserSession, + ) -> Result { + let finished_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE user_sessions + SET finished_at = $1 + WHERE user_session_id = $2 + "#, + finished_at, + Uuid::from(user_session.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_session.finished_at = Some(finished_at); + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(user_session) + } + + #[tracing::instrument( + name = "db.browser_session.list_active_paginated", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn list_active_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error> { + // TODO: ordering of last authentication is wrong + let mut query = QueryBuilder::new( + r#" + SELECT DISTINCT ON (s.user_session_id) + s.user_session_id, + u.user_id, + u.username, + s.created_at, + a.user_session_authentication_id AS "last_authentication_id", + a.created_at AS "last_authd_at", + FROM user_sessions s + INNER JOIN users u + USING (user_id) + LEFT JOIN user_session_authentications a + USING (user_session_id) + "#, + ); + + query + .push(" WHERE s.finished_at IS NULL AND s.user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("s.user_session_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination + .process(edges) + .try_map(BrowserSession::try_from)?; + Ok(page) + } + + #[tracing::instrument( + name = "db.browser_session.count_active", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn count_active(&mut self, user: &User) -> Result { + let res = sqlx::query_scalar!( + r#" + SELECT COUNT(*) as "count!" + FROM user_sessions s + WHERE s.user_id = $1 AND s.finished_at IS NULL + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + res.try_into().map_err(DatabaseError::to_invalid_operation) + } + + #[tracing::instrument( + name = "db.browser_session.authenticate_with_password", + skip_all, + fields( + db.statement, + %user_session.id, + %user_password.id, + user_session_authentication.id, + ), + err, + )] + async fn authenticate_with_password( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + mut user_session: BrowserSession, + user_password: &Password, + ) -> Result { + let _user_password = user_password; + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "user_session_authentication.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO user_session_authentications + (user_session_authentication_id, user_session_id, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_session.last_authentication = Some(Authentication { id, created_at }); + + Ok(user_session) + } + + #[tracing::instrument( + name = "db.browser_session.authenticate_with_upstream", + skip_all, + fields( + db.statement, + %user_session.id, + %upstream_oauth_link.id, + user_session_authentication.id, + ), + err, + )] + async fn authenticate_with_upstream( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + mut user_session: BrowserSession, + upstream_oauth_link: &UpstreamOAuthLink, + ) -> Result { + let _upstream_oauth_link = upstream_oauth_link; + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "user_session_authentication.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO user_session_authentications + (user_session_authentication_id, user_session_id, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_session.last_authentication = Some(Authentication { id, created_at }); + + Ok(user_session) + } +} diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs new file mode 100644 index 000000000..9aec949de --- /dev/null +++ b/crates/storage-pg/src/user/tests.rs @@ -0,0 +1,407 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::Duration; +use mas_storage::{ + clock::MockClock, + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + Pagination, Repository, RepositoryAccess, +}; +use rand::SeedableRng; +use rand_chacha::ChaChaRng; +use sqlx::PgPool; + +use crate::PgRepository; + +/// Test the user repository, by adding and looking up a user +#[sqlx::test(migrator = "crate::MIGRATOR")] +async fn test_user_repo(pool: PgPool) { + const USERNAME: &str = "john"; + + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + + // Initially, the user shouldn't exist + assert!(!repo.user().exists(USERNAME).await.unwrap()); + assert!(repo + .user() + .find_by_username(USERNAME) + .await + .unwrap() + .is_none()); + + // Adding the user should work + let user = repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .unwrap(); + + // And now it should exist + assert!(repo.user().exists(USERNAME).await.unwrap()); + assert!(repo + .user() + .find_by_username(USERNAME) + .await + .unwrap() + .is_some()); + assert!(repo.user().lookup(user.id).await.unwrap().is_some()); + + // Adding a second time should give a conflict + assert!(repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .is_err()); + + repo.save().await.unwrap(); +} + +/// Test the user email repository, by trying out most of its methods +#[sqlx::test(migrator = "crate::MIGRATOR")] +async fn test_user_email_repo(pool: PgPool) { + const USERNAME: &str = "john"; + const CODE: &str = "012345"; + const CODE2: &str = "543210"; + const EMAIL: &str = "john@example.com"; + + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + + let user = repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .unwrap(); + + // The user email should not exist yet + assert!(repo + .user_email() + .find(&user, EMAIL) + .await + .unwrap() + .is_none()); + + assert_eq!(repo.user_email().count(&user).await.unwrap(), 0); + + let user_email = repo + .user_email() + .add(&mut rng, &clock, &user, EMAIL.to_owned()) + .await + .unwrap(); + + assert_eq!(user_email.user_id, user.id); + assert_eq!(user_email.email, EMAIL); + assert!(user_email.confirmed_at.is_none()); + + assert_eq!(repo.user_email().count(&user).await.unwrap(), 1); + + assert!(repo + .user_email() + .find(&user, EMAIL) + .await + .unwrap() + .is_some()); + + let user_email = repo + .user_email() + .lookup(user_email.id) + .await + .unwrap() + .expect("user email was not found"); + + assert_eq!(user_email.user_id, user.id); + assert_eq!(user_email.email, EMAIL); + + let verification = repo + .user_email() + .add_verification_code( + &mut rng, + &clock, + &user_email, + Duration::hours(8), + CODE.to_owned(), + ) + .await + .unwrap(); + + let verification_id = verification.id; + assert_eq!(verification.user_email_id, user_email.id); + assert_eq!(verification.code, CODE); + + // A single user email can have multiple verification at the same time + let _verification2 = repo + .user_email() + .add_verification_code( + &mut rng, + &clock, + &user_email, + Duration::hours(8), + CODE2.to_owned(), + ) + .await + .unwrap(); + + let verification = repo + .user_email() + .find_verification_code(&clock, &user_email, CODE) + .await + .unwrap() + .expect("user email verification was not found"); + + assert_eq!(verification.id, verification_id); + assert_eq!(verification.user_email_id, user_email.id); + assert_eq!(verification.code, CODE); + + // Consuming the verification code + repo.user_email() + .consume_verification_code(&clock, verification) + .await + .unwrap(); + + // Mark the email as verified + repo.user_email() + .mark_as_verified(&clock, user_email) + .await + .unwrap(); + + // Reload the user_email + let user_email = repo + .user_email() + .find(&user, EMAIL) + .await + .unwrap() + .expect("user email was not found"); + + // The email should be marked as verified now + assert!(user_email.confirmed_at.is_some()); + + // Reload the verification + let verification = repo + .user_email() + .find_verification_code(&clock, &user_email, CODE) + .await + .unwrap() + .expect("user email verification was not found"); + + // Consuming a second time should not work + assert!(repo + .user_email() + .consume_verification_code(&clock, verification) + .await + .is_err()); + + // The user shouldn't have a primary email yet + assert!(repo + .user_email() + .get_primary(&user) + .await + .unwrap() + .is_none()); + + repo.user_email().set_as_primary(&user_email).await.unwrap(); + + // Reload the user + let user = repo + .user() + .lookup(user.id) + .await + .unwrap() + .expect("user was not found"); + + // Now it should have one + assert!(repo + .user_email() + .get_primary(&user) + .await + .unwrap() + .is_some()); + + // Listing the user emails should work + let emails = repo + .user_email() + .list_paginated(&user, Pagination::first(10)) + .await + .unwrap(); + assert!(!emails.has_next_page); + assert_eq!(emails.edges.len(), 1); + assert_eq!(emails.edges[0], user_email); + + // Deleting the user email should work + repo.user_email().remove(user_email).await.unwrap(); + assert_eq!(repo.user_email().count(&user).await.unwrap(), 0); + + // Reload the user + let user = repo + .user() + .lookup(user.id) + .await + .unwrap() + .expect("user was not found"); + + // The primary user email should be gone + assert!(repo + .user_email() + .get_primary(&user) + .await + .unwrap() + .is_none()); + + repo.save().await.unwrap(); +} + +/// Test the user password repository implementation. +#[sqlx::test(migrator = "crate::MIGRATOR")] +async fn test_user_password_repo(pool: PgPool) { + const USERNAME: &str = "john"; + const FIRST_PASSWORD_HASH: &str = "doesntmatter"; + const SECOND_PASSWORD_HASH: &str = "alsodoesntmatter"; + + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + + let user = repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .unwrap(); + + // User should have no active password + assert!(repo.user_password().active(&user).await.unwrap().is_none()); + + // Insert a first password + let first_password = repo + .user_password() + .add( + &mut rng, + &clock, + &user, + 1, + FIRST_PASSWORD_HASH.to_owned(), + None, + ) + .await + .unwrap(); + + // User should now have an active password + let first_password_lookup = repo + .user_password() + .active(&user) + .await + .unwrap() + .expect("user should have an active password"); + + assert_eq!(first_password.id, first_password_lookup.id); + assert_eq!(first_password_lookup.hashed_password, FIRST_PASSWORD_HASH); + assert_eq!(first_password_lookup.version, 1); + assert_eq!(first_password_lookup.upgraded_from_id, None); + + // Getting the last inserted password is based on the clock, so we need to + // advance it + clock.advance(Duration::seconds(10)); + + let second_password = repo + .user_password() + .add( + &mut rng, + &clock, + &user, + 2, + SECOND_PASSWORD_HASH.to_owned(), + Some(&first_password), + ) + .await + .unwrap(); + + // User should now have an active password + let second_password_lookup = repo + .user_password() + .active(&user) + .await + .unwrap() + .expect("user should have an active password"); + + assert_eq!(second_password.id, second_password_lookup.id); + assert_eq!(second_password_lookup.hashed_password, SECOND_PASSWORD_HASH); + assert_eq!(second_password_lookup.version, 2); + assert_eq!( + second_password_lookup.upgraded_from_id, + Some(first_password.id) + ); + + repo.save().await.unwrap(); +} + +#[sqlx::test(migrator = "crate::MIGRATOR")] +async fn test_user_session(pool: PgPool) { + const USERNAME: &str = "john"; + + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + + let user = repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .unwrap(); + + assert_eq!(repo.browser_session().count_active(&user).await.unwrap(), 0); + + let session = repo + .browser_session() + .add(&mut rng, &clock, &user) + .await + .unwrap(); + assert_eq!(session.user.id, user.id); + assert!(session.finished_at.is_none()); + + assert_eq!(repo.browser_session().count_active(&user).await.unwrap(), 1); + + let session_lookup = repo + .browser_session() + .lookup(session.id) + .await + .unwrap() + .expect("user session not found"); + + assert_eq!(session_lookup.id, session.id); + assert_eq!(session_lookup.user.id, user.id); + assert!(session_lookup.finished_at.is_none()); + + // Finish the session + repo.browser_session() + .finish(&clock, session_lookup) + .await + .unwrap(); + + // The active session counter is back to 0 + assert_eq!(repo.browser_session().count_active(&user).await.unwrap(), 0); + + // Reload the session + let session_lookup = repo + .browser_session() + .lookup(session.id) + .await + .unwrap() + .expect("user session not found"); + + assert_eq!(session_lookup.id, session.id); + assert_eq!(session_lookup.user.id, user.id); + // This time the session is finished + assert!(session_lookup.finished_at.is_some()); +} diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index b0ed4c5e5..cea7b03b0 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -6,18 +6,14 @@ edition = "2021" license = "Apache-2.0" [dependencies] -sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] } -chrono = { version = "0.4.23", features = ["serde"] } -serde = { version = "1.0.152", features = ["derive"] } -serde_json = "1.0.91" +async-trait = "0.1.60" +chrono = "0.4.23" thiserror = "1.0.38" -tracing = "0.1.37" +futures-util = "0.3.25" -# Password hashing -rand = "0.8.5" -url = { version = "2.3.1", features = ["serde"] } -uuid = "1.2.2" -ulid = { version = "1.0.0", features = ["uuid", "serde"] } +rand_core = "0.6.4" +url = "2.3.1" +ulid = "1.0.0" oauth2-types = { path = "../oauth2-types" } mas-data-model = { path = "../data-model" } diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json deleted file mode 100644 index 1ce99d797..000000000 --- a/crates/storage/sqlx-data.json +++ /dev/null @@ -1,2794 +0,0 @@ -{ - "db": "PostgreSQL", - "05b50b7ae0109063c50fe70e83635a31920e44a7fbaa2b4f07552ba2f83a28d7": { - "describe": { - "columns": [ - { - "name": "oauth2_client_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "encrypted_client_secret", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "redirect_uris!", - "ordinal": 2, - "type_info": "TextArray" - }, - { - "name": "grant_type_authorization_code", - "ordinal": 3, - "type_info": "Bool" - }, - { - "name": "grant_type_refresh_token", - "ordinal": 4, - "type_info": "Bool" - }, - { - "name": "client_name", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "logo_uri", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "client_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "policy_uri", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "tos_uri", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "jwks_uri", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "jwks", - "ordinal": 11, - "type_info": "Jsonb" - }, - { - "name": "id_token_signed_response_alg", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "userinfo_signed_response_alg", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_method", - "ordinal": 14, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_signing_alg", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "initiate_login_uri", - "ordinal": 16, - "type_info": "Text" - } - ], - "nullable": [ - false, - true, - null, - false, - false, - true, - true, - true, - true, - true, - true, - true, - true, - true, - true, - true, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = $1\n " - }, - "0af182315b36766eca8e232280986bade0202d1b1d64160a99cd14eadcbfc25b": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_provider_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "issuer", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "scope", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "client_id", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "encrypted_client_secret", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "token_endpoint_signing_alg", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_method", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 7, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n " - }, - "0b49cde0b7b79f79ec261502ab89bcffa81f9f5ed2f922a41b1718274b9e3073": { - "describe": { - "columns": [ - { - "name": "user_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_username", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 5, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n u.user_id,\n u.username AS user_username,\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM users u\n\n LEFT JOIN user_emails ue\n USING (user_id)\n\n WHERE u.username = $1\n " - }, - "1166343ad1563cb66ab387368f67320a53c34edf388bdb991359ebdf324497d5": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE user_emails\n SET confirmed_at = $2\n WHERE user_email_id = $1\n " - }, - "1e7b1b7e06b5d97d81dc4a8524bb223c3dc7ddbbcce7cc2a142dbfbdd6a2902e": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " - }, - "1eb6d13e75d8f526c2785749a020731c18012f03e07995213acd38ab560ce497": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n " - }, - "1ee5cecfafd4726a4ebc08da8a34c09178e6e1e072581c8fca9d3d76967792cb": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n " - }, - "2153118b364a33582e7f598acce3789fcb8d938948a819b15cf0b6d37edf58b2": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO compat_access_tokens\n (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " - }, - "24d6154b138a5e9105b996d6447e45a5c208e157f6583b4220cf58813d6f436c": { - "describe": { - "columns": [ - { - "name": "oauth2_authorization_grant_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_cancelled_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_fulfilled_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_exchanged_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_scope", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_state", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_redirect_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_mode", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_nonce", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_max_age", - "ordinal": 10, - "type_info": "Int4" - }, - { - "name": "oauth2_client_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_code", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_type_code", - "ordinal": 13, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_response_type_id_token", - "ordinal": 14, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_code_challenge", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_code_challenge_method", - "ordinal": 16, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_requires_consent", - "ordinal": 17, - "type_info": "Bool" - }, - { - "name": "oauth2_session_id?", - "ordinal": 18, - "type_info": "Uuid" - }, - { - "name": "user_session_id?", - "ordinal": 19, - "type_info": "Uuid" - }, - { - "name": "user_session_created_at?", - "ordinal": 20, - "type_info": "Timestamptz" - }, - { - "name": "user_id?", - "ordinal": 21, - "type_info": "Uuid" - }, - { - "name": "user_username?", - "ordinal": 22, - "type_info": "Text" - }, - { - "name": "user_session_last_authentication_id?", - "ordinal": 23, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_created_at?", - "ordinal": 24, - "type_info": "Timestamptz" - }, - { - "name": "user_email_id?", - "ordinal": 25, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 26, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 27, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 28, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true, - false, - false, - true, - true, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n og.oauth2_authorization_grant_id,\n og.created_at AS oauth2_authorization_grant_created_at,\n og.cancelled_at AS oauth2_authorization_grant_cancelled_at,\n og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at,\n og.exchanged_at AS oauth2_authorization_grant_exchanged_at,\n og.scope AS oauth2_authorization_grant_scope,\n og.state AS oauth2_authorization_grant_state,\n og.redirect_uri AS oauth2_authorization_grant_redirect_uri,\n og.response_mode AS oauth2_authorization_grant_response_mode,\n og.nonce AS oauth2_authorization_grant_nonce,\n og.max_age AS oauth2_authorization_grant_max_age,\n og.oauth2_client_id AS oauth2_client_id,\n og.authorization_code AS oauth2_authorization_grant_code,\n og.response_type_code AS oauth2_authorization_grant_response_type_code,\n og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token,\n og.code_challenge AS oauth2_authorization_grant_code_challenge,\n og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method,\n og.requires_consent AS oauth2_authorization_grant_requires_consent,\n os.oauth2_session_id AS \"oauth2_session_id?\",\n us.user_session_id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN user_sessions us\n USING (user_session_id)\n LEFT JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE og.authorization_code = $1\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " - }, - "262bee715889dc3e608639549600a131e641951ff979634e7c97afc74bbc1605": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE oauth2_authorization_grants\n SET exchanged_at = $2\n WHERE oauth2_authorization_grant_id = $1\n " - }, - "26a9391df9f1128673cdaf431fe8c5e4a83b576ddf7b02d92abfab6deadd4fa2": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Bool", - "Bool", - "Text", - "Jsonb", - "Text" - ] - } - }, - "query": "\n INSERT INTO oauth2_clients\n (oauth2_client_id,\n encrypted_client_secret,\n grant_type_authorization_code,\n grant_type_refresh_token,\n token_endpoint_auth_method,\n jwks,\n jwks_uri)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n " - }, - "27a729b229491d179391b19b634f07291312bd238380c5a7ea0f60e9b71dfb14": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n " - }, - "2ca7b990c11e84db62fb7887a2bc3410ec1eee2f6a0ec124db36575111970ca9": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_authorization_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_provider_id", - "ordinal": 1, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_link_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "state", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "code_challenge_verifier", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "nonce", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "id_token", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 7, - "type_info": "Timestamptz" - }, - { - "name": "completed_at", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "consumed_at", - "ordinal": 9, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false, - true, - false, - true, - false, - true, - true - ], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n AND upstream_oauth_link_id = $2\n " - }, - "2e581d57db471b96091860cd0252361d16332deeffabab0dace405ead55324be": { - "describe": { - "columns": [ - { - "name": "compat_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_access_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "user_id!", - "ordinal": 8, - "type_info": "Uuid" - }, - { - "name": "user_username!", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 10, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 12, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 13, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true, - false, - false, - true, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text", - "Timestamptz" - ] - } - }, - "query": "\n SELECT\n ct.compat_access_token_id,\n ct.access_token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.compat_session_id,\n cs.created_at AS \"compat_session_created_at\",\n cs.finished_at AS \"compat_session_finished_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE ct.access_token = $1\n AND (ct.expires_at < $2 OR ct.expires_at IS NULL)\n AND cs.finished_at IS NULL \n " - }, - "2e756fe7be50128c0acc5f79df3a084230e9ca13cd45bd0858f97e59da20006e": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n " - }, - "2fb8f1aef96571a6f3f6260d7836de99ff24ba1947747e08b0e8d64097507442": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz", - "Text", - "Uuid" - ] - } - }, - "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n " - }, - "360466ff599c67c9af2ac75399c0b536a22c1178972a0172b707bcc81d47357b": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO compat_refresh_tokens\n (compat_refresh_token_id, compat_session_id,\n compat_access_token_id, refresh_token, created_at)\n VALUES ($1, $2, $3, $4, $5)\n " - }, - "3a19b087ae9e4dab770f102de1cb62628525fc72c7b052e1c146161ab088c02b": { - "describe": { - "columns": [ - { - "name": "user_email_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_email", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "user_email_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at", - "ordinal": 3, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid", - "Text" - ] - } - }, - "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.email = $2\n " - }, - "3df0838b660466f69ee681337fe6753133748defb715e53c8381badcc3e8bca9": { - "describe": { - "columns": [ - { - "name": "user_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_id", - "ordinal": 1, - "type_info": "Uuid" - }, - { - "name": "username", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "last_authentication_id?", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "last_authd_at?", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "user_email_id?", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 9, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n s.user_session_id,\n u.user_id,\n u.username,\n s.created_at,\n a.user_session_authentication_id AS \"last_authentication_id?\",\n a.created_at AS \"last_authd_at?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications a\n USING (user_session_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n WHERE s.user_session_id = $1 AND s.finished_at IS NULL\n ORDER BY a.created_at DESC\n LIMIT 1\n " - }, - "3e8f862ed05ce3e58c181ac6e0bd71e0a6a88419611af6f4117d14d9c36cb1ef": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO user_emails (user_email_id, user_id, email, created_at)\n VALUES ($1, $2, $3, $4)\n " - }, - "42bfb0de5bbea2d580f1ff2322255731a4a5655ba80fc2dba0b55a0add8c55c0": { - "describe": { - "columns": [ - { - "name": "compat_sso_login_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_sso_login_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_sso_login_redirect_uri", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "compat_sso_login_created_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_fulfilled_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_exchanged_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id?", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at?", - "ordinal": 7, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at?", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id?", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_id?", - "ordinal": 10, - "type_info": "Uuid" - }, - { - "name": "user_username?", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 12, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 14, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 15, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - false, - false, - true, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n cl.compat_sso_login_id,\n cl.login_token AS \"compat_sso_login_token\",\n cl.redirect_uri AS \"compat_sso_login_redirect_uri\",\n cl.created_at AS \"compat_sso_login_created_at\",\n cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\",\n cl.exchanged_at AS \"compat_sso_login_exchanged_at\",\n cs.compat_session_id AS \"compat_session_id?\",\n cs.created_at AS \"compat_session_created_at?\",\n cs.finished_at AS \"compat_session_finished_at?\",\n cs.device_id AS \"compat_session_device_id?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM compat_sso_logins cl\n LEFT JOIN compat_sessions cs\n USING (compat_session_id)\n LEFT JOIN users u\n USING (user_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n WHERE cl.compat_sso_login_id = $1\n " - }, - "43a5cafbdc8037e9fb779812a0793cf0859902aa0dc8d25d4c33d231d3d1118b": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO oauth2_access_tokens\n (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " - }, - "4693f2b9b3d51ff4a05e233b6667161ebc97f331d96bf5f1c61069e1c8492105": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "UuidArray", - "Uuid", - "TextArray" - ] - } - }, - "query": "\n INSERT INTO oauth2_client_redirect_uris\n (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)\n SELECT id, $2, redirect_uri\n FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)\n " - }, - "46c5ae7052504bfd7b94f20e61b9cf92570779a794bccda23dd654fb8523f340": { - "describe": { - "columns": [ - { - "name": "fulfilled_at!: DateTime", - "ordinal": 0, - "type_info": "Timestamptz" - } - ], - "nullable": [ - true - ], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n oauth2_session_id = os.oauth2_session_id,\n fulfilled_at = os.created_at\n FROM oauth2_sessions os\n WHERE\n og.oauth2_authorization_grant_id = $1\n AND os.oauth2_session_id = $2\n RETURNING fulfilled_at AS \"fulfilled_at!: DateTime\"\n " - }, - "47d4048365144c7bfc14790dfb8fa7f862d2952075a68cd5e90ac76d9e6d1388": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_link_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_provider_id", - "ordinal": 1, - "type_info": "Uuid" - }, - { - "name": "user_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "subject", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 4, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_link_id = $1\n " - }, - "4f8ec19f3f1bfe0268fe102a24e5a9fa542e77eccbebdce65e6deb1c197adf36": { - "describe": { - "columns": [ - { - "name": "oauth2_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "oauth2_access_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id!", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 5, - "type_info": "Uuid" - }, - { - "name": "scope!", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "user_session_id!", - "ordinal": 7, - "type_info": "Uuid" - }, - { - "name": "user_session_created_at!", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "user_id!", - "ordinal": 9, - "type_info": "Uuid" - }, - { - "name": "user_username!", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "user_session_last_authentication_id?", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_created_at?", - "ordinal": 12, - "type_info": "Timestamptz" - }, - { - "name": "user_email_id?", - "ordinal": 13, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 14, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 15, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 16, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n at.oauth2_access_token_id,\n at.access_token AS \"oauth2_access_token\",\n at.created_at AS \"oauth2_access_token_created_at\",\n at.expires_at AS \"oauth2_access_token_expires_at\",\n os.oauth2_session_id AS \"oauth2_session_id!\",\n os.oauth2_client_id AS \"oauth2_client_id!\",\n os.scope AS \"scope!\",\n us.user_session_id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n INNER JOIN user_sessions us\n USING (user_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " - }, - "51158bfcaa1a8d8e051bffe7c5ba0369bf53fb162f7622626054e89e68fc07bd": { - "describe": { - "columns": [ - { - "name": "scope_token", - "ordinal": 0, - "type_info": "Text" - } - ], - "nullable": [ - false - ], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n " - }, - "559a486756d08d101eb7188ef6637b9d24c024d056795b8121f7f04a7f9db6a3": { - "describe": { - "columns": [ - { - "name": "compat_session_id", - "ordinal": 0, - "type_info": "Uuid" - } - ], - "nullable": [ - false - ], - "parameters": { - "Left": [ - "Text", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n FROM compat_access_tokens ca\n WHERE ca.access_token = $1\n AND ca.compat_session_id = cs.compat_session_id\n AND cs.finished_at IS NULL\n RETURNING cs.compat_session_id\n " - }, - "59439585536bb4e547a6cf58a8bc6ac735f29c225bcbeac7d371f09166789a73": { - "describe": { - "columns": [ - { - "name": "user_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_username", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 5, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n u.user_id,\n u.username AS user_username,\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM users u\n\n LEFT JOIN user_emails ue\n USING (user_id)\n\n WHERE u.user_id = $1\n " - }, - "5b5d5c82da37c6f2d8affacfb02119965c04d1f2a9cc53dbf5bd4c12584969a0": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Timestamptz" - ] - } - }, - "query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n " - }, - "5ccde09ee3fe43e7b492d73fa67708b5dcb2b7496c4d05bcfcf0ea63c7576d48": { - "describe": { - "columns": [ - { - "name": "user_email_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_email", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "user_email_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at", - "ordinal": 3, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n\n ORDER BY ue.email ASC\n " - }, - "60d039442cfa57e187602c0ff5e386e32fb774b5ad2d2f2c616040819b76873e": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_sso_logins\n SET\n compat_session_id = $2,\n fulfilled_at = $3\n WHERE\n compat_sso_login_id = $1\n " - }, - "62d05e8e4317bdb180298737d422e64d161c5ad3813913a6f7d67a53569ea76a": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "UuidArray", - "Uuid", - "Uuid", - "TextArray", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO oauth2_consents\n (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)\n SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)\n ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5\n " - }, - "64a56818dd16ac6368efe3e34196a77b7feda1eb87b696e0063a51bf50e499e5": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Timestamptz", - "Uuid" - ] - } - }, - "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " - }, - "65c7600f1af07cb6ea49d89ae6fbca5374a57c5a866c8aadd7b75ed1d2d1d0cd": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_authorization_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_provider_id", - "ordinal": 1, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_link_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "state", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "code_challenge_verifier", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "nonce", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "id_token", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 7, - "type_info": "Timestamptz" - }, - { - "name": "completed_at", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "consumed_at", - "ordinal": 9, - "type_info": "Timestamptz" - }, - { - "name": "provider_issuer", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "provider_scope", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "provider_client_id", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "provider_encrypted_client_secret", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "provider_token_endpoint_auth_method", - "ordinal": 14, - "type_info": "Text" - }, - { - "name": "provider_token_endpoint_signing_alg", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "provider_created_at", - "ordinal": 16, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false, - true, - false, - true, - false, - true, - true, - false, - false, - false, - true, - false, - true, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n ua.upstream_oauth_authorization_session_id,\n ua.upstream_oauth_provider_id,\n ua.upstream_oauth_link_id,\n ua.state,\n ua.code_challenge_verifier,\n ua.nonce,\n ua.id_token,\n ua.created_at,\n ua.completed_at,\n ua.consumed_at,\n up.issuer AS \"provider_issuer\",\n up.scope AS \"provider_scope\",\n up.client_id AS \"provider_client_id\",\n up.encrypted_client_secret AS \"provider_encrypted_client_secret\",\n up.token_endpoint_auth_method AS \"provider_token_endpoint_auth_method\",\n up.token_endpoint_signing_alg AS \"provider_token_endpoint_signing_alg\",\n up.created_at AS \"provider_created_at\"\n FROM upstream_oauth_authorization_sessions ua\n INNER JOIN upstream_oauth_providers up\n USING (upstream_oauth_provider_id)\n WHERE upstream_oauth_authorization_session_id = $1\n " - }, - "6bf0da5ba3dd07b499193a2e0ddeea6e712f9df8f7f28874ff56a952a9f10e54": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n " - }, - "7262f81a335a984c4051383d2ede7455ff65ed90fbd3151d625f8a21fd26cb05": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO user_email_confirmation_codes\n (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " - }, - "75a16693cabdf57012f741e789b19d0a0f96fcd1e41bb2af92f2991b722cc9f1": { - "describe": { - "columns": [ - { - "name": "oauth2_authorization_grant_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_cancelled_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_fulfilled_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_exchanged_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_scope", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_state", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_redirect_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_mode", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_nonce", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_max_age", - "ordinal": 10, - "type_info": "Int4" - }, - { - "name": "oauth2_client_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_code", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_type_code", - "ordinal": 13, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_response_type_id_token", - "ordinal": 14, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_code_challenge", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_code_challenge_method", - "ordinal": 16, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_requires_consent", - "ordinal": 17, - "type_info": "Bool" - }, - { - "name": "oauth2_session_id?", - "ordinal": 18, - "type_info": "Uuid" - }, - { - "name": "user_session_id?", - "ordinal": 19, - "type_info": "Uuid" - }, - { - "name": "user_session_created_at?", - "ordinal": 20, - "type_info": "Timestamptz" - }, - { - "name": "user_id?", - "ordinal": 21, - "type_info": "Uuid" - }, - { - "name": "user_username?", - "ordinal": 22, - "type_info": "Text" - }, - { - "name": "user_session_last_authentication_id?", - "ordinal": 23, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_created_at?", - "ordinal": 24, - "type_info": "Timestamptz" - }, - { - "name": "user_email_id?", - "ordinal": 25, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 26, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 27, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 28, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true, - false, - false, - true, - true, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n og.oauth2_authorization_grant_id,\n og.created_at AS oauth2_authorization_grant_created_at,\n og.cancelled_at AS oauth2_authorization_grant_cancelled_at,\n og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at,\n og.exchanged_at AS oauth2_authorization_grant_exchanged_at,\n og.scope AS oauth2_authorization_grant_scope,\n og.state AS oauth2_authorization_grant_state,\n og.redirect_uri AS oauth2_authorization_grant_redirect_uri,\n og.response_mode AS oauth2_authorization_grant_response_mode,\n og.nonce AS oauth2_authorization_grant_nonce,\n og.max_age AS oauth2_authorization_grant_max_age,\n og.oauth2_client_id AS oauth2_client_id,\n og.authorization_code AS oauth2_authorization_grant_code,\n og.response_type_code AS oauth2_authorization_grant_response_type_code,\n og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token,\n og.code_challenge AS oauth2_authorization_grant_code_challenge,\n og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method,\n og.requires_consent AS oauth2_authorization_grant_requires_consent,\n os.oauth2_session_id AS \"oauth2_session_id?\",\n us.user_session_id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN user_sessions us\n USING (user_session_id)\n LEFT JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE og.oauth2_authorization_grant_id = $1\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " - }, - "7756a60c36a64a259f7450d6eb77ee92303638ca374a63f23ac4944ccf9f4436": { - "describe": { - "columns": [ - { - "name": "oauth2_client_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "encrypted_client_secret", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "redirect_uris!", - "ordinal": 2, - "type_info": "TextArray" - }, - { - "name": "grant_type_authorization_code", - "ordinal": 3, - "type_info": "Bool" - }, - { - "name": "grant_type_refresh_token", - "ordinal": 4, - "type_info": "Bool" - }, - { - "name": "client_name", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "logo_uri", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "client_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "policy_uri", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "tos_uri", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "jwks_uri", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "jwks", - "ordinal": 11, - "type_info": "Jsonb" - }, - { - "name": "id_token_signed_response_alg", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "userinfo_signed_response_alg", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_method", - "ordinal": 14, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_signing_alg", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "initiate_login_uri", - "ordinal": 16, - "type_info": "Text" - } - ], - "nullable": [ - false, - true, - null, - false, - false, - true, - true, - true, - true, - true, - true, - true, - true, - true, - true, - true, - true - ], - "parameters": { - "Left": [ - "UuidArray" - ] - } - }, - "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = ANY($1::uuid[])\n " - }, - "7cf5ae665b15ba78b01bb1dfa304150a89fd7203f4ee15b0753cb2143049a3dc": { - "describe": { - "columns": [ - { - "name": "oauth2_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "oauth2_refresh_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_id?", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "oauth2_access_token?", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "oauth2_access_token_created_at?", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_expires_at?", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id!", - "ordinal": 7, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 8, - "type_info": "Uuid" - }, - { - "name": "oauth2_session_scope!", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_session_id!", - "ordinal": 10, - "type_info": "Uuid" - }, - { - "name": "user_session_created_at!", - "ordinal": 11, - "type_info": "Timestamptz" - }, - { - "name": "user_id!", - "ordinal": 12, - "type_info": "Uuid" - }, - { - "name": "user_username!", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "user_session_last_authentication_id?", - "ordinal": 14, - "type_info": "Uuid" - }, - { - "name": "user_session_last_authentication_created_at?", - "ordinal": 15, - "type_info": "Timestamptz" - }, - { - "name": "user_email_id?", - "ordinal": 16, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 17, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 18, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 19, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n rt.oauth2_refresh_token_id,\n rt.refresh_token AS oauth2_refresh_token,\n rt.created_at AS oauth2_refresh_token_created_at,\n at.oauth2_access_token_id AS \"oauth2_access_token_id?\",\n at.access_token AS \"oauth2_access_token?\",\n at.created_at AS \"oauth2_access_token_created_at?\",\n at.expires_at AS \"oauth2_access_token_expires_at?\",\n os.oauth2_session_id AS \"oauth2_session_id!\",\n os.oauth2_client_id AS \"oauth2_client_id!\",\n os.scope AS \"oauth2_session_scope!\",\n us.user_session_id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n usa.user_session_authentication_id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN oauth2_access_tokens at\n USING (oauth2_access_token_id)\n INNER JOIN user_sessions us\n USING (user_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND us.finished_at IS NULL\n AND os.finished_at IS NULL\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " - }, - "7d600dd15e9dac72c8071c854799fc2ac69777ade5e2d7d2d944b0dedf8ecdf8": { - "describe": { - "columns": [ - { - "name": "user_email_confirmation_code_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "code", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "consumed_at", - "ordinal": 4, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text", - "Uuid" - ] - } - }, - "query": "\n SELECT\n ec.user_email_confirmation_code_id,\n ec.code,\n ec.created_at,\n ec.expires_at,\n ec.consumed_at\n FROM user_email_confirmation_codes ec\n WHERE ec.code = $1\n AND ec.user_email_id = $2\n " - }, - "7e3247e35ecf5335f0656c53bcde27264a9efb8dccb6246344950614f487dcaf": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n " - }, - "819d6472e5bcbd83a83f3a7680e8dc88e77f3970d6beddcf54e8416c880bd496": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n UPDATE users\n SET primary_user_email_id = user_emails.user_email_id\n FROM user_emails\n WHERE user_emails.user_email_id = $1\n AND users.user_id = user_emails.user_id\n " - }, - "874e677f82c221c5bb621c12f293bcef4e70c68c87ec003fcd475bcb994b5a4c": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE oauth2_refresh_tokens\n SET consumed_at = $2\n WHERE oauth2_refresh_token_id = $1\n " - }, - "89e0d338348588831a7a810763a1901073f7a7cb81d51c18bb987a5be10c1202": { - "describe": { - "columns": [ - { - "name": "count", - "ordinal": 0, - "type_info": "Int8" - } - ], - "nullable": [ - null - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT COUNT(*)\n FROM user_emails ue\n WHERE ue.user_id = $1\n " - }, - "99f5f9eb0adc5ec120ed8194cbf6a8545155bef09e6d94d92fb67fd1b14d4f28": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_refresh_tokens\n SET consumed_at = $2\n WHERE compat_refresh_token_id = $1\n " - }, - "9c1ef3114bfe22884d893bb11dc6054421c28cce4bd828cfe6a4ad46c062481a": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE oauth2_sessions\n SET finished_at = $2\n WHERE oauth2_session_id = $1\n " - }, - "9edf5e8a3e00a7cdd8e55b97105df7831ee580096299df4bd6c1ed7c96b95e83": { - "describe": { - "columns": [ - { - "name": "count!", - "ordinal": 0, - "type_info": "Int8" - } - ], - "nullable": [ - null - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.finished_at IS NULL\n " - }, - "a1c19d9d7f1522d126787c7f9946ed51cbbd8f27a4947bc371acab3e7bf23267": { - "describe": { - "columns": [ - { - "name": "user_password_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "hashed_password", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "version", - "ordinal": 2, - "type_info": "Int4" - }, - { - "name": "upgraded_from_id", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "created_at", - "ordinal": 4, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT up.user_password_id\n , up.hashed_password\n , up.version\n , up.upgraded_from_id\n , up.created_at\n FROM user_passwords up\n WHERE up.user_id = $1\n ORDER BY up.created_at DESC\n LIMIT 1\n " - }, - "a5a7dad633396e087239d5629092e4a305908ffce9c2610db07372f719070546": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n requires_consent = 'f'\n WHERE\n og.oauth2_authorization_grant_id = $1\n " - }, - "a8117b4dd167167b477fb4ebda52789e376defbdc67f3d9093aa06308b2f856e": { - "describe": { - "columns": [ - { - "name": "compat_sso_login_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_sso_login_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_sso_login_redirect_uri", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "compat_sso_login_created_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_fulfilled_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_exchanged_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id?", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at?", - "ordinal": 7, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at?", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id?", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_id?", - "ordinal": 10, - "type_info": "Uuid" - }, - { - "name": "user_username?", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 12, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 14, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 15, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - false, - false, - true, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n cl.compat_sso_login_id,\n cl.login_token AS \"compat_sso_login_token\",\n cl.redirect_uri AS \"compat_sso_login_redirect_uri\",\n cl.created_at AS \"compat_sso_login_created_at\",\n cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\",\n cl.exchanged_at AS \"compat_sso_login_exchanged_at\",\n cs.compat_session_id AS \"compat_session_id?\",\n cs.created_at AS \"compat_session_created_at?\",\n cs.finished_at AS \"compat_session_finished_at?\",\n cs.device_id AS \"compat_session_device_id?\",\n u.user_id AS \"user_id?\",\n u.username AS \"user_username?\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM compat_sso_logins cl\n LEFT JOIN compat_sessions cs\n USING (compat_session_id)\n LEFT JOIN users u\n USING (user_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n WHERE cl.login_token = $1\n " - }, - "af77bad7259175464c5ad57f9662571c17b29552ebb70e4b6022584b41bdff0d": { - "describe": { - "columns": [ - { - "name": "exists!", - "ordinal": 0, - "type_info": "Bool" - } - ], - "nullable": [ - null - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT EXISTS(\n SELECT 1 FROM users WHERE username = $1\n ) AS \"exists!\"\n " - }, - "b5b955169ebe6c399e53b74627c11c8219c0736ef2b5b6b44be568a35fd5389f": { - "describe": { - "columns": [ - { - "name": "user_email_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_email", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "user_email_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at", - "ordinal": 3, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.user_email_id = $2\n " - }, - "bc768c63a7737818967bc28560de714bbbd262bdf3ab73d297263bb73dcd9f5e": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Timestamptz", - "Uuid" - ] - } - }, - "query": "\n INSERT INTO oauth2_sessions\n (oauth2_session_id, user_session_id, oauth2_client_id, scope, created_at)\n SELECT\n $1,\n $2,\n og.oauth2_client_id,\n og.scope,\n $3\n FROM\n oauth2_authorization_grants og\n WHERE\n og.oauth2_authorization_grant_id = $4\n " - }, - "bd7a4a008851f3f6d7591e3463e4369cee08820af57dcd3faf95f8e9be82857d": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Int4", - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO user_passwords\n (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at)\n VALUES ($1, $2, $3, $4, $5, $6)\n " - }, - "c52c911bf39ada298bfdc4526028f1b29fdcb6f557b288bb7ea2472b160c8698": { - "describe": { - "columns": [ - { - "name": "compat_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_refresh_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_id", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "compat_access_token", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "compat_access_token_created_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_expires_at", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 7, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at", - "ordinal": 9, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "user_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "user_username!", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 13, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 14, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 15, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 16, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - true, - false, - false, - true, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n cr.compat_refresh_token_id,\n cr.refresh_token AS \"compat_refresh_token\",\n cr.created_at AS \"compat_refresh_token_created_at\",\n ct.compat_access_token_id,\n ct.access_token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.compat_session_id,\n cs.created_at AS \"compat_session_created_at\",\n cs.finished_at AS \"compat_session_finished_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.user_id,\n u.username AS \"user_username!\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_refresh_tokens cr\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN compat_access_tokens ct\n USING (compat_access_token_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE cr.refresh_token = $1\n AND cr.consumed_at IS NULL\n AND cs.finished_at IS NULL\n " - }, - "c88376abdba124ff0487a9a69d2345c7d69d7394f355111ec369cfa6d45fb40f": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Text", - "Text", - "Text", - "Int4", - "Text", - "Text", - "Text", - "Bool", - "Bool", - "Text", - "Bool", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)\n " - }, - "caf54e4659306a746747aa61906bdb2cb8da51176e90435aa8b9754ebf3e4d60": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " - }, - "cb8ba981330e58a6c8580f6e394a721df110e1f2206e080434aa821c44c0164b": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [] - } - }, - "query": "TRUNCATE oauth2_client_redirect_uris, oauth2_clients CASCADE" - }, - "cc9e30678d673546efca336ee8e550083eed71459611fa2db52264e51e175901": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Bool", - "Bool", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Jsonb", - "Text", - "Text", - "Text", - "Text", - "Text" - ] - } - }, - "query": "\n INSERT INTO oauth2_clients\n (oauth2_client_id,\n encrypted_client_secret,\n grant_type_authorization_code,\n grant_type_refresh_token,\n client_name,\n logo_uri,\n client_uri,\n policy_uri,\n tos_uri,\n jwks_uri,\n jwks,\n id_token_signed_response_alg,\n userinfo_signed_response_alg,\n token_endpoint_auth_method,\n token_endpoint_auth_signing_alg,\n initiate_login_uri)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " - }, - "cf00e0ad529bcb5c0640adcfe0880a3560d9739f355b90ca3ba88dd1eaf26565": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_provider_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "issuer", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "scope", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "client_id", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "encrypted_client_secret", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "token_endpoint_signing_alg", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_method", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 7, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - false, - false - ], - "parameters": { - "Left": [] - } - }, - "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n " - }, - "d1738c27339b81f0844da4bd9b040b9b07a91aa4d9b199b98f24c9cee5709b2b": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n " - }, - "d55a321e8935f4effda29d9620a0f622125cb38472785049ee21c2616a6bd068": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE user_email_confirmation_codes\n SET consumed_at = $2\n WHERE user_email_confirmation_code_id = $1\n " - }, - "d8677b3b6ee594c230fad98c1aa1c6e3d983375bf5b701c7b52468e7f906abf9": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,\n refresh_token, created_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " - }, - "e16ac9f75be25ef6873f1851e916df3ea730422409decc0344f7f05ce3c3841f": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n " - }, - "e1dc9dd2bf26a341050a53151bf51f7638448ccc2bd458bbdfe87cc22f086313": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n " - }, - "e30562e9637d3a723a91adca6336a8d083657ce6d7fe9551fcd6a9d672453d3c": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Timestamptz", - "Uuid" - ] - } - }, - "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n " - }, - "e446e37d48c8838ef2e0d0fd82f8f7b04893c84ad46747cdf193ebd83755ceb2": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " - }, - "f71cb5761bfc15d8bc3ba7ee49b63fb3c3ea9691745688eb5fd91f4f6e1ec018": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_link_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_provider_id", - "ordinal": 1, - "type_info": "Uuid" - }, - { - "name": "user_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "subject", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 4, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid", - "Text" - ] - } - }, - "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " - }, - "fb71ac6539039313fd90b29ac943330e54c7b62b2778727726e2f60a554f9c5a": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Text", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n " - } -} \ No newline at end of file diff --git a/crates/storage/src/clock.rs b/crates/storage/src/clock.rs new file mode 100644 index 000000000..04c69f25a --- /dev/null +++ b/crates/storage/src/clock.rs @@ -0,0 +1,135 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! A [`Clock`] is a way to get the current date and time. +//! +//! This module defines two implemetation of the [`Clock`] trait: +//! [`SystemClock`] which uses the system time, and a [`MockClock`], which can +//! be used and freely manipulated in tests. + +use std::sync::atomic::AtomicI64; + +use chrono::{DateTime, TimeZone, Utc}; + +/// Represents a clock which can give the current date and time +pub trait Clock: Sync { + /// Get the current date and time + fn now(&self) -> DateTime; +} + +impl Clock for Box { + fn now(&self) -> DateTime { + (**self).now() + } +} + +/// A clock which uses the system time +#[derive(Clone, Default)] +pub struct SystemClock { + _private: (), +} + +impl Clock for SystemClock { + fn now(&self) -> DateTime { + // This is the clock used elsewhere, it's fine to call Utc::now here + #[allow(clippy::disallowed_methods)] + Utc::now() + } +} + +/// A fake clock, which uses a fixed timestamp, and can be advanced with the +/// [`MockClock::advance`] method. +/// +/// ```rust +/// use mas_storage::clock::{Clock, MockClock}; +/// use chrono::Duration; +/// +/// let clock = MockClock::default(); +/// let t1 = clock.now(); +/// let t2 = clock.now(); +/// assert_eq!(t1, t2); +/// +/// clock.advance(Duration::seconds(10)); +/// let t3 = clock.now(); +/// assert_eq!(t2 + Duration::seconds(10), t3); +/// ``` +pub struct MockClock { + timestamp: AtomicI64, +} + +impl Default for MockClock { + fn default() -> Self { + let datetime = Utc.with_ymd_and_hms(2022, 1, 16, 14, 40, 0).unwrap(); + Self::new(datetime) + } +} + +impl MockClock { + /// Create a new clock which starts at the given datetime + #[must_use] + pub fn new(datetime: DateTime) -> Self { + let timestamp = AtomicI64::new(datetime.timestamp()); + Self { timestamp } + } + + /// Move the clock forward by the given amount of time + pub fn advance(&self, duration: chrono::Duration) { + self.timestamp + .fetch_add(duration.num_seconds(), std::sync::atomic::Ordering::Relaxed); + } +} + +impl Clock for MockClock { + fn now(&self) -> DateTime { + let timestamp = self.timestamp.load(std::sync::atomic::Ordering::Relaxed); + chrono::TimeZone::timestamp_opt(&Utc, timestamp, 0).unwrap() + } +} + +#[cfg(test)] +mod tests { + use chrono::Duration; + + use super::*; + + #[test] + fn test_mocked_clock() { + let clock = MockClock::default(); + + // Time should be frozen, and give out the same timestamp on each call + let first = clock.now(); + std::thread::sleep(std::time::Duration::from_millis(10)); + let second = clock.now(); + + assert_eq!(first, second); + + // Clock can be advanced by a fixed duration + clock.advance(Duration::seconds(10)); + let third = clock.now(); + assert_eq!(first + Duration::seconds(10), third); + } + + #[test] + fn test_real_clock() { + let clock = SystemClock::default(); + + // Time should not be frozen + let first = clock.now(); + std::thread::sleep(std::time::Duration::from_millis(10)); + let second = clock.now(); + + assert_ne!(first, second); + assert!(first < second); + } +} diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs deleted file mode 100644 index 9737fcb9c..000000000 --- a/crates/storage/src/compat.rs +++ /dev/null @@ -1,948 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{ - CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState, - Device, User, UserEmail, -}; -use rand::Rng; -use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder}; -use tracing::{info_span, Instrument}; -use ulid::Ulid; -use url::Url; -use uuid::Uuid; - -use crate::{ - pagination::{process_page, QueryBuilderExt}, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; - -struct CompatAccessTokenLookup { - compat_access_token_id: Uuid, - compat_access_token: String, - compat_access_token_created_at: DateTime, - compat_access_token_expires_at: Option>, - compat_session_id: Uuid, - compat_session_created_at: DateTime, - compat_session_finished_at: Option>, - compat_session_device_id: String, - user_id: Uuid, - user_username: String, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, -} - -#[tracing::instrument(skip_all, err)] -pub async fn lookup_active_compat_access_token( - executor: impl PgExecutor<'_>, - clock: &Clock, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatAccessTokenLookup, - r#" - SELECT - ct.compat_access_token_id, - ct.access_token AS "compat_access_token", - ct.created_at AS "compat_access_token_created_at", - ct.expires_at AS "compat_access_token_expires_at", - cs.compat_session_id, - cs.created_at AS "compat_session_created_at", - cs.finished_at AS "compat_session_finished_at", - cs.device_id AS "compat_session_device_id", - u.user_id AS "user_id!", - u.username AS "user_username!", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" - - FROM compat_access_tokens ct - INNER JOIN compat_sessions cs - USING (compat_session_id) - INNER JOIN users u - USING (user_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id - - WHERE ct.access_token = $1 - AND (ct.expires_at < $2 OR ct.expires_at IS NULL) - AND cs.finished_at IS NULL - "#, - token, - clock.now(), - ) - .fetch_one(executor) - .instrument(info_span!("Fetch compat access token")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let token = CompatAccessToken { - id: res.compat_access_token_id.into(), - token: res.compat_access_token, - created_at: res.compat_access_token_created_at, - expires_at: res.compat_access_token_expires_at, - }; - - let user_id = Ulid::from(res.user_id); - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("compat_sessions") - .column("user_id") - .row(user_id) - .into()) - } - }; - - let user = User { - id: user_id, - username: res.user_username, - sub: user_id.to_string(), - primary_email, - }; - - let id = res.compat_session_id.into(); - let device = Device::try_from(res.compat_session_device_id).map_err(|e| { - DatabaseInconsistencyError::on("compat_sessions") - .column("device_id") - .row(id) - .source(e) - })?; - - let session = CompatSession { - id, - user, - device, - created_at: res.compat_session_created_at, - finished_at: res.compat_session_finished_at, - }; - - Ok(Some((token, session))) -} - -pub struct CompatRefreshTokenLookup { - compat_refresh_token_id: Uuid, - compat_refresh_token: String, - compat_refresh_token_created_at: DateTime, - compat_access_token_id: Uuid, - compat_access_token: String, - compat_access_token_created_at: DateTime, - compat_access_token_expires_at: Option>, - compat_session_id: Uuid, - compat_session_created_at: DateTime, - compat_session_finished_at: Option>, - compat_session_device_id: String, - user_id: Uuid, - user_username: String, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, -} - -#[tracing::instrument(skip_all, err)] -#[allow(clippy::type_complexity)] -pub async fn lookup_active_compat_refresh_token( - executor: impl PgExecutor<'_>, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatRefreshTokenLookup, - r#" - SELECT - cr.compat_refresh_token_id, - cr.refresh_token AS "compat_refresh_token", - cr.created_at AS "compat_refresh_token_created_at", - ct.compat_access_token_id, - ct.access_token AS "compat_access_token", - ct.created_at AS "compat_access_token_created_at", - ct.expires_at AS "compat_access_token_expires_at", - cs.compat_session_id, - cs.created_at AS "compat_session_created_at", - cs.finished_at AS "compat_session_finished_at", - cs.device_id AS "compat_session_device_id", - u.user_id, - u.username AS "user_username!", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" - - FROM compat_refresh_tokens cr - INNER JOIN compat_sessions cs - USING (compat_session_id) - INNER JOIN compat_access_tokens ct - USING (compat_access_token_id) - INNER JOIN users u - USING (user_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id - - WHERE cr.refresh_token = $1 - AND cr.consumed_at IS NULL - AND cs.finished_at IS NULL - "#, - token, - ) - .fetch_one(executor) - .instrument(info_span!("Fetch compat refresh token")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None); }; - - let refresh_token = CompatRefreshToken { - id: res.compat_refresh_token_id.into(), - token: res.compat_refresh_token, - created_at: res.compat_refresh_token_created_at, - }; - - let access_token = CompatAccessToken { - id: res.compat_access_token_id.into(), - token: res.compat_access_token, - created_at: res.compat_access_token_created_at, - expires_at: res.compat_access_token_expires_at, - }; - - let user_id = Ulid::from(res.user_id); - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(user_id) - .into()) - } - }; - - let user = User { - id: user_id, - username: res.user_username, - sub: user_id.to_string(), - primary_email, - }; - - let session_id = res.compat_session_id.into(); - let device = Device::try_from(res.compat_session_device_id).map_err(|e| { - DatabaseInconsistencyError::on("compat_sessions") - .column("device_id") - .row(session_id) - .source(e) - })?; - - let session = CompatSession { - id: session_id, - user, - device, - created_at: res.compat_session_created_at, - finished_at: res.compat_session_finished_at, - }; - - Ok(Some((refresh_token, access_token, session))) -} - -#[tracing::instrument( - skip_all, - fields( - compat_session.id = %session.id, - compat_session.device.id = session.device.as_str(), - compat_access_token.id, - user.id = %session.user.id, - ), - err, -)] -pub async fn add_compat_access_token( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - session: &CompatSession, - token: String, - expires_after: Option, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); - - let expires_at = expires_after.map(|expires_after| created_at + expires_after); - - sqlx::query!( - r#" - INSERT INTO compat_access_tokens - (compat_access_token_id, compat_session_id, access_token, created_at, expires_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - token, - created_at, - expires_at, - ) - .execute(executor) - .instrument(tracing::info_span!("Insert compat access token")) - .await?; - - Ok(CompatAccessToken { - id, - token, - created_at, - expires_at, - }) -} - -#[tracing::instrument( - skip_all, - fields( - compat_access_token.id = %access_token.id, - ), - err, -)] -pub async fn expire_compat_access_token( - executor: impl PgExecutor<'_>, - clock: &Clock, - access_token: CompatAccessToken, -) -> Result<(), DatabaseError> { - let expires_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE compat_access_tokens - SET expires_at = $2 - WHERE compat_access_token_id = $1 - "#, - Uuid::from(access_token.id), - expires_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} - -#[tracing::instrument( - skip_all, - fields( - compat_session.id = %session.id, - compat_session.device.id = session.device.as_str(), - compat_access_token.id = %access_token.id, - compat_refresh_token.id, - user.id = %session.user.id, - ), - err, -)] -pub async fn add_compat_refresh_token( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - session: &CompatSession, - access_token: &CompatAccessToken, - token: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_refresh_tokens - (compat_refresh_token_id, compat_session_id, - compat_access_token_id, refresh_token, created_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - Uuid::from(access_token.id), - token, - created_at, - ) - .execute(executor) - .instrument(tracing::info_span!("Insert compat refresh token")) - .await?; - - Ok(CompatRefreshToken { - id, - token, - created_at, - }) -} - -#[tracing::instrument( - skip_all, - fields(compat_session.id), - err, -)] -pub async fn compat_logout( - executor: impl PgExecutor<'_>, - clock: &Clock, - token: &str, -) -> Result { - let finished_at = clock.now(); - // TODO: this does not check for token expiration - let res = sqlx::query_scalar!( - r#" - UPDATE compat_sessions cs - SET finished_at = $2 - FROM compat_access_tokens ca - WHERE ca.access_token = $1 - AND ca.compat_session_id = cs.compat_session_id - AND cs.finished_at IS NULL - RETURNING cs.compat_session_id - "#, - token, - finished_at, - ) - .fetch_one(executor) - .await - .to_option()?; - - if let Some(compat_session_id) = res { - tracing::Span::current().record( - "compat_session.id", - tracing::field::display(compat_session_id), - ); - Ok(true) - } else { - Ok(false) - } -} - -#[tracing::instrument( - skip_all, - fields( - compat_refresh_token.id = %refresh_token.id, - ), - err, -)] -pub async fn consume_compat_refresh_token( - executor: impl PgExecutor<'_>, - clock: &Clock, - refresh_token: CompatRefreshToken, -) -> Result<(), DatabaseError> { - let consumed_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE compat_refresh_tokens - SET consumed_at = $2 - WHERE compat_refresh_token_id = $1 - "#, - Uuid::from(refresh_token.id), - consumed_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} - -#[tracing::instrument( - skip_all, - fields( - compat_sso_login.id, - compat_sso_login.redirect_uri = %redirect_uri, - ), - err, -)] -pub async fn insert_compat_sso_login( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - login_token: String, - redirect_uri: Url, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_sso_logins - (compat_sso_login_id, login_token, redirect_uri, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - &login_token, - redirect_uri.as_str(), - created_at, - ) - .execute(executor) - .instrument(tracing::info_span!("Insert compat SSO login")) - .await?; - - Ok(CompatSsoLogin { - id, - login_token, - redirect_uri, - created_at, - state: CompatSsoLoginState::Pending, - }) -} - -#[derive(sqlx::FromRow)] -struct CompatSsoLoginLookup { - compat_sso_login_id: Uuid, - compat_sso_login_token: String, - compat_sso_login_redirect_uri: String, - compat_sso_login_created_at: DateTime, - compat_sso_login_fulfilled_at: Option>, - compat_sso_login_exchanged_at: Option>, - compat_session_id: Option, - compat_session_created_at: Option>, - compat_session_finished_at: Option>, - compat_session_device_id: Option, - user_id: Option, - user_username: Option, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, -} - -impl TryFrom for CompatSsoLogin { - type Error = DatabaseInconsistencyError; - - fn try_from(res: CompatSsoLoginLookup) -> Result { - let id = res.compat_sso_login_id.into(); - let redirect_uri = Url::parse(&res.compat_sso_login_redirect_uri).map_err(|e| { - DatabaseInconsistencyError::on("compat_sso_logins") - .column("redirect_uri") - .row(id) - .source(e) - })?; - - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users").column("primary_user_email_id")) - } - }; - - let user = match (res.user_id, res.user_username, primary_email) { - (Some(id), Some(username), primary_email) => { - let id = Ulid::from(id); - Some(User { - id, - username, - sub: id.to_string(), - primary_email, - }) - } - - (None, None, None) => None, - _ => return Err(DatabaseInconsistencyError::on("compat_sessions").column("user_id")), - }; - - let session = match ( - res.compat_session_id, - res.compat_session_device_id, - res.compat_session_created_at, - res.compat_session_finished_at, - user, - ) { - (Some(id), Some(device_id), Some(created_at), finished_at, Some(user)) => { - let id = id.into(); - let device = Device::try_from(device_id).map_err(|e| { - DatabaseInconsistencyError::on("compat_sessions") - .column("device") - .row(id) - .source(e) - })?; - Some(CompatSession { - id, - user, - device, - created_at, - finished_at, - }) - } - (None, None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("compat_sso_logins") - .column("compat_session_id") - .row(id)) - } - }; - - let state = match ( - res.compat_sso_login_fulfilled_at, - res.compat_sso_login_exchanged_at, - session, - ) { - (None, None, None) => CompatSsoLoginState::Pending, - (Some(fulfilled_at), None, Some(session)) => CompatSsoLoginState::Fulfilled { - fulfilled_at, - session, - }, - (Some(fulfilled_at), Some(exchanged_at), Some(session)) => { - CompatSsoLoginState::Exchanged { - fulfilled_at, - exchanged_at, - session, - } - } - _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), - }; - - Ok(CompatSsoLogin { - id, - login_token: res.compat_sso_login_token, - redirect_uri, - created_at: res.compat_sso_login_created_at, - state, - }) - } -} - -#[tracing::instrument( - skip_all, - fields( - compat_sso_login.id = %id, - ), - err, -)] -pub async fn get_compat_sso_login_by_id( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatSsoLoginLookup, - r#" - SELECT - cl.compat_sso_login_id, - cl.login_token AS "compat_sso_login_token", - cl.redirect_uri AS "compat_sso_login_redirect_uri", - cl.created_at AS "compat_sso_login_created_at", - cl.fulfilled_at AS "compat_sso_login_fulfilled_at", - cl.exchanged_at AS "compat_sso_login_exchanged_at", - cs.compat_session_id AS "compat_session_id?", - cs.created_at AS "compat_session_created_at?", - cs.finished_at AS "compat_session_finished_at?", - cs.device_id AS "compat_session_device_id?", - u.user_id AS "user_id?", - u.username AS "user_username?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" - FROM compat_sso_logins cl - LEFT JOIN compat_sessions cs - USING (compat_session_id) - LEFT JOIN users u - USING (user_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id - WHERE cl.compat_sso_login_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .instrument(tracing::info_span!("Lookup compat SSO login")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - ), - err, -)] -pub async fn get_paginated_user_compat_sso_logins( - executor: impl PgExecutor<'_>, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - // TODO: this queries too much (like user info) which we probably don't need - // because we already have them - let mut query = QueryBuilder::new( - r#" - SELECT - cl.compat_sso_login_id, - cl.login_token AS "compat_sso_login_token", - cl.redirect_uri AS "compat_sso_login_redirect_uri", - cl.created_at AS "compat_sso_login_created_at", - cl.fulfilled_at AS "compat_sso_login_fulfilled_at", - cl.exchanged_at AS "compat_sso_login_exchanged_at", - cs.compat_session_id AS "compat_session_id", - cs.created_at AS "compat_session_created_at", - cs.finished_at AS "compat_session_finished_at", - cs.device_id AS "compat_session_device_id", - u.user_id AS "user_id", - u.username AS "user_username", - ue.user_email_id AS "user_email_id", - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM compat_sso_logins cl - LEFT JOIN compat_sessions cs - USING (compat_session_id) - LEFT JOIN users u - USING (user_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id - "#, - ); - - query - .push(" WHERE cs.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("cl.compat_sso_login_id", before, after, first, last)?; - - let span = info_span!( - "Fetch paginated user compat SSO logins", - db.statement = query.sql() - ); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) - .await?; - - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - - let page: Result, _> = page.into_iter().map(TryInto::try_into).collect(); - Ok((has_previous_page, has_next_page, page?)) -} - -#[tracing::instrument(skip_all, err)] -pub async fn get_compat_sso_login_by_token( - executor: impl PgExecutor<'_>, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatSsoLoginLookup, - r#" - SELECT - cl.compat_sso_login_id, - cl.login_token AS "compat_sso_login_token", - cl.redirect_uri AS "compat_sso_login_redirect_uri", - cl.created_at AS "compat_sso_login_created_at", - cl.fulfilled_at AS "compat_sso_login_fulfilled_at", - cl.exchanged_at AS "compat_sso_login_exchanged_at", - cs.compat_session_id AS "compat_session_id?", - cs.created_at AS "compat_session_created_at?", - cs.finished_at AS "compat_session_finished_at?", - cs.device_id AS "compat_session_device_id?", - u.user_id AS "user_id?", - u.username AS "user_username?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" - FROM compat_sso_logins cl - LEFT JOIN compat_sessions cs - USING (compat_session_id) - LEFT JOIN users u - USING (user_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id - WHERE cl.login_token = $1 - "#, - token, - ) - .fetch_one(executor) - .instrument(tracing::info_span!("Lookup compat SSO login")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - compat_session.id, - compat_session.device.id = device.as_str(), - ), - err, -)] -pub async fn start_compat_session( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user: User, - device: Device, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("compat_session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - Uuid::from(user.id), - device.as_str(), - created_at, - ) - .execute(executor) - .await?; - - Ok(CompatSession { - id, - user, - device, - created_at, - finished_at: None, - }) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %compat_sso_login.id, - %compat_sso_login.redirect_uri, - compat_session.id, - compat_session.device.id = device.as_str(), - ), - err, -)] -pub async fn fullfill_compat_sso_login( - conn: impl Acquire<'_, Database = Postgres> + Send, - mut rng: impl Rng + Send, - clock: &Clock, - user: User, - mut compat_sso_login: CompatSsoLogin, - device: Device, -) -> Result { - if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) { - return Err(DatabaseError::invalid_operation()); - }; - - let mut txn = conn.begin().await?; - - let session = start_compat_session(&mut txn, &mut rng, clock, user, device).await?; - - let fulfilled_at = clock.now(); - sqlx::query!( - r#" - UPDATE compat_sso_logins - SET - compat_session_id = $2, - fulfilled_at = $3 - WHERE - compat_sso_login_id = $1 - "#, - Uuid::from(compat_sso_login.id), - Uuid::from(session.id), - fulfilled_at, - ) - .execute(&mut txn) - .instrument(tracing::info_span!("Update compat SSO login")) - .await?; - - let state = CompatSsoLoginState::Fulfilled { - fulfilled_at, - session, - }; - - compat_sso_login.state = state; - - txn.commit().await?; - - Ok(compat_sso_login) -} - -#[tracing::instrument( - skip_all, - fields( - %compat_sso_login.id, - %compat_sso_login.redirect_uri, - ), - err, -)] -pub async fn mark_compat_sso_login_as_exchanged( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut compat_sso_login: CompatSsoLogin, -) -> Result { - let CompatSsoLoginState::Fulfilled { fulfilled_at, session } = compat_sso_login.state else { - return Err(DatabaseError::invalid_operation()); - }; - - let exchanged_at = clock.now(); - sqlx::query!( - r#" - UPDATE compat_sso_logins - SET - exchanged_at = $2 - WHERE - compat_sso_login_id = $1 - "#, - Uuid::from(compat_sso_login.id), - exchanged_at, - ) - .execute(executor) - .instrument(tracing::info_span!("Update compat SSO login")) - .await?; - - let state = CompatSsoLoginState::Exchanged { - fulfilled_at, - exchanged_at, - session, - }; - compat_sso_login.state = state; - Ok(compat_sso_login) -} diff --git a/crates/storage/src/compat/access_token.rs b/crates/storage/src/compat/access_token.rs new file mode 100644 index 000000000..c6d3979ee --- /dev/null +++ b/crates/storage/src/compat/access_token.rs @@ -0,0 +1,121 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::Duration; +use mas_data_model::{CompatAccessToken, CompatSession}; +use rand_core::RngCore; +use ulid::Ulid; + +use crate::{repository_impl, Clock}; + +/// A [`CompatAccessTokenRepository`] helps interacting with +/// [`CompatAccessToken`] saved in the storage backend +#[async_trait] +pub trait CompatAccessTokenRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup a compat access token by its ID + /// + /// Returns the compat access token if it exists, `None` otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the compat access token to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find a compat access token by its token + /// + /// Returns the compat access token if found, `None` otherwise + /// + /// # Parameters + /// + /// * `access_token`: The token of the compat access token to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error>; + + /// Add a new compat access token to the database + /// + /// Returns the newly created compat access token + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `compat_session`: The compat session associated with the access token + /// * `token`: The token of the access token + /// * `expires_after`: The duration after which the access token expires, if + /// specified + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + compat_session: &CompatSession, + token: String, + expires_after: Option, + ) -> Result; + + /// Set the expiration time of the compat access token to now + /// + /// Returns the expired compat access token + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `compat_access_token`: The compat access token to expire + async fn expire( + &mut self, + clock: &dyn Clock, + compat_access_token: CompatAccessToken, + ) -> Result; +} + +repository_impl!(CompatAccessTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + compat_session: &CompatSession, + token: String, + expires_after: Option, + ) -> Result; + + async fn expire( + &mut self, + clock: &dyn Clock, + compat_access_token: CompatAccessToken, + ) -> Result; +); diff --git a/crates/storage/src/compat/mod.rs b/crates/storage/src/compat/mod.rs new file mode 100644 index 000000000..eb971edd1 --- /dev/null +++ b/crates/storage/src/compat/mod.rs @@ -0,0 +1,25 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Repositories to interact with entities of the compatibility layer + +mod access_token; +mod refresh_token; +mod session; +mod sso_login; + +pub use self::{ + access_token::CompatAccessTokenRepository, refresh_token::CompatRefreshTokenRepository, + session::CompatSessionRepository, sso_login::CompatSsoLoginRepository, +}; diff --git a/crates/storage/src/compat/refresh_token.rs b/crates/storage/src/compat/refresh_token.rs new file mode 100644 index 000000000..c9b3aabe4 --- /dev/null +++ b/crates/storage/src/compat/refresh_token.rs @@ -0,0 +1,121 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use mas_data_model::{CompatAccessToken, CompatRefreshToken, CompatSession}; +use rand_core::RngCore; +use ulid::Ulid; + +use crate::{repository_impl, Clock}; + +/// A [`CompatRefreshTokenRepository`] helps interacting with +/// [`CompatRefreshToken`] saved in the storage backend +#[async_trait] +pub trait CompatRefreshTokenRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup a compat refresh token by its ID + /// + /// Returns the compat refresh token if it exists, `None` otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the compat refresh token to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find a compat refresh token by its token + /// + /// Returns the compat refresh token if found, `None` otherwise + /// + /// # Parameters + /// + /// * `refresh_token`: The token of the compat refresh token to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error>; + + /// Add a new compat refresh token to the database + /// + /// Returns the newly created compat refresh token + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `compat_session`: The compat session associated with this refresh + /// token + /// * `compat_access_token`: The compat access token created alongside this + /// refresh token + /// * `token`: The token of the refresh token + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + compat_session: &CompatSession, + compat_access_token: &CompatAccessToken, + token: String, + ) -> Result; + + /// Consume a compat refresh token + /// + /// Returns the consumed compat refresh token + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `compat_refresh_token`: The compat refresh token to consume + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn consume( + &mut self, + clock: &dyn Clock, + compat_refresh_token: CompatRefreshToken, + ) -> Result; +} + +repository_impl!(CompatRefreshTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + compat_session: &CompatSession, + compat_access_token: &CompatAccessToken, + token: String, + ) -> Result; + + async fn consume( + &mut self, + clock: &dyn Clock, + compat_refresh_token: CompatRefreshToken, + ) -> Result; +); diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs new file mode 100644 index 000000000..fb9dea73c --- /dev/null +++ b/crates/storage/src/compat/session.rs @@ -0,0 +1,99 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use mas_data_model::{CompatSession, Device, User}; +use rand_core::RngCore; +use ulid::Ulid; + +use crate::{repository_impl, Clock}; + +/// A [`CompatSessionRepository`] helps interacting with +/// [`CompatSessionRepository`] saved in the storage backend +#[async_trait] +pub trait CompatSessionRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup a compat session by its ID + /// + /// Returns the compat session if it exists, `None` otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the compat session to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Start a new compat session + /// + /// Returns the newly created compat session + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `user`: The user to create the compat session for + /// * `device`: The device ID of this session + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + device: Device, + ) -> Result; + + /// End a compat session + /// + /// Returns the ended compat session + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `compat_session`: The compat session to end + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn finish( + &mut self, + clock: &dyn Clock, + compat_session: CompatSession, + ) -> Result; +} + +repository_impl!(CompatSessionRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + device: Device, + ) -> Result; + + async fn finish( + &mut self, + clock: &dyn Clock, + compat_session: CompatSession, + ) -> Result; +); diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs new file mode 100644 index 000000000..7c823d620 --- /dev/null +++ b/crates/storage/src/compat/sso_login.rs @@ -0,0 +1,171 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use mas_data_model::{CompatSession, CompatSsoLogin, User}; +use rand_core::RngCore; +use ulid::Ulid; +use url::Url; + +use crate::{pagination::Page, repository_impl, Clock, Pagination}; + +/// A [`CompatSsoLoginRepository`] helps interacting with +/// [`CompatSsoLoginRepository`] saved in the storage backend +#[async_trait] +pub trait CompatSsoLoginRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup a compat SSO login by its ID + /// + /// Returns the compat SSO login if it exists, `None` otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the compat SSO login to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find a compat SSO login by its login token + /// + /// Returns the compat SSO login if found, `None` otherwise + /// + /// # Parameters + /// + /// * `login_token`: The login token of the compat SSO login to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_by_token( + &mut self, + login_token: &str, + ) -> Result, Self::Error>; + + /// Start a new compat SSO login token + /// + /// Returns the newly created compat SSO login + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate the timestamps + /// * `login_token`: The login token given to the client + /// * `redirect_uri`: The redirect URI given by the client + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + login_token: String, + redirect_uri: Url, + ) -> Result; + + /// Fulfill a compat SSO login by providing a compat session + /// + /// Returns the fulfilled compat SSO login + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate the timestamps + /// * `compat_sso_login`: The compat SSO login to fulfill + /// * `compat_session`: The compat session to associate with the compat SSO + /// login + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn fulfill( + &mut self, + clock: &dyn Clock, + compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, + ) -> Result; + + /// Mark a compat SSO login as exchanged + /// + /// Returns the exchanged compat SSO login + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate the timestamps + /// * `compat_sso_login`: The compat SSO login to mark as exchanged + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn exchange( + &mut self, + clock: &dyn Clock, + compat_sso_login: CompatSsoLogin, + ) -> Result; + + /// Get a paginated list of compat SSO logins for a user + /// + /// # Parameters + /// + /// * `user`: The user to get the compat SSO logins for + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; +} + +repository_impl!(CompatSsoLoginRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + login_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + login_token: String, + redirect_uri: Url, + ) -> Result; + + async fn fulfill( + &mut self, + clock: &dyn Clock, + compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, + ) -> Result; + + async fn exchange( + &mut self, + clock: &dyn Clock, + compat_sso_login: CompatSsoLogin, + ) -> Result; + + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; +); diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index f059e376c..0e8458b7d 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,174 +12,153 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Interactions with the database +//! Interactions with the storage backend +//! +//! This crate provides a set of traits that can be implemented to interact with +//! the storage backend. Those traits are called repositories and are grouped by +//! the type of data they manage. +//! +//! Each of those reposotories can be accessed via the [`RepositoryAccess`] +//! trait. This trait can be wrapped in a [`BoxRepository`] to allow using it +//! without caring about the underlying storage backend, and without carrying +//! around the generic type parameter. +//! +//! This crate also defines a [`Clock`] trait that can be used to abstract the +//! way the current time is retrieved. It has two implementation: +//! [`SystemClock`] that uses the system time and [`MockClock`] which is useful +//! for testing. +//! +//! [`MockClock`]: crate::clock::MockClock +//! +//! # Defining a new repository +//! +//! To define a new repository, you have to: +//! 1. Define a new (async) repository trait, with the methods you need +//! 2. Write an implementation of this trait for each storage backend you want +//! (currently only for [`mas-storage-pg`]) +//! 3. Make it accessible via the [`RepositoryAccess`] trait +//! +//! The repository trait definition should look like this: +//! +//! ```rust +//! # use async_trait::async_trait; +//! # use ulid::Ulid; +//! # use rand_core::RngCore; +//! # use mas_storage::Clock; +//! # +//! # // A fake data structure, usually defined in mas-data-model +//! # struct FakeData { +//! # id: Ulid, +//! # } +//! # +//! # // A fake empty macro, to replace `mas_storage::repository_impl` +//! # macro_rules! repository_impl { ($($tok:tt)*) => {} } +//! +//! #[async_trait] +//! pub trait FakeDataRepository: Send + Sync { +//! /// The error type returned by the repository +//! type Error; +//! +//! /// Lookup a [`FakeData`] by its ID +//! /// +//! /// Returns `None` if no [`FakeData`] was found +//! /// +//! /// # Parameters +//! /// +//! /// * `id`: The ID of the [`FakeData`] to lookup +//! /// +//! /// # Errors +//! /// +//! /// Returns [`Self::Error`] if the underlying repository fails +//! async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; +//! +//! /// Create a new [`FakeData`] +//! /// +//! /// Returns the newly-created [`FakeData`]. +//! /// +//! /// # Parameters +//! /// +//! /// * `rng`: The random number generator to use +//! /// * `clock`: The clock used to generate timestamps +//! /// +//! /// # Errors +//! /// +//! /// Returns [`Self::Error`] if the underlying repository fails +//! async fn add( +//! &mut self, +//! rng: &mut (dyn RngCore + Send), +//! clock: &dyn Clock, +//! ) -> Result; +//! } +//! +//! repository_impl!(FakeDataRepository: +//! async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; +//! async fn add( +//! &mut self, +//! rng: &mut (dyn RngCore + Send), +//! clock: &dyn Clock, +//! ) -> Result; +//! ); +//! ``` +//! +//! Four things to note with the implementation: +//! +//! 1. It defined an assocated error type, and all functions are faillible, +//! and use that error type +//! 2. Lookups return an `Result, Self::Error>`, because 'not found' +//! errors are usually cases that are handled differently +//! 3. Operations that need to record the current type use a [`Clock`] +//! parameter. Operations that need to generate new IDs also use a random +//! number generator. +//! 4. All the methods use an `&mut self`. This is ensures only one operation +//! is done at a time on a single repository instance. +//! +//! Then update the [`RepositoryAccess`] trait to make the new repository +//! available: +//! +//! ```rust +//! # trait FakeDataRepository { +//! # type Error; +//! # } +//! +//! /// Access the various repositories the backend implements. +//! pub trait RepositoryAccess: Send { +//! /// The backend-specific error type used by each repository. +//! type Error: std::error::Error + Send + Sync + 'static; +//! +//! // ...other repositories... +//! +//! /// Get a [`FakeDataRepository`] +//! fn fake_data<'c>(&'c mut self) -> Box + 'c>; +//! } +//! ``` #![forbid(unsafe_code)] #![deny( clippy::all, clippy::str_to_string, clippy::future_not_send, - rustdoc::broken_intra_doc_links + rustdoc::broken_intra_doc_links, + missing_docs )] #![warn(clippy::pedantic)] -#![allow( - clippy::missing_errors_doc, - clippy::missing_panics_doc, - clippy::module_name_repetitions -)] +#![allow(clippy::module_name_repetitions)] -use chrono::{DateTime, Utc}; -use pagination::InvalidPagination; -use sqlx::{migrate::Migrator, postgres::PgQueryResult}; -use thiserror::Error; -use ulid::Ulid; - -trait LookupResultExt { - type Output; - - /// Transform a [`Result`] from a sqlx query to transform "not found" errors - /// into [`None`] - fn to_option(self) -> Result, sqlx::Error>; -} - -impl LookupResultExt for Result { - type Output = T; - - fn to_option(self) -> Result, sqlx::Error> { - match self { - Ok(v) => Ok(Some(v)), - Err(sqlx::Error::RowNotFound) => Ok(None), - Err(e) => Err(e), - } - } -} - -/// Generic error when interacting with the database -#[derive(Debug, Error)] -#[error(transparent)] -pub enum DatabaseError { - /// An error which came from the database itself - Driver(#[from] sqlx::Error), - - /// An error which occured while converting the data from the database - Inconsistency(#[from] DatabaseInconsistencyError), - - /// An error which occured while generating the paginated query - Pagination(#[from] InvalidPagination), - - /// An error which happened because the requested database operation is - /// invalid - #[error("Invalid database operation")] - InvalidOperation { - #[source] - source: Option>, - }, - - /// An error which happens when an operation affects not enough or too many - /// rows - #[error("Expected {expected} rows to be affected, but {actual} rows were affected")] - RowsAffected { expected: u64, actual: u64 }, -} - -impl DatabaseError { - pub(crate) fn ensure_affected_rows( - result: &PgQueryResult, - expected: u64, - ) -> Result<(), DatabaseError> { - let actual = result.rows_affected(); - if actual == expected { - Ok(()) - } else { - Err(DatabaseError::RowsAffected { expected, actual }) - } - } - - pub(crate) fn to_invalid_operation(e: E) -> Self { - Self::InvalidOperation { - source: Some(Box::new(e)), - } - } - - pub(crate) const fn invalid_operation() -> Self { - Self::InvalidOperation { source: None } - } -} - -#[derive(Debug, Error)] -pub struct DatabaseInconsistencyError { - table: &'static str, - column: Option<&'static str>, - row: Option, - - #[source] - source: Option>, -} - -impl std::fmt::Display for DatabaseInconsistencyError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Database inconsistency on table {}", self.table)?; - if let Some(column) = self.column { - write!(f, " column {column}")?; - } - if let Some(row) = self.row { - write!(f, " row {row}")?; - } - - Ok(()) - } -} - -impl DatabaseInconsistencyError { - #[must_use] - pub(crate) const fn on(table: &'static str) -> Self { - Self { - table, - column: None, - row: None, - source: None, - } - } - - #[must_use] - pub(crate) const fn column(mut self, column: &'static str) -> Self { - self.column = Some(column); - self - } - - #[must_use] - pub(crate) const fn row(mut self, row: Ulid) -> Self { - self.row = Some(row); - self - } - - pub(crate) fn source( - mut self, - source: E, - ) -> Self { - self.source = Some(Box::new(source)); - self - } -} - -#[derive(Default, Debug, Clone, Copy)] -pub struct Clock { - _private: (), -} - -impl Clock { - #[must_use] - pub fn now(&self) -> DateTime { - // This is the clock used elsewhere, it's fine to call Utc::now here - #[allow(clippy::disallowed_methods)] - Utc::now() - } -} +pub mod clock; +pub mod pagination; +pub(crate) mod repository; +mod utils; pub mod compat; pub mod oauth2; -pub(crate) mod pagination; pub mod upstream_oauth2; pub mod user; -/// Embedded migrations, allowing them to run on startup -pub static MIGRATOR: Migrator = sqlx::migrate!(); +pub use self::{ + clock::{Clock, SystemClock}, + pagination::{Page, Pagination}, + repository::{ + BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction, + }, + utils::{BoxClock, BoxRng, MapErr}, +}; diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index be9831426..3fba2399d 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,261 +12,128 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail}; -use rand::Rng; -use sqlx::{PgConnection, PgExecutor}; +use async_trait::async_trait; +use chrono::Duration; +use mas_data_model::{AccessToken, Session}; +use rand_core::RngCore; use ulid::Ulid; -use uuid::Uuid; -use super::client::lookup_client; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; +use crate::{repository_impl, Clock}; -#[tracing::instrument( - skip_all, - fields( - %session.id, - client.id = %session.client.id, - user.id = %session.browser_session.user.id, - access_token.id, - ), - err, -)] -pub async fn add_access_token( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - session: &Session, - access_token: String, - expires_after: Duration, -) -> Result { - let created_at = clock.now(); - let expires_at = created_at + expires_after; - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); +/// An [`OAuth2AccessTokenRepository`] helps interacting with [`AccessToken`] +/// saved in the storage backend +#[async_trait] +pub trait OAuth2AccessTokenRepository: Send + Sync { + /// The error type returned by the repository + type Error; - tracing::Span::current().record("access_token.id", tracing::field::display(id)); + /// Lookup an access token by its ID + /// + /// Returns the access token if it exists, `None` otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the access token to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - sqlx::query!( - r#" - INSERT INTO oauth2_access_tokens - (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at) - VALUES - ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - &access_token, - created_at, - expires_at, - ) - .execute(executor) - .await?; + /// Find an access token by its token + /// + /// Returns the access token if it exists, `None` otherwise + /// + /// # Parameters + /// + /// * `access_token`: The token of the access token to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error>; - Ok(AccessToken { - id, - access_token, - jti: id.to_string(), - created_at, - expires_at, - }) + /// Add a new access token to the database + /// + /// Returns the newly created access token + /// + /// # Parameters + /// + /// * `rng`: A random number generator + /// * `clock`: The clock used to generate timestamps + /// * `session`: The session the access token is associated with + /// * `access_token`: The access token to add + /// * `expires_after`: The duration after which the access token expires + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + session: &Session, + access_token: String, + expires_after: Duration, + ) -> Result; + + /// Revoke an access token + /// + /// Returns the revoked access token + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `access_token`: The access token to revoke + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn revoke( + &mut self, + clock: &dyn Clock, + access_token: AccessToken, + ) -> Result; + + /// Cleanup expired access tokens + /// + /// Returns the number of access tokens that were cleaned up + /// + /// # Parameters + /// + /// * `clock`: The clock used to get the current time + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result; } -#[derive(Debug)] -pub struct OAuth2AccessTokenLookup { - oauth2_access_token_id: Uuid, - oauth2_access_token: String, - oauth2_access_token_created_at: DateTime, - oauth2_access_token_expires_at: DateTime, - oauth2_session_id: Uuid, - oauth2_client_id: Uuid, - scope: String, - user_session_id: Uuid, - user_session_created_at: DateTime, - user_id: Uuid, - user_username: String, - user_session_last_authentication_id: Option, - user_session_last_authentication_created_at: Option>, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, -} +repository_impl!(OAuth2AccessTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; -#[allow(clippy::too_many_lines)] -pub async fn lookup_active_access_token( - conn: &mut PgConnection, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - OAuth2AccessTokenLookup, - r#" - SELECT - at.oauth2_access_token_id, - at.access_token AS "oauth2_access_token", - at.created_at AS "oauth2_access_token_created_at", - at.expires_at AS "oauth2_access_token_expires_at", - os.oauth2_session_id AS "oauth2_session_id!", - os.oauth2_client_id AS "oauth2_client_id!", - os.scope AS "scope!", - us.user_session_id AS "user_session_id!", - us.created_at AS "user_session_created_at!", - u.user_id AS "user_id!", - u.username AS "user_username!", - usa.user_session_authentication_id AS "user_session_last_authentication_id?", - usa.created_at AS "user_session_last_authentication_created_at?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error>; - FROM oauth2_access_tokens at - INNER JOIN oauth2_sessions os - USING (oauth2_session_id) - INNER JOIN user_sessions us - USING (user_session_id) - INNER JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications usa - USING (user_session_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + session: &Session, + access_token: String, + expires_after: Duration, + ) -> Result; - WHERE at.access_token = $1 - AND at.revoked_at IS NULL - AND os.finished_at IS NULL + async fn revoke( + &mut self, + clock: &dyn Clock, + access_token: AccessToken, + ) -> Result; - ORDER BY usa.created_at DESC - LIMIT 1 - "#, - token, - ) - .fetch_one(&mut *conn) - .await?; - - let access_token_id = Ulid::from(res.oauth2_access_token_id); - let access_token = AccessToken { - id: access_token_id, - jti: access_token_id.to_string(), - access_token: res.oauth2_access_token, - created_at: res.oauth2_access_token_created_at, - expires_at: res.oauth2_access_token_expires_at, - }; - - let session_id = res.oauth2_session_id.into(); - let client = lookup_client(&mut *conn, res.oauth2_client_id.into()) - .await? - .ok_or_else(|| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("client_id") - .row(session_id) - })?; - - let user_id = Ulid::from(res.user_id); - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(user_id) - .into()) - } - }; - - let user = User { - id: user_id, - username: res.user_username, - sub: user_id.to_string(), - primary_email, - }; - - let last_authentication = match ( - res.user_session_last_authentication_id, - res.user_session_last_authentication_created_at, - ) { - (None, None) => None, - (Some(id), Some(created_at)) => Some(Authentication { - id: id.into(), - created_at, - }), - _ => return Err(DatabaseInconsistencyError::on("user_session_authentications").into()), - }; - - let browser_session = BrowserSession { - id: res.user_session_id.into(), - created_at: res.user_session_created_at, - user, - last_authentication, - }; - - let scope = res.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("scope") - .row(session_id) - .source(e) - })?; - - let session = Session { - id: session_id, - client, - browser_session, - scope, - }; - - Ok(Some((access_token, session))) -} - -#[tracing::instrument( - skip_all, - fields(%access_token.id), - err, -)] -pub async fn revoke_access_token( - executor: impl PgExecutor<'_>, - clock: &Clock, - access_token: AccessToken, -) -> Result<(), DatabaseError> { - let revoked_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_access_tokens - SET revoked_at = $2 - WHERE oauth2_access_token_id = $1 - "#, - Uuid::from(access_token.id), - revoked_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} - -pub async fn cleanup_expired( - executor: impl PgExecutor<'_>, - clock: &Clock, -) -> Result { - // Cleanup token which expired more than 15 minutes ago - let threshold = clock.now() - Duration::minutes(15); - let res = sqlx::query!( - r#" - DELETE FROM oauth2_access_tokens - WHERE expires_at < $1 - "#, - threshold, - ) - .execute(executor) - .await?; - - Ok(res.rows_affected()) -} + async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result; +); diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index e00274cc0..623ea596d 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,692 +14,184 @@ use std::num::NonZeroU32; -use chrono::{DateTime, Utc}; -use mas_data_model::{ - Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, - Client, Pkce, Session, User, UserEmail, -}; -use mas_iana::oauth::PkceCodeChallengeMethod; +use async_trait::async_trait; +use mas_data_model::{AuthorizationCode, AuthorizationGrant, Client, Session}; use oauth2_types::{requests::ResponseMode, scope::Scope}; -use rand::Rng; -use sqlx::{PgConnection, PgExecutor}; +use rand_core::RngCore; use ulid::Ulid; use url::Url; -use uuid::Uuid; -use super::client::lookup_client; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +use crate::{repository_impl, Clock}; -#[tracing::instrument( - skip_all, - fields( - %client.id, - grant.id, - ), - err, -)] -#[allow(clippy::too_many_arguments)] -pub async fn new_authorization_grant( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - client: Client, - redirect_uri: Url, - scope: Scope, - code: Option, - state: Option, - nonce: Option, - max_age: Option, - _acr_values: Option, - response_mode: ResponseMode, - response_type_id_token: bool, - requires_consent: bool, -) -> Result { - let code_challenge = code - .as_ref() - .and_then(|c| c.pkce.as_ref()) - .map(|p| &p.challenge); - let code_challenge_method = code - .as_ref() - .and_then(|c| c.pkce.as_ref()) - .map(|p| p.challenge_method.to_string()); - // TODO: this conversion is a bit ugly - let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)); - let code_str = code.as_ref().map(|c| &c.code); +/// An [`OAuth2AuthorizationGrantRepository`] helps interacting with +/// [`AuthorizationGrant`] saved in the storage backend +#[async_trait] +pub trait OAuth2AuthorizationGrantRepository: Send + Sync { + /// The error type returned by the repository + type Error; - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("grant.id", tracing::field::display(id)); + /// Create a new authorization grant + /// + /// Returns the newly created authorization grant + /// + /// # Parameters + /// + /// * `rng`: A random number generator + /// * `clock`: The clock used to generate timestamps + /// * `client`: The client that requested the authorization grant + /// * `redirect_uri`: The redirect URI the client requested + /// * `scope`: The scope the client requested + /// * `code`: The authorization code used by this grant, if the `code` + /// `response_type` was requested + /// * `state`: The state the client sent, if set + /// * `nonce`: The nonce the client sent, if set + /// * `max_age`: The maximum age since the user last authenticated, if asked + /// by the client + /// * `response_mode`: The response mode the client requested + /// * `response_type_id_token`: Whether the `id_token` `response_type` was + /// requested + /// * `requires_consent`: Whether the client explicitly requested consent + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client: &Client, + redirect_uri: Url, + scope: Scope, + code: Option, + state: Option, + nonce: Option, + max_age: Option, + response_mode: ResponseMode, + response_type_id_token: bool, + requires_consent: bool, + ) -> Result; - sqlx::query!( - r#" - INSERT INTO oauth2_authorization_grants ( - oauth2_authorization_grant_id, - oauth2_client_id, - redirect_uri, - scope, - state, - nonce, - max_age, - response_mode, - code_challenge, - code_challenge_method, - response_type_code, - response_type_id_token, - authorization_code, - requires_consent, - created_at - ) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) - "#, - Uuid::from(id), - Uuid::from(client.id), - redirect_uri.to_string(), - scope.to_string(), - state, - nonce, - max_age_i32, - response_mode.to_string(), - code_challenge, - code_challenge_method, - code.is_some(), - response_type_id_token, - code_str, - requires_consent, - created_at, - ) - .execute(executor) - .await?; + /// Lookup an authorization grant by its ID + /// + /// Returns the authorization grant if found, `None` otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the authorization grant to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - Ok(AuthorizationGrant { - id, - stage: AuthorizationGrantStage::Pending, - code, - redirect_uri, - client, - scope, - state, - nonce, - max_age, - response_mode, - created_at, - response_type_id_token, - requires_consent, - }) + /// Find an authorization grant by its code + /// + /// Returns the authorization grant if found, `None` otherwise + /// + /// # Parameters + /// + /// * `code`: The code of the authorization grant to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_by_code(&mut self, code: &str) + -> Result, Self::Error>; + + /// Fulfill an authorization grant, by giving the [`Session`] that it + /// created + /// + /// Returns the updated authorization grant + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `session`: The session that was created using this authorization grant + /// * `authorization_grant`: The authorization grant to fulfill + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn fulfill( + &mut self, + clock: &dyn Clock, + session: &Session, + authorization_grant: AuthorizationGrant, + ) -> Result; + + /// Mark an authorization grant as exchanged + /// + /// Returns the updated authorization grant + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `authorization_grant`: The authorization grant to mark as exchanged + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn exchange( + &mut self, + clock: &dyn Clock, + authorization_grant: AuthorizationGrant, + ) -> Result; + + /// Unset the `requires_consent` flag on an authorization grant + /// + /// Returns the updated authorization grant + /// + /// # Parameters + /// + /// * `authorization_grant`: The authorization grant to update + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn give_consent( + &mut self, + authorization_grant: AuthorizationGrant, + ) -> Result; } -#[allow(clippy::struct_excessive_bools)] -struct GrantLookup { - oauth2_authorization_grant_id: Uuid, - oauth2_authorization_grant_created_at: DateTime, - oauth2_authorization_grant_cancelled_at: Option>, - oauth2_authorization_grant_fulfilled_at: Option>, - oauth2_authorization_grant_exchanged_at: Option>, - oauth2_authorization_grant_scope: String, - oauth2_authorization_grant_state: Option, - oauth2_authorization_grant_nonce: Option, - oauth2_authorization_grant_redirect_uri: String, - oauth2_authorization_grant_response_mode: String, - oauth2_authorization_grant_max_age: Option, - oauth2_authorization_grant_response_type_code: bool, - oauth2_authorization_grant_response_type_id_token: bool, - oauth2_authorization_grant_code: Option, - oauth2_authorization_grant_code_challenge: Option, - oauth2_authorization_grant_code_challenge_method: Option, - oauth2_authorization_grant_requires_consent: bool, - oauth2_client_id: Uuid, - oauth2_session_id: Option, - user_session_id: Option, - user_session_created_at: Option>, - user_id: Option, - user_username: Option, - user_session_last_authentication_id: Option, - user_session_last_authentication_created_at: Option>, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, -} +repository_impl!(OAuth2AuthorizationGrantRepository: + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client: &Client, + redirect_uri: Url, + scope: Scope, + code: Option, + state: Option, + nonce: Option, + max_age: Option, + response_mode: ResponseMode, + response_type_id_token: bool, + requires_consent: bool, + ) -> Result; -impl GrantLookup { - #[allow(clippy::too_many_lines)] - async fn into_authorization_grant( - self, - executor: impl PgExecutor<'_>, - ) -> Result { - let id = self.oauth2_authorization_grant_id.into(); - let scope: Scope = self.oauth2_authorization_grant_scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("scope") - .row(id) - .source(e) - })?; + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - // TODO: don't unwrap - let client = lookup_client(executor, self.oauth2_client_id.into()) - .await? - .ok_or_else(|| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("client_id") - .row(id) - })?; + async fn find_by_code(&mut self, code: &str) + -> Result, Self::Error>; - let last_authentication = match ( - self.user_session_last_authentication_id, - self.user_session_last_authentication_created_at, - ) { - (Some(id), Some(created_at)) => Some(Authentication { - id: id.into(), - created_at, - }), - (None, None) => None, - _ => return Err(DatabaseInconsistencyError::on("user_session_authentications").into()), - }; + async fn fulfill( + &mut self, + clock: &dyn Clock, + session: &Session, + authorization_grant: AuthorizationGrant, + ) -> Result; - let primary_email = match ( - self.user_email_id, - self.user_email, - self.user_email_created_at, - self.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .into()) - } - }; + async fn exchange( + &mut self, + clock: &dyn Clock, + authorization_grant: AuthorizationGrant, + ) -> Result; - let session = match ( - self.oauth2_session_id, - self.user_session_id, - self.user_session_created_at, - self.user_id, - self.user_username, - last_authentication, - primary_email, - ) { - ( - Some(session_id), - Some(user_session_id), - Some(user_session_created_at), - Some(user_id), - Some(user_username), - last_authentication, - primary_email, - ) => { - let user_id = Ulid::from(user_id); - let user = User { - id: user_id, - username: user_username, - sub: user_id.to_string(), - primary_email, - }; - - let browser_session = BrowserSession { - id: user_session_id.into(), - user, - created_at: user_session_created_at, - last_authentication, - }; - - let client = client.clone(); - let scope = scope.clone(); - - let session = Session { - id: session_id.into(), - client, - browser_session, - scope, - }; - - Some(session) - } - (None, None, None, None, None, None, None) => None, - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("oauth2_session_id") - .row(id) - .into(), - ) - } - }; - - let stage = match ( - self.oauth2_authorization_grant_fulfilled_at, - self.oauth2_authorization_grant_exchanged_at, - self.oauth2_authorization_grant_cancelled_at, - session, - ) { - (None, None, None, None) => AuthorizationGrantStage::Pending, - (Some(fulfilled_at), None, None, Some(session)) => AuthorizationGrantStage::Fulfilled { - session, - fulfilled_at, - }, - (Some(fulfilled_at), Some(exchanged_at), None, Some(session)) => { - AuthorizationGrantStage::Exchanged { - session, - fulfilled_at, - exchanged_at, - } - } - (None, None, Some(cancelled_at), None) => { - AuthorizationGrantStage::Cancelled { cancelled_at } - } - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("stage") - .row(id) - .into(), - ); - } - }; - - let pkce = match ( - self.oauth2_authorization_grant_code_challenge, - self.oauth2_authorization_grant_code_challenge_method, - ) { - (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { - Some(Pkce { - challenge_method: PkceCodeChallengeMethod::Plain, - challenge, - }) - } - (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce { - challenge_method: PkceCodeChallengeMethod::S256, - challenge, - }), - (None, None) => None, - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("code_challenge_method") - .row(id) - .into(), - ); - } - }; - - let code: Option = match ( - self.oauth2_authorization_grant_response_type_code, - self.oauth2_authorization_grant_code, - pkce, - ) { - (false, None, None) => None, - (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }), - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("authorization_code") - .row(id) - .into(), - ); - } - }; - - let redirect_uri = self - .oauth2_authorization_grant_redirect_uri - .parse() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("redirect_uri") - .row(id) - .source(e) - })?; - - let response_mode = self - .oauth2_authorization_grant_response_mode - .parse() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("response_mode") - .row(id) - .source(e) - })?; - - let max_age = self - .oauth2_authorization_grant_max_age - .map(u32::try_from) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("max_age") - .row(id) - .source(e) - })? - .map(NonZeroU32::try_from) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("max_age") - .row(id) - .source(e) - })?; - - Ok(AuthorizationGrant { - id, - stage, - client, - code, - scope, - state: self.oauth2_authorization_grant_state, - nonce: self.oauth2_authorization_grant_nonce, - max_age, - response_mode, - redirect_uri, - created_at: self.oauth2_authorization_grant_created_at, - response_type_id_token: self.oauth2_authorization_grant_response_type_id_token, - requires_consent: self.oauth2_authorization_grant_requires_consent, - }) - } -} - -#[tracing::instrument( - skip_all, - fields(grant.id = %id), - err, -)] -pub async fn get_grant_by_id( - conn: &mut PgConnection, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - GrantLookup, - r#" - SELECT - og.oauth2_authorization_grant_id, - og.created_at AS oauth2_authorization_grant_created_at, - og.cancelled_at AS oauth2_authorization_grant_cancelled_at, - og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at, - og.exchanged_at AS oauth2_authorization_grant_exchanged_at, - og.scope AS oauth2_authorization_grant_scope, - og.state AS oauth2_authorization_grant_state, - og.redirect_uri AS oauth2_authorization_grant_redirect_uri, - og.response_mode AS oauth2_authorization_grant_response_mode, - og.nonce AS oauth2_authorization_grant_nonce, - og.max_age AS oauth2_authorization_grant_max_age, - og.oauth2_client_id AS oauth2_client_id, - og.authorization_code AS oauth2_authorization_grant_code, - og.response_type_code AS oauth2_authorization_grant_response_type_code, - og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token, - og.code_challenge AS oauth2_authorization_grant_code_challenge, - og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method, - og.requires_consent AS oauth2_authorization_grant_requires_consent, - os.oauth2_session_id AS "oauth2_session_id?", - us.user_session_id AS "user_session_id?", - us.created_at AS "user_session_created_at?", - u.user_id AS "user_id?", - u.username AS "user_username?", - usa.user_session_authentication_id AS "user_session_last_authentication_id?", - usa.created_at AS "user_session_last_authentication_created_at?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" - FROM - oauth2_authorization_grants og - LEFT JOIN oauth2_sessions os - USING (oauth2_session_id) - LEFT JOIN user_sessions us - USING (user_session_id) - LEFT JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications usa - USING (user_session_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id - - WHERE og.oauth2_authorization_grant_id = $1 - - ORDER BY usa.created_at DESC - LIMIT 1 - "#, - Uuid::from(id), - ) - .fetch_one(&mut *conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let grant = res.into_authorization_grant(&mut *conn).await?; - - Ok(Some(grant)) -} - -#[tracing::instrument(skip_all, err)] -pub async fn lookup_grant_by_code( - conn: &mut PgConnection, - code: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - GrantLookup, - r#" - SELECT - og.oauth2_authorization_grant_id, - og.created_at AS oauth2_authorization_grant_created_at, - og.cancelled_at AS oauth2_authorization_grant_cancelled_at, - og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at, - og.exchanged_at AS oauth2_authorization_grant_exchanged_at, - og.scope AS oauth2_authorization_grant_scope, - og.state AS oauth2_authorization_grant_state, - og.redirect_uri AS oauth2_authorization_grant_redirect_uri, - og.response_mode AS oauth2_authorization_grant_response_mode, - og.nonce AS oauth2_authorization_grant_nonce, - og.max_age AS oauth2_authorization_grant_max_age, - og.oauth2_client_id AS oauth2_client_id, - og.authorization_code AS oauth2_authorization_grant_code, - og.response_type_code AS oauth2_authorization_grant_response_type_code, - og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token, - og.code_challenge AS oauth2_authorization_grant_code_challenge, - og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method, - og.requires_consent AS oauth2_authorization_grant_requires_consent, - os.oauth2_session_id AS "oauth2_session_id?", - us.user_session_id AS "user_session_id?", - us.created_at AS "user_session_created_at?", - u.user_id AS "user_id?", - u.username AS "user_username?", - usa.user_session_authentication_id AS "user_session_last_authentication_id?", - usa.created_at AS "user_session_last_authentication_created_at?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" - FROM - oauth2_authorization_grants og - LEFT JOIN oauth2_sessions os - USING (oauth2_session_id) - LEFT JOIN user_sessions us - USING (user_session_id) - LEFT JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications usa - USING (user_session_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id - - WHERE og.authorization_code = $1 - - ORDER BY usa.created_at DESC - LIMIT 1 - "#, - code, - ) - .fetch_one(&mut *conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let grant = res.into_authorization_grant(&mut *conn).await?; - - Ok(Some(grant)) -} - -#[tracing::instrument( - skip_all, - fields( - %grant.id, - client.id = %grant.client.id, - session.id, - user_session.id = %browser_session.id, - user.id = %browser_session.user.id, - ), - err, -)] -pub async fn derive_session( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - grant: &AuthorizationGrant, - browser_session: BrowserSession, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO oauth2_sessions - (oauth2_session_id, user_session_id, oauth2_client_id, scope, created_at) - SELECT - $1, - $2, - og.oauth2_client_id, - og.scope, - $3 - FROM - oauth2_authorization_grants og - WHERE - og.oauth2_authorization_grant_id = $4 - "#, - Uuid::from(id), - Uuid::from(browser_session.id), - created_at, - Uuid::from(grant.id), - ) - .execute(executor) - .await?; - - Ok(Session { - id, - browser_session, - client: grant.client.clone(), - scope: grant.scope.clone(), - }) -} - -#[tracing::instrument( - skip_all, - fields( - %grant.id, - client.id = %grant.client.id, - %session.id, - user_session.id = %session.browser_session.id, - user.id = %session.browser_session.user.id, - ), - err, -)] -pub async fn fulfill_grant( - executor: impl PgExecutor<'_>, - mut grant: AuthorizationGrant, - session: Session, -) -> Result { - let fulfilled_at = sqlx::query_scalar!( - r#" - UPDATE oauth2_authorization_grants AS og - SET - oauth2_session_id = os.oauth2_session_id, - fulfilled_at = os.created_at - FROM oauth2_sessions os - WHERE - og.oauth2_authorization_grant_id = $1 - AND os.oauth2_session_id = $2 - RETURNING fulfilled_at AS "fulfilled_at!: DateTime" - "#, - Uuid::from(grant.id), - Uuid::from(session.id), - ) - .fetch_one(executor) - .await?; - - grant.stage = grant - .stage - .fulfill(fulfilled_at, session) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(grant) -} - -#[tracing::instrument( - skip_all, - fields( - %grant.id, - client.id = %grant.client.id, - ), - err, -)] -pub async fn give_consent_to_grant( - executor: impl PgExecutor<'_>, - mut grant: AuthorizationGrant, -) -> Result { - sqlx::query!( - r#" - UPDATE oauth2_authorization_grants AS og - SET - requires_consent = 'f' - WHERE - og.oauth2_authorization_grant_id = $1 - "#, - Uuid::from(grant.id), - ) - .execute(executor) - .await?; - - grant.requires_consent = false; - - Ok(grant) -} - -#[tracing::instrument( - skip_all, - fields( - %grant.id, - client.id = %grant.client.id, - ), - err, -)] -pub async fn exchange_grant( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut grant: AuthorizationGrant, -) -> Result { - let exchanged_at = clock.now(); - sqlx::query!( - r#" - UPDATE oauth2_authorization_grants - SET exchanged_at = $2 - WHERE oauth2_authorization_grant_id = $1 - "#, - Uuid::from(grant.id), - exchanged_at, - ) - .execute(executor) - .await?; - - grant.stage = grant - .stage - .exchange(exchanged_at) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(grant) -} + async fn give_consent( + &mut self, + authorization_grant: AuthorizationGrant, + ) -> Result; +); diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 6138e377c..18f0108b7 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,527 +12,243 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashMap, string::ToString}; +use std::collections::{BTreeMap, BTreeSet}; -use mas_data_model::{Client, JwksOrJwksUri}; -use mas_iana::{ - jose::JsonWebSignatureAlg, - oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, -}; +use async_trait::async_trait; +use mas_data_model::{Client, User}; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_jose::jwk::PublicJsonWebKeySet; -use oauth2_types::requests::GrantType; -use rand::Rng; -use sqlx::{PgConnection, PgExecutor}; +use oauth2_types::{requests::GrantType, scope::Scope}; +use rand_core::RngCore; use ulid::Ulid; use url::Url; -use uuid::Uuid; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +use crate::{repository_impl, Clock}; -// XXX: response_types & contacts -#[derive(Debug)] -pub struct OAuth2ClientLookup { - oauth2_client_id: Uuid, - encrypted_client_secret: Option, - redirect_uris: Vec, - // response_types: Vec, - grant_type_authorization_code: bool, - grant_type_refresh_token: bool, - // contacts: Vec, - client_name: Option, - logo_uri: Option, - client_uri: Option, - policy_uri: Option, - tos_uri: Option, - jwks_uri: Option, - jwks: Option, - id_token_signed_response_alg: Option, - userinfo_signed_response_alg: Option, - token_endpoint_auth_method: Option, - token_endpoint_auth_signing_alg: Option, - initiate_login_uri: Option, -} +/// An [`OAuth2ClientRepository`] helps interacting with [`Client`] saved in the +/// storage backend +#[async_trait] +pub trait OAuth2ClientRepository: Send + Sync { + /// The error type returned by the repository + type Error; -impl TryInto for OAuth2ClientLookup { - type Error = DatabaseInconsistencyError; + /// Lookup an OAuth2 client by its ID + /// + /// Returns `None` if the client does not exist + /// + /// # Parameters + /// + /// * `id`: The ID of the client to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - #[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing - fn try_into(self) -> Result { - let id = Ulid::from(self.oauth2_client_id); - - let redirect_uris: Result, _> = - self.redirect_uris.iter().map(|s| s.parse()).collect(); - let redirect_uris = redirect_uris.map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("redirect_uris") - .row(id) - .source(e) - })?; - - let response_types = vec![ - OAuthAuthorizationEndpointResponseType::Code, - OAuthAuthorizationEndpointResponseType::IdToken, - OAuthAuthorizationEndpointResponseType::None, - ]; - /* XXX - let response_types: Result, _> = - self.response_types.iter().map(|s| s.parse()).collect(); - let response_types = response_types.map_err(|source| ClientFetchError::ParseField { - field: "response_types", - source, - })?; - */ - - let mut grant_types = Vec::new(); - if self.grant_type_authorization_code { - grant_types.push(GrantType::AuthorizationCode); - } - if self.grant_type_refresh_token { - grant_types.push(GrantType::RefreshToken); - } - - let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("logo_uri") - .row(id) - .source(e) - })?; - - let client_uri = self - .client_uri - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("client_uri") - .row(id) - .source(e) - })?; - - let policy_uri = self - .policy_uri - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("policy_uri") - .row(id) - .source(e) - })?; - - let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("tos_uri") - .row(id) - .source(e) - })?; - - let id_token_signed_response_alg = self - .id_token_signed_response_alg - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("id_token_signed_response_alg") - .row(id) - .source(e) - })?; - - let userinfo_signed_response_alg = self - .userinfo_signed_response_alg - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("userinfo_signed_response_alg") - .row(id) - .source(e) - })?; - - let token_endpoint_auth_method = self - .token_endpoint_auth_method - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("token_endpoint_auth_method") - .row(id) - .source(e) - })?; - - let token_endpoint_auth_signing_alg = self - .token_endpoint_auth_signing_alg - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("token_endpoint_auth_signing_alg") - .row(id) - .source(e) - })?; - - let initiate_login_uri = self - .initiate_login_uri - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("initiate_login_uri") - .row(id) - .source(e) - })?; - - let jwks = match (self.jwks, self.jwks_uri) { - (None, None) => None, - (Some(jwks), None) => { - let jwks = serde_json::from_value(jwks).map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("jwks") - .row(id) - .source(e) - })?; - Some(JwksOrJwksUri::Jwks(jwks)) - } - (None, Some(jwks_uri)) => { - let jwks_uri = jwks_uri.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("jwks_uri") - .row(id) - .source(e) - })?; - - Some(JwksOrJwksUri::JwksUri(jwks_uri)) - } - _ => { - return Err(DatabaseInconsistencyError::on("oauth2_clients") - .column("jwks(_uri)") - .row(id)) - } - }; - - Ok(Client { - id, - client_id: id.to_string(), - encrypted_client_secret: self.encrypted_client_secret, - redirect_uris, - response_types, - grant_types, - // contacts: self.contacts, - contacts: vec![], - client_name: self.client_name, - logo_uri, - client_uri, - policy_uri, - tos_uri, - jwks, - id_token_signed_response_alg, - userinfo_signed_response_alg, - token_endpoint_auth_method, - token_endpoint_auth_signing_alg, - initiate_login_uri, - }) + /// Find an OAuth2 client by its client ID + async fn find_by_client_id(&mut self, client_id: &str) -> Result, Self::Error> { + let Ok(id) = client_id.parse() else { return Ok(None) }; + self.lookup(id).await } + + /// Load a batch of OAuth2 clients by their IDs + /// + /// Returns a map of client IDs to clients. If a client does not exist, it + /// is not present in the map. + /// + /// # Parameters + /// + /// * `ids`: The IDs of the clients to load + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn load_batch( + &mut self, + ids: BTreeSet, + ) -> Result, Self::Error>; + + /// Add a new OAuth2 client + /// + /// Returns the client that was added + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `redirect_uris`: The list of redirect URIs used by this client + /// * `encrypted_client_secret`: The encrypted client secret, if any + /// * `grant_types`: The list of grant types this client can use + /// * `contacts`: The list of contacts for this client + /// * `client_name`: The human-readable name of this client, if given + /// * `logo_uri`: The URI of the logo of this client, if given + /// * `client_uri`: The URI of a website of this client, if given + /// * `policy_uri`: The URI of the privacy policy of this client, if given + /// * `tos_uri`: The URI of the terms of service of this client, if given + /// * `jwks_uri`: The URI of the JWKS of this client, if given + /// * `jwks`: The JWKS of this client, if given + /// * `id_token_signed_response_alg`: The algorithm used to sign the ID + /// token + /// * `userinfo_signed_response_alg`: The algorithm used to sign the user + /// info. If none, the user info endpoint will not sign the response + /// * `token_endpoint_auth_method`: The authentication method used by this + /// client when calling the token endpoint + /// * `token_endpoint_auth_signing_alg`: The algorithm used to sign the JWT + /// when using the `client_secret_jwt` or `private_key_jwt` authentication + /// methods + /// * `initiate_login_uri`: The URI used to initiate a login, if given + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + redirect_uris: Vec, + encrypted_client_secret: Option, + grant_types: Vec, + contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, + ) -> Result; + + /// Add or replace a client from the configuration + /// + /// Returns the client that was added or replaced + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `client_id`: The client ID + /// * `client_auth_method`: The authentication method this client uses + /// * `encrypted_client_secret`: The encrypted client secret, if any + /// * `jwks`: The client JWKS, if any + /// * `jwks_uri`: The client JWKS URI, if any + /// * `redirect_uris`: The list of redirect URIs used by this client + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + #[allow(clippy::too_many_arguments)] + async fn add_from_config( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client_id: Ulid, + client_auth_method: OAuthClientAuthenticationMethod, + encrypted_client_secret: Option, + jwks: Option, + jwks_uri: Option, + redirect_uris: Vec, + ) -> Result; + + /// Get the list of scopes that the user has given consent for the given + /// client + /// + /// # Parameters + /// + /// * `client`: The client to get the consent for + /// * `user`: The user to get the consent for + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn get_consent_for_user( + &mut self, + client: &Client, + user: &User, + ) -> Result; + + /// Give consent for a set of scopes for the given client and user + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `client`: The client to give the consent for + /// * `user`: The user to give the consent for + /// * `scope`: The scope to give consent for + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn give_consent_for_user( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client: &Client, + user: &User, + scope: &Scope, + ) -> Result<(), Self::Error>; } -#[tracing::instrument(skip_all, err)] -pub async fn lookup_clients( - executor: impl PgExecutor<'_>, - ids: impl IntoIterator + Send, -) -> Result, DatabaseError> { - let ids: Vec = ids.into_iter().map(Uuid::from).collect(); - let res = sqlx::query_as!( - OAuth2ClientLookup, - r#" - SELECT - c.oauth2_client_id, - c.encrypted_client_secret, - ARRAY( - SELECT redirect_uri - FROM oauth2_client_redirect_uris r - WHERE r.oauth2_client_id = c.oauth2_client_id - ) AS "redirect_uris!", - c.grant_type_authorization_code, - c.grant_type_refresh_token, - c.client_name, - c.logo_uri, - c.client_uri, - c.policy_uri, - c.tos_uri, - c.jwks_uri, - c.jwks, - c.id_token_signed_response_alg, - c.userinfo_signed_response_alg, - c.token_endpoint_auth_method, - c.token_endpoint_auth_signing_alg, - c.initiate_login_uri - FROM oauth2_clients c +repository_impl!(OAuth2ClientRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - WHERE c.oauth2_client_id = ANY($1::uuid[]) - "#, - &ids, - ) - .fetch_all(executor) - .await?; + async fn load_batch( + &mut self, + ids: BTreeSet, + ) -> Result, Self::Error>; - res.into_iter() - .map(|r| { - r.try_into() - .map(|c: Client| (c.id, c)) - .map_err(DatabaseError::from) - }) - .collect() -} + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + redirect_uris: Vec, + encrypted_client_secret: Option, + grant_types: Vec, + contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, + ) -> Result; -#[tracing::instrument( - skip_all, - fields(client.id = %id), - err, -)] -pub async fn lookup_client( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - OAuth2ClientLookup, - r#" - SELECT - c.oauth2_client_id, - c.encrypted_client_secret, - ARRAY( - SELECT redirect_uri - FROM oauth2_client_redirect_uris r - WHERE r.oauth2_client_id = c.oauth2_client_id - ) AS "redirect_uris!", - c.grant_type_authorization_code, - c.grant_type_refresh_token, - c.client_name, - c.logo_uri, - c.client_uri, - c.policy_uri, - c.tos_uri, - c.jwks_uri, - c.jwks, - c.id_token_signed_response_alg, - c.userinfo_signed_response_alg, - c.token_endpoint_auth_method, - c.token_endpoint_auth_signing_alg, - c.initiate_login_uri - FROM oauth2_clients c + async fn add_from_config( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client_id: Ulid, + client_auth_method: OAuthClientAuthenticationMethod, + encrypted_client_secret: Option, + jwks: Option, + jwks_uri: Option, + redirect_uris: Vec, + ) -> Result; - WHERE c.oauth2_client_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; + async fn get_consent_for_user( + &mut self, + client: &Client, + user: &User, + ) -> Result; - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) -} - -#[tracing::instrument( - skip_all, - fields(client.id = client_id), - err, -)] -pub async fn lookup_client_by_client_id( - executor: impl PgExecutor<'_>, - client_id: &str, -) -> Result, DatabaseError> { - let Ok(id) = client_id.parse() else { return Ok(None) }; - lookup_client(executor, id).await -} - -#[tracing::instrument( - skip_all, - fields(client.id = %client_id, client.name = client_name), - err, -)] -#[allow(clippy::too_many_arguments)] -pub async fn insert_client( - conn: &mut PgConnection, - mut rng: impl Rng + Send, - clock: &Clock, - client_id: Ulid, - redirect_uris: &[Url], - encrypted_client_secret: Option<&str>, - grant_types: &[GrantType], - _contacts: &[String], - client_name: Option<&str>, - logo_uri: Option<&Url>, - client_uri: Option<&Url>, - policy_uri: Option<&Url>, - tos_uri: Option<&Url>, - jwks_uri: Option<&Url>, - jwks: Option<&PublicJsonWebKeySet>, - id_token_signed_response_alg: Option<&JsonWebSignatureAlg>, - userinfo_signed_response_alg: Option<&JsonWebSignatureAlg>, - token_endpoint_auth_method: Option<&OAuthClientAuthenticationMethod>, - token_endpoint_auth_signing_alg: Option<&JsonWebSignatureAlg>, - initiate_login_uri: Option<&Url>, -) -> Result<(), sqlx::Error> { - let grant_type_authorization_code = grant_types.contains(&GrantType::AuthorizationCode); - let grant_type_refresh_token = grant_types.contains(&GrantType::RefreshToken); - let logo_uri = logo_uri.map(Url::as_str); - let client_uri = client_uri.map(Url::as_str); - let policy_uri = policy_uri.map(Url::as_str); - let tos_uri = tos_uri.map(Url::as_str); - let jwks = jwks.map(serde_json::to_value).transpose().unwrap(); // TODO - let jwks_uri = jwks_uri.map(Url::as_str); - let id_token_signed_response_alg = id_token_signed_response_alg.map(ToString::to_string); - let userinfo_signed_response_alg = userinfo_signed_response_alg.map(ToString::to_string); - let token_endpoint_auth_method = token_endpoint_auth_method.map(ToString::to_string); - let token_endpoint_auth_signing_alg = token_endpoint_auth_signing_alg.map(ToString::to_string); - let initiate_login_uri = initiate_login_uri.map(Url::as_str); - - sqlx::query!( - r#" - INSERT INTO oauth2_clients - (oauth2_client_id, - encrypted_client_secret, - grant_type_authorization_code, - grant_type_refresh_token, - client_name, - logo_uri, - client_uri, - policy_uri, - tos_uri, - jwks_uri, - jwks, - id_token_signed_response_alg, - userinfo_signed_response_alg, - token_endpoint_auth_method, - token_endpoint_auth_signing_alg, - initiate_login_uri) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) - "#, - Uuid::from(client_id), - encrypted_client_secret, - grant_type_authorization_code, - grant_type_refresh_token, - client_name, - logo_uri, - client_uri, - policy_uri, - tos_uri, - jwks_uri, - jwks, - id_token_signed_response_alg, - userinfo_signed_response_alg, - token_endpoint_auth_method, - token_endpoint_auth_signing_alg, - initiate_login_uri, - ) - .execute(&mut *conn) - .await?; - - let now = clock.now(); - let (ids, redirect_uris): (Vec, Vec) = redirect_uris - .iter() - .map(|uri| { - ( - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), - uri.as_str().to_owned(), - ) - }) - .unzip(); - - sqlx::query!( - r#" - INSERT INTO oauth2_client_redirect_uris - (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) - SELECT id, $2, redirect_uri - FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) - "#, - &ids, - Uuid::from(client_id), - &redirect_uris, - ) - .execute(&mut *conn) - .await?; - - Ok(()) -} - -#[allow(clippy::too_many_arguments)] -pub async fn insert_client_from_config( - conn: &mut PgConnection, - mut rng: impl Rng + Send, - clock: &Clock, - client_id: Ulid, - client_auth_method: OAuthClientAuthenticationMethod, - encrypted_client_secret: Option<&str>, - jwks: Option<&PublicJsonWebKeySet>, - jwks_uri: Option<&Url>, - redirect_uris: &[Url], -) -> Result<(), DatabaseError> { - let jwks = jwks - .map(serde_json::to_value) - .transpose() - .map_err(DatabaseError::to_invalid_operation)?; - - let jwks_uri = jwks_uri.map(Url::as_str); - - let client_auth_method = client_auth_method.to_string(); - - sqlx::query!( - r#" - INSERT INTO oauth2_clients - (oauth2_client_id, - encrypted_client_secret, - grant_type_authorization_code, - grant_type_refresh_token, - token_endpoint_auth_method, - jwks, - jwks_uri) - VALUES - ($1, $2, $3, $4, $5, $6, $7) - "#, - Uuid::from(client_id), - encrypted_client_secret, - true, - true, - client_auth_method, - jwks, - jwks_uri, - ) - .execute(&mut *conn) - .await?; - - let now = clock.now(); - let (ids, redirect_uris): (Vec, Vec) = redirect_uris - .iter() - .map(|uri| { - ( - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), - uri.as_str().to_owned(), - ) - }) - .unzip(); - - sqlx::query!( - r#" - INSERT INTO oauth2_client_redirect_uris - (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) - SELECT id, $2, redirect_uri - FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) - "#, - &ids, - Uuid::from(client_id), - &redirect_uris, - ) - .execute(&mut *conn) - .await?; - - Ok(()) -} - -pub async fn truncate_clients(executor: impl PgExecutor<'_>) -> Result<(), sqlx::Error> { - sqlx::query!("TRUNCATE oauth2_client_redirect_uris, oauth2_clients CASCADE") - .execute(executor) - .await?; - Ok(()) -} + async fn give_consent_for_user( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client: &Client, + user: &User, + scope: &Scope, + ) -> Result<(), Self::Error>; +); diff --git a/crates/storage/src/oauth2/consent.rs b/crates/storage/src/oauth2/consent.rs deleted file mode 100644 index c1a5080df..000000000 --- a/crates/storage/src/oauth2/consent.rs +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::str::FromStr; - -use mas_data_model::{Client, User}; -use oauth2_types::scope::{Scope, ScopeToken}; -use rand::Rng; -use sqlx::PgExecutor; -use ulid::Ulid; -use uuid::Uuid; - -use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %client.id, - ), - err, -)] -pub async fn fetch_client_consent( - executor: impl PgExecutor<'_>, - user: &User, - client: &Client, -) -> Result { - let scope_tokens: Vec = sqlx::query_scalar!( - r#" - SELECT scope_token - FROM oauth2_consents - WHERE user_id = $1 AND oauth2_client_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(client.id), - ) - .fetch_all(executor) - .await?; - - let scope: Result = scope_tokens - .into_iter() - .map(|s| ScopeToken::from_str(&s)) - .collect(); - - let scope = scope.map_err(|e| { - DatabaseInconsistencyError::on("oauth2_consents") - .column("scope_token") - .source(e) - })?; - - Ok(scope) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %client.id, - %scope, - ), - err, -)] -pub async fn insert_client_consent( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user: &User, - client: &Client, - scope: &Scope, -) -> Result<(), sqlx::Error> { - let now = clock.now(); - let (tokens, ids): (Vec, Vec) = scope - .iter() - .map(|token| { - ( - token.to_string(), - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), - ) - }) - .unzip(); - - sqlx::query!( - r#" - INSERT INTO oauth2_consents - (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at) - SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token) - ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5 - "#, - &ids, - Uuid::from(user.id), - Uuid::from(client.id), - &tokens, - now, - ) - .execute(executor) - .await?; - - Ok(()) -} diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index 81a743633..75823c277 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,173 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{BTreeSet, HashMap}; +//! Repositories to interact with entities related to the OAuth 2.0 protocol -use mas_data_model::{BrowserSession, Session, User}; -use sqlx::{PgConnection, PgExecutor, QueryBuilder}; -use tracing::{info_span, Instrument}; -use ulid::Ulid; -use uuid::Uuid; +mod access_token; +mod authorization_grant; +mod client; +mod refresh_token; +mod session; -use self::client::lookup_clients; -use crate::{ - pagination::{process_page, QueryBuilderExt}, - user::lookup_active_session, - Clock, DatabaseError, DatabaseInconsistencyError, +pub use self::{ + access_token::OAuth2AccessTokenRepository, + authorization_grant::OAuth2AuthorizationGrantRepository, client::OAuth2ClientRepository, + refresh_token::OAuth2RefreshTokenRepository, session::OAuth2SessionRepository, }; - -pub mod access_token; -pub mod authorization_grant; -pub mod client; -pub mod consent; -pub mod refresh_token; - -#[tracing::instrument( - skip_all, - fields( - %session.id, - user.id = %session.browser_session.user.id, - user_session.id = %session.browser_session.id, - client.id = %session.client.id, - ), - err, -)] -pub async fn end_oauth_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - session: Session, -) -> Result<(), DatabaseError> { - let finished_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_sessions - SET finished_at = $2 - WHERE oauth2_session_id = $1 - "#, - Uuid::from(session.id), - finished_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} - -#[derive(sqlx::FromRow)] -struct OAuthSessionLookup { - oauth2_session_id: Uuid, - user_session_id: Uuid, - oauth2_client_id: Uuid, - scope: String, -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - ), - err, -)] -pub async fn get_paginated_user_oauth_sessions( - conn: &mut PgConnection, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - os.oauth2_session_id, - os.user_session_id, - os.oauth2_client_id, - os.scope, - os.created_at, - os.finished_at - FROM oauth2_sessions os - LEFT JOIN user_sessions us - USING (user_session_id) - "#, - ); - - query - .push(" WHERE us.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("oauth2_session_id", before, after, first, last)?; - - let span = info_span!( - "Fetch paginated user oauth sessions", - db.statement = query.sql() - ); - let page: Vec = query - .build_query_as() - .fetch_all(&mut *conn) - .instrument(span) - .await?; - - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - - let client_ids: BTreeSet = page - .iter() - .map(|i| Ulid::from(i.oauth2_client_id)) - .collect(); - - let browser_session_ids: BTreeSet = - page.iter().map(|i| Ulid::from(i.user_session_id)).collect(); - - let clients = lookup_clients(&mut *conn, client_ids).await?; - - // TODO: this can generate N queries instead of batching. This is less than - // ideal - let mut browser_sessions: HashMap = HashMap::new(); - for id in browser_session_ids { - let v = lookup_active_session(&mut *conn, id) - .await? - .ok_or_else(|| { - DatabaseInconsistencyError::on("oauth2_sessions").column("user_session_id") - })?; - browser_sessions.insert(id, v); - } - - let page: Result, DatabaseInconsistencyError> = page - .into_iter() - .map(|item| { - let id = Ulid::from(item.oauth2_session_id); - let client = clients - .get(&Ulid::from(item.oauth2_client_id)) - .ok_or_else(|| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("oauth2_client_id") - .row(id) - })? - .clone(); - - let browser_session = browser_sessions - .get(&Ulid::from(item.user_session_id)) - .ok_or_else(|| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("user_session_id") - .row(id) - })? - .clone(); - - let scope = item.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("scope") - .row(id) - .source(e) - })?; - - Ok(Session { - id: Ulid::from(item.oauth2_session_id), - client, - browser_session, - scope, - }) - }) - .collect(); - - Ok((has_previous_page, has_next_page, page?)) -} diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 74e111c9e..a0e2c44a0 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,272 +12,114 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Utc}; -use mas_data_model::{ - AccessToken, Authentication, BrowserSession, RefreshToken, Session, User, UserEmail, -}; -use rand::Rng; -use sqlx::{PgConnection, PgExecutor}; +use async_trait::async_trait; +use mas_data_model::{AccessToken, RefreshToken, Session}; +use rand_core::RngCore; use ulid::Ulid; -use uuid::Uuid; -use super::client::lookup_client; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; +use crate::{repository_impl, Clock}; -#[tracing::instrument( - skip_all, - fields( - %session.id, - user.id = %session.browser_session.user.id, - user_session.id = %session.browser_session.id, - client.id = %session.client.id, - refresh_token.id, - ), - err, -)] -pub async fn add_refresh_token( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - session: &Session, - access_token: AccessToken, - refresh_token: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("refresh_token.id", tracing::field::display(id)); +/// An [`OAuth2RefreshTokenRepository`] helps interacting with [`RefreshToken`] +/// saved in the storage backend +#[async_trait] +pub trait OAuth2RefreshTokenRepository: Send + Sync { + /// The error type returned by the repository + type Error; - sqlx::query!( - r#" - INSERT INTO oauth2_refresh_tokens - (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id, - refresh_token, created_at) - VALUES - ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - Uuid::from(access_token.id), - refresh_token, - created_at, - ) - .execute(executor) - .await?; + /// Lookup a refresh token by its ID + /// + /// Returns `None` if no [`RefreshToken`] was found + /// + /// # Parameters + /// + /// * `id`: The ID of the [`RefreshToken`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - Ok(RefreshToken { - id, - refresh_token, - access_token: Some(access_token), - created_at, - }) + /// Find a refresh token by its token + /// + /// Returns `None` if no [`RefreshToken`] was found + /// + /// # Parameters + /// + /// * `token`: The token of the [`RefreshToken`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error>; + + /// Add a new refresh token to the database + /// + /// Returns the newly created [`RefreshToken`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `session`: The [`Session`] in which to create the [`RefreshToken`] + /// * `access_token`: The [`AccessToken`] created alongside this + /// [`RefreshToken`] + /// * `refresh_token`: The refresh token to store + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + session: &Session, + access_token: &AccessToken, + refresh_token: String, + ) -> Result; + + /// Consume a refresh token + /// + /// Returns the updated [`RefreshToken`] + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `refresh_token`: The [`RefreshToken`] to consume + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails, or if the + /// token was already consumed + async fn consume( + &mut self, + clock: &dyn Clock, + refresh_token: RefreshToken, + ) -> Result; } -struct OAuth2RefreshTokenLookup { - oauth2_refresh_token_id: Uuid, - oauth2_refresh_token: String, - oauth2_refresh_token_created_at: DateTime, - oauth2_access_token_id: Option, - oauth2_access_token: Option, - oauth2_access_token_created_at: Option>, - oauth2_access_token_expires_at: Option>, - oauth2_session_id: Uuid, - oauth2_client_id: Uuid, - oauth2_session_scope: String, - user_session_id: Uuid, - user_session_created_at: DateTime, - user_id: Uuid, - user_username: String, - user_session_last_authentication_id: Option, - user_session_last_authentication_created_at: Option>, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, -} +repository_impl!(OAuth2RefreshTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; -#[tracing::instrument(skip_all, err)] -#[allow(clippy::too_many_lines)] -pub async fn lookup_active_refresh_token( - conn: &mut PgConnection, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - OAuth2RefreshTokenLookup, - r#" - SELECT - rt.oauth2_refresh_token_id, - rt.refresh_token AS oauth2_refresh_token, - rt.created_at AS oauth2_refresh_token_created_at, - at.oauth2_access_token_id AS "oauth2_access_token_id?", - at.access_token AS "oauth2_access_token?", - at.created_at AS "oauth2_access_token_created_at?", - at.expires_at AS "oauth2_access_token_expires_at?", - os.oauth2_session_id AS "oauth2_session_id!", - os.oauth2_client_id AS "oauth2_client_id!", - os.scope AS "oauth2_session_scope!", - us.user_session_id AS "user_session_id!", - us.created_at AS "user_session_created_at!", - u.user_id AS "user_id!", - u.username AS "user_username!", - usa.user_session_authentication_id AS "user_session_last_authentication_id?", - usa.created_at AS "user_session_last_authentication_created_at?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" - FROM oauth2_refresh_tokens rt - INNER JOIN oauth2_sessions os - USING (oauth2_session_id) - LEFT JOIN oauth2_access_tokens at - USING (oauth2_access_token_id) - INNER JOIN user_sessions us - USING (user_session_id) - INNER JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications usa - USING (user_session_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error>; - WHERE rt.refresh_token = $1 - AND rt.consumed_at IS NULL - AND rt.revoked_at IS NULL - AND us.finished_at IS NULL - AND os.finished_at IS NULL + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + session: &Session, + access_token: &AccessToken, + refresh_token: String, + ) -> Result; - ORDER BY usa.created_at DESC - LIMIT 1 - "#, - token, - ) - .fetch_one(&mut *conn) - .await?; - - let access_token = match ( - res.oauth2_access_token_id, - res.oauth2_access_token, - res.oauth2_access_token_created_at, - res.oauth2_access_token_expires_at, - ) { - (None, None, None, None) => None, - (Some(id), Some(access_token), Some(created_at), Some(expires_at)) => { - let id = Ulid::from(id); - Some(AccessToken { - id, - jti: id.to_string(), - access_token, - created_at, - expires_at, - }) - } - _ => return Err(DatabaseInconsistencyError::on("oauth2_access_tokens").into()), - }; - - let refresh_token = RefreshToken { - id: res.oauth2_refresh_token_id.into(), - refresh_token: res.oauth2_refresh_token, - created_at: res.oauth2_refresh_token_created_at, - access_token, - }; - - let session_id = res.oauth2_session_id.into(); - let client = lookup_client(&mut *conn, res.oauth2_client_id.into()) - .await? - .ok_or_else(|| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("client_id") - .row(session_id) - })?; - - let user_id = Ulid::from(res.user_id); - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(user_id) - .into()) - } - }; - - let user = User { - id: user_id, - username: res.user_username, - sub: user_id.to_string(), - primary_email, - }; - - let last_authentication = match ( - res.user_session_last_authentication_id, - res.user_session_last_authentication_created_at, - ) { - (None, None) => None, - (Some(id), Some(created_at)) => Some(Authentication { - id: id.into(), - created_at, - }), - _ => return Err(DatabaseInconsistencyError::on("user_session_authentications").into()), - }; - - let browser_session = BrowserSession { - id: res.user_session_id.into(), - created_at: res.user_session_created_at, - user, - last_authentication, - }; - - let scope = res.oauth2_session_scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("scope") - .row(session_id) - .source(e) - })?; - - let session = Session { - id: session_id, - client, - browser_session, - scope, - }; - - Ok(Some((refresh_token, session))) -} - -#[tracing::instrument( - skip_all, - fields( - %refresh_token.id, - ), - err, -)] -pub async fn consume_refresh_token( - executor: impl PgExecutor<'_>, - clock: &Clock, - refresh_token: &RefreshToken, -) -> Result<(), DatabaseError> { - let consumed_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_refresh_tokens - SET consumed_at = $2 - WHERE oauth2_refresh_token_id = $1 - "#, - Uuid::from(refresh_token.id), - consumed_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} + async fn consume( + &mut self, + clock: &dyn Clock, + refresh_token: RefreshToken, + ) -> Result; +); diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs new file mode 100644 index 000000000..880992a67 --- /dev/null +++ b/crates/storage/src/oauth2/session.rs @@ -0,0 +1,116 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User}; +use rand_core::RngCore; +use ulid::Ulid; + +use crate::{pagination::Page, repository_impl, Clock, Pagination}; + +/// An [`OAuth2SessionRepository`] helps interacting with [`Session`] +/// saved in the storage backend +#[async_trait] +pub trait OAuth2SessionRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup an [`Session`] by its ID + /// + /// Returns `None` if no [`Session`] was found + /// + /// # Parameters + /// + /// * `id`: The ID of the [`Session`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Create a new [`Session`] from an [`AuthorizationGrant`] + /// + /// Returns the newly created [`Session`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `grant`: The [`AuthorizationGrant`] to create the [`Session`] from + /// * `user_session`: The [`BrowserSession`] of the user which completed the + /// authorization + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn create_from_grant( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + grant: &AuthorizationGrant, + user_session: &BrowserSession, + ) -> Result; + + /// Mark a [`Session`] as finished + /// + /// Returns the updated [`Session`] + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `session`: The [`Session`] to mark as finished + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn finish(&mut self, clock: &dyn Clock, session: Session) + -> Result; + + /// Get a paginated list of [`Session`]s for a [`User`] + /// + /// # Parameters + /// + /// * `user`: The [`User`] to get the [`Session`]s for + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; +} + +repository_impl!(OAuth2SessionRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn create_from_grant( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + grant: &AuthorizationGrant, + user_session: &BrowserSession, + ) -> Result; + + async fn finish(&mut self, clock: &dyn Clock, session: Session) + -> Result; + + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; +); diff --git a/crates/storage/src/pagination.rs b/crates/storage/src/pagination.rs index 956556750..d8d8bc1c4 100644 --- a/crates/storage/src/pagination.rs +++ b/crates/storage/src/pagination.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,131 +12,179 @@ // See the License for the specific language governing permissions and // limitations under the License. -use sqlx::{Database, QueryBuilder}; +//! Utilities to manage paginated queries. + use thiserror::Error; use ulid::Ulid; -use uuid::Uuid; +/// An error returned when invalid pagination parameters are provided #[derive(Debug, Error)] #[error("Either 'first' or 'last' must be specified")] pub struct InvalidPagination; -/// Add cursor-based pagination to a query, as used in paginated GraphQL -/// connections -pub fn generate_pagination<'a, DB>( - query: &mut QueryBuilder<'a, DB>, - id_field: &'static str, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(), InvalidPagination> -where - DB: Database, - Uuid: sqlx::Type + sqlx::Encode<'a, DB>, - i64: sqlx::Type + sqlx::Encode<'a, DB>, -{ - // ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564 - // 1. Start from the greedy query: SELECT * FROM table +/// Pagination parameters +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Pagination { + /// The cursor to start from + pub before: Option, - // 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE` - // clause - if let Some(after) = after { - query - .push(" AND ") - .push(id_field) - .push(" > ") - .push_bind(Uuid::from(after)); - } + /// The cursor to end at + pub after: Option, - // 3. If the before argument is provided, add `id < parsed_cursor` to the - // `WHERE` clause - if let Some(before) = before { - query - .push(" AND ") - .push(id_field) - .push(" < ") - .push_bind(Uuid::from(before)); - } + /// The maximum number of items to return + pub count: usize, - // 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to - // the query - if let Some(count) = first { - query - .push(" ORDER BY ") - .push(id_field) - .push(" ASC LIMIT ") - .push_bind((count + 1) as i64); - // 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` - // to the query - } else if let Some(count) = last { - query - .push(" ORDER BY ") - .push(id_field) - .push(" DESC LIMIT ") - .push_bind((count + 1) as i64); - } else { - return Err(InvalidPagination); - } - - Ok(()) + /// In which direction to paginate + pub direction: PaginationDirection, } -/// Process a page returned by a paginated query -pub fn process_page( - mut page: Vec, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), InvalidPagination> { - let limit = match (first, last) { - (Some(count), _) | (_, Some(count)) => count, - _ => return Err(InvalidPagination), - }; +/// The direction to paginate +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PaginationDirection { + /// Paginate forward + Forward, - let is_full = page.len() == (limit + 1); - if is_full { - page.pop(); - } - - let (has_previous_page, has_next_page) = if first.is_some() { - (false, is_full) - } else if last.is_some() { - // 6. If the last argument is provided, I reverse the order of the results - page.reverse(); - (is_full, false) - } else { - unreachable!() - }; - - Ok((has_previous_page, has_next_page, page)) + /// Paginate backward + Backward, } -pub trait QueryBuilderExt { - fn generate_pagination( - &mut self, - id_field: &'static str, +impl Pagination { + /// Creates a new [`Pagination`] from user-provided parameters. + /// + /// # Errors + /// + /// Either `first` or `last` must be provided, else this function will + /// return an [`InvalidPagination`] error. + pub const fn try_new( before: Option, after: Option, first: Option, last: Option, - ) -> Result<&mut Self, InvalidPagination>; -} + ) -> Result { + let (direction, count) = match (first, last) { + (Some(first), _) => (PaginationDirection::Forward, first), + (_, Some(last)) => (PaginationDirection::Backward, last), + (None, None) => return Err(InvalidPagination), + }; -impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB> -where - DB: Database, - Uuid: sqlx::Type + sqlx::Encode<'a, DB>, - i64: sqlx::Type + sqlx::Encode<'a, DB>, -{ - fn generate_pagination( - &mut self, - id_field: &'static str, - before: Option, - after: Option, - first: Option, - last: Option, - ) -> Result<&mut Self, InvalidPagination> { - generate_pagination(self, id_field, before, after, first, last)?; - Ok(self) + Ok(Self { + before, + after, + count, + direction, + }) + } + + /// Creates a [`Pagination`] which gets the first N items + #[must_use] + pub const fn first(first: usize) -> Self { + Self { + before: None, + after: None, + count: first, + direction: PaginationDirection::Forward, + } + } + + /// Creates a [`Pagination`] which gets the last N items + #[must_use] + pub const fn last(last: usize) -> Self { + Self { + before: None, + after: None, + count: last, + direction: PaginationDirection::Backward, + } + } + + /// Get items before the given cursor + #[must_use] + pub const fn before(mut self, id: Ulid) -> Self { + self.before = Some(id); + self + } + + /// Get items after the given cursor + #[must_use] + pub const fn after(mut self, id: Ulid) -> Self { + self.after = Some(id); + self + } + + /// Process a page returned by a paginated query + #[must_use] + pub fn process(&self, mut edges: Vec) -> Page { + let is_full = edges.len() == (self.count + 1); + if is_full { + edges.pop(); + } + + let (has_previous_page, has_next_page) = match self.direction { + PaginationDirection::Forward => (false, is_full), + PaginationDirection::Backward => { + // 6. If the last argument is provided, I reverse the order of the results + edges.reverse(); + (is_full, false) + } + }; + + Page { + has_next_page, + has_previous_page, + edges, + } + } +} + +/// A page of results returned by a paginated query +pub struct Page { + /// When paginating forwards, this is true if there are more items after + pub has_next_page: bool, + + /// When paginating backwards, this is true if there are more items before + pub has_previous_page: bool, + + /// The items in the page + pub edges: Vec, +} + +impl Page { + /// Map the items in this page with the given function + /// + /// # Parameters + /// + /// * `f`: The function to map the items with + #[must_use] + pub fn map(self, f: F) -> Page + where + F: FnMut(T) -> T2, + { + let edges = self.edges.into_iter().map(f).collect(); + Page { + has_next_page: self.has_next_page, + has_previous_page: self.has_previous_page, + edges, + } + } + + /// Try to map the items in this page with the given fallible function + /// + /// # Parameters + /// + /// * `f`: The fallible function to map the items with + /// + /// # Errors + /// + /// Returns the first error encountered while mapping the items + pub fn try_map(self, f: F) -> Result, E> + where + F: FnMut(T) -> Result, + { + let edges: Result, E> = self.edges.into_iter().map(f).collect(); + Ok(Page { + has_next_page: self.has_next_page, + has_previous_page: self.has_previous_page, + edges: edges?, + }) } } diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs new file mode 100644 index 000000000..c76e98665 --- /dev/null +++ b/crates/storage/src/repository.rs @@ -0,0 +1,473 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use futures_util::future::BoxFuture; +use thiserror::Error; + +use crate::{ + compat::{ + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + CompatSsoLoginRepository, + }, + oauth2::{ + OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, + OAuth2RefreshTokenRepository, OAuth2SessionRepository, + }, + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + MapErr, +}; + +/// A [`Repository`] helps interacting with the underlying storage backend. +pub trait Repository: + RepositoryAccess + RepositoryTransaction + Send +where + E: std::error::Error + Send + Sync + 'static, +{ + /// Construct a (boxed) typed-erased repository + fn boxed(self) -> BoxRepository + where + Self: Sync + Sized + 'static, + { + Box::new(self) + } + + /// Map the error type of all the methods of a [`Repository`] + fn map_err(self, mapper: Mapper) -> MapErr + where + Self: Sized, + { + MapErr::new(self, mapper) + } +} + +/// An opaque, type-erased error +#[derive(Debug, Error)] +#[error(transparent)] +pub struct RepositoryError { + source: Box, +} + +impl RepositoryError { + /// Construct a [`RepositoryError`] from any error kind + pub fn from_error(value: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self { + source: Box::new(value), + } + } +} + +/// A type-erased [`Repository`] +pub type BoxRepository = Box + Send + Sync + 'static>; + +/// A [`RepositoryTransaction`] can be saved or cancelled, after a series +/// of operations. +pub trait RepositoryTransaction { + /// The error type used by the [`Self::save`] and [`Self::cancel`] functions + type Error; + + /// Commit the transaction + /// + /// # Errors + /// + /// Returns an error if the underlying storage backend failed to commit the + /// transaction. + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; + + /// Rollback the transaction + /// + /// # Errors + /// + /// Returns an error if the underlying storage backend failed to rollback + /// the transaction. + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; +} + +/// Access the various repositories the backend implements. +/// +/// All the methods return a boxed trait object, which can be used to access a +/// particular repository. The lifetime of the returned object is bound to the +/// lifetime of the whole repository, so that only one mutable reference to the +/// repository is used at a time. +/// +/// When adding a new repository, you should add a new method to this trait, and +/// update the implementations for [`MapErr`] and [`Box`] below. +/// +/// Note: this used to have generic associated types to avoid boxing all the +/// repository traits, but that was removed because it made almost impossible to +/// box the trait object. This might be a shortcoming of the initial +/// implementation of generic associated types, and might be fixed in the +/// future. +pub trait RepositoryAccess: Send { + /// The backend-specific error type used by each repository. + type Error: std::error::Error + Send + Sync + 'static; + + /// Get an [`UpstreamOAuthLinkRepository`] + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c>; + + /// Get an [`UpstreamOAuthProviderRepository`] + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c>; + + /// Get an [`UpstreamOAuthSessionRepository`] + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c>; + + /// Get an [`UserRepository`] + fn user<'c>(&'c mut self) -> Box + 'c>; + + /// Get an [`UserEmailRepository`] + fn user_email<'c>(&'c mut self) -> Box + 'c>; + + /// Get an [`UserPasswordRepository`] + fn user_password<'c>(&'c mut self) + -> Box + 'c>; + + /// Get a [`BrowserSessionRepository`] + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c>; + + /// Get an [`OAuth2ClientRepository`] + fn oauth2_client<'c>(&'c mut self) + -> Box + 'c>; + + /// Get an [`OAuth2AuthorizationGrantRepository`] + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c>; + + /// Get an [`OAuth2SessionRepository`] + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c>; + + /// Get an [`OAuth2AccessTokenRepository`] + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c>; + + /// Get an [`OAuth2RefreshTokenRepository`] + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c>; + + /// Get a [`CompatSessionRepository`] + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c>; + + /// Get a [`CompatSsoLoginRepository`] + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c>; + + /// Get a [`CompatAccessTokenRepository`] + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c>; + + /// Get a [`CompatRefreshTokenRepository`] + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c>; +} + +/// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and +/// [`Repository`] for the [`MapErr`] wrapper and [`Box`] +mod impls { + use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; + + use super::RepositoryAccess; + use crate::{ + compat::{ + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + CompatSsoLoginRepository, + }, + oauth2::{ + OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, + OAuth2ClientRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, + }, + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + user::{ + BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository, + }, + MapErr, Repository, RepositoryTransaction, + }; + + // --- Repository --- + impl Repository for MapErr + where + R: Repository + RepositoryAccess + RepositoryTransaction, + F: FnMut(E1) -> E2 + Send + Sync + 'static, + E1: std::error::Error + Send + Sync + 'static, + E2: std::error::Error + Send + Sync + 'static, + { + } + + // --- RepositoryTransaction -- + impl RepositoryTransaction for MapErr + where + R: RepositoryTransaction, + R::Error: 'static, + F: FnMut(R::Error) -> E + Send + Sync + 'static, + E: std::error::Error, + { + type Error = E; + + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::new(self.inner).save().map_err(self.mapper).boxed() + } + + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::new(self.inner).cancel().map_err(self.mapper).boxed() + } + } + + // --- RepositoryAccess -- + impl RepositoryAccess for MapErr + where + R: RepositoryAccess, + R::Error: 'static, + F: FnMut(R::Error) -> E + Send + Sync + 'static, + E: std::error::Error + Send + Sync + 'static, + { + type Error = E; + + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_link(), + &mut self.mapper, + )) + } + + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_provider(), + &mut self.mapper, + )) + } + + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_session(), + &mut self.mapper, + )) + } + + fn user<'c>(&'c mut self) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user(), &mut self.mapper)) + } + + fn user_email<'c>(&'c mut self) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user_email(), &mut self.mapper)) + } + + fn user_password<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper)) + } + + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper)) + } + + fn oauth2_client<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.oauth2_client(), &mut self.mapper)) + } + + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_authorization_grant(), + &mut self.mapper, + )) + } + + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.oauth2_session(), &mut self.mapper)) + } + + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_access_token(), + &mut self.mapper, + )) + } + + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_refresh_token(), + &mut self.mapper, + )) + } + + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.compat_session(), &mut self.mapper)) + } + + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.compat_sso_login(), &mut self.mapper)) + } + + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.compat_access_token(), + &mut self.mapper, + )) + } + + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.compat_refresh_token(), + &mut self.mapper, + )) + } + } + + impl RepositoryAccess for Box { + type Error = R::Error; + + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).upstream_oauth_link() + } + + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).upstream_oauth_provider() + } + + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).upstream_oauth_session() + } + + fn user<'c>(&'c mut self) -> Box + 'c> { + (**self).user() + } + + fn user_email<'c>(&'c mut self) -> Box + 'c> { + (**self).user_email() + } + + fn user_password<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).user_password() + } + + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).browser_session() + } + + fn oauth2_client<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_client() + } + + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_authorization_grant() + } + + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_session() + } + + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_access_token() + } + + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_refresh_token() + } + + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_session() + } + + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_sso_login() + } + + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_access_token() + } + + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_refresh_token() + } + } +} diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 931b2b7d6..9b8a4f1cd 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,231 +12,136 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Utc}; +use async_trait::async_trait; use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; -use rand::Rng; -use sqlx::{PgExecutor, QueryBuilder}; -use tracing::{info_span, Instrument}; +use rand_core::RngCore; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{process_page, QueryBuilderExt}, - Clock, DatabaseError, LookupResultExt, -}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; -#[derive(sqlx::FromRow)] -struct LinkLookup { - upstream_oauth_link_id: Uuid, - upstream_oauth_provider_id: Uuid, - user_id: Option, - subject: String, - created_at: DateTime, +/// An [`UpstreamOAuthLinkRepository`] helps interacting with +/// [`UpstreamOAuthLink`] with the storage backend +#[async_trait] +pub trait UpstreamOAuthLinkRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup an upstream OAuth link by its ID + /// + /// Returns `None` if the link does not exist + /// + /// # Parameters + /// + /// * `id`: The ID of the upstream OAuth link to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find an upstream OAuth link for a provider by its subject + /// + /// Returns `None` if no matching upstream OAuth link was found + /// + /// # Parameters + /// + /// * `upstream_oauth_provider`: The upstream OAuth provider on which to + /// find the link + /// * `subject`: The subject of the upstream OAuth link to find + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_by_subject( + &mut self, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: &str, + ) -> Result, Self::Error>; + + /// Add a new upstream OAuth link + /// + /// Returns the newly created upstream OAuth link + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `upsream_oauth_provider`: The upstream OAuth provider for which to + /// create the link + /// * `subject`: The subject of the upstream OAuth link to create + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: String, + ) -> Result; + + /// Associate an upstream OAuth link to a user + /// + /// Returns the updated upstream OAuth link + /// + /// # Parameters + /// + /// * `upstream_oauth_link`: The upstream OAuth link to update + /// * `user`: The user to associate to the upstream OAuth link + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn associate_to_user( + &mut self, + upstream_oauth_link: &UpstreamOAuthLink, + user: &User, + ) -> Result<(), Self::Error>; + + /// Get a paginated list of upstream OAuth links on a user + /// + /// # Parameters + /// + /// * `user`: The user for which to get the upstream OAuth links + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; } -impl From for UpstreamOAuthLink { - fn from(value: LinkLookup) -> Self { - UpstreamOAuthLink { - id: Ulid::from(value.upstream_oauth_link_id), - provider_id: Ulid::from(value.upstream_oauth_provider_id), - user_id: value.user_id.map(Ulid::from), - subject: value.subject, - created_at: value.created_at, - } - } -} +repository_impl!(UpstreamOAuthLinkRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; -#[tracing::instrument( - skip_all, - fields(upstream_oauth_link.id = %id), - err, -)] -pub async fn lookup_link( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - LinkLookup, - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - WHERE upstream_oauth_link_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()? - .map(Into::into); + async fn find_by_subject( + &mut self, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: &str, + ) -> Result, Self::Error>; - Ok(res) -} + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: String, + ) -> Result; -#[tracing::instrument( - skip_all, - fields( - upstream_oauth_link.subject = subject, - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - ), - err, -)] -pub async fn lookup_link_by_subject( - executor: impl PgExecutor<'_>, - upstream_oauth_provider: &UpstreamOAuthProvider, - subject: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - LinkLookup, - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - WHERE upstream_oauth_provider_id = $1 - AND subject = $2 - "#, - Uuid::from(upstream_oauth_provider.id), - subject, - ) - .fetch_one(executor) - .await - .to_option()? - .map(Into::into); + async fn associate_to_user( + &mut self, + upstream_oauth_link: &UpstreamOAuthLink, + user: &User, + ) -> Result<(), Self::Error>; - Ok(res) -} - -#[tracing::instrument( - skip_all, - fields( - upstream_oauth_link.id, - upstream_oauth_link.subject = subject, - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - ), - err, -)] -pub async fn add_link( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - upstream_oauth_provider: &UpstreamOAuthProvider, - subject: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO upstream_oauth_links ( - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - ) VALUES ($1, $2, NULL, $3, $4) - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_provider.id), - &subject, - created_at, - ) - .execute(executor) - .await?; - - Ok(UpstreamOAuthLink { - id, - provider_id: upstream_oauth_provider.id, - user_id: None, - subject, - created_at, - }) -} - -#[tracing::instrument( - skip_all, - fields( - %upstream_oauth_link.id, - %upstream_oauth_link.subject, - %user.id, - %user.username, - ), - err, -)] -pub async fn associate_link_to_user( - executor: impl PgExecutor<'_>, - upstream_oauth_link: &UpstreamOAuthLink, - user: &User, -) -> Result<(), sqlx::Error> { - sqlx::query!( - r#" - UPDATE upstream_oauth_links - SET user_id = $1 - WHERE upstream_oauth_link_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(upstream_oauth_link.id), - ) - .execute(executor) - .await?; - - Ok(()) -} - -#[tracing::instrument( - skip_all, - fields(%user.id, %user.username), - err -)] -pub async fn get_paginated_user_links( - executor: impl PgExecutor<'_>, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - "#, - ); - - query - .push(" WHERE user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("upstream_oauth_link_id", before, after, first, last)?; - - let span = info_span!( - "Fetch paginated upstream OAuth 2.0 user links", - db.statement = query.sql() - ); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) - .await?; - - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - - let page: Vec<_> = page.into_iter().map(Into::into).collect(); - Ok((has_previous_page, has_next_page, page)) -} + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; +); diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 4b1d517a6..252217527 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,17 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Repositories to interact with entities related to the upstream OAuth 2.0 +//! providers + mod link; mod provider; mod session; pub use self::{ - link::{ - add_link, associate_link_to_user, get_paginated_user_links, lookup_link, - lookup_link_by_subject, - }, - provider::{add_provider, get_paginated_providers, get_providers, lookup_provider}, - session::{ - add_session, complete_session, consume_session, lookup_session, lookup_session_on_link, - }, + link::UpstreamOAuthLinkRepository, provider::UpstreamOAuthProviderRepository, + session::UpstreamOAuthSessionRepository, }; diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 360b9a4af..663af2c92 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,237 +12,110 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Utc}; +use async_trait::async_trait; use mas_data_model::UpstreamOAuthProvider; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use oauth2_types::scope::Scope; -use rand::Rng; -use sqlx::{PgExecutor, QueryBuilder}; -use tracing::{info_span, Instrument}; +use rand_core::RngCore; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{process_page, QueryBuilderExt}, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; -#[derive(sqlx::FromRow)] -struct ProviderLookup { - upstream_oauth_provider_id: Uuid, - issuer: String, - scope: String, - client_id: String, - encrypted_client_secret: Option, - token_endpoint_signing_alg: Option, - token_endpoint_auth_method: String, - created_at: DateTime, +/// An [`UpstreamOAuthProviderRepository`] helps interacting with +/// [`UpstreamOAuthProvider`] saved in the storage backend +#[async_trait] +pub trait UpstreamOAuthProviderRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup an upstream OAuth provider by its ID + /// + /// Returns `None` if the provider was not found + /// + /// # Parameters + /// + /// * `id`: The ID of the provider to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Add a new upstream OAuth provider + /// + /// Returns the newly created provider + /// + /// # Parameters + /// + /// * `rng`: A random number generator + /// * `clock`: The clock used to generate timestamps + /// * `issuer`: The OIDC issuer of the provider + /// * `scope`: The scope to request during the authorization flow + /// * `token_endpoint_auth_method`: The token endpoint authentication method + /// * `token_endpoint_auth_signing_alg`: The JWT signing algorithm to use + /// when then `client_secret_jwt` or `private_key_jwt` authentication + /// methods are used + /// * `client_id`: The client ID to use when authenticating to the upstream + /// * `encrypted_client_secret`: The encrypted client secret to use when + /// authenticating to the upstream + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option, + ) -> Result; + + /// Get a paginated list of upstream OAuth providers + /// + /// # Parameters + /// + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn list_paginated( + &mut self, + pagination: Pagination, + ) -> Result, Self::Error>; + + /// Get all upstream OAuth providers + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn all(&mut self) -> Result, Self::Error>; } -impl TryFrom for UpstreamOAuthProvider { - type Error = DatabaseInconsistencyError; - fn try_from(value: ProviderLookup) -> Result { - let id = value.upstream_oauth_provider_id.into(); - let scope = value.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("scope") - .row(id) - .source(e) - })?; - let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("token_endpoint_auth_method") - .row(id) - .source(e) - })?; - let token_endpoint_signing_alg = value - .token_endpoint_signing_alg - .map(|x| x.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("token_endpoint_signing_alg") - .row(id) - .source(e) - })?; +repository_impl!(UpstreamOAuthProviderRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - Ok(UpstreamOAuthProvider { - id, - issuer: value.issuer, - scope, - client_id: value.client_id, - encrypted_client_secret: value.encrypted_client_secret, - token_endpoint_auth_method, - token_endpoint_signing_alg, - created_at: value.created_at, - }) - } -} + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option + ) -> Result; -#[tracing::instrument( - skip_all, - fields(upstream_oauth_provider.id = %id), - err, -)] -pub async fn lookup_provider( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - ProviderLookup, - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - WHERE upstream_oauth_provider_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; + async fn list_paginated( + &mut self, + pagination: Pagination + ) -> Result, Self::Error>; - let res = res - .map(UpstreamOAuthProvider::try_from) - .transpose() - .map_err(DatabaseError::from)?; - - Ok(res) -} - -#[tracing::instrument( - skip_all, - fields( - upstream_oauth_provider.id, - upstream_oauth_provider.issuer = %issuer, - upstream_oauth_provider.client_id = %client_id, - ), - err, -)] -#[allow(clippy::too_many_arguments)] -pub async fn add_provider( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - issuer: String, - scope: Scope, - token_endpoint_auth_method: OAuthClientAuthenticationMethod, - token_endpoint_signing_alg: Option, - client_id: String, - encrypted_client_secret: Option, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO upstream_oauth_providers ( - upstream_oauth_provider_id, - issuer, - scope, - token_endpoint_auth_method, - token_endpoint_signing_alg, - client_id, - encrypted_client_secret, - created_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - "#, - Uuid::from(id), - &issuer, - scope.to_string(), - token_endpoint_auth_method.to_string(), - token_endpoint_signing_alg.as_ref().map(ToString::to_string), - &client_id, - encrypted_client_secret.as_deref(), - created_at, - ) - .execute(executor) - .await?; - - Ok(UpstreamOAuthProvider { - id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at, - }) -} - -#[tracing::instrument(skip_all, err)] -pub async fn get_paginated_providers( - executor: impl PgExecutor<'_>, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - WHERE 1 = 1 - "#, - ); - - query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?; - - let span = info_span!( - "Fetch paginated upstream OAuth 2.0 providers", - db.statement = query.sql() - ); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) - .await?; - - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - - let page: Result, _> = page.into_iter().map(TryInto::try_into).collect(); - Ok((has_previous_page, has_next_page, page?)) -} - -#[tracing::instrument(skip_all, err)] -pub async fn get_providers( - executor: impl PgExecutor<'_>, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - ProviderLookup, - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - "#, - ) - .fetch_all(executor) - .await?; - - let res: Result, _> = res.into_iter().map(TryInto::try_into).collect(); - Ok(res?) -} + async fn all(&mut self) -> Result, Self::Error>; +); diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index 5e013f241..2d8f14be7 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,327 +12,134 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Utc}; +use async_trait::async_trait; use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider}; -use rand::Rng; -use sqlx::PgExecutor; +use rand_core::RngCore; use ulid::Ulid; -use uuid::Uuid; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +use crate::{repository_impl, Clock}; -struct SessionAndProviderLookup { - upstream_oauth_authorization_session_id: Uuid, - upstream_oauth_provider_id: Uuid, - upstream_oauth_link_id: Option, - state: String, - code_challenge_verifier: Option, - nonce: String, - id_token: Option, - created_at: DateTime, - completed_at: Option>, - consumed_at: Option>, - provider_issuer: String, - provider_scope: String, - provider_client_id: String, - provider_encrypted_client_secret: Option, - provider_token_endpoint_auth_method: String, - provider_token_endpoint_signing_alg: Option, - provider_created_at: DateTime, +/// An [`UpstreamOAuthSessionRepository`] helps interacting with +/// [`UpstreamOAuthAuthorizationSession`] saved in the storage backend +#[async_trait] +pub trait UpstreamOAuthSessionRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup a session by its ID + /// + /// Returns `None` if the session does not exist + /// + /// # Parameters + /// + /// * `id`: the ID of the session to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup( + &mut self, + id: Ulid, + ) -> Result, Self::Error>; + + /// Add a session to the database + /// + /// Returns the newly created session + /// + /// # Parameters + /// + /// * `rng`: the random number generator to use + /// * `clock`: the clock source + /// * `upstream_oauth_provider`: the upstream OAuth provider for which to + /// create the session + /// * `state`: the authorization grant `state` parameter sent to the + /// upstream OAuth provider + /// * `code_challenge_verifier`: the code challenge verifier used in this + /// session, if PKCE is being used + /// * `nonce`: the `nonce` used in this session + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + state: String, + code_challenge_verifier: Option, + nonce: String, + ) -> Result; + + /// Mark a session as completed and associate the given link + /// + /// Returns the updated session + /// + /// # Parameters + /// + /// * `clock`: the clock source + /// * `upstream_oauth_authorization_session`: the session to update + /// * `upstream_oauth_link`: the link to associate with the session + /// * `id_token`: the ID token returned by the upstream OAuth provider, if + /// present + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn complete_with_link( + &mut self, + clock: &dyn Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result; + + /// Mark a session as consumed + /// + /// Returns the updated session + /// + /// # Parameters + /// + /// * `clock`: the clock source + /// * `upstream_oauth_authorization_session`: the session to consume + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn consume( + &mut self, + clock: &dyn Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + ) -> Result; } -/// Lookup a session and its provider by its ID -#[tracing::instrument( - skip_all, - fields(upstream_oauth_authorization_session.id = %id), - err, -)] -pub async fn lookup_session( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - SessionAndProviderLookup, - r#" - SELECT - ua.upstream_oauth_authorization_session_id, - ua.upstream_oauth_provider_id, - ua.upstream_oauth_link_id, - ua.state, - ua.code_challenge_verifier, - ua.nonce, - ua.id_token, - ua.created_at, - ua.completed_at, - ua.consumed_at, - up.issuer AS "provider_issuer", - up.scope AS "provider_scope", - up.client_id AS "provider_client_id", - up.encrypted_client_secret AS "provider_encrypted_client_secret", - up.token_endpoint_auth_method AS "provider_token_endpoint_auth_method", - up.token_endpoint_signing_alg AS "provider_token_endpoint_signing_alg", - up.created_at AS "provider_created_at" - FROM upstream_oauth_authorization_sessions ua - INNER JOIN upstream_oauth_providers up - USING (upstream_oauth_provider_id) - WHERE upstream_oauth_authorization_session_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; +repository_impl!(UpstreamOAuthSessionRepository: + async fn lookup( + &mut self, + id: Ulid, + ) -> Result, Self::Error>; - let Some(res) = res else { return Ok(None) }; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + state: String, + code_challenge_verifier: Option, + nonce: String, + ) -> Result; - let id = res.upstream_oauth_provider_id.into(); - let provider = UpstreamOAuthProvider { - id, - issuer: res.provider_issuer, - scope: res.provider_scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("scope") - .row(id) - .source(e) - })?, - client_id: res.provider_client_id, - encrypted_client_secret: res.provider_encrypted_client_secret, - token_endpoint_auth_method: res.provider_token_endpoint_auth_method.parse().map_err( - |e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("token_endpoint_auth_method") - .row(id) - .source(e) - }, - )?, - token_endpoint_signing_alg: res - .provider_token_endpoint_signing_alg - .map(|x| x.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("token_endpoint_signing_alg") - .row(id) - .source(e) - })?, - created_at: res.provider_created_at, - }; + async fn complete_with_link( + &mut self, + clock: &dyn Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result; - let session = UpstreamOAuthAuthorizationSession { - id: res.upstream_oauth_authorization_session_id.into(), - provider_id: provider.id, - link_id: res.upstream_oauth_link_id.map(Ulid::from), - state: res.state, - code_challenge_verifier: res.code_challenge_verifier, - nonce: res.nonce, - id_token: res.id_token, - created_at: res.created_at, - completed_at: res.completed_at, - consumed_at: res.consumed_at, - }; - - Ok(Some((provider, session))) -} - -/// Add a session to the database -#[tracing::instrument( - skip_all, - fields( - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - upstream_oauth_authorization_session.id, - ), - err, -)] -pub async fn add_session( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - upstream_oauth_provider: &UpstreamOAuthProvider, - state: String, - code_challenge_verifier: Option, - nonce: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record( - "upstream_oauth_authorization_session.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO upstream_oauth_authorization_sessions ( - upstream_oauth_authorization_session_id, - upstream_oauth_provider_id, - state, - code_challenge_verifier, - nonce, - created_at, - completed_at, - consumed_at, - id_token - ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL) - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_provider.id), - &state, - code_challenge_verifier.as_deref(), - nonce, - created_at, - ) - .execute(executor) - .await?; - - Ok(UpstreamOAuthAuthorizationSession { - id, - provider_id: upstream_oauth_provider.id, - link_id: None, - state, - code_challenge_verifier, - nonce, - id_token: None, - created_at, - completed_at: None, - consumed_at: None, - }) -} - -/// Mark a session as completed and associate the given link -#[tracing::instrument( - skip_all, - fields( - %upstream_oauth_authorization_session.id, - %upstream_oauth_link.id, - ), - err, -)] -pub async fn complete_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, - upstream_oauth_link: &UpstreamOAuthLink, - id_token: Option, -) -> Result { - let completed_at = clock.now(); - sqlx::query!( - r#" - UPDATE upstream_oauth_authorization_sessions - SET upstream_oauth_link_id = $1, - completed_at = $2, - id_token = $3 - WHERE upstream_oauth_authorization_session_id = $4 - "#, - Uuid::from(upstream_oauth_link.id), - completed_at, - id_token, - Uuid::from(upstream_oauth_authorization_session.id), - ) - .execute(executor) - .await?; - - upstream_oauth_authorization_session.completed_at = Some(completed_at); - upstream_oauth_authorization_session.id_token = id_token; - - Ok(upstream_oauth_authorization_session) -} - -/// Mark a session as consumed -#[tracing::instrument( - skip_all, - fields( - %upstream_oauth_authorization_session.id, - ), - err, -)] -pub async fn consume_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, -) -> Result { - let consumed_at = clock.now(); - sqlx::query!( - r#" - UPDATE upstream_oauth_authorization_sessions - SET consumed_at = $1 - WHERE upstream_oauth_authorization_session_id = $2 - "#, - consumed_at, - Uuid::from(upstream_oauth_authorization_session.id), - ) - .execute(executor) - .await?; - - upstream_oauth_authorization_session.consumed_at = Some(consumed_at); - - Ok(upstream_oauth_authorization_session) -} - -struct SessionLookup { - upstream_oauth_authorization_session_id: Uuid, - upstream_oauth_provider_id: Uuid, - upstream_oauth_link_id: Option, - state: String, - code_challenge_verifier: Option, - nonce: String, - id_token: Option, - created_at: DateTime, - completed_at: Option>, - consumed_at: Option>, -} - -/// Lookup a session, which belongs to a link, by its ID -#[tracing::instrument( - skip_all, - fields( - upstream_oauth_authorization_session.id = %id, - %upstream_oauth_link.id, - ), - err, -)] -pub async fn lookup_session_on_link( - executor: impl PgExecutor<'_>, - upstream_oauth_link: &UpstreamOAuthLink, - id: Ulid, -) -> Result, sqlx::Error> { - let res = sqlx::query_as!( - SessionLookup, - r#" - SELECT - upstream_oauth_authorization_session_id, - upstream_oauth_provider_id, - upstream_oauth_link_id, - state, - code_challenge_verifier, - nonce, - id_token, - created_at, - completed_at, - consumed_at - FROM upstream_oauth_authorization_sessions - WHERE upstream_oauth_authorization_session_id = $1 - AND upstream_oauth_link_id = $2 - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_link.id), - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(UpstreamOAuthAuthorizationSession { - id: res.upstream_oauth_authorization_session_id.into(), - provider_id: res.upstream_oauth_provider_id.into(), - link_id: res.upstream_oauth_link_id.map(Ulid::from), - state: res.state, - code_challenge_verifier: res.code_challenge_verifier, - nonce: res.nonce, - id_token: res.id_token, - created_at: res.created_at, - completed_at: res.completed_at, - consumed_at: res.consumed_at, - })) -} + async fn consume( + &mut self, + clock: &dyn Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + ) -> Result; +); diff --git a/crates/storage/src/user/authentication.rs b/crates/storage/src/user/authentication.rs deleted file mode 100644 index 546b54a28..000000000 --- a/crates/storage/src/user/authentication.rs +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink}; -use rand::Rng; -use sqlx::PgExecutor; -use ulid::Ulid; -use uuid::Uuid; - -use crate::Clock; - -#[tracing::instrument( - skip_all, - fields( - user.id = %user_session.user.id, - %user_password.id, - %user_session.id, - user_session_authentication.id, - ), - err, -)] -pub async fn authenticate_session_with_password( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user_session: &mut BrowserSession, - user_password: &Password, -) -> Result<(), sqlx::Error> { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record( - "user_session_authentication.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO user_session_authentications - (user_session_authentication_id, user_session_id, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - Uuid::from(user_session.id), - created_at, - ) - .execute(executor) - .await?; - - user_session.last_authentication = Some(Authentication { id, created_at }); - - Ok(()) -} - -#[tracing::instrument( - skip_all, - fields( - user.id = %user_session.user.id, - %upstream_oauth_link.id, - %user_session.id, - user_session_authentication.id, - ), - err, -)] -pub async fn authenticate_session_with_upstream( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user_session: &mut BrowserSession, - upstream_oauth_link: &UpstreamOAuthLink, -) -> Result<(), sqlx::Error> { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record( - "user_session_authentication.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO user_session_authentications - (user_session_authentication_id, user_session_id, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - Uuid::from(user_session.id), - created_at, - ) - .execute(executor) - .await?; - - user_session.last_authentication = Some(Authentication { id, created_at }); - - Ok(()) -} diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs new file mode 100644 index 000000000..9ae815348 --- /dev/null +++ b/crates/storage/src/user/email.rs @@ -0,0 +1,283 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use mas_data_model::{User, UserEmail, UserEmailVerification}; +use rand_core::RngCore; +use ulid::Ulid; + +use crate::{pagination::Page, repository_impl, Clock, Pagination}; + +/// A [`UserEmailRepository`] helps interacting with [`UserEmail`] saved in the +/// storage backend +#[async_trait] +pub trait UserEmailRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup an [`UserEmail`] by its ID + /// + /// Returns `None` if no [`UserEmail`] was found + /// + /// # Parameters + /// + /// * `id`: The ID of the [`UserEmail`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Lookup an [`UserEmail`] by its email address for a [`User`] + /// + /// Returns `None` if no matching [`UserEmail`] was found + /// + /// # Parameters + /// + /// * `user`: The [`User`] for whom to lookup the [`UserEmail`] + /// * `email`: The email address to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error>; + + /// Get the primary [`UserEmail`] of a [`User`] + /// + /// Returns `None` if no the user has no primary [`UserEmail`] + /// + /// # Parameters + /// + /// * `user`: The [`User`] for whom to lookup the primary [`UserEmail`] + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn get_primary(&mut self, user: &User) -> Result, Self::Error>; + + /// Get all [`UserEmail`] of a [`User`] + /// + /// # Parameters + /// + /// * `user`: The [`User`] for whom to lookup the [`UserEmail`] + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn all(&mut self, user: &User) -> Result, Self::Error>; + + /// List [`UserEmail`] of a [`User`] with the given pagination + /// + /// # Parameters + /// + /// * `user`: The [`User`] for whom to lookup the [`UserEmail`] + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; + + /// Count the [`UserEmail`] of a [`User`] + /// + /// # Parameters + /// + /// * `user`: The [`User`] for whom to count the [`UserEmail`] + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn count(&mut self, user: &User) -> Result; + + /// Create a new [`UserEmail`] for a [`User`] + /// + /// Returns the newly created [`UserEmail`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock to use + /// * `user`: The [`User`] for whom to create the [`UserEmail`] + /// * `email`: The email address of the [`UserEmail`] + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + email: String, + ) -> Result; + + /// Delete a [`UserEmail`] + /// + /// # Parameters + /// + /// * `user_email`: The [`UserEmail`] to delete + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error>; + + /// Mark a [`UserEmail`] as verified + /// + /// Returns the updated [`UserEmail`] + /// + /// # Parameters + /// + /// * `clock`: The clock to use + /// * `user_email`: The [`UserEmail`] to mark as verified + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn mark_as_verified( + &mut self, + clock: &dyn Clock, + user_email: UserEmail, + ) -> Result; + + /// Mark a [`UserEmail`] as primary + /// + /// # Parameters + /// + /// * `user_email`: The [`UserEmail`] to mark as primary + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error>; + + /// Add a [`UserEmailVerification`] for a [`UserEmail`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock to use + /// * `user_email`: The [`UserEmail`] for which to add the + /// [`UserEmailVerification`] + /// * `max_age`: The duration for which the [`UserEmailVerification`] is + /// valid + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add_verification_code( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_email: &UserEmail, + max_age: chrono::Duration, + code: String, + ) -> Result; + + /// Find a [`UserEmailVerification`] for a [`UserEmail`] by its code + /// + /// Returns `None` if no matching [`UserEmailVerification`] was found + /// + /// # Parameters + /// + /// * `clock`: The clock to use + /// * `user_email`: The [`UserEmail`] for which to lookup the + /// [`UserEmailVerification`] + /// * `code`: The code used to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_verification_code( + &mut self, + clock: &dyn Clock, + user_email: &UserEmail, + code: &str, + ) -> Result, Self::Error>; + + /// Consume a [`UserEmailVerification`] + /// + /// Returns the consumed [`UserEmailVerification`] + /// + /// # Parameters + /// + /// * `clock`: The clock to use + /// * `verification`: The [`UserEmailVerification`] to consume + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn consume_verification_code( + &mut self, + clock: &dyn Clock, + verification: UserEmailVerification, + ) -> Result; +} + +repository_impl!(UserEmailRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error>; + async fn get_primary(&mut self, user: &User) -> Result, Self::Error>; + + async fn all(&mut self, user: &User) -> Result, Self::Error>; + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; + async fn count(&mut self, user: &User) -> Result; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + email: String, + ) -> Result; + async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error>; + + async fn mark_as_verified( + &mut self, + clock: &dyn Clock, + user_email: UserEmail, + ) -> Result; + + async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error>; + + async fn add_verification_code( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_email: &UserEmail, + max_age: chrono::Duration, + code: String, + ) -> Result; + + async fn find_verification_code( + &mut self, + clock: &dyn Clock, + user_email: &UserEmail, + code: &str, + ) -> Result, Self::Error>; + + async fn consume_verification_code( + &mut self, + clock: &dyn Clock, + verification: UserEmailVerification, + ) -> Result; +); diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 1b8c2c61d..a611b459c 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,993 +12,98 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Utc}; -use mas_data_model::{ - Authentication, BrowserSession, User, UserEmail, UserEmailVerification, - UserEmailVerificationState, -}; -use rand::Rng; -use sqlx::{PgExecutor, QueryBuilder}; -use tracing::{info_span, Instrument}; +//! Repositories to interact with entities related to user accounts + +use async_trait::async_trait; +use mas_data_model::User; +use rand_core::RngCore; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{process_page, QueryBuilderExt}, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::{repository_impl, Clock}; -mod authentication; +mod email; mod password; +mod session; pub use self::{ - authentication::{authenticate_session_with_password, authenticate_session_with_upstream}, - password::{add_user_password, lookup_user_password}, + email::UserEmailRepository, password::UserPasswordRepository, session::BrowserSessionRepository, }; -#[derive(Debug, Clone)] -struct UserLookup { - user_id: Uuid, - user_username: String, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, +/// A [`UserRepository`] helps interacting with [`User`] saved in the storage +/// backend +#[async_trait] +pub trait UserRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup a [`User`] by its ID + /// + /// Returns `None` if no [`User`] was found + /// + /// # Parameters + /// + /// * `id`: The ID of the [`User`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find a [`User`] by its username + /// + /// Returns `None` if no [`User`] was found + /// + /// # Parameters + /// + /// * `username`: The username of the [`User`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_by_username(&mut self, username: &str) -> Result, Self::Error>; + + /// Create a new [`User`] + /// + /// Returns the newly created [`User`] + /// + /// # Parameters + /// + /// * `rng`: A random number generator to generate the [`User`] ID + /// * `clock`: The clock used to generate timestamps + /// * `username`: The username of the [`User`] + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + username: String, + ) -> Result; + + /// Check if a [`User`] exists + /// + /// Returns `true` if the [`User`] exists, `false` otherwise + /// + /// # Parameters + /// + /// * `username`: The username of the [`User`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn exists(&mut self, username: &str) -> Result; } -#[derive(sqlx::FromRow)] -struct SessionLookup { - user_session_id: Uuid, - user_id: Uuid, - username: String, - created_at: DateTime, - last_authentication_id: Option, - last_authd_at: Option>, - user_email_id: Option, - user_email: Option, - user_email_created_at: Option>, - user_email_confirmed_at: Option>, -} - -impl TryInto for SessionLookup { - type Error = DatabaseInconsistencyError; - - fn try_into(self) -> Result { - let id = Ulid::from(self.user_id); - let primary_email = match ( - self.user_email_id, - self.user_email, - self.user_email_created_at, - self.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(id)) - } - }; - - let user = User { - id, - username: self.username, - sub: id.to_string(), - primary_email, - }; - - let last_authentication = match (self.last_authentication_id, self.last_authd_at) { - (Some(id), Some(created_at)) => Some(Authentication { - id: id.into(), - created_at, - }), - (None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on( - "user_session_authentications", - )) - } - }; - - Ok(BrowserSession { - id: self.user_session_id.into(), - user, - created_at: self.created_at, - last_authentication, - }) - } -} - -#[tracing::instrument( - skip_all, - fields(user_session.id = %id), - err, -)] -pub async fn lookup_active_session( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - SessionLookup, - r#" - SELECT - s.user_session_id, - u.user_id, - u.username, - s.created_at, - a.user_session_authentication_id AS "last_authentication_id?", - a.created_at AS "last_authd_at?", - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" - FROM user_sessions s - INNER JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications a - USING (user_session_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id - WHERE s.user_session_id = $1 AND s.finished_at IS NULL - ORDER BY a.created_at DESC - LIMIT 1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - ), - err, -)] -pub async fn get_paginated_user_sessions( - executor: impl PgExecutor<'_>, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - s.user_session_id, - u.user_id, - u.username, - s.created_at, - a.user_session_authentication_id AS "last_authentication_id", - a.created_at AS "last_authd_at", - ue.user_email_id AS "user_email_id", - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM user_sessions s - INNER JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications a - USING (user_session_id) - LEFT JOIN user_emails ue - ON ue.user_email_id = u.primary_user_email_id - "#, - ); - - query - .push(" WHERE s.finished_at IS NULL AND s.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("s.user_session_id", before, after, first, last)?; - - let span = info_span!("Fetch paginated user emails", db.statement = query.sql()); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) - .await?; - - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - - let page: Result, _> = page.into_iter().map(TryInto::try_into).collect(); - Ok((has_previous_page, has_next_page, page?)) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - user_session.id, - ), - err, -)] -pub async fn start_session( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user: User, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("user_session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO user_sessions (user_session_id, user_id, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - Uuid::from(user.id), - created_at, - ) - .execute(executor) - .await?; - - let session = BrowserSession { - id, - user, - created_at, - last_authentication: None, - }; - - Ok(session) -} - -#[tracing::instrument( - skip_all, - fields(%user.id), - err, -)] -pub async fn count_active_sessions( - executor: impl PgExecutor<'_>, - user: &User, -) -> Result { - let res = sqlx::query_scalar!( - r#" - SELECT COUNT(*) as "count!" - FROM user_sessions s - WHERE s.user_id = $1 AND s.finished_at IS NULL - "#, - Uuid::from(user.id), - ) - .fetch_one(executor) - .await?; - - Ok(res) -} - -#[tracing::instrument( - skip_all, - fields( - user.username = username, - user.id, - ), - err, -)] -pub async fn add_user( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - username: &str, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("user.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO users (user_id, username, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - username, - created_at, - ) - .execute(executor) - .await?; - - Ok(User { - id, - username: username.to_owned(), - sub: id.to_string(), - primary_email: None, - }) -} - -#[tracing::instrument( - skip_all, - fields(%user_session.id), - err, -)] -pub async fn end_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - user_session: &BrowserSession, -) -> Result<(), DatabaseError> { - let now = clock.now(); - let res = sqlx::query!( - r#" - UPDATE user_sessions - SET finished_at = $1 - WHERE user_session_id = $2 - "#, - now, - Uuid::from(user_session.id), - ) - .execute(executor) - .instrument(info_span!("End session")) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} - -#[tracing::instrument( - skip_all, - fields(user.username = username), - err, -)] -pub async fn lookup_user_by_username( - executor: impl PgExecutor<'_>, - username: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - UserLookup, - r#" - SELECT - u.user_id, - u.username AS user_username, - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" - FROM users u - - LEFT JOIN user_emails ue - USING (user_id) - - WHERE u.username = $1 - "#, - username, - ) - .fetch_one(executor) - .instrument(info_span!("Fetch user")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let id = Ulid::from(res.user_id); - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(id) - .into()) - } - }; - - Ok(Some(User { - id, - username: res.user_username, - sub: id.to_string(), - primary_email, - })) -} - -#[tracing::instrument( - skip_all, - fields(user.id = %id), - err, -)] -pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result { - let res = sqlx::query_as!( - UserLookup, - r#" - SELECT - u.user_id, - u.username AS user_username, - ue.user_email_id AS "user_email_id?", - ue.email AS "user_email?", - ue.created_at AS "user_email_created_at?", - ue.confirmed_at AS "user_email_confirmed_at?" - FROM users u - - LEFT JOIN user_emails ue - USING (user_id) - - WHERE u.user_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .instrument(info_span!("Fetch user")) - .await?; - - let id = Ulid::from(res.user_id); - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - id: id.into(), - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(id) - .into()) - } - }; - - Ok(User { - id, - username: res.user_username, - sub: id.to_string(), - primary_email, - }) -} - -#[tracing::instrument( - skip_all, - fields(user.username = username), - err, -)] -pub async fn username_exists( - executor: impl PgExecutor<'_>, - username: &str, -) -> Result { - sqlx::query_scalar!( - r#" - SELECT EXISTS( - SELECT 1 FROM users WHERE username = $1 - ) AS "exists!" - "#, - username - ) - .fetch_one(executor) - .await -} - -#[derive(Debug, Clone, sqlx::FromRow)] -struct UserEmailLookup { - user_email_id: Uuid, - user_email: String, - user_email_created_at: DateTime, - user_email_confirmed_at: Option>, -} - -impl From for UserEmail { - fn from(e: UserEmailLookup) -> UserEmail { - UserEmail { - id: e.user_email_id.into(), - email: e.user_email, - created_at: e.user_email_created_at, - confirmed_at: e.user_email_confirmed_at, - } - } -} - -#[tracing::instrument( - skip_all, - fields(%user.id, %user.username), - err, -)] -pub async fn get_user_emails( - executor: impl PgExecutor<'_>, - user: &User, -) -> Result, sqlx::Error> { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT - ue.user_email_id, - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM user_emails ue - - WHERE ue.user_id = $1 - - ORDER BY ue.email ASC - "#, - Uuid::from(user.id), - ) - .fetch_all(executor) - .instrument(info_span!("Fetch user emails")) - .await?; - - Ok(res.into_iter().map(Into::into).collect()) -} - -#[tracing::instrument( - skip_all, - fields(%user.id, %user.username), - err, -)] -pub async fn count_user_emails( - executor: impl PgExecutor<'_>, - user: &User, -) -> Result { - let res = sqlx::query_scalar!( - r#" - SELECT COUNT(*) - FROM user_emails ue - WHERE ue.user_id = $1 - "#, - Uuid::from(user.id), - ) - .fetch_one(executor) - .instrument(info_span!("Count user emails")) - .await?; - - Ok(res.unwrap_or_default()) -} - -#[tracing::instrument( - skip_all, - fields(%user.id, %user.username), - err, -)] -pub async fn get_paginated_user_emails( - executor: impl PgExecutor<'_>, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - ue.user_email_id, - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM user_emails ue - "#, - ); - - query - .push(" WHERE ue.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("ue.user_email_id", before, after, first, last)?; - - let span = info_span!("Fetch paginated user sessions", db.statement = query.sql()); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) - .await?; - - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - - Ok(( - has_previous_page, - has_next_page, - page.into_iter().map(Into::into).collect(), - )) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - user_email.id = %id, - ), - err, -)] -pub async fn get_user_email( - executor: impl PgExecutor<'_>, - user: &User, - id: Ulid, -) -> Result { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT - ue.user_email_id, - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM user_emails ue - - WHERE ue.user_id = $1 - AND ue.user_email_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(id), - ) - .fetch_one(executor) - .instrument(info_span!("Fetch user emails")) - .await?; - - Ok(res.into()) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - user_email.id, - user_email.email = %email, - ), - err, -)] -pub async fn add_user_email( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user: &User, - email: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("user_email.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO user_emails (user_email_id, user_id, email, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - Uuid::from(user.id), - &email, - created_at, - ) - .execute(executor) - .instrument(info_span!("Add user email")) - .await?; - - Ok(UserEmail { - id, - email, - created_at, - confirmed_at: None, - }) -} - -#[tracing::instrument( - skip_all, - fields( - %user_email.id, - %user_email.email, - ), - err, -)] -pub async fn set_user_email_as_primary( - executor: impl PgExecutor<'_>, - user_email: &UserEmail, -) -> Result<(), sqlx::Error> { - sqlx::query!( - r#" - UPDATE users - SET primary_user_email_id = user_emails.user_email_id - FROM user_emails - WHERE user_emails.user_email_id = $1 - AND users.user_id = user_emails.user_id - "#, - Uuid::from(user_email.id), - ) - .execute(executor) - .instrument(info_span!("Add user email")) - .await?; - - Ok(()) -} - -#[tracing::instrument( - skip_all, - fields( - %user_email.id, - %user_email.email, - ), - err, -)] -pub async fn remove_user_email( - executor: impl PgExecutor<'_>, - user_email: UserEmail, -) -> Result<(), sqlx::Error> { - sqlx::query!( - r#" - DELETE FROM user_emails - WHERE user_emails.user_email_id = $1 - "#, - Uuid::from(user_email.id), - ) - .execute(executor) - .instrument(info_span!("Remove user email")) - .await?; - - Ok(()) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - user_email.email = email, - ), - err, -)] -pub async fn lookup_user_email( - executor: impl PgExecutor<'_>, - user: &User, - email: &str, -) -> Result, sqlx::Error> { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT - ue.user_email_id, - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM user_emails ue - - WHERE ue.user_id = $1 - AND ue.email = $2 - "#, - Uuid::from(user.id), - email, - ) - .fetch_one(executor) - .instrument(info_span!("Lookup user email")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - user_email.id = %id, - ), - err, -)] -pub async fn lookup_user_email_by_id( - executor: impl PgExecutor<'_>, - user: &User, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT - ue.user_email_id, - ue.email AS "user_email", - ue.created_at AS "user_email_created_at", - ue.confirmed_at AS "user_email_confirmed_at" - FROM user_emails ue - - WHERE ue.user_id = $1 - AND ue.user_email_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(id), - ) - .fetch_one(executor) - .instrument(info_span!("Lookup user email")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) -} - -#[tracing::instrument( - skip_all, - fields(%user_email.id), - err, -)] -pub async fn mark_user_email_as_verified( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut user_email: UserEmail, -) -> Result { - let confirmed_at = clock.now(); - sqlx::query!( - r#" - UPDATE user_emails - SET confirmed_at = $2 - WHERE user_email_id = $1 - "#, - Uuid::from(user_email.id), - confirmed_at, - ) - .execute(executor) - .instrument(info_span!("Confirm user email")) - .await?; - - user_email.confirmed_at = Some(confirmed_at); - - Ok(user_email) -} - -struct UserEmailConfirmationCodeLookup { - user_email_confirmation_code_id: Uuid, - code: String, - created_at: DateTime, - expires_at: DateTime, - consumed_at: Option>, -} - -#[tracing::instrument( - skip_all, - fields(%user_email.id), - err, -)] -pub async fn lookup_user_email_verification_code( - executor: impl PgExecutor<'_>, - clock: &Clock, - user_email: UserEmail, - code: &str, -) -> Result, DatabaseError> { - let now = clock.now(); - - let res = sqlx::query_as!( - UserEmailConfirmationCodeLookup, - r#" - SELECT - ec.user_email_confirmation_code_id, - ec.code, - ec.created_at, - ec.expires_at, - ec.consumed_at - FROM user_email_confirmation_codes ec - WHERE ec.code = $1 - AND ec.user_email_id = $2 - "#, - code, - Uuid::from(user_email.id), - ) - .fetch_one(executor) - .instrument(info_span!("Lookup user email verification")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let state = if let Some(when) = res.consumed_at { - UserEmailVerificationState::AlreadyUsed { when } - } else if res.expires_at < now { - UserEmailVerificationState::Expired { - when: res.expires_at, - } - } else { - UserEmailVerificationState::Valid - }; - - Ok(Some(UserEmailVerification { - id: res.user_email_confirmation_code_id.into(), - code: res.code, - email: user_email, - state, - created_at: res.created_at, - })) -} - -#[tracing::instrument( - skip_all, - fields( - %user_email_verification.id, - ), - err, -)] -pub async fn consume_email_verification( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut user_email_verification: UserEmailVerification, -) -> Result { - if !matches!( - user_email_verification.state, - UserEmailVerificationState::Valid - ) { - return Err(DatabaseError::invalid_operation()); - } - - let consumed_at = clock.now(); - - sqlx::query!( - r#" - UPDATE user_email_confirmation_codes - SET consumed_at = $2 - WHERE user_email_confirmation_code_id = $1 - "#, - Uuid::from(user_email_verification.id), - consumed_at - ) - .execute(executor) - .instrument(info_span!("Consume user email verification")) - .await?; - - user_email_verification.state = UserEmailVerificationState::AlreadyUsed { when: consumed_at }; - - Ok(user_email_verification) -} - -#[tracing::instrument( - skip_all, - fields( - %user_email.id, - %user_email.email, - user_email_confirmation.id, - user_email_confirmation.code = code, - ), - err, -)] -pub async fn add_user_email_verification_code( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user_email: UserEmail, - max_age: chrono::Duration, - code: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id)); - let expires_at = created_at + max_age; - - sqlx::query!( - r#" - INSERT INTO user_email_confirmation_codes - (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(user_email.id), - code, - created_at, - expires_at, - ) - .execute(executor) - .instrument(info_span!("Add user email verification code")) - .await?; - - let verification = UserEmailVerification { - id, - email: user_email, - code, - created_at, - state: UserEmailVerificationState::Valid, - }; - - Ok(verification) -} +repository_impl!(UserRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn find_by_username(&mut self, username: &str) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + username: String, + ) -> Result; + async fn exists(&mut self, username: &str) -> Result; +); diff --git a/crates/storage/src/user/password.rs b/crates/storage/src/user/password.rs index 14ac52226..7ef5c7ad8 100644 --- a/crates/storage/src/user/password.rs +++ b/crates/storage/src/user/password.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,124 +12,68 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Utc}; +use async_trait::async_trait; use mas_data_model::{Password, User}; -use rand::Rng; -use sqlx::PgExecutor; -use ulid::Ulid; -use uuid::Uuid; +use rand_core::RngCore; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +use crate::{repository_impl, Clock}; -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - user_password.id, - user_password.version = version, - ), - err, -)] -pub async fn add_user_password( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user: &User, - version: u16, - hashed_password: String, - upgraded_from: Option, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("user_password.id", tracing::field::display(id)); +/// A [`UserPasswordRepository`] helps interacting with [`Password`] saved in +/// the storage backend +#[async_trait] +pub trait UserPasswordRepository: Send + Sync { + /// The error type returned by the repository + type Error; - let upgraded_from_id = upgraded_from.map(|p| p.id); + /// Get the active password for a user + /// + /// Returns `None` if the user has no password set + /// + /// # Parameters + /// + /// * `user`: The user to get the password for + /// + /// # Errors + /// + /// Returns [`Self::Error`] if underlying repository fails + async fn active(&mut self, user: &User) -> Result, Self::Error>; - sqlx::query!( - r#" - INSERT INTO user_passwords - (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at) - VALUES ($1, $2, $3, $4, $5, $6) - "#, - Uuid::from(id), - Uuid::from(user.id), - hashed_password, - i32::from(version), - upgraded_from_id.map(Uuid::from), - created_at, - ) - .execute(executor) - .await?; - - Ok(Password { - id, - hashed_password, - version, - upgraded_from_id, - created_at, - }) + /// Set a new password for a user + /// + /// Returns the newly created [`Password`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `user`: The user to set the password for + /// * `version`: The version of the hashing scheme used + /// * `hashed_password`: The hashed password + /// * `upgraded_from`: The password this password was upgraded from, if any + /// + /// # Errors + /// + /// Returns [`Self::Error`] if underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + version: u16, + hashed_password: String, + upgraded_from: Option<&Password>, + ) -> Result; } -struct UserPasswordLookup { - user_password_id: Uuid, - hashed_password: String, - version: i32, - upgraded_from_id: Option, - created_at: DateTime, -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - ), - err, -)] -pub async fn lookup_user_password( - executor: impl PgExecutor<'_>, - user: &User, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - UserPasswordLookup, - r#" - SELECT up.user_password_id - , up.hashed_password - , up.version - , up.upgraded_from_id - , up.created_at - FROM user_passwords up - WHERE up.user_id = $1 - ORDER BY up.created_at DESC - LIMIT 1 - "#, - Uuid::from(user.id), - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let id = Ulid::from(res.user_password_id); - - let version = res.version.try_into().map_err(|e| { - DatabaseInconsistencyError::on("user_passwords") - .column("version") - .row(id) - .source(e) - })?; - - let upgraded_from_id = res.upgraded_from_id.map(Ulid::from); - let created_at = res.created_at; - let hashed_password = res.hashed_password; - - Ok(Some(Password { - id, - hashed_password, - version, - upgraded_from_id, - created_at, - })) -} +repository_impl!(UserPasswordRepository: + async fn active(&mut self, user: &User) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + version: u16, + hashed_password: String, + upgraded_from: Option<&Password>, + ) -> Result; +); diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs new file mode 100644 index 000000000..5e9defbec --- /dev/null +++ b/crates/storage/src/user/session.rs @@ -0,0 +1,188 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use mas_data_model::{BrowserSession, Password, UpstreamOAuthLink, User}; +use rand_core::RngCore; +use ulid::Ulid; + +use crate::{pagination::Page, repository_impl, Clock, Pagination}; + +/// A [`BrowserSessionRepository`] helps interacting with [`BrowserSession`] +/// saved in the storage backend +#[async_trait] +pub trait BrowserSessionRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup a [`BrowserSession`] by its ID + /// + /// Returns `None` if the session is not found + /// + /// # Parameters + /// + /// * `id`: The ID of the session to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Create a new [`BrowserSession`] for a [`User`] + /// + /// Returns the newly created [`BrowserSession`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `user`: The user to create the session for + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + ) -> Result; + + /// Finish a [`BrowserSession`] + /// + /// Returns the finished session + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `user_session`: The session to finish + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn finish( + &mut self, + clock: &dyn Clock, + user_session: BrowserSession, + ) -> Result; + + /// List active [`BrowserSession`] for a [`User`] with the given pagination + /// + /// # Parameters + /// + /// * `user`: The user to list the sessions for + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn list_active_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; + + /// Count active [`BrowserSession`] for a [`User`] + /// + /// # Parameters + /// + /// * `user`: The user to count the sessions for + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn count_active(&mut self, user: &User) -> Result; + + /// Authenticate a [`BrowserSession`] with the given [`Password`] + /// + /// Returns the updated [`BrowserSession`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `user_session`: The session to authenticate + /// * `user_password`: The password which was used to authenticate + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn authenticate_with_password( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_session: BrowserSession, + user_password: &Password, + ) -> Result; + + /// Authenticate a [`BrowserSession`] with the given [`UpstreamOAuthLink`] + /// + /// Returns the updated [`BrowserSession`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock used to generate timestamps + /// * `user_session`: The session to authenticate + /// * `upstream_oauth_link`: The upstream OAuth link which was used to + /// authenticate + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn authenticate_with_upstream( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_session: BrowserSession, + upstream_oauth_link: &UpstreamOAuthLink, + ) -> Result; +} + +repository_impl!(BrowserSessionRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + ) -> Result; + async fn finish( + &mut self, + clock: &dyn Clock, + user_session: BrowserSession, + ) -> Result; + async fn list_active_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; + async fn count_active(&mut self, user: &User) -> Result; + + async fn authenticate_with_password( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_session: BrowserSession, + user_password: &Password, + ) -> Result; + + async fn authenticate_with_upstream( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_session: BrowserSession, + upstream_oauth_link: &UpstreamOAuthLink, + ) -> Result; +); diff --git a/crates/storage/src/utils.rs b/crates/storage/src/utils.rs new file mode 100644 index 000000000..44caa23ea --- /dev/null +++ b/crates/storage/src/utils.rs @@ -0,0 +1,86 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Wrappers and useful type aliases + +use rand_core::CryptoRngCore; + +use crate::Clock; + +/// A wrapper which is used to map the error type of a repository to another +pub struct MapErr { + pub(crate) inner: R, + pub(crate) mapper: F, + _private: (), +} + +impl MapErr { + pub(crate) fn new(inner: R, mapper: F) -> Self { + Self { + inner, + mapper, + _private: (), + } + } +} + +/// A boxed [`Clock`] +pub type BoxClock = Box; + +/// A boxed random number generator +pub type BoxRng = Box; + +/// A macro to implement a repository trait for the [`MapErr`] wrapper and for +/// [`Box`] +#[macro_export] +macro_rules! repository_impl { + ($repo_trait:ident: + $( + async fn $method:ident ( + &mut self + $(, $arg:ident: $arg_ty:ty )* + $(,)? + ) -> Result<$ret_ty:ty, Self::Error>; + )* + ) => { + #[::async_trait::async_trait] + impl $repo_trait for ::std::boxed::Box + where + R: $repo_trait, + { + type Error = ::Error; + + $( + async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> { + (**self).$method ( $($arg),* ).await + } + )* + } + + #[::async_trait::async_trait] + impl $repo_trait for $crate::MapErr + where + R: $repo_trait, + F: FnMut(::Error) -> E + ::std::marker::Send + ::std::marker::Sync, + { + type Error = E; + + $( + async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> { + self.inner.$method ( $($arg),* ).await.map_err(&mut self.mapper) + } + )* + } + }; +} diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index 502a5ea37..9abf36f0a 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -14,3 +14,4 @@ tracing = "0.1.37" sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } mas-storage = { path = "../storage" } +mas-storage-pg = { path = "../storage-pg" } diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index 5e72141e5..ebade53af 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -14,14 +14,15 @@ //! Database-related tasks -use mas_storage::Clock; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, RepositoryAccess, SystemClock}; +use mas_storage_pg::PgRepository; use sqlx::{Pool, Postgres}; use tracing::{debug, error, info}; use super::Task; #[derive(Clone)] -struct CleanupExpired(Pool, Clock); +struct CleanupExpired(Pool, SystemClock); impl std::fmt::Debug for CleanupExpired { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -32,7 +33,13 @@ impl std::fmt::Debug for CleanupExpired { #[async_trait::async_trait] impl Task for CleanupExpired { async fn run(&self) { - let res = mas_storage::oauth2::access_token::cleanup_expired(&self.0, &self.1).await; + let res = async move { + let mut repo = PgRepository::from_pool(&self.0).await?; + let res = repo.oauth2_access_token().cleanup_expired(&self.1).await; + res + } + .await; + match res { Ok(0) => { debug!("no token to clean up"); @@ -51,5 +58,5 @@ impl Task for CleanupExpired { #[must_use] pub fn cleanup_expired(pool: &Pool) -> impl Task + Clone { // XXX: the clock should come from somewhere else - CleanupExpired(pool.clone(), Clock::default()) + CleanupExpired(pool.clone(), SystemClock::default()) } diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 642a760bc..cd70289b6 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -532,14 +532,14 @@ where { /// Context used by the `account/index.html` template #[derive(Serialize)] pub struct AccountContext { - active_sessions: i64, + active_sessions: usize, emails: Vec, } impl AccountContext { /// Constructs a context for the "my account" page #[must_use] - pub fn new(active_sessions: i64, emails: Vec) -> Self { + pub fn new(active_sessions: usize, emails: Vec) -> Self { Self { active_sessions, emails, @@ -618,6 +618,7 @@ impl TemplateContext for EmailVerificationContext { .map(|user| { let email = UserEmail { id: Ulid::from_datetime_with_source(now.into(), rng), + user_id: user.id, email: "foobar@example.com".to_owned(), created_at: now, confirmed_at: None, @@ -625,8 +626,8 @@ impl TemplateContext for EmailVerificationContext { let verification = UserEmailVerification { id: Ulid::from_datetime_with_source(now.into(), rng), + user_email_id: email.id, code: "123456".to_owned(), - email, created_at: now, state: mas_data_model::UserEmailVerificationState::Valid, }; @@ -684,6 +685,7 @@ impl TemplateContext for EmailVerificationPageContext { { let email = UserEmail { id: Ulid::from_datetime_with_source(now.into(), rng), + user_id: Ulid::from_datetime_with_source(now.into(), rng), email: "foobar@example.com".to_owned(), created_at: now, confirmed_at: None, diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 496bcc26a..938a701a4 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -96,6 +96,12 @@ impl Templates { } /// Load the templates from the given config + #[tracing::instrument( + name = "templates.load", + skip_all, + fields(%path), + err, + )] pub async fn load( path: Utf8PathBuf, url_builder: UrlBuilder, @@ -110,14 +116,17 @@ impl Templates { async fn load_(path: &Utf8Path, url_builder: UrlBuilder) -> Result { let path = path.to_owned(); + let span = tracing::Span::current(); // This uses blocking I/Os, do that in a blocking task let mut tera = tokio::task::spawn_blocking(move || { - let path = path.canonicalize_utf8()?; - let path = format!("{path}/**/*.{{html,txt,subject}}"); + span.in_scope(move || { + let path = path.canonicalize_utf8()?; + let path = format!("{path}/**/*.{{html,txt,subject}}"); - info!(%path, "Loading templates from filesystem"); - Tera::new(&path) + info!(%path, "Loading templates from filesystem"); + Tera::new(&path) + }) }) .await??; @@ -138,7 +147,13 @@ impl Templates { } /// Reload the templates on disk - pub async fn reload(&self) -> anyhow::Result<()> { + #[tracing::instrument( + name = "templates.reload", + skip_all, + fields(path = %self.path), + err, + )] + pub async fn reload(&self) -> Result<(), TemplateLoadingError> { // Prepare the new Tera instance let new_tera = Self::load_(&self.path, self.url_builder.clone()).await?; diff --git a/docs/development/architecture.md b/docs/development/architecture.md index f15b31203..691c881f4 100644 --- a/docs/development/architecture.md +++ b/docs/development/architecture.md @@ -10,17 +10,31 @@ The whole repository is a [Cargo Workspace](https://doc.rust-lang.org/book/ch14- This includes: - `mas-cli`: Command line utility, main entry point - - `mas-config`: Configuration parsing and loading - - `mas-data-model`: Models of objects that live in the database, regardless of the storage backend - - `mas-email`: High-level email sending abstraction - - `mas-handlers`: Main HTTP application logic - - `mas-iana`: Auto-generated enums from IANA registries - - `mas-iana-codegen`: Code generator for the `mas-iana` crate - - `mas-jose`: JWT/JWS/JWE/JWK abstraction - - `mas-static-files`: Frontend static files (CSS/JS). Includes some frontend tooling - - `mas-storage`: Interactions with the database - - `mas-tasks`: Asynchronous task runner and scheduler - - `oauth2-types`: Useful structures and types to deal with OAuth 2.0/OpenID Connect endpoints. This might end up published as a standalone library as it can be useful in other contexts. + - [`mas-config`][mas-config]: Configuration parsing and loading + - [`mas-data-model`][mas-data-model]: Models of objects that live in the database, regardless of the storage backend + - [`mas-email`][mas-email]: High-level email sending abstraction + - [`mas-handlers`][mas-handlers]: Main HTTP application logic + - [`mas-iana`][mas-iana]: Auto-generated enums from IANA registries + - [`mas-iana-codegen`][mas-iana-codegen]: Code generator for the `mas-iana` crate + - [`mas-jose`][mas-jose]: JWT/JWS/JWE/JWK abstraction + - [`mas-static-files`][mas-static-files]: Frontend static files (CSS/JS). Includes some frontend tooling + - [`mas-storage`][mas-storage]: Abstraction of the storage backends + - [`mas-storage-pg`][mas-storage-pg]: Storage backend implementation for a PostgreSQL database + - [`mas-tasks`][mas-tasks]: Asynchronous task runner and scheduler + - [`oauth2-types`][oauth2-types]: Useful structures and types to deal with OAuth 2.0/OpenID Connect endpoints. This might end up published as a standalone library as it can be useful in other contexts. + +[mas-config]: ../rustdoc/mas_config/index.html +[mas-data-model]: ../rustdoc/mas_data_model/index.html +[mas-email]: ../rustdoc/mas_email/index.html +[mas-handlers]: ../rustdoc/mas_handlers/index.html +[mas-iana]: ../rustdoc/mas_iana/index.html +[mas-iana-codegen]: ../rustdoc/mas_iana_codegen/index.html +[mas-jose]: ../rustdoc/mas_jose/index.html +[mas-static-files]: ../rustdoc/mas_static_files/index.html +[mas-storage]: ../rustdoc/mas_storage/index.html +[mas-storage-pg]: ../rustdoc/mas_storage/index.html +[mas-tasks]: ../rustdoc/mas_tasks/index.html +[oauth2-types]: ../rustdoc/oauth2_types/index.html ## Important crates diff --git a/docs/development/database.md b/docs/development/database.md index e9583124a..5ffe8a5a8 100644 --- a/docs/development/database.md +++ b/docs/development/database.md @@ -3,6 +3,25 @@ Interactions with the database goes through `sqlx`. It provides async database operations with connection pooling, migrations support and compile-time check of queries through macros. +## Writing database interactions + +All database interactions are done through repositoriy traits. Each repository trait usually manages one type of data, defined in the [`mas-data-model`][mas-data-model] crate. + +Defining a new data type and associated repository looks like this: + + - Define new structs in [`mas-data-model`][mas-data-model] crate + - Define the repository trait in [`mas-storage`][mas-storage] crate + - Make that repository trait available via the `RepositoryAccess` trait in [`mas-storage`][mas-storage] crate + - Setup the database schema by writing a migration file in [`mas-storage-pg`][mas-storage-pg] crate + - Implement the new repository trait in [`mas-storage-pg`][mas-storage-pg] crate + - Write tests for the PostgreSQL implementation in [`mas-storage-pg`][mas-storage-pg] crate + +Some of those steps are documented in more details in the [`mas-storage`][mas-storage] and [`mas-storage-pg`][mas-storage-pg] crates. + +[mas-data-model]: ../rustdoc/mas_data_model/index.html +[mas-storage]: ../rustdoc/mas_storage/index.html +[mas-storage-pg]: ../rustdoc/mas_storage_pg/index.html + ## Compile-time check of queries To be able to check queries, `sqlx` has to introspect the live database. @@ -14,7 +33,7 @@ Preparing this flat file is done through `sqlx-cli`, and should be done everytim # Install the CLI cargo install sqlx-cli --no-default-features --features postgres -cd crates/storage/ # Must be in the mas-storage crate folder +cd crates/storage-pg/ # Must be in the mas-storage-pg crate folder export DATABASE_URL=postgresql:///matrix_auth cargo sqlx prepare ``` @@ -24,73 +43,10 @@ cargo sqlx prepare Migration files live in the `migrations` folder in the `mas-core` crate. ```sh -cd crates/storage/ # Again, in the mas-storage crate folder +cd crates/storage-pg/ # Again, in the mas-storage-pg crate folder export DATABASE_URL=postgresql:///matrix_auth cargo sqlx migrate run # Run pending migrations -cargo sqlx migrate revert # Revert the last migration -cargo sqlx migrate add -r [description] # Add new migration files +cargo sqlx migrate add [description] # Add new migration files ``` Note that migrations are embedded in the final binary and can be run from the service CLI tool. - -## Writing database interactions - -A typical interaction with the database look like this: - -```rust -pub async fn lookup_session( - executor: impl Executor<'_, Database = Postgres>, - id: i64, -) -> anyhow::Result { - sqlx::query_as!( - SessionInfo, // Struct that will be filled with the result - r#" - SELECT - s.id, - u.id as user_id, - u.username, - s.active, - s.created_at, - a.created_at as "last_authd_at?" - FROM user_sessions s - INNER JOIN users u - ON s.user_id = u.id - LEFT JOIN user_session_authentications a - ON a.session_id = s.id - WHERE s.id = $1 - ORDER BY a.created_at DESC - LIMIT 1 - "#, - id, // Query parameter - ) - .fetch_one(executor) - .await - // Providing some context when there is an error - .context("could not fetch session") -} -``` - -Note that we pass an `impl Executor` as parameter here. -This allows us to use this function from either a simple connection or from an active transaction. - -The caveat here is that the `executor` can be used only once, so if an interaction needs to do multiple queries, it should probably take an `impl Acquire` to then acquire a transaction and do multiple interactions. - -```rust -pub async fn login( - conn: impl Acquire<'_, Database = Postgres>, - username: &str, - password: String, -) -> Result { - let mut txn = conn.begin().await.context("could not start transaction")?; - // First interaction - let user = lookup_user_by_username(&mut txn, username)?; - // Second interaction - let mut session = start_session(&mut txn, user).await?; - // Third interaction - session.last_authd_at = - Some(authenticate_session(&mut txn, session.id, password).await?); - // Commit the transaction once everything went fine - txn.commit().await.context("could not commit transaction")?; - Ok(session) -} -```