Introduce a RepositoryFactory
This commit is contained in:
@@ -19,8 +19,8 @@ use mas_keystore::{Encrypter, Keystore};
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_policy::{Policy, PolicyFactory};
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::{BoxClock, BoxRepository, BoxRng, SystemClock};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use mas_storage::{BoxClock, BoxRepository, BoxRepositoryFactory, BoxRng, SystemClock, RepositoryFactory};
|
||||
use mas_storage_pg::PgRepositoryFactory;
|
||||
use mas_templates::Templates;
|
||||
use opentelemetry::{KeyValue, metrics::Histogram};
|
||||
use rand::SeedableRng;
|
||||
@@ -31,7 +31,7 @@ use crate::telemetry::METER;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub pool: PgPool,
|
||||
pub repository_factory: PgRepositoryFactory,
|
||||
pub templates: Templates,
|
||||
pub key_store: Keystore,
|
||||
pub cookie_manager: CookieManager,
|
||||
@@ -53,7 +53,7 @@ pub struct AppState {
|
||||
impl AppState {
|
||||
/// Init the metrics for the app state.
|
||||
pub fn init_metrics(&mut self) {
|
||||
let pool = self.pool.clone();
|
||||
let pool = self.repository_factory.pool();
|
||||
METER
|
||||
.i64_observable_up_down_counter("db.connections.usage")
|
||||
.with_description("The number of connections that are currently in `state` described by the state attribute.")
|
||||
@@ -66,7 +66,7 @@ impl AppState {
|
||||
})
|
||||
.build();
|
||||
|
||||
let pool = self.pool.clone();
|
||||
let pool = self.repository_factory.pool();
|
||||
METER
|
||||
.i64_observable_up_down_counter("db.connections.max")
|
||||
.with_description("The maximum number of open connections allowed.")
|
||||
@@ -88,14 +88,14 @@ impl AppState {
|
||||
|
||||
/// Init the metadata cache in the background
|
||||
pub fn init_metadata_cache(&self) {
|
||||
let pool = self.pool.clone();
|
||||
let factory = self.repository_factory.clone();
|
||||
let metadata_cache = self.metadata_cache.clone();
|
||||
let http_client = self.http_client.clone();
|
||||
|
||||
tokio::spawn(
|
||||
LogContext::new("metadata-cache-warmup")
|
||||
.run(async move || {
|
||||
let conn = match pool.acquire().await {
|
||||
let mut repo = match factory.create().await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
@@ -106,8 +106,6 @@ impl AppState {
|
||||
}
|
||||
};
|
||||
|
||||
let mut repo = PgRepository::from_conn(conn);
|
||||
|
||||
if let Err(e) = metadata_cache
|
||||
.warm_up_and_run(
|
||||
&http_client,
|
||||
@@ -127,9 +125,17 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
|
||||
// XXX(quenting): we only use this for the healthcheck endpoint, checking the db
|
||||
// should be part of the repository
|
||||
impl FromRef<AppState> for PgPool {
|
||||
fn from_ref(input: &AppState) -> Self {
|
||||
input.pool.clone()
|
||||
input.repository_factory.pool()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<AppState> for BoxRepositoryFactory {
|
||||
fn from_ref(input: &AppState) -> Self {
|
||||
input.repository_factory.clone().boxed()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,14 +365,14 @@ impl FromRequestParts<AppState> for RequesterFingerprint {
|
||||
}
|
||||
|
||||
impl FromRequestParts<AppState> for BoxRepository {
|
||||
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;
|
||||
type Rejection = ErrorWrapper<mas_storage::RepositoryError>;
|
||||
|
||||
async fn from_request_parts(
|
||||
_parts: &mut axum::http::request::Parts,
|
||||
state: &AppState,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let start = Instant::now();
|
||||
let repo = PgRepository::from_pool(&state.pool).await?;
|
||||
let repo = state.repository_factory.create().await?;
|
||||
|
||||
// Measure the time it took to create the connection
|
||||
let duration = start.elapsed();
|
||||
@@ -376,6 +382,6 @@ impl FromRequestParts<AppState> for BoxRepository {
|
||||
histogram.record(duration_ms, &[]);
|
||||
}
|
||||
|
||||
Ok(repo.boxed())
|
||||
Ok(repo)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache};
|
||||
use mas_listener::server::Server;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::SystemClock;
|
||||
use mas_storage_pg::MIGRATOR;
|
||||
use mas_storage_pg::{PgRepositoryFactory, MIGRATOR};
|
||||
use sqlx::migrate::Migrate;
|
||||
use tracing::{Instrument, info, info_span, warn};
|
||||
|
||||
@@ -226,7 +226,7 @@ impl Options {
|
||||
|
||||
let state = {
|
||||
let mut s = AppState {
|
||||
pool,
|
||||
repository_factory: PgRepositoryFactory::new(pool),
|
||||
templates,
|
||||
key_store,
|
||||
cookie_manager,
|
||||
|
||||
@@ -178,7 +178,11 @@ pub(crate) mod repository;
|
||||
pub(crate) mod tracing;
|
||||
|
||||
pub(crate) use self::errors::DatabaseInconsistencyError;
|
||||
pub use self::{errors::DatabaseError, repository::PgRepository, tracing::ExecuteExt};
|
||||
pub use self::{
|
||||
errors::DatabaseError,
|
||||
repository::{PgRepository, PgRepositoryFactory},
|
||||
tracing::ExecuteExt,
|
||||
};
|
||||
|
||||
/// Embedded migrations, allowing them to run on startup
|
||||
pub static MIGRATOR: Migrator = {
|
||||
|
||||
@@ -6,9 +6,11 @@
|
||||
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{FutureExt, TryFutureExt, future::BoxFuture};
|
||||
use mas_storage::{
|
||||
BoxRepository, MapErr, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
|
||||
BoxRepository, BoxRepositoryFactory, MapErr, Repository, RepositoryAccess, RepositoryError,
|
||||
RepositoryFactory, RepositoryTransaction,
|
||||
app_session::AppSessionRepository,
|
||||
compat::{
|
||||
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
|
||||
@@ -57,6 +59,43 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
/// An implementation of the [`RepositoryFactory`] trait backed by a PostgreSQL
|
||||
/// connection pool.
|
||||
#[derive(Clone)]
|
||||
pub struct PgRepositoryFactory {
|
||||
pool: PgPool,
|
||||
}
|
||||
|
||||
impl PgRepositoryFactory {
|
||||
/// Create a new [`PgRepositoryFactory`] from a PostgreSQL connection pool.
|
||||
#[must_use]
|
||||
pub fn new(pool: PgPool) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
|
||||
/// Box the factory
|
||||
#[must_use]
|
||||
pub fn boxed(self) -> BoxRepositoryFactory {
|
||||
Box::new(self)
|
||||
}
|
||||
|
||||
/// Get the underlying PostgreSQL connection pool
|
||||
#[must_use]
|
||||
pub fn pool(&self) -> PgPool {
|
||||
self.pool.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RepositoryFactory for PgRepositoryFactory {
|
||||
async fn create(&self) -> Result<BoxRepository, RepositoryError> {
|
||||
Ok(PgRepository::from_pool(&self.pool)
|
||||
.await
|
||||
.map_err(RepositoryError::from_error)?
|
||||
.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
/// An implementation of the [`Repository`] trait backed by a PostgreSQL
|
||||
/// transaction.
|
||||
pub struct PgRepository<C = Transaction<'static, Postgres>> {
|
||||
|
||||
@@ -128,7 +128,8 @@ pub use self::{
|
||||
clock::{Clock, SystemClock},
|
||||
pagination::{Page, Pagination},
|
||||
repository::{
|
||||
BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
|
||||
BoxRepository, BoxRepositoryFactory, Repository, RepositoryAccess, RepositoryError,
|
||||
RepositoryFactory, RepositoryTransaction,
|
||||
},
|
||||
utils::{BoxClock, BoxRng, MapErr},
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures_util::future::BoxFuture;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -29,6 +30,17 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
/// A [`RepositoryFactory`] is a factory that can create a [`BoxRepository`]
|
||||
// XXX(quenting): this could be generic over the repository type, but it's annoying to make it dyn-safe
|
||||
#[async_trait]
|
||||
pub trait RepositoryFactory {
|
||||
/// Create a new [`BoxRepository`]
|
||||
async fn create(&self) -> Result<BoxRepository, RepositoryError>;
|
||||
}
|
||||
|
||||
/// A type-erased [`RepositoryFactory`]
|
||||
pub type BoxRepositoryFactory = Box<dyn RepositoryFactory + Send + Sync + 'static>;
|
||||
|
||||
/// A [`Repository`] helps interacting with the underlying storage backend.
|
||||
pub trait Repository<E>:
|
||||
RepositoryAccess<Error = E> + RepositoryTransaction<Error = E> + Send
|
||||
|
||||
Reference in New Issue
Block a user