From 1e3d838c99a5aecfb66acb5dba255450a9a6ce93 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 19 Dec 2024 12:12:34 +0100 Subject: [PATCH] Allow longer & shorter usernames, complying with the MXID length spec --- crates/cli/src/commands/debug.rs | 5 +++-- crates/cli/src/commands/server.rs | 2 +- crates/cli/src/util.rs | 6 ++++- crates/handlers/src/graphql/tests.rs | 18 ++++++++++----- crates/handlers/src/oauth2/token.rs | 9 +++++--- crates/handlers/src/test_utils.rs | 15 +++++++++---- crates/policy/src/lib.rs | 33 ++++++++++++++++++++++++---- policies/register/register.rego | 7 ++++-- policies/register/register_test.rego | 20 ++++++++++++++--- 9 files changed, 89 insertions(+), 26 deletions(-) diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 8ffb50306..8768ae7b1 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -8,7 +8,7 @@ use std::process::ExitCode; use clap::Parser; use figment::Figment; -use mas_config::{ConfigurationSectionExt, PolicyConfig}; +use mas_config::{ConfigurationSection, ConfigurationSectionExt, MatrixConfig, PolicyConfig}; use tracing::{info, info_span}; use crate::util::policy_factory_from_config; @@ -33,8 +33,9 @@ impl Options { SC::Policy => { let _span = info_span!("cli.debug.policy").entered(); let config = PolicyConfig::extract_or_default(figment)?; + let matrix_config = MatrixConfig::extract(figment)?; info!("Loading and compiling the policy module"); - let policy_factory = policy_factory_from_config(&config).await?; + let policy_factory = policy_factory_from_config(&config, &matrix_config).await?; let _instance = policy_factory.instantiate().await?; } diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 9234e45bb..2b867ca86 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -123,7 +123,7 @@ impl Options { // Load and compile the WASM policies (and fallback to the default embedded one) info!("Loading and compiling the policy module"); - let policy_factory = policy_factory_from_config(&config.policy).await?; + let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?; let policy_factory = Arc::new(policy_factory); let url_builder = UrlBuilder::new( diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index a4ab8eba6..3d3d8f676 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -101,6 +101,7 @@ pub fn mailer_from_config( pub async fn policy_factory_from_config( config: &PolicyConfig, + matrix_config: &MatrixConfig, ) -> Result { let policy_file = tokio::fs::File::open(&config.wasm_module) .await @@ -113,7 +114,10 @@ pub async fn policy_factory_from_config( email: config.email_entrypoint.clone(), }; - PolicyFactory::load(policy_file, config.data.clone(), entrypoints) + let data = + mas_policy::Data::new(matrix_config.homeserver.clone()).with_rest(config.data.clone()); + + PolicyFactory::load(policy_file, data, entrypoints) .await .context("failed to load the policy") } diff --git a/crates/handlers/src/graphql/tests.rs b/crates/handlers/src/graphql/tests.rs index 1d72eb66b..a830f2448 100644 --- a/crates/handlers/src/graphql/tests.rs +++ b/crates/handlers/src/graphql/tests.rs @@ -469,9 +469,12 @@ async fn test_oauth2_client_credentials(pool: PgPool) { // Now make the client admin and try again let state = { let mut state = state; - state.policy_factory = test_utils::policy_factory(serde_json::json!({ - "admin_clients": [client_id], - })) + state.policy_factory = test_utils::policy_factory( + "example.com", + serde_json::json!({ + "admin_clients": [client_id], + }), + ) .await .unwrap(); state @@ -593,9 +596,12 @@ async fn test_add_user(pool: PgPool) { // Make the client admin let state = { let mut state = state; - state.policy_factory = test_utils::policy_factory(serde_json::json!({ - "admin_clients": [client_id], - })) + state.policy_factory = test_utils::policy_factory( + "example.com", + serde_json::json!({ + "admin_clients": [client_id], + }), + ) .await .unwrap(); state diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index b5c3b0341..ad6a15618 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -1475,9 +1475,12 @@ mod tests { // Now, if we add the client to the admin list in the policy, it should work let state = { let mut state = state; - state.policy_factory = crate::test_utils::policy_factory(serde_json::json!({ - "admin_clients": [client_id] - })) + state.policy_factory = crate::test_utils::policy_factory( + "example.com", + serde_json::json!({ + "admin_clients": [client_id] + }), + ) .await .unwrap(); state diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index f6a16037a..5f7240dab 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -69,6 +69,7 @@ pub(crate) fn setup() { } pub(crate) async fn policy_factory( + server_name: &str, data: serde_json::Value, ) -> Result, anyhow::Error> { let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR")) @@ -84,6 +85,8 @@ pub(crate) async fn policy_factory( email: "email/violation".to_owned(), }; + let data = mas_policy::Data::new(server_name.to_owned()).with_rest(data); + let policy_factory = PolicyFactory::load(file, data, entrypoints).await?; let policy_factory = Arc::new(policy_factory); Ok(policy_factory) @@ -192,7 +195,8 @@ impl TestState { PasswordManager::disabled() }; - let policy_factory = policy_factory(serde_json::json!({})).await?; + let policy_factory = + policy_factory(&site_config.server_name, serde_json::json!({})).await?; let homeserver_connection = Arc::new(MockHomeserverConnection::new(&site_config.server_name)); @@ -297,9 +301,12 @@ impl TestState { // Make the client admin let state = { let mut state = self.clone(); - state.policy_factory = policy_factory(serde_json::json!({ - "admin_clients": [client_id], - })) + state.policy_factory = policy_factory( + "example.com", + serde_json::json!({ + "admin_clients": [client_id], + }), + ) .await .unwrap(); state diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 9db450fca..950957692 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -12,6 +12,7 @@ use opa_wasm::{ wasmtime::{Config, Engine, Module, OptLevel, Store}, Runtime, }; +use serde::Serialize; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; @@ -69,10 +70,34 @@ impl Entrypoints { } } +#[derive(Serialize, Debug)] +pub struct Data { + server_name: String, + + #[serde(flatten)] + rest: Option, +} + +impl Data { + #[must_use] + pub fn new(server_name: String) -> Self { + Self { + server_name, + rest: None, + } + } + + #[must_use] + pub fn with_rest(mut self, rest: serde_json::Value) -> Self { + self.rest = Some(rest); + self + } +} + pub struct PolicyFactory { engine: Engine, module: Module, - data: serde_json::Value, + data: Data, entrypoints: Entrypoints, } @@ -80,7 +105,7 @@ impl PolicyFactory { #[tracing::instrument(name = "policy.load", skip(source), err)] pub async fn load( mut source: impl AsyncRead + std::marker::Unpin, - data: serde_json::Value, + data: Data, entrypoints: Entrypoints, ) -> Result { let mut config = Config::default(); @@ -364,10 +389,10 @@ mod tests { #[tokio::test] async fn test_register() { - let data = serde_json::json!({ + let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({ "allowed_domains": ["element.io", "*.element.io"], "banned_domains": ["staging.element.io"], - }); + })); #[allow(clippy::disallowed_types)] let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) diff --git a/policies/register/register.rego b/policies/register/register.rego index 4507a10ca..1fb400aa5 100644 --- a/policies/register/register.rego +++ b/policies/register/register.rego @@ -13,14 +13,17 @@ allow if { count(violation) == 0 } +mxid(username, server_name) := sprintf("@%s:%s", [username, server_name]) + # METADATA # entrypoint: true violation contains {"field": "username", "msg": "username too short"} if { - count(input.username) <= 2 + count(input.username) == 0 } violation contains {"field": "username", "msg": "username too long"} if { - count(input.username) > 64 + user_id := mxid(input.username, data.server_name) + count(user_id) > 255 } violation contains {"field": "username", "msg": "username contains invalid characters"} if { diff --git a/policies/register/register_test.rego b/policies/register/register_test.rego index 0d270ec26..0199e2ad0 100644 --- a/policies/register/register_test.rego +++ b/policies/register/register_test.rego @@ -42,15 +42,29 @@ test_no_email if { register.allow with input as {"username": "hello", "registration_method": "upstream-oauth2"} } -test_short_username if { - not register.allow with input as {"username": "a", "registration_method": "upstream-oauth2"} +test_empty_username if { + not register.allow with input as {"username": "", "registration_method": "upstream-oauth2"} } test_long_username if { not register.allow with input as { - "username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "username": concat("", ["a" | some x in numbers.range(1, 249)]), "registration_method": "upstream-oauth2", } + with data.server_name as "matrix.org" + + # This makes a MXID that is exactly 255 characters long + register.allow with input as { + "username": concat("", ["a" | some x in numbers.range(1, 249)]), + "registration_method": "upstream-oauth2", + } + with data.server_name as "a.io" + + not register.allow with input as { + "username": concat("", ["a" | some x in numbers.range(1, 250)]), + "registration_method": "upstream-oauth2", + } + with data.server_name as "a.io" } test_invalid_username if {