From 7ee32e796aca426423742a24d84bb2c97e7194e1 Mon Sep 17 00:00:00 2001 From: Olivier 'reivilibre Date: Thu, 6 Nov 2025 08:17:03 +0000 Subject: [PATCH] Add session limit config to policy data --- crates/cli/src/commands/debug.rs | 9 +++++++-- crates/cli/src/commands/server.rs | 4 +++- crates/cli/src/util.rs | 14 ++++++++++++-- crates/handlers/src/test_utils.rs | 2 +- crates/policy/src/lib.rs | 24 +++++++++++++++--------- 5 files changed, 38 insertions(+), 15 deletions(-) diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index bb87c5e81..6da64f95b 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -9,7 +9,8 @@ use std::process::ExitCode; use clap::Parser; use figment::Figment; use mas_config::{ - ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig, + ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, ExperimentalConfig, + MatrixConfig, PolicyConfig, }; use mas_storage_pg::PgRepositoryFactory; use tracing::{info, info_span}; @@ -45,8 +46,12 @@ impl Options { PolicyConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?; let matrix_config = MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?; + let experimental_config = + ExperimentalConfig::extract(figment).map_err(anyhow::Error::from_boxed)?; info!("Loading and compiling the policy module"); - let policy_factory = policy_factory_from_config(&config, &matrix_config).await?; + let policy_factory = + policy_factory_from_config(&config, &matrix_config, &experimental_config) + .await?; if with_dynamic_data { let database_config = diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 52465f077..020d24d0f 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -132,7 +132,9 @@ impl Options { // Load and compile the WASM policies (and fallback to the default embedded one) info!("Loading and compiling the policy module"); - let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?; + let policy_factory = + policy_factory_from_config(&config.policy, &config.matrix, &config.experimental) + .await?; let policy_factory = Arc::new(policy_factory); load_policy_factory_dynamic_data_continuously( diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index eefbb5b0e..a9b9a3132 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -135,6 +135,7 @@ pub fn test_mailer_in_background(mailer: &Mailer, timeout: Duration) { pub async fn policy_factory_from_config( config: &PolicyConfig, matrix_config: &MatrixConfig, + experimental_config: &ExperimentalConfig, ) -> Result { let policy_file = tokio::fs::File::open(&config.wasm_module) .await @@ -147,8 +148,17 @@ pub async fn policy_factory_from_config( email: config.email_entrypoint.clone(), }; - let data = - mas_policy::Data::new(matrix_config.homeserver.clone()).with_rest(config.data.clone()); + let session_limit_config = + experimental_config + .session_limit + .as_ref() + .map(|c| SessionLimitConfig { + soft_limit: c.soft_limit, + hard_limit: c.hard_limit, + }); + + let data = mas_policy::Data::new(matrix_config.homeserver.clone(), session_limit_config) + .with_rest(config.data.clone()); PolicyFactory::load(policy_file, data, entrypoints) .await diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index df60c5c20..cf0466a9c 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -85,7 +85,7 @@ pub(crate) async fn policy_factory( email: "email/violation".to_owned(), }; - let data = mas_policy::Data::new(server_name.to_owned()).with_rest(data); + let data = mas_policy::Data::new(server_name.to_owned(), None).with_rest(data); let policy_factory = PolicyFactory::load(file, data, entrypoints).await?; let policy_factory = Arc::new(policy_factory); diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 3a3a23c3f..b5b187e1e 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -9,11 +9,12 @@ pub mod model; use std::sync::Arc; use arc_swap::ArcSwap; -use mas_data_model::Ulid; +use mas_data_model::{SessionLimitConfig, Ulid}; use opa_wasm::{ Runtime, wasmtime::{Config, Engine, Module, OptLevel, Store}, }; +use serde::Serialize; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; @@ -85,18 +86,25 @@ impl Entrypoints { } } -#[derive(Debug)] +#[derive(Serialize, Debug)] pub struct Data { server_name: String, + /// Limits on the number of application sessions that each user can have + session_limit: Option, + + // We will merge this in a custom way, so don't emit as part of the base + #[serde(skip)] rest: Option, } impl Data { #[must_use] - pub fn new(server_name: String) -> Self { + pub fn new(server_name: String, session_limit: Option) -> Self { Self { server_name, + session_limit, + rest: None, } } @@ -108,9 +116,7 @@ impl Data { } fn to_value(&self) -> Result { - let base = serde_json::json!({ - "server_name": self.server_name, - }); + let base = serde_json::to_value(self)?; if let Some(rest) = &self.rest { merge_data(base, rest.clone()) @@ -458,7 +464,7 @@ mod tests { #[tokio::test] async fn test_register() { - let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({ + let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({ "allowed_domains": ["element.io", "*.element.io"], "banned_domains": ["staging.element.io"], })); @@ -528,7 +534,7 @@ mod tests { #[tokio::test] async fn test_dynamic_data() { - let data = Data::new("example.com".to_owned()); + let data = Data::new("example.com".to_owned(), None); #[allow(clippy::disallowed_types)] let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) @@ -597,7 +603,7 @@ mod tests { #[tokio::test] async fn test_big_dynamic_data() { - let data = Data::new("example.com".to_owned()); + let data = Data::new("example.com".to_owned(), None); #[allow(clippy::disallowed_types)] let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))