Introduce a RepositoryFactory

This commit is contained in:
Quentin Gliech
2025-05-07 17:00:49 +02:00
parent e77b532fa9
commit 90faa72633
6 changed files with 80 additions and 18 deletions

View File

@@ -19,8 +19,8 @@ use mas_keystore::{Encrypter, Keystore};
use mas_matrix::HomeserverConnection;
use mas_policy::{Policy, PolicyFactory};
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, SystemClock};
use mas_storage_pg::PgRepository;
use mas_storage::{BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, SystemClock, RepositoryFactory};
use mas_storage_pg::PgRepositoryFactory;
use mas_templates::Templates;
use opentelemetry::{KeyValue, metrics::Histogram};
use rand::SeedableRng;
@@ -31,7 +31,7 @@ use crate::telemetry::METER;
#[derive(Clone)]
pub struct AppState {
pub pool: PgPool,
pub repository_factory: PgRepositoryFactory,
pub templates: Templates,
pub key_store: Keystore,
pub cookie_manager: CookieManager,
@@ -53,7 +53,7 @@ pub struct AppState {
impl AppState {
/// Init the metrics for the app state.
pub fn init_metrics(&mut self) {
let pool = self.pool.clone();
let pool = self.repository_factory.pool();
METER
.i64_observable_up_down_counter("db.connections.usage")
.with_description("The number of connections that are currently in `state` described by the state attribute.")
@@ -66,7 +66,7 @@ impl AppState {
})
.build();
let pool = self.pool.clone();
let pool = self.repository_factory.pool();
METER
.i64_observable_up_down_counter("db.connections.max")
.with_description("The maximum number of open connections allowed.")
@@ -88,14 +88,14 @@ impl AppState {
/// Init the metadata cache in the background
pub fn init_metadata_cache(&self) {
let pool = self.pool.clone();
let factory = self.repository_factory.clone();
let metadata_cache = self.metadata_cache.clone();
let http_client = self.http_client.clone();
tokio::spawn(
LogContext::new("metadata-cache-warmup")
.run(async move || {
let conn = match pool.acquire().await {
let mut repo = match factory.create().await {
Ok(conn) => conn,
Err(e) => {
tracing::error!(
@@ -106,8 +106,6 @@ impl AppState {
}
};
let mut repo = PgRepository::from_conn(conn);
if let Err(e) = metadata_cache
.warm_up_and_run(
&http_client,
@@ -127,9 +125,17 @@ impl AppState {
}
}
// XXX(quenting): we only use this for the healthcheck endpoint, checking the db
// should be part of the repository
impl FromRef<AppState> for PgPool {
fn from_ref(input: &AppState) -> Self {
input.pool.clone()
input.repository_factory.pool()
}
}
impl FromRef<AppState> for BoxRepositoryFactory {
fn from_ref(input: &AppState) -> Self {
input.repository_factory.clone().boxed()
}
}
@@ -359,14 +365,14 @@ impl FromRequestParts<AppState> for RequesterFingerprint {
}
impl FromRequestParts<AppState> for BoxRepository {
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;
type Rejection = ErrorWrapper<mas_storage::RepositoryError>;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let start = Instant::now();
let repo = PgRepository::from_pool(&state.pool).await?;
let repo = state.repository_factory.create().await?;
// Measure the time it took to create the connection
let duration = start.elapsed();
@@ -376,6 +382,6 @@ impl FromRequestParts<AppState> for BoxRepository {
histogram.record(duration_ms, &[]);
}
Ok(repo.boxed())
Ok(repo)
}
}

View File

@@ -18,7 +18,7 @@ use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache};
use mas_listener::server::Server;
use mas_router::UrlBuilder;
use mas_storage::SystemClock;
use mas_storage_pg::MIGRATOR;
use mas_storage_pg::{PgRepositoryFactory, MIGRATOR};
use sqlx::migrate::Migrate;
use tracing::{Instrument, info, info_span, warn};
@@ -226,7 +226,7 @@ impl Options {
let state = {
let mut s = AppState {
pool,
repository_factory: PgRepositoryFactory::new(pool),
templates,
key_store,
cookie_manager,

View File

@@ -178,7 +178,11 @@ pub(crate) mod repository;
pub(crate) mod tracing;
pub(crate) use self::errors::DatabaseInconsistencyError;
pub use self::{errors::DatabaseError, repository::PgRepository, tracing::ExecuteExt};
pub use self::{
errors::DatabaseError,
repository::{PgRepository, PgRepositoryFactory},
tracing::ExecuteExt,
};
/// Embedded migrations, allowing them to run on startup
pub static MIGRATOR: Migrator = {

View File

@@ -6,9 +6,11 @@
use std::ops::{Deref, DerefMut};
use async_trait::async_trait;
use futures_util::{FutureExt, TryFutureExt, future::BoxFuture};
use mas_storage::{
BoxRepository, MapErr, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
BoxRepository, BoxRepositoryFactory, MapErr, Repository, RepositoryAccess, RepositoryError,
RepositoryFactory, RepositoryTransaction,
app_session::AppSessionRepository,
compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
@@ -57,6 +59,43 @@ use crate::{
},
};
/// An implementation of the [`RepositoryFactory`] trait backed by a PostgreSQL
/// connection pool.
#[derive(Clone)]
pub struct PgRepositoryFactory {
pool: PgPool,
}
impl PgRepositoryFactory {
/// Create a new [`PgRepositoryFactory`] from a PostgreSQL connection pool.
#[must_use]
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
/// Box the factory
#[must_use]
pub fn boxed(self) -> BoxRepositoryFactory {
Box::new(self)
}
/// Get the underlying PostgreSQL connection pool
#[must_use]
pub fn pool(&self) -> PgPool {
self.pool.clone()
}
}
#[async_trait]
impl RepositoryFactory for PgRepositoryFactory {
async fn create(&self) -> Result<BoxRepository, RepositoryError> {
Ok(PgRepository::from_pool(&self.pool)
.await
.map_err(RepositoryError::from_error)?
.boxed())
}
}
/// An implementation of the [`Repository`] trait backed by a PostgreSQL
/// transaction.
pub struct PgRepository<C = Transaction<'static, Postgres>> {

View File

@@ -128,7 +128,8 @@ pub use self::{
clock::{Clock, SystemClock},
pagination::{Page, Pagination},
repository::{
BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
BoxRepository, BoxRepositoryFactory, Repository, RepositoryAccess, RepositoryError,
RepositoryFactory, RepositoryTransaction,
},
utils::{BoxClock, BoxRng, MapErr},
};

View File

@@ -4,6 +4,7 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use async_trait::async_trait;
use futures_util::future::BoxFuture;
use thiserror::Error;
@@ -29,6 +30,17 @@ use crate::{
},
};
/// A [`RepositoryFactory`] is a factory that can create a [`BoxRepository`]
// XXX(quenting): this could be generic over the repository type, but it's annoying to make it dyn-safe
#[async_trait]
pub trait RepositoryFactory {
/// Create a new [`BoxRepository`]
async fn create(&self) -> Result<BoxRepository, RepositoryError>;
}
/// A type-erased [`RepositoryFactory`]
pub type BoxRepositoryFactory = Box<dyn RepositoryFactory + Send + Sync + 'static>;
/// A [`Repository`] helps interacting with the underlying storage backend.
pub trait Repository<E>:
RepositoryAccess<Error = E> + RepositoryTransaction<Error = E> + Send