Box all the figment errors to avoid large enum differences

This commit is contained in:
Quentin Gliech
2025-07-16 19:08:49 +02:00
parent a51a697013
commit 62dcab9f75
24 changed files with 245 additions and 151 deletions

View File

@@ -72,7 +72,7 @@ impl Options {
SC::Dump { output } => {
let _span = info_span!("cli.config.dump").entered();
let config = RootConfig::extract(figment)?;
let config = RootConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let config = serde_yaml::to_string(&config)?;
if let Some(output) = output {
@@ -88,7 +88,7 @@ impl Options {
SC::Check => {
let _span = info_span!("cli.config.check").entered();
let _config = RootConfig::extract(figment)?;
let _config = RootConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
info!("Configuration file looks good");
}
@@ -105,7 +105,8 @@ impl Options {
if !synapse_config.is_empty() {
info!("Adjusting MAS config to match Synapse config from {synapse_config:?}");
let synapse_config = syn2mas::synapse_config::Config::load(&synapse_config)?;
let synapse_config = syn2mas::synapse_config::Config::load(&synapse_config)
.map_err(anyhow::Error::from_boxed)?;
config = synapse_config.adjust_mas_config(config, &mut rng, clock.now());
}
@@ -121,7 +122,7 @@ impl Options {
}
SC::Sync { prune, dry_run } => {
let config = SyncConfig::extract(figment)?;
let config = SyncConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let clock = SystemClock::default();
let encrypter = config.secrets.encrypter().await?;

View File

@@ -30,7 +30,8 @@ enum Subcommand {
impl Options {
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
let _span = info_span!("cli.database.migrate").entered();
let config = DatabaseConfig::extract_or_default(figment)?;
let config =
DatabaseConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&config).await?;
// Run pending migrations

View File

@@ -41,13 +41,16 @@ impl Options {
match self.subcommand {
SC::Policy { with_dynamic_data } => {
let _span = info_span!("cli.debug.policy").entered();
let config = PolicyConfig::extract_or_default(figment)?;
let matrix_config = MatrixConfig::extract(figment)?;
let config =
PolicyConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?;
let matrix_config =
MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config, &matrix_config).await?;
if with_dynamic_data {
let database_config = DatabaseConfig::extract(figment)?;
let database_config =
DatabaseConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let pool = database_pool_from_config(&database_config).await?;
let repository_factory = PgRepositoryFactory::new(pool.clone());
load_policy_factory_dynamic_data(&policy_factory, &repository_factory).await?;

View File

@@ -33,7 +33,7 @@ impl Options {
"💡 Running diagnostics, make sure that both MAS and Synapse are running, and that MAS is using the same configuration files as this tool."
);
let config = RootConfig::extract(figment)?;
let config = RootConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
// We'll need an HTTP client
let http_client = mas_http::reqwest_client();

View File

@@ -219,8 +219,10 @@ impl Options {
let _span =
info_span!("cli.manage.set_password", user.username = %username).entered();
let database_config = DatabaseConfig::extract_or_default(figment)?;
let passwords_config = PasswordsConfig::extract_or_default(figment)?;
let database_config = DatabaseConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let passwords_config = PasswordsConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&database_config).await?;
let password_manager = password_manager_from_config(&passwords_config).await?;
@@ -260,7 +262,8 @@ impl Options {
)
.entered();
let database_config = DatabaseConfig::extract_or_default(figment)?;
let database_config = DatabaseConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&database_config).await?;
let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn);
@@ -314,7 +317,8 @@ impl Options {
admin,
device_id,
} => {
let database_config = DatabaseConfig::extract_or_default(figment)?;
let database_config = DatabaseConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&database_config).await?;
let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn);
@@ -372,7 +376,8 @@ impl Options {
(Some(_), true) => unreachable!(), // This should be handled by the clap group
};
let database_config = DatabaseConfig::extract_or_default(figment)?;
let database_config = DatabaseConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&database_config).await?;
let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn);
@@ -399,7 +404,8 @@ impl Options {
SC::ProvisionAllUsers => {
let _span = info_span!("cli.manage.provision_all_users").entered();
let database_config = DatabaseConfig::extract_or_default(figment)?;
let database_config = DatabaseConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&database_config).await?;
let mut txn = conn.begin().await?;
@@ -425,7 +431,8 @@ impl Options {
SC::KillSessions { username, dry_run } => {
let _span =
info_span!("cli.manage.kill_sessions", user.username = username).entered();
let database_config = DatabaseConfig::extract_or_default(figment)?;
let database_config = DatabaseConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&database_config).await?;
let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn);
@@ -497,7 +504,8 @@ impl Options {
deactivate,
} => {
let _span = info_span!("cli.manage.lock_user", user.username = username).entered();
let config = DatabaseConfig::extract_or_default(figment)?;
let config = DatabaseConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&config).await?;
let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn);
@@ -529,7 +537,8 @@ impl Options {
SC::UnlockUser { username } => {
let _span = info_span!("cli.manage.lock_user", user.username = username).entered();
let config = DatabaseConfig::extract_or_default(figment)?;
let config = DatabaseConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&config).await?;
let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn);
@@ -562,9 +571,12 @@ impl Options {
ignore_password_complexity,
} => {
let http_client = mas_http::reqwest_client();
let password_config = PasswordsConfig::extract_or_default(figment)?;
let database_config = DatabaseConfig::extract_or_default(figment)?;
let matrix_config = MatrixConfig::extract(figment)?;
let password_config = PasswordsConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let database_config = DatabaseConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let matrix_config =
MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let password_manager = password_manager_from_config(&password_config).await?;
let homeserver = homeserver_connection_from_config(&matrix_config, http_client);

View File

@@ -59,7 +59,7 @@ impl Options {
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
let span = info_span!("cli.run.init").entered();
let mut shutdown = LifecycleManager::new()?;
let config = AppConfig::extract(figment)?;
let config = AppConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
info!(version = crate::VERSION, "Starting up");
@@ -101,8 +101,10 @@ impl Options {
} else {
// Sync the configuration with the database
let mut conn = pool.acquire().await?;
let clients_config = ClientsConfig::extract_or_default(figment)?;
let upstream_oauth2_config = UpstreamOAuth2Config::extract_or_default(figment)?;
let clients_config =
ClientsConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?;
let upstream_oauth2_config = UpstreamOAuth2Config::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
crate::sync::config_sync(
upstream_oauth2_config,

View File

@@ -96,6 +96,7 @@ impl Options {
}
let synapse_config = synapse_config::Config::load(&self.synapse_configuration_files)
.map_err(anyhow::Error::from_boxed)
.context("Failed to load Synapse configuration")?;
// Establish a connection to Synapse's Postgres database
@@ -111,7 +112,8 @@ impl Options {
.await
.context("could not connect to Synapse Postgres database")?;
let config = DatabaseConfig::extract_or_default(figment)?;
let config =
DatabaseConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?;
let mut mas_connection = database_connection_from_config_with_options(
&config,
@@ -131,7 +133,7 @@ impl Options {
// First perform a config sync
// This is crucial to ensure we register upstream OAuth providers
// in the MAS database
let config = SyncConfig::extract(figment)?;
let config = SyncConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let clock = SystemClock::default();
let encrypter = config.secrets.encrypter().await?;
@@ -213,7 +215,8 @@ impl Options {
Subcommand::Migrate { dry_run } => {
let provider_id_mappings: HashMap<String, Uuid> = {
let mas_oauth2 = UpstreamOAuth2Config::extract_or_default(figment)?;
let mas_oauth2 = UpstreamOAuth2Config::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
mas_oauth2
.providers
@@ -252,7 +255,8 @@ impl Options {
let occasional_progress_logger_task =
tokio::spawn(occasional_progress_logger(progress.clone()));
let mas_matrix = MatrixConfig::extract(figment)?;
let mas_matrix =
MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
syn2mas::migrate(
reader,
writer,

View File

@@ -37,13 +37,20 @@ impl Options {
SC::Check => {
let _span = info_span!("cli.templates.check").entered();
let template_config = TemplatesConfig::extract_or_default(figment)?;
let branding_config = BrandingConfig::extract_or_default(figment)?;
let matrix_config = MatrixConfig::extract(figment)?;
let experimental_config = ExperimentalConfig::extract_or_default(figment)?;
let password_config = PasswordsConfig::extract_or_default(figment)?;
let account_config = AccountConfig::extract_or_default(figment)?;
let captcha_config = CaptchaConfig::extract_or_default(figment)?;
let template_config = TemplatesConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let branding_config = BrandingConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let matrix_config =
MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let experimental_config = ExperimentalConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let password_config = PasswordsConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let account_config = AccountConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let captcha_config = CaptchaConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let clock = SystemClock::default();
// XXX: we should disallow SeedableRng::from_entropy

View File

@@ -29,7 +29,7 @@ impl Options {
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
let shutdown = LifecycleManager::new()?;
let span = info_span!("cli.worker.init").entered();
let config = AppConfig::extract(figment)?;
let config = AppConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
// Connect to the database
info!("Connecting to the database");

View File

@@ -115,8 +115,9 @@ async fn try_main() -> anyhow::Result<ExitCode> {
// Load the base configuration files
let figment = opts.figment();
let telemetry_config =
TelemetryConfig::extract_or_default(&figment).context("Failed to load telemetry config")?;
let telemetry_config = TelemetryConfig::extract_or_default(&figment)
.map_err(anyhow::Error::from_boxed)
.context("Failed to load telemetry config")?;
// Setup Sentry
let sentry = sentry::init((

View File

@@ -51,7 +51,10 @@ impl CaptchaConfig {
impl ConfigurationSection for CaptchaConfig {
const PATH: Option<&'static str> = Some("captcha");
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let metadata = figment.find_metadata(Self::PATH.unwrap());
let error_on_field = |mut error: figment::error::Error, field: &'static str| {
@@ -67,11 +70,11 @@ impl ConfigurationSection for CaptchaConfig {
if let Some(CaptchaServiceKind::RecaptchaV2) = self.service {
if self.site_key.is_none() {
return Err(missing_field("site_key"));
return Err(missing_field("site_key").into());
}
if self.secret_key.is_none() {
return Err(missing_field("secret_key"));
return Err(missing_field("secret_key").into());
}
}

View File

@@ -6,7 +6,6 @@
use std::ops::Deref;
use figment::Figment;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::jwk::PublicJsonWebKeySet;
use schemars::JsonSchema;
@@ -104,7 +103,7 @@ pub struct ClientConfig {
}
impl ClientConfig {
fn validate(&self) -> Result<(), figment::error::Error> {
fn validate(&self) -> Result<(), Box<figment::error::Error>> {
let auth_method = self.client_auth_method;
match self.client_auth_method {
ClientAuthMethodConfig::PrivateKeyJwt => {
@@ -112,20 +111,20 @@ impl ClientConfig {
let error = figment::error::Error::custom(
"jwks or jwks_uri is required for private_key_jwt",
);
return Err(error.with_path("client_auth_method"));
return Err(Box::new(error.with_path("client_auth_method")));
}
if self.jwks.is_some() && self.jwks_uri.is_some() {
let error =
figment::error::Error::custom("jwks and jwks_uri are mutually exclusive");
return Err(error.with_path("jwks"));
return Err(Box::new(error.with_path("jwks")));
}
if self.client_secret.is_some() {
let error = figment::error::Error::custom(
"client_secret is not allowed with private_key_jwt",
);
return Err(error.with_path("client_secret"));
return Err(Box::new(error.with_path("client_secret")));
}
}
@@ -136,21 +135,21 @@ impl ClientConfig {
let error = figment::error::Error::custom(format!(
"client_secret is required for {auth_method}"
));
return Err(error.with_path("client_auth_method"));
return Err(Box::new(error.with_path("client_auth_method")));
}
if self.jwks.is_some() {
let error = figment::error::Error::custom(format!(
"jwks is not allowed with {auth_method}"
));
return Err(error.with_path("jwks"));
return Err(Box::new(error.with_path("jwks")));
}
if self.jwks_uri.is_some() {
let error = figment::error::Error::custom(format!(
"jwks_uri is not allowed with {auth_method}"
));
return Err(error.with_path("jwks_uri"));
return Err(Box::new(error.with_path("jwks_uri")));
}
}
@@ -159,21 +158,21 @@ impl ClientConfig {
let error = figment::error::Error::custom(
"client_secret is not allowed with none authentication method",
);
return Err(error.with_path("client_secret"));
return Err(Box::new(error.with_path("client_secret")));
}
if self.jwks.is_some() {
let error = figment::error::Error::custom(
"jwks is not allowed with none authentication method",
);
return Err(error);
return Err(Box::new(error));
}
if self.jwks_uri.is_some() {
let error = figment::error::Error::custom(
"jwks_uri is not allowed with none authentication method",
);
return Err(error);
return Err(Box::new(error));
}
}
}
@@ -232,7 +231,10 @@ impl IntoIterator for ClientsConfig {
impl ConfigurationSection for ClientsConfig {
const PATH: Option<&'static str> = Some("clients");
fn validate(&self, figment: &Figment) -> Result<(), figment::error::Error> {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
for (index, client) in self.0.iter().enumerate() {
client.validate().map_err(|mut err| {
// Save the error location information in the error

View File

@@ -222,13 +222,16 @@ pub struct DatabaseConfig {
impl ConfigurationSection for DatabaseConfig {
const PATH: Option<&'static str> = Some("database");
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::error::Error> {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let metadata = figment.find_metadata(Self::PATH.unwrap());
let annotate = |mut error: figment::Error| {
error.metadata = metadata.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![Self::PATH.unwrap().to_owned()];
Err(error)
error
};
// Check that the user did not specify both `uri` and the split options at the
@@ -241,37 +244,41 @@ impl ConfigurationSection for DatabaseConfig {
|| self.database.is_some();
if self.uri.is_some() && has_split_options {
return annotate(figment::error::Error::from(
return Err(annotate(figment::error::Error::from(
"uri must not be specified if host, port, socket, username, password, or database are specified".to_owned(),
));
)).into());
}
if self.ssl_ca.is_some() && self.ssl_ca_file.is_some() {
return annotate(figment::error::Error::from(
return Err(annotate(figment::error::Error::from(
"ssl_ca must not be specified if ssl_ca_file is specified".to_owned(),
));
))
.into());
}
if self.ssl_certificate.is_some() && self.ssl_certificate_file.is_some() {
return annotate(figment::error::Error::from(
return Err(annotate(figment::error::Error::from(
"ssl_certificate must not be specified if ssl_certificate_file is specified"
.to_owned(),
));
))
.into());
}
if self.ssl_key.is_some() && self.ssl_key_file.is_some() {
return annotate(figment::error::Error::from(
return Err(annotate(figment::error::Error::from(
"ssl_key must not be specified if ssl_key_file is specified".to_owned(),
));
))
.into());
}
if (self.ssl_key.is_some() || self.ssl_key_file.is_some())
^ (self.ssl_certificate.is_some() || self.ssl_certificate_file.is_some())
{
return annotate(figment::error::Error::from(
return Err(annotate(figment::error::Error::from(
"both a ssl_certificate and a ssl_key must be set at the same time or none of them"
.to_owned(),
));
))
.into());
}
Ok(())

View File

@@ -175,7 +175,10 @@ impl Default for EmailConfig {
impl ConfigurationSection for EmailConfig {
const PATH: Option<&'static str> = Some("email");
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::error::Error> {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let metadata = figment.find_metadata(Self::PATH.unwrap());
let error_on_field = |mut error: figment::error::Error, field: &'static str| {
@@ -201,29 +204,29 @@ impl ConfigurationSection for EmailConfig {
EmailTransportKind::Smtp => {
if let Err(e) = Mailbox::from_str(&self.from) {
return Err(error_on_field(figment::error::Error::custom(e), "from"));
return Err(error_on_field(figment::error::Error::custom(e), "from").into());
}
if let Err(e) = Mailbox::from_str(&self.reply_to) {
return Err(error_on_field(figment::error::Error::custom(e), "reply_to"));
return Err(error_on_field(figment::error::Error::custom(e), "reply_to").into());
}
match (self.username.is_some(), self.password.is_some()) {
(true, true) | (false, false) => {}
(true, false) => {
return Err(missing_field("password"));
return Err(missing_field("password").into());
}
(false, true) => {
return Err(missing_field("username"));
return Err(missing_field("username").into());
}
}
if self.mode.is_none() {
return Err(missing_field("mode"));
return Err(missing_field("mode").into());
}
if self.hostname.is_none() {
return Err(missing_field("hostname"));
return Err(missing_field("hostname").into());
}
if self.command.is_some() {
@@ -239,7 +242,8 @@ impl ConfigurationSection for EmailConfig {
"username",
"password",
],
));
)
.into());
}
}
@@ -247,35 +251,35 @@ impl ConfigurationSection for EmailConfig {
let expected_fields = &["from", "reply_to", "transport", "command"];
if let Err(e) = Mailbox::from_str(&self.from) {
return Err(error_on_field(figment::error::Error::custom(e), "from"));
return Err(error_on_field(figment::error::Error::custom(e), "from").into());
}
if let Err(e) = Mailbox::from_str(&self.reply_to) {
return Err(error_on_field(figment::error::Error::custom(e), "reply_to"));
return Err(error_on_field(figment::error::Error::custom(e), "reply_to").into());
}
if self.command.is_none() {
return Err(missing_field("command"));
return Err(missing_field("command").into());
}
if self.mode.is_some() {
return Err(unexpected_field("mode", expected_fields));
return Err(unexpected_field("mode", expected_fields).into());
}
if self.hostname.is_some() {
return Err(unexpected_field("hostname", expected_fields));
return Err(unexpected_field("hostname", expected_fields).into());
}
if self.port.is_some() {
return Err(unexpected_field("port", expected_fields));
return Err(unexpected_field("port", expected_fields).into());
}
if self.username.is_some() {
return Err(unexpected_field("username", expected_fields));
return Err(unexpected_field("username", expected_fields).into());
}
if self.password.is_some() {
return Err(unexpected_field("password", expected_fields));
return Err(unexpected_field("password", expected_fields).into());
}
}
}

View File

@@ -412,7 +412,10 @@ impl Default for HttpConfig {
impl ConfigurationSection for HttpConfig {
const PATH: Option<&'static str> = Some("http");
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
for (index, listener) in self.listeners.iter().enumerate() {
let annotate = |mut error: figment::Error| {
error.metadata = figment
@@ -424,49 +427,57 @@ impl ConfigurationSection for HttpConfig {
"listeners".to_owned(),
index.to_string(),
];
Err(error)
error
};
if listener.resources.is_empty() {
return annotate(figment::Error::from("listener has no resources".to_owned()));
return Err(
annotate(figment::Error::from("listener has no resources".to_owned())).into(),
);
}
if listener.binds.is_empty() {
return annotate(figment::Error::from(
return Err(annotate(figment::Error::from(
"listener does not bind to any address".to_owned(),
));
))
.into());
}
if let Some(tls_config) = &listener.tls {
if tls_config.certificate.is_some() && tls_config.certificate_file.is_some() {
return annotate(figment::Error::from(
return Err(annotate(figment::Error::from(
"Only one of `certificate` or `certificate_file` can be set at a time"
.to_owned(),
));
))
.into());
}
if tls_config.certificate.is_none() && tls_config.certificate_file.is_none() {
return annotate(figment::Error::from(
return Err(annotate(figment::Error::from(
"TLS configuration is missing a certificate".to_owned(),
));
))
.into());
}
if tls_config.key.is_some() && tls_config.key_file.is_some() {
return annotate(figment::Error::from(
return Err(annotate(figment::Error::from(
"Only one of `key` or `key_file` can be set at a time".to_owned(),
));
))
.into());
}
if tls_config.key.is_none() && tls_config.key_file.is_none() {
return annotate(figment::Error::from(
return Err(annotate(figment::Error::from(
"TLS configuration is missing a private key".to_owned(),
));
))
.into());
}
if tls_config.password.is_some() && tls_config.password_file.is_some() {
return annotate(figment::Error::from(
return Err(annotate(figment::Error::from(
"Only one of `password` or `password_file` can be set at a time".to_owned(),
));
))
.into());
}
}
}

View File

@@ -130,7 +130,10 @@ pub struct RootConfig {
}
impl ConfigurationSection for RootConfig {
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
self.clients.validate(figment)?;
self.http.validate(figment)?;
self.database.validate(figment)?;
@@ -249,7 +252,10 @@ pub struct AppConfig {
}
impl ConfigurationSection for AppConfig {
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
self.http.validate(figment)?;
self.database.validate(figment)?;
self.templates.validate(figment)?;
@@ -285,7 +291,10 @@ pub struct SyncConfig {
}
impl ConfigurationSection for SyncConfig {
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
self.database.validate(figment)?;
self.secrets.validate(figment)?;
self.clients.validate(figment)?;

View File

@@ -72,12 +72,15 @@ impl Default for PasswordsConfig {
impl ConfigurationSection for PasswordsConfig {
const PATH: Option<&'static str> = Some("passwords");
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let annotate = |mut error: figment::Error| {
error.metadata = figment.find_metadata(Self::PATH.unwrap()).cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![Self::PATH.unwrap().to_owned()];
Err(error)
error
};
if !self.enabled {
@@ -86,16 +89,18 @@ impl ConfigurationSection for PasswordsConfig {
}
if self.schemes.is_empty() {
return annotate(figment::Error::from(
return Err(annotate(figment::Error::from(
"Requires at least one password scheme in the config".to_owned(),
));
))
.into());
}
for scheme in &self.schemes {
if scheme.secret.is_some() && scheme.secret_file.is_some() {
return annotate(figment::Error::from(
return Err(annotate(figment::Error::from(
"Cannot specify both `secret` and `secret_file`".to_owned(),
));
))
.into());
}
}

View File

@@ -117,7 +117,10 @@ pub struct RateLimiterConfiguration {
impl ConfigurationSection for RateLimitingConfig {
const PATH: Option<&'static str> = Some("rate_limiting");
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let metadata = figment.find_metadata(Self::PATH.unwrap());
let error_on_field = |mut error: figment::error::Error, field: &'static str| {
@@ -154,25 +157,21 @@ impl ConfigurationSection for RateLimitingConfig {
};
if let Some(error) = error_on_limiter(&self.account_recovery.per_ip) {
return Err(error_on_nested_field(error, "account_recovery", "per_ip"));
return Err(error_on_nested_field(error, "account_recovery", "per_ip").into());
}
if let Some(error) = error_on_limiter(&self.account_recovery.per_address) {
return Err(error_on_nested_field(
error,
"account_recovery",
"per_address",
));
return Err(error_on_nested_field(error, "account_recovery", "per_address").into());
}
if let Some(error) = error_on_limiter(&self.registration) {
return Err(error_on_field(error, "registration"));
return Err(error_on_field(error, "registration").into());
}
if let Some(error) = error_on_limiter(&self.login.per_ip) {
return Err(error_on_nested_field(error, "login", "per_ip"));
return Err(error_on_nested_field(error, "login", "per_ip").into());
}
if let Some(error) = error_on_limiter(&self.login.per_account) {
return Err(error_on_nested_field(error, "login", "per_account"));
return Err(error_on_nested_field(error, "login", "per_account").into());
}
Ok(())

View File

@@ -194,13 +194,17 @@ impl TelemetryConfig {
impl ConfigurationSection for TelemetryConfig {
const PATH: Option<&'static str> = Some("telemetry");
fn validate(&self, _figment: &figment::Figment) -> Result<(), figment::Error> {
fn validate(
&self,
_figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
if let Some(sample_rate) = self.sentry.sample_rate {
if !(0.0..=1.0).contains(&sample_rate) {
return Err(figment::error::Error::custom(
"Sentry sample rate must be between 0.0 and 1.0",
)
.with_path("sentry.sample_rate"));
.with_path("sentry.sample_rate")
.into());
}
}
@@ -209,7 +213,8 @@ impl ConfigurationSection for TelemetryConfig {
return Err(figment::error::Error::custom(
"Sentry sample rate must be between 0.0 and 1.0",
)
.with_path("sentry.traces_sample_rate"));
.with_path("sentry.traces_sample_rate")
.into());
}
}
@@ -218,7 +223,8 @@ impl ConfigurationSection for TelemetryConfig {
return Err(figment::error::Error::custom(
"Tracing sample rate must be between 0.0 and 1.0",
)
.with_path("tracing.sample_rate"));
.with_path("tracing.sample_rate")
.into());
}
}

View File

@@ -33,7 +33,10 @@ impl UpstreamOAuth2Config {
impl ConfigurationSection for UpstreamOAuth2Config {
const PATH: Option<&'static str> = Some("upstream_oauth2");
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
for (index, provider) in self.providers.iter().enumerate() {
let annotate = |mut error: figment::Error| {
error.metadata = figment
@@ -45,15 +48,16 @@ impl ConfigurationSection for UpstreamOAuth2Config {
"providers".to_owned(),
index.to_string(),
];
Err(error)
error
};
if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
&& provider.issuer.is_none()
{
return annotate(figment::Error::custom(
return Err(annotate(figment::Error::custom(
"The `issuer` field is required when discovery is enabled",
));
))
.into());
}
match provider.token_endpoint_auth_method {
@@ -61,16 +65,16 @@ impl ConfigurationSection for UpstreamOAuth2Config {
| TokenAuthMethod::PrivateKeyJwt
| TokenAuthMethod::SignInWithApple => {
if provider.client_secret.is_some() {
return annotate(figment::Error::custom(
return Err(annotate(figment::Error::custom(
"Unexpected field `client_secret` for the selected authentication method",
));
)).into());
}
}
TokenAuthMethod::ClientSecretBasic
| TokenAuthMethod::ClientSecretPost
| TokenAuthMethod::ClientSecretJwt => {
if provider.client_secret.is_none() {
return annotate(figment::Error::missing_field("client_secret"));
return Err(annotate(figment::Error::missing_field("client_secret")).into());
}
}
}
@@ -81,16 +85,17 @@ impl ConfigurationSection for UpstreamOAuth2Config {
| TokenAuthMethod::ClientSecretPost
| TokenAuthMethod::SignInWithApple => {
if provider.token_endpoint_auth_signing_alg.is_some() {
return annotate(figment::Error::custom(
return Err(annotate(figment::Error::custom(
"Unexpected field `token_endpoint_auth_signing_alg` for the selected authentication method",
));
)).into());
}
}
TokenAuthMethod::ClientSecretJwt | TokenAuthMethod::PrivateKeyJwt => {
if provider.token_endpoint_auth_signing_alg.is_none() {
return annotate(figment::Error::missing_field(
return Err(annotate(figment::Error::missing_field(
"token_endpoint_auth_signing_alg",
));
))
.into());
}
}
}
@@ -98,15 +103,17 @@ impl ConfigurationSection for UpstreamOAuth2Config {
match provider.token_endpoint_auth_method {
TokenAuthMethod::SignInWithApple => {
if provider.sign_in_with_apple.is_none() {
return annotate(figment::Error::missing_field("sign_in_with_apple"));
return Err(
annotate(figment::Error::missing_field("sign_in_with_apple")).into(),
);
}
}
_ => {
if provider.sign_in_with_apple.is_some() {
return annotate(figment::Error::custom(
return Err(annotate(figment::Error::custom(
"Unexpected field `sign_in_with_apple` for the selected authentication method",
));
)).into());
}
}
}

View File

@@ -4,7 +4,7 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use figment::{Figment, error::Error as FigmentError};
use figment::Figment;
use serde::de::DeserializeOwned;
/// Trait implemented by all configuration section to help loading specific part
@@ -18,7 +18,10 @@ pub trait ConfigurationSection: Sized + DeserializeOwned {
/// # Errors
///
/// Returns an error if the configuration is invalid
fn validate(&self, _figment: &Figment) -> Result<(), FigmentError> {
fn validate(
&self,
_figment: &Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
Ok(())
}
@@ -27,7 +30,9 @@ pub trait ConfigurationSection: Sized + DeserializeOwned {
/// # Errors
///
/// Returns an error if the configuration could not be loaded
fn extract(figment: &Figment) -> Result<Self, FigmentError> {
fn extract(
figment: &Figment,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync + 'static>> {
let this: Self = if let Some(path) = Self::PATH {
figment.extract_inner(path)?
} else {
@@ -49,7 +54,9 @@ pub trait ConfigurationSectionExt: ConfigurationSection + Default {
/// # Errors
///
/// Returns an error if the configuration section is invalid.
fn extract_or_default(figment: &Figment) -> Result<Self, figment::Error> {
fn extract_or_default(
figment: &Figment,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync + 'static>> {
let this: Self = if let Some(path) = Self::PATH {
// If the configuration section is not present, we return the default value
if !figment.contains(path) {

View File

@@ -78,7 +78,7 @@ pub struct EndOAuth2SessionInput {
/// The payload of the `endOauth2Session` mutation.
pub enum EndOAuth2SessionPayload {
NotFound,
Ended(mas_data_model::Session),
Ended(Box<mas_data_model::Session>),
}
/// The status of the `endOauth2Session` mutation.
@@ -104,7 +104,7 @@ impl EndOAuth2SessionPayload {
/// Returns the ended session.
async fn oauth2_session(&self) -> Option<OAuth2Session> {
match self {
Self::Ended(session) => Some(OAuth2Session(session.clone())),
Self::Ended(session) => Some(OAuth2Session(*session.clone())),
Self::NotFound => None,
}
}
@@ -126,7 +126,7 @@ pub enum SetOAuth2SessionNamePayload {
NotFound,
/// The session was updated.
Updated(mas_data_model::Session),
Updated(Box<mas_data_model::Session>),
}
/// The status of the `setOauth2SessionName` mutation.
@@ -152,7 +152,7 @@ impl SetOAuth2SessionNamePayload {
/// The session that was updated.
async fn oauth2_session(&self) -> Option<OAuth2Session> {
match self {
Self::Updated(session) => Some(OAuth2Session(session.clone())),
Self::Updated(session) => Some(OAuth2Session(*session.clone())),
Self::NotFound => None,
}
}
@@ -293,7 +293,7 @@ impl OAuth2SessionMutations {
repo.save().await?;
Ok(EndOAuth2SessionPayload::Ended(session))
Ok(EndOAuth2SessionPayload::Ended(Box::new(session)))
}
async fn set_oauth2_session_name(
@@ -343,6 +343,6 @@ impl OAuth2SessionMutations {
repo.save().await?;
Ok(SetOAuth2SessionNamePayload::Updated(session))
Ok(SetOAuth2SessionNamePayload::Updated(Box::new(session)))
}
}

View File

@@ -25,7 +25,7 @@ pub enum Error {
Sqlx(#[from] sqlx::Error),
#[error("failed to load MAS config: {0}")]
MasConfig(#[from] figment::Error),
MasConfig(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("failed to load MAS password config: {0}")]
MasPasswordConfig(#[source] anyhow::Error),
@@ -188,13 +188,13 @@ pub async fn synapse_config_check_against_mas_config(
let mut errors = Vec::new();
let mut warnings = Vec::new();
let mas_passwords = PasswordsConfig::extract_or_default(mas)?;
let mas_passwords = PasswordsConfig::extract_or_default(mas).map_err(Error::MasConfig)?;
let mas_password_schemes = mas_passwords
.load()
.await
.map_err(Error::MasPasswordConfig)?;
let mas_matrix = MatrixConfig::extract(mas)?;
let mas_matrix = MatrixConfig::extract(mas).map_err(Error::MasConfig)?;
// Look for the MAS password hashing scheme that will be used for imported
// Synapse passwords, then check the configuration matches so that Synapse
@@ -230,12 +230,12 @@ pub async fn synapse_config_check_against_mas_config(
});
}
let mas_captcha = CaptchaConfig::extract_or_default(mas)?;
let mas_captcha = CaptchaConfig::extract_or_default(mas).map_err(Error::MasConfig)?;
if synapse.enable_registration_captcha && mas_captcha.service.is_none() {
warnings.push(CheckWarning::ShouldPortRegistrationCaptcha);
}
let mas_branding = BrandingConfig::extract_or_default(mas)?;
let mas_branding = BrandingConfig::extract_or_default(mas).map_err(Error::MasConfig)?;
if synapse.user_consent.is_some() && mas_branding.tos_uri.is_none() {
warnings.push(CheckWarning::ShouldPortUserConsentAsTerms);
}
@@ -295,7 +295,7 @@ pub async fn synapse_database_check(
.await?;
if !oauth_provider_user_counts.is_empty() {
let syn_oauth2 = synapse.all_oidc_providers();
let mas_oauth2 = UpstreamOAuth2Config::extract_or_default(mas)?;
let mas_oauth2 = UpstreamOAuth2Config::extract_or_default(mas).map_err(Error::MasConfig)?;
for row in oauth_provider_user_counts {
// This is a special case of a previous migration attempt to MAS
if row.auth_provider == "oauth-delegated" {

View File

@@ -94,7 +94,9 @@ impl Config {
///
/// - If there is a problem reading any of the files.
/// - If the configuration is not valid.
pub fn load(files: &[Utf8PathBuf]) -> Result<Config, figment::Error> {
pub fn load(
files: &[Utf8PathBuf],
) -> Result<Config, Box<dyn std::error::Error + Send + Sync + 'static>> {
let mut figment = figment::Figment::new();
for file in files {
// TODO this is not exactly correct behaviour — Synapse does not merge anything
@@ -103,7 +105,8 @@ impl Config {
// https://github.com/element-hq/synapse/blob/develop/synapse/config/_base.py?rgh-link-date=2025-01-20T17%3A02%3A56Z#L870
figment = figment.merge(Yaml::file(file));
}
figment.extract::<Config>()
let config = figment.extract::<Config>()?;
Ok(config)
}
/// Returns a map of all OIDC providers from the Synapse configuration.