Allow longer & shorter usernames, complying with the MXID length spec
This commit is contained in:
@@ -8,7 +8,7 @@ use std::process::ExitCode;
|
||||
|
||||
use clap::Parser;
|
||||
use figment::Figment;
|
||||
use mas_config::{ConfigurationSectionExt, PolicyConfig};
|
||||
use mas_config::{ConfigurationSection, ConfigurationSectionExt, MatrixConfig, PolicyConfig};
|
||||
use tracing::{info, info_span};
|
||||
|
||||
use crate::util::policy_factory_from_config;
|
||||
@@ -33,8 +33,9 @@ impl Options {
|
||||
SC::Policy => {
|
||||
let _span = info_span!("cli.debug.policy").entered();
|
||||
let config = PolicyConfig::extract_or_default(figment)?;
|
||||
let matrix_config = MatrixConfig::extract(figment)?;
|
||||
info!("Loading and compiling the policy module");
|
||||
let policy_factory = policy_factory_from_config(&config).await?;
|
||||
let policy_factory = policy_factory_from_config(&config, &matrix_config).await?;
|
||||
|
||||
let _instance = policy_factory.instantiate().await?;
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ impl Options {
|
||||
|
||||
// Load and compile the WASM policies (and fallback to the default embedded one)
|
||||
info!("Loading and compiling the policy module");
|
||||
let policy_factory = policy_factory_from_config(&config.policy).await?;
|
||||
let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?;
|
||||
let policy_factory = Arc::new(policy_factory);
|
||||
|
||||
let url_builder = UrlBuilder::new(
|
||||
|
||||
@@ -101,6 +101,7 @@ pub fn mailer_from_config(
|
||||
|
||||
pub async fn policy_factory_from_config(
|
||||
config: &PolicyConfig,
|
||||
matrix_config: &MatrixConfig,
|
||||
) -> Result<PolicyFactory, anyhow::Error> {
|
||||
let policy_file = tokio::fs::File::open(&config.wasm_module)
|
||||
.await
|
||||
@@ -113,7 +114,10 @@ pub async fn policy_factory_from_config(
|
||||
email: config.email_entrypoint.clone(),
|
||||
};
|
||||
|
||||
PolicyFactory::load(policy_file, config.data.clone(), entrypoints)
|
||||
let data =
|
||||
mas_policy::Data::new(matrix_config.homeserver.clone()).with_rest(config.data.clone());
|
||||
|
||||
PolicyFactory::load(policy_file, data, entrypoints)
|
||||
.await
|
||||
.context("failed to load the policy")
|
||||
}
|
||||
|
||||
@@ -469,9 +469,12 @@ async fn test_oauth2_client_credentials(pool: PgPool) {
|
||||
// Now make the client admin and try again
|
||||
let state = {
|
||||
let mut state = state;
|
||||
state.policy_factory = test_utils::policy_factory(serde_json::json!({
|
||||
state.policy_factory = test_utils::policy_factory(
|
||||
"example.com",
|
||||
serde_json::json!({
|
||||
"admin_clients": [client_id],
|
||||
}))
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
state
|
||||
@@ -593,9 +596,12 @@ async fn test_add_user(pool: PgPool) {
|
||||
// Make the client admin
|
||||
let state = {
|
||||
let mut state = state;
|
||||
state.policy_factory = test_utils::policy_factory(serde_json::json!({
|
||||
state.policy_factory = test_utils::policy_factory(
|
||||
"example.com",
|
||||
serde_json::json!({
|
||||
"admin_clients": [client_id],
|
||||
}))
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
state
|
||||
|
||||
@@ -1475,9 +1475,12 @@ mod tests {
|
||||
// Now, if we add the client to the admin list in the policy, it should work
|
||||
let state = {
|
||||
let mut state = state;
|
||||
state.policy_factory = crate::test_utils::policy_factory(serde_json::json!({
|
||||
state.policy_factory = crate::test_utils::policy_factory(
|
||||
"example.com",
|
||||
serde_json::json!({
|
||||
"admin_clients": [client_id]
|
||||
}))
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
state
|
||||
|
||||
@@ -69,6 +69,7 @@ pub(crate) fn setup() {
|
||||
}
|
||||
|
||||
pub(crate) async fn policy_factory(
|
||||
server_name: &str,
|
||||
data: serde_json::Value,
|
||||
) -> Result<Arc<PolicyFactory>, anyhow::Error> {
|
||||
let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
@@ -84,6 +85,8 @@ pub(crate) async fn policy_factory(
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
let data = mas_policy::Data::new(server_name.to_owned()).with_rest(data);
|
||||
|
||||
let policy_factory = PolicyFactory::load(file, data, entrypoints).await?;
|
||||
let policy_factory = Arc::new(policy_factory);
|
||||
Ok(policy_factory)
|
||||
@@ -192,7 +195,8 @@ impl TestState {
|
||||
PasswordManager::disabled()
|
||||
};
|
||||
|
||||
let policy_factory = policy_factory(serde_json::json!({})).await?;
|
||||
let policy_factory =
|
||||
policy_factory(&site_config.server_name, serde_json::json!({})).await?;
|
||||
|
||||
let homeserver_connection =
|
||||
Arc::new(MockHomeserverConnection::new(&site_config.server_name));
|
||||
@@ -297,9 +301,12 @@ impl TestState {
|
||||
// Make the client admin
|
||||
let state = {
|
||||
let mut state = self.clone();
|
||||
state.policy_factory = policy_factory(serde_json::json!({
|
||||
state.policy_factory = policy_factory(
|
||||
"example.com",
|
||||
serde_json::json!({
|
||||
"admin_clients": [client_id],
|
||||
}))
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
state
|
||||
|
||||
@@ -12,6 +12,7 @@ use opa_wasm::{
|
||||
wasmtime::{Config, Engine, Module, OptLevel, Store},
|
||||
Runtime,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
@@ -69,10 +70,34 @@ impl Entrypoints {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct Data {
|
||||
server_name: String,
|
||||
|
||||
#[serde(flatten)]
|
||||
rest: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl Data {
|
||||
#[must_use]
|
||||
pub fn new(server_name: String) -> Self {
|
||||
Self {
|
||||
server_name,
|
||||
rest: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
|
||||
self.rest = Some(rest);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PolicyFactory {
|
||||
engine: Engine,
|
||||
module: Module,
|
||||
data: serde_json::Value,
|
||||
data: Data,
|
||||
entrypoints: Entrypoints,
|
||||
}
|
||||
|
||||
@@ -80,7 +105,7 @@ impl PolicyFactory {
|
||||
#[tracing::instrument(name = "policy.load", skip(source), err)]
|
||||
pub async fn load(
|
||||
mut source: impl AsyncRead + std::marker::Unpin,
|
||||
data: serde_json::Value,
|
||||
data: Data,
|
||||
entrypoints: Entrypoints,
|
||||
) -> Result<Self, LoadError> {
|
||||
let mut config = Config::default();
|
||||
@@ -364,10 +389,10 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register() {
|
||||
let data = serde_json::json!({
|
||||
let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({
|
||||
"allowed_domains": ["element.io", "*.element.io"],
|
||||
"banned_domains": ["staging.element.io"],
|
||||
});
|
||||
}));
|
||||
|
||||
#[allow(clippy::disallowed_types)]
|
||||
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
|
||||
@@ -13,14 +13,17 @@ allow if {
|
||||
count(violation) == 0
|
||||
}
|
||||
|
||||
mxid(username, server_name) := sprintf("@%s:%s", [username, server_name])
|
||||
|
||||
# METADATA
|
||||
# entrypoint: true
|
||||
violation contains {"field": "username", "msg": "username too short"} if {
|
||||
count(input.username) <= 2
|
||||
count(input.username) == 0
|
||||
}
|
||||
|
||||
violation contains {"field": "username", "msg": "username too long"} if {
|
||||
count(input.username) > 64
|
||||
user_id := mxid(input.username, data.server_name)
|
||||
count(user_id) > 255
|
||||
}
|
||||
|
||||
violation contains {"field": "username", "msg": "username contains invalid characters"} if {
|
||||
|
||||
@@ -42,15 +42,29 @@ test_no_email if {
|
||||
register.allow with input as {"username": "hello", "registration_method": "upstream-oauth2"}
|
||||
}
|
||||
|
||||
test_short_username if {
|
||||
not register.allow with input as {"username": "a", "registration_method": "upstream-oauth2"}
|
||||
test_empty_username if {
|
||||
not register.allow with input as {"username": "", "registration_method": "upstream-oauth2"}
|
||||
}
|
||||
|
||||
test_long_username if {
|
||||
not register.allow with input as {
|
||||
"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
"username": concat("", ["a" | some x in numbers.range(1, 249)]),
|
||||
"registration_method": "upstream-oauth2",
|
||||
}
|
||||
with data.server_name as "matrix.org"
|
||||
|
||||
# This makes a MXID that is exactly 255 characters long
|
||||
register.allow with input as {
|
||||
"username": concat("", ["a" | some x in numbers.range(1, 249)]),
|
||||
"registration_method": "upstream-oauth2",
|
||||
}
|
||||
with data.server_name as "a.io"
|
||||
|
||||
not register.allow with input as {
|
||||
"username": concat("", ["a" | some x in numbers.range(1, 250)]),
|
||||
"registration_method": "upstream-oauth2",
|
||||
}
|
||||
with data.server_name as "a.io"
|
||||
}
|
||||
|
||||
test_invalid_username if {
|
||||
|
||||
Reference in New Issue
Block a user