|
|
|
|
@@ -6,11 +6,14 @@
|
|
|
|
|
|
|
|
|
|
pub mod model;
|
|
|
|
|
|
|
|
|
|
use std::sync::Arc;
|
|
|
|
|
|
|
|
|
|
use arc_swap::ArcSwap;
|
|
|
|
|
use mas_data_model::Ulid;
|
|
|
|
|
use opa_wasm::{
|
|
|
|
|
Runtime,
|
|
|
|
|
wasmtime::{Config, Engine, Module, OptLevel, Store},
|
|
|
|
|
};
|
|
|
|
|
use serde::Serialize;
|
|
|
|
|
use thiserror::Error;
|
|
|
|
|
use tokio::io::{AsyncRead, AsyncReadExt};
|
|
|
|
|
|
|
|
|
|
@@ -33,10 +36,23 @@ pub enum LoadError {
|
|
|
|
|
#[error("failed to compile WASM module")]
|
|
|
|
|
Compilation(#[source] anyhow::Error),
|
|
|
|
|
|
|
|
|
|
#[error("invalid policy data")]
|
|
|
|
|
InvalidData(#[source] anyhow::Error),
|
|
|
|
|
|
|
|
|
|
#[error("failed to instantiate a test instance")]
|
|
|
|
|
Instantiate(#[source] InstantiateError),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl LoadError {
|
|
|
|
|
/// Creates an example of an invalid data error, used for API response
|
|
|
|
|
/// documentation
|
|
|
|
|
#[doc(hidden)]
|
|
|
|
|
#[must_use]
|
|
|
|
|
pub fn invalid_data_example() -> Self {
|
|
|
|
|
Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Error)]
|
|
|
|
|
pub enum InstantiateError {
|
|
|
|
|
#[error("failed to create WASM runtime")]
|
|
|
|
|
@@ -69,11 +85,10 @@ impl Entrypoints {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Serialize, Debug)]
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
|
pub struct Data {
|
|
|
|
|
server_name: String,
|
|
|
|
|
|
|
|
|
|
#[serde(flatten)]
|
|
|
|
|
rest: Option<serde_json::Value>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -91,12 +106,93 @@ impl Data {
|
|
|
|
|
self.rest = Some(rest);
|
|
|
|
|
self
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
|
|
|
|
|
let base = serde_json::json!({
|
|
|
|
|
"server_name": self.server_name,
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
if let Some(rest) = &self.rest {
|
|
|
|
|
merge_data(base, rest.clone())
|
|
|
|
|
} else {
|
|
|
|
|
Ok(base)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn value_kind(value: &serde_json::Value) -> &'static str {
|
|
|
|
|
match value {
|
|
|
|
|
serde_json::Value::Object(_) => "object",
|
|
|
|
|
serde_json::Value::Array(_) => "array",
|
|
|
|
|
serde_json::Value::String(_) => "string",
|
|
|
|
|
serde_json::Value::Number(_) => "number",
|
|
|
|
|
serde_json::Value::Bool(_) => "boolean",
|
|
|
|
|
serde_json::Value::Null => "null",
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn merge_data(
|
|
|
|
|
mut left: serde_json::Value,
|
|
|
|
|
right: serde_json::Value,
|
|
|
|
|
) -> Result<serde_json::Value, anyhow::Error> {
|
|
|
|
|
merge_data_rec(&mut left, right)?;
|
|
|
|
|
Ok(left)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn merge_data_rec(
|
|
|
|
|
left: &mut serde_json::Value,
|
|
|
|
|
right: serde_json::Value,
|
|
|
|
|
) -> Result<(), anyhow::Error> {
|
|
|
|
|
match (left, right) {
|
|
|
|
|
(serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
|
|
|
|
|
for (key, value) in right {
|
|
|
|
|
if let Some(left_value) = left.get_mut(&key) {
|
|
|
|
|
merge_data_rec(left_value, value)?;
|
|
|
|
|
} else {
|
|
|
|
|
left.insert(key, value);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
(serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
|
|
|
|
|
left.extend(right);
|
|
|
|
|
}
|
|
|
|
|
// Other values override
|
|
|
|
|
(serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
|
|
|
|
|
*left = right;
|
|
|
|
|
}
|
|
|
|
|
(serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
|
|
|
|
|
*left = right;
|
|
|
|
|
}
|
|
|
|
|
(serde_json::Value::String(left), serde_json::Value::String(right)) => {
|
|
|
|
|
*left = right;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Null gets overridden by anything
|
|
|
|
|
(left, right) if left.is_null() => *left = right,
|
|
|
|
|
|
|
|
|
|
// Null on the right makes the left value null
|
|
|
|
|
(left, right) if right.is_null() => *left = right,
|
|
|
|
|
|
|
|
|
|
(left, right) => anyhow::bail!(
|
|
|
|
|
"Cannot merge a {} into a {}",
|
|
|
|
|
value_kind(&right),
|
|
|
|
|
value_kind(left),
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct DynamicData {
|
|
|
|
|
version: Option<Ulid>,
|
|
|
|
|
merged: serde_json::Value,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub struct PolicyFactory {
|
|
|
|
|
engine: Engine,
|
|
|
|
|
module: Module,
|
|
|
|
|
data: Data,
|
|
|
|
|
dynamic_data: ArcSwap<DynamicData>,
|
|
|
|
|
entrypoints: Entrypoints,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -124,10 +220,17 @@ impl PolicyFactory {
|
|
|
|
|
.await?
|
|
|
|
|
.map_err(LoadError::Compilation)?;
|
|
|
|
|
|
|
|
|
|
let merged = data.to_value().map_err(LoadError::InvalidData)?;
|
|
|
|
|
let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
|
|
|
|
|
version: None,
|
|
|
|
|
merged,
|
|
|
|
|
}));
|
|
|
|
|
|
|
|
|
|
let factory = Self {
|
|
|
|
|
engine,
|
|
|
|
|
module,
|
|
|
|
|
data,
|
|
|
|
|
dynamic_data,
|
|
|
|
|
entrypoints,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
@@ -140,8 +243,56 @@ impl PolicyFactory {
|
|
|
|
|
Ok(factory)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Set the dynamic data for the policy.
|
|
|
|
|
///
|
|
|
|
|
/// The `dynamic_data` object is merged with the static data given when the
|
|
|
|
|
/// policy was loaded.
|
|
|
|
|
///
|
|
|
|
|
/// Returns `true` if the data was updated, `false` if the version
|
|
|
|
|
/// of the dynamic data was the same as the one we already have.
|
|
|
|
|
///
|
|
|
|
|
/// # Errors
|
|
|
|
|
///
|
|
|
|
|
/// Returns an error if the data can't be merged with the static data, or if
|
|
|
|
|
/// the policy can't be instantiated with the new data.
|
|
|
|
|
pub async fn set_dynamic_data(
|
|
|
|
|
&self,
|
|
|
|
|
dynamic_data: mas_data_model::PolicyData,
|
|
|
|
|
) -> Result<bool, LoadError> {
|
|
|
|
|
// Check if the version of the dynamic data we have is the same as the one we're
|
|
|
|
|
// trying to set
|
|
|
|
|
if self.dynamic_data.load().version == Some(dynamic_data.id) {
|
|
|
|
|
// Don't do anything if the version is the same
|
|
|
|
|
return Ok(false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
|
|
|
|
|
let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
|
|
|
|
|
|
|
|
|
|
// Try to instantiate with the new data
|
|
|
|
|
self.instantiate_with_data(&merged)
|
|
|
|
|
.await
|
|
|
|
|
.map_err(LoadError::Instantiate)?;
|
|
|
|
|
|
|
|
|
|
// If instantiation succeeds, swap the data
|
|
|
|
|
self.dynamic_data.store(Arc::new(DynamicData {
|
|
|
|
|
version: Some(dynamic_data.id),
|
|
|
|
|
merged,
|
|
|
|
|
}));
|
|
|
|
|
|
|
|
|
|
Ok(true)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[tracing::instrument(name = "policy.instantiate", skip_all, err)]
|
|
|
|
|
pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
|
|
|
|
|
let data = self.dynamic_data.load();
|
|
|
|
|
self.instantiate_with_data(&data.merged).await
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn instantiate_with_data(
|
|
|
|
|
&self,
|
|
|
|
|
data: &serde_json::Value,
|
|
|
|
|
) -> Result<Policy, InstantiateError> {
|
|
|
|
|
let mut store = Store::new(&self.engine, ());
|
|
|
|
|
let runtime = Runtime::new(&mut store, &self.module)
|
|
|
|
|
.await
|
|
|
|
|
@@ -159,7 +310,7 @@ impl PolicyFactory {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let instance = runtime
|
|
|
|
|
.with_data(&mut store, &self.data)
|
|
|
|
|
.with_data(&mut store, data)
|
|
|
|
|
.await
|
|
|
|
|
.map_err(InstantiateError::LoadData)?;
|
|
|
|
|
|
|
|
|
|
@@ -273,6 +424,8 @@ impl Policy {
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
mod tests {
|
|
|
|
|
|
|
|
|
|
use std::time::SystemTime;
|
|
|
|
|
|
|
|
|
|
use super::*;
|
|
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
|
@@ -344,4 +497,167 @@ mod tests {
|
|
|
|
|
.unwrap();
|
|
|
|
|
assert!(!res.valid());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
|
async fn test_dynamic_data() {
|
|
|
|
|
let data = Data::new("example.com".to_owned());
|
|
|
|
|
|
|
|
|
|
#[allow(clippy::disallowed_types)]
|
|
|
|
|
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
|
|
|
|
|
.join("..")
|
|
|
|
|
.join("..")
|
|
|
|
|
.join("policies")
|
|
|
|
|
.join("policy.wasm");
|
|
|
|
|
|
|
|
|
|
let file = tokio::fs::File::open(path).await.unwrap();
|
|
|
|
|
|
|
|
|
|
let entrypoints = Entrypoints {
|
|
|
|
|
register: "register/violation".to_owned(),
|
|
|
|
|
client_registration: "client_registration/violation".to_owned(),
|
|
|
|
|
authorization_grant: "authorization_grant/violation".to_owned(),
|
|
|
|
|
email: "email/violation".to_owned(),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
|
|
|
|
|
|
|
|
|
let mut policy = factory.instantiate().await.unwrap();
|
|
|
|
|
|
|
|
|
|
let res = policy
|
|
|
|
|
.evaluate_register(RegisterInput {
|
|
|
|
|
registration_method: RegistrationMethod::Password,
|
|
|
|
|
username: "hello",
|
|
|
|
|
email: Some("hello@example.com"),
|
|
|
|
|
requester: Requester {
|
|
|
|
|
ip_address: None,
|
|
|
|
|
user_agent: None,
|
|
|
|
|
},
|
|
|
|
|
})
|
|
|
|
|
.await
|
|
|
|
|
.unwrap();
|
|
|
|
|
assert!(res.valid());
|
|
|
|
|
|
|
|
|
|
// Update the policy data
|
|
|
|
|
factory
|
|
|
|
|
.set_dynamic_data(mas_data_model::PolicyData {
|
|
|
|
|
id: Ulid::nil(),
|
|
|
|
|
created_at: SystemTime::now().into(),
|
|
|
|
|
data: serde_json::json!({
|
|
|
|
|
"emails": {
|
|
|
|
|
"banned_addresses": {
|
|
|
|
|
"substrings": ["hello"]
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}),
|
|
|
|
|
})
|
|
|
|
|
.await
|
|
|
|
|
.unwrap();
|
|
|
|
|
let mut policy = factory.instantiate().await.unwrap();
|
|
|
|
|
let res = policy
|
|
|
|
|
.evaluate_register(RegisterInput {
|
|
|
|
|
registration_method: RegistrationMethod::Password,
|
|
|
|
|
username: "hello",
|
|
|
|
|
email: Some("hello@example.com"),
|
|
|
|
|
requester: Requester {
|
|
|
|
|
ip_address: None,
|
|
|
|
|
user_agent: None,
|
|
|
|
|
},
|
|
|
|
|
})
|
|
|
|
|
.await
|
|
|
|
|
.unwrap();
|
|
|
|
|
assert!(!res.valid());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
|
async fn test_big_dynamic_data() {
|
|
|
|
|
let data = Data::new("example.com".to_owned());
|
|
|
|
|
|
|
|
|
|
#[allow(clippy::disallowed_types)]
|
|
|
|
|
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
|
|
|
|
|
.join("..")
|
|
|
|
|
.join("..")
|
|
|
|
|
.join("policies")
|
|
|
|
|
.join("policy.wasm");
|
|
|
|
|
|
|
|
|
|
let file = tokio::fs::File::open(path).await.unwrap();
|
|
|
|
|
|
|
|
|
|
let entrypoints = Entrypoints {
|
|
|
|
|
register: "register/violation".to_owned(),
|
|
|
|
|
client_registration: "client_registration/violation".to_owned(),
|
|
|
|
|
authorization_grant: "authorization_grant/violation".to_owned(),
|
|
|
|
|
email: "email/violation".to_owned(),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
|
|
|
|
|
|
|
|
|
// That is around 1 MB of JSON data. Each element is a 5-digit string, so 8
|
|
|
|
|
// characters including the quotes and a comma.
|
|
|
|
|
let data: Vec<String> = (0..(1024 * 1024 / 8))
|
|
|
|
|
.map(|i| format!("{:05}", i % 100_000))
|
|
|
|
|
.collect();
|
|
|
|
|
let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
|
|
|
|
|
factory
|
|
|
|
|
.set_dynamic_data(mas_data_model::PolicyData {
|
|
|
|
|
id: Ulid::nil(),
|
|
|
|
|
created_at: SystemTime::now().into(),
|
|
|
|
|
data: json,
|
|
|
|
|
})
|
|
|
|
|
.await
|
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
|
|
// Try instantiating the policy, make sure 5-digit numbers are banned from email
|
|
|
|
|
// addresses
|
|
|
|
|
let mut policy = factory.instantiate().await.unwrap();
|
|
|
|
|
let res = policy
|
|
|
|
|
.evaluate_register(RegisterInput {
|
|
|
|
|
registration_method: RegistrationMethod::Password,
|
|
|
|
|
username: "hello",
|
|
|
|
|
email: Some("12345@example.com"),
|
|
|
|
|
requester: Requester {
|
|
|
|
|
ip_address: None,
|
|
|
|
|
user_agent: None,
|
|
|
|
|
},
|
|
|
|
|
})
|
|
|
|
|
.await
|
|
|
|
|
.unwrap();
|
|
|
|
|
assert!(!res.valid());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn test_merge() {
|
|
|
|
|
use serde_json::json as j;
|
|
|
|
|
|
|
|
|
|
// Merging objects
|
|
|
|
|
let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
|
|
|
|
|
assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
|
|
|
|
|
|
|
|
|
|
// Override a value of the same type
|
|
|
|
|
let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
|
|
|
|
|
assert_eq!(res, j!({"hello": "john"}));
|
|
|
|
|
|
|
|
|
|
let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
|
|
|
|
|
assert_eq!(res, j!({"hello": false}));
|
|
|
|
|
|
|
|
|
|
let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
|
|
|
|
|
assert_eq!(res, j!({"hello": 42}));
|
|
|
|
|
|
|
|
|
|
// Override a value of a different type
|
|
|
|
|
merge_data(j!({"hello": "world"}), j!({"hello": 123}))
|
|
|
|
|
.expect_err("Can't merge different types");
|
|
|
|
|
|
|
|
|
|
// Merge arrays
|
|
|
|
|
let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
|
|
|
|
|
assert_eq!(res, j!({"hello": ["world", "john"]}));
|
|
|
|
|
|
|
|
|
|
// Null overrides a value
|
|
|
|
|
let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
|
|
|
|
|
assert_eq!(res, j!({"hello": null}));
|
|
|
|
|
|
|
|
|
|
// Null gets overridden by a value
|
|
|
|
|
let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
|
|
|
|
|
assert_eq!(res, j!({"hello": "world"}));
|
|
|
|
|
|
|
|
|
|
// Objects get deeply merged
|
|
|
|
|
let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
|
|
|
|
|
assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|