diff --git a/crates/cli/src/app_state.rs b/crates/cli/src/app_state.rs index cd4ae44ad..adf257186 100644 --- a/crates/cli/src/app_state.rs +++ b/crates/cli/src/app_state.rs @@ -19,8 +19,8 @@ use mas_keystore::{Encrypter, Keystore}; use mas_matrix::HomeserverConnection; use mas_policy::{Policy, PolicyFactory}; use mas_router::UrlBuilder; -use mas_storage::{BoxClock, BoxRepository, BoxRng, SystemClock}; -use mas_storage_pg::PgRepository; +use mas_storage::{BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, SystemClock, RepositoryFactory}; +use mas_storage_pg::PgRepositoryFactory; use mas_templates::Templates; use opentelemetry::{KeyValue, metrics::Histogram}; use rand::SeedableRng; @@ -31,7 +31,7 @@ use crate::telemetry::METER; #[derive(Clone)] pub struct AppState { - pub pool: PgPool, + pub repository_factory: PgRepositoryFactory, pub templates: Templates, pub key_store: Keystore, pub cookie_manager: CookieManager, @@ -53,7 +53,7 @@ pub struct AppState { impl AppState { /// Init the metrics for the app state. pub fn init_metrics(&mut self) { - let pool = self.pool.clone(); + let pool = self.repository_factory.pool(); METER .i64_observable_up_down_counter("db.connections.usage") .with_description("The number of connections that are currently in `state` described by the state attribute.") @@ -66,7 +66,7 @@ impl AppState { }) .build(); - let pool = self.pool.clone(); + let pool = self.repository_factory.pool(); METER .i64_observable_up_down_counter("db.connections.max") .with_description("The maximum number of open connections allowed.") @@ -88,14 +88,14 @@ impl AppState { /// Init the metadata cache in the background pub fn init_metadata_cache(&self) { - let pool = self.pool.clone(); + let factory = self.repository_factory.clone(); let metadata_cache = self.metadata_cache.clone(); let http_client = self.http_client.clone(); tokio::spawn( LogContext::new("metadata-cache-warmup") .run(async move || { - let conn = match pool.acquire().await { + let mut repo = match factory.create().await { Ok(conn) => conn, Err(e) => { tracing::error!( @@ -106,8 +106,6 @@ impl AppState { } }; - let mut repo = PgRepository::from_conn(conn); - if let Err(e) = metadata_cache .warm_up_and_run( &http_client, @@ -127,9 +125,17 @@ impl AppState { } } +// XXX(quenting): we only use this for the healthcheck endpoint, checking the db +// should be part of the repository impl FromRef for PgPool { fn from_ref(input: &AppState) -> Self { - input.pool.clone() + input.repository_factory.pool() + } +} + +impl FromRef for BoxRepositoryFactory { + fn from_ref(input: &AppState) -> Self { + input.repository_factory.clone().boxed() } } @@ -359,14 +365,14 @@ impl FromRequestParts for RequesterFingerprint { } impl FromRequestParts for BoxRepository { - type Rejection = ErrorWrapper; + type Rejection = ErrorWrapper; async fn from_request_parts( _parts: &mut axum::http::request::Parts, state: &AppState, ) -> Result { let start = Instant::now(); - let repo = PgRepository::from_pool(&state.pool).await?; + let repo = state.repository_factory.create().await?; // Measure the time it took to create the connection let duration = start.elapsed(); @@ -376,6 +382,6 @@ impl FromRequestParts for BoxRepository { histogram.record(duration_ms, &[]); } - Ok(repo.boxed()) + Ok(repo) } } diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index de94ce7b6..b82086d79 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -18,7 +18,7 @@ use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache}; use mas_listener::server::Server; use mas_router::UrlBuilder; use mas_storage::SystemClock; -use mas_storage_pg::MIGRATOR; +use mas_storage_pg::{PgRepositoryFactory, MIGRATOR}; use sqlx::migrate::Migrate; use tracing::{Instrument, info, info_span, warn}; @@ -226,7 +226,7 @@ impl Options { let state = { let mut s = AppState { - pool, + repository_factory: PgRepositoryFactory::new(pool), templates, key_store, cookie_manager, diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs index 8971488a5..ccc1676d8 100644 --- a/crates/storage-pg/src/lib.rs +++ b/crates/storage-pg/src/lib.rs @@ -178,7 +178,11 @@ pub(crate) mod repository; pub(crate) mod tracing; pub(crate) use self::errors::DatabaseInconsistencyError; -pub use self::{errors::DatabaseError, repository::PgRepository, tracing::ExecuteExt}; +pub use self::{ + errors::DatabaseError, + repository::{PgRepository, PgRepositoryFactory}, + tracing::ExecuteExt, +}; /// Embedded migrations, allowing them to run on startup pub static MIGRATOR: Migrator = { diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 901f1fd45..739422dd3 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -6,9 +6,11 @@ use std::ops::{Deref, DerefMut}; +use async_trait::async_trait; use futures_util::{FutureExt, TryFutureExt, future::BoxFuture}; use mas_storage::{ - BoxRepository, MapErr, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction, + BoxRepository, BoxRepositoryFactory, MapErr, Repository, RepositoryAccess, RepositoryError, + RepositoryFactory, RepositoryTransaction, app_session::AppSessionRepository, compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, @@ -57,6 +59,43 @@ use crate::{ }, }; +/// An implementation of the [`RepositoryFactory`] trait backed by a PostgreSQL +/// connection pool. +#[derive(Clone)] +pub struct PgRepositoryFactory { + pool: PgPool, +} + +impl PgRepositoryFactory { + /// Create a new [`PgRepositoryFactory`] from a PostgreSQL connection pool. + #[must_use] + pub fn new(pool: PgPool) -> Self { + Self { pool } + } + + /// Box the factory + #[must_use] + pub fn boxed(self) -> BoxRepositoryFactory { + Box::new(self) + } + + /// Get the underlying PostgreSQL connection pool + #[must_use] + pub fn pool(&self) -> PgPool { + self.pool.clone() + } +} + +#[async_trait] +impl RepositoryFactory for PgRepositoryFactory { + async fn create(&self) -> Result { + Ok(PgRepository::from_pool(&self.pool) + .await + .map_err(RepositoryError::from_error)? + .boxed()) + } +} + /// An implementation of the [`Repository`] trait backed by a PostgreSQL /// transaction. pub struct PgRepository> { diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 923113a6a..07d8bd97c 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -128,7 +128,8 @@ pub use self::{ clock::{Clock, SystemClock}, pagination::{Page, Pagination}, repository::{ - BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction, + BoxRepository, BoxRepositoryFactory, Repository, RepositoryAccess, RepositoryError, + RepositoryFactory, RepositoryTransaction, }, utils::{BoxClock, BoxRng, MapErr}, }; diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 2f051493c..951c12798 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -4,6 +4,7 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use async_trait::async_trait; use futures_util::future::BoxFuture; use thiserror::Error; @@ -29,6 +30,17 @@ use crate::{ }, }; +/// A [`RepositoryFactory`] is a factory that can create a [`BoxRepository`] +// XXX(quenting): this could be generic over the repository type, but it's annoying to make it dyn-safe +#[async_trait] +pub trait RepositoryFactory { + /// Create a new [`BoxRepository`] + async fn create(&self) -> Result; +} + +/// A type-erased [`RepositoryFactory`] +pub type BoxRepositoryFactory = Box; + /// A [`Repository`] helps interacting with the underlying storage backend. pub trait Repository: RepositoryAccess + RepositoryTransaction + Send