Merge pull request #733 from matrix-org/quenting/storage-repository

Repository pattern
This commit is contained in:
Quentin Gliech
2023-01-27 11:59:08 +01:00
committed by GitHub
146 changed files with 15675 additions and 9358 deletions

View File

@@ -1,30 +1,69 @@
name: Deploy the documentation
name: Build and deploy the documentation
on:
push:
branches:
- main
branches: [ main ]
pull_request:
branches: [ main ]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
CARGO_TERM_COLOR: always
CARGO_NET_GIT_FETCH_WITH_CLI: "true"
jobs:
pages:
name: GitHub Pages
build:
name: Build the documentation
runs-on: ubuntu-latest
steps:
- name: Checkout the code
uses: actions/checkout@v3
- name: Install Rust toolchain
run: |
rustup toolchain install nightly
rustup default nightly
- name: Setup Rust cache
uses: Swatinem/rust-cache@v2
- name: Setup mdBook
uses: peaceiris/actions-mdbook@adeb05db28a0c0004681db83893d56c0388ea9ea # v1.1.14
uses: peaceiris/actions-mdbook@v1.2.0
with:
mdbook-version: '0.4.12'
mdbook-version: '0.4.25'
- name: Build the documentation
run: mdbook build
- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@64b46b4226a4a12da2239ba3ea5aa73e3163c75b # v3.8.0
- name: Build rustdoc
run: cargo doc -Zrustdoc-map --workspace --lib --no-deps
- name: Move the Rust documentation within the mdBook
run: mv target/doc target/book/rustdoc
- name: Upload GitHub Pages artifacts
uses: actions/upload-pages-artifact@v1.0.7
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: ./target/book
path: target/book/
deploy:
name: Deploy the documentation on GitHub Pages
runs-on: ubuntu-latest
needs: build
if: github.ref == 'refs/heads/main'
permissions:
pages: write
id-token: write
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v1.2.3

30
Cargo.lock generated
View File

@@ -2688,7 +2688,6 @@ dependencies = [
"serde_json",
"serde_urlencoded",
"serde_with",
"sqlx",
"thiserror",
"tokio",
"tower",
@@ -2721,6 +2720,7 @@ dependencies = [
"mas-router",
"mas-spa",
"mas-storage",
"mas-storage-pg",
"mas-tasks",
"mas-templates",
"oauth2-types",
@@ -2821,8 +2821,8 @@ dependencies = [
"mas-storage",
"oauth2-types",
"serde",
"sqlx",
"thiserror",
"tokio",
"tracing",
"ulid",
"url",
@@ -2859,6 +2859,7 @@ dependencies = [
"mas-policy",
"mas-router",
"mas-storage",
"mas-storage-pg",
"mas-templates",
"mime",
"oauth2-types",
@@ -3112,12 +3113,33 @@ dependencies = [
name = "mas-storage"
version = "0.1.0"
dependencies = [
"async-trait",
"chrono",
"futures-util",
"mas-data-model",
"mas-iana",
"mas-jose",
"oauth2-types",
"rand_core 0.6.4",
"thiserror",
"ulid",
"url",
]
[[package]]
name = "mas-storage-pg"
version = "0.1.0"
dependencies = [
"async-trait",
"chrono",
"futures-util",
"mas-data-model",
"mas-iana",
"mas-jose",
"mas-storage",
"oauth2-types",
"rand 0.8.5",
"rand_chacha 0.3.1",
"serde",
"serde_json",
"sqlx",
@@ -3135,6 +3157,7 @@ dependencies = [
"async-trait",
"futures-util",
"mas-storage",
"mas-storage-pg",
"sqlx",
"tokio",
"tokio-stream",
@@ -5590,8 +5613,7 @@ checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81"
[[package]]
name = "ulid"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13a3aaa69b04e5b66cc27309710a569ea23593612387d67daaf102e73aa974fd"
source = "git+https://github.com/dylanhart/ulid-rs.git?rev=0b9295c2db2114cd87aa19abcc1fc00c16b272db#0b9295c2db2114cd87aa19abcc1fc00c16b272db"
dependencies = [
"rand 0.8.5",
"serde",

View File

@@ -7,3 +7,8 @@ opt-level = 3
[profile.dev.package.sqlx-macros]
opt-level = 3
# Until https://github.com/dylanhart/ulid-rs/pull/56 gets released
[patch.crates-io.ulid]
git = "https://github.com/dylanhart/ulid-rs.git"
rev = "0b9295c2db2114cd87aa19abcc1fc00c16b272db"

View File

@@ -1,4 +1,4 @@
doc-valid-idents = ["OpenID", "OAuth", ".."]
doc-valid-idents = ["OpenID", "OAuth", "..", "PostgreSQL"]
disallowed-methods = [
{ path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" },

View File

@@ -21,7 +21,6 @@ serde = "1.0.152"
serde_with = "2.1.0"
serde_urlencoded = "0.7.1"
serde_json = "1.0.91"
sqlx = "0.6.2"
thiserror = "1.0.38"
tokio = "1.24.1"
tower = { version = "0.4.13", features = ["util"] }

View File

@@ -31,10 +31,9 @@ use mas_http::HttpServiceExt;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
use mas_keystore::Encrypter;
use mas_storage::{oauth2::client::lookup_client_by_client_id, DatabaseError};
use mas_storage::{oauth2::OAuth2ClientRepository, RepositoryAccess};
use serde::{de::DeserializeOwned, Deserialize};
use serde_json::Value;
use sqlx::PgExecutor;
use thiserror::Error;
use tower::{Service, ServiceExt};
@@ -73,10 +72,10 @@ pub enum Credentials {
}
impl Credentials {
pub async fn fetch(
pub async fn fetch<E>(
&self,
executor: impl PgExecutor<'_>,
) -> Result<Option<Client>, DatabaseError> {
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Option<Client>, E> {
let client_id = match self {
Credentials::None { client_id }
| Credentials::ClientSecretBasic { client_id, .. }
@@ -84,7 +83,7 @@ impl Credentials {
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
};
lookup_client_by_client_id(executor, client_id).await
repo.oauth2_client().find_by_client_id(client_id).await
}
#[tracing::instrument(skip_all, err)]

View File

@@ -15,6 +15,7 @@
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
use chrono::{DateTime, Duration, Utc};
use data_encoding::{DecodeError, BASE64URL_NOPAD};
use mas_storage::Clock;
use rand::{Rng, RngCore};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, TimestampSeconds};
@@ -108,22 +109,27 @@ pub struct ProtectedForm<T> {
}
pub trait CsrfExt {
fn csrf_token<R>(self, now: DateTime<Utc>, rng: R) -> (CsrfToken, Self)
fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
where
R: RngCore;
fn verify_form<T>(&self, now: DateTime<Utc>, form: ProtectedForm<T>) -> Result<T, CsrfError>;
R: RngCore,
C: Clock;
fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
where
C: Clock;
}
impl<K> CsrfExt for PrivateCookieJar<K> {
fn csrf_token<R>(self, now: DateTime<Utc>, rng: R) -> (CsrfToken, Self)
fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
where
R: RngCore,
C: Clock,
{
let jar = self;
let mut cookie = jar.get("csrf").unwrap_or_else(|| Cookie::new("csrf", ""));
cookie.set_path("/");
cookie.set_http_only(true);
let now = clock.now();
let new_token = cookie
.decode()
.ok()
@@ -136,10 +142,13 @@ impl<K> CsrfExt for PrivateCookieJar<K> {
(new_token, jar)
}
fn verify_form<T>(&self, now: DateTime<Utc>, form: ProtectedForm<T>) -> Result<T, CsrfError> {
fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
where
C: Clock,
{
let cookie = self.get("csrf").ok_or(CsrfError::Missing)?;
let token: CsrfToken = cookie.decode()?;
let token = token.verify_expiration(now)?;
let token = token.verify_expiration(clock.now())?;
token.verify_form_value(&form.csrf)?;
Ok(form.inner)
}

View File

@@ -56,7 +56,7 @@ impl HttpClientFactory {
Ok(layer.layer(client))
}
/// Constructs a new [`HttpService`], suitable for [`mas_oidc_client`]
/// Constructs a new [`HttpService`], suitable for `mas-oidc-client`
///
/// # Errors
///

View File

@@ -14,9 +14,8 @@
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
use mas_data_model::BrowserSession;
use mas_storage::{user::lookup_active_session, DatabaseError};
use mas_storage::{user::BrowserSessionRepository, RepositoryAccess};
use serde::{Deserialize, Serialize};
use sqlx::{Executor, Postgres};
use ulid::Ulid;
use crate::CookieExt;
@@ -44,18 +43,24 @@ impl SessionInfo {
}
/// Load the [`BrowserSession`] from database
pub async fn load_session(
pub async fn load_session<E>(
&self,
executor: impl Executor<'_, Database = Postgres>,
) -> Result<Option<BrowserSession>, DatabaseError> {
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Option<BrowserSession>, E> {
let session_id = if let Some(id) = self.current {
id
} else {
return Ok(None);
};
let res = lookup_active_session(executor, session_id).await?;
Ok(res)
let maybe_session = repo
.browser_session()
.lookup(session_id)
.await?
// Ensure that the session is still active
.filter(BrowserSession::active);
Ok(maybe_session)
}
}

View File

@@ -27,9 +27,11 @@ use axum::{
use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName};
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
use mas_data_model::Session;
use mas_storage::{oauth2::access_token::lookup_active_access_token, DatabaseError};
use mas_storage::{
oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
Clock, RepositoryAccess,
};
use serde::{de::DeserializeOwned, Deserialize};
use sqlx::PgConnection;
use thiserror::Error;
#[derive(Debug, Deserialize)]
@@ -49,16 +51,24 @@ enum AccessToken {
}
impl AccessToken {
pub async fn fetch(
async fn fetch<E>(
&self,
conn: &mut PgConnection,
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> {
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> {
let token = match self {
AccessToken::Form(t) | AccessToken::Header(t) => t,
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
};
let (token, session) = lookup_active_access_token(conn, token.as_str())
let token = repo
.oauth2_access_token()
.find_by_token(token.as_str())
.await?
.ok_or(AuthorizationVerificationError::InvalidToken)?;
let session = repo
.oauth2_session()
.lookup(token.session_id)
.await?
.ok_or(AuthorizationVerificationError::InvalidToken)?;
@@ -74,26 +84,36 @@ pub struct UserAuthorization<F = ()> {
impl<F: Send> UserAuthorization<F> {
// TODO: take scopes to validate as parameter
pub async fn protected_form(
pub async fn protected_form<E>(
self,
conn: &mut PgConnection,
) -> Result<(Session, F), AuthorizationVerificationError> {
repo: &mut impl RepositoryAccess<Error = E>,
clock: &impl Clock,
) -> Result<(Session, F), AuthorizationVerificationError<E>> {
let form = match self.form {
Some(f) => f,
None => return Err(AuthorizationVerificationError::MissingForm),
};
let (_token, session) = self.access_token.fetch(conn).await?;
let (token, session) = self.access_token.fetch(repo).await?;
if !token.is_valid(clock.now()) || !session.is_valid() {
return Err(AuthorizationVerificationError::InvalidToken);
}
Ok((session, form))
}
// TODO: take scopes to validate as parameter
pub async fn protected(
pub async fn protected<E>(
self,
conn: &mut PgConnection,
) -> Result<Session, AuthorizationVerificationError> {
let (_token, session) = self.access_token.fetch(conn).await?;
repo: &mut impl RepositoryAccess<Error = E>,
clock: &impl Clock,
) -> Result<Session, AuthorizationVerificationError<E>> {
let (token, session) = self.access_token.fetch(repo).await?;
if !token.is_valid(clock.now()) || !session.is_valid() {
return Err(AuthorizationVerificationError::InvalidToken);
}
Ok(session)
}
@@ -107,7 +127,7 @@ pub enum UserAuthorizationError {
}
#[derive(Debug, Error)]
pub enum AuthorizationVerificationError {
pub enum AuthorizationVerificationError<E> {
#[error("missing token")]
MissingToken,
@@ -118,7 +138,7 @@ pub enum AuthorizationVerificationError {
MissingForm,
#[error(transparent)]
Internal(#[from] DatabaseError),
Internal(#[from] E),
}
enum BearerError {
@@ -226,7 +246,10 @@ impl IntoResponse for UserAuthorizationError {
}
}
impl IntoResponse for AuthorizationVerificationError {
impl<E> IntoResponse for AuthorizationVerificationError<E>
where
E: ToString,
{
fn into_response(self) -> Response {
match self {
Self::MissingForm | Self::MissingToken => {

View File

@@ -50,6 +50,7 @@ mas-policy = { path = "../policy" }
mas-router = { path = "../router" }
mas-spa = { path = "../spa" }
mas-storage = { path = "../storage" }
mas-storage-pg = { path = "../storage-pg" }
mas-tasks = { path = "../tasks" }
mas-templates = { path = "../templates" }
oauth2-types = { path = "../oauth2-types" }

View File

@@ -15,7 +15,7 @@
use clap::Parser;
use mas_config::{ConfigurationSection, RootConfig};
use rand::SeedableRng;
use tracing::info;
use tracing::{info, info_span};
#[derive(Parser, Debug)]
pub(super) struct Options {
@@ -40,6 +40,8 @@ impl Options {
use Subcommand as SC;
match &self.subcommand {
SC::Dump => {
let _span = info_span!("cli.config.dump").entered();
let config: RootConfig = root.load_config()?;
serde_yaml::to_writer(std::io::stdout(), &config)?;
@@ -47,11 +49,15 @@ impl Options {
Ok(())
}
SC::Check => {
let _span = info_span!("cli.config.check").entered();
let _config: RootConfig = root.load_config()?;
info!(path = ?root.config, "Configuration file looks good");
Ok(())
}
SC::Generate => {
let _span = info_span!("cli.config.generate").entered();
// XXX: we should disallow SeedableRng::from_entropy
let rng = rand_chacha::ChaChaRng::from_entropy();
let config = RootConfig::load_and_generate(rng).await?;

View File

@@ -15,7 +15,8 @@
use anyhow::Context;
use clap::Parser;
use mas_config::DatabaseConfig;
use mas_storage::MIGRATOR;
use mas_storage_pg::MIGRATOR;
use tracing::{info_span, Instrument};
use crate::util::database_from_config;
@@ -33,12 +34,14 @@ enum Subcommand {
impl Options {
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
let _span = info_span!("cli.database.migrate").entered();
let config: DatabaseConfig = root.load_config()?;
let pool = database_from_config(&config).await?;
// Run pending migrations
MIGRATOR
.run(&pool)
.instrument(info_span!("db.migrate"))
.await
.context("could not run migrations")?;

View File

@@ -19,7 +19,7 @@ use mas_handlers::HttpClientFactory;
use mas_http::HttpServiceExt;
use tokio::io::AsyncWriteExt;
use tower::{Service, ServiceExt};
use tracing::info;
use tracing::{info, info_span};
use crate::util::policy_factory_from_config;
@@ -74,6 +74,7 @@ impl Options {
json: false,
url,
} => {
let _span = info_span!("cli.debug.http").entered();
let mut client = http_client_factory.client("cli-debug-http").await?;
let request = hyper::Request::builder()
.uri(url)
@@ -98,6 +99,7 @@ impl Options {
json: true,
url,
} => {
let _span = info_span!("cli.debug.http").entered();
let mut client = http_client_factory
.client("cli-debug-http")
.await?
@@ -122,6 +124,7 @@ impl Options {
}
SC::Policy => {
let _span = info_span!("cli.debug.policy").entered();
let config: PolicyConfig = root.load_config()?;
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config).await?;

View File

@@ -18,15 +18,15 @@ use mas_config::{DatabaseConfig, PasswordsConfig, RootConfig};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_router::UrlBuilder;
use mas_storage::{
oauth2::client::{insert_client_from_config, lookup_client, truncate_clients},
user::{
add_user_password, lookup_user_by_username, lookup_user_email, mark_user_email_as_verified,
},
Clock,
oauth2::OAuth2ClientRepository,
upstream_oauth2::UpstreamOAuthProviderRepository,
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
Repository, RepositoryAccess, SystemClock,
};
use mas_storage_pg::PgRepository;
use oauth2_types::scope::Scope;
use rand::SeedableRng;
use tracing::{info, warn};
use tracing::{info, info_span, warn};
use crate::util::{database_from_config, password_manager_from_config};
@@ -147,9 +147,9 @@ enum Subcommand {
/// Import clients from config
ImportClients {
/// Remove all clients before importing
/// Update existing clients
#[arg(long)]
truncate: bool,
update: bool,
},
/// Set a user password
@@ -188,20 +188,25 @@ impl Options {
#[allow(clippy::too_many_lines)]
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
use Subcommand as SC;
let clock = Clock::default();
let clock = SystemClock::default();
// XXX: we should disallow SeedableRng::from_entropy
let mut rng = rand_chacha::ChaChaRng::from_entropy();
match &self.subcommand {
SC::SetPassword { username, password } => {
let _span =
info_span!("cli.manage.set_password", user.username = %username).entered();
let database_config: DatabaseConfig = root.load_config()?;
let passwords_config: PasswordsConfig = root.load_config()?;
let pool = database_from_config(&database_config).await?;
let password_manager = password_manager_from_config(&passwords_config).await?;
let mut txn = pool.begin().await?;
let user = lookup_user_by_username(&mut txn, username)
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
let user = repo
.user()
.find_by_username(username)
.await?
.context("User not found")?;
@@ -209,89 +214,96 @@ impl Options {
let (version, hashed_password) = password_manager.hash(&mut rng, password).await?;
add_user_password(
&mut txn,
&mut rng,
&clock,
&user,
version,
hashed_password,
None,
)
.await?;
repo.user_password()
.add(&mut rng, &clock, &user, version, hashed_password, None)
.await?;
info!(%user.id, %user.username, "Password changed");
txn.commit().await?;
repo.save().await?;
Ok(())
}
SC::VerifyEmail { username, email } => {
let _span = info_span!(
"cli.manage.verify_email",
user.username = username,
user_email.email = email
)
.entered();
let config: DatabaseConfig = root.load_config()?;
let pool = database_from_config(&config).await?;
let mut txn = pool.begin().await?;
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
let user = lookup_user_by_username(&mut txn, username)
let user = repo
.user()
.find_by_username(username)
.await?
.context("User not found")?;
let email = lookup_user_email(&mut txn, &user, email)
let email = repo
.user_email()
.find(&user, email)
.await?
.context("Email not found")?;
let email = mark_user_email_as_verified(&mut txn, &clock, email).await?;
let email = repo.user_email().mark_as_verified(&clock, email).await?;
txn.commit().await?;
repo.save().await?;
info!(?email, "Email marked as verified");
Ok(())
}
SC::ImportClients { truncate } => {
SC::ImportClients { update } => {
let _span = info_span!("cli.manage.import_clients").entered();
let config: RootConfig = root.load_config()?;
let pool = database_from_config(&config.database).await?;
let encrypter = config.secrets.encrypter();
let mut txn = pool.begin().await?;
if *truncate {
warn!("Removing all clients first");
truncate_clients(&mut txn).await?;
}
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
for client in config.clients.iter() {
let client_id = client.client_id;
let res = lookup_client(&mut txn, client_id).await?;
if res.is_some() {
warn!(%client_id, "Skipping already imported client");
let existing = repo.oauth2_client().lookup(client_id).await?.is_some();
if !update && existing {
warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients.");
continue;
}
info!(%client_id, "Importing client");
if existing {
info!(%client_id, "Updating client");
} else {
info!(%client_id, "Importing client");
}
let client_secret = client.client_secret();
let client_auth_method = client.client_auth_method();
let jwks = client.jwks();
let jwks_uri = client.jwks_uri();
let redirect_uris = &client.redirect_uris;
// TODO: should be moved somewhere else
let encrypted_client_secret = client_secret
.map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes()))
.transpose()?;
insert_client_from_config(
&mut txn,
&mut rng,
&clock,
client_id,
client_auth_method,
encrypted_client_secret.as_deref(),
jwks,
jwks_uri,
redirect_uris,
)
.await?;
repo.oauth2_client()
.add_from_config(
&mut rng,
&clock,
client_id,
client_auth_method,
encrypted_client_secret,
jwks.cloned(),
jwks_uri.cloned(),
client.redirect_uris.clone(),
)
.await?;
}
txn.commit().await?;
repo.save().await?;
Ok(())
}
@@ -304,11 +316,18 @@ impl Options {
client_secret,
signing_alg,
} => {
let _span = info_span!(
"cli.manage.add_oauth_upstream",
upstream_oauth_provider.issuer = issuer,
upstream_oauth_provider.client_id = client_id,
)
.entered();
let config: RootConfig = root.load_config()?;
let encrypter = config.secrets.encrypter();
let pool = database_from_config(&config.database).await?;
let url_builder = UrlBuilder::new(config.http.public_base);
let mut conn = pool.acquire().await?;
let mut repo = PgRepository::from_pool(&pool).await?;
let requires_client_secret = token_endpoint_auth_method.requires_client_secret();
@@ -329,18 +348,19 @@ impl Options {
.map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes()))
.transpose()?;
let provider = mas_storage::upstream_oauth2::add_provider(
&mut conn,
&mut rng,
&clock,
issuer.clone(),
scope.clone(),
token_endpoint_auth_method,
token_endpoint_signing_alg,
client_id.clone(),
encrypted_client_secret,
)
.await?;
let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
&clock,
issuer.clone(),
scope.clone(),
token_endpoint_auth_method,
token_endpoint_signing_alg,
client_id.clone(),
encrypted_client_secret,
)
.await?;
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
let auth_uri = url_builder.upstream_oauth_authorize(provider.id);

View File

@@ -1,4 +1,4 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -21,10 +21,10 @@ use mas_config::RootConfig;
use mas_handlers::{AppState, HttpClientFactory, MatrixHomeserver};
use mas_listener::{server::Server, shutdown::ShutdownStream};
use mas_router::UrlBuilder;
use mas_storage::MIGRATOR;
use mas_storage_pg::MIGRATOR;
use mas_tasks::TaskQueue;
use tokio::signal::unix::SignalKind;
use tracing::{info, warn};
use tracing::{info, info_span, warn, Instrument};
use crate::util::{
database_from_config, mailer_from_config, password_manager_from_config,
@@ -45,6 +45,7 @@ pub(super) struct Options {
impl Options {
#[allow(clippy::too_many_lines)]
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
let span = info_span!("cli.run.init").entered();
let config: RootConfig = root.load_config()?;
// Connect to the database
@@ -55,6 +56,7 @@ impl Options {
info!("Running pending migrations");
MIGRATOR
.run(&pool)
.instrument(info_span!("db.migrate"))
.await
.context("could not run migrations")?;
}
@@ -100,7 +102,7 @@ impl Options {
watch_templates(&templates).await?;
}
let graphql_schema = mas_handlers::graphql_schema(&pool);
let graphql_schema = mas_handlers::graphql_schema();
// Maximum 50 outgoing HTTP requests at a time
let http_client_factory = HttpClientFactory::new(50);
@@ -186,6 +188,8 @@ impl Options {
.with_signal(SignalKind::terminate())?
.with_signal(SignalKind::interrupt())?;
span.exit();
mas_listener::server::run_servers(servers, shutdown).await;
Ok(())

View File

@@ -14,9 +14,10 @@
use camino::Utf8PathBuf;
use clap::Parser;
use mas_storage::Clock;
use mas_storage::{Clock, SystemClock};
use mas_templates::Templates;
use rand::SeedableRng;
use tracing::info_span;
#[derive(Parser, Debug)]
pub(super) struct Options {
@@ -38,7 +39,9 @@ impl Options {
use Subcommand as SC;
match &self.subcommand {
SC::Check { path } => {
let clock = Clock::default();
let _span = info_span!("cli.templates.check").entered();
let clock = SystemClock::default();
// XXX: we should disallow SeedableRng::from_entropy
let mut rng = rand_chacha::ChaChaRng::from_entropy();
let url_builder = mas_router::UrlBuilder::new("https://example.com/".parse()?);

View File

@@ -110,6 +110,7 @@ pub async fn templates_from_config(
Templates::load(config.path.clone(), url_builder.clone()).await
}
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
pub async fn database_from_config(config: &DatabaseConfig) -> Result<PgPool, anyhow::Error> {
let mut options = match &config.options {
DatabaseConnectConfig::Uri { uri } => uri

View File

@@ -86,6 +86,7 @@ impl SecretsConfig {
/// # Errors
///
/// Returns an error when a key could not be imported
#[tracing::instrument(name = "secrets.load", skip_all, err(Debug))]
pub async fn key_store(&self) -> anyhow::Result<Keystore> {
let mut keys = Vec::with_capacity(self.keys.len());
for item in &self.keys {

View File

@@ -11,8 +11,8 @@ thiserror = "1.0.38"
serde = "1.0.152"
url = { version = "2.3.1", features = ["serde"] }
crc = "3.0.0"
ulid = { version = "1.0.0", features = ["serde"] }
rand = "0.8.5"
ulid = "1.0.0"
rand_chacha = "0.3.1"
mas-iana = { path = "../iana" }

View File

@@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,18 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::{DateTime, Utc};
use oauth2_types::scope::ScopeToken;
use rand::{
distributions::{Alphanumeric, DistString},
Rng,
RngCore,
};
use serde::Serialize;
use thiserror::Error;
use ulid::Ulid;
use url::Url;
use crate::User;
static DEVICE_ID_LENGTH: usize = 10;
@@ -53,7 +48,7 @@ impl Device {
}
/// Generate a random device ID
pub fn generate<R: Rng + ?Sized>(rng: &mut R) -> Self {
pub fn generate<R: RngCore + ?Sized>(rng: &mut R) -> Self {
let id: String = Alphanumeric.sample_string(rng, DEVICE_ID_LENGTH);
Self { id }
}
@@ -81,50 +76,3 @@ impl TryFrom<String> for Device {
Ok(Self { id })
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct CompatSession {
pub id: Ulid,
pub user: User,
pub device: Device,
pub created_at: DateTime<Utc>,
pub finished_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompatAccessToken {
pub id: Ulid,
pub token: String,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompatRefreshToken {
pub id: Ulid,
pub token: String,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum CompatSsoLoginState {
Pending,
Fulfilled {
fulfilled_at: DateTime<Utc>,
session: CompatSession,
},
Exchanged {
fulfilled_at: DateTime<Utc>,
exchanged_at: DateTime<Utc>,
session: CompatSession,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct CompatSsoLogin {
pub id: Ulid,
pub redirect_uri: Url,
pub login_token: String,
pub created_at: DateTime<Utc>,
pub state: CompatSsoLoginState,
}

View File

@@ -0,0 +1,106 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::{DateTime, Utc};
use ulid::Ulid;
mod device;
mod session;
mod sso_login;
pub use self::{
device::Device,
session::{CompatSession, CompatSessionState},
sso_login::{CompatSsoLogin, CompatSsoLoginState},
};
use crate::InvalidTransitionError;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompatAccessToken {
pub id: Ulid,
pub session_id: Ulid,
pub token: String,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
}
impl CompatAccessToken {
#[must_use]
pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
if let Some(expires_at) = self.expires_at {
expires_at > now
} else {
true
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum CompatRefreshTokenState {
#[default]
Valid,
Consumed {
consumed_at: DateTime<Utc>,
},
}
impl CompatRefreshTokenState {
/// Returns `true` if the compat refresh token state is [`Valid`].
///
/// [`Valid`]: CompatRefreshTokenState::Valid
#[must_use]
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
/// Returns `true` if the compat refresh token state is [`Consumed`].
///
/// [`Consumed`]: CompatRefreshTokenState::Consumed
#[must_use]
pub fn is_consumed(&self) -> bool {
matches!(self, Self::Consumed { .. })
}
pub fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Valid => Ok(Self::Consumed { consumed_at }),
Self::Consumed { .. } => Err(InvalidTransitionError),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompatRefreshToken {
pub id: Ulid,
pub state: CompatRefreshTokenState,
pub session_id: Ulid,
pub access_token_id: Ulid,
pub token: String,
pub created_at: DateTime<Utc>,
}
impl std::ops::Deref for CompatRefreshToken {
type Target = CompatRefreshTokenState;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl CompatRefreshToken {
pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.consume(consumed_at)?;
Ok(self)
}
}

View File

@@ -0,0 +1,86 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::{DateTime, Utc};
use serde::Serialize;
use ulid::Ulid;
use super::Device;
use crate::InvalidTransitionError;
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
pub enum CompatSessionState {
#[default]
Valid,
Finished {
finished_at: DateTime<Utc>,
},
}
impl CompatSessionState {
/// Returns `true` if the compta session state is [`Valid`].
///
/// [`Valid`]: CompatSessionState::Valid
#[must_use]
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
/// Returns `true` if the compta session state is [`Finished`].
///
/// [`Finished`]: CompatSessionState::Finished
#[must_use]
pub fn is_finished(&self) -> bool {
matches!(self, Self::Finished { .. })
}
pub fn finish(self, finished_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Valid => Ok(Self::Finished { finished_at }),
Self::Finished { .. } => Err(InvalidTransitionError),
}
}
#[must_use]
pub fn finished_at(&self) -> Option<DateTime<Utc>> {
match self {
CompatSessionState::Valid => None,
CompatSessionState::Finished { finished_at } => Some(*finished_at),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct CompatSession {
pub id: Ulid,
pub state: CompatSessionState,
pub user_id: Ulid,
pub device: Device,
pub created_at: DateTime<Utc>,
}
impl std::ops::Deref for CompatSession {
type Target = CompatSessionState;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl CompatSession {
pub fn finish(mut self, finished_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.finish(finished_at)?;
Ok(self)
}
}

View File

@@ -0,0 +1,151 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::{DateTime, Utc};
use serde::Serialize;
use ulid::Ulid;
use url::Url;
use super::CompatSession;
use crate::InvalidTransitionError;
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
pub enum CompatSsoLoginState {
#[default]
Pending,
Fulfilled {
fulfilled_at: DateTime<Utc>,
session_id: Ulid,
},
Exchanged {
fulfilled_at: DateTime<Utc>,
exchanged_at: DateTime<Utc>,
session_id: Ulid,
},
}
impl CompatSsoLoginState {
/// Returns `true` if the compat sso login state is [`Pending`].
///
/// [`Pending`]: CompatSsoLoginState::Pending
#[must_use]
pub fn is_pending(&self) -> bool {
matches!(self, Self::Pending)
}
/// Returns `true` if the compat sso login state is [`Fulfilled`].
///
/// [`Fulfilled`]: CompatSsoLoginState::Fulfilled
#[must_use]
pub fn is_fulfilled(&self) -> bool {
matches!(self, Self::Fulfilled { .. })
}
/// Returns `true` if the compat sso login state is [`Exchanged`].
///
/// [`Exchanged`]: CompatSsoLoginState::Exchanged
#[must_use]
pub fn is_exchanged(&self) -> bool {
matches!(self, Self::Exchanged { .. })
}
#[must_use]
pub fn fulfilled_at(&self) -> Option<DateTime<Utc>> {
match self {
Self::Pending => None,
Self::Fulfilled { fulfilled_at, .. } | Self::Exchanged { fulfilled_at, .. } => {
Some(*fulfilled_at)
}
}
}
#[must_use]
pub fn exchanged_at(&self) -> Option<DateTime<Utc>> {
match self {
Self::Pending | Self::Fulfilled { .. } => None,
Self::Exchanged { exchanged_at, .. } => Some(*exchanged_at),
}
}
#[must_use]
pub fn session_id(&self) -> Option<Ulid> {
match self {
Self::Pending => None,
Self::Fulfilled { session_id, .. } | Self::Exchanged { session_id, .. } => {
Some(*session_id)
}
}
}
pub fn fulfill(
self,
fulfilled_at: DateTime<Utc>,
session: &CompatSession,
) -> Result<Self, InvalidTransitionError> {
match self {
Self::Pending => Ok(Self::Fulfilled {
fulfilled_at,
session_id: session.id,
}),
Self::Fulfilled { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError),
}
}
pub fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Fulfilled {
fulfilled_at,
session_id,
} => Ok(Self::Exchanged {
fulfilled_at,
exchanged_at,
session_id,
}),
Self::Pending { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct CompatSsoLogin {
pub id: Ulid,
pub redirect_uri: Url,
pub login_token: String,
pub created_at: DateTime<Utc>,
pub state: CompatSsoLoginState,
}
impl std::ops::Deref for CompatSsoLogin {
type Target = CompatSsoLoginState;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl CompatSsoLogin {
pub fn fulfill(
mut self,
fulfilled_at: DateTime<Utc>,
session: &CompatSession,
) -> Result<Self, InvalidTransitionError> {
self.state = self.state.fulfill(fulfilled_at, session)?;
Ok(self)
}
pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.exchange(exchanged_at)?;
Ok(self)
}
}

View File

@@ -1,4 +1,4 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,24 +23,33 @@
clippy::type_repetition_in_bounds
)]
use thiserror::Error;
pub(crate) mod compat;
pub(crate) mod oauth2;
pub(crate) mod tokens;
pub(crate) mod upstream_oauth2;
pub(crate) mod users;
#[derive(Debug, Error)]
#[error("invalid state transition")]
pub struct InvalidTransitionError;
pub use self::{
compat::{
CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState,
Device,
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device,
},
oauth2::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client,
InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session,
InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState,
},
tokens::{
AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType,
},
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
upstream_oauth2::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider,
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState,
UpstreamOAuthLink, UpstreamOAuthProvider,
},
users::{
Authentication, BrowserSession, Password, User, UserEmail, UserEmailVerification,

View File

@@ -21,11 +21,11 @@ use oauth2_types::{
requests::ResponseMode,
};
use serde::Serialize;
use thiserror::Error;
use ulid::Ulid;
use url::Url;
use super::{client::Client, session::Session};
use super::session::Session;
use crate::InvalidTransitionError;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct Pkce {
@@ -53,21 +53,17 @@ pub struct AuthorizationCode {
pub pkce: Option<Pkce>,
}
#[derive(Debug, Error)]
#[error("invalid state transition")]
pub struct InvalidTransitionError;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
#[serde(tag = "stage", rename_all = "lowercase")]
pub enum AuthorizationGrantStage {
#[default]
Pending,
Fulfilled {
session: Session,
session_id: Ulid,
fulfilled_at: DateTime<Utc>,
},
Exchanged {
session: Session,
session_id: Ulid,
fulfilled_at: DateTime<Utc>,
exchanged_at: DateTime<Utc>,
},
@@ -82,35 +78,35 @@ impl AuthorizationGrantStage {
Self::Pending
}
pub fn fulfill(
fn fulfill(
self,
fulfilled_at: DateTime<Utc>,
session: Session,
session: &Session,
) -> Result<Self, InvalidTransitionError> {
match self {
Self::Pending => Ok(Self::Fulfilled {
fulfilled_at,
session,
session_id: session.id,
}),
_ => Err(InvalidTransitionError),
}
}
pub fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Fulfilled {
fulfilled_at,
session,
session_id,
} => Ok(Self::Exchanged {
fulfilled_at,
exchanged_at,
session,
session_id,
}),
_ => Err(InvalidTransitionError),
}
}
pub fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Pending => Ok(Self::Cancelled { cancelled_at }),
_ => Err(InvalidTransitionError),
@@ -124,6 +120,22 @@ impl AuthorizationGrantStage {
pub fn is_pending(&self) -> bool {
matches!(self, Self::Pending)
}
/// Returns `true` if the authorization grant stage is [`Fulfilled`].
///
/// [`Fulfilled`]: AuthorizationGrantStage::Fulfilled
#[must_use]
pub fn is_fulfilled(&self) -> bool {
matches!(self, Self::Fulfilled { .. })
}
/// Returns `true` if the authorization grant stage is [`Exchanged`].
///
/// [`Exchanged`]: AuthorizationGrantStage::Exchanged
#[must_use]
pub fn is_exchanged(&self) -> bool {
matches!(self, Self::Exchanged { .. })
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
@@ -132,7 +144,7 @@ pub struct AuthorizationGrant {
#[serde(flatten)]
pub stage: AuthorizationGrantStage,
pub code: Option<AuthorizationCode>,
pub client: Client,
pub client_id: Ulid,
pub redirect_uri: Url,
pub scope: oauth2_types::scope::Scope,
pub state: Option<String>,
@@ -144,10 +156,38 @@ pub struct AuthorizationGrant {
pub requires_consent: bool,
}
impl std::ops::Deref for AuthorizationGrant {
type Target = AuthorizationGrantStage;
fn deref(&self) -> &Self::Target {
&self.stage
}
}
impl AuthorizationGrant {
#[must_use]
pub fn max_auth_time(&self) -> DateTime<Utc> {
let max_age: Option<i64> = self.max_age.map(|x| x.get().into());
self.created_at - Duration::seconds(max_age.unwrap_or(3600 * 24 * 365))
}
pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.stage = self.stage.exchange(exchanged_at)?;
Ok(self)
}
pub fn fulfill(
mut self,
fulfilled_at: DateTime<Utc>,
session: &Session,
) -> Result<Self, InvalidTransitionError> {
self.stage = self.stage.fulfill(fulfilled_at, session)?;
Ok(self)
}
// TODO: this is not used?
pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.stage = self.stage.cancel(canceld_at)?;
Ok(self)
}
}

View File

@@ -1,4 +1,4 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -19,5 +19,5 @@ pub(self) mod session;
pub use self::{
authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce},
client::{Client, InvalidRedirectUriError, JwksOrJwksUri},
session::Session,
session::{Session, SessionState},
};

View File

@@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,17 +12,76 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::{DateTime, Utc};
use oauth2_types::scope::Scope;
use serde::Serialize;
use ulid::Ulid;
use super::client::Client;
use crate::users::BrowserSession;
use crate::InvalidTransitionError;
trait T {
type State;
}
impl T for Session {
type State = SessionState;
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
pub enum SessionState {
#[default]
Valid,
Finished {
finished_at: DateTime<Utc>,
},
}
impl SessionState {
/// Returns `true` if the session state is [`Valid`].
///
/// [`Valid`]: SessionState::Valid
#[must_use]
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
/// Returns `true` if the session state is [`Finished`].
///
/// [`Finished`]: SessionState::Finished
#[must_use]
pub fn is_finished(&self) -> bool {
matches!(self, Self::Finished { .. })
}
pub fn finish(self, finished_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Valid => Ok(Self::Finished { finished_at }),
Self::Finished { .. } => Err(InvalidTransitionError),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct Session {
pub id: Ulid,
pub browser_session: BrowserSession,
pub client: Client,
pub state: SessionState,
pub created_at: DateTime<Utc>,
pub user_session_id: Ulid,
pub client_id: Ulid,
pub scope: Scope,
}
impl std::ops::Deref for Session {
type Target = SessionState;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl Session {
pub fn finish(mut self, finished_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.finish(finished_at)?;
Ok(self)
}
}

View File

@@ -15,25 +15,135 @@
use chrono::{DateTime, Utc};
use crc::{Crc, CRC_32_ISO_HDLC};
use mas_iana::oauth::OAuthTokenTypeHint;
use rand::{distributions::Alphanumeric, Rng};
use rand::{distributions::Alphanumeric, Rng, RngCore};
use thiserror::Error;
use ulid::Ulid;
use crate::InvalidTransitionError;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum AccessTokenState {
#[default]
Valid,
Revoked {
revoked_at: DateTime<Utc>,
},
}
impl AccessTokenState {
fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Valid => Ok(Self::Revoked { revoked_at }),
Self::Revoked { .. } => Err(InvalidTransitionError),
}
}
/// Returns `true` if the refresh token state is [`Valid`].
///
/// [`Valid`]: AccessTokenState::Valid
#[must_use]
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
/// Returns `true` if the refresh token state is [`Revoked`].
///
/// [`Revoked`]: AccessTokenState::Revoked
#[must_use]
pub fn is_revoked(&self) -> bool {
matches!(self, Self::Revoked { .. })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AccessToken {
pub id: Ulid,
pub jti: String,
pub state: AccessTokenState,
pub session_id: Ulid,
pub access_token: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
impl AccessToken {
#[must_use]
pub fn jti(&self) -> String {
self.id.to_string()
}
#[must_use]
pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
self.state.is_valid() && self.expires_at > now
}
pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.revoke(revoked_at)?;
Ok(self)
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum RefreshTokenState {
#[default]
Valid,
Consumed {
consumed_at: DateTime<Utc>,
},
}
impl RefreshTokenState {
fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Valid => Ok(Self::Consumed { consumed_at }),
Self::Consumed { .. } => Err(InvalidTransitionError),
}
}
/// Returns `true` if the refresh token state is [`Valid`].
///
/// [`Valid`]: RefreshTokenState::Valid
#[must_use]
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
/// Returns `true` if the refresh token state is [`Consumed`].
///
/// [`Consumed`]: RefreshTokenState::Consumed
#[must_use]
pub fn is_consumed(&self) -> bool {
matches!(self, Self::Consumed { .. })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RefreshToken {
pub id: Ulid,
pub state: RefreshTokenState,
pub refresh_token: String,
pub session_id: Ulid,
pub created_at: DateTime<Utc>,
pub access_token: Option<AccessToken>,
pub access_token_id: Option<Ulid>,
}
impl std::ops::Deref for RefreshToken {
type Target = RefreshTokenState;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl RefreshToken {
#[must_use]
pub fn jti(&self) -> String {
self.id.to_string()
}
pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.consume(consumed_at)?;
Ok(self)
}
}
/// Type of token to generate or validate
@@ -80,10 +190,10 @@ impl TokenType {
/// use rand::thread_rng;
/// use mas_data_model::TokenType::{AccessToken, RefreshToken};
///
/// AccessToken.generate(thread_rng());
/// RefreshToken.generate(thread_rng());
/// AccessToken.generate(&mut thread_rng());
/// RefreshToken.generate(&mut thread_rng());
/// ```
pub fn generate(self, rng: impl Rng) -> String {
pub fn generate(self, rng: &mut (impl RngCore + ?Sized)) -> String {
let random_part: String = rng
.sample_iter(&Alphanumeric)
.take(30)

View File

@@ -0,0 +1,26 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::{DateTime, Utc};
use serde::Serialize;
use ulid::Ulid;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthLink {
pub id: Ulid,
pub provider_id: Ulid,
pub user_id: Option<Ulid>,
pub subject: String,
pub created_at: DateTime<Utc>,
}

View File

@@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,55 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::{DateTime, Utc};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope;
use serde::Serialize;
use ulid::Ulid;
mod link;
mod provider;
mod session;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthProvider {
pub id: Ulid,
pub issuer: String,
pub scope: Scope,
pub client_id: String,
pub encrypted_client_secret: Option<String>,
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthLink {
pub id: Ulid,
pub provider_id: Ulid,
pub user_id: Option<Ulid>,
pub subject: String,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthAuthorizationSession {
pub id: Ulid,
pub provider_id: Ulid,
pub link_id: Option<Ulid>,
pub state: String,
pub code_challenge_verifier: Option<String>,
pub nonce: String,
pub created_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
pub consumed_at: Option<DateTime<Utc>>,
pub id_token: Option<String>,
}
impl UpstreamOAuthAuthorizationSession {
#[must_use]
pub const fn completed(&self) -> bool {
self.completed_at.is_some()
}
#[must_use]
pub const fn consumed(&self) -> bool {
self.consumed_at.is_some()
}
}
pub use self::{
link::UpstreamOAuthLink,
provider::UpstreamOAuthProvider,
session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState},
};

View File

@@ -0,0 +1,31 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::{DateTime, Utc};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope;
use serde::Serialize;
use ulid::Ulid;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthProvider {
pub id: Ulid,
pub issuer: String,
pub scope: Scope,
pub client_id: String,
pub encrypted_client_secret: Option<String>,
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
pub created_at: DateTime<Utc>,
}

View File

@@ -0,0 +1,170 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::{DateTime, Utc};
use serde::Serialize;
use ulid::Ulid;
use super::UpstreamOAuthLink;
use crate::InvalidTransitionError;
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
pub enum UpstreamOAuthAuthorizationSessionState {
#[default]
Pending,
Completed {
completed_at: DateTime<Utc>,
link_id: Ulid,
id_token: Option<String>,
},
Consumed {
completed_at: DateTime<Utc>,
consumed_at: DateTime<Utc>,
link_id: Ulid,
id_token: Option<String>,
},
}
impl UpstreamOAuthAuthorizationSessionState {
pub fn complete(
self,
completed_at: DateTime<Utc>,
link: &UpstreamOAuthLink,
id_token: Option<String>,
) -> Result<Self, InvalidTransitionError> {
match self {
Self::Pending => Ok(Self::Completed {
completed_at,
link_id: link.id,
id_token,
}),
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
}
}
pub fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Completed {
completed_at,
link_id,
id_token,
} => Ok(Self::Consumed {
completed_at,
link_id,
consumed_at,
id_token,
}),
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
}
}
#[must_use]
pub fn link_id(&self) -> Option<Ulid> {
match self {
Self::Pending => None,
Self::Completed { link_id, .. } | Self::Consumed { link_id, .. } => Some(*link_id),
}
}
#[must_use]
pub fn completed_at(&self) -> Option<DateTime<Utc>> {
match self {
Self::Pending => None,
Self::Completed { completed_at, .. } | Self::Consumed { completed_at, .. } => {
Some(*completed_at)
}
}
}
#[must_use]
pub fn id_token(&self) -> Option<&str> {
match self {
Self::Pending => None,
Self::Completed { id_token, .. } | Self::Consumed { id_token, .. } => {
id_token.as_deref()
}
}
}
#[must_use]
pub fn consumed_at(&self) -> Option<DateTime<Utc>> {
match self {
Self::Pending | Self::Completed { .. } => None,
Self::Consumed { consumed_at, .. } => Some(*consumed_at),
}
}
/// Returns `true` if the upstream oauth authorization session state is
/// [`Pending`].
///
/// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
#[must_use]
pub fn is_pending(&self) -> bool {
matches!(self, Self::Pending)
}
/// Returns `true` if the upstream oauth authorization session state is
/// [`Completed`].
///
/// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
#[must_use]
pub fn is_completed(&self) -> bool {
matches!(self, Self::Completed { .. })
}
/// Returns `true` if the upstream oauth authorization session state is
/// [`Consumed`].
///
/// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed
#[must_use]
pub fn is_consumed(&self) -> bool {
matches!(self, Self::Consumed { .. })
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthAuthorizationSession {
pub id: Ulid,
pub state: UpstreamOAuthAuthorizationSessionState,
pub provider_id: Ulid,
pub state_str: String,
pub code_challenge_verifier: Option<String>,
pub nonce: String,
pub created_at: DateTime<Utc>,
}
impl std::ops::Deref for UpstreamOAuthAuthorizationSession {
type Target = UpstreamOAuthAuthorizationSessionState;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl UpstreamOAuthAuthorizationSession {
pub fn complete(
mut self,
completed_at: DateTime<Utc>,
link: &UpstreamOAuthLink,
id_token: Option<String>,
) -> Result<Self, InvalidTransitionError> {
self.state = self.state.complete(completed_at, link, id_token)?;
Ok(self)
}
pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.consume(consumed_at)?;
Ok(self)
}
}

View File

@@ -22,7 +22,7 @@ pub struct User {
pub id: Ulid,
pub username: String,
pub sub: String,
pub primary_email: Option<UserEmail>,
pub primary_user_email_id: Option<Ulid>,
}
impl User {
@@ -32,7 +32,7 @@ impl User {
id: Ulid::from_datetime_with_source(now.into(), rng),
username: "john".to_owned(),
sub: "123-456".to_owned(),
primary_email: None,
primary_user_email_id: None,
}]
}
}
@@ -57,10 +57,16 @@ pub struct BrowserSession {
pub id: Ulid,
pub user: User,
pub created_at: DateTime<Utc>,
pub finished_at: Option<DateTime<Utc>>,
pub last_authentication: Option<Authentication>,
}
impl BrowserSession {
#[must_use]
pub fn active(&self) -> bool {
self.finished_at.is_none()
}
#[must_use]
pub fn was_authenticated_after(&self, after: DateTime<Utc>) -> bool {
if let Some(auth) = &self.last_authentication {
@@ -80,6 +86,7 @@ impl BrowserSession {
id: Ulid::from_datetime_with_source(now.into(), rng),
user,
created_at: now,
finished_at: None,
last_authentication: None,
})
.collect()
@@ -89,6 +96,7 @@ impl BrowserSession {
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UserEmail {
pub id: Ulid,
pub user_id: Ulid,
pub email: String,
pub created_at: DateTime<Utc>,
pub confirmed_at: Option<DateTime<Utc>>,
@@ -100,12 +108,14 @@ impl UserEmail {
vec![
Self {
id: Ulid::from_datetime_with_source(now.into(), rng),
user_id: Ulid::from_datetime_with_source(now.into(), rng),
email: "alice@example.com".to_owned(),
created_at: now,
confirmed_at: Some(now),
},
Self {
id: Ulid::from_datetime_with_source(now.into(), rng),
user_id: Ulid::from_datetime_with_source(now.into(), rng),
email: "bob@example.com".to_owned(),
created_at: now,
confirmed_at: None,
@@ -124,7 +134,7 @@ pub enum UserEmailVerificationState {
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UserEmailVerification {
pub id: Ulid,
pub email: UserEmail,
pub user_email_id: Ulid,
pub code: String,
pub created_at: DateTime<Utc>,
pub state: UserEmailVerificationState,
@@ -152,8 +162,8 @@ impl UserEmailVerification {
.into_iter()
.map(move |email| Self {
id: Ulid::from_datetime_with_source(now.into(), &mut rng),
user_email_id: email.id,
code: "123456".to_owned(),
email,
created_at: now - Duration::minutes(10),
state: state.clone(),
})

View File

@@ -100,10 +100,10 @@ impl Mailer {
///
/// Will return `Err` if the email failed rendering or failed sending
#[tracing::instrument(
name = "email.verification.send",
skip_all,
fields(
email.to = %to,
email.from = %self.from,
user.id = %context.user().id,
user_email_verification.id = %context.verification().id,
user_email_verification.code = context.verification().code,
@@ -125,6 +125,7 @@ impl Mailer {
/// # Errors
///
/// Returns an error if the connection failed
#[tracing::instrument(name = "email.test_connection", skip_all, err)]
pub async fn test_connection(&self) -> Result<(), crate::transport::Error> {
self.transport.test_connection().await
}

View File

@@ -10,7 +10,7 @@ anyhow = "1.0.68"
async-graphql = { version = "5.0.5", features = ["chrono", "url"] }
chrono = "0.4.23"
serde = { version = "1.0.152", features = ["derive"] }
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] }
tokio = { version = "1.23.0", features = ["sync"] }
thiserror = "1.0.38"
tracing = "0.1.37"
ulid = "1.0.0"

View File

@@ -30,8 +30,14 @@ use async_graphql::{
connection::{query, Connection, Edge, OpaqueCursor},
Context, Description, EmptyMutation, EmptySubscription, ID,
};
use mas_storage::{
oauth2::OAuth2ClientRepository,
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
user::{BrowserSessionRepository, UserEmailRepository},
BoxRepository, Pagination,
};
use model::CreationEvent;
use sqlx::PgPool;
use tokio::sync::Mutex;
use self::model::{
BrowserSession, Cursor, Node, NodeCursor, NodeType, OAuth2Client, UpstreamOAuth2Link,
@@ -87,10 +93,9 @@ impl RootQuery {
id: ID,
) -> Result<Option<OAuth2Client>, async_graphql::Error> {
let id = NodeType::OAuth2Client.extract_ulid(&id)?;
let database = ctx.data::<PgPool>()?;
let mut conn = database.acquire().await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let client = mas_storage::oauth2::client::lookup_client(&mut conn, id).await?;
let client = repo.oauth2_client().lookup(id).await?;
Ok(client.map(OAuth2Client))
}
@@ -118,13 +123,12 @@ impl RootQuery {
) -> Result<Option<BrowserSession>, async_graphql::Error> {
let id = NodeType::BrowserSession.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let database = ctx.data::<PgPool>()?;
let mut conn = database.acquire().await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
let browser_session = mas_storage::user::lookup_active_session(&mut conn, id).await?;
let browser_session = repo.browser_session().lookup(id).await?;
let ret = browser_session.and_then(|browser_session| {
if browser_session.user.id == current_user.id {
@@ -145,14 +149,16 @@ impl RootQuery {
) -> Result<Option<UserEmail>, async_graphql::Error> {
let id = NodeType::UserEmail.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let database = ctx.data::<PgPool>()?;
let mut conn = database.acquire().await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
let user_email =
mas_storage::user::lookup_user_email_by_id(&mut conn, &current_user, id).await?;
let user_email = repo
.user_email()
.lookup(id)
.await?
.filter(|e| e.user_id == current_user.id);
Ok(user_email.map(UserEmail))
}
@@ -165,13 +171,12 @@ impl RootQuery {
) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> {
let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let database = ctx.data::<PgPool>()?;
let mut conn = database.acquire().await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
let link = mas_storage::upstream_oauth2::lookup_link(&mut conn, id).await?;
let link = repo.upstream_oauth_link().lookup(id).await?;
// Ensure that the link belongs to the current user
let link = link.filter(|link| link.user_id == Some(current_user.id));
@@ -186,10 +191,9 @@ impl RootQuery {
id: ID,
) -> Result<Option<UpstreamOAuth2Provider>, async_graphql::Error> {
let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?;
let database = ctx.data::<PgPool>()?;
let mut conn = database.acquire().await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id).await?;
let provider = repo.upstream_oauth_provider().lookup(id).await?;
Ok(provider.map(UpstreamOAuth2Provider::new))
}
@@ -206,7 +210,7 @@ impl RootQuery {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UpstreamOAuth2Provider>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query(
after,
@@ -214,7 +218,6 @@ impl RootQuery {
first,
last,
|after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_type(NodeType::UpstreamOAuth2Provider)
@@ -225,15 +228,15 @@ impl RootQuery {
x.extract_for_type(NodeType::UpstreamOAuth2Provider)
})
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let (has_previous_page, has_next_page, edges) =
mas_storage::upstream_oauth2::get_paginated_providers(
&mut conn, before_id, after_id, first, last,
)
let page = repo
.upstream_oauth_provider()
.list_paginated(pagination)
.await?;
let mut connection = Connection::new(has_previous_page, has_next_page);
connection.edges.extend(edges.into_iter().map(|p| {
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
connection.edges.extend(page.edges.into_iter().map(|p| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)),
UpstreamOAuth2Provider::new(p),

View File

@@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::{Description, Object, ID};
use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID};
use chrono::{DateTime, Utc};
use mas_data_model::CompatSsoLoginState;
use mas_storage::{compat::CompatSessionRepository, user::UserRepository, BoxRepository};
use tokio::sync::Mutex;
use url::Url;
use super::{NodeType, User};
@@ -32,8 +34,14 @@ impl CompatSession {
}
/// The user authorized for this session.
async fn user(&self) -> User {
User(self.0.user.clone())
async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let user = repo
.user()
.lookup(self.0.user_id)
.await?
.context("Could not load user")?;
Ok(User(user))
}
/// The Matrix Device ID of this session.
@@ -48,7 +56,7 @@ impl CompatSession {
/// When the session ended.
pub async fn finished_at(&self) -> Option<DateTime<Utc>> {
self.0.finished_at
self.0.finished_at()
}
}
@@ -77,29 +85,28 @@ impl CompatSsoLogin {
/// When the login was fulfilled, and the user was redirected back to the
/// client.
async fn fulfilled_at(&self) -> Option<DateTime<Utc>> {
match &self.0.state {
CompatSsoLoginState::Pending => None,
CompatSsoLoginState::Fulfilled { fulfilled_at, .. }
| CompatSsoLoginState::Exchanged { fulfilled_at, .. } => Some(*fulfilled_at),
}
self.0.fulfilled_at()
}
/// When the client exchanged the login token sent during the redirection.
async fn exchanged_at(&self) -> Option<DateTime<Utc>> {
match &self.0.state {
CompatSsoLoginState::Pending | CompatSsoLoginState::Fulfilled { .. } => None,
CompatSsoLoginState::Exchanged { exchanged_at, .. } => Some(*exchanged_at),
}
self.0.exchanged_at()
}
/// The compat session which was started by this login.
async fn session(&self) -> Option<CompatSession> {
match &self.0.state {
CompatSsoLoginState::Pending => None,
CompatSsoLoginState::Fulfilled { session, .. }
| CompatSsoLoginState::Exchanged { session, .. } => {
Some(CompatSession(session.clone()))
}
}
async fn session(
&self,
ctx: &Context<'_>,
) -> Result<Option<CompatSession>, async_graphql::Error> {
let Some(session_id) = self.0.session_id() else { return Ok(None) };
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let session = repo
.compat_session()
.lookup(session_id)
.await?
.context("Could not load compat session")?;
Ok(Some(CompatSession(session)))
}
}

View File

@@ -14,9 +14,9 @@
use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID};
use mas_storage::oauth2::client::lookup_client;
use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, BoxRepository};
use oauth2_types::scope::Scope;
use sqlx::PgPool;
use tokio::sync::Mutex;
use ulid::Ulid;
use url::Url;
@@ -35,8 +35,15 @@ impl OAuth2Session {
}
/// OAuth 2.0 client used by this session.
pub async fn client(&self) -> OAuth2Client {
OAuth2Client(self.0.client.clone())
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let client = repo
.oauth2_client()
.lookup(self.0.client_id)
.await?
.context("Could not load client")?;
Ok(OAuth2Client(client))
}
/// Scope granted for this session.
@@ -45,13 +52,30 @@ impl OAuth2Session {
}
/// The browser session which started this OAuth 2.0 session.
pub async fn browser_session(&self) -> BrowserSession {
BrowserSession(self.0.browser_session.clone())
pub async fn browser_session(
&self,
ctx: &Context<'_>,
) -> Result<BrowserSession, async_graphql::Error> {
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let browser_session = repo
.browser_session()
.lookup(self.0.user_session_id)
.await?
.context("Could not load browser session")?;
Ok(BrowserSession(browser_session))
}
/// User authorized for this session.
pub async fn user(&self) -> User {
User(self.0.browser_session.user.clone())
pub async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let browser_session = repo
.browser_session()
.lookup(self.0.user_session_id)
.await?
.context("Could not load browser session")?;
Ok(User(browser_session.user))
}
}
@@ -114,8 +138,10 @@ impl OAuth2Consent {
/// OAuth 2.0 client for which the user granted access.
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
let mut conn = ctx.data::<PgPool>()?.acquire().await?;
let client = lookup_client(&mut conn, self.client_id)
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let client = repo
.oauth2_client()
.lookup(self.client_id)
.await?
.context("Could not load client")?;
Ok(OAuth2Client(client))

View File

@@ -15,7 +15,10 @@
use anyhow::Context as _;
use async_graphql::{Context, Object, ID};
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, BoxRepository,
};
use tokio::sync::Mutex;
use super::{NodeType, User};
@@ -99,11 +102,13 @@ impl UpstreamOAuth2Link {
provider.clone()
} else {
// Fetch on-the-fly
let database = ctx.data::<PgPool>()?;
let mut conn = database.acquire().await?;
mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id)
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let provider = repo
.upstream_oauth_provider()
.lookup(self.link.provider_id)
.await?
.context("Upstream OAuth 2.0 provider not found")?
.context("Upstream OAuth 2.0 provider not found")?;
provider
};
Ok(UpstreamOAuth2Provider::new(provider))
@@ -116,9 +121,13 @@ impl UpstreamOAuth2Link {
user.clone()
} else if let Some(user_id) = &self.link.user_id {
// Fetch on-the-fly
let database = ctx.data::<PgPool>()?;
let mut conn = database.acquire().await?;
mas_storage::user::lookup_user(&mut conn, *user_id).await?
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let user = repo
.user()
.lookup(*user_id)
.await?
.context("User not found")?;
user
} else {
return Ok(None);
};

View File

@@ -17,7 +17,14 @@ use async_graphql::{
Context, Description, Object, ID,
};
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use mas_storage::{
compat::CompatSsoLoginRepository,
oauth2::OAuth2SessionRepository,
upstream_oauth2::UpstreamOAuthLinkRepository,
user::{BrowserSessionRepository, UserEmailRepository},
BoxRepository, Pagination,
};
use tokio::sync::Mutex;
use super::{
compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session,
@@ -53,8 +60,14 @@ impl User {
}
/// Primary email address of the user.
async fn primary_email(&self) -> Option<UserEmail> {
self.0.primary_email.clone().map(UserEmail)
async fn primary_email(
&self,
ctx: &Context<'_>,
) -> Result<Option<UserEmail>, async_graphql::Error> {
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let mut user_email_repo = repo.user_email();
Ok(user_email_repo.get_primary(&self.0).await?.map(UserEmail))
}
/// Get the list of compatibility SSO logins, chronologically sorted
@@ -69,7 +82,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, CompatSsoLogin>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query(
after,
@@ -77,22 +90,21 @@ impl User {
first,
last,
|after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin))
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin))
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let (has_previous_page, has_next_page, edges) =
mas_storage::compat::get_paginated_user_compat_sso_logins(
&mut conn, &self.0, before_id, after_id, first, last,
)
let page = repo
.compat_sso_login()
.list_paginated(&self.0, pagination)
.await?;
let mut connection = Connection::new(has_previous_page, has_next_page);
connection.edges.extend(edges.into_iter().map(|u| {
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
connection.edges.extend(page.edges.into_iter().map(|u| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.id)),
CompatSsoLogin(u),
@@ -117,7 +129,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, BrowserSession>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query(
after,
@@ -125,22 +137,21 @@ impl User {
first,
last,
|after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::BrowserSession))
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::BrowserSession))
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let (has_previous_page, has_next_page, edges) =
mas_storage::user::get_paginated_user_sessions(
&mut conn, &self.0, before_id, after_id, first, last,
)
let page = repo
.browser_session()
.list_active_paginated(&self.0, pagination)
.await?;
let mut connection = Connection::new(has_previous_page, has_next_page);
connection.edges.extend(edges.into_iter().map(|u| {
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
connection.edges.extend(page.edges.into_iter().map(|u| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::BrowserSession, u.id)),
BrowserSession(u),
@@ -165,7 +176,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UserEmail, UserEmailsPagination>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query(
after,
@@ -173,26 +184,25 @@ impl User {
first,
last,
|after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::UserEmail))
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::UserEmail))
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let (has_previous_page, has_next_page, edges) =
mas_storage::user::get_paginated_user_emails(
&mut conn, &self.0, before_id, after_id, first, last,
)
let page = repo
.user_email()
.list_paginated(&self.0, pagination)
.await?;
let mut connection = Connection::with_additional_fields(
has_previous_page,
has_next_page,
page.has_previous_page,
page.has_next_page,
UserEmailsPagination(self.0.clone()),
);
connection.edges.extend(edges.into_iter().map(|u| {
connection.edges.extend(page.edges.into_iter().map(|u| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::UserEmail, u.id)),
UserEmail(u),
@@ -217,7 +227,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, OAuth2Session>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query(
after,
@@ -225,22 +235,21 @@ impl User {
first,
last,
|after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::OAuth2Session))
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::OAuth2Session))
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let (has_previous_page, has_next_page, edges) =
mas_storage::oauth2::get_paginated_user_oauth_sessions(
&mut conn, &self.0, before_id, after_id, first, last,
)
let page = repo
.oauth2_session()
.list_paginated(&self.0, pagination)
.await?;
let mut connection = Connection::new(has_previous_page, has_next_page);
connection.edges.extend(edges.into_iter().map(|s| {
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
connection.edges.extend(page.edges.into_iter().map(|s| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::OAuth2Session, s.id)),
OAuth2Session(s),
@@ -265,7 +274,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UpstreamOAuth2Link>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query(
after,
@@ -273,7 +282,6 @@ impl User {
first,
last,
|after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_type(NodeType::UpstreamOAuth2Link)
@@ -284,15 +292,15 @@ impl User {
x.extract_for_type(NodeType::UpstreamOAuth2Link)
})
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let (has_previous_page, has_next_page, edges) =
mas_storage::upstream_oauth2::get_paginated_user_links(
&mut conn, &self.0, before_id, after_id, first, last,
)
let page = repo
.upstream_oauth_link()
.list_paginated(&self.0, pagination)
.await?;
let mut connection = Connection::new(has_previous_page, has_next_page);
connection.edges.extend(edges.into_iter().map(|s| {
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
connection.edges.extend(page.edges.into_iter().map(|s| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Link, s.id)),
UpstreamOAuth2Link::new(s),
@@ -339,9 +347,9 @@ pub struct UserEmailsPagination(mas_data_model::User);
#[Object]
impl UserEmailsPagination {
/// Identifies the total count of items in the connection.
async fn total_count(&self, ctx: &Context<'_>) -> Result<i64, async_graphql::Error> {
let mut conn = ctx.data::<PgPool>()?.acquire().await?;
let count = mas_storage::user::count_user_emails(&mut conn, &self.0).await?;
async fn total_count(&self, ctx: &Context<'_>) -> Result<usize, async_graphql::Error> {
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let count = repo.user_email().count(&self.0).await?;
Ok(count)
}
}

View File

@@ -68,6 +68,7 @@ mas-oidc-client = { path = "../oidc-client" }
mas-policy = { path = "../policy" }
mas-router = { path = "../router" }
mas-storage = { path = "../storage" }
mas-storage-pg = { path = "../storage-pg" }
mas-templates = { path = "../templates" }
oauth2-types = { path = "../oauth2-types" }

View File

@@ -12,16 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use std::{convert::Infallible, sync::Arc};
use axum::extract::FromRef;
use axum::{
async_trait,
extract::{FromRef, FromRequestParts},
response::IntoResponse,
};
use hyper::StatusCode;
use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_email::Mailer;
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock};
use mas_storage_pg::PgRepository;
use mas_templates::Templates;
use rand::SeedableRng;
use sqlx::PgPool;
use thiserror::Error;
use crate::{passwords::PasswordManager, MatrixHomeserver};
@@ -105,3 +114,58 @@ impl FromRef<AppState> for PasswordManager {
input.password_manager.clone()
}
}
#[async_trait]
impl FromRequestParts<AppState> for BoxClock {
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
_state: &AppState,
) -> Result<Self, Self::Rejection> {
let clock = SystemClock::default();
Ok(Box::new(clock))
}
}
#[async_trait]
impl FromRequestParts<AppState> for BoxRng {
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
_state: &AppState,
) -> Result<Self, Self::Rejection> {
// This rng is used to source the local rng
#[allow(clippy::disallowed_methods)]
let rng = rand::thread_rng();
let rng = rand_chacha::ChaChaRng::from_rng(rng).expect("Failed to seed RNG");
Ok(Box::new(rng))
}
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct RepositoryError(#[from] mas_storage_pg::DatabaseError);
impl IntoResponse for RepositoryError {
fn into_response(self) -> axum::response::Response {
(StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
}
}
#[async_trait]
impl FromRequestParts<AppState> for BoxRepository {
type Rejection = RepositoryError;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let repo = PgRepository::from_pool(&state.pool).await?;
Ok(repo
.map_err(mas_storage::RepositoryError::from_error)
.boxed())
}
}

View File

@@ -15,18 +15,18 @@
use axum::{extract::State, response::IntoResponse, Json};
use chrono::Duration;
use hyper::StatusCode;
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType};
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User};
use mas_storage::{
compat::{
add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token,
mark_compat_sso_login_as_exchanged, start_compat_session,
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
CompatSsoLoginRepository,
},
user::{add_user_password, lookup_user_by_username, lookup_user_password},
Clock,
user::{UserPasswordRepository, UserRepository},
BoxClock, BoxRepository, BoxRng, Clock,
};
use rand::{CryptoRng, RngCore};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
use sqlx::{PgPool, Postgres, Transaction};
use thiserror::Error;
use zeroize::Zeroizing;
@@ -137,6 +137,9 @@ pub enum RouteError {
#[error("user not found")]
UserNotFound,
#[error("session not found")]
SessionNotFound,
#[error("user has no password")]
NoPassword,
@@ -150,13 +153,12 @@ pub enum RouteError {
InvalidLoginToken,
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
match self {
Self::Internal(_) => MatrixError {
Self::Internal(_) | Self::SessionNotFound => MatrixError {
errcode: "M_UNKNOWN",
error: "Internal server error",
status: StatusCode::INTERNAL_SERVER_ERROR,
@@ -190,27 +192,37 @@ impl IntoResponse for RouteError {
#[tracing::instrument(skip_all, err)]
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(password_manager): State<PasswordManager>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
State(homeserver): State<MatrixHomeserver>,
Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
let session = match input.credentials {
let (session, user) = match input.credentials {
Credentials::Password {
identifier: Identifier::User { user },
password,
} => user_password_login(&password_manager, &mut txn, user, password).await?,
} => {
user_password_login(
&mut rng,
&clock,
&password_manager,
&mut repo,
user,
password,
)
.await?
}
Credentials::Token { token } => token_login(&mut txn, &clock, &token).await?,
Credentials::Token { token } => token_login(&mut repo, &clock, &token).await?,
_ => {
return Err(RouteError::Unsupported);
}
};
let user_id = format!("@{username}:{homeserver}", username = session.user.username);
let user_id = format!("@{username}:{homeserver}", username = user.username);
// If the client asked for a refreshable token, make it expire
let expires_in = if input.refresh_token {
@@ -221,33 +233,23 @@ pub(crate) async fn post(
};
let access_token = TokenType::CompatAccessToken.generate(&mut rng);
let access_token = add_compat_access_token(
&mut txn,
&mut rng,
&clock,
&session,
access_token,
expires_in,
)
.await?;
let access_token = repo
.compat_access_token()
.add(&mut rng, &clock, &session, access_token, expires_in)
.await?;
let refresh_token = if input.refresh_token {
let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng);
let refresh_token = add_compat_refresh_token(
&mut txn,
&mut rng,
&clock,
&session,
&access_token,
refresh_token,
)
.await?;
let refresh_token = repo
.compat_refresh_token()
.add(&mut rng, &clock, &session, &access_token, refresh_token)
.await?;
Some(refresh_token.token)
} else {
None
};
txn.commit().await?;
repo.save().await?;
Ok(Json(ResponseBody {
access_token: access_token.token,
@@ -259,16 +261,18 @@ pub(crate) async fn post(
}
async fn token_login(
txn: &mut Transaction<'_, Postgres>,
clock: &Clock,
repo: &mut BoxRepository,
clock: &dyn Clock,
token: &str,
) -> Result<CompatSession, RouteError> {
let login = get_compat_sso_login_by_token(&mut *txn, token)
) -> Result<(CompatSession, User), RouteError> {
let login = repo
.compat_sso_login()
.find_by_token(token)
.await?
.ok_or(RouteError::InvalidLoginToken)?;
let now = clock.now();
match login.state {
let session_id = match login.state {
CompatSsoLoginState::Pending => {
tracing::error!(
compat_sso_login.id = %login.id,
@@ -277,49 +281,70 @@ async fn token_login(
return Err(RouteError::InvalidLoginToken);
}
CompatSsoLoginState::Fulfilled {
fulfilled_at: fullfilled_at,
fulfilled_at,
session_id,
..
} => {
if now > fullfilled_at + Duration::seconds(30) {
if now > fulfilled_at + Duration::seconds(30) {
return Err(RouteError::LoginTookTooLong);
}
session_id
}
CompatSsoLoginState::Exchanged { exchanged_at, .. } => {
CompatSsoLoginState::Exchanged {
exchanged_at,
session_id,
..
} => {
if now > exchanged_at + Duration::seconds(30) {
// TODO: log that session out
tracing::error!(
compat_sso_login.id = %login.id,
compat_session.id = %session_id,
"Login token exchanged a second time more than 30s after"
);
}
return Err(RouteError::InvalidLoginToken);
}
}
};
let login = mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?;
let session = repo
.compat_session()
.lookup(session_id)
.await?
.ok_or(RouteError::SessionNotFound)?;
match login.state {
CompatSsoLoginState::Exchanged { session, .. } => Ok(session),
_ => unreachable!(),
}
let user = repo
.user()
.lookup(session.user_id)
.await?
.ok_or(RouteError::UserNotFound)?;
repo.compat_sso_login().exchange(clock, login).await?;
Ok((session, user))
}
async fn user_password_login(
mut rng: &mut (impl RngCore + CryptoRng + Send),
clock: &impl Clock,
password_manager: &PasswordManager,
txn: &mut Transaction<'_, Postgres>,
repo: &mut BoxRepository,
username: String,
password: String,
) -> Result<CompatSession, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
) -> Result<(CompatSession, User), RouteError> {
// Find the user
let user = lookup_user_by_username(&mut *txn, &username)
let user = repo
.user()
.find_by_username(&username)
.await?
.ok_or(RouteError::UserNotFound)?;
// Lookup its password
let user_password = lookup_user_password(&mut *txn, &user)
let user_password = repo
.user_password()
.active(&user)
.await?
.ok_or(RouteError::NoPassword)?;
@@ -338,21 +363,24 @@ async fn user_password_login(
if let Some((version, hashed_password)) = new_password_hash {
// Save the upgraded password if needed
add_user_password(
&mut *txn,
&mut rng,
&clock,
&user,
version,
hashed_password,
Some(user_password),
)
.await?;
repo.user_password()
.add(
&mut rng,
clock,
&user,
version,
hashed_password,
Some(&user_password),
)
.await?;
}
// Now that the user credentials have been verified, start a new compat session
let device = Device::generate(&mut rng);
let session = start_compat_session(&mut *txn, &mut rng, &clock, user, device).await?;
let session = repo
.compat_session()
.add(&mut rng, clock, &user, device)
.await?;
Ok(session)
Ok((session, user))
}

View File

@@ -29,10 +29,12 @@ use mas_axum_utils::{
use mas_data_model::Device;
use mas_keystore::Encrypter;
use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id};
use mas_storage::{
compat::{CompatSessionRepository, CompatSsoLoginRepository},
BoxClock, BoxRepository, BoxRng, Clock,
};
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use ulid::Ulid;
#[derive(Serialize)]
@@ -50,19 +52,18 @@ pub struct Params {
}
pub async fn get(
State(pool): State<PgPool>,
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
State(templates): State<Templates>,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(id): Path<Ulid>,
Query(params): Query<Params>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let maybe_session = session_info.load_session(&mut conn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session {
session
@@ -80,20 +81,16 @@ pub async fn get(
return Ok((cookie_jar, url).into_response());
};
// TODO: make that more generic
if session
.user
.primary_email
.as_ref()
.and_then(|e| e.confirmed_at)
.is_none()
{
// TODO: make that more generic, check that the email has been confirmed
if session.user.primary_user_email_id.is_none() {
let destination = mas_router::AccountAddEmail::default()
.and_then(PostAuthAction::continue_compat_sso_login(id));
return Ok((cookie_jar, destination.go()).into_response());
}
let login = get_compat_sso_login_by_id(&mut conn, id)
let login = repo
.compat_sso_login()
.lookup(id)
.await?
.context("Could not find compat SSO login")?;
@@ -117,20 +114,19 @@ pub async fn get(
}
pub async fn post(
State(pool): State<PgPool>,
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
State(templates): State<Templates>,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(id): Path<Ulid>,
Query(params): Query<Params>,
Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
cookie_jar.verify_form(clock.now(), form)?;
cookie_jar.verify_form(&clock, form)?;
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session {
session
@@ -149,19 +145,15 @@ pub async fn post(
};
// TODO: make that more generic
if session
.user
.primary_email
.as_ref()
.and_then(|e| e.confirmed_at)
.is_none()
{
if session.user.primary_user_email_id.is_none() {
let destination = mas_router::AccountAddEmail::default()
.and_then(PostAuthAction::continue_compat_sso_login(id));
return Ok((cookie_jar, destination.go()).into_response());
}
let login = get_compat_sso_login_by_id(&mut txn, id)
let login = repo
.compat_sso_login()
.lookup(id)
.await?
.context("Could not find compat SSO login")?;
@@ -193,10 +185,16 @@ pub async fn post(
};
let device = Device::generate(&mut rng);
let _login =
fullfill_compat_sso_login(&mut txn, &mut rng, &clock, session.user, login, device).await?;
let compat_session = repo
.compat_session()
.add(&mut rng, &clock, &session.user, device)
.await?;
txn.commit().await?;
repo.compat_sso_login()
.fulfill(&clock, login, &compat_session)
.await?;
repo.save().await?;
Ok((cookie_jar, Redirect::to(redirect_uri.as_str())).into_response())
}

View File

@@ -19,11 +19,10 @@ use axum::{
};
use hyper::StatusCode;
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
use mas_storage::compat::insert_compat_sso_login;
use mas_storage::{compat::CompatSsoLoginRepository, BoxClock, BoxRepository, BoxRng};
use rand::distributions::{Alphanumeric, DistString};
use serde::Deserialize;
use serde_with::serde;
use sqlx::PgPool;
use thiserror::Error;
use url::Url;
@@ -48,7 +47,7 @@ pub enum RouteError {
InvalidRedirectUrl,
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
@@ -56,14 +55,13 @@ impl IntoResponse for RouteError {
}
}
#[tracing::instrument(skip(pool, url_builder), err)]
pub async fn get(
State(pool): State<PgPool>,
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>,
Query(params): Query<Params>,
) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
// Check the redirectUrl parameter
let redirect_url = params.redirect_url.ok_or(RouteError::MissingRedirectUrl)?;
let redirect_url = Url::parse(&redirect_url).map_err(|_| RouteError::InvalidRedirectUrl)?;
@@ -79,8 +77,10 @@ pub async fn get(
}
let token = Alphanumeric.sample_string(&mut rng, 32);
let mut conn = pool.acquire().await?;
let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?;
let login = repo
.compat_sso_login()
.add(&mut rng, &clock, token, redirect_url)
.await?;
Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.id, params.action)))
}

View File

@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use axum::{extract::State, response::IntoResponse, Json, TypedHeader};
use axum::{response::IntoResponse, Json, TypedHeader};
use headers::{authorization::Bearer, Authorization};
use hyper::StatusCode;
use mas_data_model::TokenType;
use mas_storage::{compat::compat_logout, Clock};
use sqlx::PgPool;
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatSessionRepository},
BoxClock, BoxRepository, Clock,
};
use thiserror::Error;
use super::MatrixError;
@@ -36,12 +38,9 @@ pub enum RouteError {
#[error("Invalid access token")]
InvalidAuthorization,
#[error("Logout failed")]
LogoutFailed,
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
@@ -56,7 +55,7 @@ impl IntoResponse for RouteError {
error: "Missing access token",
status: StatusCode::UNAUTHORIZED,
},
Self::InvalidAuthorization | Self::LogoutFailed | Self::TokenFormat(_) => MatrixError {
Self::InvalidAuthorization | Self::TokenFormat(_) => MatrixError {
errcode: "M_UNKNOWN_TOKEN",
error: "Invalid access token",
status: StatusCode::UNAUTHORIZED,
@@ -67,12 +66,10 @@ impl IntoResponse for RouteError {
}
pub(crate) async fn post(
State(pool): State<PgPool>,
clock: BoxClock,
mut repo: BoxRepository,
maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>,
) -> Result<impl IntoResponse, RouteError> {
let clock = Clock::default();
let mut conn = pool.acquire().await?;
let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?;
let token = authorization.token();
@@ -82,9 +79,23 @@ pub(crate) async fn post(
return Err(RouteError::InvalidAuthorization);
}
if !compat_logout(&mut conn, &clock, token).await? {
return Err(RouteError::LogoutFailed);
}
let token = repo
.compat_access_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::InvalidAuthorization)?;
let session = repo
.compat_session()
.lookup(token.session_id)
.await?
.filter(|s| s.is_valid())
.ok_or(RouteError::InvalidAuthorization)?;
repo.compat_session().finish(&clock, session).await?;
repo.save().await?;
Ok(Json(serde_json::json!({})))
}

View File

@@ -12,17 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use axum::{extract::State, response::IntoResponse, Json};
use axum::{response::IntoResponse, Json};
use chrono::Duration;
use hyper::StatusCode;
use mas_data_model::{TokenFormatError, TokenType};
use mas_storage::compat::{
add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token,
expire_compat_access_token, lookup_active_compat_refresh_token,
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
BoxClock, BoxRepository, BoxRng, Clock,
};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DurationMilliSeconds};
use sqlx::PgPool;
use thiserror::Error;
use super::MatrixError;
@@ -40,17 +39,26 @@ pub enum RouteError {
#[error("invalid token")]
InvalidToken,
#[error("refresh token already consumed")]
RefreshTokenConsumed,
#[error("invalid session")]
InvalidSession,
#[error("unknown session")]
UnknownSession,
}
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
match self {
Self::Internal(_) => MatrixError {
Self::Internal(_) | Self::UnknownSession => MatrixError {
errcode: "M_UNKNOWN",
error: "Internal error",
status: StatusCode::INTERNAL_SERVER_ERROR,
},
Self::InvalidToken => MatrixError {
Self::InvalidToken | Self::InvalidSession | Self::RefreshTokenConsumed => MatrixError {
errcode: "M_UNKNOWN_TOKEN",
error: "Invalid refresh token",
status: StatusCode::UNAUTHORIZED,
@@ -60,8 +68,7 @@ impl IntoResponse for RouteError {
}
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl From<TokenFormatError> for RouteError {
fn from(_e: TokenFormatError) -> Self {
@@ -79,50 +86,79 @@ pub struct ResponseBody {
}
pub(crate) async fn post(
State(pool): State<PgPool>,
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
let token_type = TokenType::check(&input.refresh_token)?;
if token_type != TokenType::CompatRefreshToken {
return Err(RouteError::InvalidToken);
}
let (refresh_token, access_token, session) =
lookup_active_compat_refresh_token(&mut txn, &input.refresh_token)
.await?
.ok_or(RouteError::InvalidToken)?;
let refresh_token = repo
.compat_refresh_token()
.find_by_token(&input.refresh_token)
.await?
.ok_or(RouteError::InvalidToken)?;
if !refresh_token.is_valid() {
return Err(RouteError::RefreshTokenConsumed);
}
let session = repo
.compat_session()
.lookup(refresh_token.session_id)
.await?
.ok_or(RouteError::UnknownSession)?;
if !session.is_valid() {
return Err(RouteError::InvalidSession);
}
let access_token = repo
.compat_access_token()
.lookup(refresh_token.access_token_id)
.await?
.filter(|t| t.is_valid(clock.now()));
let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng);
let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
let expires_in = Duration::minutes(5);
let new_access_token = add_compat_access_token(
&mut txn,
&mut rng,
&clock,
&session,
new_access_token_str,
Some(expires_in),
)
.await?;
let new_refresh_token = add_compat_refresh_token(
&mut txn,
&mut rng,
&clock,
&session,
&new_access_token,
new_refresh_token_str,
)
.await?;
let new_access_token = repo
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
new_access_token_str,
Some(expires_in),
)
.await?;
let new_refresh_token = repo
.compat_refresh_token()
.add(
&mut rng,
&clock,
&session,
&new_access_token,
new_refresh_token_str,
)
.await?;
consume_compat_refresh_token(&mut txn, &clock, refresh_token).await?;
expire_compat_access_token(&mut txn, &clock, access_token).await?;
repo.compat_refresh_token()
.consume(&clock, refresh_token)
.await?;
txn.commit().await?;
if let Some(access_token) = access_token {
repo.compat_access_token()
.expire(&clock, access_token)
.await?;
}
repo.save().await?;
Ok(Json(ResponseBody {
access_token: new_access_token.token,

View File

@@ -22,19 +22,19 @@ use axum::{
Json, TypedHeader,
};
use axum_extra::extract::PrivateCookieJar;
use futures_util::{StreamExt, TryStreamExt};
use futures_util::TryStreamExt;
use headers::{ContentType, HeaderValue};
use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{FancyError, SessionInfoExt};
use mas_graphql::Schema;
use mas_keystore::Encrypter;
use sqlx::PgPool;
use mas_storage::BoxRepository;
use tokio::sync::Mutex;
use tracing::{info_span, Instrument};
#[must_use]
pub fn schema(pool: &PgPool) -> Schema {
pub fn schema() -> Schema {
mas_graphql::schema_builder()
.data(pool.clone())
.extension(Tracing)
.extension(ApolloTracing)
.finish()
@@ -58,8 +58,8 @@ fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
}
pub async fn post(
State(pool): State<PgPool>,
State(schema): State<Schema>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
content_type: Option<TypedHeader<ContentType>>,
body: BodyStream,
@@ -67,58 +67,46 @@ pub async fn post(
let content_type = content_type.map(|TypedHeader(h)| h.to_string());
let (session_info, _cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&pool).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let mut request = async_graphql::http::receive_batch_body(
let mut request = async_graphql::http::receive_body(
content_type,
body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.into_async_read(),
MultipartOptions::default(),
)
.await?; // XXX: this should probably return another error response?
.await? // XXX: this should probably return another error response?
.data(Mutex::new(repo));
if let Some(session) = maybe_session {
request = request.data(session);
}
let response = match request {
async_graphql::BatchRequest::Single(request) => {
let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
async_graphql::BatchResponse::Single(response)
}
async_graphql::BatchRequest::Batch(requests) => async_graphql::BatchResponse::Batch(
futures_util::stream::iter(requests.into_iter())
.then(|request| {
let span = span_for_graphql_request(&request);
schema.execute(request).instrument(span)
})
.collect()
.await,
),
};
let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
let cache_control = response
.cache_control()
.cache_control
.value()
.and_then(|v| HeaderValue::from_str(&v).ok())
.map(|h| [(CACHE_CONTROL, h)]);
let headers = response.http_headers();
let headers = response.http_headers.clone();
Ok((headers, cache_control, Json(response)))
}
pub async fn get(
State(pool): State<PgPool>,
State(schema): State<Schema>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> {
let (session_info, _cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&pool).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?;
let mut request =
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(Mutex::new(repo));
if let Some(session) = maybe_session {
request = request.data(session);

View File

@@ -1,4 +1,4 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -36,7 +36,7 @@ mod tests {
use super::*;
#[sqlx::test(migrator = "mas_storage::MIGRATOR")]
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_get_health(pool: PgPool) -> Result<(), anyhow::Error> {
let state = crate::test_state(pool).await?;
let app = crate::healthcheck_router().with_state(state);

View File

@@ -21,14 +21,17 @@
)]
#![warn(clippy::pedantic)]
#![allow(
clippy::unused_async // Some axum handlers need that
// Some axum handlers need that
clippy::unused_async,
// Because of how axum handlers work, we sometime have take many arguments
clippy::too_many_arguments,
)]
use std::{convert::Infallible, sync::Arc, time::Duration};
use axum::{
body::{Bytes, HttpBody},
extract::FromRef,
extract::{FromRef, FromRequestParts},
response::{Html, IntoResponse},
routing::{get, on, post, MethodFilter},
Router,
@@ -40,9 +43,9 @@ use mas_http::CorsLayerExt;
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_router::{Route, UrlBuilder};
use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_templates::{ErrorContext, Templates};
use passwords::PasswordManager;
use rand::SeedableRng;
use sqlx::PgPool;
use tower::util::AndThenLayer;
use tower_http::cors::{Any, CorsLayer};
@@ -94,7 +97,7 @@ where
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static,
mas_graphql::Schema: FromRef<S>,
PgPool: FromRef<S>,
BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>,
{
let mut router = Router::new().route(
@@ -116,6 +119,8 @@ where
S: Clone + Send + Sync + 'static,
Keystore: FromRef<S>,
UrlBuilder: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
{
Router::new()
.route(
@@ -152,9 +157,11 @@ where
Keystore: FromRef<S>,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
PgPool: FromRef<S>,
BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>,
HttpClientFactory: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
{
// All those routes are API-like, with a common CORS layer
Router::new()
@@ -205,9 +212,11 @@ where
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>,
PgPool: FromRef<S>,
BoxRepository: FromRequestParts<S>,
MatrixHomeserver: FromRef<S>,
PasswordManager: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
{
Router::new()
.route(
@@ -248,13 +257,15 @@ where
S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
PgPool: FromRef<S>,
BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>,
Templates: FromRef<S>,
Mailer: FromRef<S>,
Keystore: FromRef<S>,
HttpClientFactory: FromRef<S>,
PasswordManager: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
{
Router::new()
.route(
@@ -350,7 +361,7 @@ where
}
#[cfg(test)]
async fn test_state(pool: PgPool) -> Result<AppState, anyhow::Error> {
async fn test_state(pool: sqlx::PgPool) -> Result<AppState, anyhow::Error> {
use mas_email::MailTransport;
use crate::passwords::Hasher;
@@ -389,7 +400,7 @@ async fn test_state(pool: PgPool) -> Result<AppState, anyhow::Error> {
let policy_factory = Arc::new(policy_factory);
let graphql_schema = graphql_schema(&pool);
let graphql_schema = graphql_schema();
let http_client_factory = HttpClientFactory::new(10);
@@ -407,16 +418,3 @@ async fn test_state(pool: PgPool) -> Result<AppState, anyhow::Error> {
password_manager,
})
}
// XXX: that should be moved somewhere else
fn clock_and_rng() -> (mas_storage::Clock, rand_chacha::ChaChaRng) {
let clock = mas_storage::Clock::default();
// This rng is used to source the local rng
#[allow(clippy::disallowed_methods)]
let rng = rand::thread_rng();
let rng = rand_chacha::ChaChaRng::from_rng(rng).expect("Failed to seed RNG");
(clock, rng)
}

View File

@@ -25,13 +25,12 @@ use mas_data_model::{AuthorizationGrant, BrowserSession};
use mas_keystore::Encrypter;
use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::oauth2::{
authorization_grant::{derive_session, fulfill_grant, get_grant_by_id},
consent::fetch_client_consent,
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::Templates;
use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse};
use sqlx::{PgPool, Postgres, Transaction};
use thiserror::Error;
use ulid::Ulid;
@@ -69,8 +68,7 @@ impl IntoResponse for RouteError {
}
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
@@ -78,19 +76,21 @@ impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(super::callback::CallbackDestinationError);
pub(crate) async fn get(
rng: BoxRng,
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(grant_id): Path<Ulid>,
) -> Result<Response, RouteError> {
let mut txn = pool.begin().await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let grant = get_grant_by_id(&mut txn, grant_id)
let grant = repo
.oauth2_authorization_grant()
.lookup(grant_id)
.await?
.ok_or(RouteError::NotFound)?;
@@ -105,7 +105,7 @@ pub(crate) async fn get(
return Ok((cookie_jar, mas_router::Login::and_then(continue_grant).go()).into_response());
};
match complete(grant, session, &policy_factory, txn).await {
match complete(rng, clock, grant, session, &policy_factory, repo).await {
Ok(params) => {
let res = callback_destination.go(&templates, params).await?;
Ok((cookie_jar, res).into_response())
@@ -121,6 +121,7 @@ pub(crate) async fn get(
}
Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending),
Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)),
Err(e) => Err(RouteError::Internal(e.into())),
}
}
@@ -140,23 +141,25 @@ pub enum GrantCompletionError {
#[error("denied by the policy")]
PolicyViolation,
#[error("failed to load client")]
NoSuchClient,
}
impl_from_error_for_route!(GrantCompletionError: sqlx::Error);
impl_from_error_for_route!(GrantCompletionError: mas_storage::DatabaseError);
impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError);
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError);
pub(crate) async fn complete(
mut rng: BoxRng,
clock: BoxClock,
grant: AuthorizationGrant,
browser_session: BrowserSession,
policy_factory: &PolicyFactory,
mut txn: Transaction<'_, Postgres>,
mut repo: BoxRepository,
) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> {
let (clock, mut rng) = crate::clock_and_rng();
// Verify that the grant is in a pending stage
if !grant.stage.is_pending() {
return Err(GrantCompletionError::NotPending);
@@ -164,7 +167,7 @@ pub(crate) async fn complete(
// Check if the authentication is fresh enough
if !browser_session.was_authenticated_after(grant.max_auth_time()) {
txn.commit().await?;
repo.save().await?;
return Err(GrantCompletionError::RequiresReauth);
}
@@ -178,8 +181,16 @@ pub(crate) async fn complete(
return Err(GrantCompletionError::PolicyViolation);
}
let current_consent =
fetch_client_consent(&mut txn, &browser_session.user, &grant.client).await?;
let client = repo
.oauth2_client()
.lookup(grant.client_id)
.await?
.ok_or(GrantCompletionError::NoSuchClient)?;
let current_consent = repo
.oauth2_client()
.get_consent_for_user(&client, &browser_session.user)
.await?;
let lacks_consent = grant
.scope
@@ -188,14 +199,20 @@ pub(crate) async fn complete(
// Check if the client lacks consent *or* if consent was explicitely asked
if lacks_consent || grant.requires_consent {
txn.commit().await?;
repo.save().await?;
return Err(GrantCompletionError::RequiresConsent);
}
// All good, let's start the session
let session = derive_session(&mut txn, &mut rng, &clock, &grant, browser_session).await?;
let session = repo
.oauth2_session()
.create_from_grant(&mut rng, &clock, &grant, &browser_session)
.await?;
let grant = fulfill_grant(&mut txn, grant, session.clone()).await?;
let grant = repo
.oauth2_authorization_grant()
.fulfill(&clock, &session, grant)
.await?;
// Yep! Let's complete the auth now
let mut params = AuthorizationResponse::default();
@@ -213,6 +230,6 @@ pub(crate) async fn complete(
));
}
txn.commit().await?;
repo.save().await?;
Ok(params)
}

View File

@@ -25,8 +25,9 @@ use mas_data_model::{AuthorizationCode, Pkce};
use mas_keystore::Encrypter;
use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::oauth2::{
authorization_grant::new_authorization_grant, client::lookup_client_by_client_id,
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::Templates;
use oauth2_types::{
@@ -37,7 +38,6 @@ use oauth2_types::{
};
use rand::{distributions::Alphanumeric, Rng};
use serde::Deserialize;
use sqlx::PgPool;
use thiserror::Error;
use self::{callback::CallbackDestination, complete::GrantCompletionError};
@@ -89,8 +89,7 @@ impl IntoResponse for RouteError {
}
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(self::callback::CallbackDestinationError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
@@ -131,17 +130,18 @@ fn resolve_response_mode(
#[allow(clippy::too_many_lines)]
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(params): Form<Params>,
) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
// First, figure out what client it is
let client = lookup_client_by_client_id(&mut txn, &params.auth.client_id)
let client = repo
.oauth2_client()
.find_by_client_id(&params.auth.client_id)
.await?
.ok_or(RouteError::ClientNotFound)?;
@@ -167,7 +167,7 @@ pub(crate) async fn get(
let templates = templates.clone();
let callback_destination = callback_destination.clone();
async move {
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let prompt = params.auth.prompt.as_deref().unwrap_or_default();
// Check if the request/request_uri/registration params are used. If so, reply
@@ -272,23 +272,23 @@ pub(crate) async fn get(
let requires_consent = prompt.contains(&Prompt::Consent);
let grant = new_authorization_grant(
&mut txn,
&mut rng,
&clock,
client,
redirect_uri.clone(),
params.auth.scope,
code,
params.auth.state.clone(),
params.auth.nonce,
params.auth.max_age,
None,
response_mode,
response_type.has_id_token(),
requires_consent,
)
.await?;
let grant = repo
.oauth2_authorization_grant()
.add(
&mut rng,
&clock,
&client,
redirect_uri.clone(),
params.auth.scope,
code,
params.auth.state.clone(),
params.auth.nonce,
params.auth.max_age,
response_mode,
response_type.has_id_token(),
requires_consent,
)
.await?;
let continue_grant = PostAuthAction::continue_grant(grant.id);
let res = match maybe_session {
@@ -299,7 +299,7 @@ pub(crate) async fn get(
}
None if prompt.contains(&Prompt::Create) => {
// Client asked for a registration, show the registration prompt
txn.commit().await?;
repo.save().await?;
mas_router::Register::and_then(continue_grant)
.go()
@@ -307,7 +307,7 @@ pub(crate) async fn get(
}
None => {
// Other cases where we don't have a session, ask for a login
txn.commit().await?;
repo.save().await?;
mas_router::Login::and_then(continue_grant)
.go()
@@ -320,7 +320,7 @@ pub(crate) async fn get(
|| prompt.contains(&Prompt::SelectAccount) =>
{
// TODO: better pages here
txn.commit().await?;
repo.save().await?;
mas_router::Reauth::and_then(continue_grant)
.go()
@@ -330,7 +330,15 @@ pub(crate) async fn get(
// Else, we immediately try to complete the authorization grant
Some(user_session) if prompt.contains(&Prompt::None) => {
// With prompt=none, we should get back to the client immediately
match self::complete::complete(grant, user_session, &policy_factory, txn).await
match self::complete::complete(
rng,
clock,
grant,
user_session,
&policy_factory,
repo,
)
.await
{
Ok(params) => callback_destination.go(&templates, params).await?,
Err(GrantCompletionError::RequiresConsent) => {
@@ -357,7 +365,10 @@ pub(crate) async fn get(
Err(GrantCompletionError::Internal(e)) => {
return Err(RouteError::Internal(e))
}
Err(e @ GrantCompletionError::NotPending) => {
Err(
e @ (GrantCompletionError::NotPending
| GrantCompletionError::NoSuchClient),
) => {
// This should never happen
return Err(RouteError::Internal(Box::new(e)));
}
@@ -366,7 +377,15 @@ pub(crate) async fn get(
Some(user_session) => {
let grant_id = grant.id;
// Else, we show the relevant reauth/consent page if necessary
match self::complete::complete(grant, user_session, &policy_factory, txn).await
match self::complete::complete(
rng,
clock,
grant,
user_session,
&policy_factory,
repo,
)
.await
{
Ok(params) => callback_destination.go(&templates, params).await?,
Err(
@@ -387,7 +406,10 @@ pub(crate) async fn get(
Err(GrantCompletionError::Internal(e)) => {
return Err(RouteError::Internal(e))
}
Err(e @ GrantCompletionError::NotPending) => {
Err(
e @ (GrantCompletionError::NotPending
| GrantCompletionError::NoSuchClient),
) => {
// This should never happen
return Err(RouteError::Internal(Box::new(e)));
}

View File

@@ -28,12 +28,11 @@ use mas_data_model::AuthorizationGrantStage;
use mas_keystore::Encrypter;
use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::oauth2::{
authorization_grant::{get_grant_by_id, give_consent_to_grant},
consent::insert_client_consent,
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
use sqlx::PgPool;
use thiserror::Error;
use ulid::Ulid;
@@ -55,11 +54,13 @@ pub enum RouteError {
#[error("Policy violation")]
PolicyViolation,
#[error("Failed to load client")]
NoSuchClient,
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
@@ -71,20 +72,21 @@ impl IntoResponse for RouteError {
}
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(grant_id): Path<Ulid>,
) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let grant = get_grant_by_id(&mut conn, grant_id)
let grant = repo
.oauth2_authorization_grant()
.lookup(grant_id)
.await?
.ok_or(RouteError::GrantNotFound)?;
@@ -93,7 +95,7 @@ pub(crate) async fn get(
}
if let Some(session) = maybe_session {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let mut policy = policy_factory.instantiate().await?;
let res = policy
@@ -124,22 +126,23 @@ pub(crate) async fn get(
}
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(grant_id): Path<Ulid>,
Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
cookie_jar.verify_form(clock.now(), form)?;
cookie_jar.verify_form(&clock, form)?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let grant = get_grant_by_id(&mut txn, grant_id)
let grant = repo
.oauth2_authorization_grant()
.lookup(grant_id)
.await?
.ok_or(RouteError::GrantNotFound)?;
let next = PostAuthAction::continue_grant(grant_id);
@@ -160,6 +163,12 @@ pub(crate) async fn post(
return Err(RouteError::PolicyViolation);
}
let client = repo
.oauth2_client()
.lookup(grant.client_id)
.await?
.ok_or(RouteError::NoSuchClient)?;
// Do not consent for the "urn:matrix:org.matrix.msc2967.client:device:*" scope
let scope_without_device = grant
.scope
@@ -167,19 +176,21 @@ pub(crate) async fn post(
.filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:"))
.cloned()
.collect();
insert_client_consent(
&mut txn,
&mut rng,
&clock,
&session.user,
&grant.client,
&scope_without_device,
)
.await?;
repo.oauth2_client()
.give_consent_for_user(
&mut rng,
&clock,
&client,
&session.user,
&scope_without_device,
)
.await?;
let _grant = give_consent_to_grant(&mut txn, grant).await?;
repo.oauth2_authorization_grant()
.give_consent(grant)
.await?;
txn.commit().await?;
repo.save().await?;
Ok((cookie_jar, next.go_next()).into_response())
}

View File

@@ -22,18 +22,16 @@ use mas_data_model::{TokenFormatError, TokenType};
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_keystore::Encrypter;
use mas_storage::{
compat::{lookup_active_compat_access_token, lookup_active_compat_refresh_token},
oauth2::{
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
},
Clock,
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository},
user::{BrowserSessionRepository, UserRepository},
BoxClock, BoxRepository, Clock,
};
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
requests::{IntrospectionRequest, IntrospectionResponse},
scope::ScopeToken,
};
use sqlx::PgPool;
use thiserror::Error;
use crate::impl_from_error_for_route;
@@ -97,8 +95,7 @@ impl IntoResponse for RouteError {
}
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl From<TokenFormatError> for RouteError {
fn from(_e: TokenFormatError) -> Self {
@@ -125,18 +122,17 @@ const API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:org.matrix.msc
#[allow(clippy::too_many_lines)]
pub(crate) async fn post(
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
State(encrypter): State<Encrypter>,
client_authorization: ClientAuthorization<IntrospectionRequest>,
) -> Result<impl IntoResponse, RouteError> {
let clock = Clock::default();
let mut conn = pool.acquire().await?;
let client = client_authorization
.credentials
.fetch(&mut conn)
.await?
.fetch(&mut repo)
.await
.unwrap()
.ok_or(RouteError::ClientNotFound)?;
let method = match &client.token_endpoint_auth_method {
@@ -167,48 +163,103 @@ pub(crate) async fn post(
let reply = match token_type {
TokenType::AccessToken => {
let (token, session) = lookup_active_access_token(&mut conn, token)
let token = repo
.oauth2_access_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::UnknownToken)?;
let session = repo
.oauth2_session()
.lookup(token.session_id)
.await?
.filter(|s| s.is_valid())
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
let browser_session = repo
.browser_session()
.lookup(session.user_session_id)
.await?
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
IntrospectionResponse {
active: true,
scope: Some(session.scope),
client_id: Some(session.client.client_id),
username: Some(session.browser_session.user.username),
client_id: Some(session.client_id.to_string()),
username: Some(browser_session.user.username),
token_type: Some(OAuthTokenTypeHint::AccessToken),
exp: Some(token.expires_at),
iat: Some(token.created_at),
nbf: Some(token.created_at),
sub: Some(session.browser_session.user.sub),
sub: Some(browser_session.user.sub),
aud: None,
iss: None,
jti: None,
jti: Some(token.jti()),
}
}
TokenType::RefreshToken => {
let (token, session) = lookup_active_refresh_token(&mut conn, token)
let token = repo
.oauth2_refresh_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid())
.ok_or(RouteError::UnknownToken)?;
let session = repo
.oauth2_session()
.lookup(token.session_id)
.await?
.filter(|s| s.is_valid())
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
let browser_session = repo
.browser_session()
.lookup(session.user_session_id)
.await?
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
IntrospectionResponse {
active: true,
scope: Some(session.scope),
client_id: Some(session.client.client_id),
username: Some(session.browser_session.user.username),
client_id: Some(session.client_id.to_string()),
username: Some(browser_session.user.username),
token_type: Some(OAuthTokenTypeHint::RefreshToken),
exp: None,
iat: Some(token.created_at),
nbf: Some(token.created_at),
sub: Some(session.browser_session.user.sub),
sub: Some(browser_session.user.sub),
aud: None,
iss: None,
jti: None,
jti: Some(token.jti()),
}
}
TokenType::CompatAccessToken => {
let (token, session) = lookup_active_compat_access_token(&mut conn, &clock, token)
let access_token = repo
.compat_access_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::UnknownToken)?;
let session = repo
.compat_session()
.lookup(access_token.session_id)
.await?
.filter(|s| s.is_valid())
.ok_or(RouteError::UnknownToken)?;
let user = repo
.user()
.lookup(session.user_id)
.await?
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
let device_scope = session.device.to_scope_token();
@@ -218,22 +269,39 @@ pub(crate) async fn post(
active: true,
scope: Some(scope),
client_id: Some("legacy".into()),
username: Some(session.user.username),
username: Some(user.username),
token_type: Some(OAuthTokenTypeHint::AccessToken),
exp: token.expires_at,
iat: Some(token.created_at),
nbf: Some(token.created_at),
sub: Some(session.user.sub),
exp: access_token.expires_at,
iat: Some(access_token.created_at),
nbf: Some(access_token.created_at),
sub: Some(user.sub),
aud: None,
iss: None,
jti: None,
}
}
TokenType::CompatRefreshToken => {
let (refresh_token, _access_token, session) =
lookup_active_compat_refresh_token(&mut conn, token)
.await?
.ok_or(RouteError::UnknownToken)?;
let refresh_token = repo
.compat_refresh_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid())
.ok_or(RouteError::UnknownToken)?;
let session = repo
.compat_session()
.lookup(refresh_token.session_id)
.await?
.filter(|s| s.is_valid())
.ok_or(RouteError::UnknownToken)?;
let user = repo
.user()
.lookup(session.user_id)
.await?
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
let device_scope = session.device.to_scope_token();
let scope = [API_SCOPE, device_scope].into_iter().collect();
@@ -242,12 +310,12 @@ pub(crate) async fn post(
active: true,
scope: Some(scope),
client_id: Some("legacy".into()),
username: Some(session.user.username),
username: Some(user.username),
token_type: Some(OAuthTokenTypeHint::RefreshToken),
exp: None,
iat: Some(refresh_token.created_at),
nbf: Some(refresh_token.created_at),
sub: Some(session.user.sub),
sub: Some(user.sub),
aud: None,
iss: None,
jti: None,

View File

@@ -19,7 +19,7 @@ use hyper::StatusCode;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_keystore::Encrypter;
use mas_policy::{PolicyFactory, Violation};
use mas_storage::oauth2::client::insert_client;
use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng};
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
registration::{
@@ -27,10 +27,8 @@ use oauth2_types::{
},
};
use rand::distributions::{Alphanumeric, DistString};
use sqlx::PgPool;
use thiserror::Error;
use tracing::info;
use ulid::Ulid;
use crate::impl_from_error_for_route;
@@ -49,7 +47,7 @@ pub(crate) enum RouteError {
PolicyDenied(Vec<Violation>),
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
@@ -107,12 +105,13 @@ impl IntoResponse for RouteError {
#[tracing::instrument(skip_all, err)]
pub(crate) async fn post(
State(pool): State<PgPool>,
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
State(policy_factory): State<Arc<PolicyFactory>>,
State(encrypter): State<Encrypter>,
Json(body): Json<ClientMetadata>,
) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
info!(?body, "Client registration");
// Validate the body
@@ -124,16 +123,6 @@ pub(crate) async fn post(
return Err(RouteError::PolicyDenied(res.violations));
}
// Contacts was checked by the policy
let contacts = metadata.contacts.as_deref().unwrap_or_default();
// Grab a txn
let mut txn = pool.begin().await?;
let now = clock.now();
// Let's generate a random client ID
let client_id = Ulid::from_datetime_with_source(now.into(), &mut rng);
let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
Some(
OAuthClientAuthenticationMethod::ClientSecretJwt
@@ -148,41 +137,42 @@ pub(crate) async fn post(
_ => (None, None),
};
insert_client(
&mut txn,
&mut rng,
&clock,
client_id,
metadata.redirect_uris(),
encrypted_client_secret.as_deref(),
//&metadata.response_types(),
metadata.grant_types(),
contacts,
metadata
.client_name
.as_ref()
.map(|l| l.non_localized().as_ref()),
metadata.logo_uri.as_ref().map(Localized::non_localized),
metadata.client_uri.as_ref().map(Localized::non_localized),
metadata.policy_uri.as_ref().map(Localized::non_localized),
metadata.tos_uri.as_ref().map(Localized::non_localized),
metadata.jwks_uri.as_ref(),
metadata.jwks.as_ref(),
// XXX: those might not be right, should be function calls
metadata.id_token_signed_response_alg.as_ref(),
metadata.userinfo_signed_response_alg.as_ref(),
metadata.token_endpoint_auth_method.as_ref(),
metadata.token_endpoint_auth_signing_alg.as_ref(),
metadata.initiate_login_uri.as_ref(),
)
.await?;
let client = repo
.oauth2_client()
.add(
&mut rng,
&clock,
metadata.redirect_uris().to_vec(),
encrypted_client_secret,
//&metadata.response_types(),
metadata.grant_types().to_vec(),
metadata.contacts.clone().unwrap_or_default(),
metadata
.client_name
.clone()
.map(Localized::to_non_localized),
metadata.logo_uri.clone().map(Localized::to_non_localized),
metadata.client_uri.clone().map(Localized::to_non_localized),
metadata.policy_uri.clone().map(Localized::to_non_localized),
metadata.tos_uri.clone().map(Localized::to_non_localized),
metadata.jwks_uri.clone(),
metadata.jwks.clone(),
// XXX: those might not be right, should be function calls
metadata.id_token_signed_response_alg.clone(),
metadata.userinfo_signed_response_alg.clone(),
metadata.token_endpoint_auth_method.clone(),
metadata.token_endpoint_auth_signing_alg.clone(),
metadata.initiate_login_uri.clone(),
)
.await?;
txn.commit().await?;
repo.save().await?;
let response = ClientRegistrationResponse {
client_id: client_id.to_string(),
client_id: client.client_id,
client_secret,
client_id_issued_at: Some(now),
// XXX: we should have a `created_at` field on the clients
client_id_issued_at: Some(client.id.datetime().into()),
client_secret_expires_at: None,
};

View File

@@ -1,4 +1,4 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -31,11 +31,13 @@ use mas_jose::{
};
use mas_keystore::{Encrypter, Keystore};
use mas_router::UrlBuilder;
use mas_storage::oauth2::{
access_token::{add_access_token, revoke_access_token},
authorization_grant::{exchange_grant, lookup_grant_by_code},
end_oauth_session,
refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token},
use mas_storage::{
oauth2::{
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
OAuth2RefreshTokenRepository, OAuth2SessionRepository,
},
user::BrowserSessionRepository,
BoxClock, BoxRepository, BoxRng, Clock,
};
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
@@ -47,7 +49,6 @@ use oauth2_types::{
};
use serde::Serialize;
use serde_with::{serde_as, skip_serializing_none};
use sqlx::{PgPool, Postgres, Transaction};
use thiserror::Error;
use tracing::debug;
use url::Url;
@@ -102,12 +103,21 @@ pub(crate) enum RouteError {
#[error("no suitable key found for signing")]
InvalidSigningKey,
#[error("failed to load browser session")]
NoSuchBrowserSession,
#[error("failed to load oauth session")]
NoSuchOAuthSession,
}
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
match self {
Self::Internal(_) | Self::InvalidSigningKey => (
Self::Internal(_)
| Self::InvalidSigningKey
| Self::NoSuchBrowserSession
| Self::NoSuchOAuthSession => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ClientError::from(ClientErrorCode::ServerError)),
),
@@ -139,8 +149,7 @@ impl IntoResponse for RouteError {
}
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_keystore::WrongAlgorithmError);
impl_from_error_for_route!(mas_jose::claims::ClaimError);
impl_from_error_for_route!(mas_jose::claims::TokenHashError);
@@ -148,18 +157,18 @@ impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError);
#[tracing::instrument(skip_all, err)]
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
State(key_store): State<Keystore>,
State(url_builder): State<UrlBuilder>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
State(encrypter): State<Encrypter>,
client_authorization: ClientAuthorization<AccessTokenRequest>,
) -> Result<impl IntoResponse, RouteError> {
let mut txn = pool.begin().await?;
let client = client_authorization
.credentials
.fetch(&mut txn)
.fetch(&mut repo)
.await?
.ok_or(RouteError::ClientNotFound)?;
@@ -175,18 +184,29 @@ pub(crate) async fn post(
let form = client_authorization.form.ok_or(RouteError::BadRequest)?;
let reply = match form {
let (reply, repo) = match form {
AccessTokenRequest::AuthorizationCode(grant) => {
authorization_code_grant(&grant, &client, &key_store, &url_builder, txn).await?
authorization_code_grant(
&mut rng,
&clock,
&grant,
&client,
&key_store,
&url_builder,
repo,
)
.await?
}
AccessTokenRequest::RefreshToken(grant) => {
refresh_token_grant(&grant, &client, txn).await?
refresh_token_grant(&mut rng, &clock, &grant, &client, repo).await?
}
_ => {
return Err(RouteError::InvalidGrant);
}
};
repo.save().await?;
let mut headers = HeaderMap::new();
headers.typed_insert(CacheControl::new().with_no_store());
headers.typed_insert(Pragma::no_cache());
@@ -196,23 +216,23 @@ pub(crate) async fn post(
#[allow(clippy::too_many_lines)]
async fn authorization_code_grant(
mut rng: &mut BoxRng,
clock: &impl Clock,
grant: &AuthorizationCodeGrant,
client: &Client,
key_store: &Keystore,
url_builder: &UrlBuilder,
mut txn: Transaction<'_, Postgres>,
) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
// TODO: there is a bunch of unnecessary cloning here
// TODO: handle "not found" cases
let authz_grant = lookup_grant_by_code(&mut txn, &grant.code)
mut repo: BoxRepository,
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
let authz_grant = repo
.oauth2_authorization_grant()
.find_by_code(&grant.code)
.await?
.ok_or(RouteError::GrantNotFound)?;
let now = clock.now();
let session = match authz_grant.stage {
let session_id = match authz_grant.stage {
AuthorizationGrantStage::Cancelled { cancelled_at } => {
debug!(%cancelled_at, "Authorization grant was cancelled");
return Err(RouteError::InvalidGrant);
@@ -220,15 +240,20 @@ async fn authorization_code_grant(
AuthorizationGrantStage::Exchanged {
exchanged_at,
fulfilled_at,
session,
session_id,
} => {
debug!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged");
// Ending the session if the token was already exchanged more than 20s ago
if now - exchanged_at > Duration::seconds(20) {
debug!("Ending potentially compromised session");
end_oauth_session(&mut txn, &clock, session).await?;
txn.commit().await?;
let session = repo
.oauth2_session()
.lookup(session_id)
.await?
.ok_or(RouteError::NoSuchOAuthSession)?;
repo.oauth2_session().finish(clock, session).await?;
repo.save().await?;
}
return Err(RouteError::InvalidGrant);
@@ -238,7 +263,7 @@ async fn authorization_code_grant(
return Err(RouteError::InvalidGrant);
}
AuthorizationGrantStage::Fulfilled {
ref session,
session_id,
fulfilled_at,
} => {
if now - fulfilled_at > Duration::minutes(10) {
@@ -246,14 +271,20 @@ async fn authorization_code_grant(
return Err(RouteError::InvalidGrant);
}
session
session_id
}
};
let session = repo
.oauth2_session()
.lookup(session_id)
.await?
.ok_or(RouteError::NoSuchOAuthSession)?;
// This should never happen, since we looked up in the database using the code
let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?;
if client.client_id != session.client.client_id {
if client.id != session.client_id {
return Err(RouteError::UnauthorizedClient);
}
@@ -267,31 +298,25 @@ async fn authorization_code_grant(
}
};
let browser_session = &session.browser_session;
let browser_session = repo
.browser_session()
.lookup(session.user_session_id)
.await?
.ok_or(RouteError::NoSuchBrowserSession)?;
let ttl = Duration::minutes(5);
let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let access_token = add_access_token(
&mut txn,
&mut rng,
&clock,
session,
access_token_str.clone(),
ttl,
)
.await?;
let access_token = repo
.oauth2_access_token()
.add(&mut rng, clock, &session, access_token_str, ttl)
.await?;
let _refresh_token = add_refresh_token(
&mut txn,
&mut rng,
&clock,
session,
access_token,
refresh_token_str.clone(),
)
.await?;
let refresh_token = repo
.oauth2_refresh_token()
.add(&mut rng, clock, &session, &access_token, refresh_token_str)
.await?;
let id_token = if session.scope.contains(&scope::OPENID) {
let mut claims = HashMap::new();
@@ -317,7 +342,7 @@ async fn authorization_code_grant(
.signing_key_for_algorithm(&alg)
.ok_or(RouteError::InvalidSigningKey)?;
claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token_str)?)?;
claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token.access_token)?)?;
claims::C_HASH.insert(&mut claims, hash_token(&alg, &grant.code)?)?;
let signer = key.params().signing_key_for_alg(&alg)?;
@@ -330,34 +355,46 @@ async fn authorization_code_grant(
None
};
let mut params = AccessTokenResponse::new(access_token_str)
let mut params = AccessTokenResponse::new(access_token.access_token)
.with_expires_in(ttl)
.with_refresh_token(refresh_token_str)
.with_refresh_token(refresh_token.refresh_token)
.with_scope(session.scope.clone());
if let Some(id_token) = id_token {
params = params.with_id_token(id_token);
}
exchange_grant(&mut txn, &clock, authz_grant).await?;
repo.oauth2_authorization_grant()
.exchange(clock, authz_grant)
.await?;
txn.commit().await?;
Ok(params)
Ok((params, repo))
}
async fn refresh_token_grant(
mut rng: &mut BoxRng,
clock: &impl Clock,
grant: &RefreshTokenGrant,
client: &Client,
mut txn: Transaction<'_, Postgres>,
) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token)
mut repo: BoxRepository,
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
let refresh_token = repo
.oauth2_refresh_token()
.find_by_token(&grant.refresh_token)
.await?
.ok_or(RouteError::InvalidGrant)?;
if client.client_id != session.client.client_id {
let session = repo
.oauth2_session()
.lookup(refresh_token.session_id)
.await?
.ok_or(RouteError::NoSuchOAuthSession)?;
if !refresh_token.is_valid() || !session.is_valid() {
return Err(RouteError::InvalidGrant);
}
if client.id != session.client_id {
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
return Err(RouteError::InvalidGrant);
}
@@ -366,30 +403,34 @@ async fn refresh_token_grant(
let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let new_access_token = add_access_token(
&mut txn,
&mut rng,
&clock,
&session,
access_token_str.clone(),
ttl,
)
.await?;
let new_access_token = repo
.oauth2_access_token()
.add(&mut rng, clock, &session, access_token_str.clone(), ttl)
.await?;
let new_refresh_token = add_refresh_token(
&mut txn,
&mut rng,
&clock,
&session,
new_access_token,
refresh_token_str,
)
.await?;
let new_refresh_token = repo
.oauth2_refresh_token()
.add(
&mut rng,
clock,
&session,
&new_access_token,
refresh_token_str,
)
.await?;
consume_refresh_token(&mut txn, &clock, &refresh_token).await?;
let refresh_token = repo
.oauth2_refresh_token()
.consume(clock, refresh_token)
.await?;
if let Some(access_token) = refresh_token.access_token {
revoke_access_token(&mut txn, &clock, access_token).await?;
if let Some(access_token_id) = refresh_token.access_token_id {
let access_token = repo.oauth2_access_token().lookup(access_token_id).await?;
if let Some(access_token) = access_token {
repo.oauth2_access_token()
.revoke(clock, access_token)
.await?;
}
}
let params = AccessTokenResponse::new(access_token_str)
@@ -397,7 +438,5 @@ async fn refresh_token_grant(
.with_refresh_token(new_refresh_token.refresh_token)
.with_scope(session.scope);
txn.commit().await?;
Ok(params)
Ok((params, repo))
}

View File

@@ -28,10 +28,14 @@ use mas_jose::{
};
use mas_keystore::Keystore;
use mas_router::UrlBuilder;
use mas_storage::{
oauth2::OAuth2ClientRepository,
user::{BrowserSessionRepository, UserEmailRepository},
BoxClock, BoxRepository, BoxRng,
};
use oauth2_types::scope;
use serde::Serialize;
use serde_with::skip_serializing_none;
use sqlx::PgPool;
use thiserror::Error;
use crate::impl_from_error_for_route;
@@ -59,20 +63,31 @@ pub enum RouteError {
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("failed to authenticate")]
AuthorizationVerificationError(#[from] AuthorizationVerificationError),
AuthorizationVerificationError(
#[from] AuthorizationVerificationError<mas_storage::RepositoryError>,
),
#[error("no suitable key found for signing")]
InvalidSigningKey,
#[error("failed to load client")]
NoSuchClient,
#[error("failed to load browser session")]
NoSuchBrowserSession,
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_keystore::WrongAlgorithmError);
impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
match self {
Self::Internal(_) | Self::InvalidSigningKey => {
Self::Internal(_)
| Self::InvalidSigningKey
| Self::NoSuchClient
| Self::NoSuchBrowserSession => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
Self::AuthorizationVerificationError(_e) => StatusCode::UNAUTHORIZED.into_response(),
@@ -81,32 +96,43 @@ impl IntoResponse for RouteError {
}
pub async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(url_builder): State<UrlBuilder>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
State(key_store): State<Keystore>,
user_authorization: UserAuthorization,
) -> Result<Response, RouteError> {
let (_clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let session = user_authorization.protected(&mut repo, &clock).await?;
let session = user_authorization.protected(&mut conn).await?;
let browser_session = repo
.browser_session()
.lookup(session.user_session_id)
.await?
.ok_or(RouteError::NoSuchBrowserSession)?;
let user = session.browser_session.user;
let mut user_info = UserInfo {
sub: user.sub,
username: user.username,
email: None,
email_verified: None,
let user = browser_session.user;
let user_email = if session.scope.contains(&scope::EMAIL) {
repo.user_email().get_primary(&user).await?
} else {
None
};
if session.scope.contains(&scope::EMAIL) {
if let Some(email) = user.primary_email {
user_info.email_verified = Some(email.confirmed_at.is_some());
user_info.email = Some(email.email);
}
}
let user_info = UserInfo {
sub: user.sub.clone(),
username: user.username.clone(),
email_verified: user_email.as_ref().map(|u| u.confirmed_at.is_some()),
email: user_email.map(|u| u.email),
};
if let Some(alg) = session.client.userinfo_signed_response_alg {
let client = repo
.oauth2_client()
.lookup(session.client_id)
.await?
.ok_or(RouteError::NoSuchClient)?;
if let Some(alg) = client.userinfo_signed_response_alg {
let key = key_store
.signing_key_for_algorithm(&alg)
.ok_or(RouteError::InvalidSigningKey)?;
@@ -117,7 +143,7 @@ pub async fn get(
let user_info = SignedUserInfo {
iss: url_builder.oidc_issuer().to_string(),
aud: session.client.client_id,
aud: client.client_id,
user_info,
};

View File

@@ -71,7 +71,7 @@ impl PasswordManager {
/// # Errors
///
/// Returns an error if the hashing failed
#[tracing::instrument(skip_all)]
#[tracing::instrument(name = "passwords.hash", skip_all)]
pub async fn hash<R: CryptoRng + RngCore + Send>(
&self,
rng: R,
@@ -82,13 +82,16 @@ impl PasswordManager {
let rng = rand_chacha::ChaChaRng::from_rng(rng)?;
let hashers = self.hashers.clone();
let default_hasher_version = self.default_hasher;
let span = tracing::Span::current();
let hashed = tokio::task::spawn_blocking(move || {
let default_hasher = hashers
.get(&default_hasher_version)
.context("Default hasher not found")?;
span.in_scope(move || {
let default_hasher = hashers
.get(&default_hasher_version)
.context("Default hasher not found")?;
default_hasher.hash_blocking(rng, &password)
default_hasher.hash_blocking(rng, &password)
})
})
.await??;
@@ -100,7 +103,7 @@ impl PasswordManager {
/// # Errors
///
/// Returns an error if the password hash verification failed
#[tracing::instrument(skip_all, fields(%scheme))]
#[tracing::instrument(name = "passwords.verify", skip_all, fields(%scheme))]
pub async fn verify(
&self,
scheme: SchemeVersion,
@@ -108,10 +111,13 @@ impl PasswordManager {
hashed_password: String,
) -> Result<(), anyhow::Error> {
let hashers = self.hashers.clone();
let span = tracing::Span::current();
tokio::task::spawn_blocking(move || {
let hasher = hashers.get(&scheme).context("Hashing scheme not found")?;
hasher.verify_blocking(&hashed_password, &password)
span.in_scope(move || {
let hasher = hashers.get(&scheme).context("Hashing scheme not found")?;
hasher.verify_blocking(&hashed_password, &password)
})
})
.await??;
@@ -124,7 +130,7 @@ impl PasswordManager {
/// # Errors
///
/// Returns an error if the password hash verification failed
#[tracing::instrument(skip_all, fields(%scheme))]
#[tracing::instrument(name = "passwords.verify_and_upgrade", skip_all, fields(%scheme))]
pub async fn verify_and_upgrade<R: CryptoRng + RngCore + Send>(
&self,
rng: R,

View File

@@ -22,8 +22,10 @@ use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_keystore::Encrypter;
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
use mas_router::UrlBuilder;
use mas_storage::upstream_oauth2::lookup_provider;
use sqlx::PgPool;
use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
BoxClock, BoxRepository, BoxRng,
};
use thiserror::Error;
use ulid::Ulid;
@@ -39,11 +41,10 @@ pub(crate) enum RouteError {
Internal(Box<dyn std::error::Error>),
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_http::ClientInitError);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
@@ -55,18 +56,18 @@ impl IntoResponse for RouteError {
}
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(provider_id): Path<Ulid>,
Query(query): Query<OptionalPostAuthAction>,
) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
let provider = lookup_provider(&mut txn, provider_id)
let provider = repo
.upstream_oauth_provider()
.lookup(provider_id)
.await?
.ok_or(RouteError::ProviderNotFound)?;
@@ -95,22 +96,23 @@ pub(crate) async fn get(
&mut rng,
)?;
let session = mas_storage::upstream_oauth2::add_session(
&mut txn,
&mut rng,
&clock,
&provider,
data.state.clone(),
data.code_challenge_verifier,
data.nonce,
)
.await?;
let session = repo
.upstream_oauth_session()
.add(
&mut rng,
&clock,
&provider,
data.state.clone(),
data.code_challenge_verifier,
data.nonce,
)
.await?;
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
.add(session.id, provider.id, data.state, query.post_auth_action)
.save(cookie_jar, clock.now());
.save(cookie_jar, &clock);
txn.commit().await?;
repo.save().await?;
Ok((cookie_jar, Redirect::temporary(url.as_str())))
}

View File

@@ -25,12 +25,15 @@ use mas_oidc_client::requests::{
authorization_code::AuthorizationValidationData, jose::JwtVerificationData,
};
use mas_router::{Route, UrlBuilder};
use mas_storage::upstream_oauth2::{
add_link, complete_session, lookup_link_by_subject, lookup_session,
use mas_storage::{
upstream_oauth2::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
UpstreamOAuthSessionRepository,
},
BoxClock, BoxRepository, BoxRng, Clock,
};
use oauth2_types::errors::ClientErrorCode;
use serde::Deserialize;
use sqlx::PgPool;
use thiserror::Error;
use ulid::Ulid;
@@ -64,6 +67,9 @@ pub(crate) enum RouteError {
#[error("Session not found")]
SessionNotFound,
#[error("Provider not found")]
ProviderNotFound,
#[error("Provider mismatch")]
ProviderMismatch,
@@ -92,9 +98,8 @@ pub(crate) enum RouteError {
Internal(Box<dyn std::error::Error>),
}
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_http::ClientInitError);
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::JwksError);
impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError);
@@ -104,6 +109,7 @@ impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
match self {
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
@@ -113,8 +119,10 @@ impl IntoResponse for RouteError {
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>,
State(encrypter): State<Encrypter>,
State(keystore): State<Keystore>,
@@ -122,30 +130,34 @@ pub(crate) async fn get(
Path(provider_id): Path<Ulid>,
Query(params): Query<QueryParams>,
) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
let provider = repo
.upstream_oauth_provider()
.lookup(provider_id)
.await?
.ok_or(RouteError::ProviderNotFound)?;
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
let (session_id, _post_auth_action) = sessions_cookie
.find_session(provider_id, &params.state)
.map_err(|_| RouteError::MissingCookie)?;
let (provider, session) = lookup_session(&mut txn, session_id)
let session = repo
.upstream_oauth_session()
.lookup(session_id)
.await?
.ok_or(RouteError::SessionNotFound)?;
if provider.id != provider_id {
if provider.id != session.provider_id {
// The provider in the session cookie should match the one from the URL
return Err(RouteError::ProviderMismatch);
}
if params.state != session.state {
if params.state != session.state_str {
// The state in the session cookie should match the one from the params
return Err(RouteError::StateMismatch);
}
if session.completed() {
if !session.is_pending() {
// The session was already completed
return Err(RouteError::AlreadyCompleted);
}
@@ -194,7 +206,7 @@ pub(crate) async fn get(
// TODO: all that should be borrowed
let validation_data = AuthorizationValidationData {
state: session.state.clone(),
state: session.state_str.clone(),
nonce: session.nonce.clone(),
code_challenge_verifier: session.code_challenge_verifier.clone(),
redirect_uri,
@@ -231,20 +243,29 @@ pub(crate) async fn get(
let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?;
// Look for an existing link
let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject).await?;
let maybe_link = repo
.upstream_oauth_link()
.find_by_subject(&provider, &subject)
.await?;
let link = if let Some(link) = maybe_link {
link
} else {
add_link(&mut txn, &mut rng, &clock, &provider, subject).await?
repo.upstream_oauth_link()
.add(&mut rng, &clock, &provider, subject)
.await?
};
let session = complete_session(&mut txn, &clock, session, &link, response.id_token).await?;
let session = repo
.upstream_oauth_session()
.complete_with_link(&clock, session, &link, response.id_token)
.await?;
let cookie_jar = sessions_cookie
.add_link_to_session(session.id, link.id)?
.save(cookie_jar, clock.now());
.save(cookie_jar, &clock);
txn.commit().await?;
repo.save().await?;
Ok((
cookie_jar,

View File

@@ -18,6 +18,7 @@ use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
use chrono::{DateTime, Duration, NaiveDateTime, Utc};
use mas_axum_utils::CookieExt;
use mas_router::PostAuthAction;
use mas_storage::Clock;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use time::OffsetDateTime;
@@ -65,11 +66,11 @@ impl UpstreamSessions {
}
/// Save the upstreams sessions to the cookie jar
pub fn save<K>(
self,
cookie_jar: PrivateCookieJar<K>,
now: DateTime<Utc>,
) -> PrivateCookieJar<K> {
pub fn save<K, C>(self, cookie_jar: PrivateCookieJar<K>, clock: &C) -> PrivateCookieJar<K>
where
C: Clock,
{
let now = clock.now();
let this = self.expire(now);
let mut cookie = Cookie::named(COOKIE_NAME).encode(&this);
cookie.set_path("/");

View File

@@ -25,17 +25,15 @@ use mas_axum_utils::{
};
use mas_keystore::Encrypter;
use mas_storage::{
upstream_oauth2::{
associate_link_to_user, consume_session, lookup_link, lookup_session_on_link,
},
user::{add_user, authenticate_session_with_upstream, lookup_user, start_session},
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
user::{BrowserSessionRepository, UserRepository},
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::{
EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink,
};
use serde::Deserialize;
use sqlx::PgPool;
use thiserror::Error;
use ulid::Ulid;
@@ -52,6 +50,10 @@ pub(crate) enum RouteError {
#[error("Session not found")]
SessionNotFound,
/// Couldn't find the user
#[error("User not found")]
UserNotFound,
/// Session was already consumed
#[error("Session already consumed")]
SessionConsumed,
@@ -66,11 +68,10 @@ pub(crate) enum RouteError {
Internal(Box<dyn std::error::Error>),
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
@@ -91,48 +92,60 @@ pub(crate) enum FormData {
}
pub(crate) async fn get(
State(pool): State<PgPool>,
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
State(templates): State<Templates>,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(link_id): Path<Ulid>,
) -> Result<impl IntoResponse, RouteError> {
let mut txn = pool.begin().await?;
let (clock, mut rng) = crate::clock_and_rng();
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
let (session_id, _post_auth_action) = sessions_cookie
.lookup_link(link_id)
.map_err(|_| RouteError::MissingCookie)?;
let link = lookup_link(&mut txn, link_id)
let link = repo
.upstream_oauth_link()
.lookup(link_id)
.await?
.ok_or(RouteError::LinkNotFound)?;
// This checks that we're in a browser session which is allowed to consume this
// link: the upstream auth session should have been started in this browser.
let upstream_session = lookup_session_on_link(&mut txn, &link, session_id)
let upstream_session = repo
.upstream_oauth_session()
.lookup(session_id)
.await?
.ok_or(RouteError::SessionNotFound)?;
if upstream_session.consumed() {
// This checks that we're in a browser session which is allowed to consume this
// link: the upstream auth session should have been started in this browser.
if upstream_session.link_id() != Some(link.id) {
return Err(RouteError::SessionNotFound);
}
if upstream_session.is_consumed() {
return Err(RouteError::SessionConsumed);
}
let (user_session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let maybe_user_session = user_session_info.load_session(&mut txn).await?;
let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
let render = match (maybe_user_session, link.user_id) {
(Some(mut session), Some(user_id)) if session.user.id == user_id => {
(Some(session), Some(user_id)) if session.user.id == user_id => {
// Session already linked, and link matches the currently logged
// user. Mark the session as consumed and renew the authentication.
consume_session(&mut txn, &clock, upstream_session).await?;
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link)
repo.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;
let session = repo
.browser_session()
.authenticate_with_upstream(&mut rng, &clock, session, &link)
.await?;
cookie_jar = cookie_jar.set_session(&session);
txn.commit().await?;
repo.save().await?;
let ctx = EmptyContext
.with_session(session)
@@ -147,7 +160,11 @@ pub(crate) async fn get(
// Session already linked, but link doesn't match the currently
// logged user. Suggest logging out of the current user
// and logging in with the new one
let user = lookup_user(&mut txn, user_id).await?;
let user = repo
.user()
.lookup(user_id)
.await?
.ok_or(RouteError::UserNotFound)?;
let ctx = UpstreamExistingLinkContext::new(user)
.with_session(user_session)
@@ -167,7 +184,11 @@ pub(crate) async fn get(
(None, Some(user_id)) => {
// Session linked, but user not logged in: do the login
let user = lookup_user(&mut txn, user_id).await?;
let user = repo
.user()
.lookup(user_id)
.await?
.ok_or(RouteError::UserNotFound)?;
let ctx = UpstreamExistingLinkContext::new(user).with_csrf(csrf_token.form_value());
@@ -187,14 +208,14 @@ pub(crate) async fn get(
}
pub(crate) async fn post(
State(pool): State<PgPool>,
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(link_id): Path<Ulid>,
Form(form): Form<ProtectedForm<FormData>>,
) -> Result<impl IntoResponse, RouteError> {
let mut txn = pool.begin().await?;
let (clock, mut rng) = crate::clock_and_rng();
let form = cookie_jar.verify_form(clock.now(), form)?;
let form = cookie_jar.verify_form(&clock, form)?;
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
let (session_id, post_auth_action) = sessions_cookie
@@ -205,53 +226,77 @@ pub(crate) async fn post(
post_auth_action: post_auth_action.cloned(),
};
let link = lookup_link(&mut txn, link_id)
let link = repo
.upstream_oauth_link()
.lookup(link_id)
.await?
.ok_or(RouteError::LinkNotFound)?;
// This checks that we're in a browser session which is allowed to consume this
// link: the upstream auth session should have been started in this browser.
let upstream_session = lookup_session_on_link(&mut txn, &link, session_id)
let upstream_session = repo
.upstream_oauth_session()
.lookup(session_id)
.await?
.ok_or(RouteError::SessionNotFound)?;
if upstream_session.consumed() {
// This checks that we're in a browser session which is allowed to consume this
// link: the upstream auth session should have been started in this browser.
if upstream_session.link_id() != Some(link.id) {
return Err(RouteError::SessionNotFound);
}
if upstream_session.is_consumed() {
return Err(RouteError::SessionConsumed);
}
let (user_session_info, cookie_jar) = cookie_jar.session_info();
let maybe_user_session = user_session_info.load_session(&mut txn).await?;
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
let mut session = match (maybe_user_session, link.user_id, form) {
let session = match (maybe_user_session, link.user_id, form) {
(Some(session), None, FormData::Link) => {
associate_link_to_user(&mut txn, &link, &session.user).await?;
repo.upstream_oauth_link()
.associate_to_user(&link, &session.user)
.await?;
session
}
(None, Some(user_id), FormData::Login) => {
let user = lookup_user(&mut txn, user_id).await?;
start_session(&mut txn, &mut rng, &clock, user).await?
let user = repo
.user()
.lookup(user_id)
.await?
.ok_or(RouteError::UserNotFound)?;
repo.browser_session().add(&mut rng, &clock, &user).await?
}
(None, None, FormData::Register { username }) => {
let user = add_user(&mut txn, &mut rng, &clock, &username).await?;
associate_link_to_user(&mut txn, &link, &user).await?;
let user = repo.user().add(&mut rng, &clock, username).await?;
repo.upstream_oauth_link()
.associate_to_user(&link, &user)
.await?;
start_session(&mut txn, &mut rng, &clock, user).await?
repo.browser_session().add(&mut rng, &clock, &user).await?
}
_ => return Err(RouteError::InvalidFormAction),
};
consume_session(&mut txn, &clock, upstream_session).await?;
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link).await?;
repo.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;
let session = repo
.browser_session()
.authenticate_with_upstream(&mut rng, &clock, session, &link)
.await?;
let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, clock.now());
.save(cookie_jar, &clock);
let cookie_jar = cookie_jar.set_session(&session);
txn.commit().await?;
repo.save().await?;
Ok((cookie_jar, post_auth_action.go_next()))
}

View File

@@ -24,10 +24,9 @@ use mas_axum_utils::{
use mas_email::Mailer;
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::user::add_user_email;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng};
use mas_templates::{EmailAddContext, TemplateContext, Templates};
use serde::Deserialize;
use sqlx::PgPool;
use super::start_email_verification;
use crate::views::shared::OptionalPostAuthAction;
@@ -38,17 +37,16 @@ pub struct EmailForm {
}
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.begin().await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session {
session
@@ -67,19 +65,18 @@ pub(crate) async fn get(
}
pub(crate) async fn post(
State(pool): State<PgPool>,
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
State(mailer): State<Mailer>,
cookie_jar: PrivateCookieJar<Encrypter>,
Query(query): Query<OptionalPostAuthAction>,
Form(form): Form<ProtectedForm<EmailForm>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(clock.now(), form)?;
let form = cookie_jar.verify_form(&clock, form)?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session {
session
@@ -88,7 +85,11 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response());
};
let user_email = add_user_email(&mut txn, &mut rng, &clock, &session.user, form.email).await?;
let user_email = repo
.user_email()
.add(&mut rng, &clock, &session.user, form.email)
.await?;
let next = mas_router::AccountVerifyEmail::new(user_email.id);
let next = if let Some(action) = query.post_auth_action {
next.and_then(action)
@@ -97,7 +98,7 @@ pub(crate) async fn post(
};
start_email_verification(
&mailer,
&mut txn,
&mut repo,
&mut rng,
&clock,
&session.user,
@@ -105,7 +106,7 @@ pub(crate) async fn post(
)
.await?;
txn.commit().await?;
repo.save().await?;
Ok((cookie_jar, next.go()).into_response())
}

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::{anyhow, Context};
use axum::{
extract::{Form, State},
response::{Html, IntoResponse, Response},
@@ -28,16 +29,11 @@ use mas_email::Mailer;
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
user::{
add_user_email, add_user_email_verification_code, get_user_email, get_user_emails,
remove_user_email, set_user_email_as_primary,
},
Clock,
user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
};
use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates};
use rand::{distributions::Uniform, Rng};
use serde::Deserialize;
use sqlx::{PgExecutor, PgPool};
use tracing::info;
pub mod add;
@@ -53,37 +49,35 @@ pub enum ManagementForm {
}
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
if let Some(session) = maybe_session {
render(&mut rng, &clock, templates, session, cookie_jar, &mut conn).await
render(&mut rng, &clock, templates, session, cookie_jar, &mut repo).await
} else {
let login = mas_router::Login::default();
Ok((cookie_jar, login.go()).into_response())
}
}
async fn render(
async fn render<E: std::error::Error>(
rng: impl Rng + Send,
clock: &Clock,
clock: &impl Clock,
templates: Templates,
session: BrowserSession,
cookie_jar: PrivateCookieJar<Encrypter>,
executor: impl PgExecutor<'_>,
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
let emails = get_user_emails(executor, &session.user).await?;
let emails = repo.user_email().all(&session.user).await?;
let ctx = AccountEmailsContext::new(emails)
.with_session(session)
@@ -94,11 +88,11 @@ async fn render(
Ok((cookie_jar, Html(content)).into_response())
}
async fn start_email_verification(
async fn start_email_verification<E: std::error::Error + Send + Sync + 'static>(
mailer: &Mailer,
executor: impl PgExecutor<'_>,
repo: &mut impl RepositoryAccess<Error = E>,
mut rng: impl Rng + Send,
clock: &Clock,
clock: &impl Clock,
user: &User,
user_email: UserEmail,
) -> anyhow::Result<()> {
@@ -108,15 +102,10 @@ async fn start_email_verification(
let address: Address = user_email.email.parse()?;
let verification = add_user_email_verification_code(
executor,
&mut rng,
clock,
user_email,
Duration::hours(8),
code,
)
.await?;
let verification = repo
.user_email()
.add_verification_code(&mut rng, clock, &user_email, Duration::hours(8), code)
.await?;
// And send the verification email
let mailbox = Mailbox::new(Some(user.username.clone()), address);
@@ -126,25 +115,24 @@ async fn start_email_verification(
mailer.send_verification_email(mailbox, &context).await?;
info!(
email.id = %verification.email.id,
email.id = %user_email.id,
"Verification email sent"
);
Ok(())
}
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
State(mailer): State<Mailer>,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ManagementForm>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let mut session = if let Some(session) = maybe_session {
session
@@ -153,53 +141,69 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response());
};
let form = cookie_jar.verify_form(clock.now(), form)?;
let form = cookie_jar.verify_form(&clock, form)?;
match form {
ManagementForm::Add { email } => {
let user_email =
add_user_email(&mut txn, &mut rng, &clock, &session.user, email).await?;
let next = mas_router::AccountVerifyEmail::new(user_email.id);
start_email_verification(
&mailer,
&mut txn,
&mut rng,
&clock,
&session.user,
user_email,
)
.await?;
txn.commit().await?;
let email = repo
.user_email()
.add(&mut rng, &clock, &session.user, email)
.await?;
let next = mas_router::AccountVerifyEmail::new(email.id);
start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email)
.await?;
repo.save().await?;
return Ok((cookie_jar, next.go()).into_response());
}
ManagementForm::ResendConfirmation { id } => {
let id = id.parse()?;
let user_email = get_user_email(&mut txn, &session.user, id).await?;
let next = mas_router::AccountVerifyEmail::new(user_email.id);
start_email_verification(
&mailer,
&mut txn,
&mut rng,
&clock,
&session.user,
user_email,
)
.await?;
txn.commit().await?;
let email = repo
.user_email()
.lookup(id)
.await?
.context("Email not found")?;
if email.user_id != session.user.id {
return Err(anyhow!("Email not found").into());
}
let next = mas_router::AccountVerifyEmail::new(email.id);
start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email)
.await?;
repo.save().await?;
return Ok((cookie_jar, next.go()).into_response());
}
ManagementForm::Remove { id } => {
let id = id.parse()?;
let email = get_user_email(&mut txn, &session.user, id).await?;
remove_user_email(&mut txn, email).await?;
let email = repo
.user_email()
.lookup(id)
.await?
.context("Email not found")?;
if email.user_id != session.user.id {
return Err(anyhow!("Email not found").into());
}
repo.user_email().remove(email).await?;
}
ManagementForm::SetPrimary { id } => {
let id = id.parse()?;
let email = get_user_email(&mut txn, &session.user, id).await?;
set_user_email_as_primary(&mut txn, &email).await?;
session.user.primary_email = Some(email);
let email = repo
.user_email()
.lookup(id)
.await?
.context("Email not found")?;
if email.user_id != session.user.id {
return Err(anyhow!("Email not found").into());
}
repo.user_email().set_as_primary(&email).await?;
session.user.primary_user_email_id = Some(email.id);
}
};
@@ -209,11 +213,11 @@ pub(crate) async fn post(
templates.clone(),
session,
cookie_jar,
&mut txn,
&mut repo,
)
.await?;
txn.commit().await?;
repo.save().await?;
Ok(reply)
}

View File

@@ -24,16 +24,9 @@ use mas_axum_utils::{
};
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
user::{
consume_email_verification, lookup_user_email_by_id, lookup_user_email_verification_code,
mark_user_email_as_verified, set_user_email_as_primary,
},
Clock,
};
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng};
use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates};
use serde::Deserialize;
use sqlx::PgPool;
use ulid::Ulid;
use crate::views::shared::OptionalPostAuthAction;
@@ -44,19 +37,18 @@ pub struct CodeForm {
}
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
Path(id): Path<Ulid>,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session {
session
@@ -65,8 +57,11 @@ pub(crate) async fn get(
return Ok((cookie_jar, login.go()).into_response());
};
let user_email = lookup_user_email_by_id(&mut conn, &session.user, id)
let user_email = repo
.user_email()
.lookup(id)
.await?
.filter(|u| u.user_id == session.user.id)
.context("Could not find user email")?;
if user_email.confirmed_at.is_some() {
@@ -85,19 +80,17 @@ pub(crate) async fn get(
}
pub(crate) async fn post(
State(pool): State<PgPool>,
clock: BoxClock,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Query(query): Query<OptionalPostAuthAction>,
Path(id): Path<Ulid>,
Form(form): Form<ProtectedForm<CodeForm>>,
) -> Result<Response, FancyError> {
let clock = Clock::default();
let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(clock.now(), form)?;
let form = cookie_jar.verify_form(&clock, form)?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session {
session
@@ -106,25 +99,33 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response());
};
let email = lookup_user_email_by_id(&mut txn, &session.user, id)
let user_email = repo
.user_email()
.lookup(id)
.await?
.filter(|u| u.user_id == session.user.id)
.context("Could not find user email")?;
if session.user.primary_email.is_none() {
set_user_email_as_primary(&mut txn, &email).await?;
}
// TODO: make those 8 hours configurable
let verification = lookup_user_email_verification_code(&mut txn, &clock, email, &form.code)
let verification = repo
.user_email()
.find_verification_code(&clock, &user_email, &form.code)
.await?
.context("Invalid code")?;
// TODO: display nice errors if the code was already consumed or expired
let verification = consume_email_verification(&mut txn, &clock, verification).await?;
repo.user_email()
.consume_verification_code(&clock, verification)
.await?;
let _email = mark_user_email_as_verified(&mut txn, &clock, verification.email).await?;
if session.user.primary_user_email_id.is_none() {
repo.user_email().set_as_primary(&user_email).await?;
}
txn.commit().await?;
repo.user_email()
.mark_as_verified(&clock, user_email)
.await?;
repo.save().await?;
let destination = query.go_next_or_default(&mas_router::AccountEmails);
Ok((cookie_jar, destination).into_response())

View File

@@ -23,22 +23,23 @@ use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt};
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::user::{count_active_sessions, get_user_emails};
use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository},
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::{AccountContext, TemplateContext, Templates};
use sqlx::PgPool;
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session {
session
@@ -47,9 +48,9 @@ pub(crate) async fn get(
return Ok((cookie_jar, login.go()).into_response());
};
let active_sessions = count_active_sessions(&mut conn, &session.user).await?;
let active_sessions = repo.browser_session().count_active(&session.user).await?;
let emails = get_user_emails(&mut conn, &session.user).await?;
let emails = repo.user_email().all(&session.user).await?;
let ctx = AccountContext::new(active_sessions, emails)
.with_session(session)

View File

@@ -26,13 +26,12 @@ use mas_data_model::BrowserSession;
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
user::{add_user_password, authenticate_session_with_password, lookup_user_password},
Clock,
user::{BrowserSessionRepository, UserPasswordRepository},
BoxClock, BoxRepository, BoxRng, Clock,
};
use mas_templates::{EmptyContext, TemplateContext, Templates};
use rand::Rng;
use serde::Deserialize;
use sqlx::PgPool;
use zeroize::Zeroizing;
use crate::passwords::PasswordManager;
@@ -45,16 +44,15 @@ pub struct ChangeForm {
}
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
if let Some(session) = maybe_session {
render(&mut rng, &clock, templates, session, cookie_jar).await
@@ -66,12 +64,12 @@ pub(crate) async fn get(
async fn render(
rng: impl Rng + Send,
clock: &Clock,
clock: &impl Clock,
templates: Templates,
session: BrowserSession,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
let ctx = EmptyContext
.with_session(session)
@@ -83,29 +81,30 @@ async fn render(
}
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(password_manager): State<PasswordManager>,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ChangeForm>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(clock.now(), form)?;
let form = cookie_jar.verify_form(&clock, form)?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let mut session = if let Some(session) = maybe_session {
let session = if let Some(session) = maybe_session {
session
} else {
let login = mas_router::Login::and_then(mas_router::PostAuthAction::ChangePassword);
return Ok((cookie_jar, login.go()).into_response());
};
let user_password = lookup_user_password(&mut txn, &session.user)
let user_password = repo
.user_password()
.active(&session.user)
.await?
.context("user has no password")?;
@@ -127,23 +126,26 @@ pub(crate) async fn post(
}
let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?;
let user_password = add_user_password(
&mut txn,
&mut rng,
&clock,
&session.user,
version,
hashed_password,
None,
)
.await?;
let user_password = repo
.user_password()
.add(
&mut rng,
&clock,
&session.user,
version,
hashed_password,
None,
)
.await?;
authenticate_session_with_password(&mut txn, &mut rng, &clock, &mut session, &user_password)
let session = repo
.browser_session()
.authenticate_with_password(&mut rng, &clock, session, &user_password)
.await?;
let reply = render(&mut rng, &clock, templates.clone(), session, cookie_jar).await?;
txn.commit().await?;
repo.save().await?;
Ok(reply)
}

View File

@@ -20,21 +20,20 @@ use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt};
use mas_keystore::Encrypter;
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_templates::{IndexContext, TemplateContext, Templates};
use sqlx::PgPool;
pub async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<impl IntoResponse, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let session = session_info.load_session(&mut conn).await?;
let session = session_info.load_session(&mut repo).await?;
let ctx = IndexContext::new(url_builder.oidc_discovery())
.maybe_with_session(session)

View File

@@ -24,18 +24,15 @@ use mas_axum_utils::{
use mas_data_model::BrowserSession;
use mas_keystore::Encrypter;
use mas_storage::{
user::{
add_user_password, authenticate_session_with_password, lookup_user_by_username,
lookup_user_password, start_session,
},
Clock,
upstream_oauth2::UpstreamOAuthProviderRepository,
user::{BrowserSessionRepository, UserPasswordRepository, UserRepository},
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
};
use mas_templates::{
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
};
use rand::{CryptoRng, Rng};
use serde::{Deserialize, Serialize};
use sqlx::{PgConnection, PgPool};
use zeroize::Zeroizing;
use super::shared::OptionalPostAuthAction;
@@ -52,29 +49,28 @@ impl ToFormState for LoginForm {
}
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
if maybe_session.is_some() {
let reply = query.go_next();
Ok((cookie_jar, reply).into_response())
} else {
let providers = mas_storage::upstream_oauth2::get_providers(&mut conn).await?;
let providers = repo.upstream_oauth_provider().all().await?;
let content = render(
LoginContext::default().with_upstrem_providers(providers),
query,
csrf_token,
&mut conn,
&mut repo,
&templates,
)
.await?;
@@ -84,19 +80,18 @@ pub(crate) async fn get(
}
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(password_manager): State<PasswordManager>,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<LoginForm>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let form = cookie_jar.verify_form(&clock, form)?;
let form = cookie_jar.verify_form(clock.now(), form)?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
// Validate the form
let state = {
@@ -114,14 +109,14 @@ pub(crate) async fn post(
};
if !state.is_valid() {
let providers = mas_storage::upstream_oauth2::get_providers(&mut conn).await?;
let providers = repo.upstream_oauth_provider().all().await?;
let content = render(
LoginContext::default()
.with_form_state(state)
.with_upstrem_providers(providers),
query,
csrf_token,
&mut conn,
&mut repo,
&templates,
)
.await?;
@@ -129,11 +124,9 @@ pub(crate) async fn post(
return Ok((cookie_jar, Html(content)).into_response());
}
lookup_user_by_username(&mut conn, &form.username).await?;
match login(
password_manager,
&mut conn,
&mut repo,
rng,
&clock,
&form.username,
@@ -142,6 +135,8 @@ pub(crate) async fn post(
.await
{
Ok(session_info) => {
repo.save().await?;
let cookie_jar = cookie_jar.set_session(&session_info);
let reply = query.go_next();
Ok((cookie_jar, reply).into_response())
@@ -153,7 +148,7 @@ pub(crate) async fn post(
LoginContext::default().with_form_state(state),
query,
csrf_token,
&mut conn,
&mut repo,
&templates,
)
.await?;
@@ -166,21 +161,25 @@ pub(crate) async fn post(
// TODO: move that logic elsewhere?
async fn login(
password_manager: PasswordManager,
conn: &mut PgConnection,
repo: &mut impl RepositoryAccess,
mut rng: impl Rng + CryptoRng + Send,
clock: &Clock,
clock: &impl Clock,
username: &str,
password: &str,
) -> Result<BrowserSession, FormError> {
// XXX: we're loosing the error context here
// First, lookup the user
let user = lookup_user_by_username(&mut *conn, username)
let user = repo
.user()
.find_by_username(username)
.await
.map_err(|_e| FormError::Internal)?
.ok_or(FormError::InvalidCredentials)?;
// And its password
let user_password = lookup_user_password(&mut *conn, &user)
let user_password = repo
.user_password()
.active(&user)
.await
.map_err(|_e| FormError::Internal)?
.ok_or(FormError::InvalidCredentials)?;
@@ -200,28 +199,32 @@ async fn login(
let user_password = if let Some((version, new_password_hash)) = new_password_hash {
// Save the upgraded password
add_user_password(
&mut *conn,
&mut rng,
clock,
&user,
version,
new_password_hash,
Some(user_password),
)
.await
.map_err(|_| FormError::Internal)?
repo.user_password()
.add(
&mut rng,
clock,
&user,
version,
new_password_hash,
Some(&user_password),
)
.await
.map_err(|_| FormError::Internal)?
} else {
user_password
};
// Start a new session
let mut user_session = start_session(&mut *conn, &mut rng, clock, user)
let user_session = repo
.browser_session()
.add(&mut rng, clock, &user)
.await
.map_err(|_| FormError::Internal)?;
// And mark it as authenticated by the password
authenticate_session_with_password(&mut *conn, rng, clock, &mut user_session, &user_password)
let user_session = repo
.browser_session()
.authenticate_with_password(&mut rng, clock, user_session, &user_password)
.await
.map_err(|_| FormError::Internal)?;
@@ -232,10 +235,10 @@ async fn render(
ctx: LoginContext,
action: OptionalPostAuthAction,
csrf_token: CsrfToken,
conn: &mut PgConnection,
repo: &mut impl RepositoryAccess,
templates: &Templates,
) -> Result<String, FancyError> {
let next = action.load_context(conn).await?;
let next = action.load_context(repo).await?;
let ctx = if let Some(next) = next {
ctx.with_post_action(next)
} else {

View File

@@ -12,10 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use axum::{
extract::{Form, State},
response::IntoResponse,
};
use axum::{extract::Form, response::IntoResponse};
use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm},
@@ -23,29 +20,26 @@ use mas_axum_utils::{
};
use mas_keystore::Encrypter;
use mas_router::{PostAuthAction, Route};
use mas_storage::{user::end_session, Clock};
use sqlx::PgPool;
use mas_storage::{user::BrowserSessionRepository, BoxClock, BoxRepository};
pub(crate) async fn post(
State(pool): State<PgPool>,
clock: BoxClock,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<Option<PostAuthAction>>>,
) -> Result<impl IntoResponse, FancyError> {
let clock = Clock::default();
let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(clock.now(), form)?;
let form = cookie_jar.verify_form(&clock, form)?;
let (session_info, mut cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
if let Some(session) = maybe_session {
end_session(&mut txn, &clock, &session).await?;
repo.browser_session().finish(&clock, session).await?;
cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended());
}
txn.commit().await?;
repo.save().await?;
let destination = if let Some(action) = form {
action.go_next()

View File

@@ -24,12 +24,12 @@ use mas_axum_utils::{
};
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::user::{
add_user_password, authenticate_session_with_password, lookup_user_password,
use mas_storage::{
user::{BrowserSessionRepository, UserPasswordRepository},
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::{ReauthContext, TemplateContext, Templates};
use serde::Deserialize;
use sqlx::PgPool;
use zeroize::Zeroizing;
use super::shared::OptionalPostAuthAction;
@@ -41,18 +41,17 @@ pub(crate) struct ReauthForm {
}
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session {
session
@@ -64,7 +63,7 @@ pub(crate) async fn get(
};
let ctx = ReauthContext::default();
let next = query.load_context(&mut conn).await?;
let next = query.load_context(&mut repo).await?;
let ctx = if let Some(next) = next {
ctx.with_post_action(next)
} else {
@@ -78,22 +77,21 @@ pub(crate) async fn get(
}
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(password_manager): State<PasswordManager>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ReauthForm>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(clock.now(), form)?;
let form = cookie_jar.verify_form(&clock, form)?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let mut session = if let Some(session) = maybe_session {
let session = if let Some(session) = maybe_session {
session
} else {
// If there is no session, redirect to the login screen, keeping the
@@ -103,7 +101,9 @@ pub(crate) async fn post(
};
// Load the user password
let user_password = lookup_user_password(&mut txn, &session.user)
let user_password = repo
.user_password()
.active(&session.user)
.await?
.context("User has no password")?;
@@ -122,25 +122,28 @@ pub(crate) async fn post(
let user_password = if let Some((version, new_password_hash)) = new_password_hash {
// Save the upgraded password
add_user_password(
&mut *txn,
&mut rng,
&clock,
&session.user,
version,
new_password_hash,
Some(user_password),
)
.await?
repo.user_password()
.add(
&mut rng,
&clock,
&session.user,
version,
new_password_hash,
Some(&user_password),
)
.await?
} else {
user_password
};
// Mark the session as authenticated by the password
authenticate_session_with_password(&mut txn, rng, &clock, &mut session, &user_password).await?;
let session = repo
.browser_session()
.authenticate_with_password(&mut rng, &clock, session, &user_password)
.await?;
let cookie_jar = cookie_jar.set_session(&session);
txn.commit().await?;
repo.save().await?;
let reply = query.go_next();
Ok((cookie_jar, reply).into_response())

View File

@@ -31,9 +31,9 @@ use mas_email::Mailer;
use mas_keystore::Encrypter;
use mas_policy::PolicyFactory;
use mas_router::Route;
use mas_storage::user::{
add_user, add_user_email, add_user_email_verification_code, add_user_password,
authenticate_session_with_password, start_session, username_exists,
use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
};
use mas_templates::{
EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField,
@@ -41,7 +41,6 @@ use mas_templates::{
};
use rand::{distributions::Uniform, Rng};
use serde::{Deserialize, Serialize};
use sqlx::{PgConnection, PgPool};
use zeroize::Zeroizing;
use super::shared::OptionalPostAuthAction;
@@ -60,18 +59,17 @@ impl ToFormState for RegisterForm {
}
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
if maybe_session.is_some() {
let reply = query.go_next();
@@ -81,7 +79,7 @@ pub(crate) async fn get(
RegisterContext::default(),
query,
csrf_token,
&mut conn,
&mut repo,
&templates,
)
.await?;
@@ -92,21 +90,20 @@ pub(crate) async fn get(
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(password_manager): State<PasswordManager>,
State(mailer): State<Mailer>,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
State(pool): State<PgPool>,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<RegisterForm>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(&clock, form)?;
let form = cookie_jar.verify_form(clock.now(), form)?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
// Validate the form
let state = {
@@ -114,7 +111,7 @@ pub(crate) async fn post(
if form.username.is_empty() {
state.add_error_on_field(RegisterFormField::Username, FieldError::Required);
} else if username_exists(&mut txn, &form.username).await? {
} else if repo.user().exists(&form.username).await? {
state.add_error_on_field(RegisterFormField::Username, FieldError::Exists);
}
@@ -177,7 +174,7 @@ pub(crate) async fn post(
RegisterContext::default().with_form_state(state),
query,
csrf_token,
&mut txn,
&mut repo,
&templates,
)
.await?;
@@ -185,21 +182,18 @@ pub(crate) async fn post(
return Ok((cookie_jar, Html(content)).into_response());
}
let user = add_user(&mut txn, &mut rng, &clock, &form.username).await?;
let user = repo.user().add(&mut rng, &clock, form.username).await?;
let password = Zeroizing::new(form.password.into_bytes());
let (version, hashed_password) = password_manager.hash(&mut rng, password).await?;
let user_password = add_user_password(
&mut txn,
&mut rng,
&clock,
&user,
version,
hashed_password,
None,
)
.await?;
let user_password = repo
.user_password()
.add(&mut rng, &clock, &user, version, hashed_password, None)
.await?;
let user_email = add_user_email(&mut txn, &mut rng, &clock, &user, form.email).await?;
let user_email = repo
.user_email()
.add(&mut rng, &clock, &user, form.email)
.await?;
// First, generate a code
let range = Uniform::<u32>::from(0..1_000_000);
@@ -208,15 +202,10 @@ pub(crate) async fn post(
let address: Address = user_email.email.parse()?;
let verification = add_user_email_verification_code(
&mut txn,
&mut rng,
&clock,
user_email,
Duration::hours(8),
code,
)
.await?;
let verification = repo
.user_email()
.add_verification_code(&mut rng, &clock, &user_email, Duration::hours(8), code)
.await?;
// And send the verification email
let mailbox = Mailbox::new(Some(user.username.clone()), address);
@@ -225,14 +214,16 @@ pub(crate) async fn post(
mailer.send_verification_email(mailbox, &context).await?;
let next = mas_router::AccountVerifyEmail::new(verification.email.id)
.and_maybe(query.post_auth_action);
let next = mas_router::AccountVerifyEmail::new(user_email.id).and_maybe(query.post_auth_action);
let mut session = start_session(&mut txn, &mut rng, &clock, user).await?;
authenticate_session_with_password(&mut txn, &mut rng, &clock, &mut session, &user_password)
let session = repo.browser_session().add(&mut rng, &clock, &user).await?;
let session = repo
.browser_session()
.authenticate_with_password(&mut rng, &clock, session, &user_password)
.await?;
txn.commit().await?;
repo.save().await?;
let cookie_jar = cookie_jar.set_session(&session);
Ok((cookie_jar, next.go()).into_response())
@@ -242,10 +233,10 @@ async fn render(
ctx: RegisterContext,
action: OptionalPostAuthAction,
csrf_token: CsrfToken,
conn: &mut PgConnection,
repo: &mut impl RepositoryAccess,
templates: &Templates,
) -> Result<String, FancyError> {
let next = action.load_context(conn).await?;
let next = action.load_context(repo).await?;
let ctx = if let Some(next) = next {
ctx.with_post_action(next)
} else {

View File

@@ -15,11 +15,13 @@
use anyhow::Context;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
compat::CompatSsoLoginRepository,
oauth2::OAuth2AuthorizationGrantRepository,
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
RepositoryAccess,
};
use mas_templates::{PostAuthContext, PostAuthContextInner};
use serde::{Deserialize, Serialize};
use sqlx::PgConnection;
#[derive(Serialize, Deserialize, Default, Debug, Clone)]
pub(crate) struct OptionalPostAuthAction {
@@ -38,14 +40,16 @@ impl OptionalPostAuthAction {
self.go_next_or_default(&mas_router::Index)
}
pub async fn load_context(
&self,
conn: &mut PgConnection,
pub async fn load_context<'a>(
&'a self,
repo: &'a mut impl RepositoryAccess,
) -> anyhow::Result<Option<PostAuthContext>> {
let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
let ctx = match action {
PostAuthAction::ContinueAuthorizationGrant { id } => {
let grant = get_grant_by_id(conn, id)
let grant = repo
.oauth2_authorization_grant()
.lookup(id)
.await?
.context("Failed to load authorization grant")?;
let grant = Box::new(grant);
@@ -53,7 +57,9 @@ impl OptionalPostAuthAction {
}
PostAuthAction::ContinueCompatSsoLogin { id } => {
let login = get_compat_sso_login_by_id(conn, id)
let login = repo
.compat_sso_login()
.lookup(id)
.await?
.context("Failed to load compat SSO login")?;
let login = Box::new(login);
@@ -63,14 +69,17 @@ impl OptionalPostAuthAction {
PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword,
PostAuthAction::LinkUpstream { id } => {
let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id)
let link = repo
.upstream_oauth_link()
.lookup(id)
.await?
.context("Failed to load upstream OAuth 2.0 link")?;
let provider =
mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id)
.await?
.context("Failed to load upstream OAuth 2.0 provider")?;
let provider = repo
.upstream_oauth_provider()
.lookup(link.provider_id)
.await?
.context("Failed to load upstream OAuth 2.0 provider")?;
let provider = Box::new(provider);
let link = Box::new(link);

View File

@@ -15,12 +15,7 @@
//! A crate to store keys which can then be used to sign and verify JWTs.
#![forbid(unsafe_code)]
#![deny(
clippy::all,
clippy::str_to_string,
rustdoc::broken_intra_doc_links,
rustdoc::all
)]
#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)]
#![warn(clippy::pedantic)]
use std::{ops::Deref, sync::Arc};

View File

@@ -22,6 +22,9 @@
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
//! An utility crate to build flexible [`hyper`] listeners, with optional TLS
//! and proxy protocol support.
use self::{maybe_tls::TlsStreamInfo, proxy_protocol::ProxyProtocolV1Info};
pub mod maybe_tls;

View File

@@ -39,7 +39,7 @@ pub struct Server<S> {
impl<S> Server<S> {
/// # Errors
///
/// Returns an error if the listener couldn't be converted via [`TryInfo`]
/// Returns an error if the listener couldn't be converted via [`TryInto`]
pub fn try_new<L>(listener: L, service: S) -> Result<Self, L::Error>
where
L: TryInto<UnixOrTcpListener>,

View File

@@ -90,6 +90,11 @@ impl<T> Localized<T> {
&self.non_localized
}
/// Get the non-localized variant.
pub fn to_non_localized(self) -> T {
self.non_localized
}
/// Get the variant corresponding to the given language, if it exists.
pub fn get(&self, language: Option<&LanguageTag>) -> Option<&T> {
match language {

View File

@@ -69,7 +69,7 @@ pub struct PolicyFactory {
}
impl PolicyFactory {
#[tracing::instrument(skip(source), err)]
#[tracing::instrument(name = "policy.load", skip(source), err)]
pub async fn load(
mut source: impl AsyncRead + std::marker::Unpin,
data: serde_json::Value,
@@ -108,7 +108,7 @@ impl PolicyFactory {
authorization_grant_endpoint,
};
// Try to instanciate
// Try to instantiate
factory
.instantiate()
.await
@@ -117,7 +117,7 @@ impl PolicyFactory {
Ok(factory)
}
#[tracing::instrument(skip(self), err)]
#[tracing::instrument(name = "policy.instantiate", skip_all, err)]
pub async fn instantiate(&self) -> Result<Policy, InstanciateError> {
let mut store = Store::new(&self.engine, ());
let runtime = Runtime::new(&mut store, &self.module)
@@ -189,7 +189,14 @@ pub enum EvaluationError {
}
impl Policy {
#[tracing::instrument(skip(self, password))]
#[tracing::instrument(
name = "policy.evaluate.register",
skip_all,
fields(
data.username = username,
),
err,
)]
pub async fn evaluate_register(
&mut self,
username: &str,
@@ -234,7 +241,15 @@ impl Policy {
Ok(res)
}
#[tracing::instrument(skip(self))]
#[tracing::instrument(
name = "policy.evaluate.authorization_grant",
skip_all,
fields(
data.authorization_grant.id = %authorization_grant.id,
data.user.id = %user.id,
),
err,
)]
pub async fn evaluate_authorization_grant(
&mut self,
authorization_grant: &AuthorizationGrant,

View File

@@ -21,6 +21,8 @@
)]
#![warn(clippy::pedantic)]
//! A crate to help serve single-page apps built by Vite.
mod vite;
use std::{future::Future, pin::Pin};

View File

@@ -0,0 +1,28 @@
[package]
name = "mas-storage-pg"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2021"
license = "Apache-2.0"
[dependencies]
async-trait = "0.1.60"
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] }
chrono = { version = "0.4.23", features = ["serde"] }
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.91"
thiserror = "1.0.38"
tracing = "0.1.37"
futures-util = "0.3.25"
rand = "0.8.5"
rand_chacha = "0.3.1"
url = { version = "2.3.1", features = ["serde"] }
uuid = "1.2.2"
ulid = { version = "1.0.0", features = ["uuid", "serde"] }
oauth2-types = { path = "../oauth2-types" }
mas-storage = { path = "../storage" }
mas-data-model = { path = "../data-model" }
mas-iana = { path = "../iana" }
mas-jose = { path = "../jose" }

View File

@@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,220 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{CompatAccessToken, CompatSession};
use mas_storage::{compat::CompatAccessTokenRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
/// An implementation of [`CompatAccessTokenRepository`] for a PostgreSQL
/// connection
pub struct PgCompatAccessTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatAccessTokenRepository<'c> {
/// Create a new [`PgCompatAccessTokenRepository`] from an active PostgreSQL
/// connection
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatAccessTokenLookup {
compat_access_token_id: Uuid,
access_token: String,
created_at: DateTime<Utc>,
expires_at: Option<DateTime<Utc>>,
compat_session_id: Uuid,
}
impl From<CompatAccessTokenLookup> for CompatAccessToken {
fn from(value: CompatAccessTokenLookup) -> Self {
Self {
id: value.compat_access_token_id.into(),
session_id: value.compat_session_id.into(),
token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[async_trait]
impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_access_token.lookup",
skip_all,
fields(
db.statement,
compat_session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE compat_access_token_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_access_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
access_token: &str,
) -> Result<Option<CompatAccessToken>, Self::Error> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE access_token = $1
"#,
access_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_access_token.add",
skip_all,
fields(
db.statement,
compat_access_token.id,
%compat_session.id,
user.id = %compat_session.user_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
compat_session: &CompatSession,
token: String,
expires_after: Option<Duration>,
) -> Result<CompatAccessToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
sqlx::query!(
r#"
INSERT INTO compat_access_tokens
(compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(compat_session.id),
token,
created_at,
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatAccessToken {
id,
session_id: compat_session.id,
token,
created_at,
expires_at,
})
}
#[tracing::instrument(
name = "db.compat_access_token.expire",
skip_all,
fields(
db.statement,
%compat_access_token.id,
compat_session.id = %compat_access_token.session_id,
),
err,
)]
async fn expire(
&mut self,
clock: &dyn Clock,
mut compat_access_token: CompatAccessToken,
) -> Result<CompatAccessToken, Self::Error> {
let expires_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_access_tokens
SET expires_at = $2
WHERE compat_access_token_id = $1
"#,
Uuid::from(compat_access_token.id),
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
compat_access_token.expires_at = Some(expires_at);
Ok(compat_access_token)
}
}

View File

@@ -0,0 +1,449 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A module containing PostgreSQL implementation of repositories for the
//! compatibility layer
mod access_token;
mod refresh_token;
mod session;
mod sso_login;
pub use self::{
access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository,
session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository,
};
#[cfg(test)]
mod tests {
use chrono::Duration;
use mas_data_model::Device;
use mas_storage::{
clock::MockClock,
compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
},
user::UserRepository,
Clock, Pagination, Repository, RepositoryAccess,
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
use ulid::Ulid;
use crate::PgRepository;
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_session_repository(pool: PgPool) {
const FIRST_TOKEN: &str = "first_access_token";
const SECOND_TOKEN: &str = "second_access_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Start a compat session for that user
let device = Device::generate(&mut rng);
let device_str = device.as_str().to_owned();
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device)
.await
.unwrap();
assert_eq!(session.user_id, user.id);
assert_eq!(session.device.as_str(), device_str);
assert!(session.is_valid());
assert!(!session.is_finished());
// Lookup the session and check it didn't change
let session_lookup = repo
.compat_session()
.lookup(session.id)
.await
.unwrap()
.expect("compat session not found");
assert_eq!(session_lookup.id, session.id);
assert_eq!(session_lookup.user_id, user.id);
assert_eq!(session_lookup.device.as_str(), device_str);
assert!(session_lookup.is_valid());
assert!(!session_lookup.is_finished());
// Finish the session
let session = repo.compat_session().finish(&clock, session).await.unwrap();
assert!(!session.is_valid());
assert!(session.is_finished());
// Reload the session and check again
let session_lookup = repo
.compat_session()
.lookup(session.id)
.await
.unwrap()
.expect("compat session not found");
assert!(!session_lookup.is_valid());
assert!(session_lookup.is_finished());
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_access_token_repository(pool: PgPool) {
const FIRST_TOKEN: &str = "first_access_token";
const SECOND_TOKEN: &str = "second_access_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Start a compat session for that user
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device)
.await
.unwrap();
// Add an access token to that session
let token = repo
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
FIRST_TOKEN.to_owned(),
Some(Duration::minutes(1)),
)
.await
.unwrap();
assert_eq!(token.session_id, session.id);
assert_eq!(token.token, FIRST_TOKEN);
// Commit the txn and grab a new transaction, to test a conflict
repo.save().await.unwrap();
{
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Adding the same token a second time should conflict
assert!(repo
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
FIRST_TOKEN.to_owned(),
Some(Duration::minutes(1)),
)
.await
.is_err());
repo.cancel().await.unwrap();
}
// Grab a new repo
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Looking up via ID works
let token_lookup = repo
.compat_access_token()
.lookup(token.id)
.await
.unwrap()
.expect("compat access token not found");
assert_eq!(token.id, token_lookup.id);
assert_eq!(token_lookup.session_id, session.id);
// Looking up via the token value works
let token_lookup = repo
.compat_access_token()
.find_by_token(FIRST_TOKEN)
.await
.unwrap()
.expect("compat access token not found");
assert_eq!(token.id, token_lookup.id);
assert_eq!(token_lookup.session_id, session.id);
// Token is currently valid
assert!(token.is_valid(clock.now()));
clock.advance(Duration::minutes(1));
// Token should have expired
assert!(!token.is_valid(clock.now()));
// Add a second access token, this time without expiration
let token = repo
.compat_access_token()
.add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
.await
.unwrap();
assert_eq!(token.session_id, session.id);
assert_eq!(token.token, SECOND_TOKEN);
// Token is currently valid
assert!(token.is_valid(clock.now()));
// Make it expire
repo.compat_access_token()
.expire(&clock, token)
.await
.unwrap();
// Reload it
let token = repo
.compat_access_token()
.find_by_token(SECOND_TOKEN)
.await
.unwrap()
.expect("compat access token not found");
// Token is not valid anymore
assert!(!token.is_valid(clock.now()));
repo.save().await.unwrap();
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_refresh_token_repository(pool: PgPool) {
const ACCESS_TOKEN: &str = "access_token";
const REFRESH_TOKEN: &str = "refresh_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Start a compat session for that user
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device)
.await
.unwrap();
// Add an access token to that session
let access_token = repo
.compat_access_token()
.add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
.await
.unwrap();
let refresh_token = repo
.compat_refresh_token()
.add(
&mut rng,
&clock,
&session,
&access_token,
REFRESH_TOKEN.to_owned(),
)
.await
.unwrap();
assert_eq!(refresh_token.session_id, session.id);
assert_eq!(refresh_token.access_token_id, access_token.id);
assert_eq!(refresh_token.token, REFRESH_TOKEN);
assert!(refresh_token.is_valid());
assert!(!refresh_token.is_consumed());
// Look it up by ID and check everything matches
let refresh_token_lookup = repo
.compat_refresh_token()
.lookup(refresh_token.id)
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token_lookup.id, refresh_token.id);
assert_eq!(refresh_token_lookup.session_id, session.id);
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
assert!(refresh_token_lookup.is_valid());
assert!(!refresh_token_lookup.is_consumed());
// Look it up by token and check everything matches
let refresh_token_lookup = repo
.compat_refresh_token()
.find_by_token(REFRESH_TOKEN)
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token_lookup.id, refresh_token.id);
assert_eq!(refresh_token_lookup.session_id, session.id);
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
assert!(refresh_token_lookup.is_valid());
assert!(!refresh_token_lookup.is_consumed());
// Consume it
let refresh_token = repo
.compat_refresh_token()
.consume(&clock, refresh_token)
.await
.unwrap();
assert!(!refresh_token.is_valid());
assert!(refresh_token.is_consumed());
// Reload it and check again
let refresh_token_lookup = repo
.compat_refresh_token()
.find_by_token(REFRESH_TOKEN)
.await
.unwrap()
.expect("refresh token not found");
assert!(!refresh_token_lookup.is_valid());
assert!(refresh_token_lookup.is_consumed());
// Consuming it again should not work
assert!(repo
.compat_refresh_token()
.consume(&clock, refresh_token)
.await
.is_err());
repo.save().await.unwrap();
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_compat_sso_login_repository(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Lookup an unknown SSO login
let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap();
assert_eq!(login, None);
// Lookup an unknown login token
let login = repo
.compat_sso_login()
.find_by_token("login-token")
.await
.unwrap();
assert_eq!(login, None);
// Start a new SSO login
let login = repo
.compat_sso_login()
.add(
&mut rng,
&clock,
"login-token".to_owned(),
"https://example.com/callback".parse().unwrap(),
)
.await
.unwrap();
assert!(login.is_pending());
// Lookup the login by ID
let login_lookup = repo
.compat_sso_login()
.lookup(login.id)
.await
.unwrap()
.expect("login not found");
assert_eq!(login_lookup, login);
// Find the login by token
let login_lookup = repo
.compat_sso_login()
.find_by_token("login-token")
.await
.unwrap()
.expect("login not found");
assert_eq!(login_lookup, login);
// Exchanging before fulfilling should not work
// Note: It should also not poison the SQL transaction
let res = repo
.compat_sso_login()
.exchange(&clock, login.clone())
.await;
assert!(res.is_err());
// Start a compat session for that user
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device)
.await
.unwrap();
// Associate the login with the session
let login = repo
.compat_sso_login()
.fulfill(&clock, login, &session)
.await
.unwrap();
assert!(login.is_fulfilled());
// Fulfilling again should not work
// Note: It should also not poison the SQL transaction
let res = repo
.compat_sso_login()
.fulfill(&clock, login.clone(), &session)
.await;
assert!(res.is_err());
// Exchange that login
let login = repo
.compat_sso_login()
.exchange(&clock, login)
.await
.unwrap();
assert!(login.is_exchanged());
// Exchange again should not work
// Note: It should also not poison the SQL transaction
let res = repo
.compat_sso_login()
.exchange(&clock, login.clone())
.await;
assert!(res.is_err());
// Fulfilling after exchanging should not work
// Note: It should also not poison the SQL transaction
let res = repo
.compat_sso_login()
.fulfill(&clock, login.clone(), &session)
.await;
assert!(res.is_err());
// List the logins for the user
let logins = repo
.compat_sso_login()
.list_paginated(&user, Pagination::first(10))
.await
.unwrap();
assert!(!logins.has_next_page);
assert_eq!(logins.edges, vec![login]);
}
}

View File

@@ -0,0 +1,234 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
};
use mas_storage::{compat::CompatRefreshTokenRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
/// An implementation of [`CompatRefreshTokenRepository`] for a PostgreSQL
/// connection
pub struct PgCompatRefreshTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatRefreshTokenRepository<'c> {
/// Create a new [`PgCompatRefreshTokenRepository`] from an active
/// PostgreSQL connection
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatRefreshTokenLookup {
compat_refresh_token_id: Uuid,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
compat_access_token_id: Uuid,
compat_session_id: Uuid,
}
impl From<CompatRefreshTokenLookup> for CompatRefreshToken {
fn from(value: CompatRefreshTokenLookup) -> Self {
let state = match value.consumed_at {
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
None => CompatRefreshTokenState::Valid,
};
Self {
id: value.compat_refresh_token_id.into(),
state,
session_id: value.compat_session_id.into(),
token: value.refresh_token,
created_at: value.created_at,
access_token_id: value.compat_access_token_id.into(),
}
}
}
#[async_trait]
impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_refresh_token.lookup",
skip_all,
fields(
db.statement,
compat_refresh_token.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE compat_refresh_token_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_refresh_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<CompatRefreshToken>, Self::Error> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE refresh_token = $1
"#,
refresh_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_refresh_token.add",
skip_all,
fields(
db.statement,
compat_refresh_token.id,
%compat_session.id,
user.id = %compat_session.user_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
compat_session: &CompatSession,
compat_access_token: &CompatAccessToken,
token: String,
) -> Result<CompatRefreshToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_refresh_tokens
(compat_refresh_token_id, compat_session_id,
compat_access_token_id, refresh_token, created_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(compat_session.id),
Uuid::from(compat_access_token.id),
token,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatRefreshToken {
id,
state: CompatRefreshTokenState::default(),
session_id: compat_session.id,
access_token_id: compat_access_token.id,
token,
created_at,
})
}
#[tracing::instrument(
name = "db.compat_refresh_token.consume",
skip_all,
fields(
db.statement,
%compat_refresh_token.id,
compat_session.id = %compat_refresh_token.session_id,
),
err,
)]
async fn consume(
&mut self,
clock: &dyn Clock,
compat_refresh_token: CompatRefreshToken,
) -> Result<CompatRefreshToken, Self::Error> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_refresh_tokens
SET consumed_at = $2
WHERE compat_refresh_token_id = $1
"#,
Uuid::from(compat_refresh_token.id),
consumed_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_refresh_token = compat_refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_refresh_token)
}
}

View File

@@ -0,0 +1,198 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{CompatSession, CompatSessionState, Device, User};
use mas_storage::{compat::CompatSessionRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
/// An implementation of [`CompatSessionRepository`] for a PostgreSQL connection
pub struct PgCompatSessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatSessionRepository<'c> {
/// Create a new [`PgCompatSessionRepository`] from an active PostgreSQL
/// connection
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatSessionLookup {
compat_session_id: Uuid,
device_id: String,
user_id: Uuid,
created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>,
}
impl TryFrom<CompatSessionLookup> for CompatSession {
type Error = DatabaseInconsistencyError;
fn try_from(value: CompatSessionLookup) -> Result<Self, Self::Error> {
let id = value.compat_session_id.into();
let device = Device::try_from(value.device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device_id")
.row(id)
.source(e)
})?;
let state = match value.finished_at {
None => CompatSessionState::Valid,
Some(finished_at) => CompatSessionState::Finished { finished_at },
};
let session = CompatSession {
id,
state,
user_id: value.user_id.into(),
device,
created_at: value.created_at,
};
Ok(session)
}
}
#[async_trait]
impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_session.lookup",
skip_all,
fields(
db.statement,
compat_session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
let res = sqlx::query_as!(
CompatSessionLookup,
r#"
SELECT compat_session_id
, device_id
, user_id
, created_at
, finished_at
FROM compat_sessions
WHERE compat_session_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_session.add",
skip_all,
fields(
db.statement,
compat_session.id,
%user.id,
%user.username,
compat_session.device.id = device.as_str(),
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
user: &User,
device: Device,
) -> Result<CompatSession, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
Uuid::from(user.id),
device.as_str(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatSession {
id,
state: CompatSessionState::default(),
user_id: user.id,
device,
created_at,
})
}
#[tracing::instrument(
name = "db.compat_session.finish",
skip_all,
fields(
db.statement,
%compat_session.id,
user.id = %compat_session.user_id,
compat_session.device.id = compat_session.device.as_str(),
),
err,
)]
async fn finish(
&mut self,
clock: &dyn Clock,
compat_session: CompatSession,
) -> Result<CompatSession, Self::Error> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_sessions cs
SET finished_at = $2
WHERE compat_session_id = $1
"#,
Uuid::from(compat_session.id),
finished_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_session = compat_session
.finish(finished_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_session)
}
}

View File

@@ -0,0 +1,346 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User};
use mas_storage::{compat::CompatSsoLoginRepository, Clock, Page, Pagination};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
LookupResultExt,
};
/// An implementation of [`CompatSsoLoginRepository`] for a PostgreSQL
/// connection
pub struct PgCompatSsoLoginRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatSsoLoginRepository<'c> {
/// Create a new [`PgCompatSsoLoginRepository`] from an active PostgreSQL
/// connection
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct CompatSsoLoginLookup {
compat_sso_login_id: Uuid,
login_token: String,
redirect_uri: String,
created_at: DateTime<Utc>,
fulfilled_at: Option<DateTime<Utc>>,
exchanged_at: Option<DateTime<Utc>>,
compat_session_id: Option<Uuid>,
}
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
type Error = DatabaseInconsistencyError;
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
let id = res.compat_sso_login_id.into();
let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
DatabaseInconsistencyError::on("compat_sso_logins")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) {
(None, None, None) => CompatSsoLoginState::Pending,
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
fulfilled_at,
session_id: session_id.into(),
},
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
CompatSsoLoginState::Exchanged {
fulfilled_at,
exchanged_at,
session_id: session_id.into(),
}
}
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
};
Ok(CompatSsoLogin {
id,
login_token: res.login_token,
redirect_uri,
created_at: res.created_at,
state,
})
}
}
#[async_trait]
impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_sso_login.lookup",
skip_all,
fields(
db.statement,
compat_sso_login.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT compat_sso_login_id
, login_token
, redirect_uri
, created_at
, fulfilled_at
, exchanged_at
, compat_session_id
FROM compat_sso_logins
WHERE compat_sso_login_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_sso_login.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
login_token: &str,
) -> Result<Option<CompatSsoLogin>, Self::Error> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT compat_sso_login_id
, login_token
, redirect_uri
, created_at
, fulfilled_at
, exchanged_at
, compat_session_id
FROM compat_sso_logins
WHERE login_token = $1
"#,
login_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_sso_login.add",
skip_all,
fields(
db.statement,
compat_sso_login.id,
compat_sso_login.redirect_uri = %redirect_uri,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
login_token: String,
redirect_uri: Url,
) -> Result<CompatSsoLogin, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sso_logins
(compat_sso_login_id, login_token, redirect_uri, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
&login_token,
redirect_uri.as_str(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatSsoLogin {
id,
login_token,
redirect_uri,
created_at,
state: CompatSsoLoginState::default(),
})
}
#[tracing::instrument(
name = "db.compat_sso_login.fulfill",
skip_all,
fields(
db.statement,
%compat_sso_login.id,
%compat_session.id,
compat_session.device.id = compat_session.device.as_str(),
user.id = %compat_session.user_id,
),
err,
)]
async fn fulfill(
&mut self,
clock: &dyn Clock,
compat_sso_login: CompatSsoLogin,
compat_session: &CompatSession,
) -> Result<CompatSsoLogin, Self::Error> {
let fulfilled_at = clock.now();
let compat_sso_login = compat_sso_login
.fulfill(fulfilled_at, compat_session)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
compat_session_id = $2,
fulfilled_at = $3
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
Uuid::from(compat_session.id),
fulfilled_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(compat_sso_login)
}
#[tracing::instrument(
name = "db.compat_sso_login.exchange",
skip_all,
fields(
db.statement,
%compat_sso_login.id,
),
err,
)]
async fn exchange(
&mut self,
clock: &dyn Clock,
compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, Self::Error> {
let exchanged_at = clock.now();
let compat_sso_login = compat_sso_login
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
exchanged_at = $2
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
exchanged_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(compat_sso_login)
}
#[tracing::instrument(
name = "db.compat_sso_login.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err
)]
async fn list_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<CompatSsoLogin>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT cl.compat_sso_login_id
, cl.login_token
, cl.redirect_uri
, cl.created_at
, cl.fulfilled_at
, cl.exchanged_at
, cl.compat_session_id
FROM compat_sso_logins cl
INNER JOIN compat_sessions cs USING (compat_session_id)
"#,
);
query
.push(" WHERE cs.user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("cl.compat_sso_login_id", pagination);
let edges: Vec<CompatSsoLoginLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination
.process(edges)
.try_map(CompatSsoLogin::try_from)?;
Ok(page)
}
}

View File

@@ -0,0 +1,144 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use sqlx::postgres::PgQueryResult;
use thiserror::Error;
use ulid::Ulid;
/// Generic error when interacting with the database
#[derive(Debug, Error)]
#[error(transparent)]
pub enum DatabaseError {
/// An error which came from the database itself
Driver {
/// The underlying error from the database driver
#[from]
source: sqlx::Error,
},
/// An error which occured while converting the data from the database
Inconsistency(#[from] DatabaseInconsistencyError),
/// An error which happened because the requested database operation is
/// invalid
#[error("Invalid database operation")]
InvalidOperation {
/// The source of the error, if any
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
},
/// An error which happens when an operation affects not enough or too many
/// rows
#[error("Expected {expected} rows to be affected, but {actual} rows were affected")]
RowsAffected {
/// How many rows were expected to be affected
expected: u64,
/// How many rows were actually affected
actual: u64,
},
}
impl DatabaseError {
pub(crate) fn ensure_affected_rows(
result: &PgQueryResult,
expected: u64,
) -> Result<(), DatabaseError> {
let actual = result.rows_affected();
if actual == expected {
Ok(())
} else {
Err(DatabaseError::RowsAffected { expected, actual })
}
}
pub(crate) fn to_invalid_operation<E: std::error::Error + Send + Sync + 'static>(e: E) -> Self {
Self::InvalidOperation {
source: Some(Box::new(e)),
}
}
pub(crate) const fn invalid_operation() -> Self {
Self::InvalidOperation { source: None }
}
}
/// An error which occured while converting the data from the database
#[derive(Debug, Error)]
pub struct DatabaseInconsistencyError {
/// The table which was being queried
table: &'static str,
/// The column which was being queried
column: Option<&'static str>,
/// The row which was being queried
row: Option<Ulid>,
/// The source of the error
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
}
impl std::fmt::Display for DatabaseInconsistencyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Database inconsistency on table {}", self.table)?;
if let Some(column) = self.column {
write!(f, " column {column}")?;
}
if let Some(row) = self.row {
write!(f, " row {row}")?;
}
Ok(())
}
}
impl DatabaseInconsistencyError {
/// Create a new [`DatabaseInconsistencyError`] for the given table
#[must_use]
pub(crate) const fn on(table: &'static str) -> Self {
Self {
table,
column: None,
row: None,
source: None,
}
}
/// Set the column which was being queried
#[must_use]
pub(crate) const fn column(mut self, column: &'static str) -> Self {
self.column = Some(column);
self
}
/// Set the row which was being queried
#[must_use]
pub(crate) const fn row(mut self, row: Ulid) -> Self {
self.row = Some(row);
self
}
/// Give the source of the error
#[must_use]
pub(crate) fn source<E: std::error::Error + Send + Sync + 'static>(
mut self,
source: E,
) -> Self {
self.source = Some(Box::new(source));
self
}
}

View File

@@ -0,0 +1,225 @@
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! An implementation of the storage traits for a PostgreSQL database
//!
//! This backend uses [`sqlx`] to interact with the database. Most queries are
//! type-checked, using introspection data recorded in the `sqlx-data.json`
//! file. This file is generated by the `sqlx` CLI tool, and should be updated
//! whenever the database schema changes, or new queries are added.
//!
//! # Implementing a new repository
//!
//! When a new repository is defined in [`mas_storage`], it should be
//! implemented here, with the PostgreSQL backend.
//!
//! A typical implementation will look like this:
//!
//! ```rust
//! # use async_trait::async_trait;
//! # use ulid::Ulid;
//! # use rand::RngCore;
//! # use mas_storage::Clock;
//! # use mas_storage_pg::{DatabaseError, ExecuteExt, LookupResultExt};
//! # use sqlx::PgConnection;
//! # use uuid::Uuid;
//! #
//! # // A fake data structure, usually defined in mas-data-model
//! # #[derive(sqlx::FromRow)]
//! # struct FakeData {
//! # id: Ulid,
//! # }
//! #
//! # // A fake repository trait, usually defined in mas-storage
//! # #[async_trait]
//! # pub trait FakeDataRepository: Send + Sync {
//! # type Error;
//! # async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error>;
//! # async fn add(
//! # &mut self,
//! # rng: &mut (dyn RngCore + Send),
//! # clock: &dyn Clock,
//! # ) -> Result<FakeData, Self::Error>;
//! # }
//! #
//! /// An implementation of [`FakeDataRepository`] for a PostgreSQL connection
//! pub struct PgFakeDataRepository<'c> {
//! conn: &'c mut PgConnection,
//! }
//!
//! impl<'c> PgFakeDataRepository<'c> {
//! /// Create a new [`FakeDataRepository`] from an active PostgreSQL connection
//! pub fn new(conn: &'c mut PgConnection) -> Self {
//! Self { conn }
//! }
//! }
//!
//! #[derive(sqlx::FromRow)]
//! struct FakeDataLookup {
//! fake_data_id: Uuid,
//! }
//!
//! impl From<FakeDataLookup> for FakeData {
//! fn from(value: FakeDataLookup) -> Self {
//! Self {
//! id: value.fake_data_id.into(),
//! }
//! }
//! }
//!
//! #[async_trait]
//! impl<'c> FakeDataRepository for PgFakeDataRepository<'c> {
//! type Error = DatabaseError;
//!
//! #[tracing::instrument(
//! name = "db.fake_data.lookup",
//! skip_all,
//! fields(
//! db.statement,
//! fake_data.id = %id,
//! ),
//! err,
//! )]
//! async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error> {
//! // Note: here we would use the macro version instead, but it's not possible here in
//! // this documentation example
//! let res: Option<FakeDataLookup> = sqlx::query_as(
//! r#"
//! SELECT fake_data_id
//! FROM fake_data
//! WHERE fake_data_id = $1
//! "#,
//! )
//! .bind(Uuid::from(id))
//! .traced()
//! .fetch_one(&mut *self.conn)
//! .await
//! .to_option()?;
//!
//! let Some(res) = res else { return Ok(None) };
//!
//! Ok(Some(res.into()))
//! }
//!
//! #[tracing::instrument(
//! name = "db.fake_data.add",
//! skip_all,
//! fields(
//! db.statement,
//! fake_data.id,
//! ),
//! err,
//! )]
//! async fn add(
//! &mut self,
//! rng: &mut (dyn RngCore + Send),
//! clock: &dyn Clock,
//! ) -> Result<FakeData, Self::Error> {
//! let created_at = clock.now();
//! let id = Ulid::from_datetime_with_source(created_at.into(), rng);
//! tracing::Span::current().record("fake_data.id", tracing::field::display(id));
//!
//! // Note: here we would use the macro version instead, but it's not possible here in
//! // this documentation example
//! sqlx::query(
//! r#"
//! INSERT INTO fake_data (id)
//! VALUES ($1)
//! "#,
//! )
//! .bind(Uuid::from(id))
//! .traced()
//! .execute(&mut *self.conn)
//! .await?;
//!
//! Ok(FakeData {
//! id,
//! })
//! }
//! }
//! ```
//!
//! A few things to note with the implementation:
//!
//! - All methods are traced, with an explicit, somewhat consistent name.
//! - The SQL statement is included as attribute, by declaring a `db.statement`
//! attribute on the tracing span, and then calling [`ExecuteExt::traced`].
//! - The IDs are all [`Ulid`], and generated from the clock and the random
//! number generated passed as parameters. The generated IDs are recorded in
//! the span.
//! - The IDs are stored as [`Uuid`] in PostgreSQL, so conversions are required
//! - "Not found" errors are handled by returning `Ok(None)` instead of an
//! error. The [`LookupResultExt::to_option`] method helps to do that.
//!
//! [`Ulid`]: ulid::Ulid
//! [`Uuid`]: uuid::Uuid
#![forbid(unsafe_code)]
#![deny(
clippy::all,
clippy::str_to_string,
clippy::future_not_send,
rustdoc::broken_intra_doc_links,
missing_docs
)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
use sqlx::migrate::Migrator;
/// An extension trait for [`Result`] which adds a [`to_option`] method, useful
/// for handling "not found" errors from [`sqlx`]
///
/// [`to_option`]: LookupResultExt::to_option
pub trait LookupResultExt {
/// The output type
type Output;
/// Transform a [`Result`] from a sqlx query to transform "not found" errors
/// into [`None`]
///
/// # Errors
///
/// Returns the original error if the error was not a
/// [`sqlx::Error::RowNotFound`] error
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error>;
}
impl<T> LookupResultExt for Result<T, sqlx::Error> {
type Output = T;
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error> {
match self {
Ok(v) => Ok(Some(v)),
Err(sqlx::Error::RowNotFound) => Ok(None),
Err(e) => Err(e),
}
}
}
pub mod compat;
pub mod oauth2;
pub mod upstream_oauth2;
pub mod user;
mod errors;
pub(crate) mod pagination;
pub(crate) mod repository;
pub(crate) mod tracing;
pub(crate) use self::errors::DatabaseInconsistencyError;
pub use self::{errors::DatabaseError, repository::PgRepository, tracing::ExecuteExt};
/// Embedded migrations, allowing them to run on startup
pub static MIGRATOR: Migrator = sqlx::migrate!();

View File

@@ -0,0 +1,227 @@
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, AccessTokenState, Session};
use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
/// An implementation of [`OAuth2AccessTokenRepository`] for a PostgreSQL
/// connection
pub struct PgOAuth2AccessTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2AccessTokenRepository<'c> {
/// Create a new [`PgOAuth2AccessTokenRepository`] from an active PostgreSQL
/// connection
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct OAuth2AccessTokenLookup {
oauth2_access_token_id: Uuid,
oauth2_session_id: Uuid,
access_token: String,
created_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
revoked_at: Option<DateTime<Utc>>,
}
impl From<OAuth2AccessTokenLookup> for AccessToken {
fn from(value: OAuth2AccessTokenLookup) -> Self {
let state = match value.revoked_at {
None => AccessTokenState::Valid,
Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
};
Self {
id: value.oauth2_access_token_id.into(),
state,
session_id: value.oauth2_session_id.into(),
access_token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[async_trait]
impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> {
type Error = DatabaseError;
async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_access_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
access_token: &str,
) -> Result<Option<AccessToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens
WHERE access_token = $1
"#,
access_token,
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_access_token.add",
skip_all,
fields(
db.statement,
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
access_token.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
session: &Session,
access_token: String,
expires_after: Duration,
) -> Result<AccessToken, Self::Error> {
let created_at = clock.now();
let expires_at = created_at + expires_after;
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("access_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_access_tokens
(oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
&access_token,
created_at,
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(AccessToken {
id,
state: AccessTokenState::default(),
access_token,
session_id: session.id,
created_at,
expires_at,
})
}
async fn revoke(
&mut self,
clock: &dyn Clock,
access_token: AccessToken,
) -> Result<AccessToken, Self::Error> {
let revoked_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_access_tokens
SET revoked_at = $2
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token.id),
revoked_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
access_token
.revoke(revoked_at)
.map_err(DatabaseError::to_invalid_operation)
}
async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result<usize, Self::Error> {
// Cleanup token which expired more than 15 minutes ago
let threshold = clock.now() - Duration::minutes(15);
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE expires_at < $1
"#,
threshold,
)
.execute(&mut *self.conn)
.await?;
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
}
}

View File

@@ -0,0 +1,514 @@
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::num::NonZeroU32;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
};
use mas_iana::oauth::PkceCodeChallengeMethod;
use mas_storage::{oauth2::OAuth2AuthorizationGrantRepository, Clock};
use oauth2_types::{requests::ResponseMode, scope::Scope};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
/// An implementation of [`OAuth2AuthorizationGrantRepository`] for a PostgreSQL
/// connection
pub struct PgOAuth2AuthorizationGrantRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
/// Create a new [`PgOAuth2AuthorizationGrantRepository`] from an active
/// PostgreSQL connection
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[allow(clippy::struct_excessive_bools)]
struct GrantLookup {
oauth2_authorization_grant_id: Uuid,
created_at: DateTime<Utc>,
cancelled_at: Option<DateTime<Utc>>,
fulfilled_at: Option<DateTime<Utc>>,
exchanged_at: Option<DateTime<Utc>>,
scope: String,
state: Option<String>,
nonce: Option<String>,
redirect_uri: String,
response_mode: String,
max_age: Option<i32>,
response_type_code: bool,
response_type_id_token: bool,
authorization_code: Option<String>,
code_challenge: Option<String>,
code_challenge_method: Option<String>,
requires_consent: bool,
oauth2_client_id: Uuid,
oauth2_session_id: Option<Uuid>,
}
impl TryFrom<GrantLookup> for AuthorizationGrant {
type Error = DatabaseInconsistencyError;
#[allow(clippy::too_many_lines)]
fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
let id = value.oauth2_authorization_grant_id.into();
let scope: Scope = value.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("scope")
.row(id)
.source(e)
})?;
let stage = match (
value.fulfilled_at,
value.exchanged_at,
value.cancelled_at,
value.oauth2_session_id,
) {
(None, None, None, None) => AuthorizationGrantStage::Pending,
(Some(fulfilled_at), None, None, Some(session_id)) => {
AuthorizationGrantStage::Fulfilled {
session_id: session_id.into(),
fulfilled_at,
}
}
(Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
AuthorizationGrantStage::Exchanged {
session_id: session_id.into(),
fulfilled_at,
exchanged_at,
}
}
(None, None, Some(cancelled_at), None) => {
AuthorizationGrantStage::Cancelled { cancelled_at }
}
_ => {
return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("stage")
.row(id),
);
}
};
let pkce = match (value.code_challenge, value.code_challenge_method) {
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
Some(Pkce {
challenge_method: PkceCodeChallengeMethod::Plain,
challenge,
})
}
(Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
challenge_method: PkceCodeChallengeMethod::S256,
challenge,
}),
(None, None) => None,
_ => {
return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("code_challenge_method")
.row(id),
);
}
};
let code: Option<AuthorizationCode> =
match (value.response_type_code, value.authorization_code, pkce) {
(false, None, None) => None,
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
_ => {
return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("authorization_code")
.row(id),
);
}
};
let redirect_uri = value.redirect_uri.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let response_mode = value.response_mode.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("response_mode")
.row(id)
.source(e)
})?;
let max_age = value
.max_age
.map(u32::try_from)
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("max_age")
.row(id)
.source(e)
})?
.map(NonZeroU32::try_from)
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("max_age")
.row(id)
.source(e)
})?;
Ok(AuthorizationGrant {
id,
stage,
client_id: value.oauth2_client_id.into(),
code,
scope,
state: value.state,
nonce: value.nonce,
max_age,
response_mode,
redirect_uri,
created_at: value.created_at,
response_type_id_token: value.response_type_id_token,
requires_consent: value.requires_consent,
})
}
}
#[async_trait]
impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_authorization_grant.add",
skip_all,
fields(
db.statement,
grant.id,
grant.scope = %scope,
%client.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
client: &Client,
redirect_uri: Url,
scope: Scope,
code: Option<AuthorizationCode>,
state: Option<String>,
nonce: Option<String>,
max_age: Option<NonZeroU32>,
response_mode: ResponseMode,
response_type_id_token: bool,
requires_consent: bool,
) -> Result<AuthorizationGrant, Self::Error> {
let code_challenge = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| &p.challenge);
let code_challenge_method = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| p.challenge_method.to_string());
// TODO: this conversion is a bit ugly
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
let code_str = code.as_ref().map(|c| &c.code);
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("grant.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_authorization_grants (
oauth2_authorization_grant_id,
oauth2_client_id,
redirect_uri,
scope,
state,
nonce,
max_age,
response_mode,
code_challenge,
code_challenge_method,
response_type_code,
response_type_id_token,
authorization_code,
requires_consent,
created_at
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
"#,
Uuid::from(id),
Uuid::from(client.id),
redirect_uri.to_string(),
scope.to_string(),
state,
nonce,
max_age_i32,
response_mode.to_string(),
code_challenge,
code_challenge_method,
code.is_some(),
response_type_id_token,
code_str,
requires_consent,
created_at,
)
.execute(&mut *self.conn)
.await?;
Ok(AuthorizationGrant {
id,
stage: AuthorizationGrantStage::Pending,
code,
redirect_uri,
client_id: client.id,
scope,
state,
nonce,
max_age,
response_mode,
created_at,
response_type_id_token,
requires_consent,
})
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.lookup",
skip_all,
fields(
db.statement,
grant.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT oauth2_authorization_grant_id
, created_at
, cancelled_at
, fulfilled_at
, exchanged_at
, scope
, state
, redirect_uri
, response_mode
, nonce
, max_age
, oauth2_client_id
, authorization_code
, response_type_code
, response_type_id_token
, code_challenge
, code_challenge_method
, requires_consent
, oauth2_session_id
FROM
oauth2_authorization_grants
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.find_by_code",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_code(
&mut self,
code: &str,
) -> Result<Option<AuthorizationGrant>, Self::Error> {
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT oauth2_authorization_grant_id
, created_at
, cancelled_at
, fulfilled_at
, exchanged_at
, scope
, state
, redirect_uri
, response_mode
, nonce
, max_age
, oauth2_client_id
, authorization_code
, response_type_code
, response_type_id_token
, code_challenge
, code_challenge_method
, requires_consent
, oauth2_session_id
FROM
oauth2_authorization_grants
WHERE authorization_code = $1
"#,
code,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.fulfill",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
%session.id,
user_session.id = %session.user_session_id,
),
err,
)]
async fn fulfill(
&mut self,
clock: &dyn Clock,
session: &Session,
grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
let fulfilled_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_authorization_grants
SET fulfilled_at = $2
, oauth2_session_id = $3
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
fulfilled_at,
Uuid::from(session.id),
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
// XXX: check affected rows & new methods
let grant = grant
.fulfill(fulfilled_at, session)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.exchange",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
),
err,
)]
async fn exchange(
&mut self,
clock: &dyn Clock,
grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
let exchanged_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_authorization_grants
SET exchanged_at = $2
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
exchanged_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let grant = grant
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.give_consent",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
),
err,
)]
async fn give_consent(
&mut self,
mut grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
sqlx::query!(
r#"
UPDATE oauth2_authorization_grants AS og
SET
requires_consent = 'f'
WHERE
og.oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
)
.execute(&mut *self.conn)
.await?;
grant.requires_consent = false;
Ok(grant)
}
}

View File

@@ -0,0 +1,748 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::{BTreeMap, BTreeSet},
str::FromStr,
string::ToString,
};
use async_trait::async_trait;
use mas_data_model::{Client, JwksOrJwksUri, User};
use mas_iana::{
jose::JsonWebSignatureAlg,
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
};
use mas_jose::jwk::PublicJsonWebKeySet;
use mas_storage::{oauth2::OAuth2ClientRepository, Clock};
use oauth2_types::{
requests::GrantType,
scope::{Scope, ScopeToken},
};
use rand::RngCore;
use sqlx::PgConnection;
use tracing::{info_span, Instrument};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
/// An implementation of [`OAuth2ClientRepository`] for a PostgreSQL connection
pub struct PgOAuth2ClientRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2ClientRepository<'c> {
/// Create a new [`PgOAuth2ClientRepository`] from an active PostgreSQL
/// connection
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
// XXX: response_types & contacts
#[derive(Debug)]
struct OAuth2ClientLookup {
oauth2_client_id: Uuid,
encrypted_client_secret: Option<String>,
redirect_uris: Vec<String>,
// response_types: Vec<String>,
grant_type_authorization_code: bool,
grant_type_refresh_token: bool,
// contacts: Vec<String>,
client_name: Option<String>,
logo_uri: Option<String>,
client_uri: Option<String>,
policy_uri: Option<String>,
tos_uri: Option<String>,
jwks_uri: Option<String>,
jwks: Option<serde_json::Value>,
id_token_signed_response_alg: Option<String>,
userinfo_signed_response_alg: Option<String>,
token_endpoint_auth_method: Option<String>,
token_endpoint_auth_signing_alg: Option<String>,
initiate_login_uri: Option<String>,
}
impl TryInto<Client> for OAuth2ClientLookup {
type Error = DatabaseInconsistencyError;
#[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
fn try_into(self) -> Result<Client, Self::Error> {
let id = Ulid::from(self.oauth2_client_id);
let redirect_uris: Result<Vec<Url>, _> =
self.redirect_uris.iter().map(|s| s.parse()).collect();
let redirect_uris = redirect_uris.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("redirect_uris")
.row(id)
.source(e)
})?;
let response_types = vec![
OAuthAuthorizationEndpointResponseType::Code,
OAuthAuthorizationEndpointResponseType::IdToken,
OAuthAuthorizationEndpointResponseType::None,
];
/* XXX
let response_types: Result<Vec<OAuthAuthorizationEndpointResponseType>, _> =
self.response_types.iter().map(|s| s.parse()).collect();
let response_types = response_types.map_err(|source| ClientFetchError::ParseField {
field: "response_types",
source,
})?;
*/
let mut grant_types = Vec::new();
if self.grant_type_authorization_code {
grant_types.push(GrantType::AuthorizationCode);
}
if self.grant_type_refresh_token {
grant_types.push(GrantType::RefreshToken);
}
let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("logo_uri")
.row(id)
.source(e)
})?;
let client_uri = self
.client_uri
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("client_uri")
.row(id)
.source(e)
})?;
let policy_uri = self
.policy_uri
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("policy_uri")
.row(id)
.source(e)
})?;
let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("tos_uri")
.row(id)
.source(e)
})?;
let id_token_signed_response_alg = self
.id_token_signed_response_alg
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("id_token_signed_response_alg")
.row(id)
.source(e)
})?;
let userinfo_signed_response_alg = self
.userinfo_signed_response_alg
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("userinfo_signed_response_alg")
.row(id)
.source(e)
})?;
let token_endpoint_auth_method = self
.token_endpoint_auth_method
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("token_endpoint_auth_method")
.row(id)
.source(e)
})?;
let token_endpoint_auth_signing_alg = self
.token_endpoint_auth_signing_alg
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("token_endpoint_auth_signing_alg")
.row(id)
.source(e)
})?;
let initiate_login_uri = self
.initiate_login_uri
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("initiate_login_uri")
.row(id)
.source(e)
})?;
let jwks = match (self.jwks, self.jwks_uri) {
(None, None) => None,
(Some(jwks), None) => {
let jwks = serde_json::from_value(jwks).map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("jwks")
.row(id)
.source(e)
})?;
Some(JwksOrJwksUri::Jwks(jwks))
}
(None, Some(jwks_uri)) => {
let jwks_uri = jwks_uri.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("jwks_uri")
.row(id)
.source(e)
})?;
Some(JwksOrJwksUri::JwksUri(jwks_uri))
}
_ => {
return Err(DatabaseInconsistencyError::on("oauth2_clients")
.column("jwks(_uri)")
.row(id))
}
};
Ok(Client {
id,
client_id: id.to_string(),
encrypted_client_secret: self.encrypted_client_secret,
redirect_uris,
response_types,
grant_types,
// contacts: self.contacts,
contacts: vec![],
client_name: self.client_name,
logo_uri,
client_uri,
policy_uri,
tos_uri,
jwks,
id_token_signed_response_alg,
userinfo_signed_response_alg,
token_endpoint_auth_method,
token_endpoint_auth_signing_alg,
initiate_login_uri,
})
}
}
#[async_trait]
impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_client.lookup",
skip_all,
fields(
db.statement,
oauth2_client.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT oauth2_client_id
, encrypted_client_secret
, ARRAY(
SELECT redirect_uri
FROM oauth2_client_redirect_uris r
WHERE r.oauth2_client_id = c.oauth2_client_id
) AS "redirect_uris!"
, grant_type_authorization_code
, grant_type_refresh_token
, client_name
, logo_uri
, client_uri
, policy_uri
, tos_uri
, jwks_uri
, jwks
, id_token_signed_response_alg
, userinfo_signed_response_alg
, token_endpoint_auth_method
, token_endpoint_auth_signing_alg
, initiate_login_uri
FROM oauth2_clients c
WHERE oauth2_client_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_client.load_batch",
skip_all,
fields(
db.statement,
),
err,
)]
async fn load_batch(
&mut self,
ids: BTreeSet<Ulid>,
) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT oauth2_client_id
, encrypted_client_secret
, ARRAY(
SELECT redirect_uri
FROM oauth2_client_redirect_uris r
WHERE r.oauth2_client_id = c.oauth2_client_id
) AS "redirect_uris!"
, grant_type_authorization_code
, grant_type_refresh_token
, client_name
, logo_uri
, client_uri
, policy_uri
, tos_uri
, jwks_uri
, jwks
, id_token_signed_response_alg
, userinfo_signed_response_alg
, token_endpoint_auth_method
, token_endpoint_auth_signing_alg
, initiate_login_uri
FROM oauth2_clients c
WHERE oauth2_client_id = ANY($1::uuid[])
"#,
&ids,
)
.traced()
.fetch_all(&mut *self.conn)
.await?;
res.into_iter()
.map(|r| {
r.try_into()
.map(|c: Client| (c.id, c))
.map_err(DatabaseError::from)
})
.collect()
}
#[tracing::instrument(
name = "db.oauth2_client.add",
skip_all,
fields(
db.statement,
client.id,
client.name = client_name
),
err,
)]
#[allow(clippy::too_many_lines)]
async fn add(
&mut self,
mut rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
redirect_uris: Vec<Url>,
encrypted_client_secret: Option<String>,
grant_types: Vec<GrantType>,
contacts: Vec<String>,
client_name: Option<String>,
logo_uri: Option<Url>,
client_uri: Option<Url>,
policy_uri: Option<Url>,
tos_uri: Option<Url>,
jwks_uri: Option<Url>,
jwks: Option<PublicJsonWebKeySet>,
id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
initiate_login_uri: Option<Url>,
) -> Result<Client, Self::Error> {
let now = clock.now();
let id = Ulid::from_datetime_with_source(now.into(), rng);
tracing::Span::current().record("client.id", tracing::field::display(id));
let jwks_json = jwks
.as_ref()
.map(serde_json::to_value)
.transpose()
.map_err(DatabaseError::to_invalid_operation)?;
sqlx::query!(
r#"
INSERT INTO oauth2_clients
( oauth2_client_id
, encrypted_client_secret
, grant_type_authorization_code
, grant_type_refresh_token
, client_name
, logo_uri
, client_uri
, policy_uri
, tos_uri
, jwks_uri
, jwks
, id_token_signed_response_alg
, userinfo_signed_response_alg
, token_endpoint_auth_method
, token_endpoint_auth_signing_alg
, initiate_login_uri
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
"#,
Uuid::from(id),
encrypted_client_secret,
grant_types.contains(&GrantType::AuthorizationCode),
grant_types.contains(&GrantType::RefreshToken),
client_name,
logo_uri.as_ref().map(Url::as_str),
client_uri.as_ref().map(Url::as_str),
policy_uri.as_ref().map(Url::as_str),
tos_uri.as_ref().map(Url::as_str),
jwks_uri.as_ref().map(Url::as_str),
jwks_json,
id_token_signed_response_alg
.as_ref()
.map(ToString::to_string),
userinfo_signed_response_alg
.as_ref()
.map(ToString::to_string),
token_endpoint_auth_method.as_ref().map(ToString::to_string),
token_endpoint_auth_signing_alg
.as_ref()
.map(ToString::to_string),
initiate_login_uri.as_ref().map(Url::as_str),
)
.traced()
.execute(&mut *self.conn)
.await?;
{
let span = info_span!(
"db.oauth2_client.add.redirect_uris",
db.statement = tracing::field::Empty,
client.id = %id,
);
let (uri_ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
.iter()
.map(|uri| {
(
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
uri.as_str().to_owned(),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_client_redirect_uris
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
SELECT id, $2, redirect_uri
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
"#,
&uri_ids,
Uuid::from(id),
&redirect_uris,
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
let jwks = match (jwks, jwks_uri) {
(None, None) => None,
(Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
(None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
_ => return Err(DatabaseError::invalid_operation()),
};
Ok(Client {
id,
client_id: id.to_string(),
encrypted_client_secret,
redirect_uris,
response_types: vec![
OAuthAuthorizationEndpointResponseType::Code,
OAuthAuthorizationEndpointResponseType::IdToken,
OAuthAuthorizationEndpointResponseType::None,
],
grant_types,
contacts,
client_name,
logo_uri,
client_uri,
policy_uri,
tos_uri,
jwks,
id_token_signed_response_alg,
userinfo_signed_response_alg,
token_endpoint_auth_method,
token_endpoint_auth_signing_alg,
initiate_login_uri,
})
}
#[tracing::instrument(
name = "db.oauth2_client.add_from_config",
skip_all,
fields(
db.statement,
client.id = %client_id,
),
err,
)]
async fn add_from_config(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
client_id: Ulid,
client_auth_method: OAuthClientAuthenticationMethod,
encrypted_client_secret: Option<String>,
jwks: Option<PublicJsonWebKeySet>,
jwks_uri: Option<Url>,
redirect_uris: Vec<Url>,
) -> Result<Client, Self::Error> {
let jwks_json = jwks
.as_ref()
.map(serde_json::to_value)
.transpose()
.map_err(DatabaseError::to_invalid_operation)?;
let client_auth_method = client_auth_method.to_string();
sqlx::query!(
r#"
INSERT INTO oauth2_clients
( oauth2_client_id
, encrypted_client_secret
, grant_type_authorization_code
, grant_type_refresh_token
, token_endpoint_auth_method
, jwks
, jwks_uri
)
VALUES
($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (oauth2_client_id)
DO
UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
, grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
, grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
, token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
, jwks = EXCLUDED.jwks
, jwks_uri = EXCLUDED.jwks_uri
"#,
Uuid::from(client_id),
encrypted_client_secret,
true,
true,
client_auth_method,
jwks_json,
jwks_uri.as_ref().map(Url::as_str),
)
.traced()
.execute(&mut *self.conn)
.await?;
{
let span = info_span!(
"db.oauth2_client.add_from_config.redirect_uris",
client.id = %client_id,
db.statement = tracing::field::Empty,
);
let now = clock.now();
let (ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
.iter()
.map(|uri| {
(
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut *rng)),
uri.as_str().to_owned(),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_client_redirect_uris
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
SELECT id, $2, redirect_uri
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
"#,
&ids,
Uuid::from(client_id),
&redirect_uris,
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
let jwks = match (jwks, jwks_uri) {
(None, None) => None,
(Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
(None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
_ => return Err(DatabaseError::invalid_operation()),
};
Ok(Client {
id: client_id,
client_id: client_id.to_string(),
encrypted_client_secret,
redirect_uris,
response_types: vec![
OAuthAuthorizationEndpointResponseType::Code,
OAuthAuthorizationEndpointResponseType::IdToken,
OAuthAuthorizationEndpointResponseType::None,
],
grant_types: Vec::new(),
contacts: Vec::new(),
client_name: None,
logo_uri: None,
client_uri: None,
policy_uri: None,
tos_uri: None,
jwks,
id_token_signed_response_alg: None,
userinfo_signed_response_alg: None,
token_endpoint_auth_method: None,
token_endpoint_auth_signing_alg: None,
initiate_login_uri: None,
})
}
#[tracing::instrument(
name = "db.oauth2_client.get_consent_for_user",
skip_all,
fields(
db.statement,
%user.id,
%client.id,
),
err,
)]
async fn get_consent_for_user(
&mut self,
client: &Client,
user: &User,
) -> Result<Scope, Self::Error> {
let scope_tokens: Vec<String> = sqlx::query_scalar!(
r#"
SELECT scope_token
FROM oauth2_consents
WHERE user_id = $1 AND oauth2_client_id = $2
"#,
Uuid::from(user.id),
Uuid::from(client.id),
)
.fetch_all(&mut *self.conn)
.await?;
let scope: Result<Scope, _> = scope_tokens
.into_iter()
.map(|s| ScopeToken::from_str(&s))
.collect();
let scope = scope.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_consents")
.column("scope_token")
.source(e)
})?;
Ok(scope)
}
#[tracing::instrument(
skip_all,
fields(
db.statement,
%user.id,
%client.id,
%scope,
),
err,
)]
async fn give_consent_for_user(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
client: &Client,
user: &User,
scope: &Scope,
) -> Result<(), Self::Error> {
let now = clock.now();
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
.iter()
.map(|token| {
(
token.to_string(),
Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_consents
(oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
"#,
&ids,
Uuid::from(user.id),
Uuid::from(client.id),
&tokens,
now,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(())
}
}

View File

@@ -0,0 +1,371 @@
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A module containing the PostgreSQL implementations of the OAuth2-related
//! repositories
mod access_token;
mod authorization_grant;
mod client;
mod refresh_token;
mod session;
pub use self::{
access_token::PgOAuth2AccessTokenRepository,
authorization_grant::PgOAuth2AuthorizationGrantRepository, client::PgOAuth2ClientRepository,
refresh_token::PgOAuth2RefreshTokenRepository, session::PgOAuth2SessionRepository,
};
#[cfg(test)]
mod tests {
use chrono::Duration;
use mas_data_model::AuthorizationCode;
use mas_storage::{clock::MockClock, Clock, Pagination, Repository};
use oauth2_types::{
requests::{GrantType, ResponseMode},
scope::{Scope, OPENID},
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
use ulid::Ulid;
use crate::PgRepository;
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_repositories(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Lookup a non-existing client
let client = repo.oauth2_client().lookup(Ulid::nil()).await.unwrap();
assert_eq!(client, None);
// Find a non-existing client by client id
let client = repo
.oauth2_client()
.find_by_client_id("some-client-id")
.await
.unwrap();
assert_eq!(client, None);
// Create a client
let client = repo
.oauth2_client()
.add(
&mut rng,
&clock,
vec!["https://example.com/redirect".parse().unwrap()],
None,
vec![GrantType::AuthorizationCode],
Vec::new(), // TODO: contacts are not yet saved
// vec!["contact@example.com".to_owned()],
Some("Test client".to_owned()),
Some("https://example.com/logo.png".parse().unwrap()),
Some("https://example.com/".parse().unwrap()),
Some("https://example.com/policy".parse().unwrap()),
Some("https://example.com/tos".parse().unwrap()),
Some("https://example.com/jwks.json".parse().unwrap()),
None,
None,
None,
None,
None,
Some("https://example.com/login".parse().unwrap()),
)
.await
.unwrap();
// Lookup the same client by id
let client_lookup = repo
.oauth2_client()
.lookup(client.id)
.await
.unwrap()
.expect("client not found");
assert_eq!(client, client_lookup);
// Find the same client by client id
let client_lookup = repo
.oauth2_client()
.find_by_client_id(&client.client_id)
.await
.unwrap()
.expect("client not found");
assert_eq!(client, client_lookup);
// Lookup a non-existing grant
let grant = repo
.oauth2_authorization_grant()
.lookup(Ulid::nil())
.await
.unwrap();
assert_eq!(grant, None);
// Find a non-existing grant by code
let grant = repo
.oauth2_authorization_grant()
.find_by_code("code")
.await
.unwrap();
assert_eq!(grant, None);
// Create an authorization grant
let grant = repo
.oauth2_authorization_grant()
.add(
&mut rng,
&clock,
&client,
"https://example.com/redirect".parse().unwrap(),
Scope::from_iter([OPENID]),
Some(AuthorizationCode {
code: "code".to_owned(),
pkce: None,
}),
Some("state".to_owned()),
Some("nonce".to_owned()),
None,
ResponseMode::Query,
true,
false,
)
.await
.unwrap();
assert!(grant.is_pending());
// Lookup the same grant by id
let grant_lookup = repo
.oauth2_authorization_grant()
.lookup(grant.id)
.await
.unwrap()
.expect("grant not found");
assert_eq!(grant, grant_lookup);
// Find the same grant by code
let grant_lookup = repo
.oauth2_authorization_grant()
.find_by_code("code")
.await
.unwrap()
.expect("grant not found");
assert_eq!(grant, grant_lookup);
// Create a user and a start a user session
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
let user_session = repo
.browser_session()
.add(&mut rng, &clock, &user)
.await
.unwrap();
// Lookup the consent the user gave to the client
let consent = repo
.oauth2_client()
.get_consent_for_user(&client, &user)
.await
.unwrap();
assert!(consent.is_empty());
// Give consent to the client
let scope = Scope::from_iter([OPENID]);
repo.oauth2_client()
.give_consent_for_user(&mut rng, &clock, &client, &user, &scope)
.await
.unwrap();
// Lookup the consent the user gave to the client
let consent = repo
.oauth2_client()
.get_consent_for_user(&client, &user)
.await
.unwrap();
assert_eq!(scope, consent);
// Lookup a non-existing session
let session = repo.oauth2_session().lookup(Ulid::nil()).await.unwrap();
assert_eq!(session, None);
// Create a session out of the grant
let session = repo
.oauth2_session()
.create_from_grant(&mut rng, &clock, &grant, &user_session)
.await
.unwrap();
// Mark the grant as fulfilled
let grant = repo
.oauth2_authorization_grant()
.fulfill(&clock, &session, grant)
.await
.unwrap();
assert!(grant.is_fulfilled());
// Lookup the same session by id
let session_lookup = repo
.oauth2_session()
.lookup(session.id)
.await
.unwrap()
.expect("session not found");
assert_eq!(session, session_lookup);
// Mark the grant as exchanged
let grant = repo
.oauth2_authorization_grant()
.exchange(&clock, grant)
.await
.unwrap();
assert!(grant.is_exchanged());
// Lookup a non-existing token
let token = repo
.oauth2_access_token()
.lookup(Ulid::nil())
.await
.unwrap();
assert_eq!(token, None);
// Find a non-existing token
let token = repo
.oauth2_access_token()
.find_by_token("aabbcc")
.await
.unwrap();
assert_eq!(token, None);
// Create an access token
let access_token = repo
.oauth2_access_token()
.add(
&mut rng,
&clock,
&session,
"aabbcc".to_owned(),
Duration::minutes(5),
)
.await
.unwrap();
// Lookup the same token by id
let access_token_lookup = repo
.oauth2_access_token()
.lookup(access_token.id)
.await
.unwrap()
.expect("token not found");
assert_eq!(access_token, access_token_lookup);
// Find the same token by token
let access_token_lookup = repo
.oauth2_access_token()
.find_by_token("aabbcc")
.await
.unwrap()
.expect("token not found");
assert_eq!(access_token, access_token_lookup);
// Lookup a non-existing refresh token
let refresh_token = repo
.oauth2_refresh_token()
.lookup(Ulid::nil())
.await
.unwrap();
assert_eq!(refresh_token, None);
// Find a non-existing refresh token
let refresh_token = repo
.oauth2_refresh_token()
.find_by_token("aabbcc")
.await
.unwrap();
assert_eq!(refresh_token, None);
// Create a refresh token
let refresh_token = repo
.oauth2_refresh_token()
.add(
&mut rng,
&clock,
&session,
&access_token,
"aabbcc".to_owned(),
)
.await
.unwrap();
// Lookup the same refresh token by id
let refresh_token_lookup = repo
.oauth2_refresh_token()
.lookup(refresh_token.id)
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token, refresh_token_lookup);
// Find the same refresh token by token
let refresh_token_lookup = repo
.oauth2_refresh_token()
.find_by_token("aabbcc")
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token, refresh_token_lookup);
assert!(access_token.is_valid(clock.now()));
clock.advance(Duration::minutes(6));
assert!(!access_token.is_valid(clock.now()));
// XXX: we might want to create a new access token
clock.advance(Duration::minutes(-6)); // Go back in time
assert!(access_token.is_valid(clock.now()));
// Mark the access token as revoked
let access_token = repo
.oauth2_access_token()
.revoke(&clock, access_token)
.await
.unwrap();
assert!(!access_token.is_valid(clock.now()));
// Mark the refresh token as consumed
assert!(refresh_token.is_valid());
let refresh_token = repo
.oauth2_refresh_token()
.consume(&clock, refresh_token)
.await
.unwrap();
assert!(!refresh_token.is_valid());
// Mark the session as finished
assert!(session.is_valid());
let session = repo.oauth2_session().finish(&clock, session).await.unwrap();
assert!(!session.is_valid());
// The session should appear in the paginated list of sessions for the user
let sessions = repo
.oauth2_session()
.list_paginated(&user, Pagination::first(10))
.await
.unwrap();
assert!(!sessions.has_next_page);
assert_eq!(sessions.edges, vec![session]);
}
}

View File

@@ -0,0 +1,228 @@
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
use mas_storage::{oauth2::OAuth2RefreshTokenRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
/// An implementation of [`OAuth2RefreshTokenRepository`] for a PostgreSQL
/// connection
pub struct PgOAuth2RefreshTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2RefreshTokenRepository<'c> {
/// Create a new [`PgOAuth2RefreshTokenRepository`] from an active
/// PostgreSQL connection
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct OAuth2RefreshTokenLookup {
oauth2_refresh_token_id: Uuid,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
oauth2_access_token_id: Option<Uuid>,
oauth2_session_id: Uuid,
}
impl From<OAuth2RefreshTokenLookup> for RefreshToken {
fn from(value: OAuth2RefreshTokenLookup) -> Self {
let state = match value.consumed_at {
None => RefreshTokenState::Valid,
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
};
RefreshToken {
id: value.oauth2_refresh_token_id.into(),
state,
session_id: value.oauth2_session_id.into(),
refresh_token: value.refresh_token,
created_at: value.created_at,
access_token_id: value.oauth2_access_token_id.map(Ulid::from),
}
}
}
#[async_trait]
impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_refresh_token.lookup",
skip_all,
fields(
db.statement,
refresh_token.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<RefreshToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
WHERE refresh_token = $1
"#,
refresh_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.add",
skip_all,
fields(
db.statement,
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
refresh_token.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
session: &Session,
access_token: &AccessToken,
refresh_token: String,
) -> Result<RefreshToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_refresh_tokens
(oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
refresh_token, created_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
Uuid::from(access_token.id),
refresh_token,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(RefreshToken {
id,
state: RefreshTokenState::default(),
session_id: session.id,
refresh_token,
access_token_id: Some(access_token.id),
created_at,
})
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.consume",
skip_all,
fields(
db.statement,
%refresh_token.id,
session.id = %refresh_token.session_id,
),
err,
)]
async fn consume(
&mut self,
clock: &dyn Clock,
refresh_token: RefreshToken,
) -> Result<RefreshToken, Self::Error> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens
SET consumed_at = $2
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(refresh_token.id),
consumed_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)
}
}

View File

@@ -0,0 +1,256 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User};
use mas_storage::{oauth2::OAuth2SessionRepository, Clock, Page, Pagination};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
LookupResultExt,
};
/// An implementation of [`OAuth2SessionRepository`] for a PostgreSQL connection
pub struct PgOAuth2SessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2SessionRepository<'c> {
/// Create a new [`PgOAuth2SessionRepository`] from an active PostgreSQL
/// connection
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct OAuthSessionLookup {
oauth2_session_id: Uuid,
user_session_id: Uuid,
oauth2_client_id: Uuid,
scope: String,
#[allow(dead_code)]
created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>,
}
impl TryFrom<OAuthSessionLookup> for Session {
type Error = DatabaseInconsistencyError;
fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
let id = Ulid::from(value.oauth2_session_id);
let scope = value.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope")
.row(id)
.source(e)
})?;
let state = match value.finished_at {
None => SessionState::Valid,
Some(finished_at) => SessionState::Finished { finished_at },
};
Ok(Session {
id,
state,
created_at: value.created_at,
client_id: value.oauth2_client_id.into(),
user_session_id: value.user_session_id.into(),
scope,
})
}
}
#[async_trait]
impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_session.lookup",
skip_all,
fields(
db.statement,
session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
let res = sqlx::query_as!(
OAuthSessionLookup,
r#"
SELECT oauth2_session_id
, user_session_id
, oauth2_client_id
, scope
, created_at
, finished_at
FROM oauth2_sessions
WHERE oauth2_session_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(session) = res else { return Ok(None) };
Ok(Some(session.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_session.create_from_grant",
skip_all,
fields(
db.statement,
%user_session.id,
user.id = %user_session.user.id,
%grant.id,
client.id = %grant.client_id,
session.id,
session.scope = %grant.scope,
),
err,
)]
async fn create_from_grant(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
grant: &AuthorizationGrant,
user_session: &BrowserSession,
) -> Result<Session, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_sessions
( oauth2_session_id
, user_session_id
, oauth2_client_id
, scope
, created_at
)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(user_session.id),
Uuid::from(grant.client_id),
grant.scope.to_string(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(Session {
id,
state: SessionState::Valid,
created_at,
user_session_id: user_session.id,
client_id: grant.client_id,
scope: grant.scope.clone(),
})
}
#[tracing::instrument(
name = "db.oauth2_session.finish",
skip_all,
fields(
db.statement,
%session.id,
%session.scope,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
),
err,
)]
async fn finish(
&mut self,
clock: &dyn Clock,
session: Session,
) -> Result<Session, Self::Error> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_sessions
SET finished_at = $2
WHERE oauth2_session_id = $1
"#,
Uuid::from(session.id),
finished_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
session
.finish(finished_at)
.map_err(DatabaseError::to_invalid_operation)
}
#[tracing::instrument(
name = "db.oauth2_session.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err,
)]
async fn list_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<Session>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT oauth2_session_id
, user_session_id
, oauth2_client_id
, scope
, os.created_at
, os.finished_at
FROM oauth2_sessions os
INNER JOIN user_sessions USING (user_session_id)
"#,
);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("oauth2_session_id", pagination);
let edges: Vec<OAuthSessionLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).try_map(Session::try_from)?;
Ok(page)
}
}

View File

@@ -0,0 +1,78 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Utilities to manage paginated queries.
use mas_storage::{pagination::PaginationDirection, Pagination};
use sqlx::{Database, QueryBuilder};
use uuid::Uuid;
/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination
/// to a query
pub trait QueryBuilderExt {
/// Add cursor-based pagination to a query, as used in paginated GraphQL
/// connections
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self;
}
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
where
DB: Database,
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
{
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self {
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
// 1. Start from the greedy query: SELECT * FROM table
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
// clause
if let Some(after) = pagination.after {
self.push(" AND ")
.push(id_field)
.push(" > ")
.push_bind(Uuid::from(after));
}
// 3. If the before argument is provided, add `id < parsed_cursor` to the
// `WHERE` clause
if let Some(before) = pagination.before {
self.push(" AND ")
.push(id_field)
.push(" < ")
.push_bind(Uuid::from(before));
}
match pagination.direction {
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the
// query
PaginationDirection::Forward => {
self.push(" ORDER BY ")
.push(id_field)
.push(" ASC LIMIT ")
.push_bind((pagination.count + 1) as i64);
}
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the
// query
PaginationDirection::Backward => {
self.push(" ORDER BY ")
.push(id_field)
.push(" DESC LIMIT ")
.push_bind((pagination.count + 1) as i64);
}
};
self
}
}

Some files were not shown because too many files have changed in this diff Show More