Merge pull request #733 from matrix-org/quenting/storage-repository
Repository pattern
This commit is contained in:
63
.github/workflows/docs.yaml
vendored
63
.github/workflows/docs.yaml
vendored
@@ -1,30 +1,69 @@
|
||||
name: Deploy the documentation
|
||||
name: Build and deploy the documentation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
CARGO_NET_GIT_FETCH_WITH_CLI: "true"
|
||||
|
||||
jobs:
|
||||
pages:
|
||||
name: GitHub Pages
|
||||
build:
|
||||
name: Build the documentation
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout the code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install Rust toolchain
|
||||
run: |
|
||||
rustup toolchain install nightly
|
||||
rustup default nightly
|
||||
|
||||
- name: Setup Rust cache
|
||||
uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Setup mdBook
|
||||
uses: peaceiris/actions-mdbook@adeb05db28a0c0004681db83893d56c0388ea9ea # v1.1.14
|
||||
uses: peaceiris/actions-mdbook@v1.2.0
|
||||
with:
|
||||
mdbook-version: '0.4.12'
|
||||
mdbook-version: '0.4.25'
|
||||
|
||||
- name: Build the documentation
|
||||
run: mdbook build
|
||||
|
||||
- name: Deploy to GitHub Pages
|
||||
uses: peaceiris/actions-gh-pages@64b46b4226a4a12da2239ba3ea5aa73e3163c75b # v3.8.0
|
||||
|
||||
- name: Build rustdoc
|
||||
run: cargo doc -Zrustdoc-map --workspace --lib --no-deps
|
||||
|
||||
- name: Move the Rust documentation within the mdBook
|
||||
run: mv target/doc target/book/rustdoc
|
||||
|
||||
- name: Upload GitHub Pages artifacts
|
||||
uses: actions/upload-pages-artifact@v1.0.7
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: ./target/book
|
||||
path: target/book/
|
||||
|
||||
deploy:
|
||||
name: Deploy the documentation on GitHub Pages
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
if: github.ref == 'refs/heads/main'
|
||||
|
||||
permissions:
|
||||
pages: write
|
||||
id-token: write
|
||||
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v1.2.3
|
||||
|
||||
30
Cargo.lock
generated
30
Cargo.lock
generated
@@ -2688,7 +2688,6 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"serde_with",
|
||||
"sqlx",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tower",
|
||||
@@ -2721,6 +2720,7 @@ dependencies = [
|
||||
"mas-router",
|
||||
"mas-spa",
|
||||
"mas-storage",
|
||||
"mas-storage-pg",
|
||||
"mas-tasks",
|
||||
"mas-templates",
|
||||
"oauth2-types",
|
||||
@@ -2821,8 +2821,8 @@ dependencies = [
|
||||
"mas-storage",
|
||||
"oauth2-types",
|
||||
"serde",
|
||||
"sqlx",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"ulid",
|
||||
"url",
|
||||
@@ -2859,6 +2859,7 @@ dependencies = [
|
||||
"mas-policy",
|
||||
"mas-router",
|
||||
"mas-storage",
|
||||
"mas-storage-pg",
|
||||
"mas-templates",
|
||||
"mime",
|
||||
"oauth2-types",
|
||||
@@ -3112,12 +3113,33 @@ dependencies = [
|
||||
name = "mas-storage"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"futures-util",
|
||||
"mas-data-model",
|
||||
"mas-iana",
|
||||
"mas-jose",
|
||||
"oauth2-types",
|
||||
"rand_core 0.6.4",
|
||||
"thiserror",
|
||||
"ulid",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mas-storage-pg"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"futures-util",
|
||||
"mas-data-model",
|
||||
"mas-iana",
|
||||
"mas-jose",
|
||||
"mas-storage",
|
||||
"oauth2-types",
|
||||
"rand 0.8.5",
|
||||
"rand_chacha 0.3.1",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sqlx",
|
||||
@@ -3135,6 +3157,7 @@ dependencies = [
|
||||
"async-trait",
|
||||
"futures-util",
|
||||
"mas-storage",
|
||||
"mas-storage-pg",
|
||||
"sqlx",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
@@ -5590,8 +5613,7 @@ checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81"
|
||||
[[package]]
|
||||
name = "ulid"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13a3aaa69b04e5b66cc27309710a569ea23593612387d67daaf102e73aa974fd"
|
||||
source = "git+https://github.com/dylanhart/ulid-rs.git?rev=0b9295c2db2114cd87aa19abcc1fc00c16b272db#0b9295c2db2114cd87aa19abcc1fc00c16b272db"
|
||||
dependencies = [
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
|
||||
@@ -7,3 +7,8 @@ opt-level = 3
|
||||
|
||||
[profile.dev.package.sqlx-macros]
|
||||
opt-level = 3
|
||||
|
||||
# Until https://github.com/dylanhart/ulid-rs/pull/56 gets released
|
||||
[patch.crates-io.ulid]
|
||||
git = "https://github.com/dylanhart/ulid-rs.git"
|
||||
rev = "0b9295c2db2114cd87aa19abcc1fc00c16b272db"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
doc-valid-idents = ["OpenID", "OAuth", ".."]
|
||||
doc-valid-idents = ["OpenID", "OAuth", "..", "PostgreSQL"]
|
||||
|
||||
disallowed-methods = [
|
||||
{ path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" },
|
||||
|
||||
@@ -21,7 +21,6 @@ serde = "1.0.152"
|
||||
serde_with = "2.1.0"
|
||||
serde_urlencoded = "0.7.1"
|
||||
serde_json = "1.0.91"
|
||||
sqlx = "0.6.2"
|
||||
thiserror = "1.0.38"
|
||||
tokio = "1.24.1"
|
||||
tower = { version = "0.4.13", features = ["util"] }
|
||||
|
||||
@@ -31,10 +31,9 @@ use mas_http::HttpServiceExt;
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_storage::{oauth2::client::lookup_client_by_client_id, DatabaseError};
|
||||
use mas_storage::{oauth2::OAuth2ClientRepository, RepositoryAccess};
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use serde_json::Value;
|
||||
use sqlx::PgExecutor;
|
||||
use thiserror::Error;
|
||||
use tower::{Service, ServiceExt};
|
||||
|
||||
@@ -73,10 +72,10 @@ pub enum Credentials {
|
||||
}
|
||||
|
||||
impl Credentials {
|
||||
pub async fn fetch(
|
||||
pub async fn fetch<E>(
|
||||
&self,
|
||||
executor: impl PgExecutor<'_>,
|
||||
) -> Result<Option<Client>, DatabaseError> {
|
||||
repo: &mut impl RepositoryAccess<Error = E>,
|
||||
) -> Result<Option<Client>, E> {
|
||||
let client_id = match self {
|
||||
Credentials::None { client_id }
|
||||
| Credentials::ClientSecretBasic { client_id, .. }
|
||||
@@ -84,7 +83,7 @@ impl Credentials {
|
||||
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
|
||||
};
|
||||
|
||||
lookup_client_by_client_id(executor, client_id).await
|
||||
repo.oauth2_client().find_by_client_id(client_id).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use data_encoding::{DecodeError, BASE64URL_NOPAD};
|
||||
use mas_storage::Clock;
|
||||
use rand::{Rng, RngCore};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, TimestampSeconds};
|
||||
@@ -108,22 +109,27 @@ pub struct ProtectedForm<T> {
|
||||
}
|
||||
|
||||
pub trait CsrfExt {
|
||||
fn csrf_token<R>(self, now: DateTime<Utc>, rng: R) -> (CsrfToken, Self)
|
||||
fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
|
||||
where
|
||||
R: RngCore;
|
||||
fn verify_form<T>(&self, now: DateTime<Utc>, form: ProtectedForm<T>) -> Result<T, CsrfError>;
|
||||
R: RngCore,
|
||||
C: Clock;
|
||||
fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
|
||||
where
|
||||
C: Clock;
|
||||
}
|
||||
|
||||
impl<K> CsrfExt for PrivateCookieJar<K> {
|
||||
fn csrf_token<R>(self, now: DateTime<Utc>, rng: R) -> (CsrfToken, Self)
|
||||
fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
|
||||
where
|
||||
R: RngCore,
|
||||
C: Clock,
|
||||
{
|
||||
let jar = self;
|
||||
let mut cookie = jar.get("csrf").unwrap_or_else(|| Cookie::new("csrf", ""));
|
||||
cookie.set_path("/");
|
||||
cookie.set_http_only(true);
|
||||
|
||||
let now = clock.now();
|
||||
let new_token = cookie
|
||||
.decode()
|
||||
.ok()
|
||||
@@ -136,10 +142,13 @@ impl<K> CsrfExt for PrivateCookieJar<K> {
|
||||
(new_token, jar)
|
||||
}
|
||||
|
||||
fn verify_form<T>(&self, now: DateTime<Utc>, form: ProtectedForm<T>) -> Result<T, CsrfError> {
|
||||
fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
|
||||
where
|
||||
C: Clock,
|
||||
{
|
||||
let cookie = self.get("csrf").ok_or(CsrfError::Missing)?;
|
||||
let token: CsrfToken = cookie.decode()?;
|
||||
let token = token.verify_expiration(now)?;
|
||||
let token = token.verify_expiration(clock.now())?;
|
||||
token.verify_form_value(&form.csrf)?;
|
||||
Ok(form.inner)
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ impl HttpClientFactory {
|
||||
Ok(layer.layer(client))
|
||||
}
|
||||
|
||||
/// Constructs a new [`HttpService`], suitable for [`mas_oidc_client`]
|
||||
/// Constructs a new [`HttpService`], suitable for `mas-oidc-client`
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
|
||||
@@ -14,9 +14,8 @@
|
||||
|
||||
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
|
||||
use mas_data_model::BrowserSession;
|
||||
use mas_storage::{user::lookup_active_session, DatabaseError};
|
||||
use mas_storage::{user::BrowserSessionRepository, RepositoryAccess};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Executor, Postgres};
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::CookieExt;
|
||||
@@ -44,18 +43,24 @@ impl SessionInfo {
|
||||
}
|
||||
|
||||
/// Load the [`BrowserSession`] from database
|
||||
pub async fn load_session(
|
||||
pub async fn load_session<E>(
|
||||
&self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> Result<Option<BrowserSession>, DatabaseError> {
|
||||
repo: &mut impl RepositoryAccess<Error = E>,
|
||||
) -> Result<Option<BrowserSession>, E> {
|
||||
let session_id = if let Some(id) = self.current {
|
||||
id
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let res = lookup_active_session(executor, session_id).await?;
|
||||
Ok(res)
|
||||
let maybe_session = repo
|
||||
.browser_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
// Ensure that the session is still active
|
||||
.filter(BrowserSession::active);
|
||||
|
||||
Ok(maybe_session)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,9 +27,11 @@ use axum::{
|
||||
use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName};
|
||||
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
|
||||
use mas_data_model::Session;
|
||||
use mas_storage::{oauth2::access_token::lookup_active_access_token, DatabaseError};
|
||||
use mas_storage::{
|
||||
oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
|
||||
Clock, RepositoryAccess,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use sqlx::PgConnection;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -49,16 +51,24 @@ enum AccessToken {
|
||||
}
|
||||
|
||||
impl AccessToken {
|
||||
pub async fn fetch(
|
||||
async fn fetch<E>(
|
||||
&self,
|
||||
conn: &mut PgConnection,
|
||||
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> {
|
||||
repo: &mut impl RepositoryAccess<Error = E>,
|
||||
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> {
|
||||
let token = match self {
|
||||
AccessToken::Form(t) | AccessToken::Header(t) => t,
|
||||
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
|
||||
};
|
||||
|
||||
let (token, session) = lookup_active_access_token(conn, token.as_str())
|
||||
let token = repo
|
||||
.oauth2_access_token()
|
||||
.find_by_token(token.as_str())
|
||||
.await?
|
||||
.ok_or(AuthorizationVerificationError::InvalidToken)?;
|
||||
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.lookup(token.session_id)
|
||||
.await?
|
||||
.ok_or(AuthorizationVerificationError::InvalidToken)?;
|
||||
|
||||
@@ -74,26 +84,36 @@ pub struct UserAuthorization<F = ()> {
|
||||
|
||||
impl<F: Send> UserAuthorization<F> {
|
||||
// TODO: take scopes to validate as parameter
|
||||
pub async fn protected_form(
|
||||
pub async fn protected_form<E>(
|
||||
self,
|
||||
conn: &mut PgConnection,
|
||||
) -> Result<(Session, F), AuthorizationVerificationError> {
|
||||
repo: &mut impl RepositoryAccess<Error = E>,
|
||||
clock: &impl Clock,
|
||||
) -> Result<(Session, F), AuthorizationVerificationError<E>> {
|
||||
let form = match self.form {
|
||||
Some(f) => f,
|
||||
None => return Err(AuthorizationVerificationError::MissingForm),
|
||||
};
|
||||
|
||||
let (_token, session) = self.access_token.fetch(conn).await?;
|
||||
let (token, session) = self.access_token.fetch(repo).await?;
|
||||
|
||||
if !token.is_valid(clock.now()) || !session.is_valid() {
|
||||
return Err(AuthorizationVerificationError::InvalidToken);
|
||||
}
|
||||
|
||||
Ok((session, form))
|
||||
}
|
||||
|
||||
// TODO: take scopes to validate as parameter
|
||||
pub async fn protected(
|
||||
pub async fn protected<E>(
|
||||
self,
|
||||
conn: &mut PgConnection,
|
||||
) -> Result<Session, AuthorizationVerificationError> {
|
||||
let (_token, session) = self.access_token.fetch(conn).await?;
|
||||
repo: &mut impl RepositoryAccess<Error = E>,
|
||||
clock: &impl Clock,
|
||||
) -> Result<Session, AuthorizationVerificationError<E>> {
|
||||
let (token, session) = self.access_token.fetch(repo).await?;
|
||||
|
||||
if !token.is_valid(clock.now()) || !session.is_valid() {
|
||||
return Err(AuthorizationVerificationError::InvalidToken);
|
||||
}
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
@@ -107,7 +127,7 @@ pub enum UserAuthorizationError {
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthorizationVerificationError {
|
||||
pub enum AuthorizationVerificationError<E> {
|
||||
#[error("missing token")]
|
||||
MissingToken,
|
||||
|
||||
@@ -118,7 +138,7 @@ pub enum AuthorizationVerificationError {
|
||||
MissingForm,
|
||||
|
||||
#[error(transparent)]
|
||||
Internal(#[from] DatabaseError),
|
||||
Internal(#[from] E),
|
||||
}
|
||||
|
||||
enum BearerError {
|
||||
@@ -226,7 +246,10 @@ impl IntoResponse for UserAuthorizationError {
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for AuthorizationVerificationError {
|
||||
impl<E> IntoResponse for AuthorizationVerificationError<E>
|
||||
where
|
||||
E: ToString,
|
||||
{
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
Self::MissingForm | Self::MissingToken => {
|
||||
|
||||
@@ -50,6 +50,7 @@ mas-policy = { path = "../policy" }
|
||||
mas-router = { path = "../router" }
|
||||
mas-spa = { path = "../spa" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-storage-pg = { path = "../storage-pg" }
|
||||
mas-tasks = { path = "../tasks" }
|
||||
mas-templates = { path = "../templates" }
|
||||
oauth2-types = { path = "../oauth2-types" }
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
use clap::Parser;
|
||||
use mas_config::{ConfigurationSection, RootConfig};
|
||||
use rand::SeedableRng;
|
||||
use tracing::info;
|
||||
use tracing::{info, info_span};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
pub(super) struct Options {
|
||||
@@ -40,6 +40,8 @@ impl Options {
|
||||
use Subcommand as SC;
|
||||
match &self.subcommand {
|
||||
SC::Dump => {
|
||||
let _span = info_span!("cli.config.dump").entered();
|
||||
|
||||
let config: RootConfig = root.load_config()?;
|
||||
|
||||
serde_yaml::to_writer(std::io::stdout(), &config)?;
|
||||
@@ -47,11 +49,15 @@ impl Options {
|
||||
Ok(())
|
||||
}
|
||||
SC::Check => {
|
||||
let _span = info_span!("cli.config.check").entered();
|
||||
|
||||
let _config: RootConfig = root.load_config()?;
|
||||
info!(path = ?root.config, "Configuration file looks good");
|
||||
Ok(())
|
||||
}
|
||||
SC::Generate => {
|
||||
let _span = info_span!("cli.config.generate").entered();
|
||||
|
||||
// XXX: we should disallow SeedableRng::from_entropy
|
||||
let rng = rand_chacha::ChaChaRng::from_entropy();
|
||||
let config = RootConfig::load_and_generate(rng).await?;
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use mas_config::DatabaseConfig;
|
||||
use mas_storage::MIGRATOR;
|
||||
use mas_storage_pg::MIGRATOR;
|
||||
use tracing::{info_span, Instrument};
|
||||
|
||||
use crate::util::database_from_config;
|
||||
|
||||
@@ -33,12 +34,14 @@ enum Subcommand {
|
||||
|
||||
impl Options {
|
||||
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
|
||||
let _span = info_span!("cli.database.migrate").entered();
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config).await?;
|
||||
|
||||
// Run pending migrations
|
||||
MIGRATOR
|
||||
.run(&pool)
|
||||
.instrument(info_span!("db.migrate"))
|
||||
.await
|
||||
.context("could not run migrations")?;
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ use mas_handlers::HttpClientFactory;
|
||||
use mas_http::HttpServiceExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tower::{Service, ServiceExt};
|
||||
use tracing::info;
|
||||
use tracing::{info, info_span};
|
||||
|
||||
use crate::util::policy_factory_from_config;
|
||||
|
||||
@@ -74,6 +74,7 @@ impl Options {
|
||||
json: false,
|
||||
url,
|
||||
} => {
|
||||
let _span = info_span!("cli.debug.http").entered();
|
||||
let mut client = http_client_factory.client("cli-debug-http").await?;
|
||||
let request = hyper::Request::builder()
|
||||
.uri(url)
|
||||
@@ -98,6 +99,7 @@ impl Options {
|
||||
json: true,
|
||||
url,
|
||||
} => {
|
||||
let _span = info_span!("cli.debug.http").entered();
|
||||
let mut client = http_client_factory
|
||||
.client("cli-debug-http")
|
||||
.await?
|
||||
@@ -122,6 +124,7 @@ impl Options {
|
||||
}
|
||||
|
||||
SC::Policy => {
|
||||
let _span = info_span!("cli.debug.policy").entered();
|
||||
let config: PolicyConfig = root.load_config()?;
|
||||
info!("Loading and compiling the policy module");
|
||||
let policy_factory = policy_factory_from_config(&config).await?;
|
||||
|
||||
@@ -18,15 +18,15 @@ use mas_config::{DatabaseConfig, PasswordsConfig, RootConfig};
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::{
|
||||
oauth2::client::{insert_client_from_config, lookup_client, truncate_clients},
|
||||
user::{
|
||||
add_user_password, lookup_user_by_username, lookup_user_email, mark_user_email_as_verified,
|
||||
},
|
||||
Clock,
|
||||
oauth2::OAuth2ClientRepository,
|
||||
upstream_oauth2::UpstreamOAuthProviderRepository,
|
||||
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
|
||||
Repository, RepositoryAccess, SystemClock,
|
||||
};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use oauth2_types::scope::Scope;
|
||||
use rand::SeedableRng;
|
||||
use tracing::{info, warn};
|
||||
use tracing::{info, info_span, warn};
|
||||
|
||||
use crate::util::{database_from_config, password_manager_from_config};
|
||||
|
||||
@@ -147,9 +147,9 @@ enum Subcommand {
|
||||
|
||||
/// Import clients from config
|
||||
ImportClients {
|
||||
/// Remove all clients before importing
|
||||
/// Update existing clients
|
||||
#[arg(long)]
|
||||
truncate: bool,
|
||||
update: bool,
|
||||
},
|
||||
|
||||
/// Set a user password
|
||||
@@ -188,20 +188,25 @@ impl Options {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
|
||||
use Subcommand as SC;
|
||||
let clock = Clock::default();
|
||||
let clock = SystemClock::default();
|
||||
// XXX: we should disallow SeedableRng::from_entropy
|
||||
let mut rng = rand_chacha::ChaChaRng::from_entropy();
|
||||
|
||||
match &self.subcommand {
|
||||
SC::SetPassword { username, password } => {
|
||||
let _span =
|
||||
info_span!("cli.manage.set_password", user.username = %username).entered();
|
||||
|
||||
let database_config: DatabaseConfig = root.load_config()?;
|
||||
let passwords_config: PasswordsConfig = root.load_config()?;
|
||||
|
||||
let pool = database_from_config(&database_config).await?;
|
||||
let password_manager = password_manager_from_config(&passwords_config).await?;
|
||||
|
||||
let mut txn = pool.begin().await?;
|
||||
let user = lookup_user_by_username(&mut txn, username)
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
let user = repo
|
||||
.user()
|
||||
.find_by_username(username)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
@@ -209,89 +214,96 @@ impl Options {
|
||||
|
||||
let (version, hashed_password) = password_manager.hash(&mut rng, password).await?;
|
||||
|
||||
add_user_password(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&user,
|
||||
version,
|
||||
hashed_password,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
repo.user_password()
|
||||
.add(&mut rng, &clock, &user, version, hashed_password, None)
|
||||
.await?;
|
||||
|
||||
info!(%user.id, %user.username, "Password changed");
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SC::VerifyEmail { username, email } => {
|
||||
let _span = info_span!(
|
||||
"cli.manage.verify_email",
|
||||
user.username = username,
|
||||
user_email.email = email
|
||||
)
|
||||
.entered();
|
||||
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config).await?;
|
||||
let mut txn = pool.begin().await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
|
||||
let user = lookup_user_by_username(&mut txn, username)
|
||||
let user = repo
|
||||
.user()
|
||||
.find_by_username(username)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
let email = lookup_user_email(&mut txn, &user, email)
|
||||
|
||||
let email = repo
|
||||
.user_email()
|
||||
.find(&user, email)
|
||||
.await?
|
||||
.context("Email not found")?;
|
||||
let email = mark_user_email_as_verified(&mut txn, &clock, email).await?;
|
||||
let email = repo.user_email().mark_as_verified(&clock, email).await?;
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
info!(?email, "Email marked as verified");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SC::ImportClients { truncate } => {
|
||||
SC::ImportClients { update } => {
|
||||
let _span = info_span!("cli.manage.import_clients").entered();
|
||||
|
||||
let config: RootConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config.database).await?;
|
||||
let encrypter = config.secrets.encrypter();
|
||||
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
if *truncate {
|
||||
warn!("Removing all clients first");
|
||||
truncate_clients(&mut txn).await?;
|
||||
}
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
|
||||
for client in config.clients.iter() {
|
||||
let client_id = client.client_id;
|
||||
let res = lookup_client(&mut txn, client_id).await?;
|
||||
if res.is_some() {
|
||||
warn!(%client_id, "Skipping already imported client");
|
||||
|
||||
let existing = repo.oauth2_client().lookup(client_id).await?.is_some();
|
||||
if !update && existing {
|
||||
warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients.");
|
||||
continue;
|
||||
}
|
||||
|
||||
info!(%client_id, "Importing client");
|
||||
if existing {
|
||||
info!(%client_id, "Updating client");
|
||||
} else {
|
||||
info!(%client_id, "Importing client");
|
||||
}
|
||||
|
||||
let client_secret = client.client_secret();
|
||||
let client_auth_method = client.client_auth_method();
|
||||
let jwks = client.jwks();
|
||||
let jwks_uri = client.jwks_uri();
|
||||
let redirect_uris = &client.redirect_uris;
|
||||
|
||||
// TODO: should be moved somewhere else
|
||||
let encrypted_client_secret = client_secret
|
||||
.map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes()))
|
||||
.transpose()?;
|
||||
|
||||
insert_client_from_config(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
client_id,
|
||||
client_auth_method,
|
||||
encrypted_client_secret.as_deref(),
|
||||
jwks,
|
||||
jwks_uri,
|
||||
redirect_uris,
|
||||
)
|
||||
.await?;
|
||||
repo.oauth2_client()
|
||||
.add_from_config(
|
||||
&mut rng,
|
||||
&clock,
|
||||
client_id,
|
||||
client_auth_method,
|
||||
encrypted_client_secret,
|
||||
jwks.cloned(),
|
||||
jwks_uri.cloned(),
|
||||
client.redirect_uris.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -304,11 +316,18 @@ impl Options {
|
||||
client_secret,
|
||||
signing_alg,
|
||||
} => {
|
||||
let _span = info_span!(
|
||||
"cli.manage.add_oauth_upstream",
|
||||
upstream_oauth_provider.issuer = issuer,
|
||||
upstream_oauth_provider.client_id = client_id,
|
||||
)
|
||||
.entered();
|
||||
|
||||
let config: RootConfig = root.load_config()?;
|
||||
let encrypter = config.secrets.encrypter();
|
||||
let pool = database_from_config(&config.database).await?;
|
||||
let url_builder = UrlBuilder::new(config.http.public_base);
|
||||
let mut conn = pool.acquire().await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?;
|
||||
|
||||
let requires_client_secret = token_endpoint_auth_method.requires_client_secret();
|
||||
|
||||
@@ -329,18 +348,19 @@ impl Options {
|
||||
.map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes()))
|
||||
.transpose()?;
|
||||
|
||||
let provider = mas_storage::upstream_oauth2::add_provider(
|
||||
&mut conn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
issuer.clone(),
|
||||
scope.clone(),
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_signing_alg,
|
||||
client_id.clone(),
|
||||
encrypted_client_secret,
|
||||
)
|
||||
.await?;
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
issuer.clone(),
|
||||
scope.clone(),
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_signing_alg,
|
||||
client_id.clone(),
|
||||
encrypted_client_secret,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
||||
let auth_uri = url_builder.upstream_oauth_authorize(provider.id);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -21,10 +21,10 @@ use mas_config::RootConfig;
|
||||
use mas_handlers::{AppState, HttpClientFactory, MatrixHomeserver};
|
||||
use mas_listener::{server::Server, shutdown::ShutdownStream};
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::MIGRATOR;
|
||||
use mas_storage_pg::MIGRATOR;
|
||||
use mas_tasks::TaskQueue;
|
||||
use tokio::signal::unix::SignalKind;
|
||||
use tracing::{info, warn};
|
||||
use tracing::{info, info_span, warn, Instrument};
|
||||
|
||||
use crate::util::{
|
||||
database_from_config, mailer_from_config, password_manager_from_config,
|
||||
@@ -45,6 +45,7 @@ pub(super) struct Options {
|
||||
impl Options {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
|
||||
let span = info_span!("cli.run.init").entered();
|
||||
let config: RootConfig = root.load_config()?;
|
||||
|
||||
// Connect to the database
|
||||
@@ -55,6 +56,7 @@ impl Options {
|
||||
info!("Running pending migrations");
|
||||
MIGRATOR
|
||||
.run(&pool)
|
||||
.instrument(info_span!("db.migrate"))
|
||||
.await
|
||||
.context("could not run migrations")?;
|
||||
}
|
||||
@@ -100,7 +102,7 @@ impl Options {
|
||||
watch_templates(&templates).await?;
|
||||
}
|
||||
|
||||
let graphql_schema = mas_handlers::graphql_schema(&pool);
|
||||
let graphql_schema = mas_handlers::graphql_schema();
|
||||
|
||||
// Maximum 50 outgoing HTTP requests at a time
|
||||
let http_client_factory = HttpClientFactory::new(50);
|
||||
@@ -186,6 +188,8 @@ impl Options {
|
||||
.with_signal(SignalKind::terminate())?
|
||||
.with_signal(SignalKind::interrupt())?;
|
||||
|
||||
span.exit();
|
||||
|
||||
mas_listener::server::run_servers(servers, shutdown).await;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -14,9 +14,10 @@
|
||||
|
||||
use camino::Utf8PathBuf;
|
||||
use clap::Parser;
|
||||
use mas_storage::Clock;
|
||||
use mas_storage::{Clock, SystemClock};
|
||||
use mas_templates::Templates;
|
||||
use rand::SeedableRng;
|
||||
use tracing::info_span;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
pub(super) struct Options {
|
||||
@@ -38,7 +39,9 @@ impl Options {
|
||||
use Subcommand as SC;
|
||||
match &self.subcommand {
|
||||
SC::Check { path } => {
|
||||
let clock = Clock::default();
|
||||
let _span = info_span!("cli.templates.check").entered();
|
||||
|
||||
let clock = SystemClock::default();
|
||||
// XXX: we should disallow SeedableRng::from_entropy
|
||||
let mut rng = rand_chacha::ChaChaRng::from_entropy();
|
||||
let url_builder = mas_router::UrlBuilder::new("https://example.com/".parse()?);
|
||||
|
||||
@@ -110,6 +110,7 @@ pub async fn templates_from_config(
|
||||
Templates::load(config.path.clone(), url_builder.clone()).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
|
||||
pub async fn database_from_config(config: &DatabaseConfig) -> Result<PgPool, anyhow::Error> {
|
||||
let mut options = match &config.options {
|
||||
DatabaseConnectConfig::Uri { uri } => uri
|
||||
|
||||
@@ -86,6 +86,7 @@ impl SecretsConfig {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error when a key could not be imported
|
||||
#[tracing::instrument(name = "secrets.load", skip_all, err(Debug))]
|
||||
pub async fn key_store(&self) -> anyhow::Result<Keystore> {
|
||||
let mut keys = Vec::with_capacity(self.keys.len());
|
||||
for item in &self.keys {
|
||||
|
||||
@@ -11,8 +11,8 @@ thiserror = "1.0.38"
|
||||
serde = "1.0.152"
|
||||
url = { version = "2.3.1", features = ["serde"] }
|
||||
crc = "3.0.0"
|
||||
ulid = { version = "1.0.0", features = ["serde"] }
|
||||
rand = "0.8.5"
|
||||
ulid = "1.0.0"
|
||||
rand_chacha = "0.3.1"
|
||||
|
||||
mas-iana = { path = "../iana" }
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -12,18 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use oauth2_types::scope::ScopeToken;
|
||||
use rand::{
|
||||
distributions::{Alphanumeric, DistString},
|
||||
Rng,
|
||||
RngCore,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
|
||||
use crate::User;
|
||||
|
||||
static DEVICE_ID_LENGTH: usize = 10;
|
||||
|
||||
@@ -53,7 +48,7 @@ impl Device {
|
||||
}
|
||||
|
||||
/// Generate a random device ID
|
||||
pub fn generate<R: Rng + ?Sized>(rng: &mut R) -> Self {
|
||||
pub fn generate<R: RngCore + ?Sized>(rng: &mut R) -> Self {
|
||||
let id: String = Alphanumeric.sample_string(rng, DEVICE_ID_LENGTH);
|
||||
Self { id }
|
||||
}
|
||||
@@ -81,50 +76,3 @@ impl TryFrom<String> for Device {
|
||||
Ok(Self { id })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct CompatSession {
|
||||
pub id: Ulid,
|
||||
pub user: User,
|
||||
pub device: Device,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub finished_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CompatAccessToken {
|
||||
pub id: Ulid,
|
||||
pub token: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CompatRefreshToken {
|
||||
pub id: Ulid,
|
||||
pub token: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub enum CompatSsoLoginState {
|
||||
Pending,
|
||||
Fulfilled {
|
||||
fulfilled_at: DateTime<Utc>,
|
||||
session: CompatSession,
|
||||
},
|
||||
Exchanged {
|
||||
fulfilled_at: DateTime<Utc>,
|
||||
exchanged_at: DateTime<Utc>,
|
||||
session: CompatSession,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct CompatSsoLogin {
|
||||
pub id: Ulid,
|
||||
pub redirect_uri: Url,
|
||||
pub login_token: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub state: CompatSsoLoginState,
|
||||
}
|
||||
106
crates/data-model/src/compat/mod.rs
Normal file
106
crates/data-model/src/compat/mod.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use ulid::Ulid;
|
||||
|
||||
mod device;
|
||||
mod session;
|
||||
mod sso_login;
|
||||
|
||||
pub use self::{
|
||||
device::Device,
|
||||
session::{CompatSession, CompatSessionState},
|
||||
sso_login::{CompatSsoLogin, CompatSsoLoginState},
|
||||
};
|
||||
use crate::InvalidTransitionError;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CompatAccessToken {
|
||||
pub id: Ulid,
|
||||
pub session_id: Ulid,
|
||||
pub token: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl CompatAccessToken {
|
||||
#[must_use]
|
||||
pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
|
||||
if let Some(expires_at) = self.expires_at {
|
||||
expires_at > now
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub enum CompatRefreshTokenState {
|
||||
#[default]
|
||||
Valid,
|
||||
Consumed {
|
||||
consumed_at: DateTime<Utc>,
|
||||
},
|
||||
}
|
||||
|
||||
impl CompatRefreshTokenState {
|
||||
/// Returns `true` if the compat refresh token state is [`Valid`].
|
||||
///
|
||||
/// [`Valid`]: CompatRefreshTokenState::Valid
|
||||
#[must_use]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
matches!(self, Self::Valid)
|
||||
}
|
||||
|
||||
/// Returns `true` if the compat refresh token state is [`Consumed`].
|
||||
///
|
||||
/// [`Consumed`]: CompatRefreshTokenState::Consumed
|
||||
#[must_use]
|
||||
pub fn is_consumed(&self) -> bool {
|
||||
matches!(self, Self::Consumed { .. })
|
||||
}
|
||||
|
||||
pub fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Valid => Ok(Self::Consumed { consumed_at }),
|
||||
Self::Consumed { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CompatRefreshToken {
|
||||
pub id: Ulid,
|
||||
pub state: CompatRefreshTokenState,
|
||||
pub session_id: Ulid,
|
||||
pub access_token_id: Ulid,
|
||||
pub token: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CompatRefreshToken {
|
||||
type Target = CompatRefreshTokenState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.state
|
||||
}
|
||||
}
|
||||
|
||||
impl CompatRefreshToken {
|
||||
pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.consume(consumed_at)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
86
crates/data-model/src/compat/session.rs
Normal file
86
crates/data-model/src/compat/session.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::Serialize;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::Device;
|
||||
use crate::InvalidTransitionError;
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
|
||||
pub enum CompatSessionState {
|
||||
#[default]
|
||||
Valid,
|
||||
Finished {
|
||||
finished_at: DateTime<Utc>,
|
||||
},
|
||||
}
|
||||
|
||||
impl CompatSessionState {
|
||||
/// Returns `true` if the compta session state is [`Valid`].
|
||||
///
|
||||
/// [`Valid`]: CompatSessionState::Valid
|
||||
#[must_use]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
matches!(self, Self::Valid)
|
||||
}
|
||||
|
||||
/// Returns `true` if the compta session state is [`Finished`].
|
||||
///
|
||||
/// [`Finished`]: CompatSessionState::Finished
|
||||
#[must_use]
|
||||
pub fn is_finished(&self) -> bool {
|
||||
matches!(self, Self::Finished { .. })
|
||||
}
|
||||
|
||||
pub fn finish(self, finished_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Valid => Ok(Self::Finished { finished_at }),
|
||||
Self::Finished { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn finished_at(&self) -> Option<DateTime<Utc>> {
|
||||
match self {
|
||||
CompatSessionState::Valid => None,
|
||||
CompatSessionState::Finished { finished_at } => Some(*finished_at),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct CompatSession {
|
||||
pub id: Ulid,
|
||||
pub state: CompatSessionState,
|
||||
pub user_id: Ulid,
|
||||
pub device: Device,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CompatSession {
|
||||
type Target = CompatSessionState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.state
|
||||
}
|
||||
}
|
||||
|
||||
impl CompatSession {
|
||||
pub fn finish(mut self, finished_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.finish(finished_at)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
151
crates/data-model/src/compat/sso_login.rs
Normal file
151
crates/data-model/src/compat/sso_login.rs
Normal file
@@ -0,0 +1,151 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::Serialize;
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
|
||||
use super::CompatSession;
|
||||
use crate::InvalidTransitionError;
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
|
||||
pub enum CompatSsoLoginState {
|
||||
#[default]
|
||||
Pending,
|
||||
Fulfilled {
|
||||
fulfilled_at: DateTime<Utc>,
|
||||
session_id: Ulid,
|
||||
},
|
||||
Exchanged {
|
||||
fulfilled_at: DateTime<Utc>,
|
||||
exchanged_at: DateTime<Utc>,
|
||||
session_id: Ulid,
|
||||
},
|
||||
}
|
||||
|
||||
impl CompatSsoLoginState {
|
||||
/// Returns `true` if the compat sso login state is [`Pending`].
|
||||
///
|
||||
/// [`Pending`]: CompatSsoLoginState::Pending
|
||||
#[must_use]
|
||||
pub fn is_pending(&self) -> bool {
|
||||
matches!(self, Self::Pending)
|
||||
}
|
||||
|
||||
/// Returns `true` if the compat sso login state is [`Fulfilled`].
|
||||
///
|
||||
/// [`Fulfilled`]: CompatSsoLoginState::Fulfilled
|
||||
#[must_use]
|
||||
pub fn is_fulfilled(&self) -> bool {
|
||||
matches!(self, Self::Fulfilled { .. })
|
||||
}
|
||||
|
||||
/// Returns `true` if the compat sso login state is [`Exchanged`].
|
||||
///
|
||||
/// [`Exchanged`]: CompatSsoLoginState::Exchanged
|
||||
#[must_use]
|
||||
pub fn is_exchanged(&self) -> bool {
|
||||
matches!(self, Self::Exchanged { .. })
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn fulfilled_at(&self) -> Option<DateTime<Utc>> {
|
||||
match self {
|
||||
Self::Pending => None,
|
||||
Self::Fulfilled { fulfilled_at, .. } | Self::Exchanged { fulfilled_at, .. } => {
|
||||
Some(*fulfilled_at)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn exchanged_at(&self) -> Option<DateTime<Utc>> {
|
||||
match self {
|
||||
Self::Pending | Self::Fulfilled { .. } => None,
|
||||
Self::Exchanged { exchanged_at, .. } => Some(*exchanged_at),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn session_id(&self) -> Option<Ulid> {
|
||||
match self {
|
||||
Self::Pending => None,
|
||||
Self::Fulfilled { session_id, .. } | Self::Exchanged { session_id, .. } => {
|
||||
Some(*session_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fulfill(
|
||||
self,
|
||||
fulfilled_at: DateTime<Utc>,
|
||||
session: &CompatSession,
|
||||
) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Pending => Ok(Self::Fulfilled {
|
||||
fulfilled_at,
|
||||
session_id: session.id,
|
||||
}),
|
||||
Self::Fulfilled { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Fulfilled {
|
||||
fulfilled_at,
|
||||
session_id,
|
||||
} => Ok(Self::Exchanged {
|
||||
fulfilled_at,
|
||||
exchanged_at,
|
||||
session_id,
|
||||
}),
|
||||
Self::Pending { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct CompatSsoLogin {
|
||||
pub id: Ulid,
|
||||
pub redirect_uri: Url,
|
||||
pub login_token: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub state: CompatSsoLoginState,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CompatSsoLogin {
|
||||
type Target = CompatSsoLoginState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.state
|
||||
}
|
||||
}
|
||||
|
||||
impl CompatSsoLogin {
|
||||
pub fn fulfill(
|
||||
mut self,
|
||||
fulfilled_at: DateTime<Utc>,
|
||||
session: &CompatSession,
|
||||
) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.fulfill(fulfilled_at, session)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.exchange(exchanged_at)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -23,24 +23,33 @@
|
||||
clippy::type_repetition_in_bounds
|
||||
)]
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
pub(crate) mod compat;
|
||||
pub(crate) mod oauth2;
|
||||
pub(crate) mod tokens;
|
||||
pub(crate) mod upstream_oauth2;
|
||||
pub(crate) mod users;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("invalid state transition")]
|
||||
pub struct InvalidTransitionError;
|
||||
|
||||
pub use self::{
|
||||
compat::{
|
||||
CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState,
|
||||
Device,
|
||||
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
|
||||
CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device,
|
||||
},
|
||||
oauth2::{
|
||||
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client,
|
||||
InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session,
|
||||
InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState,
|
||||
},
|
||||
tokens::{
|
||||
AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType,
|
||||
},
|
||||
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
|
||||
upstream_oauth2::{
|
||||
UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider,
|
||||
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState,
|
||||
UpstreamOAuthLink, UpstreamOAuthProvider,
|
||||
},
|
||||
users::{
|
||||
Authentication, BrowserSession, Password, User, UserEmail, UserEmailVerification,
|
||||
|
||||
@@ -21,11 +21,11 @@ use oauth2_types::{
|
||||
requests::ResponseMode,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
|
||||
use super::{client::Client, session::Session};
|
||||
use super::session::Session;
|
||||
use crate::InvalidTransitionError;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct Pkce {
|
||||
@@ -53,21 +53,17 @@ pub struct AuthorizationCode {
|
||||
pub pkce: Option<Pkce>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("invalid state transition")]
|
||||
pub struct InvalidTransitionError;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
|
||||
#[serde(tag = "stage", rename_all = "lowercase")]
|
||||
pub enum AuthorizationGrantStage {
|
||||
#[default]
|
||||
Pending,
|
||||
Fulfilled {
|
||||
session: Session,
|
||||
session_id: Ulid,
|
||||
fulfilled_at: DateTime<Utc>,
|
||||
},
|
||||
Exchanged {
|
||||
session: Session,
|
||||
session_id: Ulid,
|
||||
fulfilled_at: DateTime<Utc>,
|
||||
exchanged_at: DateTime<Utc>,
|
||||
},
|
||||
@@ -82,35 +78,35 @@ impl AuthorizationGrantStage {
|
||||
Self::Pending
|
||||
}
|
||||
|
||||
pub fn fulfill(
|
||||
fn fulfill(
|
||||
self,
|
||||
fulfilled_at: DateTime<Utc>,
|
||||
session: Session,
|
||||
session: &Session,
|
||||
) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Pending => Ok(Self::Fulfilled {
|
||||
fulfilled_at,
|
||||
session,
|
||||
session_id: session.id,
|
||||
}),
|
||||
_ => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Fulfilled {
|
||||
fulfilled_at,
|
||||
session,
|
||||
session_id,
|
||||
} => Ok(Self::Exchanged {
|
||||
fulfilled_at,
|
||||
exchanged_at,
|
||||
session,
|
||||
session_id,
|
||||
}),
|
||||
_ => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Pending => Ok(Self::Cancelled { cancelled_at }),
|
||||
_ => Err(InvalidTransitionError),
|
||||
@@ -124,6 +120,22 @@ impl AuthorizationGrantStage {
|
||||
pub fn is_pending(&self) -> bool {
|
||||
matches!(self, Self::Pending)
|
||||
}
|
||||
|
||||
/// Returns `true` if the authorization grant stage is [`Fulfilled`].
|
||||
///
|
||||
/// [`Fulfilled`]: AuthorizationGrantStage::Fulfilled
|
||||
#[must_use]
|
||||
pub fn is_fulfilled(&self) -> bool {
|
||||
matches!(self, Self::Fulfilled { .. })
|
||||
}
|
||||
|
||||
/// Returns `true` if the authorization grant stage is [`Exchanged`].
|
||||
///
|
||||
/// [`Exchanged`]: AuthorizationGrantStage::Exchanged
|
||||
#[must_use]
|
||||
pub fn is_exchanged(&self) -> bool {
|
||||
matches!(self, Self::Exchanged { .. })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
@@ -132,7 +144,7 @@ pub struct AuthorizationGrant {
|
||||
#[serde(flatten)]
|
||||
pub stage: AuthorizationGrantStage,
|
||||
pub code: Option<AuthorizationCode>,
|
||||
pub client: Client,
|
||||
pub client_id: Ulid,
|
||||
pub redirect_uri: Url,
|
||||
pub scope: oauth2_types::scope::Scope,
|
||||
pub state: Option<String>,
|
||||
@@ -144,10 +156,38 @@ pub struct AuthorizationGrant {
|
||||
pub requires_consent: bool,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for AuthorizationGrant {
|
||||
type Target = AuthorizationGrantStage;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.stage
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthorizationGrant {
|
||||
#[must_use]
|
||||
pub fn max_auth_time(&self) -> DateTime<Utc> {
|
||||
let max_age: Option<i64> = self.max_age.map(|x| x.get().into());
|
||||
self.created_at - Duration::seconds(max_age.unwrap_or(3600 * 24 * 365))
|
||||
}
|
||||
|
||||
pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
self.stage = self.stage.exchange(exchanged_at)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn fulfill(
|
||||
mut self,
|
||||
fulfilled_at: DateTime<Utc>,
|
||||
session: &Session,
|
||||
) -> Result<Self, InvalidTransitionError> {
|
||||
self.stage = self.stage.fulfill(fulfilled_at, session)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
// TODO: this is not used?
|
||||
pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
self.stage = self.stage.cancel(canceld_at)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -19,5 +19,5 @@ pub(self) mod session;
|
||||
pub use self::{
|
||||
authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce},
|
||||
client::{Client, InvalidRedirectUriError, JwksOrJwksUri},
|
||||
session::Session,
|
||||
session::{Session, SessionState},
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -12,17 +12,76 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use oauth2_types::scope::Scope;
|
||||
use serde::Serialize;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::client::Client;
|
||||
use crate::users::BrowserSession;
|
||||
use crate::InvalidTransitionError;
|
||||
|
||||
trait T {
|
||||
type State;
|
||||
}
|
||||
|
||||
impl T for Session {
|
||||
type State = SessionState;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
|
||||
pub enum SessionState {
|
||||
#[default]
|
||||
Valid,
|
||||
Finished {
|
||||
finished_at: DateTime<Utc>,
|
||||
},
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
/// Returns `true` if the session state is [`Valid`].
|
||||
///
|
||||
/// [`Valid`]: SessionState::Valid
|
||||
#[must_use]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
matches!(self, Self::Valid)
|
||||
}
|
||||
|
||||
/// Returns `true` if the session state is [`Finished`].
|
||||
///
|
||||
/// [`Finished`]: SessionState::Finished
|
||||
#[must_use]
|
||||
pub fn is_finished(&self) -> bool {
|
||||
matches!(self, Self::Finished { .. })
|
||||
}
|
||||
|
||||
pub fn finish(self, finished_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Valid => Ok(Self::Finished { finished_at }),
|
||||
Self::Finished { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct Session {
|
||||
pub id: Ulid,
|
||||
pub browser_session: BrowserSession,
|
||||
pub client: Client,
|
||||
pub state: SessionState,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub user_session_id: Ulid,
|
||||
pub client_id: Ulid,
|
||||
pub scope: Scope,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for Session {
|
||||
type Target = SessionState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.state
|
||||
}
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn finish(mut self, finished_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.finish(finished_at)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,25 +15,135 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use crc::{Crc, CRC_32_ISO_HDLC};
|
||||
use mas_iana::oauth::OAuthTokenTypeHint;
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use rand::{distributions::Alphanumeric, Rng, RngCore};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::InvalidTransitionError;
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub enum AccessTokenState {
|
||||
#[default]
|
||||
Valid,
|
||||
Revoked {
|
||||
revoked_at: DateTime<Utc>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AccessTokenState {
|
||||
fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Valid => Ok(Self::Revoked { revoked_at }),
|
||||
Self::Revoked { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the refresh token state is [`Valid`].
|
||||
///
|
||||
/// [`Valid`]: AccessTokenState::Valid
|
||||
#[must_use]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
matches!(self, Self::Valid)
|
||||
}
|
||||
|
||||
/// Returns `true` if the refresh token state is [`Revoked`].
|
||||
///
|
||||
/// [`Revoked`]: AccessTokenState::Revoked
|
||||
#[must_use]
|
||||
pub fn is_revoked(&self) -> bool {
|
||||
matches!(self, Self::Revoked { .. })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AccessToken {
|
||||
pub id: Ulid,
|
||||
pub jti: String,
|
||||
pub state: AccessTokenState,
|
||||
pub session_id: Ulid,
|
||||
pub access_token: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl AccessToken {
|
||||
#[must_use]
|
||||
pub fn jti(&self) -> String {
|
||||
self.id.to_string()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
|
||||
self.state.is_valid() && self.expires_at > now
|
||||
}
|
||||
|
||||
pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.revoke(revoked_at)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub enum RefreshTokenState {
|
||||
#[default]
|
||||
Valid,
|
||||
Consumed {
|
||||
consumed_at: DateTime<Utc>,
|
||||
},
|
||||
}
|
||||
|
||||
impl RefreshTokenState {
|
||||
fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Valid => Ok(Self::Consumed { consumed_at }),
|
||||
Self::Consumed { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the refresh token state is [`Valid`].
|
||||
///
|
||||
/// [`Valid`]: RefreshTokenState::Valid
|
||||
#[must_use]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
matches!(self, Self::Valid)
|
||||
}
|
||||
|
||||
/// Returns `true` if the refresh token state is [`Consumed`].
|
||||
///
|
||||
/// [`Consumed`]: RefreshTokenState::Consumed
|
||||
#[must_use]
|
||||
pub fn is_consumed(&self) -> bool {
|
||||
matches!(self, Self::Consumed { .. })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RefreshToken {
|
||||
pub id: Ulid,
|
||||
pub state: RefreshTokenState,
|
||||
pub refresh_token: String,
|
||||
pub session_id: Ulid,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub access_token: Option<AccessToken>,
|
||||
pub access_token_id: Option<Ulid>,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for RefreshToken {
|
||||
type Target = RefreshTokenState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.state
|
||||
}
|
||||
}
|
||||
|
||||
impl RefreshToken {
|
||||
#[must_use]
|
||||
pub fn jti(&self) -> String {
|
||||
self.id.to_string()
|
||||
}
|
||||
|
||||
pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.consume(consumed_at)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Type of token to generate or validate
|
||||
@@ -80,10 +190,10 @@ impl TokenType {
|
||||
/// use rand::thread_rng;
|
||||
/// use mas_data_model::TokenType::{AccessToken, RefreshToken};
|
||||
///
|
||||
/// AccessToken.generate(thread_rng());
|
||||
/// RefreshToken.generate(thread_rng());
|
||||
/// AccessToken.generate(&mut thread_rng());
|
||||
/// RefreshToken.generate(&mut thread_rng());
|
||||
/// ```
|
||||
pub fn generate(self, rng: impl Rng) -> String {
|
||||
pub fn generate(self, rng: &mut (impl RngCore + ?Sized)) -> String {
|
||||
let random_part: String = rng
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(30)
|
||||
|
||||
26
crates/data-model/src/upstream_oauth2/link.rs
Normal file
26
crates/data-model/src/upstream_oauth2/link.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::Serialize;
|
||||
use ulid::Ulid;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct UpstreamOAuthLink {
|
||||
pub id: Ulid,
|
||||
pub provider_id: Ulid,
|
||||
pub user_id: Option<Ulid>,
|
||||
pub subject: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -12,55 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use oauth2_types::scope::Scope;
|
||||
use serde::Serialize;
|
||||
use ulid::Ulid;
|
||||
mod link;
|
||||
mod provider;
|
||||
mod session;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct UpstreamOAuthProvider {
|
||||
pub id: Ulid,
|
||||
pub issuer: String,
|
||||
pub scope: Scope,
|
||||
pub client_id: String,
|
||||
pub encrypted_client_secret: Option<String>,
|
||||
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct UpstreamOAuthLink {
|
||||
pub id: Ulid,
|
||||
pub provider_id: Ulid,
|
||||
pub user_id: Option<Ulid>,
|
||||
pub subject: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct UpstreamOAuthAuthorizationSession {
|
||||
pub id: Ulid,
|
||||
pub provider_id: Ulid,
|
||||
pub link_id: Option<Ulid>,
|
||||
pub state: String,
|
||||
pub code_challenge_verifier: Option<String>,
|
||||
pub nonce: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
pub consumed_at: Option<DateTime<Utc>>,
|
||||
pub id_token: Option<String>,
|
||||
}
|
||||
|
||||
impl UpstreamOAuthAuthorizationSession {
|
||||
#[must_use]
|
||||
pub const fn completed(&self) -> bool {
|
||||
self.completed_at.is_some()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn consumed(&self) -> bool {
|
||||
self.consumed_at.is_some()
|
||||
}
|
||||
}
|
||||
pub use self::{
|
||||
link::UpstreamOAuthLink,
|
||||
provider::UpstreamOAuthProvider,
|
||||
session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState},
|
||||
};
|
||||
|
||||
31
crates/data-model/src/upstream_oauth2/provider.rs
Normal file
31
crates/data-model/src/upstream_oauth2/provider.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use oauth2_types::scope::Scope;
|
||||
use serde::Serialize;
|
||||
use ulid::Ulid;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct UpstreamOAuthProvider {
|
||||
pub id: Ulid,
|
||||
pub issuer: String,
|
||||
pub scope: Scope,
|
||||
pub client_id: String,
|
||||
pub encrypted_client_secret: Option<String>,
|
||||
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
170
crates/data-model/src/upstream_oauth2/session.rs
Normal file
170
crates/data-model/src/upstream_oauth2/session.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::Serialize;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::UpstreamOAuthLink;
|
||||
use crate::InvalidTransitionError;
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
|
||||
pub enum UpstreamOAuthAuthorizationSessionState {
|
||||
#[default]
|
||||
Pending,
|
||||
Completed {
|
||||
completed_at: DateTime<Utc>,
|
||||
link_id: Ulid,
|
||||
id_token: Option<String>,
|
||||
},
|
||||
Consumed {
|
||||
completed_at: DateTime<Utc>,
|
||||
consumed_at: DateTime<Utc>,
|
||||
link_id: Ulid,
|
||||
id_token: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl UpstreamOAuthAuthorizationSessionState {
|
||||
pub fn complete(
|
||||
self,
|
||||
completed_at: DateTime<Utc>,
|
||||
link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Pending => Ok(Self::Completed {
|
||||
completed_at,
|
||||
link_id: link.id,
|
||||
id_token,
|
||||
}),
|
||||
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Completed {
|
||||
completed_at,
|
||||
link_id,
|
||||
id_token,
|
||||
} => Ok(Self::Consumed {
|
||||
completed_at,
|
||||
link_id,
|
||||
consumed_at,
|
||||
id_token,
|
||||
}),
|
||||
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn link_id(&self) -> Option<Ulid> {
|
||||
match self {
|
||||
Self::Pending => None,
|
||||
Self::Completed { link_id, .. } | Self::Consumed { link_id, .. } => Some(*link_id),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn completed_at(&self) -> Option<DateTime<Utc>> {
|
||||
match self {
|
||||
Self::Pending => None,
|
||||
Self::Completed { completed_at, .. } | Self::Consumed { completed_at, .. } => {
|
||||
Some(*completed_at)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn id_token(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Pending => None,
|
||||
Self::Completed { id_token, .. } | Self::Consumed { id_token, .. } => {
|
||||
id_token.as_deref()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn consumed_at(&self) -> Option<DateTime<Utc>> {
|
||||
match self {
|
||||
Self::Pending | Self::Completed { .. } => None,
|
||||
Self::Consumed { consumed_at, .. } => Some(*consumed_at),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the upstream oauth authorization session state is
|
||||
/// [`Pending`].
|
||||
///
|
||||
/// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
|
||||
#[must_use]
|
||||
pub fn is_pending(&self) -> bool {
|
||||
matches!(self, Self::Pending)
|
||||
}
|
||||
|
||||
/// Returns `true` if the upstream oauth authorization session state is
|
||||
/// [`Completed`].
|
||||
///
|
||||
/// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
|
||||
#[must_use]
|
||||
pub fn is_completed(&self) -> bool {
|
||||
matches!(self, Self::Completed { .. })
|
||||
}
|
||||
|
||||
/// Returns `true` if the upstream oauth authorization session state is
|
||||
/// [`Consumed`].
|
||||
///
|
||||
/// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed
|
||||
#[must_use]
|
||||
pub fn is_consumed(&self) -> bool {
|
||||
matches!(self, Self::Consumed { .. })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct UpstreamOAuthAuthorizationSession {
|
||||
pub id: Ulid,
|
||||
pub state: UpstreamOAuthAuthorizationSessionState,
|
||||
pub provider_id: Ulid,
|
||||
pub state_str: String,
|
||||
pub code_challenge_verifier: Option<String>,
|
||||
pub nonce: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for UpstreamOAuthAuthorizationSession {
|
||||
type Target = UpstreamOAuthAuthorizationSessionState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.state
|
||||
}
|
||||
}
|
||||
|
||||
impl UpstreamOAuthAuthorizationSession {
|
||||
pub fn complete(
|
||||
mut self,
|
||||
completed_at: DateTime<Utc>,
|
||||
link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.complete(completed_at, link, id_token)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.consume(consumed_at)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,7 @@ pub struct User {
|
||||
pub id: Ulid,
|
||||
pub username: String,
|
||||
pub sub: String,
|
||||
pub primary_email: Option<UserEmail>,
|
||||
pub primary_user_email_id: Option<Ulid>,
|
||||
}
|
||||
|
||||
impl User {
|
||||
@@ -32,7 +32,7 @@ impl User {
|
||||
id: Ulid::from_datetime_with_source(now.into(), rng),
|
||||
username: "john".to_owned(),
|
||||
sub: "123-456".to_owned(),
|
||||
primary_email: None,
|
||||
primary_user_email_id: None,
|
||||
}]
|
||||
}
|
||||
}
|
||||
@@ -57,10 +57,16 @@ pub struct BrowserSession {
|
||||
pub id: Ulid,
|
||||
pub user: User,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub finished_at: Option<DateTime<Utc>>,
|
||||
pub last_authentication: Option<Authentication>,
|
||||
}
|
||||
|
||||
impl BrowserSession {
|
||||
#[must_use]
|
||||
pub fn active(&self) -> bool {
|
||||
self.finished_at.is_none()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn was_authenticated_after(&self, after: DateTime<Utc>) -> bool {
|
||||
if let Some(auth) = &self.last_authentication {
|
||||
@@ -80,6 +86,7 @@ impl BrowserSession {
|
||||
id: Ulid::from_datetime_with_source(now.into(), rng),
|
||||
user,
|
||||
created_at: now,
|
||||
finished_at: None,
|
||||
last_authentication: None,
|
||||
})
|
||||
.collect()
|
||||
@@ -89,6 +96,7 @@ impl BrowserSession {
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct UserEmail {
|
||||
pub id: Ulid,
|
||||
pub user_id: Ulid,
|
||||
pub email: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub confirmed_at: Option<DateTime<Utc>>,
|
||||
@@ -100,12 +108,14 @@ impl UserEmail {
|
||||
vec![
|
||||
Self {
|
||||
id: Ulid::from_datetime_with_source(now.into(), rng),
|
||||
user_id: Ulid::from_datetime_with_source(now.into(), rng),
|
||||
email: "alice@example.com".to_owned(),
|
||||
created_at: now,
|
||||
confirmed_at: Some(now),
|
||||
},
|
||||
Self {
|
||||
id: Ulid::from_datetime_with_source(now.into(), rng),
|
||||
user_id: Ulid::from_datetime_with_source(now.into(), rng),
|
||||
email: "bob@example.com".to_owned(),
|
||||
created_at: now,
|
||||
confirmed_at: None,
|
||||
@@ -124,7 +134,7 @@ pub enum UserEmailVerificationState {
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct UserEmailVerification {
|
||||
pub id: Ulid,
|
||||
pub email: UserEmail,
|
||||
pub user_email_id: Ulid,
|
||||
pub code: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub state: UserEmailVerificationState,
|
||||
@@ -152,8 +162,8 @@ impl UserEmailVerification {
|
||||
.into_iter()
|
||||
.map(move |email| Self {
|
||||
id: Ulid::from_datetime_with_source(now.into(), &mut rng),
|
||||
user_email_id: email.id,
|
||||
code: "123456".to_owned(),
|
||||
email,
|
||||
created_at: now - Duration::minutes(10),
|
||||
state: state.clone(),
|
||||
})
|
||||
|
||||
@@ -100,10 +100,10 @@ impl Mailer {
|
||||
///
|
||||
/// Will return `Err` if the email failed rendering or failed sending
|
||||
#[tracing::instrument(
|
||||
name = "email.verification.send",
|
||||
skip_all,
|
||||
fields(
|
||||
email.to = %to,
|
||||
email.from = %self.from,
|
||||
user.id = %context.user().id,
|
||||
user_email_verification.id = %context.verification().id,
|
||||
user_email_verification.code = context.verification().code,
|
||||
@@ -125,6 +125,7 @@ impl Mailer {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the connection failed
|
||||
#[tracing::instrument(name = "email.test_connection", skip_all, err)]
|
||||
pub async fn test_connection(&self) -> Result<(), crate::transport::Error> {
|
||||
self.transport.test_connection().await
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ anyhow = "1.0.68"
|
||||
async-graphql = { version = "5.0.5", features = ["chrono", "url"] }
|
||||
chrono = "0.4.23"
|
||||
serde = { version = "1.0.152", features = ["derive"] }
|
||||
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] }
|
||||
tokio = { version = "1.23.0", features = ["sync"] }
|
||||
thiserror = "1.0.38"
|
||||
tracing = "0.1.37"
|
||||
ulid = "1.0.0"
|
||||
|
||||
@@ -30,8 +30,14 @@ use async_graphql::{
|
||||
connection::{query, Connection, Edge, OpaqueCursor},
|
||||
Context, Description, EmptyMutation, EmptySubscription, ID,
|
||||
};
|
||||
use mas_storage::{
|
||||
oauth2::OAuth2ClientRepository,
|
||||
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
|
||||
user::{BrowserSessionRepository, UserEmailRepository},
|
||||
BoxRepository, Pagination,
|
||||
};
|
||||
use model::CreationEvent;
|
||||
use sqlx::PgPool;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use self::model::{
|
||||
BrowserSession, Cursor, Node, NodeCursor, NodeType, OAuth2Client, UpstreamOAuth2Link,
|
||||
@@ -87,10 +93,9 @@ impl RootQuery {
|
||||
id: ID,
|
||||
) -> Result<Option<OAuth2Client>, async_graphql::Error> {
|
||||
let id = NodeType::OAuth2Client.extract_ulid(&id)?;
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut conn = database.acquire().await?;
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
|
||||
let client = mas_storage::oauth2::client::lookup_client(&mut conn, id).await?;
|
||||
let client = repo.oauth2_client().lookup(id).await?;
|
||||
|
||||
Ok(client.map(OAuth2Client))
|
||||
}
|
||||
@@ -118,13 +123,12 @@ impl RootQuery {
|
||||
) -> Result<Option<BrowserSession>, async_graphql::Error> {
|
||||
let id = NodeType::BrowserSession.extract_ulid(&id)?;
|
||||
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut conn = database.acquire().await?;
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
|
||||
let Some(session) = session else { return Ok(None) };
|
||||
let current_user = session.user;
|
||||
|
||||
let browser_session = mas_storage::user::lookup_active_session(&mut conn, id).await?;
|
||||
let browser_session = repo.browser_session().lookup(id).await?;
|
||||
|
||||
let ret = browser_session.and_then(|browser_session| {
|
||||
if browser_session.user.id == current_user.id {
|
||||
@@ -145,14 +149,16 @@ impl RootQuery {
|
||||
) -> Result<Option<UserEmail>, async_graphql::Error> {
|
||||
let id = NodeType::UserEmail.extract_ulid(&id)?;
|
||||
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut conn = database.acquire().await?;
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
|
||||
let Some(session) = session else { return Ok(None) };
|
||||
let current_user = session.user;
|
||||
|
||||
let user_email =
|
||||
mas_storage::user::lookup_user_email_by_id(&mut conn, ¤t_user, id).await?;
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.lookup(id)
|
||||
.await?
|
||||
.filter(|e| e.user_id == current_user.id);
|
||||
|
||||
Ok(user_email.map(UserEmail))
|
||||
}
|
||||
@@ -165,13 +171,12 @@ impl RootQuery {
|
||||
) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> {
|
||||
let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?;
|
||||
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut conn = database.acquire().await?;
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
|
||||
let Some(session) = session else { return Ok(None) };
|
||||
let current_user = session.user;
|
||||
|
||||
let link = mas_storage::upstream_oauth2::lookup_link(&mut conn, id).await?;
|
||||
let link = repo.upstream_oauth_link().lookup(id).await?;
|
||||
|
||||
// Ensure that the link belongs to the current user
|
||||
let link = link.filter(|link| link.user_id == Some(current_user.id));
|
||||
@@ -186,10 +191,9 @@ impl RootQuery {
|
||||
id: ID,
|
||||
) -> Result<Option<UpstreamOAuth2Provider>, async_graphql::Error> {
|
||||
let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?;
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut conn = database.acquire().await?;
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
|
||||
let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id).await?;
|
||||
let provider = repo.upstream_oauth_provider().lookup(id).await?;
|
||||
|
||||
Ok(provider.map(UpstreamOAuth2Provider::new))
|
||||
}
|
||||
@@ -206,7 +210,7 @@ impl RootQuery {
|
||||
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
|
||||
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
|
||||
) -> Result<Connection<Cursor, UpstreamOAuth2Provider>, async_graphql::Error> {
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
|
||||
query(
|
||||
after,
|
||||
@@ -214,7 +218,6 @@ impl RootQuery {
|
||||
first,
|
||||
last,
|
||||
|after, before, first, last| async move {
|
||||
let mut conn = database.acquire().await?;
|
||||
let after_id = after
|
||||
.map(|x: OpaqueCursor<NodeCursor>| {
|
||||
x.extract_for_type(NodeType::UpstreamOAuth2Provider)
|
||||
@@ -225,15 +228,15 @@ impl RootQuery {
|
||||
x.extract_for_type(NodeType::UpstreamOAuth2Provider)
|
||||
})
|
||||
.transpose()?;
|
||||
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||
|
||||
let (has_previous_page, has_next_page, edges) =
|
||||
mas_storage::upstream_oauth2::get_paginated_providers(
|
||||
&mut conn, before_id, after_id, first, last,
|
||||
)
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list_paginated(pagination)
|
||||
.await?;
|
||||
|
||||
let mut connection = Connection::new(has_previous_page, has_next_page);
|
||||
connection.edges.extend(edges.into_iter().map(|p| {
|
||||
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
||||
connection.edges.extend(page.edges.into_iter().map(|p| {
|
||||
Edge::new(
|
||||
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)),
|
||||
UpstreamOAuth2Provider::new(p),
|
||||
|
||||
@@ -12,9 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_graphql::{Description, Object, ID};
|
||||
use anyhow::Context as _;
|
||||
use async_graphql::{Context, Description, Object, ID};
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::CompatSsoLoginState;
|
||||
use mas_storage::{compat::CompatSessionRepository, user::UserRepository, BoxRepository};
|
||||
use tokio::sync::Mutex;
|
||||
use url::Url;
|
||||
|
||||
use super::{NodeType, User};
|
||||
@@ -32,8 +34,14 @@ impl CompatSession {
|
||||
}
|
||||
|
||||
/// The user authorized for this session.
|
||||
async fn user(&self) -> User {
|
||||
User(self.0.user.clone())
|
||||
async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(self.0.user_id)
|
||||
.await?
|
||||
.context("Could not load user")?;
|
||||
Ok(User(user))
|
||||
}
|
||||
|
||||
/// The Matrix Device ID of this session.
|
||||
@@ -48,7 +56,7 @@ impl CompatSession {
|
||||
|
||||
/// When the session ended.
|
||||
pub async fn finished_at(&self) -> Option<DateTime<Utc>> {
|
||||
self.0.finished_at
|
||||
self.0.finished_at()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,29 +85,28 @@ impl CompatSsoLogin {
|
||||
/// When the login was fulfilled, and the user was redirected back to the
|
||||
/// client.
|
||||
async fn fulfilled_at(&self) -> Option<DateTime<Utc>> {
|
||||
match &self.0.state {
|
||||
CompatSsoLoginState::Pending => None,
|
||||
CompatSsoLoginState::Fulfilled { fulfilled_at, .. }
|
||||
| CompatSsoLoginState::Exchanged { fulfilled_at, .. } => Some(*fulfilled_at),
|
||||
}
|
||||
self.0.fulfilled_at()
|
||||
}
|
||||
|
||||
/// When the client exchanged the login token sent during the redirection.
|
||||
async fn exchanged_at(&self) -> Option<DateTime<Utc>> {
|
||||
match &self.0.state {
|
||||
CompatSsoLoginState::Pending | CompatSsoLoginState::Fulfilled { .. } => None,
|
||||
CompatSsoLoginState::Exchanged { exchanged_at, .. } => Some(*exchanged_at),
|
||||
}
|
||||
self.0.exchanged_at()
|
||||
}
|
||||
|
||||
/// The compat session which was started by this login.
|
||||
async fn session(&self) -> Option<CompatSession> {
|
||||
match &self.0.state {
|
||||
CompatSsoLoginState::Pending => None,
|
||||
CompatSsoLoginState::Fulfilled { session, .. }
|
||||
| CompatSsoLoginState::Exchanged { session, .. } => {
|
||||
Some(CompatSession(session.clone()))
|
||||
}
|
||||
}
|
||||
async fn session(
|
||||
&self,
|
||||
ctx: &Context<'_>,
|
||||
) -> Result<Option<CompatSession>, async_graphql::Error> {
|
||||
let Some(session_id) = self.0.session_id() else { return Ok(None) };
|
||||
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.context("Could not load compat session")?;
|
||||
|
||||
Ok(Some(CompatSession(session)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
use anyhow::Context as _;
|
||||
use async_graphql::{Context, Description, Object, ID};
|
||||
use mas_storage::oauth2::client::lookup_client;
|
||||
use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, BoxRepository};
|
||||
use oauth2_types::scope::Scope;
|
||||
use sqlx::PgPool;
|
||||
use tokio::sync::Mutex;
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
|
||||
@@ -35,8 +35,15 @@ impl OAuth2Session {
|
||||
}
|
||||
|
||||
/// OAuth 2.0 client used by this session.
|
||||
pub async fn client(&self) -> OAuth2Client {
|
||||
OAuth2Client(self.0.client.clone())
|
||||
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
let client = repo
|
||||
.oauth2_client()
|
||||
.lookup(self.0.client_id)
|
||||
.await?
|
||||
.context("Could not load client")?;
|
||||
|
||||
Ok(OAuth2Client(client))
|
||||
}
|
||||
|
||||
/// Scope granted for this session.
|
||||
@@ -45,13 +52,30 @@ impl OAuth2Session {
|
||||
}
|
||||
|
||||
/// The browser session which started this OAuth 2.0 session.
|
||||
pub async fn browser_session(&self) -> BrowserSession {
|
||||
BrowserSession(self.0.browser_session.clone())
|
||||
pub async fn browser_session(
|
||||
&self,
|
||||
ctx: &Context<'_>,
|
||||
) -> Result<BrowserSession, async_graphql::Error> {
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
let browser_session = repo
|
||||
.browser_session()
|
||||
.lookup(self.0.user_session_id)
|
||||
.await?
|
||||
.context("Could not load browser session")?;
|
||||
|
||||
Ok(BrowserSession(browser_session))
|
||||
}
|
||||
|
||||
/// User authorized for this session.
|
||||
pub async fn user(&self) -> User {
|
||||
User(self.0.browser_session.user.clone())
|
||||
pub async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
let browser_session = repo
|
||||
.browser_session()
|
||||
.lookup(self.0.user_session_id)
|
||||
.await?
|
||||
.context("Could not load browser session")?;
|
||||
|
||||
Ok(User(browser_session.user))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,8 +138,10 @@ impl OAuth2Consent {
|
||||
|
||||
/// OAuth 2.0 client for which the user granted access.
|
||||
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
|
||||
let mut conn = ctx.data::<PgPool>()?.acquire().await?;
|
||||
let client = lookup_client(&mut conn, self.client_id)
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
let client = repo
|
||||
.oauth2_client()
|
||||
.lookup(self.client_id)
|
||||
.await?
|
||||
.context("Could not load client")?;
|
||||
Ok(OAuth2Client(client))
|
||||
|
||||
@@ -15,7 +15,10 @@
|
||||
use anyhow::Context as _;
|
||||
use async_graphql::{Context, Object, ID};
|
||||
use chrono::{DateTime, Utc};
|
||||
use sqlx::PgPool;
|
||||
use mas_storage::{
|
||||
upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, BoxRepository,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::{NodeType, User};
|
||||
|
||||
@@ -99,11 +102,13 @@ impl UpstreamOAuth2Link {
|
||||
provider.clone()
|
||||
} else {
|
||||
// Fetch on-the-fly
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut conn = database.acquire().await?;
|
||||
mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id)
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.lookup(self.link.provider_id)
|
||||
.await?
|
||||
.context("Upstream OAuth 2.0 provider not found")?
|
||||
.context("Upstream OAuth 2.0 provider not found")?;
|
||||
provider
|
||||
};
|
||||
|
||||
Ok(UpstreamOAuth2Provider::new(provider))
|
||||
@@ -116,9 +121,13 @@ impl UpstreamOAuth2Link {
|
||||
user.clone()
|
||||
} else if let Some(user_id) = &self.link.user_id {
|
||||
// Fetch on-the-fly
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut conn = database.acquire().await?;
|
||||
mas_storage::user::lookup_user(&mut conn, *user_id).await?
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(*user_id)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
user
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
@@ -17,7 +17,14 @@ use async_graphql::{
|
||||
Context, Description, Object, ID,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use sqlx::PgPool;
|
||||
use mas_storage::{
|
||||
compat::CompatSsoLoginRepository,
|
||||
oauth2::OAuth2SessionRepository,
|
||||
upstream_oauth2::UpstreamOAuthLinkRepository,
|
||||
user::{BrowserSessionRepository, UserEmailRepository},
|
||||
BoxRepository, Pagination,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::{
|
||||
compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session,
|
||||
@@ -53,8 +60,14 @@ impl User {
|
||||
}
|
||||
|
||||
/// Primary email address of the user.
|
||||
async fn primary_email(&self) -> Option<UserEmail> {
|
||||
self.0.primary_email.clone().map(UserEmail)
|
||||
async fn primary_email(
|
||||
&self,
|
||||
ctx: &Context<'_>,
|
||||
) -> Result<Option<UserEmail>, async_graphql::Error> {
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
|
||||
let mut user_email_repo = repo.user_email();
|
||||
Ok(user_email_repo.get_primary(&self.0).await?.map(UserEmail))
|
||||
}
|
||||
|
||||
/// Get the list of compatibility SSO logins, chronologically sorted
|
||||
@@ -69,7 +82,7 @@ impl User {
|
||||
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
|
||||
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
|
||||
) -> Result<Connection<Cursor, CompatSsoLogin>, async_graphql::Error> {
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
|
||||
query(
|
||||
after,
|
||||
@@ -77,22 +90,21 @@ impl User {
|
||||
first,
|
||||
last,
|
||||
|after, before, first, last| async move {
|
||||
let mut conn = database.acquire().await?;
|
||||
let after_id = after
|
||||
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin))
|
||||
.transpose()?;
|
||||
let before_id = before
|
||||
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin))
|
||||
.transpose()?;
|
||||
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||
|
||||
let (has_previous_page, has_next_page, edges) =
|
||||
mas_storage::compat::get_paginated_user_compat_sso_logins(
|
||||
&mut conn, &self.0, before_id, after_id, first, last,
|
||||
)
|
||||
let page = repo
|
||||
.compat_sso_login()
|
||||
.list_paginated(&self.0, pagination)
|
||||
.await?;
|
||||
|
||||
let mut connection = Connection::new(has_previous_page, has_next_page);
|
||||
connection.edges.extend(edges.into_iter().map(|u| {
|
||||
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
||||
connection.edges.extend(page.edges.into_iter().map(|u| {
|
||||
Edge::new(
|
||||
OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.id)),
|
||||
CompatSsoLogin(u),
|
||||
@@ -117,7 +129,7 @@ impl User {
|
||||
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
|
||||
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
|
||||
) -> Result<Connection<Cursor, BrowserSession>, async_graphql::Error> {
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
|
||||
query(
|
||||
after,
|
||||
@@ -125,22 +137,21 @@ impl User {
|
||||
first,
|
||||
last,
|
||||
|after, before, first, last| async move {
|
||||
let mut conn = database.acquire().await?;
|
||||
let after_id = after
|
||||
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::BrowserSession))
|
||||
.transpose()?;
|
||||
let before_id = before
|
||||
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::BrowserSession))
|
||||
.transpose()?;
|
||||
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||
|
||||
let (has_previous_page, has_next_page, edges) =
|
||||
mas_storage::user::get_paginated_user_sessions(
|
||||
&mut conn, &self.0, before_id, after_id, first, last,
|
||||
)
|
||||
let page = repo
|
||||
.browser_session()
|
||||
.list_active_paginated(&self.0, pagination)
|
||||
.await?;
|
||||
|
||||
let mut connection = Connection::new(has_previous_page, has_next_page);
|
||||
connection.edges.extend(edges.into_iter().map(|u| {
|
||||
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
||||
connection.edges.extend(page.edges.into_iter().map(|u| {
|
||||
Edge::new(
|
||||
OpaqueCursor(NodeCursor(NodeType::BrowserSession, u.id)),
|
||||
BrowserSession(u),
|
||||
@@ -165,7 +176,7 @@ impl User {
|
||||
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
|
||||
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
|
||||
) -> Result<Connection<Cursor, UserEmail, UserEmailsPagination>, async_graphql::Error> {
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
|
||||
query(
|
||||
after,
|
||||
@@ -173,26 +184,25 @@ impl User {
|
||||
first,
|
||||
last,
|
||||
|after, before, first, last| async move {
|
||||
let mut conn = database.acquire().await?;
|
||||
let after_id = after
|
||||
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::UserEmail))
|
||||
.transpose()?;
|
||||
let before_id = before
|
||||
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::UserEmail))
|
||||
.transpose()?;
|
||||
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||
|
||||
let (has_previous_page, has_next_page, edges) =
|
||||
mas_storage::user::get_paginated_user_emails(
|
||||
&mut conn, &self.0, before_id, after_id, first, last,
|
||||
)
|
||||
let page = repo
|
||||
.user_email()
|
||||
.list_paginated(&self.0, pagination)
|
||||
.await?;
|
||||
|
||||
let mut connection = Connection::with_additional_fields(
|
||||
has_previous_page,
|
||||
has_next_page,
|
||||
page.has_previous_page,
|
||||
page.has_next_page,
|
||||
UserEmailsPagination(self.0.clone()),
|
||||
);
|
||||
connection.edges.extend(edges.into_iter().map(|u| {
|
||||
connection.edges.extend(page.edges.into_iter().map(|u| {
|
||||
Edge::new(
|
||||
OpaqueCursor(NodeCursor(NodeType::UserEmail, u.id)),
|
||||
UserEmail(u),
|
||||
@@ -217,7 +227,7 @@ impl User {
|
||||
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
|
||||
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
|
||||
) -> Result<Connection<Cursor, OAuth2Session>, async_graphql::Error> {
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
|
||||
query(
|
||||
after,
|
||||
@@ -225,22 +235,21 @@ impl User {
|
||||
first,
|
||||
last,
|
||||
|after, before, first, last| async move {
|
||||
let mut conn = database.acquire().await?;
|
||||
let after_id = after
|
||||
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::OAuth2Session))
|
||||
.transpose()?;
|
||||
let before_id = before
|
||||
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::OAuth2Session))
|
||||
.transpose()?;
|
||||
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||
|
||||
let (has_previous_page, has_next_page, edges) =
|
||||
mas_storage::oauth2::get_paginated_user_oauth_sessions(
|
||||
&mut conn, &self.0, before_id, after_id, first, last,
|
||||
)
|
||||
let page = repo
|
||||
.oauth2_session()
|
||||
.list_paginated(&self.0, pagination)
|
||||
.await?;
|
||||
|
||||
let mut connection = Connection::new(has_previous_page, has_next_page);
|
||||
connection.edges.extend(edges.into_iter().map(|s| {
|
||||
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
||||
connection.edges.extend(page.edges.into_iter().map(|s| {
|
||||
Edge::new(
|
||||
OpaqueCursor(NodeCursor(NodeType::OAuth2Session, s.id)),
|
||||
OAuth2Session(s),
|
||||
@@ -265,7 +274,7 @@ impl User {
|
||||
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
|
||||
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
|
||||
) -> Result<Connection<Cursor, UpstreamOAuth2Link>, async_graphql::Error> {
|
||||
let database = ctx.data::<PgPool>()?;
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
|
||||
query(
|
||||
after,
|
||||
@@ -273,7 +282,6 @@ impl User {
|
||||
first,
|
||||
last,
|
||||
|after, before, first, last| async move {
|
||||
let mut conn = database.acquire().await?;
|
||||
let after_id = after
|
||||
.map(|x: OpaqueCursor<NodeCursor>| {
|
||||
x.extract_for_type(NodeType::UpstreamOAuth2Link)
|
||||
@@ -284,15 +292,15 @@ impl User {
|
||||
x.extract_for_type(NodeType::UpstreamOAuth2Link)
|
||||
})
|
||||
.transpose()?;
|
||||
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||
|
||||
let (has_previous_page, has_next_page, edges) =
|
||||
mas_storage::upstream_oauth2::get_paginated_user_links(
|
||||
&mut conn, &self.0, before_id, after_id, first, last,
|
||||
)
|
||||
let page = repo
|
||||
.upstream_oauth_link()
|
||||
.list_paginated(&self.0, pagination)
|
||||
.await?;
|
||||
|
||||
let mut connection = Connection::new(has_previous_page, has_next_page);
|
||||
connection.edges.extend(edges.into_iter().map(|s| {
|
||||
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
||||
connection.edges.extend(page.edges.into_iter().map(|s| {
|
||||
Edge::new(
|
||||
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Link, s.id)),
|
||||
UpstreamOAuth2Link::new(s),
|
||||
@@ -339,9 +347,9 @@ pub struct UserEmailsPagination(mas_data_model::User);
|
||||
#[Object]
|
||||
impl UserEmailsPagination {
|
||||
/// Identifies the total count of items in the connection.
|
||||
async fn total_count(&self, ctx: &Context<'_>) -> Result<i64, async_graphql::Error> {
|
||||
let mut conn = ctx.data::<PgPool>()?.acquire().await?;
|
||||
let count = mas_storage::user::count_user_emails(&mut conn, &self.0).await?;
|
||||
async fn total_count(&self, ctx: &Context<'_>) -> Result<usize, async_graphql::Error> {
|
||||
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
|
||||
let count = repo.user_email().count(&self.0).await?;
|
||||
Ok(count)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,6 +68,7 @@ mas-oidc-client = { path = "../oidc-client" }
|
||||
mas-policy = { path = "../policy" }
|
||||
mas-router = { path = "../router" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-storage-pg = { path = "../storage-pg" }
|
||||
mas-templates = { path = "../templates" }
|
||||
oauth2-types = { path = "../oauth2-types" }
|
||||
|
||||
|
||||
@@ -12,16 +12,25 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::{convert::Infallible, sync::Arc};
|
||||
|
||||
use axum::extract::FromRef;
|
||||
use axum::{
|
||||
async_trait,
|
||||
extract::{FromRef, FromRequestParts},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::http_client_factory::HttpClientFactory;
|
||||
use mas_email::Mailer;
|
||||
use mas_keystore::{Encrypter, Keystore};
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use mas_templates::Templates;
|
||||
use rand::SeedableRng;
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{passwords::PasswordManager, MatrixHomeserver};
|
||||
|
||||
@@ -105,3 +114,58 @@ impl FromRef<AppState> for PasswordManager {
|
||||
input.password_manager.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<AppState> for BoxClock {
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request_parts(
|
||||
_parts: &mut axum::http::request::Parts,
|
||||
_state: &AppState,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let clock = SystemClock::default();
|
||||
Ok(Box::new(clock))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<AppState> for BoxRng {
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request_parts(
|
||||
_parts: &mut axum::http::request::Parts,
|
||||
_state: &AppState,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
// This rng is used to source the local rng
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let rng = rand::thread_rng();
|
||||
|
||||
let rng = rand_chacha::ChaChaRng::from_rng(rng).expect("Failed to seed RNG");
|
||||
Ok(Box::new(rng))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub struct RepositoryError(#[from] mas_storage_pg::DatabaseError);
|
||||
|
||||
impl IntoResponse for RepositoryError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<AppState> for BoxRepository {
|
||||
type Rejection = RepositoryError;
|
||||
|
||||
async fn from_request_parts(
|
||||
_parts: &mut axum::http::request::Parts,
|
||||
state: &AppState,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let repo = PgRepository::from_pool(&state.pool).await?;
|
||||
Ok(repo
|
||||
.map_err(mas_storage::RepositoryError::from_error)
|
||||
.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,18 +15,18 @@
|
||||
use axum::{extract::State, response::IntoResponse, Json};
|
||||
use chrono::Duration;
|
||||
use hyper::StatusCode;
|
||||
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType};
|
||||
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User};
|
||||
use mas_storage::{
|
||||
compat::{
|
||||
add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token,
|
||||
mark_compat_sso_login_as_exchanged, start_compat_session,
|
||||
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
|
||||
CompatSsoLoginRepository,
|
||||
},
|
||||
user::{add_user_password, lookup_user_by_username, lookup_user_password},
|
||||
Clock,
|
||||
user::{UserPasswordRepository, UserRepository},
|
||||
BoxClock, BoxRepository, BoxRng, Clock,
|
||||
};
|
||||
use rand::{CryptoRng, RngCore};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
|
||||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
use thiserror::Error;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
@@ -137,6 +137,9 @@ pub enum RouteError {
|
||||
#[error("user not found")]
|
||||
UserNotFound,
|
||||
|
||||
#[error("session not found")]
|
||||
SessionNotFound,
|
||||
|
||||
#[error("user has no password")]
|
||||
NoPassword,
|
||||
|
||||
@@ -150,13 +153,12 @@ pub enum RouteError {
|
||||
InvalidLoginToken,
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::Internal(_) => MatrixError {
|
||||
Self::Internal(_) | Self::SessionNotFound => MatrixError {
|
||||
errcode: "M_UNKNOWN",
|
||||
error: "Internal server error",
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
@@ -190,27 +192,37 @@ impl IntoResponse for RouteError {
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(password_manager): State<PasswordManager>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
State(homeserver): State<MatrixHomeserver>,
|
||||
Json(input): Json<RequestBody>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut txn = pool.begin().await?;
|
||||
let session = match input.credentials {
|
||||
let (session, user) = match input.credentials {
|
||||
Credentials::Password {
|
||||
identifier: Identifier::User { user },
|
||||
password,
|
||||
} => user_password_login(&password_manager, &mut txn, user, password).await?,
|
||||
} => {
|
||||
user_password_login(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&password_manager,
|
||||
&mut repo,
|
||||
user,
|
||||
password,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
|
||||
Credentials::Token { token } => token_login(&mut txn, &clock, &token).await?,
|
||||
Credentials::Token { token } => token_login(&mut repo, &clock, &token).await?,
|
||||
|
||||
_ => {
|
||||
return Err(RouteError::Unsupported);
|
||||
}
|
||||
};
|
||||
|
||||
let user_id = format!("@{username}:{homeserver}", username = session.user.username);
|
||||
let user_id = format!("@{username}:{homeserver}", username = user.username);
|
||||
|
||||
// If the client asked for a refreshable token, make it expire
|
||||
let expires_in = if input.refresh_token {
|
||||
@@ -221,33 +233,23 @@ pub(crate) async fn post(
|
||||
};
|
||||
|
||||
let access_token = TokenType::CompatAccessToken.generate(&mut rng);
|
||||
let access_token = add_compat_access_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
access_token,
|
||||
expires_in,
|
||||
)
|
||||
.await?;
|
||||
let access_token = repo
|
||||
.compat_access_token()
|
||||
.add(&mut rng, &clock, &session, access_token, expires_in)
|
||||
.await?;
|
||||
|
||||
let refresh_token = if input.refresh_token {
|
||||
let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng);
|
||||
let refresh_token = add_compat_refresh_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
&access_token,
|
||||
refresh_token,
|
||||
)
|
||||
.await?;
|
||||
let refresh_token = repo
|
||||
.compat_refresh_token()
|
||||
.add(&mut rng, &clock, &session, &access_token, refresh_token)
|
||||
.await?;
|
||||
Some(refresh_token.token)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok(Json(ResponseBody {
|
||||
access_token: access_token.token,
|
||||
@@ -259,16 +261,18 @@ pub(crate) async fn post(
|
||||
}
|
||||
|
||||
async fn token_login(
|
||||
txn: &mut Transaction<'_, Postgres>,
|
||||
clock: &Clock,
|
||||
repo: &mut BoxRepository,
|
||||
clock: &dyn Clock,
|
||||
token: &str,
|
||||
) -> Result<CompatSession, RouteError> {
|
||||
let login = get_compat_sso_login_by_token(&mut *txn, token)
|
||||
) -> Result<(CompatSession, User), RouteError> {
|
||||
let login = repo
|
||||
.compat_sso_login()
|
||||
.find_by_token(token)
|
||||
.await?
|
||||
.ok_or(RouteError::InvalidLoginToken)?;
|
||||
|
||||
let now = clock.now();
|
||||
match login.state {
|
||||
let session_id = match login.state {
|
||||
CompatSsoLoginState::Pending => {
|
||||
tracing::error!(
|
||||
compat_sso_login.id = %login.id,
|
||||
@@ -277,49 +281,70 @@ async fn token_login(
|
||||
return Err(RouteError::InvalidLoginToken);
|
||||
}
|
||||
CompatSsoLoginState::Fulfilled {
|
||||
fulfilled_at: fullfilled_at,
|
||||
fulfilled_at,
|
||||
session_id,
|
||||
..
|
||||
} => {
|
||||
if now > fullfilled_at + Duration::seconds(30) {
|
||||
if now > fulfilled_at + Duration::seconds(30) {
|
||||
return Err(RouteError::LoginTookTooLong);
|
||||
}
|
||||
|
||||
session_id
|
||||
}
|
||||
CompatSsoLoginState::Exchanged { exchanged_at, .. } => {
|
||||
CompatSsoLoginState::Exchanged {
|
||||
exchanged_at,
|
||||
session_id,
|
||||
..
|
||||
} => {
|
||||
if now > exchanged_at + Duration::seconds(30) {
|
||||
// TODO: log that session out
|
||||
tracing::error!(
|
||||
compat_sso_login.id = %login.id,
|
||||
compat_session.id = %session_id,
|
||||
"Login token exchanged a second time more than 30s after"
|
||||
);
|
||||
}
|
||||
|
||||
return Err(RouteError::InvalidLoginToken);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let login = mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?;
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::SessionNotFound)?;
|
||||
|
||||
match login.state {
|
||||
CompatSsoLoginState::Exchanged { session, .. } => Ok(session),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(session.user_id)
|
||||
.await?
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
|
||||
repo.compat_sso_login().exchange(clock, login).await?;
|
||||
|
||||
Ok((session, user))
|
||||
}
|
||||
|
||||
async fn user_password_login(
|
||||
mut rng: &mut (impl RngCore + CryptoRng + Send),
|
||||
clock: &impl Clock,
|
||||
password_manager: &PasswordManager,
|
||||
txn: &mut Transaction<'_, Postgres>,
|
||||
repo: &mut BoxRepository,
|
||||
username: String,
|
||||
password: String,
|
||||
) -> Result<CompatSession, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
|
||||
) -> Result<(CompatSession, User), RouteError> {
|
||||
// Find the user
|
||||
let user = lookup_user_by_username(&mut *txn, &username)
|
||||
let user = repo
|
||||
.user()
|
||||
.find_by_username(&username)
|
||||
.await?
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
|
||||
// Lookup its password
|
||||
let user_password = lookup_user_password(&mut *txn, &user)
|
||||
let user_password = repo
|
||||
.user_password()
|
||||
.active(&user)
|
||||
.await?
|
||||
.ok_or(RouteError::NoPassword)?;
|
||||
|
||||
@@ -338,21 +363,24 @@ async fn user_password_login(
|
||||
|
||||
if let Some((version, hashed_password)) = new_password_hash {
|
||||
// Save the upgraded password if needed
|
||||
add_user_password(
|
||||
&mut *txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&user,
|
||||
version,
|
||||
hashed_password,
|
||||
Some(user_password),
|
||||
)
|
||||
.await?;
|
||||
repo.user_password()
|
||||
.add(
|
||||
&mut rng,
|
||||
clock,
|
||||
&user,
|
||||
version,
|
||||
hashed_password,
|
||||
Some(&user_password),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Now that the user credentials have been verified, start a new compat session
|
||||
let device = Device::generate(&mut rng);
|
||||
let session = start_compat_session(&mut *txn, &mut rng, &clock, user, device).await?;
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.add(&mut rng, clock, &user, device)
|
||||
.await?;
|
||||
|
||||
Ok(session)
|
||||
Ok((session, user))
|
||||
}
|
||||
|
||||
@@ -29,10 +29,12 @@ use mas_axum_utils::{
|
||||
use mas_data_model::Device;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
|
||||
use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id};
|
||||
use mas_storage::{
|
||||
compat::{CompatSessionRepository, CompatSsoLoginRepository},
|
||||
BoxClock, BoxRepository, BoxRng, Clock,
|
||||
};
|
||||
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use ulid::Ulid;
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -50,19 +52,18 @@ pub struct Params {
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
State(pool): State<PgPool>,
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
State(templates): State<Templates>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Path(id): Path<Ulid>,
|
||||
Query(params): Query<Params>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
|
||||
let maybe_session = session_info.load_session(&mut conn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let session = if let Some(session) = maybe_session {
|
||||
session
|
||||
@@ -80,20 +81,16 @@ pub async fn get(
|
||||
return Ok((cookie_jar, url).into_response());
|
||||
};
|
||||
|
||||
// TODO: make that more generic
|
||||
if session
|
||||
.user
|
||||
.primary_email
|
||||
.as_ref()
|
||||
.and_then(|e| e.confirmed_at)
|
||||
.is_none()
|
||||
{
|
||||
// TODO: make that more generic, check that the email has been confirmed
|
||||
if session.user.primary_user_email_id.is_none() {
|
||||
let destination = mas_router::AccountAddEmail::default()
|
||||
.and_then(PostAuthAction::continue_compat_sso_login(id));
|
||||
return Ok((cookie_jar, destination.go()).into_response());
|
||||
}
|
||||
|
||||
let login = get_compat_sso_login_by_id(&mut conn, id)
|
||||
let login = repo
|
||||
.compat_sso_login()
|
||||
.lookup(id)
|
||||
.await?
|
||||
.context("Could not find compat SSO login")?;
|
||||
|
||||
@@ -117,20 +114,19 @@ pub async fn get(
|
||||
}
|
||||
|
||||
pub async fn post(
|
||||
State(pool): State<PgPool>,
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
State(templates): State<Templates>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Path(id): Path<Ulid>,
|
||||
Query(params): Query<Params>,
|
||||
Form(form): Form<ProtectedForm<()>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
cookie_jar.verify_form(clock.now(), form)?;
|
||||
cookie_jar.verify_form(&clock, form)?;
|
||||
|
||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let session = if let Some(session) = maybe_session {
|
||||
session
|
||||
@@ -149,19 +145,15 @@ pub async fn post(
|
||||
};
|
||||
|
||||
// TODO: make that more generic
|
||||
if session
|
||||
.user
|
||||
.primary_email
|
||||
.as_ref()
|
||||
.and_then(|e| e.confirmed_at)
|
||||
.is_none()
|
||||
{
|
||||
if session.user.primary_user_email_id.is_none() {
|
||||
let destination = mas_router::AccountAddEmail::default()
|
||||
.and_then(PostAuthAction::continue_compat_sso_login(id));
|
||||
return Ok((cookie_jar, destination.go()).into_response());
|
||||
}
|
||||
|
||||
let login = get_compat_sso_login_by_id(&mut txn, id)
|
||||
let login = repo
|
||||
.compat_sso_login()
|
||||
.lookup(id)
|
||||
.await?
|
||||
.context("Could not find compat SSO login")?;
|
||||
|
||||
@@ -193,10 +185,16 @@ pub async fn post(
|
||||
};
|
||||
|
||||
let device = Device::generate(&mut rng);
|
||||
let _login =
|
||||
fullfill_compat_sso_login(&mut txn, &mut rng, &clock, session.user, login, device).await?;
|
||||
let compat_session = repo
|
||||
.compat_session()
|
||||
.add(&mut rng, &clock, &session.user, device)
|
||||
.await?;
|
||||
|
||||
txn.commit().await?;
|
||||
repo.compat_sso_login()
|
||||
.fulfill(&clock, login, &compat_session)
|
||||
.await?;
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok((cookie_jar, Redirect::to(redirect_uri.as_str())).into_response())
|
||||
}
|
||||
|
||||
@@ -19,11 +19,10 @@ use axum::{
|
||||
};
|
||||
use hyper::StatusCode;
|
||||
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
|
||||
use mas_storage::compat::insert_compat_sso_login;
|
||||
use mas_storage::{compat::CompatSsoLoginRepository, BoxClock, BoxRepository, BoxRng};
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
use serde::Deserialize;
|
||||
use serde_with::serde;
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use url::Url;
|
||||
|
||||
@@ -48,7 +47,7 @@ pub enum RouteError {
|
||||
InvalidRedirectUrl,
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
@@ -56,14 +55,13 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(pool, url_builder), err)]
|
||||
pub async fn get(
|
||||
State(pool): State<PgPool>,
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
Query(params): Query<Params>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
|
||||
// Check the redirectUrl parameter
|
||||
let redirect_url = params.redirect_url.ok_or(RouteError::MissingRedirectUrl)?;
|
||||
let redirect_url = Url::parse(&redirect_url).map_err(|_| RouteError::InvalidRedirectUrl)?;
|
||||
@@ -79,8 +77,10 @@ pub async fn get(
|
||||
}
|
||||
|
||||
let token = Alphanumeric.sample_string(&mut rng, 32);
|
||||
let mut conn = pool.acquire().await?;
|
||||
let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?;
|
||||
let login = repo
|
||||
.compat_sso_login()
|
||||
.add(&mut rng, &clock, token, redirect_url)
|
||||
.await?;
|
||||
|
||||
Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.id, params.action)))
|
||||
}
|
||||
|
||||
@@ -12,12 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use axum::{extract::State, response::IntoResponse, Json, TypedHeader};
|
||||
use axum::{response::IntoResponse, Json, TypedHeader};
|
||||
use headers::{authorization::Bearer, Authorization};
|
||||
use hyper::StatusCode;
|
||||
use mas_data_model::TokenType;
|
||||
use mas_storage::{compat::compat_logout, Clock};
|
||||
use sqlx::PgPool;
|
||||
use mas_storage::{
|
||||
compat::{CompatAccessTokenRepository, CompatSessionRepository},
|
||||
BoxClock, BoxRepository, Clock,
|
||||
};
|
||||
use thiserror::Error;
|
||||
|
||||
use super::MatrixError;
|
||||
@@ -36,12 +38,9 @@ pub enum RouteError {
|
||||
|
||||
#[error("Invalid access token")]
|
||||
InvalidAuthorization,
|
||||
|
||||
#[error("Logout failed")]
|
||||
LogoutFailed,
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
@@ -56,7 +55,7 @@ impl IntoResponse for RouteError {
|
||||
error: "Missing access token",
|
||||
status: StatusCode::UNAUTHORIZED,
|
||||
},
|
||||
Self::InvalidAuthorization | Self::LogoutFailed | Self::TokenFormat(_) => MatrixError {
|
||||
Self::InvalidAuthorization | Self::TokenFormat(_) => MatrixError {
|
||||
errcode: "M_UNKNOWN_TOKEN",
|
||||
error: "Invalid access token",
|
||||
status: StatusCode::UNAUTHORIZED,
|
||||
@@ -67,12 +66,10 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
|
||||
pub(crate) async fn post(
|
||||
State(pool): State<PgPool>,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let clock = Clock::default();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?;
|
||||
|
||||
let token = authorization.token();
|
||||
@@ -82,9 +79,23 @@ pub(crate) async fn post(
|
||||
return Err(RouteError::InvalidAuthorization);
|
||||
}
|
||||
|
||||
if !compat_logout(&mut conn, &clock, token).await? {
|
||||
return Err(RouteError::LogoutFailed);
|
||||
}
|
||||
let token = repo
|
||||
.compat_access_token()
|
||||
.find_by_token(token)
|
||||
.await?
|
||||
.filter(|t| t.is_valid(clock.now()))
|
||||
.ok_or(RouteError::InvalidAuthorization)?;
|
||||
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.lookup(token.session_id)
|
||||
.await?
|
||||
.filter(|s| s.is_valid())
|
||||
.ok_or(RouteError::InvalidAuthorization)?;
|
||||
|
||||
repo.compat_session().finish(&clock, session).await?;
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok(Json(serde_json::json!({})))
|
||||
}
|
||||
|
||||
@@ -12,17 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use axum::{extract::State, response::IntoResponse, Json};
|
||||
use axum::{response::IntoResponse, Json};
|
||||
use chrono::Duration;
|
||||
use hyper::StatusCode;
|
||||
use mas_data_model::{TokenFormatError, TokenType};
|
||||
use mas_storage::compat::{
|
||||
add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token,
|
||||
expire_compat_access_token, lookup_active_compat_refresh_token,
|
||||
use mas_storage::{
|
||||
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
|
||||
BoxClock, BoxRepository, BoxRng, Clock,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DurationMilliSeconds};
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
|
||||
use super::MatrixError;
|
||||
@@ -40,17 +39,26 @@ pub enum RouteError {
|
||||
|
||||
#[error("invalid token")]
|
||||
InvalidToken,
|
||||
|
||||
#[error("refresh token already consumed")]
|
||||
RefreshTokenConsumed,
|
||||
|
||||
#[error("invalid session")]
|
||||
InvalidSession,
|
||||
|
||||
#[error("unknown session")]
|
||||
UnknownSession,
|
||||
}
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::Internal(_) => MatrixError {
|
||||
Self::Internal(_) | Self::UnknownSession => MatrixError {
|
||||
errcode: "M_UNKNOWN",
|
||||
error: "Internal error",
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
},
|
||||
Self::InvalidToken => MatrixError {
|
||||
Self::InvalidToken | Self::InvalidSession | Self::RefreshTokenConsumed => MatrixError {
|
||||
errcode: "M_UNKNOWN_TOKEN",
|
||||
error: "Invalid refresh token",
|
||||
status: StatusCode::UNAUTHORIZED,
|
||||
@@ -60,8 +68,7 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl From<TokenFormatError> for RouteError {
|
||||
fn from(_e: TokenFormatError) -> Self {
|
||||
@@ -79,50 +86,79 @@ pub struct ResponseBody {
|
||||
}
|
||||
|
||||
pub(crate) async fn post(
|
||||
State(pool): State<PgPool>,
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
Json(input): Json<RequestBody>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let token_type = TokenType::check(&input.refresh_token)?;
|
||||
|
||||
if token_type != TokenType::CompatRefreshToken {
|
||||
return Err(RouteError::InvalidToken);
|
||||
}
|
||||
|
||||
let (refresh_token, access_token, session) =
|
||||
lookup_active_compat_refresh_token(&mut txn, &input.refresh_token)
|
||||
.await?
|
||||
.ok_or(RouteError::InvalidToken)?;
|
||||
let refresh_token = repo
|
||||
.compat_refresh_token()
|
||||
.find_by_token(&input.refresh_token)
|
||||
.await?
|
||||
.ok_or(RouteError::InvalidToken)?;
|
||||
|
||||
if !refresh_token.is_valid() {
|
||||
return Err(RouteError::RefreshTokenConsumed);
|
||||
}
|
||||
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.lookup(refresh_token.session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::UnknownSession)?;
|
||||
|
||||
if !session.is_valid() {
|
||||
return Err(RouteError::InvalidSession);
|
||||
}
|
||||
|
||||
let access_token = repo
|
||||
.compat_access_token()
|
||||
.lookup(refresh_token.access_token_id)
|
||||
.await?
|
||||
.filter(|t| t.is_valid(clock.now()));
|
||||
|
||||
let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng);
|
||||
let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
|
||||
|
||||
let expires_in = Duration::minutes(5);
|
||||
let new_access_token = add_compat_access_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
new_access_token_str,
|
||||
Some(expires_in),
|
||||
)
|
||||
.await?;
|
||||
let new_refresh_token = add_compat_refresh_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
&new_access_token,
|
||||
new_refresh_token_str,
|
||||
)
|
||||
.await?;
|
||||
let new_access_token = repo
|
||||
.compat_access_token()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
new_access_token_str,
|
||||
Some(expires_in),
|
||||
)
|
||||
.await?;
|
||||
let new_refresh_token = repo
|
||||
.compat_refresh_token()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
&new_access_token,
|
||||
new_refresh_token_str,
|
||||
)
|
||||
.await?;
|
||||
|
||||
consume_compat_refresh_token(&mut txn, &clock, refresh_token).await?;
|
||||
expire_compat_access_token(&mut txn, &clock, access_token).await?;
|
||||
repo.compat_refresh_token()
|
||||
.consume(&clock, refresh_token)
|
||||
.await?;
|
||||
|
||||
txn.commit().await?;
|
||||
if let Some(access_token) = access_token {
|
||||
repo.compat_access_token()
|
||||
.expire(&clock, access_token)
|
||||
.await?;
|
||||
}
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok(Json(ResponseBody {
|
||||
access_token: new_access_token.token,
|
||||
|
||||
@@ -22,19 +22,19 @@ use axum::{
|
||||
Json, TypedHeader,
|
||||
};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use futures_util::{StreamExt, TryStreamExt};
|
||||
use futures_util::TryStreamExt;
|
||||
use headers::{ContentType, HeaderValue};
|
||||
use hyper::header::CACHE_CONTROL;
|
||||
use mas_axum_utils::{FancyError, SessionInfoExt};
|
||||
use mas_graphql::Schema;
|
||||
use mas_keystore::Encrypter;
|
||||
use sqlx::PgPool;
|
||||
use mas_storage::BoxRepository;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{info_span, Instrument};
|
||||
|
||||
#[must_use]
|
||||
pub fn schema(pool: &PgPool) -> Schema {
|
||||
pub fn schema() -> Schema {
|
||||
mas_graphql::schema_builder()
|
||||
.data(pool.clone())
|
||||
.extension(Tracing)
|
||||
.extension(ApolloTracing)
|
||||
.finish()
|
||||
@@ -58,8 +58,8 @@ fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
|
||||
}
|
||||
|
||||
pub async fn post(
|
||||
State(pool): State<PgPool>,
|
||||
State(schema): State<Schema>,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
content_type: Option<TypedHeader<ContentType>>,
|
||||
body: BodyStream,
|
||||
@@ -67,58 +67,46 @@ pub async fn post(
|
||||
let content_type = content_type.map(|TypedHeader(h)| h.to_string());
|
||||
|
||||
let (session_info, _cookie_jar) = cookie_jar.session_info();
|
||||
let maybe_session = session_info.load_session(&pool).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let mut request = async_graphql::http::receive_batch_body(
|
||||
let mut request = async_graphql::http::receive_body(
|
||||
content_type,
|
||||
body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||
.into_async_read(),
|
||||
MultipartOptions::default(),
|
||||
)
|
||||
.await?; // XXX: this should probably return another error response?
|
||||
.await? // XXX: this should probably return another error response?
|
||||
.data(Mutex::new(repo));
|
||||
|
||||
if let Some(session) = maybe_session {
|
||||
request = request.data(session);
|
||||
}
|
||||
|
||||
let response = match request {
|
||||
async_graphql::BatchRequest::Single(request) => {
|
||||
let span = span_for_graphql_request(&request);
|
||||
let response = schema.execute(request).instrument(span).await;
|
||||
async_graphql::BatchResponse::Single(response)
|
||||
}
|
||||
async_graphql::BatchRequest::Batch(requests) => async_graphql::BatchResponse::Batch(
|
||||
futures_util::stream::iter(requests.into_iter())
|
||||
.then(|request| {
|
||||
let span = span_for_graphql_request(&request);
|
||||
schema.execute(request).instrument(span)
|
||||
})
|
||||
.collect()
|
||||
.await,
|
||||
),
|
||||
};
|
||||
let span = span_for_graphql_request(&request);
|
||||
let response = schema.execute(request).instrument(span).await;
|
||||
|
||||
let cache_control = response
|
||||
.cache_control()
|
||||
.cache_control
|
||||
.value()
|
||||
.and_then(|v| HeaderValue::from_str(&v).ok())
|
||||
.map(|h| [(CACHE_CONTROL, h)]);
|
||||
|
||||
let headers = response.http_headers();
|
||||
let headers = response.http_headers.clone();
|
||||
|
||||
Ok((headers, cache_control, Json(response)))
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
State(pool): State<PgPool>,
|
||||
State(schema): State<Schema>,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
RawQuery(query): RawQuery,
|
||||
) -> Result<impl IntoResponse, FancyError> {
|
||||
let (session_info, _cookie_jar) = cookie_jar.session_info();
|
||||
let maybe_session = session_info.load_session(&pool).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?;
|
||||
let mut request =
|
||||
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(Mutex::new(repo));
|
||||
|
||||
if let Some(session) = maybe_session {
|
||||
request = request.data(session);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -36,7 +36,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage::MIGRATOR")]
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_get_health(pool: PgPool) -> Result<(), anyhow::Error> {
|
||||
let state = crate::test_state(pool).await?;
|
||||
let app = crate::healthcheck_router().with_state(state);
|
||||
|
||||
@@ -21,14 +21,17 @@
|
||||
)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(
|
||||
clippy::unused_async // Some axum handlers need that
|
||||
// Some axum handlers need that
|
||||
clippy::unused_async,
|
||||
// Because of how axum handlers work, we sometime have take many arguments
|
||||
clippy::too_many_arguments,
|
||||
)]
|
||||
|
||||
use std::{convert::Infallible, sync::Arc, time::Duration};
|
||||
|
||||
use axum::{
|
||||
body::{Bytes, HttpBody},
|
||||
extract::FromRef,
|
||||
extract::{FromRef, FromRequestParts},
|
||||
response::{Html, IntoResponse},
|
||||
routing::{get, on, post, MethodFilter},
|
||||
Router,
|
||||
@@ -40,9 +43,9 @@ use mas_http::CorsLayerExt;
|
||||
use mas_keystore::{Encrypter, Keystore};
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::{Route, UrlBuilder};
|
||||
use mas_storage::{BoxClock, BoxRepository, BoxRng};
|
||||
use mas_templates::{ErrorContext, Templates};
|
||||
use passwords::PasswordManager;
|
||||
use rand::SeedableRng;
|
||||
use sqlx::PgPool;
|
||||
use tower::util::AndThenLayer;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
@@ -94,7 +97,7 @@ where
|
||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
mas_graphql::Schema: FromRef<S>,
|
||||
PgPool: FromRef<S>,
|
||||
BoxRepository: FromRequestParts<S>,
|
||||
Encrypter: FromRef<S>,
|
||||
{
|
||||
let mut router = Router::new().route(
|
||||
@@ -116,6 +119,8 @@ where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
Keystore: FromRef<S>,
|
||||
UrlBuilder: FromRef<S>,
|
||||
BoxClock: FromRequestParts<S>,
|
||||
BoxRng: FromRequestParts<S>,
|
||||
{
|
||||
Router::new()
|
||||
.route(
|
||||
@@ -152,9 +157,11 @@ where
|
||||
Keystore: FromRef<S>,
|
||||
UrlBuilder: FromRef<S>,
|
||||
Arc<PolicyFactory>: FromRef<S>,
|
||||
PgPool: FromRef<S>,
|
||||
BoxRepository: FromRequestParts<S>,
|
||||
Encrypter: FromRef<S>,
|
||||
HttpClientFactory: FromRef<S>,
|
||||
BoxClock: FromRequestParts<S>,
|
||||
BoxRng: FromRequestParts<S>,
|
||||
{
|
||||
// All those routes are API-like, with a common CORS layer
|
||||
Router::new()
|
||||
@@ -205,9 +212,11 @@ where
|
||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
UrlBuilder: FromRef<S>,
|
||||
PgPool: FromRef<S>,
|
||||
BoxRepository: FromRequestParts<S>,
|
||||
MatrixHomeserver: FromRef<S>,
|
||||
PasswordManager: FromRef<S>,
|
||||
BoxClock: FromRequestParts<S>,
|
||||
BoxRng: FromRequestParts<S>,
|
||||
{
|
||||
Router::new()
|
||||
.route(
|
||||
@@ -248,13 +257,15 @@ where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
UrlBuilder: FromRef<S>,
|
||||
Arc<PolicyFactory>: FromRef<S>,
|
||||
PgPool: FromRef<S>,
|
||||
BoxRepository: FromRequestParts<S>,
|
||||
Encrypter: FromRef<S>,
|
||||
Templates: FromRef<S>,
|
||||
Mailer: FromRef<S>,
|
||||
Keystore: FromRef<S>,
|
||||
HttpClientFactory: FromRef<S>,
|
||||
PasswordManager: FromRef<S>,
|
||||
BoxClock: FromRequestParts<S>,
|
||||
BoxRng: FromRequestParts<S>,
|
||||
{
|
||||
Router::new()
|
||||
.route(
|
||||
@@ -350,7 +361,7 @@ where
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn test_state(pool: PgPool) -> Result<AppState, anyhow::Error> {
|
||||
async fn test_state(pool: sqlx::PgPool) -> Result<AppState, anyhow::Error> {
|
||||
use mas_email::MailTransport;
|
||||
|
||||
use crate::passwords::Hasher;
|
||||
@@ -389,7 +400,7 @@ async fn test_state(pool: PgPool) -> Result<AppState, anyhow::Error> {
|
||||
|
||||
let policy_factory = Arc::new(policy_factory);
|
||||
|
||||
let graphql_schema = graphql_schema(&pool);
|
||||
let graphql_schema = graphql_schema();
|
||||
|
||||
let http_client_factory = HttpClientFactory::new(10);
|
||||
|
||||
@@ -407,16 +418,3 @@ async fn test_state(pool: PgPool) -> Result<AppState, anyhow::Error> {
|
||||
password_manager,
|
||||
})
|
||||
}
|
||||
|
||||
// XXX: that should be moved somewhere else
|
||||
fn clock_and_rng() -> (mas_storage::Clock, rand_chacha::ChaChaRng) {
|
||||
let clock = mas_storage::Clock::default();
|
||||
|
||||
// This rng is used to source the local rng
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let rng = rand::thread_rng();
|
||||
|
||||
let rng = rand_chacha::ChaChaRng::from_rng(rng).expect("Failed to seed RNG");
|
||||
|
||||
(clock, rng)
|
||||
}
|
||||
|
||||
@@ -25,13 +25,12 @@ use mas_data_model::{AuthorizationGrant, BrowserSession};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::{PostAuthAction, Route};
|
||||
use mas_storage::oauth2::{
|
||||
authorization_grant::{derive_session, fulfill_grant, get_grant_by_id},
|
||||
consent::fetch_client_consent,
|
||||
use mas_storage::{
|
||||
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
|
||||
BoxClock, BoxRepository, BoxRng,
|
||||
};
|
||||
use mas_templates::Templates;
|
||||
use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse};
|
||||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
@@ -69,8 +68,7 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
@@ -78,19 +76,21 @@ impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
|
||||
impl_from_error_for_route!(super::callback::CallbackDestinationError);
|
||||
|
||||
pub(crate) async fn get(
|
||||
rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Path(grant_id): Path<Ulid>,
|
||||
) -> Result<Response, RouteError> {
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let grant = get_grant_by_id(&mut txn, grant_id)
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.lookup(grant_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NotFound)?;
|
||||
|
||||
@@ -105,7 +105,7 @@ pub(crate) async fn get(
|
||||
return Ok((cookie_jar, mas_router::Login::and_then(continue_grant).go()).into_response());
|
||||
};
|
||||
|
||||
match complete(grant, session, &policy_factory, txn).await {
|
||||
match complete(rng, clock, grant, session, &policy_factory, repo).await {
|
||||
Ok(params) => {
|
||||
let res = callback_destination.go(&templates, params).await?;
|
||||
Ok((cookie_jar, res).into_response())
|
||||
@@ -121,6 +121,7 @@ pub(crate) async fn get(
|
||||
}
|
||||
Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending),
|
||||
Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)),
|
||||
Err(e) => Err(RouteError::Internal(e.into())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,23 +141,25 @@ pub enum GrantCompletionError {
|
||||
|
||||
#[error("denied by the policy")]
|
||||
PolicyViolation,
|
||||
|
||||
#[error("failed to load client")]
|
||||
NoSuchClient,
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(GrantCompletionError: sqlx::Error);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError);
|
||||
|
||||
pub(crate) async fn complete(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
grant: AuthorizationGrant,
|
||||
browser_session: BrowserSession,
|
||||
policy_factory: &PolicyFactory,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
mut repo: BoxRepository,
|
||||
) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
|
||||
// Verify that the grant is in a pending stage
|
||||
if !grant.stage.is_pending() {
|
||||
return Err(GrantCompletionError::NotPending);
|
||||
@@ -164,7 +167,7 @@ pub(crate) async fn complete(
|
||||
|
||||
// Check if the authentication is fresh enough
|
||||
if !browser_session.was_authenticated_after(grant.max_auth_time()) {
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
return Err(GrantCompletionError::RequiresReauth);
|
||||
}
|
||||
|
||||
@@ -178,8 +181,16 @@ pub(crate) async fn complete(
|
||||
return Err(GrantCompletionError::PolicyViolation);
|
||||
}
|
||||
|
||||
let current_consent =
|
||||
fetch_client_consent(&mut txn, &browser_session.user, &grant.client).await?;
|
||||
let client = repo
|
||||
.oauth2_client()
|
||||
.lookup(grant.client_id)
|
||||
.await?
|
||||
.ok_or(GrantCompletionError::NoSuchClient)?;
|
||||
|
||||
let current_consent = repo
|
||||
.oauth2_client()
|
||||
.get_consent_for_user(&client, &browser_session.user)
|
||||
.await?;
|
||||
|
||||
let lacks_consent = grant
|
||||
.scope
|
||||
@@ -188,14 +199,20 @@ pub(crate) async fn complete(
|
||||
|
||||
// Check if the client lacks consent *or* if consent was explicitely asked
|
||||
if lacks_consent || grant.requires_consent {
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
return Err(GrantCompletionError::RequiresConsent);
|
||||
}
|
||||
|
||||
// All good, let's start the session
|
||||
let session = derive_session(&mut txn, &mut rng, &clock, &grant, browser_session).await?;
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.create_from_grant(&mut rng, &clock, &grant, &browser_session)
|
||||
.await?;
|
||||
|
||||
let grant = fulfill_grant(&mut txn, grant, session.clone()).await?;
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.fulfill(&clock, &session, grant)
|
||||
.await?;
|
||||
|
||||
// Yep! Let's complete the auth now
|
||||
let mut params = AuthorizationResponse::default();
|
||||
@@ -213,6 +230,6 @@ pub(crate) async fn complete(
|
||||
));
|
||||
}
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
Ok(params)
|
||||
}
|
||||
|
||||
@@ -25,8 +25,9 @@ use mas_data_model::{AuthorizationCode, Pkce};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::{PostAuthAction, Route};
|
||||
use mas_storage::oauth2::{
|
||||
authorization_grant::new_authorization_grant, client::lookup_client_by_client_id,
|
||||
use mas_storage::{
|
||||
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
|
||||
BoxClock, BoxRepository, BoxRng,
|
||||
};
|
||||
use mas_templates::Templates;
|
||||
use oauth2_types::{
|
||||
@@ -37,7 +38,6 @@ use oauth2_types::{
|
||||
};
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
|
||||
use self::{callback::CallbackDestination, complete::GrantCompletionError};
|
||||
@@ -89,8 +89,7 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(self::callback::CallbackDestinationError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
@@ -131,17 +130,18 @@ fn resolve_response_mode(
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(params): Form<Params>,
|
||||
) -> Result<Response, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
// First, figure out what client it is
|
||||
let client = lookup_client_by_client_id(&mut txn, ¶ms.auth.client_id)
|
||||
let client = repo
|
||||
.oauth2_client()
|
||||
.find_by_client_id(¶ms.auth.client_id)
|
||||
.await?
|
||||
.ok_or(RouteError::ClientNotFound)?;
|
||||
|
||||
@@ -167,7 +167,7 @@ pub(crate) async fn get(
|
||||
let templates = templates.clone();
|
||||
let callback_destination = callback_destination.clone();
|
||||
async move {
|
||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
let prompt = params.auth.prompt.as_deref().unwrap_or_default();
|
||||
|
||||
// Check if the request/request_uri/registration params are used. If so, reply
|
||||
@@ -272,23 +272,23 @@ pub(crate) async fn get(
|
||||
|
||||
let requires_consent = prompt.contains(&Prompt::Consent);
|
||||
|
||||
let grant = new_authorization_grant(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
client,
|
||||
redirect_uri.clone(),
|
||||
params.auth.scope,
|
||||
code,
|
||||
params.auth.state.clone(),
|
||||
params.auth.nonce,
|
||||
params.auth.max_age,
|
||||
None,
|
||||
response_mode,
|
||||
response_type.has_id_token(),
|
||||
requires_consent,
|
||||
)
|
||||
.await?;
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&client,
|
||||
redirect_uri.clone(),
|
||||
params.auth.scope,
|
||||
code,
|
||||
params.auth.state.clone(),
|
||||
params.auth.nonce,
|
||||
params.auth.max_age,
|
||||
response_mode,
|
||||
response_type.has_id_token(),
|
||||
requires_consent,
|
||||
)
|
||||
.await?;
|
||||
let continue_grant = PostAuthAction::continue_grant(grant.id);
|
||||
|
||||
let res = match maybe_session {
|
||||
@@ -299,7 +299,7 @@ pub(crate) async fn get(
|
||||
}
|
||||
None if prompt.contains(&Prompt::Create) => {
|
||||
// Client asked for a registration, show the registration prompt
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
mas_router::Register::and_then(continue_grant)
|
||||
.go()
|
||||
@@ -307,7 +307,7 @@ pub(crate) async fn get(
|
||||
}
|
||||
None => {
|
||||
// Other cases where we don't have a session, ask for a login
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
mas_router::Login::and_then(continue_grant)
|
||||
.go()
|
||||
@@ -320,7 +320,7 @@ pub(crate) async fn get(
|
||||
|| prompt.contains(&Prompt::SelectAccount) =>
|
||||
{
|
||||
// TODO: better pages here
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
mas_router::Reauth::and_then(continue_grant)
|
||||
.go()
|
||||
@@ -330,7 +330,15 @@ pub(crate) async fn get(
|
||||
// Else, we immediately try to complete the authorization grant
|
||||
Some(user_session) if prompt.contains(&Prompt::None) => {
|
||||
// With prompt=none, we should get back to the client immediately
|
||||
match self::complete::complete(grant, user_session, &policy_factory, txn).await
|
||||
match self::complete::complete(
|
||||
rng,
|
||||
clock,
|
||||
grant,
|
||||
user_session,
|
||||
&policy_factory,
|
||||
repo,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(params) => callback_destination.go(&templates, params).await?,
|
||||
Err(GrantCompletionError::RequiresConsent) => {
|
||||
@@ -357,7 +365,10 @@ pub(crate) async fn get(
|
||||
Err(GrantCompletionError::Internal(e)) => {
|
||||
return Err(RouteError::Internal(e))
|
||||
}
|
||||
Err(e @ GrantCompletionError::NotPending) => {
|
||||
Err(
|
||||
e @ (GrantCompletionError::NotPending
|
||||
| GrantCompletionError::NoSuchClient),
|
||||
) => {
|
||||
// This should never happen
|
||||
return Err(RouteError::Internal(Box::new(e)));
|
||||
}
|
||||
@@ -366,7 +377,15 @@ pub(crate) async fn get(
|
||||
Some(user_session) => {
|
||||
let grant_id = grant.id;
|
||||
// Else, we show the relevant reauth/consent page if necessary
|
||||
match self::complete::complete(grant, user_session, &policy_factory, txn).await
|
||||
match self::complete::complete(
|
||||
rng,
|
||||
clock,
|
||||
grant,
|
||||
user_session,
|
||||
&policy_factory,
|
||||
repo,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(params) => callback_destination.go(&templates, params).await?,
|
||||
Err(
|
||||
@@ -387,7 +406,10 @@ pub(crate) async fn get(
|
||||
Err(GrantCompletionError::Internal(e)) => {
|
||||
return Err(RouteError::Internal(e))
|
||||
}
|
||||
Err(e @ GrantCompletionError::NotPending) => {
|
||||
Err(
|
||||
e @ (GrantCompletionError::NotPending
|
||||
| GrantCompletionError::NoSuchClient),
|
||||
) => {
|
||||
// This should never happen
|
||||
return Err(RouteError::Internal(Box::new(e)));
|
||||
}
|
||||
|
||||
@@ -28,12 +28,11 @@ use mas_data_model::AuthorizationGrantStage;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::{PostAuthAction, Route};
|
||||
use mas_storage::oauth2::{
|
||||
authorization_grant::{get_grant_by_id, give_consent_to_grant},
|
||||
consent::insert_client_consent,
|
||||
use mas_storage::{
|
||||
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
|
||||
BoxClock, BoxRepository, BoxRng,
|
||||
};
|
||||
use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
@@ -55,11 +54,13 @@ pub enum RouteError {
|
||||
|
||||
#[error("Policy violation")]
|
||||
PolicyViolation,
|
||||
|
||||
#[error("Failed to load client")]
|
||||
NoSuchClient,
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
@@ -71,20 +72,21 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Path(grant_id): Path<Ulid>,
|
||||
) -> Result<Response, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut conn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let grant = get_grant_by_id(&mut conn, grant_id)
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.lookup(grant_id)
|
||||
.await?
|
||||
.ok_or(RouteError::GrantNotFound)?;
|
||||
|
||||
@@ -93,7 +95,7 @@ pub(crate) async fn get(
|
||||
}
|
||||
|
||||
if let Some(session) = maybe_session {
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
|
||||
let mut policy = policy_factory.instantiate().await?;
|
||||
let res = policy
|
||||
@@ -124,22 +126,23 @@ pub(crate) async fn get(
|
||||
}
|
||||
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Path(grant_id): Path<Ulid>,
|
||||
Form(form): Form<ProtectedForm<()>>,
|
||||
) -> Result<Response, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
cookie_jar.verify_form(clock.now(), form)?;
|
||||
cookie_jar.verify_form(&clock, form)?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let grant = get_grant_by_id(&mut txn, grant_id)
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.lookup(grant_id)
|
||||
.await?
|
||||
.ok_or(RouteError::GrantNotFound)?;
|
||||
let next = PostAuthAction::continue_grant(grant_id);
|
||||
@@ -160,6 +163,12 @@ pub(crate) async fn post(
|
||||
return Err(RouteError::PolicyViolation);
|
||||
}
|
||||
|
||||
let client = repo
|
||||
.oauth2_client()
|
||||
.lookup(grant.client_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchClient)?;
|
||||
|
||||
// Do not consent for the "urn:matrix:org.matrix.msc2967.client:device:*" scope
|
||||
let scope_without_device = grant
|
||||
.scope
|
||||
@@ -167,19 +176,21 @@ pub(crate) async fn post(
|
||||
.filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:"))
|
||||
.cloned()
|
||||
.collect();
|
||||
insert_client_consent(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session.user,
|
||||
&grant.client,
|
||||
&scope_without_device,
|
||||
)
|
||||
.await?;
|
||||
repo.oauth2_client()
|
||||
.give_consent_for_user(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&client,
|
||||
&session.user,
|
||||
&scope_without_device,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let _grant = give_consent_to_grant(&mut txn, grant).await?;
|
||||
repo.oauth2_authorization_grant()
|
||||
.give_consent(grant)
|
||||
.await?;
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok((cookie_jar, next.go_next()).into_response())
|
||||
}
|
||||
|
||||
@@ -22,18 +22,16 @@ use mas_data_model::{TokenFormatError, TokenType};
|
||||
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_storage::{
|
||||
compat::{lookup_active_compat_access_token, lookup_active_compat_refresh_token},
|
||||
oauth2::{
|
||||
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
|
||||
},
|
||||
Clock,
|
||||
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
|
||||
oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository},
|
||||
user::{BrowserSessionRepository, UserRepository},
|
||||
BoxClock, BoxRepository, Clock,
|
||||
};
|
||||
use oauth2_types::{
|
||||
errors::{ClientError, ClientErrorCode},
|
||||
requests::{IntrospectionRequest, IntrospectionResponse},
|
||||
scope::ScopeToken,
|
||||
};
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::impl_from_error_for_route;
|
||||
@@ -97,8 +95,7 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl From<TokenFormatError> for RouteError {
|
||||
fn from(_e: TokenFormatError) -> Self {
|
||||
@@ -125,18 +122,17 @@ const API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:org.matrix.msc
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(crate) async fn post(
|
||||
clock: BoxClock,
|
||||
State(http_client_factory): State<HttpClientFactory>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
State(encrypter): State<Encrypter>,
|
||||
client_authorization: ClientAuthorization<IntrospectionRequest>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let clock = Clock::default();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let client = client_authorization
|
||||
.credentials
|
||||
.fetch(&mut conn)
|
||||
.await?
|
||||
.fetch(&mut repo)
|
||||
.await
|
||||
.unwrap()
|
||||
.ok_or(RouteError::ClientNotFound)?;
|
||||
|
||||
let method = match &client.token_endpoint_auth_method {
|
||||
@@ -167,48 +163,103 @@ pub(crate) async fn post(
|
||||
|
||||
let reply = match token_type {
|
||||
TokenType::AccessToken => {
|
||||
let (token, session) = lookup_active_access_token(&mut conn, token)
|
||||
let token = repo
|
||||
.oauth2_access_token()
|
||||
.find_by_token(token)
|
||||
.await?
|
||||
.filter(|t| t.is_valid(clock.now()))
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.lookup(token.session_id)
|
||||
.await?
|
||||
.filter(|s| s.is_valid())
|
||||
// XXX: is that the right error to bubble up?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let browser_session = repo
|
||||
.browser_session()
|
||||
.lookup(session.user_session_id)
|
||||
.await?
|
||||
// XXX: is that the right error to bubble up?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
IntrospectionResponse {
|
||||
active: true,
|
||||
scope: Some(session.scope),
|
||||
client_id: Some(session.client.client_id),
|
||||
username: Some(session.browser_session.user.username),
|
||||
client_id: Some(session.client_id.to_string()),
|
||||
username: Some(browser_session.user.username),
|
||||
token_type: Some(OAuthTokenTypeHint::AccessToken),
|
||||
exp: Some(token.expires_at),
|
||||
iat: Some(token.created_at),
|
||||
nbf: Some(token.created_at),
|
||||
sub: Some(session.browser_session.user.sub),
|
||||
sub: Some(browser_session.user.sub),
|
||||
aud: None,
|
||||
iss: None,
|
||||
jti: None,
|
||||
jti: Some(token.jti()),
|
||||
}
|
||||
}
|
||||
|
||||
TokenType::RefreshToken => {
|
||||
let (token, session) = lookup_active_refresh_token(&mut conn, token)
|
||||
let token = repo
|
||||
.oauth2_refresh_token()
|
||||
.find_by_token(token)
|
||||
.await?
|
||||
.filter(|t| t.is_valid())
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.lookup(token.session_id)
|
||||
.await?
|
||||
.filter(|s| s.is_valid())
|
||||
// XXX: is that the right error to bubble up?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let browser_session = repo
|
||||
.browser_session()
|
||||
.lookup(session.user_session_id)
|
||||
.await?
|
||||
// XXX: is that the right error to bubble up?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
IntrospectionResponse {
|
||||
active: true,
|
||||
scope: Some(session.scope),
|
||||
client_id: Some(session.client.client_id),
|
||||
username: Some(session.browser_session.user.username),
|
||||
client_id: Some(session.client_id.to_string()),
|
||||
username: Some(browser_session.user.username),
|
||||
token_type: Some(OAuthTokenTypeHint::RefreshToken),
|
||||
exp: None,
|
||||
iat: Some(token.created_at),
|
||||
nbf: Some(token.created_at),
|
||||
sub: Some(session.browser_session.user.sub),
|
||||
sub: Some(browser_session.user.sub),
|
||||
aud: None,
|
||||
iss: None,
|
||||
jti: None,
|
||||
jti: Some(token.jti()),
|
||||
}
|
||||
}
|
||||
|
||||
TokenType::CompatAccessToken => {
|
||||
let (token, session) = lookup_active_compat_access_token(&mut conn, &clock, token)
|
||||
let access_token = repo
|
||||
.compat_access_token()
|
||||
.find_by_token(token)
|
||||
.await?
|
||||
.filter(|t| t.is_valid(clock.now()))
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.lookup(access_token.session_id)
|
||||
.await?
|
||||
.filter(|s| s.is_valid())
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(session.user_id)
|
||||
.await?
|
||||
// XXX: is that the right error to bubble up?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let device_scope = session.device.to_scope_token();
|
||||
@@ -218,22 +269,39 @@ pub(crate) async fn post(
|
||||
active: true,
|
||||
scope: Some(scope),
|
||||
client_id: Some("legacy".into()),
|
||||
username: Some(session.user.username),
|
||||
username: Some(user.username),
|
||||
token_type: Some(OAuthTokenTypeHint::AccessToken),
|
||||
exp: token.expires_at,
|
||||
iat: Some(token.created_at),
|
||||
nbf: Some(token.created_at),
|
||||
sub: Some(session.user.sub),
|
||||
exp: access_token.expires_at,
|
||||
iat: Some(access_token.created_at),
|
||||
nbf: Some(access_token.created_at),
|
||||
sub: Some(user.sub),
|
||||
aud: None,
|
||||
iss: None,
|
||||
jti: None,
|
||||
}
|
||||
}
|
||||
|
||||
TokenType::CompatRefreshToken => {
|
||||
let (refresh_token, _access_token, session) =
|
||||
lookup_active_compat_refresh_token(&mut conn, token)
|
||||
.await?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
let refresh_token = repo
|
||||
.compat_refresh_token()
|
||||
.find_by_token(token)
|
||||
.await?
|
||||
.filter(|t| t.is_valid())
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.lookup(refresh_token.session_id)
|
||||
.await?
|
||||
.filter(|s| s.is_valid())
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(session.user_id)
|
||||
.await?
|
||||
// XXX: is that the right error to bubble up?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let device_scope = session.device.to_scope_token();
|
||||
let scope = [API_SCOPE, device_scope].into_iter().collect();
|
||||
@@ -242,12 +310,12 @@ pub(crate) async fn post(
|
||||
active: true,
|
||||
scope: Some(scope),
|
||||
client_id: Some("legacy".into()),
|
||||
username: Some(session.user.username),
|
||||
username: Some(user.username),
|
||||
token_type: Some(OAuthTokenTypeHint::RefreshToken),
|
||||
exp: None,
|
||||
iat: Some(refresh_token.created_at),
|
||||
nbf: Some(refresh_token.created_at),
|
||||
sub: Some(session.user.sub),
|
||||
sub: Some(user.sub),
|
||||
aud: None,
|
||||
iss: None,
|
||||
jti: None,
|
||||
|
||||
@@ -19,7 +19,7 @@ use hyper::StatusCode;
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_policy::{PolicyFactory, Violation};
|
||||
use mas_storage::oauth2::client::insert_client;
|
||||
use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng};
|
||||
use oauth2_types::{
|
||||
errors::{ClientError, ClientErrorCode},
|
||||
registration::{
|
||||
@@ -27,10 +27,8 @@ use oauth2_types::{
|
||||
},
|
||||
};
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use tracing::info;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::impl_from_error_for_route;
|
||||
|
||||
@@ -49,7 +47,7 @@ pub(crate) enum RouteError {
|
||||
PolicyDenied(Vec<Violation>),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
@@ -107,12 +105,13 @@ impl IntoResponse for RouteError {
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub(crate) async fn post(
|
||||
State(pool): State<PgPool>,
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||
State(encrypter): State<Encrypter>,
|
||||
Json(body): Json<ClientMetadata>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
info!(?body, "Client registration");
|
||||
|
||||
// Validate the body
|
||||
@@ -124,16 +123,6 @@ pub(crate) async fn post(
|
||||
return Err(RouteError::PolicyDenied(res.violations));
|
||||
}
|
||||
|
||||
// Contacts was checked by the policy
|
||||
let contacts = metadata.contacts.as_deref().unwrap_or_default();
|
||||
|
||||
// Grab a txn
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let now = clock.now();
|
||||
// Let's generate a random client ID
|
||||
let client_id = Ulid::from_datetime_with_source(now.into(), &mut rng);
|
||||
|
||||
let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
|
||||
Some(
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt
|
||||
@@ -148,41 +137,42 @@ pub(crate) async fn post(
|
||||
_ => (None, None),
|
||||
};
|
||||
|
||||
insert_client(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
client_id,
|
||||
metadata.redirect_uris(),
|
||||
encrypted_client_secret.as_deref(),
|
||||
//&metadata.response_types(),
|
||||
metadata.grant_types(),
|
||||
contacts,
|
||||
metadata
|
||||
.client_name
|
||||
.as_ref()
|
||||
.map(|l| l.non_localized().as_ref()),
|
||||
metadata.logo_uri.as_ref().map(Localized::non_localized),
|
||||
metadata.client_uri.as_ref().map(Localized::non_localized),
|
||||
metadata.policy_uri.as_ref().map(Localized::non_localized),
|
||||
metadata.tos_uri.as_ref().map(Localized::non_localized),
|
||||
metadata.jwks_uri.as_ref(),
|
||||
metadata.jwks.as_ref(),
|
||||
// XXX: those might not be right, should be function calls
|
||||
metadata.id_token_signed_response_alg.as_ref(),
|
||||
metadata.userinfo_signed_response_alg.as_ref(),
|
||||
metadata.token_endpoint_auth_method.as_ref(),
|
||||
metadata.token_endpoint_auth_signing_alg.as_ref(),
|
||||
metadata.initiate_login_uri.as_ref(),
|
||||
)
|
||||
.await?;
|
||||
let client = repo
|
||||
.oauth2_client()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
metadata.redirect_uris().to_vec(),
|
||||
encrypted_client_secret,
|
||||
//&metadata.response_types(),
|
||||
metadata.grant_types().to_vec(),
|
||||
metadata.contacts.clone().unwrap_or_default(),
|
||||
metadata
|
||||
.client_name
|
||||
.clone()
|
||||
.map(Localized::to_non_localized),
|
||||
metadata.logo_uri.clone().map(Localized::to_non_localized),
|
||||
metadata.client_uri.clone().map(Localized::to_non_localized),
|
||||
metadata.policy_uri.clone().map(Localized::to_non_localized),
|
||||
metadata.tos_uri.clone().map(Localized::to_non_localized),
|
||||
metadata.jwks_uri.clone(),
|
||||
metadata.jwks.clone(),
|
||||
// XXX: those might not be right, should be function calls
|
||||
metadata.id_token_signed_response_alg.clone(),
|
||||
metadata.userinfo_signed_response_alg.clone(),
|
||||
metadata.token_endpoint_auth_method.clone(),
|
||||
metadata.token_endpoint_auth_signing_alg.clone(),
|
||||
metadata.initiate_login_uri.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
let response = ClientRegistrationResponse {
|
||||
client_id: client_id.to_string(),
|
||||
client_id: client.client_id,
|
||||
client_secret,
|
||||
client_id_issued_at: Some(now),
|
||||
// XXX: we should have a `created_at` field on the clients
|
||||
client_id_issued_at: Some(client.id.datetime().into()),
|
||||
client_secret_expires_at: None,
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -31,11 +31,13 @@ use mas_jose::{
|
||||
};
|
||||
use mas_keystore::{Encrypter, Keystore};
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::oauth2::{
|
||||
access_token::{add_access_token, revoke_access_token},
|
||||
authorization_grant::{exchange_grant, lookup_grant_by_code},
|
||||
end_oauth_session,
|
||||
refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token},
|
||||
use mas_storage::{
|
||||
oauth2::{
|
||||
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
|
||||
OAuth2RefreshTokenRepository, OAuth2SessionRepository,
|
||||
},
|
||||
user::BrowserSessionRepository,
|
||||
BoxClock, BoxRepository, BoxRng, Clock,
|
||||
};
|
||||
use oauth2_types::{
|
||||
errors::{ClientError, ClientErrorCode},
|
||||
@@ -47,7 +49,6 @@ use oauth2_types::{
|
||||
};
|
||||
use serde::Serialize;
|
||||
use serde_with::{serde_as, skip_serializing_none};
|
||||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
use thiserror::Error;
|
||||
use tracing::debug;
|
||||
use url::Url;
|
||||
@@ -102,12 +103,21 @@ pub(crate) enum RouteError {
|
||||
|
||||
#[error("no suitable key found for signing")]
|
||||
InvalidSigningKey,
|
||||
|
||||
#[error("failed to load browser session")]
|
||||
NoSuchBrowserSession,
|
||||
|
||||
#[error("failed to load oauth session")]
|
||||
NoSuchOAuthSession,
|
||||
}
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::Internal(_) | Self::InvalidSigningKey => (
|
||||
Self::Internal(_)
|
||||
| Self::InvalidSigningKey
|
||||
| Self::NoSuchBrowserSession
|
||||
| Self::NoSuchOAuthSession => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ClientError::from(ClientErrorCode::ServerError)),
|
||||
),
|
||||
@@ -139,8 +149,7 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_keystore::WrongAlgorithmError);
|
||||
impl_from_error_for_route!(mas_jose::claims::ClaimError);
|
||||
impl_from_error_for_route!(mas_jose::claims::TokenHashError);
|
||||
@@ -148,18 +157,18 @@ impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError);
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(http_client_factory): State<HttpClientFactory>,
|
||||
State(key_store): State<Keystore>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
State(encrypter): State<Encrypter>,
|
||||
client_authorization: ClientAuthorization<AccessTokenRequest>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let client = client_authorization
|
||||
.credentials
|
||||
.fetch(&mut txn)
|
||||
.fetch(&mut repo)
|
||||
.await?
|
||||
.ok_or(RouteError::ClientNotFound)?;
|
||||
|
||||
@@ -175,18 +184,29 @@ pub(crate) async fn post(
|
||||
|
||||
let form = client_authorization.form.ok_or(RouteError::BadRequest)?;
|
||||
|
||||
let reply = match form {
|
||||
let (reply, repo) = match form {
|
||||
AccessTokenRequest::AuthorizationCode(grant) => {
|
||||
authorization_code_grant(&grant, &client, &key_store, &url_builder, txn).await?
|
||||
authorization_code_grant(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&grant,
|
||||
&client,
|
||||
&key_store,
|
||||
&url_builder,
|
||||
repo,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
AccessTokenRequest::RefreshToken(grant) => {
|
||||
refresh_token_grant(&grant, &client, txn).await?
|
||||
refresh_token_grant(&mut rng, &clock, &grant, &client, repo).await?
|
||||
}
|
||||
_ => {
|
||||
return Err(RouteError::InvalidGrant);
|
||||
}
|
||||
};
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.typed_insert(CacheControl::new().with_no_store());
|
||||
headers.typed_insert(Pragma::no_cache());
|
||||
@@ -196,23 +216,23 @@ pub(crate) async fn post(
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn authorization_code_grant(
|
||||
mut rng: &mut BoxRng,
|
||||
clock: &impl Clock,
|
||||
grant: &AuthorizationCodeGrant,
|
||||
client: &Client,
|
||||
key_store: &Keystore,
|
||||
url_builder: &UrlBuilder,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
) -> Result<AccessTokenResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
|
||||
// TODO: there is a bunch of unnecessary cloning here
|
||||
// TODO: handle "not found" cases
|
||||
let authz_grant = lookup_grant_by_code(&mut txn, &grant.code)
|
||||
mut repo: BoxRepository,
|
||||
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
|
||||
let authz_grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.find_by_code(&grant.code)
|
||||
.await?
|
||||
.ok_or(RouteError::GrantNotFound)?;
|
||||
|
||||
let now = clock.now();
|
||||
|
||||
let session = match authz_grant.stage {
|
||||
let session_id = match authz_grant.stage {
|
||||
AuthorizationGrantStage::Cancelled { cancelled_at } => {
|
||||
debug!(%cancelled_at, "Authorization grant was cancelled");
|
||||
return Err(RouteError::InvalidGrant);
|
||||
@@ -220,15 +240,20 @@ async fn authorization_code_grant(
|
||||
AuthorizationGrantStage::Exchanged {
|
||||
exchanged_at,
|
||||
fulfilled_at,
|
||||
session,
|
||||
session_id,
|
||||
} => {
|
||||
debug!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged");
|
||||
|
||||
// Ending the session if the token was already exchanged more than 20s ago
|
||||
if now - exchanged_at > Duration::seconds(20) {
|
||||
debug!("Ending potentially compromised session");
|
||||
end_oauth_session(&mut txn, &clock, session).await?;
|
||||
txn.commit().await?;
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchOAuthSession)?;
|
||||
repo.oauth2_session().finish(clock, session).await?;
|
||||
repo.save().await?;
|
||||
}
|
||||
|
||||
return Err(RouteError::InvalidGrant);
|
||||
@@ -238,7 +263,7 @@ async fn authorization_code_grant(
|
||||
return Err(RouteError::InvalidGrant);
|
||||
}
|
||||
AuthorizationGrantStage::Fulfilled {
|
||||
ref session,
|
||||
session_id,
|
||||
fulfilled_at,
|
||||
} => {
|
||||
if now - fulfilled_at > Duration::minutes(10) {
|
||||
@@ -246,14 +271,20 @@ async fn authorization_code_grant(
|
||||
return Err(RouteError::InvalidGrant);
|
||||
}
|
||||
|
||||
session
|
||||
session_id
|
||||
}
|
||||
};
|
||||
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchOAuthSession)?;
|
||||
|
||||
// This should never happen, since we looked up in the database using the code
|
||||
let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?;
|
||||
|
||||
if client.client_id != session.client.client_id {
|
||||
if client.id != session.client_id {
|
||||
return Err(RouteError::UnauthorizedClient);
|
||||
}
|
||||
|
||||
@@ -267,31 +298,25 @@ async fn authorization_code_grant(
|
||||
}
|
||||
};
|
||||
|
||||
let browser_session = &session.browser_session;
|
||||
let browser_session = repo
|
||||
.browser_session()
|
||||
.lookup(session.user_session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchBrowserSession)?;
|
||||
|
||||
let ttl = Duration::minutes(5);
|
||||
let access_token_str = TokenType::AccessToken.generate(&mut rng);
|
||||
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
|
||||
|
||||
let access_token = add_access_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
session,
|
||||
access_token_str.clone(),
|
||||
ttl,
|
||||
)
|
||||
.await?;
|
||||
let access_token = repo
|
||||
.oauth2_access_token()
|
||||
.add(&mut rng, clock, &session, access_token_str, ttl)
|
||||
.await?;
|
||||
|
||||
let _refresh_token = add_refresh_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
session,
|
||||
access_token,
|
||||
refresh_token_str.clone(),
|
||||
)
|
||||
.await?;
|
||||
let refresh_token = repo
|
||||
.oauth2_refresh_token()
|
||||
.add(&mut rng, clock, &session, &access_token, refresh_token_str)
|
||||
.await?;
|
||||
|
||||
let id_token = if session.scope.contains(&scope::OPENID) {
|
||||
let mut claims = HashMap::new();
|
||||
@@ -317,7 +342,7 @@ async fn authorization_code_grant(
|
||||
.signing_key_for_algorithm(&alg)
|
||||
.ok_or(RouteError::InvalidSigningKey)?;
|
||||
|
||||
claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token_str)?)?;
|
||||
claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token.access_token)?)?;
|
||||
claims::C_HASH.insert(&mut claims, hash_token(&alg, &grant.code)?)?;
|
||||
|
||||
let signer = key.params().signing_key_for_alg(&alg)?;
|
||||
@@ -330,34 +355,46 @@ async fn authorization_code_grant(
|
||||
None
|
||||
};
|
||||
|
||||
let mut params = AccessTokenResponse::new(access_token_str)
|
||||
let mut params = AccessTokenResponse::new(access_token.access_token)
|
||||
.with_expires_in(ttl)
|
||||
.with_refresh_token(refresh_token_str)
|
||||
.with_refresh_token(refresh_token.refresh_token)
|
||||
.with_scope(session.scope.clone());
|
||||
|
||||
if let Some(id_token) = id_token {
|
||||
params = params.with_id_token(id_token);
|
||||
}
|
||||
|
||||
exchange_grant(&mut txn, &clock, authz_grant).await?;
|
||||
repo.oauth2_authorization_grant()
|
||||
.exchange(clock, authz_grant)
|
||||
.await?;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
Ok(params)
|
||||
Ok((params, repo))
|
||||
}
|
||||
|
||||
async fn refresh_token_grant(
|
||||
mut rng: &mut BoxRng,
|
||||
clock: &impl Clock,
|
||||
grant: &RefreshTokenGrant,
|
||||
client: &Client,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
) -> Result<AccessTokenResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
|
||||
let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token)
|
||||
mut repo: BoxRepository,
|
||||
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
|
||||
let refresh_token = repo
|
||||
.oauth2_refresh_token()
|
||||
.find_by_token(&grant.refresh_token)
|
||||
.await?
|
||||
.ok_or(RouteError::InvalidGrant)?;
|
||||
|
||||
if client.client_id != session.client.client_id {
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.lookup(refresh_token.session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchOAuthSession)?;
|
||||
|
||||
if !refresh_token.is_valid() || !session.is_valid() {
|
||||
return Err(RouteError::InvalidGrant);
|
||||
}
|
||||
|
||||
if client.id != session.client_id {
|
||||
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||
return Err(RouteError::InvalidGrant);
|
||||
}
|
||||
@@ -366,30 +403,34 @@ async fn refresh_token_grant(
|
||||
let access_token_str = TokenType::AccessToken.generate(&mut rng);
|
||||
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
|
||||
|
||||
let new_access_token = add_access_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
access_token_str.clone(),
|
||||
ttl,
|
||||
)
|
||||
.await?;
|
||||
let new_access_token = repo
|
||||
.oauth2_access_token()
|
||||
.add(&mut rng, clock, &session, access_token_str.clone(), ttl)
|
||||
.await?;
|
||||
|
||||
let new_refresh_token = add_refresh_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
new_access_token,
|
||||
refresh_token_str,
|
||||
)
|
||||
.await?;
|
||||
let new_refresh_token = repo
|
||||
.oauth2_refresh_token()
|
||||
.add(
|
||||
&mut rng,
|
||||
clock,
|
||||
&session,
|
||||
&new_access_token,
|
||||
refresh_token_str,
|
||||
)
|
||||
.await?;
|
||||
|
||||
consume_refresh_token(&mut txn, &clock, &refresh_token).await?;
|
||||
let refresh_token = repo
|
||||
.oauth2_refresh_token()
|
||||
.consume(clock, refresh_token)
|
||||
.await?;
|
||||
|
||||
if let Some(access_token) = refresh_token.access_token {
|
||||
revoke_access_token(&mut txn, &clock, access_token).await?;
|
||||
if let Some(access_token_id) = refresh_token.access_token_id {
|
||||
let access_token = repo.oauth2_access_token().lookup(access_token_id).await?;
|
||||
if let Some(access_token) = access_token {
|
||||
repo.oauth2_access_token()
|
||||
.revoke(clock, access_token)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
let params = AccessTokenResponse::new(access_token_str)
|
||||
@@ -397,7 +438,5 @@ async fn refresh_token_grant(
|
||||
.with_refresh_token(new_refresh_token.refresh_token)
|
||||
.with_scope(session.scope);
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
Ok(params)
|
||||
Ok((params, repo))
|
||||
}
|
||||
|
||||
@@ -28,10 +28,14 @@ use mas_jose::{
|
||||
};
|
||||
use mas_keystore::Keystore;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::{
|
||||
oauth2::OAuth2ClientRepository,
|
||||
user::{BrowserSessionRepository, UserEmailRepository},
|
||||
BoxClock, BoxRepository, BoxRng,
|
||||
};
|
||||
use oauth2_types::scope;
|
||||
use serde::Serialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::impl_from_error_for_route;
|
||||
@@ -59,20 +63,31 @@ pub enum RouteError {
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
|
||||
#[error("failed to authenticate")]
|
||||
AuthorizationVerificationError(#[from] AuthorizationVerificationError),
|
||||
AuthorizationVerificationError(
|
||||
#[from] AuthorizationVerificationError<mas_storage::RepositoryError>,
|
||||
),
|
||||
|
||||
#[error("no suitable key found for signing")]
|
||||
InvalidSigningKey,
|
||||
|
||||
#[error("failed to load client")]
|
||||
NoSuchClient,
|
||||
|
||||
#[error("failed to load browser session")]
|
||||
NoSuchBrowserSession,
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_keystore::WrongAlgorithmError);
|
||||
impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::Internal(_) | Self::InvalidSigningKey => {
|
||||
Self::Internal(_)
|
||||
| Self::InvalidSigningKey
|
||||
| Self::NoSuchClient
|
||||
| Self::NoSuchBrowserSession => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
|
||||
}
|
||||
Self::AuthorizationVerificationError(_e) => StatusCode::UNAUTHORIZED.into_response(),
|
||||
@@ -81,32 +96,43 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
State(key_store): State<Keystore>,
|
||||
user_authorization: UserAuthorization,
|
||||
) -> Result<Response, RouteError> {
|
||||
let (_clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.acquire().await?;
|
||||
let session = user_authorization.protected(&mut repo, &clock).await?;
|
||||
|
||||
let session = user_authorization.protected(&mut conn).await?;
|
||||
let browser_session = repo
|
||||
.browser_session()
|
||||
.lookup(session.user_session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchBrowserSession)?;
|
||||
|
||||
let user = session.browser_session.user;
|
||||
let mut user_info = UserInfo {
|
||||
sub: user.sub,
|
||||
username: user.username,
|
||||
email: None,
|
||||
email_verified: None,
|
||||
let user = browser_session.user;
|
||||
|
||||
let user_email = if session.scope.contains(&scope::EMAIL) {
|
||||
repo.user_email().get_primary(&user).await?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if session.scope.contains(&scope::EMAIL) {
|
||||
if let Some(email) = user.primary_email {
|
||||
user_info.email_verified = Some(email.confirmed_at.is_some());
|
||||
user_info.email = Some(email.email);
|
||||
}
|
||||
}
|
||||
let user_info = UserInfo {
|
||||
sub: user.sub.clone(),
|
||||
username: user.username.clone(),
|
||||
email_verified: user_email.as_ref().map(|u| u.confirmed_at.is_some()),
|
||||
email: user_email.map(|u| u.email),
|
||||
};
|
||||
|
||||
if let Some(alg) = session.client.userinfo_signed_response_alg {
|
||||
let client = repo
|
||||
.oauth2_client()
|
||||
.lookup(session.client_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchClient)?;
|
||||
|
||||
if let Some(alg) = client.userinfo_signed_response_alg {
|
||||
let key = key_store
|
||||
.signing_key_for_algorithm(&alg)
|
||||
.ok_or(RouteError::InvalidSigningKey)?;
|
||||
@@ -117,7 +143,7 @@ pub async fn get(
|
||||
|
||||
let user_info = SignedUserInfo {
|
||||
iss: url_builder.oidc_issuer().to_string(),
|
||||
aud: session.client.client_id,
|
||||
aud: client.client_id,
|
||||
user_info,
|
||||
};
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ impl PasswordManager {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the hashing failed
|
||||
#[tracing::instrument(skip_all)]
|
||||
#[tracing::instrument(name = "passwords.hash", skip_all)]
|
||||
pub async fn hash<R: CryptoRng + RngCore + Send>(
|
||||
&self,
|
||||
rng: R,
|
||||
@@ -82,13 +82,16 @@ impl PasswordManager {
|
||||
let rng = rand_chacha::ChaChaRng::from_rng(rng)?;
|
||||
let hashers = self.hashers.clone();
|
||||
let default_hasher_version = self.default_hasher;
|
||||
let span = tracing::Span::current();
|
||||
|
||||
let hashed = tokio::task::spawn_blocking(move || {
|
||||
let default_hasher = hashers
|
||||
.get(&default_hasher_version)
|
||||
.context("Default hasher not found")?;
|
||||
span.in_scope(move || {
|
||||
let default_hasher = hashers
|
||||
.get(&default_hasher_version)
|
||||
.context("Default hasher not found")?;
|
||||
|
||||
default_hasher.hash_blocking(rng, &password)
|
||||
default_hasher.hash_blocking(rng, &password)
|
||||
})
|
||||
})
|
||||
.await??;
|
||||
|
||||
@@ -100,7 +103,7 @@ impl PasswordManager {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the password hash verification failed
|
||||
#[tracing::instrument(skip_all, fields(%scheme))]
|
||||
#[tracing::instrument(name = "passwords.verify", skip_all, fields(%scheme))]
|
||||
pub async fn verify(
|
||||
&self,
|
||||
scheme: SchemeVersion,
|
||||
@@ -108,10 +111,13 @@ impl PasswordManager {
|
||||
hashed_password: String,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let hashers = self.hashers.clone();
|
||||
let span = tracing::Span::current();
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let hasher = hashers.get(&scheme).context("Hashing scheme not found")?;
|
||||
hasher.verify_blocking(&hashed_password, &password)
|
||||
span.in_scope(move || {
|
||||
let hasher = hashers.get(&scheme).context("Hashing scheme not found")?;
|
||||
hasher.verify_blocking(&hashed_password, &password)
|
||||
})
|
||||
})
|
||||
.await??;
|
||||
|
||||
@@ -124,7 +130,7 @@ impl PasswordManager {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the password hash verification failed
|
||||
#[tracing::instrument(skip_all, fields(%scheme))]
|
||||
#[tracing::instrument(name = "passwords.verify_and_upgrade", skip_all, fields(%scheme))]
|
||||
pub async fn verify_and_upgrade<R: CryptoRng + RngCore + Send>(
|
||||
&self,
|
||||
rng: R,
|
||||
|
||||
@@ -22,8 +22,10 @@ use mas_axum_utils::http_client_factory::HttpClientFactory;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::upstream_oauth2::lookup_provider;
|
||||
use sqlx::PgPool;
|
||||
use mas_storage::{
|
||||
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
|
||||
BoxClock, BoxRepository, BoxRng,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
@@ -39,11 +41,10 @@ pub(crate) enum RouteError {
|
||||
Internal(Box<dyn std::error::Error>),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_http::ClientInitError);
|
||||
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
|
||||
impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
@@ -55,18 +56,18 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(http_client_factory): State<HttpClientFactory>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Path(provider_id): Path<Ulid>,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let provider = lookup_provider(&mut txn, provider_id)
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.lookup(provider_id)
|
||||
.await?
|
||||
.ok_or(RouteError::ProviderNotFound)?;
|
||||
|
||||
@@ -95,22 +96,23 @@ pub(crate) async fn get(
|
||||
&mut rng,
|
||||
)?;
|
||||
|
||||
let session = mas_storage::upstream_oauth2::add_session(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&provider,
|
||||
data.state.clone(),
|
||||
data.code_challenge_verifier,
|
||||
data.nonce,
|
||||
)
|
||||
.await?;
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&provider,
|
||||
data.state.clone(),
|
||||
data.code_challenge_verifier,
|
||||
data.nonce,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
|
||||
.add(session.id, provider.id, data.state, query.post_auth_action)
|
||||
.save(cookie_jar, clock.now());
|
||||
.save(cookie_jar, &clock);
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok((cookie_jar, Redirect::temporary(url.as_str())))
|
||||
}
|
||||
|
||||
@@ -25,12 +25,15 @@ use mas_oidc_client::requests::{
|
||||
authorization_code::AuthorizationValidationData, jose::JwtVerificationData,
|
||||
};
|
||||
use mas_router::{Route, UrlBuilder};
|
||||
use mas_storage::upstream_oauth2::{
|
||||
add_link, complete_session, lookup_link_by_subject, lookup_session,
|
||||
use mas_storage::{
|
||||
upstream_oauth2::{
|
||||
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
|
||||
UpstreamOAuthSessionRepository,
|
||||
},
|
||||
BoxClock, BoxRepository, BoxRng, Clock,
|
||||
};
|
||||
use oauth2_types::errors::ClientErrorCode;
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
@@ -64,6 +67,9 @@ pub(crate) enum RouteError {
|
||||
#[error("Session not found")]
|
||||
SessionNotFound,
|
||||
|
||||
#[error("Provider not found")]
|
||||
ProviderNotFound,
|
||||
|
||||
#[error("Provider mismatch")]
|
||||
ProviderMismatch,
|
||||
|
||||
@@ -92,9 +98,8 @@ pub(crate) enum RouteError {
|
||||
Internal(Box<dyn std::error::Error>),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_http::ClientInitError);
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
|
||||
impl_from_error_for_route!(mas_oidc_client::error::JwksError);
|
||||
impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError);
|
||||
@@ -104,6 +109,7 @@ impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
|
||||
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
|
||||
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||
e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
|
||||
@@ -113,8 +119,10 @@ impl IntoResponse for RouteError {
|
||||
|
||||
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(http_client_factory): State<HttpClientFactory>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
State(encrypter): State<Encrypter>,
|
||||
State(keystore): State<Keystore>,
|
||||
@@ -122,30 +130,34 @@ pub(crate) async fn get(
|
||||
Path(provider_id): Path<Ulid>,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
|
||||
let mut txn = pool.begin().await?;
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.lookup(provider_id)
|
||||
.await?
|
||||
.ok_or(RouteError::ProviderNotFound)?;
|
||||
|
||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||
let (session_id, _post_auth_action) = sessions_cookie
|
||||
.find_session(provider_id, ¶ms.state)
|
||||
.map_err(|_| RouteError::MissingCookie)?;
|
||||
|
||||
let (provider, session) = lookup_session(&mut txn, session_id)
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::SessionNotFound)?;
|
||||
|
||||
if provider.id != provider_id {
|
||||
if provider.id != session.provider_id {
|
||||
// The provider in the session cookie should match the one from the URL
|
||||
return Err(RouteError::ProviderMismatch);
|
||||
}
|
||||
|
||||
if params.state != session.state {
|
||||
if params.state != session.state_str {
|
||||
// The state in the session cookie should match the one from the params
|
||||
return Err(RouteError::StateMismatch);
|
||||
}
|
||||
|
||||
if session.completed() {
|
||||
if !session.is_pending() {
|
||||
// The session was already completed
|
||||
return Err(RouteError::AlreadyCompleted);
|
||||
}
|
||||
@@ -194,7 +206,7 @@ pub(crate) async fn get(
|
||||
|
||||
// TODO: all that should be borrowed
|
||||
let validation_data = AuthorizationValidationData {
|
||||
state: session.state.clone(),
|
||||
state: session.state_str.clone(),
|
||||
nonce: session.nonce.clone(),
|
||||
code_challenge_verifier: session.code_challenge_verifier.clone(),
|
||||
redirect_uri,
|
||||
@@ -231,20 +243,29 @@ pub(crate) async fn get(
|
||||
let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?;
|
||||
|
||||
// Look for an existing link
|
||||
let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject).await?;
|
||||
let maybe_link = repo
|
||||
.upstream_oauth_link()
|
||||
.find_by_subject(&provider, &subject)
|
||||
.await?;
|
||||
|
||||
let link = if let Some(link) = maybe_link {
|
||||
link
|
||||
} else {
|
||||
add_link(&mut txn, &mut rng, &clock, &provider, subject).await?
|
||||
repo.upstream_oauth_link()
|
||||
.add(&mut rng, &clock, &provider, subject)
|
||||
.await?
|
||||
};
|
||||
|
||||
let session = complete_session(&mut txn, &clock, session, &link, response.id_token).await?;
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.complete_with_link(&clock, session, &link, response.id_token)
|
||||
.await?;
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
.add_link_to_session(session.id, link.id)?
|
||||
.save(cookie_jar, clock.now());
|
||||
.save(cookie_jar, &clock);
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok((
|
||||
cookie_jar,
|
||||
|
||||
@@ -18,6 +18,7 @@ use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
|
||||
use chrono::{DateTime, Duration, NaiveDateTime, Utc};
|
||||
use mas_axum_utils::CookieExt;
|
||||
use mas_router::PostAuthAction;
|
||||
use mas_storage::Clock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use time::OffsetDateTime;
|
||||
@@ -65,11 +66,11 @@ impl UpstreamSessions {
|
||||
}
|
||||
|
||||
/// Save the upstreams sessions to the cookie jar
|
||||
pub fn save<K>(
|
||||
self,
|
||||
cookie_jar: PrivateCookieJar<K>,
|
||||
now: DateTime<Utc>,
|
||||
) -> PrivateCookieJar<K> {
|
||||
pub fn save<K, C>(self, cookie_jar: PrivateCookieJar<K>, clock: &C) -> PrivateCookieJar<K>
|
||||
where
|
||||
C: Clock,
|
||||
{
|
||||
let now = clock.now();
|
||||
let this = self.expire(now);
|
||||
let mut cookie = Cookie::named(COOKIE_NAME).encode(&this);
|
||||
cookie.set_path("/");
|
||||
|
||||
@@ -25,17 +25,15 @@ use mas_axum_utils::{
|
||||
};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_storage::{
|
||||
upstream_oauth2::{
|
||||
associate_link_to_user, consume_session, lookup_link, lookup_session_on_link,
|
||||
},
|
||||
user::{add_user, authenticate_session_with_upstream, lookup_user, start_session},
|
||||
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
|
||||
user::{BrowserSessionRepository, UserRepository},
|
||||
BoxClock, BoxRepository, BoxRng,
|
||||
};
|
||||
use mas_templates::{
|
||||
EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
|
||||
UpstreamSuggestLink,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
@@ -52,6 +50,10 @@ pub(crate) enum RouteError {
|
||||
#[error("Session not found")]
|
||||
SessionNotFound,
|
||||
|
||||
/// Couldn't find the user
|
||||
#[error("User not found")]
|
||||
UserNotFound,
|
||||
|
||||
/// Session was already consumed
|
||||
#[error("Session already consumed")]
|
||||
SessionConsumed,
|
||||
@@ -66,11 +68,10 @@ pub(crate) enum RouteError {
|
||||
Internal(Box<dyn std::error::Error>),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
|
||||
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
@@ -91,48 +92,60 @@ pub(crate) enum FormData {
|
||||
}
|
||||
|
||||
pub(crate) async fn get(
|
||||
State(pool): State<PgPool>,
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
State(templates): State<Templates>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Path(link_id): Path<Ulid>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let mut txn = pool.begin().await?;
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
|
||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||
let (session_id, _post_auth_action) = sessions_cookie
|
||||
.lookup_link(link_id)
|
||||
.map_err(|_| RouteError::MissingCookie)?;
|
||||
|
||||
let link = lookup_link(&mut txn, link_id)
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.lookup(link_id)
|
||||
.await?
|
||||
.ok_or(RouteError::LinkNotFound)?;
|
||||
|
||||
// This checks that we're in a browser session which is allowed to consume this
|
||||
// link: the upstream auth session should have been started in this browser.
|
||||
let upstream_session = lookup_session_on_link(&mut txn, &link, session_id)
|
||||
let upstream_session = repo
|
||||
.upstream_oauth_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::SessionNotFound)?;
|
||||
|
||||
if upstream_session.consumed() {
|
||||
// This checks that we're in a browser session which is allowed to consume this
|
||||
// link: the upstream auth session should have been started in this browser.
|
||||
if upstream_session.link_id() != Some(link.id) {
|
||||
return Err(RouteError::SessionNotFound);
|
||||
}
|
||||
|
||||
if upstream_session.is_consumed() {
|
||||
return Err(RouteError::SessionConsumed);
|
||||
}
|
||||
|
||||
let (user_session_info, cookie_jar) = cookie_jar.session_info();
|
||||
let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
|
||||
let maybe_user_session = user_session_info.load_session(&mut txn).await?;
|
||||
let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
|
||||
|
||||
let render = match (maybe_user_session, link.user_id) {
|
||||
(Some(mut session), Some(user_id)) if session.user.id == user_id => {
|
||||
(Some(session), Some(user_id)) if session.user.id == user_id => {
|
||||
// Session already linked, and link matches the currently logged
|
||||
// user. Mark the session as consumed and renew the authentication.
|
||||
consume_session(&mut txn, &clock, upstream_session).await?;
|
||||
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link)
|
||||
repo.upstream_oauth_session()
|
||||
.consume(&clock, upstream_session)
|
||||
.await?;
|
||||
|
||||
let session = repo
|
||||
.browser_session()
|
||||
.authenticate_with_upstream(&mut rng, &clock, session, &link)
|
||||
.await?;
|
||||
|
||||
cookie_jar = cookie_jar.set_session(&session);
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
let ctx = EmptyContext
|
||||
.with_session(session)
|
||||
@@ -147,7 +160,11 @@ pub(crate) async fn get(
|
||||
// Session already linked, but link doesn't match the currently
|
||||
// logged user. Suggest logging out of the current user
|
||||
// and logging in with the new one
|
||||
let user = lookup_user(&mut txn, user_id).await?;
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(user_id)
|
||||
.await?
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
|
||||
let ctx = UpstreamExistingLinkContext::new(user)
|
||||
.with_session(user_session)
|
||||
@@ -167,7 +184,11 @@ pub(crate) async fn get(
|
||||
|
||||
(None, Some(user_id)) => {
|
||||
// Session linked, but user not logged in: do the login
|
||||
let user = lookup_user(&mut txn, user_id).await?;
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(user_id)
|
||||
.await?
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
|
||||
let ctx = UpstreamExistingLinkContext::new(user).with_csrf(csrf_token.form_value());
|
||||
|
||||
@@ -187,14 +208,14 @@ pub(crate) async fn get(
|
||||
}
|
||||
|
||||
pub(crate) async fn post(
|
||||
State(pool): State<PgPool>,
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Path(link_id): Path<Ulid>,
|
||||
Form(form): Form<ProtectedForm<FormData>>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let mut txn = pool.begin().await?;
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let form = cookie_jar.verify_form(clock.now(), form)?;
|
||||
let form = cookie_jar.verify_form(&clock, form)?;
|
||||
|
||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||
let (session_id, post_auth_action) = sessions_cookie
|
||||
@@ -205,53 +226,77 @@ pub(crate) async fn post(
|
||||
post_auth_action: post_auth_action.cloned(),
|
||||
};
|
||||
|
||||
let link = lookup_link(&mut txn, link_id)
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.lookup(link_id)
|
||||
.await?
|
||||
.ok_or(RouteError::LinkNotFound)?;
|
||||
|
||||
// This checks that we're in a browser session which is allowed to consume this
|
||||
// link: the upstream auth session should have been started in this browser.
|
||||
let upstream_session = lookup_session_on_link(&mut txn, &link, session_id)
|
||||
let upstream_session = repo
|
||||
.upstream_oauth_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::SessionNotFound)?;
|
||||
|
||||
if upstream_session.consumed() {
|
||||
// This checks that we're in a browser session which is allowed to consume this
|
||||
// link: the upstream auth session should have been started in this browser.
|
||||
if upstream_session.link_id() != Some(link.id) {
|
||||
return Err(RouteError::SessionNotFound);
|
||||
}
|
||||
|
||||
if upstream_session.is_consumed() {
|
||||
return Err(RouteError::SessionConsumed);
|
||||
}
|
||||
|
||||
let (user_session_info, cookie_jar) = cookie_jar.session_info();
|
||||
let maybe_user_session = user_session_info.load_session(&mut txn).await?;
|
||||
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
|
||||
|
||||
let mut session = match (maybe_user_session, link.user_id, form) {
|
||||
let session = match (maybe_user_session, link.user_id, form) {
|
||||
(Some(session), None, FormData::Link) => {
|
||||
associate_link_to_user(&mut txn, &link, &session.user).await?;
|
||||
repo.upstream_oauth_link()
|
||||
.associate_to_user(&link, &session.user)
|
||||
.await?;
|
||||
|
||||
session
|
||||
}
|
||||
|
||||
(None, Some(user_id), FormData::Login) => {
|
||||
let user = lookup_user(&mut txn, user_id).await?;
|
||||
start_session(&mut txn, &mut rng, &clock, user).await?
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(user_id)
|
||||
.await?
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
|
||||
repo.browser_session().add(&mut rng, &clock, &user).await?
|
||||
}
|
||||
|
||||
(None, None, FormData::Register { username }) => {
|
||||
let user = add_user(&mut txn, &mut rng, &clock, &username).await?;
|
||||
associate_link_to_user(&mut txn, &link, &user).await?;
|
||||
let user = repo.user().add(&mut rng, &clock, username).await?;
|
||||
repo.upstream_oauth_link()
|
||||
.associate_to_user(&link, &user)
|
||||
.await?;
|
||||
|
||||
start_session(&mut txn, &mut rng, &clock, user).await?
|
||||
repo.browser_session().add(&mut rng, &clock, &user).await?
|
||||
}
|
||||
|
||||
_ => return Err(RouteError::InvalidFormAction),
|
||||
};
|
||||
|
||||
consume_session(&mut txn, &clock, upstream_session).await?;
|
||||
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link).await?;
|
||||
repo.upstream_oauth_session()
|
||||
.consume(&clock, upstream_session)
|
||||
.await?;
|
||||
|
||||
let session = repo
|
||||
.browser_session()
|
||||
.authenticate_with_upstream(&mut rng, &clock, session, &link)
|
||||
.await?;
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
.consume_link(link_id)?
|
||||
.save(cookie_jar, clock.now());
|
||||
.save(cookie_jar, &clock);
|
||||
let cookie_jar = cookie_jar.set_session(&session);
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok((cookie_jar, post_auth_action.go_next()))
|
||||
}
|
||||
|
||||
@@ -24,10 +24,9 @@ use mas_axum_utils::{
|
||||
use mas_email::Mailer;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::Route;
|
||||
use mas_storage::user::add_user_email;
|
||||
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng};
|
||||
use mas_templates::{EmailAddContext, TemplateContext, Templates};
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use super::start_email_verification;
|
||||
use crate::views::shared::OptionalPostAuthAction;
|
||||
@@ -38,17 +37,16 @@ pub struct EmailForm {
|
||||
}
|
||||
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.begin().await?;
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut conn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let session = if let Some(session) = maybe_session {
|
||||
session
|
||||
@@ -67,19 +65,18 @@ pub(crate) async fn get(
|
||||
}
|
||||
|
||||
pub(crate) async fn post(
|
||||
State(pool): State<PgPool>,
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
State(mailer): State<Mailer>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
Form(form): Form<ProtectedForm<EmailForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(clock.now(), form)?;
|
||||
let form = cookie_jar.verify_form(&clock, form)?;
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let session = if let Some(session) = maybe_session {
|
||||
session
|
||||
@@ -88,7 +85,11 @@ pub(crate) async fn post(
|
||||
return Ok((cookie_jar, login.go()).into_response());
|
||||
};
|
||||
|
||||
let user_email = add_user_email(&mut txn, &mut rng, &clock, &session.user, form.email).await?;
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.add(&mut rng, &clock, &session.user, form.email)
|
||||
.await?;
|
||||
|
||||
let next = mas_router::AccountVerifyEmail::new(user_email.id);
|
||||
let next = if let Some(action) = query.post_auth_action {
|
||||
next.and_then(action)
|
||||
@@ -97,7 +98,7 @@ pub(crate) async fn post(
|
||||
};
|
||||
start_email_verification(
|
||||
&mailer,
|
||||
&mut txn,
|
||||
&mut repo,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session.user,
|
||||
@@ -105,7 +106,7 @@ pub(crate) async fn post(
|
||||
)
|
||||
.await?;
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok((cookie_jar, next.go()).into_response())
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use axum::{
|
||||
extract::{Form, State},
|
||||
response::{Html, IntoResponse, Response},
|
||||
@@ -28,16 +29,11 @@ use mas_email::Mailer;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::Route;
|
||||
use mas_storage::{
|
||||
user::{
|
||||
add_user_email, add_user_email_verification_code, get_user_email, get_user_emails,
|
||||
remove_user_email, set_user_email_as_primary,
|
||||
},
|
||||
Clock,
|
||||
user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
|
||||
};
|
||||
use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates};
|
||||
use rand::{distributions::Uniform, Rng};
|
||||
use serde::Deserialize;
|
||||
use sqlx::{PgExecutor, PgPool};
|
||||
use tracing::info;
|
||||
|
||||
pub mod add;
|
||||
@@ -53,37 +49,35 @@ pub enum ManagementForm {
|
||||
}
|
||||
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut conn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
if let Some(session) = maybe_session {
|
||||
render(&mut rng, &clock, templates, session, cookie_jar, &mut conn).await
|
||||
render(&mut rng, &clock, templates, session, cookie_jar, &mut repo).await
|
||||
} else {
|
||||
let login = mas_router::Login::default();
|
||||
Ok((cookie_jar, login.go()).into_response())
|
||||
}
|
||||
}
|
||||
|
||||
async fn render(
|
||||
async fn render<E: std::error::Error>(
|
||||
rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
clock: &impl Clock,
|
||||
templates: Templates,
|
||||
session: BrowserSession,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
executor: impl PgExecutor<'_>,
|
||||
repo: &mut impl RepositoryAccess<Error = E>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
|
||||
|
||||
let emails = get_user_emails(executor, &session.user).await?;
|
||||
let emails = repo.user_email().all(&session.user).await?;
|
||||
|
||||
let ctx = AccountEmailsContext::new(emails)
|
||||
.with_session(session)
|
||||
@@ -94,11 +88,11 @@ async fn render(
|
||||
Ok((cookie_jar, Html(content)).into_response())
|
||||
}
|
||||
|
||||
async fn start_email_verification(
|
||||
async fn start_email_verification<E: std::error::Error + Send + Sync + 'static>(
|
||||
mailer: &Mailer,
|
||||
executor: impl PgExecutor<'_>,
|
||||
repo: &mut impl RepositoryAccess<Error = E>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
clock: &impl Clock,
|
||||
user: &User,
|
||||
user_email: UserEmail,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -108,15 +102,10 @@ async fn start_email_verification(
|
||||
|
||||
let address: Address = user_email.email.parse()?;
|
||||
|
||||
let verification = add_user_email_verification_code(
|
||||
executor,
|
||||
&mut rng,
|
||||
clock,
|
||||
user_email,
|
||||
Duration::hours(8),
|
||||
code,
|
||||
)
|
||||
.await?;
|
||||
let verification = repo
|
||||
.user_email()
|
||||
.add_verification_code(&mut rng, clock, &user_email, Duration::hours(8), code)
|
||||
.await?;
|
||||
|
||||
// And send the verification email
|
||||
let mailbox = Mailbox::new(Some(user.username.clone()), address);
|
||||
@@ -126,25 +115,24 @@ async fn start_email_verification(
|
||||
mailer.send_verification_email(mailbox, &context).await?;
|
||||
|
||||
info!(
|
||||
email.id = %verification.email.id,
|
||||
email.id = %user_email.id,
|
||||
"Verification email sent"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
State(mailer): State<Mailer>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<ManagementForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let mut session = if let Some(session) = maybe_session {
|
||||
session
|
||||
@@ -153,53 +141,69 @@ pub(crate) async fn post(
|
||||
return Ok((cookie_jar, login.go()).into_response());
|
||||
};
|
||||
|
||||
let form = cookie_jar.verify_form(clock.now(), form)?;
|
||||
let form = cookie_jar.verify_form(&clock, form)?;
|
||||
|
||||
match form {
|
||||
ManagementForm::Add { email } => {
|
||||
let user_email =
|
||||
add_user_email(&mut txn, &mut rng, &clock, &session.user, email).await?;
|
||||
let next = mas_router::AccountVerifyEmail::new(user_email.id);
|
||||
start_email_verification(
|
||||
&mailer,
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session.user,
|
||||
user_email,
|
||||
)
|
||||
.await?;
|
||||
txn.commit().await?;
|
||||
let email = repo
|
||||
.user_email()
|
||||
.add(&mut rng, &clock, &session.user, email)
|
||||
.await?;
|
||||
|
||||
let next = mas_router::AccountVerifyEmail::new(email.id);
|
||||
start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email)
|
||||
.await?;
|
||||
repo.save().await?;
|
||||
return Ok((cookie_jar, next.go()).into_response());
|
||||
}
|
||||
ManagementForm::ResendConfirmation { id } => {
|
||||
let id = id.parse()?;
|
||||
|
||||
let user_email = get_user_email(&mut txn, &session.user, id).await?;
|
||||
let next = mas_router::AccountVerifyEmail::new(user_email.id);
|
||||
start_email_verification(
|
||||
&mailer,
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session.user,
|
||||
user_email,
|
||||
)
|
||||
.await?;
|
||||
txn.commit().await?;
|
||||
let email = repo
|
||||
.user_email()
|
||||
.lookup(id)
|
||||
.await?
|
||||
.context("Email not found")?;
|
||||
|
||||
if email.user_id != session.user.id {
|
||||
return Err(anyhow!("Email not found").into());
|
||||
}
|
||||
|
||||
let next = mas_router::AccountVerifyEmail::new(email.id);
|
||||
start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email)
|
||||
.await?;
|
||||
repo.save().await?;
|
||||
return Ok((cookie_jar, next.go()).into_response());
|
||||
}
|
||||
ManagementForm::Remove { id } => {
|
||||
let id = id.parse()?;
|
||||
|
||||
let email = get_user_email(&mut txn, &session.user, id).await?;
|
||||
remove_user_email(&mut txn, email).await?;
|
||||
let email = repo
|
||||
.user_email()
|
||||
.lookup(id)
|
||||
.await?
|
||||
.context("Email not found")?;
|
||||
|
||||
if email.user_id != session.user.id {
|
||||
return Err(anyhow!("Email not found").into());
|
||||
}
|
||||
|
||||
repo.user_email().remove(email).await?;
|
||||
}
|
||||
ManagementForm::SetPrimary { id } => {
|
||||
let id = id.parse()?;
|
||||
let email = get_user_email(&mut txn, &session.user, id).await?;
|
||||
set_user_email_as_primary(&mut txn, &email).await?;
|
||||
session.user.primary_email = Some(email);
|
||||
let email = repo
|
||||
.user_email()
|
||||
.lookup(id)
|
||||
.await?
|
||||
.context("Email not found")?;
|
||||
|
||||
if email.user_id != session.user.id {
|
||||
return Err(anyhow!("Email not found").into());
|
||||
}
|
||||
|
||||
repo.user_email().set_as_primary(&email).await?;
|
||||
session.user.primary_user_email_id = Some(email.id);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -209,11 +213,11 @@ pub(crate) async fn post(
|
||||
templates.clone(),
|
||||
session,
|
||||
cookie_jar,
|
||||
&mut txn,
|
||||
&mut repo,
|
||||
)
|
||||
.await?;
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok(reply)
|
||||
}
|
||||
|
||||
@@ -24,16 +24,9 @@ use mas_axum_utils::{
|
||||
};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::Route;
|
||||
use mas_storage::{
|
||||
user::{
|
||||
consume_email_verification, lookup_user_email_by_id, lookup_user_email_verification_code,
|
||||
mark_user_email_as_verified, set_user_email_as_primary,
|
||||
},
|
||||
Clock,
|
||||
};
|
||||
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng};
|
||||
use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates};
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::views::shared::OptionalPostAuthAction;
|
||||
@@ -44,19 +37,18 @@ pub struct CodeForm {
|
||||
}
|
||||
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
Path(id): Path<Ulid>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut conn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let session = if let Some(session) = maybe_session {
|
||||
session
|
||||
@@ -65,8 +57,11 @@ pub(crate) async fn get(
|
||||
return Ok((cookie_jar, login.go()).into_response());
|
||||
};
|
||||
|
||||
let user_email = lookup_user_email_by_id(&mut conn, &session.user, id)
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.lookup(id)
|
||||
.await?
|
||||
.filter(|u| u.user_id == session.user.id)
|
||||
.context("Could not find user email")?;
|
||||
|
||||
if user_email.confirmed_at.is_some() {
|
||||
@@ -85,19 +80,17 @@ pub(crate) async fn get(
|
||||
}
|
||||
|
||||
pub(crate) async fn post(
|
||||
State(pool): State<PgPool>,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
Path(id): Path<Ulid>,
|
||||
Form(form): Form<ProtectedForm<CodeForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let clock = Clock::default();
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(clock.now(), form)?;
|
||||
let form = cookie_jar.verify_form(&clock, form)?;
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let session = if let Some(session) = maybe_session {
|
||||
session
|
||||
@@ -106,25 +99,33 @@ pub(crate) async fn post(
|
||||
return Ok((cookie_jar, login.go()).into_response());
|
||||
};
|
||||
|
||||
let email = lookup_user_email_by_id(&mut txn, &session.user, id)
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.lookup(id)
|
||||
.await?
|
||||
.filter(|u| u.user_id == session.user.id)
|
||||
.context("Could not find user email")?;
|
||||
|
||||
if session.user.primary_email.is_none() {
|
||||
set_user_email_as_primary(&mut txn, &email).await?;
|
||||
}
|
||||
|
||||
// TODO: make those 8 hours configurable
|
||||
let verification = lookup_user_email_verification_code(&mut txn, &clock, email, &form.code)
|
||||
let verification = repo
|
||||
.user_email()
|
||||
.find_verification_code(&clock, &user_email, &form.code)
|
||||
.await?
|
||||
.context("Invalid code")?;
|
||||
|
||||
// TODO: display nice errors if the code was already consumed or expired
|
||||
let verification = consume_email_verification(&mut txn, &clock, verification).await?;
|
||||
repo.user_email()
|
||||
.consume_verification_code(&clock, verification)
|
||||
.await?;
|
||||
|
||||
let _email = mark_user_email_as_verified(&mut txn, &clock, verification.email).await?;
|
||||
if session.user.primary_user_email_id.is_none() {
|
||||
repo.user_email().set_as_primary(&user_email).await?;
|
||||
}
|
||||
|
||||
txn.commit().await?;
|
||||
repo.user_email()
|
||||
.mark_as_verified(&clock, user_email)
|
||||
.await?;
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
let destination = query.go_next_or_default(&mas_router::AccountEmails);
|
||||
Ok((cookie_jar, destination).into_response())
|
||||
|
||||
@@ -23,22 +23,23 @@ use axum_extra::extract::PrivateCookieJar;
|
||||
use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::Route;
|
||||
use mas_storage::user::{count_active_sessions, get_user_emails};
|
||||
use mas_storage::{
|
||||
user::{BrowserSessionRepository, UserEmailRepository},
|
||||
BoxClock, BoxRepository, BoxRng,
|
||||
};
|
||||
use mas_templates::{AccountContext, TemplateContext, Templates};
|
||||
use sqlx::PgPool;
|
||||
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut conn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let session = if let Some(session) = maybe_session {
|
||||
session
|
||||
@@ -47,9 +48,9 @@ pub(crate) async fn get(
|
||||
return Ok((cookie_jar, login.go()).into_response());
|
||||
};
|
||||
|
||||
let active_sessions = count_active_sessions(&mut conn, &session.user).await?;
|
||||
let active_sessions = repo.browser_session().count_active(&session.user).await?;
|
||||
|
||||
let emails = get_user_emails(&mut conn, &session.user).await?;
|
||||
let emails = repo.user_email().all(&session.user).await?;
|
||||
|
||||
let ctx = AccountContext::new(active_sessions, emails)
|
||||
.with_session(session)
|
||||
|
||||
@@ -26,13 +26,12 @@ use mas_data_model::BrowserSession;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::Route;
|
||||
use mas_storage::{
|
||||
user::{add_user_password, authenticate_session_with_password, lookup_user_password},
|
||||
Clock,
|
||||
user::{BrowserSessionRepository, UserPasswordRepository},
|
||||
BoxClock, BoxRepository, BoxRng, Clock,
|
||||
};
|
||||
use mas_templates::{EmptyContext, TemplateContext, Templates};
|
||||
use rand::Rng;
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use crate::passwords::PasswordManager;
|
||||
@@ -45,16 +44,15 @@ pub struct ChangeForm {
|
||||
}
|
||||
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut conn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
if let Some(session) = maybe_session {
|
||||
render(&mut rng, &clock, templates, session, cookie_jar).await
|
||||
@@ -66,12 +64,12 @@ pub(crate) async fn get(
|
||||
|
||||
async fn render(
|
||||
rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
clock: &impl Clock,
|
||||
templates: Templates,
|
||||
session: BrowserSession,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
|
||||
|
||||
let ctx = EmptyContext
|
||||
.with_session(session)
|
||||
@@ -83,29 +81,30 @@ async fn render(
|
||||
}
|
||||
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(password_manager): State<PasswordManager>,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<ChangeForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(clock.now(), form)?;
|
||||
let form = cookie_jar.verify_form(&clock, form)?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let mut session = if let Some(session) = maybe_session {
|
||||
let session = if let Some(session) = maybe_session {
|
||||
session
|
||||
} else {
|
||||
let login = mas_router::Login::and_then(mas_router::PostAuthAction::ChangePassword);
|
||||
return Ok((cookie_jar, login.go()).into_response());
|
||||
};
|
||||
|
||||
let user_password = lookup_user_password(&mut txn, &session.user)
|
||||
let user_password = repo
|
||||
.user_password()
|
||||
.active(&session.user)
|
||||
.await?
|
||||
.context("user has no password")?;
|
||||
|
||||
@@ -127,23 +126,26 @@ pub(crate) async fn post(
|
||||
}
|
||||
|
||||
let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?;
|
||||
let user_password = add_user_password(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session.user,
|
||||
version,
|
||||
hashed_password,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
let user_password = repo
|
||||
.user_password()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session.user,
|
||||
version,
|
||||
hashed_password,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
authenticate_session_with_password(&mut txn, &mut rng, &clock, &mut session, &user_password)
|
||||
let session = repo
|
||||
.browser_session()
|
||||
.authenticate_with_password(&mut rng, &clock, session, &user_password)
|
||||
.await?;
|
||||
|
||||
let reply = render(&mut rng, &clock, templates.clone(), session, cookie_jar).await?;
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok(reply)
|
||||
}
|
||||
|
||||
@@ -20,21 +20,20 @@ use axum_extra::extract::PrivateCookieJar;
|
||||
use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::{BoxClock, BoxRepository, BoxRng};
|
||||
use mas_templates::{IndexContext, TemplateContext, Templates};
|
||||
use sqlx::PgPool;
|
||||
|
||||
pub async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
) -> Result<impl IntoResponse, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
let session = session_info.load_session(&mut conn).await?;
|
||||
let session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let ctx = IndexContext::new(url_builder.oidc_discovery())
|
||||
.maybe_with_session(session)
|
||||
|
||||
@@ -24,18 +24,15 @@ use mas_axum_utils::{
|
||||
use mas_data_model::BrowserSession;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_storage::{
|
||||
user::{
|
||||
add_user_password, authenticate_session_with_password, lookup_user_by_username,
|
||||
lookup_user_password, start_session,
|
||||
},
|
||||
Clock,
|
||||
upstream_oauth2::UpstreamOAuthProviderRepository,
|
||||
user::{BrowserSessionRepository, UserPasswordRepository, UserRepository},
|
||||
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
|
||||
};
|
||||
use mas_templates::{
|
||||
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
|
||||
};
|
||||
use rand::{CryptoRng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{PgConnection, PgPool};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use super::shared::OptionalPostAuthAction;
|
||||
@@ -52,29 +49,28 @@ impl ToFormState for LoginForm {
|
||||
}
|
||||
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut conn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
if maybe_session.is_some() {
|
||||
let reply = query.go_next();
|
||||
Ok((cookie_jar, reply).into_response())
|
||||
} else {
|
||||
let providers = mas_storage::upstream_oauth2::get_providers(&mut conn).await?;
|
||||
let providers = repo.upstream_oauth_provider().all().await?;
|
||||
let content = render(
|
||||
LoginContext::default().with_upstrem_providers(providers),
|
||||
query,
|
||||
csrf_token,
|
||||
&mut conn,
|
||||
&mut repo,
|
||||
&templates,
|
||||
)
|
||||
.await?;
|
||||
@@ -84,19 +80,18 @@ pub(crate) async fn get(
|
||||
}
|
||||
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(password_manager): State<PasswordManager>,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<LoginForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.acquire().await?;
|
||||
let form = cookie_jar.verify_form(&clock, form)?;
|
||||
|
||||
let form = cookie_jar.verify_form(clock.now(), form)?;
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
|
||||
// Validate the form
|
||||
let state = {
|
||||
@@ -114,14 +109,14 @@ pub(crate) async fn post(
|
||||
};
|
||||
|
||||
if !state.is_valid() {
|
||||
let providers = mas_storage::upstream_oauth2::get_providers(&mut conn).await?;
|
||||
let providers = repo.upstream_oauth_provider().all().await?;
|
||||
let content = render(
|
||||
LoginContext::default()
|
||||
.with_form_state(state)
|
||||
.with_upstrem_providers(providers),
|
||||
query,
|
||||
csrf_token,
|
||||
&mut conn,
|
||||
&mut repo,
|
||||
&templates,
|
||||
)
|
||||
.await?;
|
||||
@@ -129,11 +124,9 @@ pub(crate) async fn post(
|
||||
return Ok((cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
lookup_user_by_username(&mut conn, &form.username).await?;
|
||||
|
||||
match login(
|
||||
password_manager,
|
||||
&mut conn,
|
||||
&mut repo,
|
||||
rng,
|
||||
&clock,
|
||||
&form.username,
|
||||
@@ -142,6 +135,8 @@ pub(crate) async fn post(
|
||||
.await
|
||||
{
|
||||
Ok(session_info) => {
|
||||
repo.save().await?;
|
||||
|
||||
let cookie_jar = cookie_jar.set_session(&session_info);
|
||||
let reply = query.go_next();
|
||||
Ok((cookie_jar, reply).into_response())
|
||||
@@ -153,7 +148,7 @@ pub(crate) async fn post(
|
||||
LoginContext::default().with_form_state(state),
|
||||
query,
|
||||
csrf_token,
|
||||
&mut conn,
|
||||
&mut repo,
|
||||
&templates,
|
||||
)
|
||||
.await?;
|
||||
@@ -166,21 +161,25 @@ pub(crate) async fn post(
|
||||
// TODO: move that logic elsewhere?
|
||||
async fn login(
|
||||
password_manager: PasswordManager,
|
||||
conn: &mut PgConnection,
|
||||
repo: &mut impl RepositoryAccess,
|
||||
mut rng: impl Rng + CryptoRng + Send,
|
||||
clock: &Clock,
|
||||
clock: &impl Clock,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> Result<BrowserSession, FormError> {
|
||||
// XXX: we're loosing the error context here
|
||||
// First, lookup the user
|
||||
let user = lookup_user_by_username(&mut *conn, username)
|
||||
let user = repo
|
||||
.user()
|
||||
.find_by_username(username)
|
||||
.await
|
||||
.map_err(|_e| FormError::Internal)?
|
||||
.ok_or(FormError::InvalidCredentials)?;
|
||||
|
||||
// And its password
|
||||
let user_password = lookup_user_password(&mut *conn, &user)
|
||||
let user_password = repo
|
||||
.user_password()
|
||||
.active(&user)
|
||||
.await
|
||||
.map_err(|_e| FormError::Internal)?
|
||||
.ok_or(FormError::InvalidCredentials)?;
|
||||
@@ -200,28 +199,32 @@ async fn login(
|
||||
|
||||
let user_password = if let Some((version, new_password_hash)) = new_password_hash {
|
||||
// Save the upgraded password
|
||||
add_user_password(
|
||||
&mut *conn,
|
||||
&mut rng,
|
||||
clock,
|
||||
&user,
|
||||
version,
|
||||
new_password_hash,
|
||||
Some(user_password),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| FormError::Internal)?
|
||||
repo.user_password()
|
||||
.add(
|
||||
&mut rng,
|
||||
clock,
|
||||
&user,
|
||||
version,
|
||||
new_password_hash,
|
||||
Some(&user_password),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| FormError::Internal)?
|
||||
} else {
|
||||
user_password
|
||||
};
|
||||
|
||||
// Start a new session
|
||||
let mut user_session = start_session(&mut *conn, &mut rng, clock, user)
|
||||
let user_session = repo
|
||||
.browser_session()
|
||||
.add(&mut rng, clock, &user)
|
||||
.await
|
||||
.map_err(|_| FormError::Internal)?;
|
||||
|
||||
// And mark it as authenticated by the password
|
||||
authenticate_session_with_password(&mut *conn, rng, clock, &mut user_session, &user_password)
|
||||
let user_session = repo
|
||||
.browser_session()
|
||||
.authenticate_with_password(&mut rng, clock, user_session, &user_password)
|
||||
.await
|
||||
.map_err(|_| FormError::Internal)?;
|
||||
|
||||
@@ -232,10 +235,10 @@ async fn render(
|
||||
ctx: LoginContext,
|
||||
action: OptionalPostAuthAction,
|
||||
csrf_token: CsrfToken,
|
||||
conn: &mut PgConnection,
|
||||
repo: &mut impl RepositoryAccess,
|
||||
templates: &Templates,
|
||||
) -> Result<String, FancyError> {
|
||||
let next = action.load_context(conn).await?;
|
||||
let next = action.load_context(repo).await?;
|
||||
let ctx = if let Some(next) = next {
|
||||
ctx.with_post_action(next)
|
||||
} else {
|
||||
|
||||
@@ -12,10 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use axum::{
|
||||
extract::{Form, State},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use axum::{extract::Form, response::IntoResponse};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use mas_axum_utils::{
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
@@ -23,29 +20,26 @@ use mas_axum_utils::{
|
||||
};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::{PostAuthAction, Route};
|
||||
use mas_storage::{user::end_session, Clock};
|
||||
use sqlx::PgPool;
|
||||
use mas_storage::{user::BrowserSessionRepository, BoxClock, BoxRepository};
|
||||
|
||||
pub(crate) async fn post(
|
||||
State(pool): State<PgPool>,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<Option<PostAuthAction>>>,
|
||||
) -> Result<impl IntoResponse, FancyError> {
|
||||
let clock = Clock::default();
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(clock.now(), form)?;
|
||||
let form = cookie_jar.verify_form(&clock, form)?;
|
||||
|
||||
let (session_info, mut cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
if let Some(session) = maybe_session {
|
||||
end_session(&mut txn, &clock, &session).await?;
|
||||
repo.browser_session().finish(&clock, session).await?;
|
||||
cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended());
|
||||
}
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
let destination = if let Some(action) = form {
|
||||
action.go_next()
|
||||
|
||||
@@ -24,12 +24,12 @@ use mas_axum_utils::{
|
||||
};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::Route;
|
||||
use mas_storage::user::{
|
||||
add_user_password, authenticate_session_with_password, lookup_user_password,
|
||||
use mas_storage::{
|
||||
user::{BrowserSessionRepository, UserPasswordRepository},
|
||||
BoxClock, BoxRepository, BoxRng,
|
||||
};
|
||||
use mas_templates::{ReauthContext, TemplateContext, Templates};
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use super::shared::OptionalPostAuthAction;
|
||||
@@ -41,18 +41,17 @@ pub(crate) struct ReauthForm {
|
||||
}
|
||||
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut conn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let session = if let Some(session) = maybe_session {
|
||||
session
|
||||
@@ -64,7 +63,7 @@ pub(crate) async fn get(
|
||||
};
|
||||
|
||||
let ctx = ReauthContext::default();
|
||||
let next = query.load_context(&mut conn).await?;
|
||||
let next = query.load_context(&mut repo).await?;
|
||||
let ctx = if let Some(next) = next {
|
||||
ctx.with_post_action(next)
|
||||
} else {
|
||||
@@ -78,22 +77,21 @@ pub(crate) async fn get(
|
||||
}
|
||||
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(password_manager): State<PasswordManager>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<ReauthForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(clock.now(), form)?;
|
||||
let form = cookie_jar.verify_form(&clock, form)?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
let mut session = if let Some(session) = maybe_session {
|
||||
let session = if let Some(session) = maybe_session {
|
||||
session
|
||||
} else {
|
||||
// If there is no session, redirect to the login screen, keeping the
|
||||
@@ -103,7 +101,9 @@ pub(crate) async fn post(
|
||||
};
|
||||
|
||||
// Load the user password
|
||||
let user_password = lookup_user_password(&mut txn, &session.user)
|
||||
let user_password = repo
|
||||
.user_password()
|
||||
.active(&session.user)
|
||||
.await?
|
||||
.context("User has no password")?;
|
||||
|
||||
@@ -122,25 +122,28 @@ pub(crate) async fn post(
|
||||
|
||||
let user_password = if let Some((version, new_password_hash)) = new_password_hash {
|
||||
// Save the upgraded password
|
||||
add_user_password(
|
||||
&mut *txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session.user,
|
||||
version,
|
||||
new_password_hash,
|
||||
Some(user_password),
|
||||
)
|
||||
.await?
|
||||
repo.user_password()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session.user,
|
||||
version,
|
||||
new_password_hash,
|
||||
Some(&user_password),
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
user_password
|
||||
};
|
||||
|
||||
// Mark the session as authenticated by the password
|
||||
authenticate_session_with_password(&mut txn, rng, &clock, &mut session, &user_password).await?;
|
||||
let session = repo
|
||||
.browser_session()
|
||||
.authenticate_with_password(&mut rng, &clock, session, &user_password)
|
||||
.await?;
|
||||
|
||||
let cookie_jar = cookie_jar.set_session(&session);
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
let reply = query.go_next();
|
||||
Ok((cookie_jar, reply).into_response())
|
||||
|
||||
@@ -31,9 +31,9 @@ use mas_email::Mailer;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::Route;
|
||||
use mas_storage::user::{
|
||||
add_user, add_user_email, add_user_email_verification_code, add_user_password,
|
||||
authenticate_session_with_password, start_session, username_exists,
|
||||
use mas_storage::{
|
||||
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
|
||||
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
|
||||
};
|
||||
use mas_templates::{
|
||||
EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField,
|
||||
@@ -41,7 +41,6 @@ use mas_templates::{
|
||||
};
|
||||
use rand::{distributions::Uniform, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{PgConnection, PgPool};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use super::shared::OptionalPostAuthAction;
|
||||
@@ -60,18 +59,17 @@ impl ToFormState for RegisterForm {
|
||||
}
|
||||
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let maybe_session = session_info.load_session(&mut conn).await?;
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
if maybe_session.is_some() {
|
||||
let reply = query.go_next();
|
||||
@@ -81,7 +79,7 @@ pub(crate) async fn get(
|
||||
RegisterContext::default(),
|
||||
query,
|
||||
csrf_token,
|
||||
&mut conn,
|
||||
&mut repo,
|
||||
&templates,
|
||||
)
|
||||
.await?;
|
||||
@@ -92,21 +90,20 @@ pub(crate) async fn get(
|
||||
|
||||
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(password_manager): State<PasswordManager>,
|
||||
State(mailer): State<Mailer>,
|
||||
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||
State(templates): State<Templates>,
|
||||
State(pool): State<PgPool>,
|
||||
mut repo: BoxRepository,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<RegisterForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut txn = pool.begin().await?;
|
||||
let form = cookie_jar.verify_form(&clock, form)?;
|
||||
|
||||
let form = cookie_jar.verify_form(clock.now(), form)?;
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
|
||||
// Validate the form
|
||||
let state = {
|
||||
@@ -114,7 +111,7 @@ pub(crate) async fn post(
|
||||
|
||||
if form.username.is_empty() {
|
||||
state.add_error_on_field(RegisterFormField::Username, FieldError::Required);
|
||||
} else if username_exists(&mut txn, &form.username).await? {
|
||||
} else if repo.user().exists(&form.username).await? {
|
||||
state.add_error_on_field(RegisterFormField::Username, FieldError::Exists);
|
||||
}
|
||||
|
||||
@@ -177,7 +174,7 @@ pub(crate) async fn post(
|
||||
RegisterContext::default().with_form_state(state),
|
||||
query,
|
||||
csrf_token,
|
||||
&mut txn,
|
||||
&mut repo,
|
||||
&templates,
|
||||
)
|
||||
.await?;
|
||||
@@ -185,21 +182,18 @@ pub(crate) async fn post(
|
||||
return Ok((cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
let user = add_user(&mut txn, &mut rng, &clock, &form.username).await?;
|
||||
let user = repo.user().add(&mut rng, &clock, form.username).await?;
|
||||
let password = Zeroizing::new(form.password.into_bytes());
|
||||
let (version, hashed_password) = password_manager.hash(&mut rng, password).await?;
|
||||
let user_password = add_user_password(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&user,
|
||||
version,
|
||||
hashed_password,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
let user_password = repo
|
||||
.user_password()
|
||||
.add(&mut rng, &clock, &user, version, hashed_password, None)
|
||||
.await?;
|
||||
|
||||
let user_email = add_user_email(&mut txn, &mut rng, &clock, &user, form.email).await?;
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.add(&mut rng, &clock, &user, form.email)
|
||||
.await?;
|
||||
|
||||
// First, generate a code
|
||||
let range = Uniform::<u32>::from(0..1_000_000);
|
||||
@@ -208,15 +202,10 @@ pub(crate) async fn post(
|
||||
|
||||
let address: Address = user_email.email.parse()?;
|
||||
|
||||
let verification = add_user_email_verification_code(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
user_email,
|
||||
Duration::hours(8),
|
||||
code,
|
||||
)
|
||||
.await?;
|
||||
let verification = repo
|
||||
.user_email()
|
||||
.add_verification_code(&mut rng, &clock, &user_email, Duration::hours(8), code)
|
||||
.await?;
|
||||
|
||||
// And send the verification email
|
||||
let mailbox = Mailbox::new(Some(user.username.clone()), address);
|
||||
@@ -225,14 +214,16 @@ pub(crate) async fn post(
|
||||
|
||||
mailer.send_verification_email(mailbox, &context).await?;
|
||||
|
||||
let next = mas_router::AccountVerifyEmail::new(verification.email.id)
|
||||
.and_maybe(query.post_auth_action);
|
||||
let next = mas_router::AccountVerifyEmail::new(user_email.id).and_maybe(query.post_auth_action);
|
||||
|
||||
let mut session = start_session(&mut txn, &mut rng, &clock, user).await?;
|
||||
authenticate_session_with_password(&mut txn, &mut rng, &clock, &mut session, &user_password)
|
||||
let session = repo.browser_session().add(&mut rng, &clock, &user).await?;
|
||||
|
||||
let session = repo
|
||||
.browser_session()
|
||||
.authenticate_with_password(&mut rng, &clock, session, &user_password)
|
||||
.await?;
|
||||
|
||||
txn.commit().await?;
|
||||
repo.save().await?;
|
||||
|
||||
let cookie_jar = cookie_jar.set_session(&session);
|
||||
Ok((cookie_jar, next.go()).into_response())
|
||||
@@ -242,10 +233,10 @@ async fn render(
|
||||
ctx: RegisterContext,
|
||||
action: OptionalPostAuthAction,
|
||||
csrf_token: CsrfToken,
|
||||
conn: &mut PgConnection,
|
||||
repo: &mut impl RepositoryAccess,
|
||||
templates: &Templates,
|
||||
) -> Result<String, FancyError> {
|
||||
let next = action.load_context(conn).await?;
|
||||
let next = action.load_context(repo).await?;
|
||||
let ctx = if let Some(next) = next {
|
||||
ctx.with_post_action(next)
|
||||
} else {
|
||||
|
||||
@@ -15,11 +15,13 @@
|
||||
use anyhow::Context;
|
||||
use mas_router::{PostAuthAction, Route};
|
||||
use mas_storage::{
|
||||
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
|
||||
compat::CompatSsoLoginRepository,
|
||||
oauth2::OAuth2AuthorizationGrantRepository,
|
||||
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
|
||||
RepositoryAccess,
|
||||
};
|
||||
use mas_templates::{PostAuthContext, PostAuthContextInner};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgConnection;
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone)]
|
||||
pub(crate) struct OptionalPostAuthAction {
|
||||
@@ -38,14 +40,16 @@ impl OptionalPostAuthAction {
|
||||
self.go_next_or_default(&mas_router::Index)
|
||||
}
|
||||
|
||||
pub async fn load_context(
|
||||
&self,
|
||||
conn: &mut PgConnection,
|
||||
pub async fn load_context<'a>(
|
||||
&'a self,
|
||||
repo: &'a mut impl RepositoryAccess,
|
||||
) -> anyhow::Result<Option<PostAuthContext>> {
|
||||
let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
|
||||
let ctx = match action {
|
||||
PostAuthAction::ContinueAuthorizationGrant { id } => {
|
||||
let grant = get_grant_by_id(conn, id)
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.lookup(id)
|
||||
.await?
|
||||
.context("Failed to load authorization grant")?;
|
||||
let grant = Box::new(grant);
|
||||
@@ -53,7 +57,9 @@ impl OptionalPostAuthAction {
|
||||
}
|
||||
|
||||
PostAuthAction::ContinueCompatSsoLogin { id } => {
|
||||
let login = get_compat_sso_login_by_id(conn, id)
|
||||
let login = repo
|
||||
.compat_sso_login()
|
||||
.lookup(id)
|
||||
.await?
|
||||
.context("Failed to load compat SSO login")?;
|
||||
let login = Box::new(login);
|
||||
@@ -63,14 +69,17 @@ impl OptionalPostAuthAction {
|
||||
PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword,
|
||||
|
||||
PostAuthAction::LinkUpstream { id } => {
|
||||
let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id)
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.lookup(id)
|
||||
.await?
|
||||
.context("Failed to load upstream OAuth 2.0 link")?;
|
||||
|
||||
let provider =
|
||||
mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id)
|
||||
.await?
|
||||
.context("Failed to load upstream OAuth 2.0 provider")?;
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.lookup(link.provider_id)
|
||||
.await?
|
||||
.context("Failed to load upstream OAuth 2.0 provider")?;
|
||||
|
||||
let provider = Box::new(provider);
|
||||
let link = Box::new(link);
|
||||
|
||||
@@ -15,12 +15,7 @@
|
||||
//! A crate to store keys which can then be used to sign and verify JWTs.
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(
|
||||
clippy::all,
|
||||
clippy::str_to_string,
|
||||
rustdoc::broken_intra_doc_links,
|
||||
rustdoc::all
|
||||
)]
|
||||
#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)]
|
||||
#![warn(clippy::pedantic)]
|
||||
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
|
||||
@@ -22,6 +22,9 @@
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
//! An utility crate to build flexible [`hyper`] listeners, with optional TLS
|
||||
//! and proxy protocol support.
|
||||
|
||||
use self::{maybe_tls::TlsStreamInfo, proxy_protocol::ProxyProtocolV1Info};
|
||||
|
||||
pub mod maybe_tls;
|
||||
|
||||
@@ -39,7 +39,7 @@ pub struct Server<S> {
|
||||
impl<S> Server<S> {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the listener couldn't be converted via [`TryInfo`]
|
||||
/// Returns an error if the listener couldn't be converted via [`TryInto`]
|
||||
pub fn try_new<L>(listener: L, service: S) -> Result<Self, L::Error>
|
||||
where
|
||||
L: TryInto<UnixOrTcpListener>,
|
||||
|
||||
@@ -90,6 +90,11 @@ impl<T> Localized<T> {
|
||||
&self.non_localized
|
||||
}
|
||||
|
||||
/// Get the non-localized variant.
|
||||
pub fn to_non_localized(self) -> T {
|
||||
self.non_localized
|
||||
}
|
||||
|
||||
/// Get the variant corresponding to the given language, if it exists.
|
||||
pub fn get(&self, language: Option<&LanguageTag>) -> Option<&T> {
|
||||
match language {
|
||||
|
||||
@@ -69,7 +69,7 @@ pub struct PolicyFactory {
|
||||
}
|
||||
|
||||
impl PolicyFactory {
|
||||
#[tracing::instrument(skip(source), err)]
|
||||
#[tracing::instrument(name = "policy.load", skip(source), err)]
|
||||
pub async fn load(
|
||||
mut source: impl AsyncRead + std::marker::Unpin,
|
||||
data: serde_json::Value,
|
||||
@@ -108,7 +108,7 @@ impl PolicyFactory {
|
||||
authorization_grant_endpoint,
|
||||
};
|
||||
|
||||
// Try to instanciate
|
||||
// Try to instantiate
|
||||
factory
|
||||
.instantiate()
|
||||
.await
|
||||
@@ -117,7 +117,7 @@ impl PolicyFactory {
|
||||
Ok(factory)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), err)]
|
||||
#[tracing::instrument(name = "policy.instantiate", skip_all, err)]
|
||||
pub async fn instantiate(&self) -> Result<Policy, InstanciateError> {
|
||||
let mut store = Store::new(&self.engine, ());
|
||||
let runtime = Runtime::new(&mut store, &self.module)
|
||||
@@ -189,7 +189,14 @@ pub enum EvaluationError {
|
||||
}
|
||||
|
||||
impl Policy {
|
||||
#[tracing::instrument(skip(self, password))]
|
||||
#[tracing::instrument(
|
||||
name = "policy.evaluate.register",
|
||||
skip_all,
|
||||
fields(
|
||||
data.username = username,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_register(
|
||||
&mut self,
|
||||
username: &str,
|
||||
@@ -234,7 +241,15 @@ impl Policy {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
#[tracing::instrument(
|
||||
name = "policy.evaluate.authorization_grant",
|
||||
skip_all,
|
||||
fields(
|
||||
data.authorization_grant.id = %authorization_grant.id,
|
||||
data.user.id = %user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_authorization_grant(
|
||||
&mut self,
|
||||
authorization_grant: &AuthorizationGrant,
|
||||
|
||||
@@ -21,6 +21,8 @@
|
||||
)]
|
||||
#![warn(clippy::pedantic)]
|
||||
|
||||
//! A crate to help serve single-page apps built by Vite.
|
||||
|
||||
mod vite;
|
||||
|
||||
use std::{future::Future, pin::Pin};
|
||||
|
||||
28
crates/storage-pg/Cargo.toml
Normal file
28
crates/storage-pg/Cargo.toml
Normal file
@@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "mas-storage-pg"
|
||||
version = "0.1.0"
|
||||
authors = ["Quentin Gliech <quenting@element.io>"]
|
||||
edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.60"
|
||||
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] }
|
||||
chrono = { version = "0.4.23", features = ["serde"] }
|
||||
serde = { version = "1.0.152", features = ["derive"] }
|
||||
serde_json = "1.0.91"
|
||||
thiserror = "1.0.38"
|
||||
tracing = "0.1.37"
|
||||
futures-util = "0.3.25"
|
||||
|
||||
rand = "0.8.5"
|
||||
rand_chacha = "0.3.1"
|
||||
url = { version = "2.3.1", features = ["serde"] }
|
||||
uuid = "1.2.2"
|
||||
ulid = { version = "1.0.0", features = ["uuid", "serde"] }
|
||||
|
||||
oauth2-types = { path = "../oauth2-types" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-data-model = { path = "../data-model" }
|
||||
mas-iana = { path = "../iana" }
|
||||
mas-jose = { path = "../jose" }
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
2441
crates/storage-pg/sqlx-data.json
Normal file
2441
crates/storage-pg/sqlx-data.json
Normal file
File diff suppressed because it is too large
Load Diff
220
crates/storage-pg/src/compat/access_token.rs
Normal file
220
crates/storage-pg/src/compat/access_token.rs
Normal file
@@ -0,0 +1,220 @@
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{CompatAccessToken, CompatSession};
|
||||
use mas_storage::{compat::CompatAccessTokenRepository, Clock};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
|
||||
|
||||
/// An implementation of [`CompatAccessTokenRepository`] for a PostgreSQL
|
||||
/// connection
|
||||
pub struct PgCompatAccessTokenRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgCompatAccessTokenRepository<'c> {
|
||||
/// Create a new [`PgCompatAccessTokenRepository`] from an active PostgreSQL
|
||||
/// connection
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct CompatAccessTokenLookup {
|
||||
compat_access_token_id: Uuid,
|
||||
access_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
expires_at: Option<DateTime<Utc>>,
|
||||
compat_session_id: Uuid,
|
||||
}
|
||||
|
||||
impl From<CompatAccessTokenLookup> for CompatAccessToken {
|
||||
fn from(value: CompatAccessTokenLookup) -> Self {
|
||||
Self {
|
||||
id: value.compat_access_token_id.into(),
|
||||
session_id: value.compat_session_id.into(),
|
||||
token: value.access_token,
|
||||
created_at: value.created_at,
|
||||
expires_at: value.expires_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_access_token.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_session.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatAccessTokenLookup,
|
||||
r#"
|
||||
SELECT compat_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_access_tokens
|
||||
|
||||
WHERE compat_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_access_token.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
access_token: &str,
|
||||
) -> Result<Option<CompatAccessToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatAccessTokenLookup,
|
||||
r#"
|
||||
SELECT compat_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_access_tokens
|
||||
|
||||
WHERE access_token = $1
|
||||
"#,
|
||||
access_token,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_access_token.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_access_token.id,
|
||||
%compat_session.id,
|
||||
user.id = %compat_session.user_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
compat_session: &CompatSession,
|
||||
token: String,
|
||||
expires_after: Option<Duration>,
|
||||
) -> Result<CompatAccessToken, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
|
||||
|
||||
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_access_tokens
|
||||
(compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(compat_session.id),
|
||||
token,
|
||||
created_at,
|
||||
expires_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(CompatAccessToken {
|
||||
id,
|
||||
session_id: compat_session.id,
|
||||
token,
|
||||
created_at,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_access_token.expire",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_access_token.id,
|
||||
compat_session.id = %compat_access_token.session_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn expire(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
mut compat_access_token: CompatAccessToken,
|
||||
) -> Result<CompatAccessToken, Self::Error> {
|
||||
let expires_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_access_tokens
|
||||
SET expires_at = $2
|
||||
WHERE compat_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_access_token.id),
|
||||
expires_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
compat_access_token.expires_at = Some(expires_at);
|
||||
Ok(compat_access_token)
|
||||
}
|
||||
}
|
||||
449
crates/storage-pg/src/compat/mod.rs
Normal file
449
crates/storage-pg/src/compat/mod.rs
Normal file
@@ -0,0 +1,449 @@
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! A module containing PostgreSQL implementation of repositories for the
|
||||
//! compatibility layer
|
||||
|
||||
mod access_token;
|
||||
mod refresh_token;
|
||||
mod session;
|
||||
mod sso_login;
|
||||
|
||||
pub use self::{
|
||||
access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository,
|
||||
session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use chrono::Duration;
|
||||
use mas_data_model::Device;
|
||||
use mas_storage::{
|
||||
clock::MockClock,
|
||||
compat::{
|
||||
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
|
||||
},
|
||||
user::UserRepository,
|
||||
Clock, Pagination, Repository, RepositoryAccess,
|
||||
};
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaChaRng;
|
||||
use sqlx::PgPool;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::PgRepository;
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_session_repository(pool: PgPool) {
|
||||
const FIRST_TOKEN: &str = "first_access_token";
|
||||
const SECOND_TOKEN: &str = "second_access_token";
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = MockClock::default();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
|
||||
// Create a user
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, "john".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Start a compat session for that user
|
||||
let device = Device::generate(&mut rng);
|
||||
let device_str = device.as_str().to_owned();
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.add(&mut rng, &clock, &user, device)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(session.user_id, user.id);
|
||||
assert_eq!(session.device.as_str(), device_str);
|
||||
assert!(session.is_valid());
|
||||
assert!(!session.is_finished());
|
||||
|
||||
// Lookup the session and check it didn't change
|
||||
let session_lookup = repo
|
||||
.compat_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("compat session not found");
|
||||
assert_eq!(session_lookup.id, session.id);
|
||||
assert_eq!(session_lookup.user_id, user.id);
|
||||
assert_eq!(session_lookup.device.as_str(), device_str);
|
||||
assert!(session_lookup.is_valid());
|
||||
assert!(!session_lookup.is_finished());
|
||||
|
||||
// Finish the session
|
||||
let session = repo.compat_session().finish(&clock, session).await.unwrap();
|
||||
assert!(!session.is_valid());
|
||||
assert!(session.is_finished());
|
||||
|
||||
// Reload the session and check again
|
||||
let session_lookup = repo
|
||||
.compat_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("compat session not found");
|
||||
assert!(!session_lookup.is_valid());
|
||||
assert!(session_lookup.is_finished());
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_access_token_repository(pool: PgPool) {
|
||||
const FIRST_TOKEN: &str = "first_access_token";
|
||||
const SECOND_TOKEN: &str = "second_access_token";
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = MockClock::default();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
|
||||
|
||||
// Create a user
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, "john".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Start a compat session for that user
|
||||
let device = Device::generate(&mut rng);
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.add(&mut rng, &clock, &user, device)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Add an access token to that session
|
||||
let token = repo
|
||||
.compat_access_token()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
FIRST_TOKEN.to_owned(),
|
||||
Some(Duration::minutes(1)),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(token.session_id, session.id);
|
||||
assert_eq!(token.token, FIRST_TOKEN);
|
||||
|
||||
// Commit the txn and grab a new transaction, to test a conflict
|
||||
repo.save().await.unwrap();
|
||||
|
||||
{
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
|
||||
// Adding the same token a second time should conflict
|
||||
assert!(repo
|
||||
.compat_access_token()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
FIRST_TOKEN.to_owned(),
|
||||
Some(Duration::minutes(1)),
|
||||
)
|
||||
.await
|
||||
.is_err());
|
||||
repo.cancel().await.unwrap();
|
||||
}
|
||||
|
||||
// Grab a new repo
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
|
||||
|
||||
// Looking up via ID works
|
||||
let token_lookup = repo
|
||||
.compat_access_token()
|
||||
.lookup(token.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("compat access token not found");
|
||||
assert_eq!(token.id, token_lookup.id);
|
||||
assert_eq!(token_lookup.session_id, session.id);
|
||||
|
||||
// Looking up via the token value works
|
||||
let token_lookup = repo
|
||||
.compat_access_token()
|
||||
.find_by_token(FIRST_TOKEN)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("compat access token not found");
|
||||
assert_eq!(token.id, token_lookup.id);
|
||||
assert_eq!(token_lookup.session_id, session.id);
|
||||
|
||||
// Token is currently valid
|
||||
assert!(token.is_valid(clock.now()));
|
||||
|
||||
clock.advance(Duration::minutes(1));
|
||||
// Token should have expired
|
||||
assert!(!token.is_valid(clock.now()));
|
||||
|
||||
// Add a second access token, this time without expiration
|
||||
let token = repo
|
||||
.compat_access_token()
|
||||
.add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(token.session_id, session.id);
|
||||
assert_eq!(token.token, SECOND_TOKEN);
|
||||
|
||||
// Token is currently valid
|
||||
assert!(token.is_valid(clock.now()));
|
||||
|
||||
// Make it expire
|
||||
repo.compat_access_token()
|
||||
.expire(&clock, token)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Reload it
|
||||
let token = repo
|
||||
.compat_access_token()
|
||||
.find_by_token(SECOND_TOKEN)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("compat access token not found");
|
||||
|
||||
// Token is not valid anymore
|
||||
assert!(!token.is_valid(clock.now()));
|
||||
|
||||
repo.save().await.unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_refresh_token_repository(pool: PgPool) {
|
||||
const ACCESS_TOKEN: &str = "access_token";
|
||||
const REFRESH_TOKEN: &str = "refresh_token";
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = MockClock::default();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
|
||||
|
||||
// Create a user
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, "john".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Start a compat session for that user
|
||||
let device = Device::generate(&mut rng);
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.add(&mut rng, &clock, &user, device)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Add an access token to that session
|
||||
let access_token = repo
|
||||
.compat_access_token()
|
||||
.add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let refresh_token = repo
|
||||
.compat_refresh_token()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
&access_token,
|
||||
REFRESH_TOKEN.to_owned(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(refresh_token.session_id, session.id);
|
||||
assert_eq!(refresh_token.access_token_id, access_token.id);
|
||||
assert_eq!(refresh_token.token, REFRESH_TOKEN);
|
||||
assert!(refresh_token.is_valid());
|
||||
assert!(!refresh_token.is_consumed());
|
||||
|
||||
// Look it up by ID and check everything matches
|
||||
let refresh_token_lookup = repo
|
||||
.compat_refresh_token()
|
||||
.lookup(refresh_token.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("refresh token not found");
|
||||
assert_eq!(refresh_token_lookup.id, refresh_token.id);
|
||||
assert_eq!(refresh_token_lookup.session_id, session.id);
|
||||
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
|
||||
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
|
||||
assert!(refresh_token_lookup.is_valid());
|
||||
assert!(!refresh_token_lookup.is_consumed());
|
||||
|
||||
// Look it up by token and check everything matches
|
||||
let refresh_token_lookup = repo
|
||||
.compat_refresh_token()
|
||||
.find_by_token(REFRESH_TOKEN)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("refresh token not found");
|
||||
assert_eq!(refresh_token_lookup.id, refresh_token.id);
|
||||
assert_eq!(refresh_token_lookup.session_id, session.id);
|
||||
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
|
||||
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
|
||||
assert!(refresh_token_lookup.is_valid());
|
||||
assert!(!refresh_token_lookup.is_consumed());
|
||||
|
||||
// Consume it
|
||||
let refresh_token = repo
|
||||
.compat_refresh_token()
|
||||
.consume(&clock, refresh_token)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!refresh_token.is_valid());
|
||||
assert!(refresh_token.is_consumed());
|
||||
|
||||
// Reload it and check again
|
||||
let refresh_token_lookup = repo
|
||||
.compat_refresh_token()
|
||||
.find_by_token(REFRESH_TOKEN)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("refresh token not found");
|
||||
assert!(!refresh_token_lookup.is_valid());
|
||||
assert!(refresh_token_lookup.is_consumed());
|
||||
|
||||
// Consuming it again should not work
|
||||
assert!(repo
|
||||
.compat_refresh_token()
|
||||
.consume(&clock, refresh_token)
|
||||
.await
|
||||
.is_err());
|
||||
|
||||
repo.save().await.unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_compat_sso_login_repository(pool: PgPool) {
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = MockClock::default();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
|
||||
|
||||
// Create a user
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, "john".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Lookup an unknown SSO login
|
||||
let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap();
|
||||
assert_eq!(login, None);
|
||||
|
||||
// Lookup an unknown login token
|
||||
let login = repo
|
||||
.compat_sso_login()
|
||||
.find_by_token("login-token")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(login, None);
|
||||
|
||||
// Start a new SSO login
|
||||
let login = repo
|
||||
.compat_sso_login()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
"login-token".to_owned(),
|
||||
"https://example.com/callback".parse().unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(login.is_pending());
|
||||
|
||||
// Lookup the login by ID
|
||||
let login_lookup = repo
|
||||
.compat_sso_login()
|
||||
.lookup(login.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("login not found");
|
||||
assert_eq!(login_lookup, login);
|
||||
|
||||
// Find the login by token
|
||||
let login_lookup = repo
|
||||
.compat_sso_login()
|
||||
.find_by_token("login-token")
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("login not found");
|
||||
assert_eq!(login_lookup, login);
|
||||
|
||||
// Exchanging before fulfilling should not work
|
||||
// Note: It should also not poison the SQL transaction
|
||||
let res = repo
|
||||
.compat_sso_login()
|
||||
.exchange(&clock, login.clone())
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
|
||||
// Start a compat session for that user
|
||||
let device = Device::generate(&mut rng);
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.add(&mut rng, &clock, &user, device)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Associate the login with the session
|
||||
let login = repo
|
||||
.compat_sso_login()
|
||||
.fulfill(&clock, login, &session)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(login.is_fulfilled());
|
||||
|
||||
// Fulfilling again should not work
|
||||
// Note: It should also not poison the SQL transaction
|
||||
let res = repo
|
||||
.compat_sso_login()
|
||||
.fulfill(&clock, login.clone(), &session)
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
|
||||
// Exchange that login
|
||||
let login = repo
|
||||
.compat_sso_login()
|
||||
.exchange(&clock, login)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(login.is_exchanged());
|
||||
|
||||
// Exchange again should not work
|
||||
// Note: It should also not poison the SQL transaction
|
||||
let res = repo
|
||||
.compat_sso_login()
|
||||
.exchange(&clock, login.clone())
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
|
||||
// Fulfilling after exchanging should not work
|
||||
// Note: It should also not poison the SQL transaction
|
||||
let res = repo
|
||||
.compat_sso_login()
|
||||
.fulfill(&clock, login.clone(), &session)
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
|
||||
// List the logins for the user
|
||||
let logins = repo
|
||||
.compat_sso_login()
|
||||
.list_paginated(&user, Pagination::first(10))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!logins.has_next_page);
|
||||
assert_eq!(logins.edges, vec![login]);
|
||||
}
|
||||
}
|
||||
234
crates/storage-pg/src/compat/refresh_token.rs
Normal file
234
crates/storage-pg/src/compat/refresh_token.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
|
||||
};
|
||||
use mas_storage::{compat::CompatRefreshTokenRepository, Clock};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
|
||||
|
||||
/// An implementation of [`CompatRefreshTokenRepository`] for a PostgreSQL
|
||||
/// connection
|
||||
pub struct PgCompatRefreshTokenRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgCompatRefreshTokenRepository<'c> {
|
||||
/// Create a new [`PgCompatRefreshTokenRepository`] from an active
|
||||
/// PostgreSQL connection
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct CompatRefreshTokenLookup {
|
||||
compat_refresh_token_id: Uuid,
|
||||
refresh_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
compat_access_token_id: Uuid,
|
||||
compat_session_id: Uuid,
|
||||
}
|
||||
|
||||
impl From<CompatRefreshTokenLookup> for CompatRefreshToken {
|
||||
fn from(value: CompatRefreshTokenLookup) -> Self {
|
||||
let state = match value.consumed_at {
|
||||
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
|
||||
None => CompatRefreshTokenState::Valid,
|
||||
};
|
||||
|
||||
Self {
|
||||
id: value.compat_refresh_token_id.into(),
|
||||
state,
|
||||
session_id: value.compat_session_id.into(),
|
||||
token: value.refresh_token,
|
||||
created_at: value.created_at,
|
||||
access_token_id: value.compat_access_token_id.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_refresh_token.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_refresh_token.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatRefreshTokenLookup,
|
||||
r#"
|
||||
SELECT compat_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, compat_session_id
|
||||
, compat_access_token_id
|
||||
|
||||
FROM compat_refresh_tokens
|
||||
|
||||
WHERE compat_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_refresh_token.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
refresh_token: &str,
|
||||
) -> Result<Option<CompatRefreshToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatRefreshTokenLookup,
|
||||
r#"
|
||||
SELECT compat_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, compat_session_id
|
||||
, compat_access_token_id
|
||||
|
||||
FROM compat_refresh_tokens
|
||||
|
||||
WHERE refresh_token = $1
|
||||
"#,
|
||||
refresh_token,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_refresh_token.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_refresh_token.id,
|
||||
%compat_session.id,
|
||||
user.id = %compat_session.user_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
compat_session: &CompatSession,
|
||||
compat_access_token: &CompatAccessToken,
|
||||
token: String,
|
||||
) -> Result<CompatRefreshToken, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_refresh_tokens
|
||||
(compat_refresh_token_id, compat_session_id,
|
||||
compat_access_token_id, refresh_token, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(compat_session.id),
|
||||
Uuid::from(compat_access_token.id),
|
||||
token,
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(CompatRefreshToken {
|
||||
id,
|
||||
state: CompatRefreshTokenState::default(),
|
||||
session_id: compat_session.id,
|
||||
access_token_id: compat_access_token.id,
|
||||
token,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_refresh_token.consume",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_refresh_token.id,
|
||||
compat_session.id = %compat_refresh_token.session_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn consume(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
compat_refresh_token: CompatRefreshToken,
|
||||
) -> Result<CompatRefreshToken, Self::Error> {
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_refresh_tokens
|
||||
SET consumed_at = $2
|
||||
WHERE compat_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_refresh_token.id),
|
||||
consumed_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
let compat_refresh_token = compat_refresh_token
|
||||
.consume(consumed_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(compat_refresh_token)
|
||||
}
|
||||
}
|
||||
198
crates/storage-pg/src/compat/session.rs
Normal file
198
crates/storage-pg/src/compat/session.rs
Normal file
@@ -0,0 +1,198 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{CompatSession, CompatSessionState, Device, User};
|
||||
use mas_storage::{compat::CompatSessionRepository, Clock};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
|
||||
|
||||
/// An implementation of [`CompatSessionRepository`] for a PostgreSQL connection
|
||||
pub struct PgCompatSessionRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgCompatSessionRepository<'c> {
|
||||
/// Create a new [`PgCompatSessionRepository`] from an active PostgreSQL
|
||||
/// connection
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct CompatSessionLookup {
|
||||
compat_session_id: Uuid,
|
||||
device_id: String,
|
||||
user_id: Uuid,
|
||||
created_at: DateTime<Utc>,
|
||||
finished_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TryFrom<CompatSessionLookup> for CompatSession {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: CompatSessionLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.compat_session_id.into();
|
||||
let device = Device::try_from(value.device_id).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("compat_sessions")
|
||||
.column("device_id")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match value.finished_at {
|
||||
None => CompatSessionState::Valid,
|
||||
Some(finished_at) => CompatSessionState::Finished { finished_at },
|
||||
};
|
||||
|
||||
let session = CompatSession {
|
||||
id,
|
||||
state,
|
||||
user_id: value.user_id.into(),
|
||||
device,
|
||||
created_at: value.created_at,
|
||||
};
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_session.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_session.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSessionLookup,
|
||||
r#"
|
||||
SELECT compat_session_id
|
||||
, device_id
|
||||
, user_id
|
||||
, created_at
|
||||
, finished_at
|
||||
FROM compat_sessions
|
||||
WHERE compat_session_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_session.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_session.id,
|
||||
%user.id,
|
||||
%user.username,
|
||||
compat_session.device.id = device.as_str(),
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
user: &User,
|
||||
device: Device,
|
||||
) -> Result<CompatSession, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user.id),
|
||||
device.as_str(),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(CompatSession {
|
||||
id,
|
||||
state: CompatSessionState::default(),
|
||||
user_id: user.id,
|
||||
device,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_session.finish",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_session.id,
|
||||
user.id = %compat_session.user_id,
|
||||
compat_session.device.id = compat_session.device.as_str(),
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn finish(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
compat_session: CompatSession,
|
||||
) -> Result<CompatSession, Self::Error> {
|
||||
let finished_at = clock.now();
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sessions cs
|
||||
SET finished_at = $2
|
||||
WHERE compat_session_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_session.id),
|
||||
finished_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
let compat_session = compat_session
|
||||
.finish(finished_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(compat_session)
|
||||
}
|
||||
}
|
||||
346
crates/storage-pg/src/compat/sso_login.rs
Normal file
346
crates/storage-pg/src/compat/sso_login.rs
Normal file
@@ -0,0 +1,346 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User};
|
||||
use mas_storage::{compat::CompatSsoLoginRepository, Clock, Page, Pagination};
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
|
||||
LookupResultExt,
|
||||
};
|
||||
|
||||
/// An implementation of [`CompatSsoLoginRepository`] for a PostgreSQL
|
||||
/// connection
|
||||
pub struct PgCompatSsoLoginRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgCompatSsoLoginRepository<'c> {
|
||||
/// Create a new [`PgCompatSsoLoginRepository`] from an active PostgreSQL
|
||||
/// connection
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct CompatSsoLoginLookup {
|
||||
compat_sso_login_id: Uuid,
|
||||
login_token: String,
|
||||
redirect_uri: String,
|
||||
created_at: DateTime<Utc>,
|
||||
fulfilled_at: Option<DateTime<Utc>>,
|
||||
exchanged_at: Option<DateTime<Utc>>,
|
||||
compat_session_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
|
||||
let id = res.compat_sso_login_id.into();
|
||||
let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("compat_sso_logins")
|
||||
.column("redirect_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) {
|
||||
(None, None, None) => CompatSsoLoginState::Pending,
|
||||
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
|
||||
fulfilled_at,
|
||||
session_id: session_id.into(),
|
||||
},
|
||||
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
|
||||
CompatSsoLoginState::Exchanged {
|
||||
fulfilled_at,
|
||||
exchanged_at,
|
||||
session_id: session_id.into(),
|
||||
}
|
||||
}
|
||||
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
|
||||
};
|
||||
|
||||
Ok(CompatSsoLogin {
|
||||
id,
|
||||
login_token: res.login_token,
|
||||
redirect_uri,
|
||||
created_at: res.created_at,
|
||||
state,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_sso_login.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSsoLoginLookup,
|
||||
r#"
|
||||
SELECT compat_sso_login_id
|
||||
, login_token
|
||||
, redirect_uri
|
||||
, created_at
|
||||
, fulfilled_at
|
||||
, exchanged_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_sso_logins
|
||||
WHERE compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
login_token: &str,
|
||||
) -> Result<Option<CompatSsoLogin>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSsoLoginLookup,
|
||||
r#"
|
||||
SELECT compat_sso_login_id
|
||||
, login_token
|
||||
, redirect_uri
|
||||
, created_at
|
||||
, fulfilled_at
|
||||
, exchanged_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_sso_logins
|
||||
WHERE login_token = $1
|
||||
"#,
|
||||
login_token,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_sso_login.id,
|
||||
compat_sso_login.redirect_uri = %redirect_uri,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
login_token: String,
|
||||
redirect_uri: Url,
|
||||
) -> Result<CompatSsoLogin, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_sso_logins
|
||||
(compat_sso_login_id, login_token, redirect_uri, created_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
&login_token,
|
||||
redirect_uri.as_str(),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(CompatSsoLogin {
|
||||
id,
|
||||
login_token,
|
||||
redirect_uri,
|
||||
created_at,
|
||||
state: CompatSsoLoginState::default(),
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.fulfill",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_sso_login.id,
|
||||
%compat_session.id,
|
||||
compat_session.device.id = compat_session.device.as_str(),
|
||||
user.id = %compat_session.user_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn fulfill(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
compat_sso_login: CompatSsoLogin,
|
||||
compat_session: &CompatSession,
|
||||
) -> Result<CompatSsoLogin, Self::Error> {
|
||||
let fulfilled_at = clock.now();
|
||||
let compat_sso_login = compat_sso_login
|
||||
.fulfill(fulfilled_at, compat_session)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sso_logins
|
||||
SET
|
||||
compat_session_id = $2,
|
||||
fulfilled_at = $3
|
||||
WHERE
|
||||
compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_sso_login.id),
|
||||
Uuid::from(compat_session.id),
|
||||
fulfilled_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
Ok(compat_sso_login)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.exchange",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_sso_login.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn exchange(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
compat_sso_login: CompatSsoLogin,
|
||||
) -> Result<CompatSsoLogin, Self::Error> {
|
||||
let exchanged_at = clock.now();
|
||||
let compat_sso_login = compat_sso_login
|
||||
.exchange(exchanged_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sso_logins
|
||||
SET
|
||||
exchanged_at = $2
|
||||
WHERE
|
||||
compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_sso_login.id),
|
||||
exchanged_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
Ok(compat_sso_login)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.list_paginated",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err
|
||||
)]
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
user: &User,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<CompatSsoLogin>, Self::Error> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT cl.compat_sso_login_id
|
||||
, cl.login_token
|
||||
, cl.redirect_uri
|
||||
, cl.created_at
|
||||
, cl.fulfilled_at
|
||||
, cl.exchanged_at
|
||||
, cl.compat_session_id
|
||||
|
||||
FROM compat_sso_logins cl
|
||||
INNER JOIN compat_sessions cs USING (compat_session_id)
|
||||
"#,
|
||||
);
|
||||
|
||||
query
|
||||
.push(" WHERE cs.user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("cl.compat_sso_login_id", pagination);
|
||||
|
||||
let edges: Vec<CompatSsoLoginLookup> = query
|
||||
.build_query_as()
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let page = pagination
|
||||
.process(edges)
|
||||
.try_map(CompatSsoLogin::try_from)?;
|
||||
Ok(page)
|
||||
}
|
||||
}
|
||||
144
crates/storage-pg/src/errors.rs
Normal file
144
crates/storage-pg/src/errors.rs
Normal file
@@ -0,0 +1,144 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use sqlx::postgres::PgQueryResult;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
/// Generic error when interacting with the database
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub enum DatabaseError {
|
||||
/// An error which came from the database itself
|
||||
Driver {
|
||||
/// The underlying error from the database driver
|
||||
#[from]
|
||||
source: sqlx::Error,
|
||||
},
|
||||
|
||||
/// An error which occured while converting the data from the database
|
||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
||||
|
||||
/// An error which happened because the requested database operation is
|
||||
/// invalid
|
||||
#[error("Invalid database operation")]
|
||||
InvalidOperation {
|
||||
/// The source of the error, if any
|
||||
#[source]
|
||||
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
},
|
||||
|
||||
/// An error which happens when an operation affects not enough or too many
|
||||
/// rows
|
||||
#[error("Expected {expected} rows to be affected, but {actual} rows were affected")]
|
||||
RowsAffected {
|
||||
/// How many rows were expected to be affected
|
||||
expected: u64,
|
||||
|
||||
/// How many rows were actually affected
|
||||
actual: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl DatabaseError {
|
||||
pub(crate) fn ensure_affected_rows(
|
||||
result: &PgQueryResult,
|
||||
expected: u64,
|
||||
) -> Result<(), DatabaseError> {
|
||||
let actual = result.rows_affected();
|
||||
if actual == expected {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(DatabaseError::RowsAffected { expected, actual })
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn to_invalid_operation<E: std::error::Error + Send + Sync + 'static>(e: E) -> Self {
|
||||
Self::InvalidOperation {
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) const fn invalid_operation() -> Self {
|
||||
Self::InvalidOperation { source: None }
|
||||
}
|
||||
}
|
||||
|
||||
/// An error which occured while converting the data from the database
|
||||
#[derive(Debug, Error)]
|
||||
pub struct DatabaseInconsistencyError {
|
||||
/// The table which was being queried
|
||||
table: &'static str,
|
||||
|
||||
/// The column which was being queried
|
||||
column: Option<&'static str>,
|
||||
|
||||
/// The row which was being queried
|
||||
row: Option<Ulid>,
|
||||
|
||||
/// The source of the error
|
||||
#[source]
|
||||
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DatabaseInconsistencyError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Database inconsistency on table {}", self.table)?;
|
||||
if let Some(column) = self.column {
|
||||
write!(f, " column {column}")?;
|
||||
}
|
||||
if let Some(row) = self.row {
|
||||
write!(f, " row {row}")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl DatabaseInconsistencyError {
|
||||
/// Create a new [`DatabaseInconsistencyError`] for the given table
|
||||
#[must_use]
|
||||
pub(crate) const fn on(table: &'static str) -> Self {
|
||||
Self {
|
||||
table,
|
||||
column: None,
|
||||
row: None,
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the column which was being queried
|
||||
#[must_use]
|
||||
pub(crate) const fn column(mut self, column: &'static str) -> Self {
|
||||
self.column = Some(column);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the row which was being queried
|
||||
#[must_use]
|
||||
pub(crate) const fn row(mut self, row: Ulid) -> Self {
|
||||
self.row = Some(row);
|
||||
self
|
||||
}
|
||||
|
||||
/// Give the source of the error
|
||||
#[must_use]
|
||||
pub(crate) fn source<E: std::error::Error + Send + Sync + 'static>(
|
||||
mut self,
|
||||
source: E,
|
||||
) -> Self {
|
||||
self.source = Some(Box::new(source));
|
||||
self
|
||||
}
|
||||
}
|
||||
225
crates/storage-pg/src/lib.rs
Normal file
225
crates/storage-pg/src/lib.rs
Normal file
@@ -0,0 +1,225 @@
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! An implementation of the storage traits for a PostgreSQL database
|
||||
//!
|
||||
//! This backend uses [`sqlx`] to interact with the database. Most queries are
|
||||
//! type-checked, using introspection data recorded in the `sqlx-data.json`
|
||||
//! file. This file is generated by the `sqlx` CLI tool, and should be updated
|
||||
//! whenever the database schema changes, or new queries are added.
|
||||
//!
|
||||
//! # Implementing a new repository
|
||||
//!
|
||||
//! When a new repository is defined in [`mas_storage`], it should be
|
||||
//! implemented here, with the PostgreSQL backend.
|
||||
//!
|
||||
//! A typical implementation will look like this:
|
||||
//!
|
||||
//! ```rust
|
||||
//! # use async_trait::async_trait;
|
||||
//! # use ulid::Ulid;
|
||||
//! # use rand::RngCore;
|
||||
//! # use mas_storage::Clock;
|
||||
//! # use mas_storage_pg::{DatabaseError, ExecuteExt, LookupResultExt};
|
||||
//! # use sqlx::PgConnection;
|
||||
//! # use uuid::Uuid;
|
||||
//! #
|
||||
//! # // A fake data structure, usually defined in mas-data-model
|
||||
//! # #[derive(sqlx::FromRow)]
|
||||
//! # struct FakeData {
|
||||
//! # id: Ulid,
|
||||
//! # }
|
||||
//! #
|
||||
//! # // A fake repository trait, usually defined in mas-storage
|
||||
//! # #[async_trait]
|
||||
//! # pub trait FakeDataRepository: Send + Sync {
|
||||
//! # type Error;
|
||||
//! # async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error>;
|
||||
//! # async fn add(
|
||||
//! # &mut self,
|
||||
//! # rng: &mut (dyn RngCore + Send),
|
||||
//! # clock: &dyn Clock,
|
||||
//! # ) -> Result<FakeData, Self::Error>;
|
||||
//! # }
|
||||
//! #
|
||||
//! /// An implementation of [`FakeDataRepository`] for a PostgreSQL connection
|
||||
//! pub struct PgFakeDataRepository<'c> {
|
||||
//! conn: &'c mut PgConnection,
|
||||
//! }
|
||||
//!
|
||||
//! impl<'c> PgFakeDataRepository<'c> {
|
||||
//! /// Create a new [`FakeDataRepository`] from an active PostgreSQL connection
|
||||
//! pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
//! Self { conn }
|
||||
//! }
|
||||
//! }
|
||||
//!
|
||||
//! #[derive(sqlx::FromRow)]
|
||||
//! struct FakeDataLookup {
|
||||
//! fake_data_id: Uuid,
|
||||
//! }
|
||||
//!
|
||||
//! impl From<FakeDataLookup> for FakeData {
|
||||
//! fn from(value: FakeDataLookup) -> Self {
|
||||
//! Self {
|
||||
//! id: value.fake_data_id.into(),
|
||||
//! }
|
||||
//! }
|
||||
//! }
|
||||
//!
|
||||
//! #[async_trait]
|
||||
//! impl<'c> FakeDataRepository for PgFakeDataRepository<'c> {
|
||||
//! type Error = DatabaseError;
|
||||
//!
|
||||
//! #[tracing::instrument(
|
||||
//! name = "db.fake_data.lookup",
|
||||
//! skip_all,
|
||||
//! fields(
|
||||
//! db.statement,
|
||||
//! fake_data.id = %id,
|
||||
//! ),
|
||||
//! err,
|
||||
//! )]
|
||||
//! async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error> {
|
||||
//! // Note: here we would use the macro version instead, but it's not possible here in
|
||||
//! // this documentation example
|
||||
//! let res: Option<FakeDataLookup> = sqlx::query_as(
|
||||
//! r#"
|
||||
//! SELECT fake_data_id
|
||||
//! FROM fake_data
|
||||
//! WHERE fake_data_id = $1
|
||||
//! "#,
|
||||
//! )
|
||||
//! .bind(Uuid::from(id))
|
||||
//! .traced()
|
||||
//! .fetch_one(&mut *self.conn)
|
||||
//! .await
|
||||
//! .to_option()?;
|
||||
//!
|
||||
//! let Some(res) = res else { return Ok(None) };
|
||||
//!
|
||||
//! Ok(Some(res.into()))
|
||||
//! }
|
||||
//!
|
||||
//! #[tracing::instrument(
|
||||
//! name = "db.fake_data.add",
|
||||
//! skip_all,
|
||||
//! fields(
|
||||
//! db.statement,
|
||||
//! fake_data.id,
|
||||
//! ),
|
||||
//! err,
|
||||
//! )]
|
||||
//! async fn add(
|
||||
//! &mut self,
|
||||
//! rng: &mut (dyn RngCore + Send),
|
||||
//! clock: &dyn Clock,
|
||||
//! ) -> Result<FakeData, Self::Error> {
|
||||
//! let created_at = clock.now();
|
||||
//! let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
//! tracing::Span::current().record("fake_data.id", tracing::field::display(id));
|
||||
//!
|
||||
//! // Note: here we would use the macro version instead, but it's not possible here in
|
||||
//! // this documentation example
|
||||
//! sqlx::query(
|
||||
//! r#"
|
||||
//! INSERT INTO fake_data (id)
|
||||
//! VALUES ($1)
|
||||
//! "#,
|
||||
//! )
|
||||
//! .bind(Uuid::from(id))
|
||||
//! .traced()
|
||||
//! .execute(&mut *self.conn)
|
||||
//! .await?;
|
||||
//!
|
||||
//! Ok(FakeData {
|
||||
//! id,
|
||||
//! })
|
||||
//! }
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! A few things to note with the implementation:
|
||||
//!
|
||||
//! - All methods are traced, with an explicit, somewhat consistent name.
|
||||
//! - The SQL statement is included as attribute, by declaring a `db.statement`
|
||||
//! attribute on the tracing span, and then calling [`ExecuteExt::traced`].
|
||||
//! - The IDs are all [`Ulid`], and generated from the clock and the random
|
||||
//! number generated passed as parameters. The generated IDs are recorded in
|
||||
//! the span.
|
||||
//! - The IDs are stored as [`Uuid`] in PostgreSQL, so conversions are required
|
||||
//! - "Not found" errors are handled by returning `Ok(None)` instead of an
|
||||
//! error. The [`LookupResultExt::to_option`] method helps to do that.
|
||||
//!
|
||||
//! [`Ulid`]: ulid::Ulid
|
||||
//! [`Uuid`]: uuid::Uuid
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(
|
||||
clippy::all,
|
||||
clippy::str_to_string,
|
||||
clippy::future_not_send,
|
||||
rustdoc::broken_intra_doc_links,
|
||||
missing_docs
|
||||
)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
use sqlx::migrate::Migrator;
|
||||
|
||||
/// An extension trait for [`Result`] which adds a [`to_option`] method, useful
|
||||
/// for handling "not found" errors from [`sqlx`]
|
||||
///
|
||||
/// [`to_option`]: LookupResultExt::to_option
|
||||
pub trait LookupResultExt {
|
||||
/// The output type
|
||||
type Output;
|
||||
|
||||
/// Transform a [`Result`] from a sqlx query to transform "not found" errors
|
||||
/// into [`None`]
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns the original error if the error was not a
|
||||
/// [`sqlx::Error::RowNotFound`] error
|
||||
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error>;
|
||||
}
|
||||
|
||||
impl<T> LookupResultExt for Result<T, sqlx::Error> {
|
||||
type Output = T;
|
||||
|
||||
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error> {
|
||||
match self {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(sqlx::Error::RowNotFound) => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod compat;
|
||||
pub mod oauth2;
|
||||
pub mod upstream_oauth2;
|
||||
pub mod user;
|
||||
|
||||
mod errors;
|
||||
pub(crate) mod pagination;
|
||||
pub(crate) mod repository;
|
||||
pub(crate) mod tracing;
|
||||
|
||||
pub(crate) use self::errors::DatabaseInconsistencyError;
|
||||
pub use self::{errors::DatabaseError, repository::PgRepository, tracing::ExecuteExt};
|
||||
|
||||
/// Embedded migrations, allowing them to run on startup
|
||||
pub static MIGRATOR: Migrator = sqlx::migrate!();
|
||||
227
crates/storage-pg/src/oauth2/access_token.rs
Normal file
227
crates/storage-pg/src/oauth2/access_token.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{AccessToken, AccessTokenState, Session};
|
||||
use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
|
||||
|
||||
/// An implementation of [`OAuth2AccessTokenRepository`] for a PostgreSQL
|
||||
/// connection
|
||||
pub struct PgOAuth2AccessTokenRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgOAuth2AccessTokenRepository<'c> {
|
||||
/// Create a new [`PgOAuth2AccessTokenRepository`] from an active PostgreSQL
|
||||
/// connection
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct OAuth2AccessTokenLookup {
|
||||
oauth2_access_token_id: Uuid,
|
||||
oauth2_session_id: Uuid,
|
||||
access_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
expires_at: DateTime<Utc>,
|
||||
revoked_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl From<OAuth2AccessTokenLookup> for AccessToken {
|
||||
fn from(value: OAuth2AccessTokenLookup) -> Self {
|
||||
let state = match value.revoked_at {
|
||||
None => AccessTokenState::Valid,
|
||||
Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
|
||||
};
|
||||
|
||||
Self {
|
||||
id: value.oauth2_access_token_id.into(),
|
||||
state,
|
||||
session_id: value.oauth2_session_id.into(),
|
||||
access_token: value.access_token,
|
||||
created_at: value.created_at,
|
||||
expires_at: value.expires_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2AccessTokenLookup,
|
||||
r#"
|
||||
SELECT oauth2_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, revoked_at
|
||||
, oauth2_session_id
|
||||
|
||||
FROM oauth2_access_tokens
|
||||
|
||||
WHERE oauth2_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_access_token.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
access_token: &str,
|
||||
) -> Result<Option<AccessToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2AccessTokenLookup,
|
||||
r#"
|
||||
SELECT oauth2_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, revoked_at
|
||||
, oauth2_session_id
|
||||
|
||||
FROM oauth2_access_tokens
|
||||
|
||||
WHERE access_token = $1
|
||||
"#,
|
||||
access_token,
|
||||
)
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_access_token.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%session.id,
|
||||
user_session.id = %session.user_session_id,
|
||||
client.id = %session.client_id,
|
||||
access_token.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
session: &Session,
|
||||
access_token: String,
|
||||
expires_after: Duration,
|
||||
) -> Result<AccessToken, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let expires_at = created_at + expires_after;
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
|
||||
tracing::Span::current().record("access_token.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_access_tokens
|
||||
(oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(session.id),
|
||||
&access_token,
|
||||
created_at,
|
||||
expires_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(AccessToken {
|
||||
id,
|
||||
state: AccessTokenState::default(),
|
||||
access_token,
|
||||
session_id: session.id,
|
||||
created_at,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
|
||||
async fn revoke(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
access_token: AccessToken,
|
||||
) -> Result<AccessToken, Self::Error> {
|
||||
let revoked_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_access_tokens
|
||||
SET revoked_at = $2
|
||||
WHERE oauth2_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(access_token.id),
|
||||
revoked_at,
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
access_token
|
||||
.revoke(revoked_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
|
||||
async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result<usize, Self::Error> {
|
||||
// Cleanup token which expired more than 15 minutes ago
|
||||
let threshold = clock.now() - Duration::minutes(15);
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM oauth2_access_tokens
|
||||
WHERE expires_at < $1
|
||||
"#,
|
||||
threshold,
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
|
||||
}
|
||||
}
|
||||
514
crates/storage-pg/src/oauth2/authorization_grant.rs
Normal file
514
crates/storage-pg/src/oauth2/authorization_grant.rs
Normal file
@@ -0,0 +1,514 @@
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::num::NonZeroU32;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
|
||||
};
|
||||
use mas_iana::oauth::PkceCodeChallengeMethod;
|
||||
use mas_storage::{oauth2::OAuth2AuthorizationGrantRepository, Clock};
|
||||
use oauth2_types::{requests::ResponseMode, scope::Scope};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
|
||||
|
||||
/// An implementation of [`OAuth2AuthorizationGrantRepository`] for a PostgreSQL
|
||||
/// connection
|
||||
pub struct PgOAuth2AuthorizationGrantRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
|
||||
/// Create a new [`PgOAuth2AuthorizationGrantRepository`] from an active
|
||||
/// PostgreSQL connection
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
struct GrantLookup {
|
||||
oauth2_authorization_grant_id: Uuid,
|
||||
created_at: DateTime<Utc>,
|
||||
cancelled_at: Option<DateTime<Utc>>,
|
||||
fulfilled_at: Option<DateTime<Utc>>,
|
||||
exchanged_at: Option<DateTime<Utc>>,
|
||||
scope: String,
|
||||
state: Option<String>,
|
||||
nonce: Option<String>,
|
||||
redirect_uri: String,
|
||||
response_mode: String,
|
||||
max_age: Option<i32>,
|
||||
response_type_code: bool,
|
||||
response_type_id_token: bool,
|
||||
authorization_code: Option<String>,
|
||||
code_challenge: Option<String>,
|
||||
code_challenge_method: Option<String>,
|
||||
requires_consent: bool,
|
||||
oauth2_client_id: Uuid,
|
||||
oauth2_session_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
impl TryFrom<GrantLookup> for AuthorizationGrant {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.oauth2_authorization_grant_id.into();
|
||||
let scope: Scope = value.scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("scope")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let stage = match (
|
||||
value.fulfilled_at,
|
||||
value.exchanged_at,
|
||||
value.cancelled_at,
|
||||
value.oauth2_session_id,
|
||||
) {
|
||||
(None, None, None, None) => AuthorizationGrantStage::Pending,
|
||||
(Some(fulfilled_at), None, None, Some(session_id)) => {
|
||||
AuthorizationGrantStage::Fulfilled {
|
||||
session_id: session_id.into(),
|
||||
fulfilled_at,
|
||||
}
|
||||
}
|
||||
(Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
|
||||
AuthorizationGrantStage::Exchanged {
|
||||
session_id: session_id.into(),
|
||||
fulfilled_at,
|
||||
exchanged_at,
|
||||
}
|
||||
}
|
||||
(None, None, Some(cancelled_at), None) => {
|
||||
AuthorizationGrantStage::Cancelled { cancelled_at }
|
||||
}
|
||||
_ => {
|
||||
return Err(
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("stage")
|
||||
.row(id),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let pkce = match (value.code_challenge, value.code_challenge_method) {
|
||||
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
|
||||
Some(Pkce {
|
||||
challenge_method: PkceCodeChallengeMethod::Plain,
|
||||
challenge,
|
||||
})
|
||||
}
|
||||
(Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
|
||||
challenge_method: PkceCodeChallengeMethod::S256,
|
||||
challenge,
|
||||
}),
|
||||
(None, None) => None,
|
||||
_ => {
|
||||
return Err(
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("code_challenge_method")
|
||||
.row(id),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let code: Option<AuthorizationCode> =
|
||||
match (value.response_type_code, value.authorization_code, pkce) {
|
||||
(false, None, None) => None,
|
||||
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
|
||||
_ => {
|
||||
return Err(
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("authorization_code")
|
||||
.row(id),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let redirect_uri = value.redirect_uri.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("redirect_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let response_mode = value.response_mode.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("response_mode")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let max_age = value
|
||||
.max_age
|
||||
.map(u32::try_from)
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("max_age")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?
|
||||
.map(NonZeroU32::try_from)
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("max_age")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Ok(AuthorizationGrant {
|
||||
id,
|
||||
stage,
|
||||
client_id: value.oauth2_client_id.into(),
|
||||
code,
|
||||
scope,
|
||||
state: value.state,
|
||||
nonce: value.nonce,
|
||||
max_age,
|
||||
response_mode,
|
||||
redirect_uri,
|
||||
created_at: value.created_at,
|
||||
response_type_id_token: value.response_type_id_token,
|
||||
requires_consent: value.requires_consent,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_authorization_grant.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
grant.id,
|
||||
grant.scope = %scope,
|
||||
%client.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
client: &Client,
|
||||
redirect_uri: Url,
|
||||
scope: Scope,
|
||||
code: Option<AuthorizationCode>,
|
||||
state: Option<String>,
|
||||
nonce: Option<String>,
|
||||
max_age: Option<NonZeroU32>,
|
||||
response_mode: ResponseMode,
|
||||
response_type_id_token: bool,
|
||||
requires_consent: bool,
|
||||
) -> Result<AuthorizationGrant, Self::Error> {
|
||||
let code_challenge = code
|
||||
.as_ref()
|
||||
.and_then(|c| c.pkce.as_ref())
|
||||
.map(|p| &p.challenge);
|
||||
let code_challenge_method = code
|
||||
.as_ref()
|
||||
.and_then(|c| c.pkce.as_ref())
|
||||
.map(|p| p.challenge_method.to_string());
|
||||
// TODO: this conversion is a bit ugly
|
||||
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
|
||||
let code_str = code.as_ref().map(|c| &c.code);
|
||||
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("grant.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_authorization_grants (
|
||||
oauth2_authorization_grant_id,
|
||||
oauth2_client_id,
|
||||
redirect_uri,
|
||||
scope,
|
||||
state,
|
||||
nonce,
|
||||
max_age,
|
||||
response_mode,
|
||||
code_challenge,
|
||||
code_challenge_method,
|
||||
response_type_code,
|
||||
response_type_id_token,
|
||||
authorization_code,
|
||||
requires_consent,
|
||||
created_at
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(client.id),
|
||||
redirect_uri.to_string(),
|
||||
scope.to_string(),
|
||||
state,
|
||||
nonce,
|
||||
max_age_i32,
|
||||
response_mode.to_string(),
|
||||
code_challenge,
|
||||
code_challenge_method,
|
||||
code.is_some(),
|
||||
response_type_id_token,
|
||||
code_str,
|
||||
requires_consent,
|
||||
created_at,
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(AuthorizationGrant {
|
||||
id,
|
||||
stage: AuthorizationGrantStage::Pending,
|
||||
code,
|
||||
redirect_uri,
|
||||
client_id: client.id,
|
||||
scope,
|
||||
state,
|
||||
nonce,
|
||||
max_age,
|
||||
response_mode,
|
||||
created_at,
|
||||
response_type_id_token,
|
||||
requires_consent,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_authorization_grant.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
grant.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
GrantLookup,
|
||||
r#"
|
||||
SELECT oauth2_authorization_grant_id
|
||||
, created_at
|
||||
, cancelled_at
|
||||
, fulfilled_at
|
||||
, exchanged_at
|
||||
, scope
|
||||
, state
|
||||
, redirect_uri
|
||||
, response_mode
|
||||
, nonce
|
||||
, max_age
|
||||
, oauth2_client_id
|
||||
, authorization_code
|
||||
, response_type_code
|
||||
, response_type_id_token
|
||||
, code_challenge
|
||||
, code_challenge_method
|
||||
, requires_consent
|
||||
, oauth2_session_id
|
||||
FROM
|
||||
oauth2_authorization_grants
|
||||
|
||||
WHERE oauth2_authorization_grant_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_authorization_grant.find_by_code",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_code(
|
||||
&mut self,
|
||||
code: &str,
|
||||
) -> Result<Option<AuthorizationGrant>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
GrantLookup,
|
||||
r#"
|
||||
SELECT oauth2_authorization_grant_id
|
||||
, created_at
|
||||
, cancelled_at
|
||||
, fulfilled_at
|
||||
, exchanged_at
|
||||
, scope
|
||||
, state
|
||||
, redirect_uri
|
||||
, response_mode
|
||||
, nonce
|
||||
, max_age
|
||||
, oauth2_client_id
|
||||
, authorization_code
|
||||
, response_type_code
|
||||
, response_type_id_token
|
||||
, code_challenge
|
||||
, code_challenge_method
|
||||
, requires_consent
|
||||
, oauth2_session_id
|
||||
FROM
|
||||
oauth2_authorization_grants
|
||||
|
||||
WHERE authorization_code = $1
|
||||
"#,
|
||||
code,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_authorization_grant.fulfill",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%grant.id,
|
||||
client.id = %grant.client_id,
|
||||
%session.id,
|
||||
user_session.id = %session.user_session_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn fulfill(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
session: &Session,
|
||||
grant: AuthorizationGrant,
|
||||
) -> Result<AuthorizationGrant, Self::Error> {
|
||||
let fulfilled_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_authorization_grants
|
||||
SET fulfilled_at = $2
|
||||
, oauth2_session_id = $3
|
||||
WHERE oauth2_authorization_grant_id = $1
|
||||
"#,
|
||||
Uuid::from(grant.id),
|
||||
fulfilled_at,
|
||||
Uuid::from(session.id),
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
// XXX: check affected rows & new methods
|
||||
let grant = grant
|
||||
.fulfill(fulfilled_at, session)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(grant)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_authorization_grant.exchange",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%grant.id,
|
||||
client.id = %grant.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn exchange(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
grant: AuthorizationGrant,
|
||||
) -> Result<AuthorizationGrant, Self::Error> {
|
||||
let exchanged_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_authorization_grants
|
||||
SET exchanged_at = $2
|
||||
WHERE oauth2_authorization_grant_id = $1
|
||||
"#,
|
||||
Uuid::from(grant.id),
|
||||
exchanged_at,
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
let grant = grant
|
||||
.exchange(exchanged_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(grant)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_authorization_grant.give_consent",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%grant.id,
|
||||
client.id = %grant.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn give_consent(
|
||||
&mut self,
|
||||
mut grant: AuthorizationGrant,
|
||||
) -> Result<AuthorizationGrant, Self::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_authorization_grants AS og
|
||||
SET
|
||||
requires_consent = 'f'
|
||||
WHERE
|
||||
og.oauth2_authorization_grant_id = $1
|
||||
"#,
|
||||
Uuid::from(grant.id),
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
grant.requires_consent = false;
|
||||
|
||||
Ok(grant)
|
||||
}
|
||||
}
|
||||
748
crates/storage-pg/src/oauth2/client.rs
Normal file
748
crates/storage-pg/src/oauth2/client.rs
Normal file
@@ -0,0 +1,748 @@
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
str::FromStr,
|
||||
string::ToString,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mas_data_model::{Client, JwksOrJwksUri, User};
|
||||
use mas_iana::{
|
||||
jose::JsonWebSignatureAlg,
|
||||
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
|
||||
};
|
||||
use mas_jose::jwk::PublicJsonWebKeySet;
|
||||
use mas_storage::{oauth2::OAuth2ClientRepository, Clock};
|
||||
use oauth2_types::{
|
||||
requests::GrantType,
|
||||
scope::{Scope, ScopeToken},
|
||||
};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use tracing::{info_span, Instrument};
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
|
||||
|
||||
/// An implementation of [`OAuth2ClientRepository`] for a PostgreSQL connection
|
||||
pub struct PgOAuth2ClientRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgOAuth2ClientRepository<'c> {
|
||||
/// Create a new [`PgOAuth2ClientRepository`] from an active PostgreSQL
|
||||
/// connection
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
// XXX: response_types & contacts
|
||||
#[derive(Debug)]
|
||||
struct OAuth2ClientLookup {
|
||||
oauth2_client_id: Uuid,
|
||||
encrypted_client_secret: Option<String>,
|
||||
redirect_uris: Vec<String>,
|
||||
// response_types: Vec<String>,
|
||||
grant_type_authorization_code: bool,
|
||||
grant_type_refresh_token: bool,
|
||||
// contacts: Vec<String>,
|
||||
client_name: Option<String>,
|
||||
logo_uri: Option<String>,
|
||||
client_uri: Option<String>,
|
||||
policy_uri: Option<String>,
|
||||
tos_uri: Option<String>,
|
||||
jwks_uri: Option<String>,
|
||||
jwks: Option<serde_json::Value>,
|
||||
id_token_signed_response_alg: Option<String>,
|
||||
userinfo_signed_response_alg: Option<String>,
|
||||
token_endpoint_auth_method: Option<String>,
|
||||
token_endpoint_auth_signing_alg: Option<String>,
|
||||
initiate_login_uri: Option<String>,
|
||||
}
|
||||
|
||||
impl TryInto<Client> for OAuth2ClientLookup {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
#[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
|
||||
fn try_into(self) -> Result<Client, Self::Error> {
|
||||
let id = Ulid::from(self.oauth2_client_id);
|
||||
|
||||
let redirect_uris: Result<Vec<Url>, _> =
|
||||
self.redirect_uris.iter().map(|s| s.parse()).collect();
|
||||
let redirect_uris = redirect_uris.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("redirect_uris")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let response_types = vec![
|
||||
OAuthAuthorizationEndpointResponseType::Code,
|
||||
OAuthAuthorizationEndpointResponseType::IdToken,
|
||||
OAuthAuthorizationEndpointResponseType::None,
|
||||
];
|
||||
/* XXX
|
||||
let response_types: Result<Vec<OAuthAuthorizationEndpointResponseType>, _> =
|
||||
self.response_types.iter().map(|s| s.parse()).collect();
|
||||
let response_types = response_types.map_err(|source| ClientFetchError::ParseField {
|
||||
field: "response_types",
|
||||
source,
|
||||
})?;
|
||||
*/
|
||||
|
||||
let mut grant_types = Vec::new();
|
||||
if self.grant_type_authorization_code {
|
||||
grant_types.push(GrantType::AuthorizationCode);
|
||||
}
|
||||
if self.grant_type_refresh_token {
|
||||
grant_types.push(GrantType::RefreshToken);
|
||||
}
|
||||
|
||||
let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("logo_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let client_uri = self
|
||||
.client_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("client_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let policy_uri = self
|
||||
.policy_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("policy_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("tos_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let id_token_signed_response_alg = self
|
||||
.id_token_signed_response_alg
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("id_token_signed_response_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let userinfo_signed_response_alg = self
|
||||
.userinfo_signed_response_alg
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("userinfo_signed_response_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let token_endpoint_auth_method = self
|
||||
.token_endpoint_auth_method
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("token_endpoint_auth_method")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let token_endpoint_auth_signing_alg = self
|
||||
.token_endpoint_auth_signing_alg
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("token_endpoint_auth_signing_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let initiate_login_uri = self
|
||||
.initiate_login_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("initiate_login_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let jwks = match (self.jwks, self.jwks_uri) {
|
||||
(None, None) => None,
|
||||
(Some(jwks), None) => {
|
||||
let jwks = serde_json::from_value(jwks).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("jwks")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
Some(JwksOrJwksUri::Jwks(jwks))
|
||||
}
|
||||
(None, Some(jwks_uri)) => {
|
||||
let jwks_uri = jwks_uri.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("jwks_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Some(JwksOrJwksUri::JwksUri(jwks_uri))
|
||||
}
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("jwks(_uri)")
|
||||
.row(id))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Client {
|
||||
id,
|
||||
client_id: id.to_string(),
|
||||
encrypted_client_secret: self.encrypted_client_secret,
|
||||
redirect_uris,
|
||||
response_types,
|
||||
grant_types,
|
||||
// contacts: self.contacts,
|
||||
contacts: vec![],
|
||||
client_name: self.client_name,
|
||||
logo_uri,
|
||||
client_uri,
|
||||
policy_uri,
|
||||
tos_uri,
|
||||
jwks,
|
||||
id_token_signed_response_alg,
|
||||
userinfo_signed_response_alg,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_auth_signing_alg,
|
||||
initiate_login_uri,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
oauth2_client.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2ClientLookup,
|
||||
r#"
|
||||
SELECT oauth2_client_id
|
||||
, encrypted_client_secret
|
||||
, ARRAY(
|
||||
SELECT redirect_uri
|
||||
FROM oauth2_client_redirect_uris r
|
||||
WHERE r.oauth2_client_id = c.oauth2_client_id
|
||||
) AS "redirect_uris!"
|
||||
, grant_type_authorization_code
|
||||
, grant_type_refresh_token
|
||||
, client_name
|
||||
, logo_uri
|
||||
, client_uri
|
||||
, policy_uri
|
||||
, tos_uri
|
||||
, jwks_uri
|
||||
, jwks
|
||||
, id_token_signed_response_alg
|
||||
, userinfo_signed_response_alg
|
||||
, token_endpoint_auth_method
|
||||
, token_endpoint_auth_signing_alg
|
||||
, initiate_login_uri
|
||||
FROM oauth2_clients c
|
||||
|
||||
WHERE oauth2_client_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.load_batch",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn load_batch(
|
||||
&mut self,
|
||||
ids: BTreeSet<Ulid>,
|
||||
) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
|
||||
let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2ClientLookup,
|
||||
r#"
|
||||
SELECT oauth2_client_id
|
||||
, encrypted_client_secret
|
||||
, ARRAY(
|
||||
SELECT redirect_uri
|
||||
FROM oauth2_client_redirect_uris r
|
||||
WHERE r.oauth2_client_id = c.oauth2_client_id
|
||||
) AS "redirect_uris!"
|
||||
, grant_type_authorization_code
|
||||
, grant_type_refresh_token
|
||||
, client_name
|
||||
, logo_uri
|
||||
, client_uri
|
||||
, policy_uri
|
||||
, tos_uri
|
||||
, jwks_uri
|
||||
, jwks
|
||||
, id_token_signed_response_alg
|
||||
, userinfo_signed_response_alg
|
||||
, token_endpoint_auth_method
|
||||
, token_endpoint_auth_signing_alg
|
||||
, initiate_login_uri
|
||||
FROM oauth2_clients c
|
||||
|
||||
WHERE oauth2_client_id = ANY($1::uuid[])
|
||||
"#,
|
||||
&ids,
|
||||
)
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
res.into_iter()
|
||||
.map(|r| {
|
||||
r.try_into()
|
||||
.map(|c: Client| (c.id, c))
|
||||
.map_err(DatabaseError::from)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
client.id,
|
||||
client.name = client_name
|
||||
),
|
||||
err,
|
||||
)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
mut rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
redirect_uris: Vec<Url>,
|
||||
encrypted_client_secret: Option<String>,
|
||||
grant_types: Vec<GrantType>,
|
||||
contacts: Vec<String>,
|
||||
client_name: Option<String>,
|
||||
logo_uri: Option<Url>,
|
||||
client_uri: Option<Url>,
|
||||
policy_uri: Option<Url>,
|
||||
tos_uri: Option<Url>,
|
||||
jwks_uri: Option<Url>,
|
||||
jwks: Option<PublicJsonWebKeySet>,
|
||||
id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
|
||||
token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
initiate_login_uri: Option<Url>,
|
||||
) -> Result<Client, Self::Error> {
|
||||
let now = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(now.into(), rng);
|
||||
tracing::Span::current().record("client.id", tracing::field::display(id));
|
||||
|
||||
let jwks_json = jwks
|
||||
.as_ref()
|
||||
.map(serde_json::to_value)
|
||||
.transpose()
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_clients
|
||||
( oauth2_client_id
|
||||
, encrypted_client_secret
|
||||
, grant_type_authorization_code
|
||||
, grant_type_refresh_token
|
||||
, client_name
|
||||
, logo_uri
|
||||
, client_uri
|
||||
, policy_uri
|
||||
, tos_uri
|
||||
, jwks_uri
|
||||
, jwks
|
||||
, id_token_signed_response_alg
|
||||
, userinfo_signed_response_alg
|
||||
, token_endpoint_auth_method
|
||||
, token_endpoint_auth_signing_alg
|
||||
, initiate_login_uri
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
encrypted_client_secret,
|
||||
grant_types.contains(&GrantType::AuthorizationCode),
|
||||
grant_types.contains(&GrantType::RefreshToken),
|
||||
client_name,
|
||||
logo_uri.as_ref().map(Url::as_str),
|
||||
client_uri.as_ref().map(Url::as_str),
|
||||
policy_uri.as_ref().map(Url::as_str),
|
||||
tos_uri.as_ref().map(Url::as_str),
|
||||
jwks_uri.as_ref().map(Url::as_str),
|
||||
jwks_json,
|
||||
id_token_signed_response_alg
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
userinfo_signed_response_alg
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
token_endpoint_auth_method.as_ref().map(ToString::to_string),
|
||||
token_endpoint_auth_signing_alg
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
initiate_login_uri.as_ref().map(Url::as_str),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
{
|
||||
let span = info_span!(
|
||||
"db.oauth2_client.add.redirect_uris",
|
||||
db.statement = tracing::field::Empty,
|
||||
client.id = %id,
|
||||
);
|
||||
|
||||
let (uri_ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
|
||||
.iter()
|
||||
.map(|uri| {
|
||||
(
|
||||
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
|
||||
uri.as_str().to_owned(),
|
||||
)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_client_redirect_uris
|
||||
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
|
||||
SELECT id, $2, redirect_uri
|
||||
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
|
||||
"#,
|
||||
&uri_ids,
|
||||
Uuid::from(id),
|
||||
&redirect_uris,
|
||||
)
|
||||
.record(&span)
|
||||
.execute(&mut *self.conn)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let jwks = match (jwks, jwks_uri) {
|
||||
(None, None) => None,
|
||||
(Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
|
||||
(None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
|
||||
_ => return Err(DatabaseError::invalid_operation()),
|
||||
};
|
||||
|
||||
Ok(Client {
|
||||
id,
|
||||
client_id: id.to_string(),
|
||||
encrypted_client_secret,
|
||||
redirect_uris,
|
||||
response_types: vec![
|
||||
OAuthAuthorizationEndpointResponseType::Code,
|
||||
OAuthAuthorizationEndpointResponseType::IdToken,
|
||||
OAuthAuthorizationEndpointResponseType::None,
|
||||
],
|
||||
grant_types,
|
||||
contacts,
|
||||
client_name,
|
||||
logo_uri,
|
||||
client_uri,
|
||||
policy_uri,
|
||||
tos_uri,
|
||||
jwks,
|
||||
id_token_signed_response_alg,
|
||||
userinfo_signed_response_alg,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_auth_signing_alg,
|
||||
initiate_login_uri,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.add_from_config",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
client.id = %client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add_from_config(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
client_id: Ulid,
|
||||
client_auth_method: OAuthClientAuthenticationMethod,
|
||||
encrypted_client_secret: Option<String>,
|
||||
jwks: Option<PublicJsonWebKeySet>,
|
||||
jwks_uri: Option<Url>,
|
||||
redirect_uris: Vec<Url>,
|
||||
) -> Result<Client, Self::Error> {
|
||||
let jwks_json = jwks
|
||||
.as_ref()
|
||||
.map(serde_json::to_value)
|
||||
.transpose()
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
let client_auth_method = client_auth_method.to_string();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_clients
|
||||
( oauth2_client_id
|
||||
, encrypted_client_secret
|
||||
, grant_type_authorization_code
|
||||
, grant_type_refresh_token
|
||||
, token_endpoint_auth_method
|
||||
, jwks
|
||||
, jwks_uri
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (oauth2_client_id)
|
||||
DO
|
||||
UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
|
||||
, grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
|
||||
, grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
|
||||
, token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
|
||||
, jwks = EXCLUDED.jwks
|
||||
, jwks_uri = EXCLUDED.jwks_uri
|
||||
"#,
|
||||
Uuid::from(client_id),
|
||||
encrypted_client_secret,
|
||||
true,
|
||||
true,
|
||||
client_auth_method,
|
||||
jwks_json,
|
||||
jwks_uri.as_ref().map(Url::as_str),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
{
|
||||
let span = info_span!(
|
||||
"db.oauth2_client.add_from_config.redirect_uris",
|
||||
client.id = %client_id,
|
||||
db.statement = tracing::field::Empty,
|
||||
);
|
||||
|
||||
let now = clock.now();
|
||||
let (ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
|
||||
.iter()
|
||||
.map(|uri| {
|
||||
(
|
||||
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut *rng)),
|
||||
uri.as_str().to_owned(),
|
||||
)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_client_redirect_uris
|
||||
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
|
||||
SELECT id, $2, redirect_uri
|
||||
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
|
||||
"#,
|
||||
&ids,
|
||||
Uuid::from(client_id),
|
||||
&redirect_uris,
|
||||
)
|
||||
.record(&span)
|
||||
.execute(&mut *self.conn)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let jwks = match (jwks, jwks_uri) {
|
||||
(None, None) => None,
|
||||
(Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
|
||||
(None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
|
||||
_ => return Err(DatabaseError::invalid_operation()),
|
||||
};
|
||||
|
||||
Ok(Client {
|
||||
id: client_id,
|
||||
client_id: client_id.to_string(),
|
||||
encrypted_client_secret,
|
||||
redirect_uris,
|
||||
response_types: vec![
|
||||
OAuthAuthorizationEndpointResponseType::Code,
|
||||
OAuthAuthorizationEndpointResponseType::IdToken,
|
||||
OAuthAuthorizationEndpointResponseType::None,
|
||||
],
|
||||
grant_types: Vec::new(),
|
||||
contacts: Vec::new(),
|
||||
client_name: None,
|
||||
logo_uri: None,
|
||||
client_uri: None,
|
||||
policy_uri: None,
|
||||
tos_uri: None,
|
||||
jwks,
|
||||
id_token_signed_response_alg: None,
|
||||
userinfo_signed_response_alg: None,
|
||||
token_endpoint_auth_method: None,
|
||||
token_endpoint_auth_signing_alg: None,
|
||||
initiate_login_uri: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.get_consent_for_user",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%client.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn get_consent_for_user(
|
||||
&mut self,
|
||||
client: &Client,
|
||||
user: &User,
|
||||
) -> Result<Scope, Self::Error> {
|
||||
let scope_tokens: Vec<String> = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT scope_token
|
||||
FROM oauth2_consents
|
||||
WHERE user_id = $1 AND oauth2_client_id = $2
|
||||
"#,
|
||||
Uuid::from(user.id),
|
||||
Uuid::from(client.id),
|
||||
)
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let scope: Result<Scope, _> = scope_tokens
|
||||
.into_iter()
|
||||
.map(|s| ScopeToken::from_str(&s))
|
||||
.collect();
|
||||
|
||||
let scope = scope.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_consents")
|
||||
.column("scope_token")
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Ok(scope)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%client.id,
|
||||
%scope,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn give_consent_for_user(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
client: &Client,
|
||||
user: &User,
|
||||
scope: &Scope,
|
||||
) -> Result<(), Self::Error> {
|
||||
let now = clock.now();
|
||||
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
|
||||
.iter()
|
||||
.map(|token| {
|
||||
(
|
||||
token.to_string(),
|
||||
Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)),
|
||||
)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_consents
|
||||
(oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
|
||||
SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
|
||||
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
|
||||
"#,
|
||||
&ids,
|
||||
Uuid::from(user.id),
|
||||
Uuid::from(client.id),
|
||||
&tokens,
|
||||
now,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
371
crates/storage-pg/src/oauth2/mod.rs
Normal file
371
crates/storage-pg/src/oauth2/mod.rs
Normal file
@@ -0,0 +1,371 @@
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! A module containing the PostgreSQL implementations of the OAuth2-related
|
||||
//! repositories
|
||||
|
||||
mod access_token;
|
||||
mod authorization_grant;
|
||||
mod client;
|
||||
mod refresh_token;
|
||||
mod session;
|
||||
|
||||
pub use self::{
|
||||
access_token::PgOAuth2AccessTokenRepository,
|
||||
authorization_grant::PgOAuth2AuthorizationGrantRepository, client::PgOAuth2ClientRepository,
|
||||
refresh_token::PgOAuth2RefreshTokenRepository, session::PgOAuth2SessionRepository,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use chrono::Duration;
|
||||
use mas_data_model::AuthorizationCode;
|
||||
use mas_storage::{clock::MockClock, Clock, Pagination, Repository};
|
||||
use oauth2_types::{
|
||||
requests::{GrantType, ResponseMode},
|
||||
scope::{Scope, OPENID},
|
||||
};
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaChaRng;
|
||||
use sqlx::PgPool;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::PgRepository;
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_repositories(pool: PgPool) {
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = MockClock::default();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
|
||||
|
||||
// Lookup a non-existing client
|
||||
let client = repo.oauth2_client().lookup(Ulid::nil()).await.unwrap();
|
||||
assert_eq!(client, None);
|
||||
|
||||
// Find a non-existing client by client id
|
||||
let client = repo
|
||||
.oauth2_client()
|
||||
.find_by_client_id("some-client-id")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(client, None);
|
||||
|
||||
// Create a client
|
||||
let client = repo
|
||||
.oauth2_client()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
vec!["https://example.com/redirect".parse().unwrap()],
|
||||
None,
|
||||
vec![GrantType::AuthorizationCode],
|
||||
Vec::new(), // TODO: contacts are not yet saved
|
||||
// vec!["contact@example.com".to_owned()],
|
||||
Some("Test client".to_owned()),
|
||||
Some("https://example.com/logo.png".parse().unwrap()),
|
||||
Some("https://example.com/".parse().unwrap()),
|
||||
Some("https://example.com/policy".parse().unwrap()),
|
||||
Some("https://example.com/tos".parse().unwrap()),
|
||||
Some("https://example.com/jwks.json".parse().unwrap()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some("https://example.com/login".parse().unwrap()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Lookup the same client by id
|
||||
let client_lookup = repo
|
||||
.oauth2_client()
|
||||
.lookup(client.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("client not found");
|
||||
assert_eq!(client, client_lookup);
|
||||
|
||||
// Find the same client by client id
|
||||
let client_lookup = repo
|
||||
.oauth2_client()
|
||||
.find_by_client_id(&client.client_id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("client not found");
|
||||
assert_eq!(client, client_lookup);
|
||||
|
||||
// Lookup a non-existing grant
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.lookup(Ulid::nil())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(grant, None);
|
||||
|
||||
// Find a non-existing grant by code
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.find_by_code("code")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(grant, None);
|
||||
|
||||
// Create an authorization grant
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&client,
|
||||
"https://example.com/redirect".parse().unwrap(),
|
||||
Scope::from_iter([OPENID]),
|
||||
Some(AuthorizationCode {
|
||||
code: "code".to_owned(),
|
||||
pkce: None,
|
||||
}),
|
||||
Some("state".to_owned()),
|
||||
Some("nonce".to_owned()),
|
||||
None,
|
||||
ResponseMode::Query,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(grant.is_pending());
|
||||
|
||||
// Lookup the same grant by id
|
||||
let grant_lookup = repo
|
||||
.oauth2_authorization_grant()
|
||||
.lookup(grant.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("grant not found");
|
||||
assert_eq!(grant, grant_lookup);
|
||||
|
||||
// Find the same grant by code
|
||||
let grant_lookup = repo
|
||||
.oauth2_authorization_grant()
|
||||
.find_by_code("code")
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("grant not found");
|
||||
assert_eq!(grant, grant_lookup);
|
||||
|
||||
// Create a user and a start a user session
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, "john".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
let user_session = repo
|
||||
.browser_session()
|
||||
.add(&mut rng, &clock, &user)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Lookup the consent the user gave to the client
|
||||
let consent = repo
|
||||
.oauth2_client()
|
||||
.get_consent_for_user(&client, &user)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(consent.is_empty());
|
||||
|
||||
// Give consent to the client
|
||||
let scope = Scope::from_iter([OPENID]);
|
||||
repo.oauth2_client()
|
||||
.give_consent_for_user(&mut rng, &clock, &client, &user, &scope)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Lookup the consent the user gave to the client
|
||||
let consent = repo
|
||||
.oauth2_client()
|
||||
.get_consent_for_user(&client, &user)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(scope, consent);
|
||||
|
||||
// Lookup a non-existing session
|
||||
let session = repo.oauth2_session().lookup(Ulid::nil()).await.unwrap();
|
||||
assert_eq!(session, None);
|
||||
|
||||
// Create a session out of the grant
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.create_from_grant(&mut rng, &clock, &grant, &user_session)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Mark the grant as fulfilled
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.fulfill(&clock, &session, grant)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(grant.is_fulfilled());
|
||||
|
||||
// Lookup the same session by id
|
||||
let session_lookup = repo
|
||||
.oauth2_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("session not found");
|
||||
assert_eq!(session, session_lookup);
|
||||
|
||||
// Mark the grant as exchanged
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.exchange(&clock, grant)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(grant.is_exchanged());
|
||||
|
||||
// Lookup a non-existing token
|
||||
let token = repo
|
||||
.oauth2_access_token()
|
||||
.lookup(Ulid::nil())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(token, None);
|
||||
|
||||
// Find a non-existing token
|
||||
let token = repo
|
||||
.oauth2_access_token()
|
||||
.find_by_token("aabbcc")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(token, None);
|
||||
|
||||
// Create an access token
|
||||
let access_token = repo
|
||||
.oauth2_access_token()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
"aabbcc".to_owned(),
|
||||
Duration::minutes(5),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Lookup the same token by id
|
||||
let access_token_lookup = repo
|
||||
.oauth2_access_token()
|
||||
.lookup(access_token.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("token not found");
|
||||
assert_eq!(access_token, access_token_lookup);
|
||||
|
||||
// Find the same token by token
|
||||
let access_token_lookup = repo
|
||||
.oauth2_access_token()
|
||||
.find_by_token("aabbcc")
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("token not found");
|
||||
assert_eq!(access_token, access_token_lookup);
|
||||
|
||||
// Lookup a non-existing refresh token
|
||||
let refresh_token = repo
|
||||
.oauth2_refresh_token()
|
||||
.lookup(Ulid::nil())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(refresh_token, None);
|
||||
|
||||
// Find a non-existing refresh token
|
||||
let refresh_token = repo
|
||||
.oauth2_refresh_token()
|
||||
.find_by_token("aabbcc")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(refresh_token, None);
|
||||
|
||||
// Create a refresh token
|
||||
let refresh_token = repo
|
||||
.oauth2_refresh_token()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
&access_token,
|
||||
"aabbcc".to_owned(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Lookup the same refresh token by id
|
||||
let refresh_token_lookup = repo
|
||||
.oauth2_refresh_token()
|
||||
.lookup(refresh_token.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("refresh token not found");
|
||||
assert_eq!(refresh_token, refresh_token_lookup);
|
||||
|
||||
// Find the same refresh token by token
|
||||
let refresh_token_lookup = repo
|
||||
.oauth2_refresh_token()
|
||||
.find_by_token("aabbcc")
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("refresh token not found");
|
||||
assert_eq!(refresh_token, refresh_token_lookup);
|
||||
|
||||
assert!(access_token.is_valid(clock.now()));
|
||||
clock.advance(Duration::minutes(6));
|
||||
assert!(!access_token.is_valid(clock.now()));
|
||||
|
||||
// XXX: we might want to create a new access token
|
||||
clock.advance(Duration::minutes(-6)); // Go back in time
|
||||
assert!(access_token.is_valid(clock.now()));
|
||||
|
||||
// Mark the access token as revoked
|
||||
let access_token = repo
|
||||
.oauth2_access_token()
|
||||
.revoke(&clock, access_token)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!access_token.is_valid(clock.now()));
|
||||
|
||||
// Mark the refresh token as consumed
|
||||
assert!(refresh_token.is_valid());
|
||||
let refresh_token = repo
|
||||
.oauth2_refresh_token()
|
||||
.consume(&clock, refresh_token)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!refresh_token.is_valid());
|
||||
|
||||
// Mark the session as finished
|
||||
assert!(session.is_valid());
|
||||
let session = repo.oauth2_session().finish(&clock, session).await.unwrap();
|
||||
assert!(!session.is_valid());
|
||||
|
||||
// The session should appear in the paginated list of sessions for the user
|
||||
let sessions = repo
|
||||
.oauth2_session()
|
||||
.list_paginated(&user, Pagination::first(10))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!sessions.has_next_page);
|
||||
assert_eq!(sessions.edges, vec![session]);
|
||||
}
|
||||
}
|
||||
228
crates/storage-pg/src/oauth2/refresh_token.rs
Normal file
228
crates/storage-pg/src/oauth2/refresh_token.rs
Normal file
@@ -0,0 +1,228 @@
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
|
||||
use mas_storage::{oauth2::OAuth2RefreshTokenRepository, Clock};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
|
||||
|
||||
/// An implementation of [`OAuth2RefreshTokenRepository`] for a PostgreSQL
|
||||
/// connection
|
||||
pub struct PgOAuth2RefreshTokenRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgOAuth2RefreshTokenRepository<'c> {
|
||||
/// Create a new [`PgOAuth2RefreshTokenRepository`] from an active
|
||||
/// PostgreSQL connection
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct OAuth2RefreshTokenLookup {
|
||||
oauth2_refresh_token_id: Uuid,
|
||||
refresh_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
oauth2_access_token_id: Option<Uuid>,
|
||||
oauth2_session_id: Uuid,
|
||||
}
|
||||
|
||||
impl From<OAuth2RefreshTokenLookup> for RefreshToken {
|
||||
fn from(value: OAuth2RefreshTokenLookup) -> Self {
|
||||
let state = match value.consumed_at {
|
||||
None => RefreshTokenState::Valid,
|
||||
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
|
||||
};
|
||||
|
||||
RefreshToken {
|
||||
id: value.oauth2_refresh_token_id.into(),
|
||||
state,
|
||||
session_id: value.oauth2_session_id.into(),
|
||||
refresh_token: value.refresh_token,
|
||||
created_at: value.created_at,
|
||||
access_token_id: value.oauth2_access_token_id.map(Ulid::from),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_refresh_token.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
refresh_token.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2RefreshTokenLookup,
|
||||
r#"
|
||||
SELECT oauth2_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, oauth2_access_token_id
|
||||
, oauth2_session_id
|
||||
FROM oauth2_refresh_tokens
|
||||
|
||||
WHERE oauth2_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_refresh_token.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
refresh_token: &str,
|
||||
) -> Result<Option<RefreshToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2RefreshTokenLookup,
|
||||
r#"
|
||||
SELECT oauth2_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, oauth2_access_token_id
|
||||
, oauth2_session_id
|
||||
FROM oauth2_refresh_tokens
|
||||
|
||||
WHERE refresh_token = $1
|
||||
"#,
|
||||
refresh_token,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_refresh_token.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%session.id,
|
||||
user_session.id = %session.user_session_id,
|
||||
client.id = %session.client_id,
|
||||
refresh_token.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
session: &Session,
|
||||
access_token: &AccessToken,
|
||||
refresh_token: String,
|
||||
) -> Result<RefreshToken, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_refresh_tokens
|
||||
(oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
|
||||
refresh_token, created_at)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(session.id),
|
||||
Uuid::from(access_token.id),
|
||||
refresh_token,
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(RefreshToken {
|
||||
id,
|
||||
state: RefreshTokenState::default(),
|
||||
session_id: session.id,
|
||||
refresh_token,
|
||||
access_token_id: Some(access_token.id),
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_refresh_token.consume",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%refresh_token.id,
|
||||
session.id = %refresh_token.session_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn consume(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
refresh_token: RefreshToken,
|
||||
) -> Result<RefreshToken, Self::Error> {
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_refresh_tokens
|
||||
SET consumed_at = $2
|
||||
WHERE oauth2_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(refresh_token.id),
|
||||
consumed_at,
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
refresh_token
|
||||
.consume(consumed_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
}
|
||||
256
crates/storage-pg/src/oauth2/session.rs
Normal file
256
crates/storage-pg/src/oauth2/session.rs
Normal file
@@ -0,0 +1,256 @@
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User};
|
||||
use mas_storage::{oauth2::OAuth2SessionRepository, Clock, Page, Pagination};
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
|
||||
LookupResultExt,
|
||||
};
|
||||
|
||||
/// An implementation of [`OAuth2SessionRepository`] for a PostgreSQL connection
|
||||
pub struct PgOAuth2SessionRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgOAuth2SessionRepository<'c> {
|
||||
/// Create a new [`PgOAuth2SessionRepository`] from an active PostgreSQL
|
||||
/// connection
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct OAuthSessionLookup {
|
||||
oauth2_session_id: Uuid,
|
||||
user_session_id: Uuid,
|
||||
oauth2_client_id: Uuid,
|
||||
scope: String,
|
||||
#[allow(dead_code)]
|
||||
created_at: DateTime<Utc>,
|
||||
finished_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TryFrom<OAuthSessionLookup> for Session {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
|
||||
let id = Ulid::from(value.oauth2_session_id);
|
||||
let scope = value.scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_sessions")
|
||||
.column("scope")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match value.finished_at {
|
||||
None => SessionState::Valid,
|
||||
Some(finished_at) => SessionState::Finished { finished_at },
|
||||
};
|
||||
|
||||
Ok(Session {
|
||||
id,
|
||||
state,
|
||||
created_at: value.created_at,
|
||||
client_id: value.oauth2_client_id.into(),
|
||||
user_session_id: value.user_session_id.into(),
|
||||
scope,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
session.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuthSessionLookup,
|
||||
r#"
|
||||
SELECT oauth2_session_id
|
||||
, user_session_id
|
||||
, oauth2_client_id
|
||||
, scope
|
||||
, created_at
|
||||
, finished_at
|
||||
FROM oauth2_sessions
|
||||
|
||||
WHERE oauth2_session_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(session) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(session.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.create_from_grant",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user_session.id,
|
||||
user.id = %user_session.user.id,
|
||||
%grant.id,
|
||||
client.id = %grant.client_id,
|
||||
session.id,
|
||||
session.scope = %grant.scope,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn create_from_grant(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
grant: &AuthorizationGrant,
|
||||
user_session: &BrowserSession,
|
||||
) -> Result<Session, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("session.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_sessions
|
||||
( oauth2_session_id
|
||||
, user_session_id
|
||||
, oauth2_client_id
|
||||
, scope
|
||||
, created_at
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user_session.id),
|
||||
Uuid::from(grant.client_id),
|
||||
grant.scope.to_string(),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(Session {
|
||||
id,
|
||||
state: SessionState::Valid,
|
||||
created_at,
|
||||
user_session_id: user_session.id,
|
||||
client_id: grant.client_id,
|
||||
scope: grant.scope.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.finish",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%session.id,
|
||||
%session.scope,
|
||||
user_session.id = %session.user_session_id,
|
||||
client.id = %session.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn finish(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
session: Session,
|
||||
) -> Result<Session, Self::Error> {
|
||||
let finished_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_sessions
|
||||
SET finished_at = $2
|
||||
WHERE oauth2_session_id = $1
|
||||
"#,
|
||||
Uuid::from(session.id),
|
||||
finished_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
session
|
||||
.finish(finished_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.list_paginated",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
user: &User,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<Session>, Self::Error> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT oauth2_session_id
|
||||
, user_session_id
|
||||
, oauth2_client_id
|
||||
, scope
|
||||
, os.created_at
|
||||
, os.finished_at
|
||||
FROM oauth2_sessions os
|
||||
INNER JOIN user_sessions USING (user_session_id)
|
||||
"#,
|
||||
);
|
||||
|
||||
query
|
||||
.push(" WHERE user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("oauth2_session_id", pagination);
|
||||
|
||||
let edges: Vec<OAuthSessionLookup> = query
|
||||
.build_query_as()
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let page = pagination.process(edges).try_map(Session::try_from)?;
|
||||
Ok(page)
|
||||
}
|
||||
}
|
||||
78
crates/storage-pg/src/pagination.rs
Normal file
78
crates/storage-pg/src/pagination.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Utilities to manage paginated queries.
|
||||
|
||||
use mas_storage::{pagination::PaginationDirection, Pagination};
|
||||
use sqlx::{Database, QueryBuilder};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination
|
||||
/// to a query
|
||||
pub trait QueryBuilderExt {
|
||||
/// Add cursor-based pagination to a query, as used in paginated GraphQL
|
||||
/// connections
|
||||
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self;
|
||||
}
|
||||
|
||||
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
|
||||
where
|
||||
DB: Database,
|
||||
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||
{
|
||||
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self {
|
||||
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
|
||||
// 1. Start from the greedy query: SELECT * FROM table
|
||||
|
||||
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
|
||||
// clause
|
||||
if let Some(after) = pagination.after {
|
||||
self.push(" AND ")
|
||||
.push(id_field)
|
||||
.push(" > ")
|
||||
.push_bind(Uuid::from(after));
|
||||
}
|
||||
|
||||
// 3. If the before argument is provided, add `id < parsed_cursor` to the
|
||||
// `WHERE` clause
|
||||
if let Some(before) = pagination.before {
|
||||
self.push(" AND ")
|
||||
.push(id_field)
|
||||
.push(" < ")
|
||||
.push_bind(Uuid::from(before));
|
||||
}
|
||||
|
||||
match pagination.direction {
|
||||
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the
|
||||
// query
|
||||
PaginationDirection::Forward => {
|
||||
self.push(" ORDER BY ")
|
||||
.push(id_field)
|
||||
.push(" ASC LIMIT ")
|
||||
.push_bind((pagination.count + 1) as i64);
|
||||
}
|
||||
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the
|
||||
// query
|
||||
PaginationDirection::Backward => {
|
||||
self.push(" ORDER BY ")
|
||||
.push(id_field)
|
||||
.push(" DESC LIMIT ")
|
||||
.push_bind((pagination.count + 1) as i64);
|
||||
}
|
||||
};
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user