Allow longer & shorter usernames, complying with the MXID length spec

This commit is contained in:
Quentin Gliech
2024-12-19 12:12:34 +01:00
parent 961dd68005
commit 1e3d838c99
9 changed files with 89 additions and 26 deletions

View File

@@ -8,7 +8,7 @@ use std::process::ExitCode;
use clap::Parser;
use figment::Figment;
use mas_config::{ConfigurationSectionExt, PolicyConfig};
use mas_config::{ConfigurationSection, ConfigurationSectionExt, MatrixConfig, PolicyConfig};
use tracing::{info, info_span};
use crate::util::policy_factory_from_config;
@@ -33,8 +33,9 @@ impl Options {
SC::Policy => {
let _span = info_span!("cli.debug.policy").entered();
let config = PolicyConfig::extract_or_default(figment)?;
let matrix_config = MatrixConfig::extract(figment)?;
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config).await?;
let policy_factory = policy_factory_from_config(&config, &matrix_config).await?;
let _instance = policy_factory.instantiate().await?;
}

View File

@@ -123,7 +123,7 @@ impl Options {
// Load and compile the WASM policies (and fallback to the default embedded one)
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config.policy).await?;
let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?;
let policy_factory = Arc::new(policy_factory);
let url_builder = UrlBuilder::new(

View File

@@ -101,6 +101,7 @@ pub fn mailer_from_config(
pub async fn policy_factory_from_config(
config: &PolicyConfig,
matrix_config: &MatrixConfig,
) -> Result<PolicyFactory, anyhow::Error> {
let policy_file = tokio::fs::File::open(&config.wasm_module)
.await
@@ -113,7 +114,10 @@ pub async fn policy_factory_from_config(
email: config.email_entrypoint.clone(),
};
PolicyFactory::load(policy_file, config.data.clone(), entrypoints)
let data =
mas_policy::Data::new(matrix_config.homeserver.clone()).with_rest(config.data.clone());
PolicyFactory::load(policy_file, data, entrypoints)
.await
.context("failed to load the policy")
}

View File

@@ -469,9 +469,12 @@ async fn test_oauth2_client_credentials(pool: PgPool) {
// Now make the client admin and try again
let state = {
let mut state = state;
state.policy_factory = test_utils::policy_factory(serde_json::json!({
"admin_clients": [client_id],
}))
state.policy_factory = test_utils::policy_factory(
"example.com",
serde_json::json!({
"admin_clients": [client_id],
}),
)
.await
.unwrap();
state
@@ -593,9 +596,12 @@ async fn test_add_user(pool: PgPool) {
// Make the client admin
let state = {
let mut state = state;
state.policy_factory = test_utils::policy_factory(serde_json::json!({
"admin_clients": [client_id],
}))
state.policy_factory = test_utils::policy_factory(
"example.com",
serde_json::json!({
"admin_clients": [client_id],
}),
)
.await
.unwrap();
state

View File

@@ -1475,9 +1475,12 @@ mod tests {
// Now, if we add the client to the admin list in the policy, it should work
let state = {
let mut state = state;
state.policy_factory = crate::test_utils::policy_factory(serde_json::json!({
"admin_clients": [client_id]
}))
state.policy_factory = crate::test_utils::policy_factory(
"example.com",
serde_json::json!({
"admin_clients": [client_id]
}),
)
.await
.unwrap();
state

View File

@@ -69,6 +69,7 @@ pub(crate) fn setup() {
}
pub(crate) async fn policy_factory(
server_name: &str,
data: serde_json::Value,
) -> Result<Arc<PolicyFactory>, anyhow::Error> {
let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR"))
@@ -84,6 +85,8 @@ pub(crate) async fn policy_factory(
email: "email/violation".to_owned(),
};
let data = mas_policy::Data::new(server_name.to_owned()).with_rest(data);
let policy_factory = PolicyFactory::load(file, data, entrypoints).await?;
let policy_factory = Arc::new(policy_factory);
Ok(policy_factory)
@@ -192,7 +195,8 @@ impl TestState {
PasswordManager::disabled()
};
let policy_factory = policy_factory(serde_json::json!({})).await?;
let policy_factory =
policy_factory(&site_config.server_name, serde_json::json!({})).await?;
let homeserver_connection =
Arc::new(MockHomeserverConnection::new(&site_config.server_name));
@@ -297,9 +301,12 @@ impl TestState {
// Make the client admin
let state = {
let mut state = self.clone();
state.policy_factory = policy_factory(serde_json::json!({
"admin_clients": [client_id],
}))
state.policy_factory = policy_factory(
"example.com",
serde_json::json!({
"admin_clients": [client_id],
}),
)
.await
.unwrap();
state

View File

@@ -12,6 +12,7 @@ use opa_wasm::{
wasmtime::{Config, Engine, Module, OptLevel, Store},
Runtime,
};
use serde::Serialize;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
@@ -69,10 +70,34 @@ impl Entrypoints {
}
}
#[derive(Serialize, Debug)]
pub struct Data {
server_name: String,
#[serde(flatten)]
rest: Option<serde_json::Value>,
}
impl Data {
#[must_use]
pub fn new(server_name: String) -> Self {
Self {
server_name,
rest: None,
}
}
#[must_use]
pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
self.rest = Some(rest);
self
}
}
pub struct PolicyFactory {
engine: Engine,
module: Module,
data: serde_json::Value,
data: Data,
entrypoints: Entrypoints,
}
@@ -80,7 +105,7 @@ impl PolicyFactory {
#[tracing::instrument(name = "policy.load", skip(source), err)]
pub async fn load(
mut source: impl AsyncRead + std::marker::Unpin,
data: serde_json::Value,
data: Data,
entrypoints: Entrypoints,
) -> Result<Self, LoadError> {
let mut config = Config::default();
@@ -364,10 +389,10 @@ mod tests {
#[tokio::test]
async fn test_register() {
let data = serde_json::json!({
let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({
"allowed_domains": ["element.io", "*.element.io"],
"banned_domains": ["staging.element.io"],
});
}));
#[allow(clippy::disallowed_types)]
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))

View File

@@ -13,14 +13,17 @@ allow if {
count(violation) == 0
}
mxid(username, server_name) := sprintf("@%s:%s", [username, server_name])
# METADATA
# entrypoint: true
violation contains {"field": "username", "msg": "username too short"} if {
count(input.username) <= 2
count(input.username) == 0
}
violation contains {"field": "username", "msg": "username too long"} if {
count(input.username) > 64
user_id := mxid(input.username, data.server_name)
count(user_id) > 255
}
violation contains {"field": "username", "msg": "username contains invalid characters"} if {

View File

@@ -42,15 +42,29 @@ test_no_email if {
register.allow with input as {"username": "hello", "registration_method": "upstream-oauth2"}
}
test_short_username if {
not register.allow with input as {"username": "a", "registration_method": "upstream-oauth2"}
test_empty_username if {
not register.allow with input as {"username": "", "registration_method": "upstream-oauth2"}
}
test_long_username if {
not register.allow with input as {
"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
"username": concat("", ["a" | some x in numbers.range(1, 249)]),
"registration_method": "upstream-oauth2",
}
with data.server_name as "matrix.org"
# This makes a MXID that is exactly 255 characters long
register.allow with input as {
"username": concat("", ["a" | some x in numbers.range(1, 249)]),
"registration_method": "upstream-oauth2",
}
with data.server_name as "a.io"
not register.allow with input as {
"username": concat("", ["a" | some x in numbers.range(1, 250)]),
"registration_method": "upstream-oauth2",
}
with data.server_name as "a.io"
}
test_invalid_username if {