From fe789884ab4dba4e110aa7b1ae6abb54ee2cc394 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Feb 2025 16:21:54 +0100 Subject: [PATCH] policy: allow dynamically setting policy data --- Cargo.lock | 1 + Cargo.toml | 4 + crates/policy/Cargo.toml | 1 + crates/policy/src/lib.rs | 324 +++++++++++++++++++++++++++++++++++- crates/templates/Cargo.toml | 2 +- 5 files changed, 327 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 041b54c00..131f3b74e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3588,6 +3588,7 @@ name = "mas-policy" version = "0.14.1" dependencies = [ "anyhow", + "arc-swap", "mas-data-model", "oauth2-types", "opa-wasm", diff --git a/Cargo.toml b/Cargo.toml index c1dc7569c..72f96c6bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,10 @@ syn2mas = { path = "./crates/syn2mas", version = "=0.14.1" } version = "0.14.1" features = ["axum", "axum-extra", "axum-json", "axum-query", "macros"] +# An `Arc` that can be atomically updated +[workspace.dependencies.arc-swap] +version = "1.7.1" + # GraphQL server [workspace.dependencies.async-graphql] version = "7.0.15" diff --git a/crates/policy/Cargo.toml b/crates/policy/Cargo.toml index 7212b4991..f1ceabfab 100644 --- a/crates/policy/Cargo.toml +++ b/crates/policy/Cargo.toml @@ -13,6 +13,7 @@ workspace = true [dependencies] anyhow.workspace = true +arc-swap.workspace = true opa-wasm = "0.1.4" serde.workspace = true serde_json.workspace = true diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 47d54f468..c8b771b6d 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -6,11 +6,14 @@ pub mod model; +use std::sync::Arc; + +use arc_swap::ArcSwap; +use mas_data_model::Ulid; use opa_wasm::{ Runtime, wasmtime::{Config, Engine, Module, OptLevel, Store}, }; -use serde::Serialize; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; @@ -33,10 +36,23 @@ pub enum LoadError { #[error("failed to compile WASM module")] Compilation(#[source] anyhow::Error), + #[error("invalid policy data")] + InvalidData(#[source] anyhow::Error), + #[error("failed to instantiate a test instance")] Instantiate(#[source] InstantiateError), } +impl LoadError { + /// Creates an example of an invalid data error, used for API response + /// documentation + #[doc(hidden)] + #[must_use] + pub fn invalid_data_example() -> Self { + Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects")) + } +} + #[derive(Debug, Error)] pub enum InstantiateError { #[error("failed to create WASM runtime")] @@ -69,11 +85,10 @@ impl Entrypoints { } } -#[derive(Serialize, Debug)] +#[derive(Debug)] pub struct Data { server_name: String, - #[serde(flatten)] rest: Option, } @@ -91,12 +106,93 @@ impl Data { self.rest = Some(rest); self } + + fn to_value(&self) -> Result { + let base = serde_json::json!({ + "server_name": self.server_name, + }); + + if let Some(rest) = &self.rest { + merge_data(base, rest.clone()) + } else { + Ok(base) + } + } +} + +fn value_kind(value: &serde_json::Value) -> &'static str { + match value { + serde_json::Value::Object(_) => "object", + serde_json::Value::Array(_) => "array", + serde_json::Value::String(_) => "string", + serde_json::Value::Number(_) => "number", + serde_json::Value::Bool(_) => "boolean", + serde_json::Value::Null => "null", + } +} + +fn merge_data( + mut left: serde_json::Value, + right: serde_json::Value, +) -> Result { + merge_data_rec(&mut left, right)?; + Ok(left) +} + +fn merge_data_rec( + left: &mut serde_json::Value, + right: serde_json::Value, +) -> Result<(), anyhow::Error> { + match (left, right) { + (serde_json::Value::Object(left), serde_json::Value::Object(right)) => { + for (key, value) in right { + if let Some(left_value) = left.get_mut(&key) { + merge_data_rec(left_value, value)?; + } else { + left.insert(key, value); + } + } + } + (serde_json::Value::Array(left), serde_json::Value::Array(right)) => { + left.extend(right); + } + // Other values override + (serde_json::Value::Number(left), serde_json::Value::Number(right)) => { + *left = right; + } + (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => { + *left = right; + } + (serde_json::Value::String(left), serde_json::Value::String(right)) => { + *left = right; + } + + // Null gets overridden by anything + (left, right) if left.is_null() => *left = right, + + // Null on the right makes the left value null + (left, right) if right.is_null() => *left = right, + + (left, right) => anyhow::bail!( + "Cannot merge a {} into a {}", + value_kind(&right), + value_kind(left), + ), + } + + Ok(()) +} + +struct DynamicData { + version: Option, + merged: serde_json::Value, } pub struct PolicyFactory { engine: Engine, module: Module, data: Data, + dynamic_data: ArcSwap, entrypoints: Entrypoints, } @@ -124,10 +220,17 @@ impl PolicyFactory { .await? .map_err(LoadError::Compilation)?; + let merged = data.to_value().map_err(LoadError::InvalidData)?; + let dynamic_data = ArcSwap::new(Arc::new(DynamicData { + version: None, + merged, + })); + let factory = Self { engine, module, data, + dynamic_data, entrypoints, }; @@ -140,8 +243,56 @@ impl PolicyFactory { Ok(factory) } + /// Set the dynamic data for the policy. + /// + /// The `dynamic_data` object is merged with the static data given when the + /// policy was loaded. + /// + /// Returns `true` if the data was updated, `false` if the version + /// of the dynamic data was the same as the one we already have. + /// + /// # Errors + /// + /// Returns an error if the data can't be merged with the static data, or if + /// the policy can't be instantiated with the new data. + pub async fn set_dynamic_data( + &self, + dynamic_data: mas_data_model::PolicyData, + ) -> Result { + // Check if the version of the dynamic data we have is the same as the one we're + // trying to set + if self.dynamic_data.load().version == Some(dynamic_data.id) { + // Don't do anything if the version is the same + return Ok(false); + } + + let static_data = self.data.to_value().map_err(LoadError::InvalidData)?; + let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?; + + // Try to instantiate with the new data + self.instantiate_with_data(&merged) + .await + .map_err(LoadError::Instantiate)?; + + // If instantiation succeeds, swap the data + self.dynamic_data.store(Arc::new(DynamicData { + version: Some(dynamic_data.id), + merged, + })); + + Ok(true) + } + #[tracing::instrument(name = "policy.instantiate", skip_all, err)] pub async fn instantiate(&self) -> Result { + let data = self.dynamic_data.load(); + self.instantiate_with_data(&data.merged).await + } + + async fn instantiate_with_data( + &self, + data: &serde_json::Value, + ) -> Result { let mut store = Store::new(&self.engine, ()); let runtime = Runtime::new(&mut store, &self.module) .await @@ -159,7 +310,7 @@ impl PolicyFactory { } let instance = runtime - .with_data(&mut store, &self.data) + .with_data(&mut store, data) .await .map_err(InstantiateError::LoadData)?; @@ -273,6 +424,8 @@ impl Policy { #[cfg(test)] mod tests { + use std::time::SystemTime; + use super::*; #[tokio::test] @@ -344,4 +497,167 @@ mod tests { .unwrap(); assert!(!res.valid()); } + + #[tokio::test] + async fn test_dynamic_data() { + let data = Data::new("example.com".to_owned()); + + #[allow(clippy::disallowed_types)] + let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join("..") + .join("policies") + .join("policy.wasm"); + + let file = tokio::fs::File::open(path).await.unwrap(); + + let entrypoints = Entrypoints { + register: "register/violation".to_owned(), + client_registration: "client_registration/violation".to_owned(), + authorization_grant: "authorization_grant/violation".to_owned(), + email: "email/violation".to_owned(), + }; + + let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap(); + + let mut policy = factory.instantiate().await.unwrap(); + + let res = policy + .evaluate_register(RegisterInput { + registration_method: RegistrationMethod::Password, + username: "hello", + email: Some("hello@example.com"), + requester: Requester { + ip_address: None, + user_agent: None, + }, + }) + .await + .unwrap(); + assert!(res.valid()); + + // Update the policy data + factory + .set_dynamic_data(mas_data_model::PolicyData { + id: Ulid::nil(), + created_at: SystemTime::now().into(), + data: serde_json::json!({ + "emails": { + "banned_addresses": { + "substrings": ["hello"] + } + } + }), + }) + .await + .unwrap(); + let mut policy = factory.instantiate().await.unwrap(); + let res = policy + .evaluate_register(RegisterInput { + registration_method: RegistrationMethod::Password, + username: "hello", + email: Some("hello@example.com"), + requester: Requester { + ip_address: None, + user_agent: None, + }, + }) + .await + .unwrap(); + assert!(!res.valid()); + } + + #[tokio::test] + async fn test_big_dynamic_data() { + let data = Data::new("example.com".to_owned()); + + #[allow(clippy::disallowed_types)] + let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join("..") + .join("policies") + .join("policy.wasm"); + + let file = tokio::fs::File::open(path).await.unwrap(); + + let entrypoints = Entrypoints { + register: "register/violation".to_owned(), + client_registration: "client_registration/violation".to_owned(), + authorization_grant: "authorization_grant/violation".to_owned(), + email: "email/violation".to_owned(), + }; + + let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap(); + + // That is around 1 MB of JSON data. Each element is a 5-digit string, so 8 + // characters including the quotes and a comma. + let data: Vec = (0..(1024 * 1024 / 8)) + .map(|i| format!("{:05}", i % 100_000)) + .collect(); + let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } }); + factory + .set_dynamic_data(mas_data_model::PolicyData { + id: Ulid::nil(), + created_at: SystemTime::now().into(), + data: json, + }) + .await + .unwrap(); + + // Try instantiating the policy, make sure 5-digit numbers are banned from email + // addresses + let mut policy = factory.instantiate().await.unwrap(); + let res = policy + .evaluate_register(RegisterInput { + registration_method: RegistrationMethod::Password, + username: "hello", + email: Some("12345@example.com"), + requester: Requester { + ip_address: None, + user_agent: None, + }, + }) + .await + .unwrap(); + assert!(!res.valid()); + } + + #[test] + fn test_merge() { + use serde_json::json as j; + + // Merging objects + let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap(); + assert_eq!(res, j!({"hello": "world", "foo": "bar"})); + + // Override a value of the same type + let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap(); + assert_eq!(res, j!({"hello": "john"})); + + let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap(); + assert_eq!(res, j!({"hello": false})); + + let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap(); + assert_eq!(res, j!({"hello": 42})); + + // Override a value of a different type + merge_data(j!({"hello": "world"}), j!({"hello": 123})) + .expect_err("Can't merge different types"); + + // Merge arrays + let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap(); + assert_eq!(res, j!({"hello": ["world", "john"]})); + + // Null overrides a value + let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap(); + assert_eq!(res, j!({"hello": null})); + + // Null gets overridden by a value + let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap(); + assert_eq!(res, j!({"hello": "world"})); + + // Objects get deeply merged + let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap(); + assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}})); + } } diff --git a/crates/templates/Cargo.toml b/crates/templates/Cargo.toml index 696a517b9..68bbadb1d 100644 --- a/crates/templates/Cargo.toml +++ b/crates/templates/Cargo.toml @@ -12,7 +12,7 @@ publish = false workspace = true [dependencies] -arc-swap = "1.7.1" +arc-swap.workspace = true tracing.workspace = true tokio.workspace = true walkdir = "2.5.0"