Admin API to dynamically set policy data (#4115)

This commit is contained in:
Quentin Gliech
2025-03-14 10:26:31 +01:00
committed by GitHub
33 changed files with 1694 additions and 11 deletions

1
Cargo.lock generated
View File

@@ -3588,6 +3588,7 @@ name = "mas-policy"
version = "0.14.1"
dependencies = [
"anyhow",
"arc-swap",
"mas-data-model",
"oauth2-types",
"opa-wasm",

View File

@@ -61,6 +61,7 @@ syn2mas = { path = "./crates/syn2mas", version = "=0.14.1" }
version = "0.14.1"
features = ["axum", "axum-extra", "axum-json", "axum-query", "macros"]
# An `Arc` that can be atomically updated
[workspace.dependencies.arc-swap]
version = "1.7.1"

View File

@@ -203,6 +203,12 @@ impl FromRef<AppState> for Limiter {
}
}
impl FromRef<AppState> for Arc<PolicyFactory> {
fn from_ref(input: &AppState) -> Self {
input.policy_factory.clone()
}
}
impl FromRef<AppState> for Arc<dyn HomeserverConnection> {
fn from_ref(input: &AppState) -> Self {
Arc::clone(&input.homeserver_connection)

View File

@@ -1,4 +1,4 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
@@ -8,10 +8,14 @@ use std::process::ExitCode;
use clap::Parser;
use figment::Figment;
use mas_config::{ConfigurationSection, ConfigurationSectionExt, MatrixConfig, PolicyConfig};
use mas_config::{
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig,
};
use tracing::{info, info_span};
use crate::util::policy_factory_from_config;
use crate::util::{
database_pool_from_config, load_policy_factory_dynamic_data, policy_factory_from_config,
};
#[derive(Parser, Debug)]
pub(super) struct Options {
@@ -22,7 +26,11 @@ pub(super) struct Options {
#[derive(Parser, Debug)]
enum Subcommand {
/// Check that the policies compile
Policy,
Policy {
/// With dynamic data loaded
#[arg(long)]
with_dynamic_data: bool,
},
}
impl Options {
@@ -30,13 +38,19 @@ impl Options {
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
use Subcommand as SC;
match self.subcommand {
SC::Policy => {
SC::Policy { with_dynamic_data } => {
let _span = info_span!("cli.debug.policy").entered();
let config = PolicyConfig::extract_or_default(figment)?;
let matrix_config = MatrixConfig::extract(figment)?;
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config, &matrix_config).await?;
if with_dynamic_data {
let database_config = DatabaseConfig::extract(figment)?;
let pool = database_pool_from_config(&database_config).await?;
load_policy_factory_dynamic_data(&policy_factory, &pool).await?;
}
let _instance = policy_factory.instantiate().await?;
}
}

View File

@@ -25,7 +25,8 @@ use crate::{
app_state::AppState,
lifecycle::LifecycleManager,
util::{
database_pool_from_config, homeserver_connection_from_config, mailer_from_config,
database_pool_from_config, homeserver_connection_from_config,
load_policy_factory_dynamic_data_continuously, mailer_from_config,
password_manager_from_config, policy_factory_from_config, site_config_from_config,
templates_from_config, test_mailer_in_background,
},
@@ -129,6 +130,14 @@ impl Options {
let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?;
let policy_factory = Arc::new(policy_factory);
load_policy_factory_dynamic_data_continuously(
&policy_factory,
&pool,
shutdown.soft_shutdown_token(),
shutdown.task_tracker(),
)
.await?;
let url_builder = UrlBuilder::new(
config.http.public_base.clone(),
config.http.issuer.clone(),

View File

@@ -19,11 +19,14 @@ use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection};
use mas_matrix_synapse::SynapseConnection;
use mas_policy::PolicyFactory;
use mas_router::UrlBuilder;
use mas_storage::RepositoryAccess;
use mas_storage_pg::PgRepository;
use mas_templates::{SiteConfigExt, TemplateLoadingError, Templates};
use sqlx::{
ConnectOptions, PgConnection, PgPool,
postgres::{PgConnectOptions, PgPoolOptions},
};
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tracing::{Instrument, log::LevelFilter};
pub async fn password_manager_from_config(
@@ -377,6 +380,66 @@ pub async fn database_connection_from_config_with_options(
.context("could not connect to the database")
}
/// Update the policy factory dynamic data from the database and spawn a task to
/// periodically update it
// XXX: this could be put somewhere else?
pub async fn load_policy_factory_dynamic_data_continuously(
policy_factory: &Arc<PolicyFactory>,
pool: &PgPool,
cancellation_token: CancellationToken,
task_tracker: &TaskTracker,
) -> Result<(), anyhow::Error> {
let policy_factory = policy_factory.clone();
let pool = pool.clone();
load_policy_factory_dynamic_data(&policy_factory, &pool).await?;
task_tracker.spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
tokio::select! {
() = cancellation_token.cancelled() => {
return;
}
_ = interval.tick() => {}
}
if let Err(err) = load_policy_factory_dynamic_data(&policy_factory, &pool).await {
tracing::error!(
error = ?err,
"Failed to load policy factory dynamic data"
);
cancellation_token.cancel();
return;
}
}
});
Ok(())
}
/// Update the policy factory dynamic data from the database
#[tracing::instrument(name = "policy.load_dynamic_data", skip_all, err(Debug))]
pub async fn load_policy_factory_dynamic_data(
policy_factory: &PolicyFactory,
pool: &PgPool,
) -> Result<(), anyhow::Error> {
let mut repo = PgRepository::from_pool(pool)
.await
.context("Failed to acquire database connection")?;
if let Some(data) = repo.policy_data().get().await? {
let id = data.id;
let updated = policy_factory.set_dynamic_data(data).await?;
if updated {
tracing::info!(policy_data.id = %id, "Loaded dynamic policy data from the database");
}
}
Ok(())
}
/// Create a clonable, type-erased [`HomeserverConnection`] from the
/// configuration
pub fn homeserver_connection_from_config(

View File

@@ -10,6 +10,7 @@ use thiserror::Error;
pub(crate) mod compat;
pub mod oauth2;
pub(crate) mod policy_data;
mod site_config;
pub(crate) mod tokens;
pub(crate) mod upstream_oauth2;
@@ -32,6 +33,7 @@ pub use self::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, DeviceCodeGrant,
DeviceCodeGrantState, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState,
},
policy_data::PolicyData,
site_config::{CaptchaConfig, CaptchaService, SessionExpirationConfig, SiteConfig},
tokens::{
AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType,

View File

@@ -0,0 +1,15 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use chrono::{DateTime, Utc};
use serde::Serialize;
use ulid::Ulid;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct PolicyData {
pub id: Ulid,
pub created_at: DateTime<Utc>,
pub data: serde_json::Value,
}

View File

@@ -22,6 +22,7 @@ use indexmap::IndexMap;
use mas_axum_utils::FancyError;
use mas_http::CorsLayerExt;
use mas_matrix::HomeserverConnection;
use mas_policy::PolicyFactory;
use mas_router::{
ApiDoc, ApiDocCallback, OAuth2AuthorizationEndpoint, OAuth2TokenEndpoint, Route, SimpleRoute,
UrlBuilder,
@@ -47,6 +48,11 @@ fn finish(t: TransformOpenApi) -> TransformOpenApi {
description: Some("Manage compatibility sessions from legacy clients".to_owned()),
..Tag::default()
})
.tag(Tag {
name: "policy-data".to_owned(),
description: Some("Manage the dynamic policy data".to_owned()),
..Tag::default()
})
.tag(Tag {
name: "oauth2-session".to_owned(),
description: Some("Manage OAuth2 sessions".to_owned()),
@@ -115,6 +121,7 @@ where
CallContext: FromRequestParts<S>,
Templates: FromRef<S>,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
{
// We *always* want to explicitly set the possible responses, beacuse the
// infered ones are not necessarily correct

View File

@@ -534,3 +534,50 @@ impl UpstreamOAuthLink {
]
}
}
/// The policy data
#[derive(Serialize, JsonSchema)]
pub struct PolicyData {
#[serde(skip)]
id: Ulid,
/// The creation date of the policy data
created_at: DateTime<Utc>,
/// The policy data content
data: serde_json::Value,
}
impl From<mas_data_model::PolicyData> for PolicyData {
fn from(policy_data: mas_data_model::PolicyData) -> Self {
Self {
id: policy_data.id,
created_at: policy_data.created_at,
data: policy_data.data,
}
}
}
impl Resource for PolicyData {
const KIND: &'static str = "policy-data";
const PATH: &'static str = "/api/admin/v1/policy-data";
fn id(&self) -> Ulid {
self.id
}
}
impl PolicyData {
/// Samples of policy data
pub fn samples() -> [Self; 1] {
[Self {
id: Ulid::from_bytes([0x01; 16]),
created_at: DateTime::default(),
data: serde_json::json!({
"hello": "world",
"foo": 42,
"bar": true
}),
}]
}
}

View File

@@ -12,6 +12,7 @@ use aide::axum::{
};
use axum::extract::{FromRef, FromRequestParts};
use mas_matrix::HomeserverConnection;
use mas_policy::PolicyFactory;
use mas_storage::BoxRng;
use super::call_context::CallContext;
@@ -19,6 +20,7 @@ use crate::passwords::PasswordManager;
mod compat_sessions;
mod oauth2_sessions;
mod policy_data;
mod upstream_oauth_links;
mod user_emails;
mod user_sessions;
@@ -29,6 +31,7 @@ where
S: Clone + Send + Sync + 'static,
Arc<dyn HomeserverConnection>: FromRef<S>,
PasswordManager: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
BoxRng: FromRequestParts<S>,
CallContext: FromRequestParts<S>,
{
@@ -49,6 +52,21 @@ where
"/oauth2-sessions/{id}",
get_with(self::oauth2_sessions::get, self::oauth2_sessions::get_doc),
)
.api_route(
"/policy-data",
post_with(self::policy_data::set, self::policy_data::set_doc),
)
.api_route(
"/policy-data/latest",
get_with(
self::policy_data::get_latest,
self::policy_data::get_latest_doc,
),
)
.api_route(
"/policy-data/{id}",
get_with(self::policy_data::get, self::policy_data::get_doc),
)
.api_route(
"/users",
get_with(self::users::list, self::users::list_doc)

View File

@@ -0,0 +1,153 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use ulid::Ulid;
use crate::{
admin::{
call_context::CallContext,
model::PolicyData,
params::UlidPathParam,
response::{ErrorResponse, SingleResponse},
},
impl_from_error_for_route,
};
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("Policy data with ID {0} not found")]
NotFound(Ulid),
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
}
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("getPolicyData")
.summary("Get policy data by ID")
.tag("policy-data")
.response_with::<200, Json<SingleResponse<PolicyData>>, _>(|t| {
let [sample, ..] = PolicyData::samples();
let response = SingleResponse::new_canonical(sample);
t.description("Policy data was found").example(response)
})
.response_with::<404, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::NotFound(Ulid::nil()));
t.description("Policy data was not found").example(response)
})
}
#[tracing::instrument(name = "handler.admin.v1.policy_data.get", skip_all, err)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
id: UlidPathParam,
) -> Result<Json<SingleResponse<PolicyData>>, RouteError> {
let policy_data = repo
.policy_data()
.get()
.await?
.ok_or(RouteError::NotFound(*id))?;
Ok(Json(SingleResponse::new_canonical(policy_data.into())))
}
#[cfg(test)]
mod tests {
use hyper::{Request, StatusCode};
use insta::assert_json_snapshot;
use sqlx::PgPool;
use ulid::Ulid;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_get(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let mut rng = state.rng();
let mut repo = state.repository().await.unwrap();
let policy_data = repo
.policy_data()
.set(
&mut rng,
&state.clock,
serde_json::json!({"hello": "world"}),
)
.await
.unwrap();
repo.save().await.unwrap();
let request = Request::get(format!("/api/admin/v1/policy-data/{}", policy_data.id))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###"
{
"data": {
"type": "policy-data",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"data": {
"hello": "world"
}
},
"links": {
"self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E"
}
},
"links": {
"self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
"###);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_get_not_found(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let request = Request::get(format!("/api/admin/v1/policy-data/{}", Ulid::nil()))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::NOT_FOUND);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###"
{
"errors": [
{
"title": "Policy data with ID 00000000000000000000000000 not found"
}
]
}
"###);
}
}

View File

@@ -0,0 +1,149 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use crate::{
admin::{
call_context::CallContext,
model::PolicyData,
response::{ErrorResponse, SingleResponse},
},
impl_from_error_for_route,
};
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("No policy data found")]
NotFound,
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
}
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("getLatestPolicyData")
.summary("Get the latest policy data")
.tag("policy-data")
.response_with::<200, Json<SingleResponse<PolicyData>>, _>(|t| {
let [sample, ..] = PolicyData::samples();
let response = SingleResponse::new_canonical(sample);
t.description("Latest policy data was found")
.example(response)
})
.response_with::<404, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::NotFound);
t.description("No policy data was found").example(response)
})
}
#[tracing::instrument(name = "handler.admin.v1.policy_data.get_latest", skip_all, err)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
) -> Result<Json<SingleResponse<PolicyData>>, RouteError> {
let policy_data = repo
.policy_data()
.get()
.await?
.ok_or(RouteError::NotFound)?;
Ok(Json(SingleResponse::new_canonical(policy_data.into())))
}
#[cfg(test)]
mod tests {
use hyper::{Request, StatusCode};
use insta::assert_json_snapshot;
use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_get_latest(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let mut rng = state.rng();
let mut repo = state.repository().await.unwrap();
repo.policy_data()
.set(
&mut rng,
&state.clock,
serde_json::json!({"hello": "world"}),
)
.await
.unwrap();
repo.save().await.unwrap();
let request = Request::get("/api/admin/v1/policy-data/latest")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###"
{
"data": {
"type": "policy-data",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"data": {
"hello": "world"
}
},
"links": {
"self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E"
}
},
"links": {
"self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
"###);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_get_no_latest(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let request = Request::get("/api/admin/v1/policy-data/latest")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::NOT_FOUND);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###"
{
"errors": [
{
"title": "No policy data found"
}
]
}
"###);
}
}

View File

@@ -0,0 +1,14 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
mod get;
mod get_latest;
mod set;
pub use self::{
get::{doc as get_doc, handler as get},
get_latest::{doc as get_latest_doc, handler as get_latest},
set::{doc as set_doc, handler as set},
};

View File

@@ -0,0 +1,152 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
use std::sync::Arc;
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, extract::State, response::IntoResponse};
use hyper::StatusCode;
use mas_policy::PolicyFactory;
use mas_storage::BoxRng;
use schemars::JsonSchema;
use serde::Deserialize;
use crate::{
admin::{
call_context::CallContext,
model::PolicyData,
response::{ErrorResponse, SingleResponse},
},
impl_from_error_for_route,
};
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error("Failed to instanciate policy with the provided data")]
InvalidPolicyData(#[from] mas_policy::LoadError),
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let status = match self {
RouteError::InvalidPolicyData(_) => StatusCode::BAD_REQUEST,
RouteError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
(status, Json(error)).into_response()
}
}
fn data_example() -> serde_json::Value {
serde_json::json!({
"hello": "world",
"foo": 42,
"bar": true
})
}
/// # JSON payload for the `POST /api/admin/v1/policy-data`
#[derive(Deserialize, JsonSchema)]
#[serde(rename = "SetPolicyDataRequest")]
pub struct SetPolicyDataRequest {
#[schemars(example = "data_example")]
pub data: serde_json::Value,
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("setPolicyData")
.summary("Set the current policy data")
.tag("policy-data")
.response_with::<201, Json<SingleResponse<PolicyData>>, _>(|t| {
let [sample, ..] = PolicyData::samples();
let response = SingleResponse::new_canonical(sample);
t.description("Policy data was successfully set")
.example(response)
})
.response_with::<400, Json<ErrorResponse>, _>(|t| {
let error = ErrorResponse::from_error(&RouteError::InvalidPolicyData(
mas_policy::LoadError::invalid_data_example(),
));
t.description("Invalid policy data").example(error)
})
}
#[tracing::instrument(name = "handler.admin.v1.policy_data.set", skip_all, err)]
pub async fn handler(
CallContext {
mut repo, clock, ..
}: CallContext,
NoApi(mut rng): NoApi<BoxRng>,
State(policy_factory): State<Arc<PolicyFactory>>,
Json(request): Json<SetPolicyDataRequest>,
) -> Result<(StatusCode, Json<SingleResponse<PolicyData>>), RouteError> {
let policy_data = repo
.policy_data()
.set(&mut rng, &clock, request.data)
.await?;
// Swap the policy data. This will fail if the policy data is invalid
policy_factory.set_dynamic_data(policy_data.clone()).await?;
repo.save().await?;
Ok((
StatusCode::CREATED,
Json(SingleResponse::new_canonical(policy_data.into())),
))
}
#[cfg(test)]
mod tests {
use hyper::{Request, StatusCode};
use insta::assert_json_snapshot;
use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_create(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let request = Request::post("/api/admin/v1/policy-data")
.bearer(&token)
.json(serde_json::json!({
"data": {
"hello": "world"
}
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###"
{
"data": {
"type": "policy-data",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"data": {
"hello": "world"
}
},
"links": {
"self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E"
}
},
"links": {
"self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
"###);
}
}

View File

@@ -58,6 +58,7 @@ impl_from_ref!(mas_templates::Templates);
impl_from_ref!(Arc<dyn mas_matrix::HomeserverConnection>);
impl_from_ref!(mas_keystore::Keystore);
impl_from_ref!(mas_handlers::passwords::PasswordManager);
impl_from_ref!(Arc<mas_policy::PolicyFactory>);
fn main() -> Result<(), Box<dyn std::error::Error>> {
let (mut api, _) = mas_handlers::admin_api_router::<DummyState>();

View File

@@ -514,6 +514,12 @@ impl FromRef<TestState> for SiteConfig {
}
}
impl FromRef<TestState> for Arc<PolicyFactory> {
fn from_ref(input: &TestState) -> Self {
input.policy_factory.clone()
}
}
impl FromRef<TestState> for Arc<dyn HomeserverConnection> {
fn from_ref(input: &TestState) -> Self {
input.homeserver_connection.clone()

View File

@@ -13,6 +13,7 @@ workspace = true
[dependencies]
anyhow.workspace = true
arc-swap.workspace = true
opa-wasm = "0.1.4"
serde.workspace = true
serde_json.workspace = true

View File

@@ -6,11 +6,14 @@
pub mod model;
use std::sync::Arc;
use arc_swap::ArcSwap;
use mas_data_model::Ulid;
use opa_wasm::{
Runtime,
wasmtime::{Config, Engine, Module, OptLevel, Store},
};
use serde::Serialize;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
@@ -33,10 +36,23 @@ pub enum LoadError {
#[error("failed to compile WASM module")]
Compilation(#[source] anyhow::Error),
#[error("invalid policy data")]
InvalidData(#[source] anyhow::Error),
#[error("failed to instantiate a test instance")]
Instantiate(#[source] InstantiateError),
}
impl LoadError {
/// Creates an example of an invalid data error, used for API response
/// documentation
#[doc(hidden)]
#[must_use]
pub fn invalid_data_example() -> Self {
Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
}
}
#[derive(Debug, Error)]
pub enum InstantiateError {
#[error("failed to create WASM runtime")]
@@ -69,11 +85,10 @@ impl Entrypoints {
}
}
#[derive(Serialize, Debug)]
#[derive(Debug)]
pub struct Data {
server_name: String,
#[serde(flatten)]
rest: Option<serde_json::Value>,
}
@@ -91,12 +106,93 @@ impl Data {
self.rest = Some(rest);
self
}
fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
let base = serde_json::json!({
"server_name": self.server_name,
});
if let Some(rest) = &self.rest {
merge_data(base, rest.clone())
} else {
Ok(base)
}
}
}
fn value_kind(value: &serde_json::Value) -> &'static str {
match value {
serde_json::Value::Object(_) => "object",
serde_json::Value::Array(_) => "array",
serde_json::Value::String(_) => "string",
serde_json::Value::Number(_) => "number",
serde_json::Value::Bool(_) => "boolean",
serde_json::Value::Null => "null",
}
}
fn merge_data(
mut left: serde_json::Value,
right: serde_json::Value,
) -> Result<serde_json::Value, anyhow::Error> {
merge_data_rec(&mut left, right)?;
Ok(left)
}
fn merge_data_rec(
left: &mut serde_json::Value,
right: serde_json::Value,
) -> Result<(), anyhow::Error> {
match (left, right) {
(serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
for (key, value) in right {
if let Some(left_value) = left.get_mut(&key) {
merge_data_rec(left_value, value)?;
} else {
left.insert(key, value);
}
}
}
(serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
left.extend(right);
}
// Other values override
(serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
*left = right;
}
(serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
*left = right;
}
(serde_json::Value::String(left), serde_json::Value::String(right)) => {
*left = right;
}
// Null gets overridden by anything
(left, right) if left.is_null() => *left = right,
// Null on the right makes the left value null
(left, right) if right.is_null() => *left = right,
(left, right) => anyhow::bail!(
"Cannot merge a {} into a {}",
value_kind(&right),
value_kind(left),
),
}
Ok(())
}
struct DynamicData {
version: Option<Ulid>,
merged: serde_json::Value,
}
pub struct PolicyFactory {
engine: Engine,
module: Module,
data: Data,
dynamic_data: ArcSwap<DynamicData>,
entrypoints: Entrypoints,
}
@@ -124,10 +220,17 @@ impl PolicyFactory {
.await?
.map_err(LoadError::Compilation)?;
let merged = data.to_value().map_err(LoadError::InvalidData)?;
let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
version: None,
merged,
}));
let factory = Self {
engine,
module,
data,
dynamic_data,
entrypoints,
};
@@ -140,8 +243,56 @@ impl PolicyFactory {
Ok(factory)
}
/// Set the dynamic data for the policy.
///
/// The `dynamic_data` object is merged with the static data given when the
/// policy was loaded.
///
/// Returns `true` if the data was updated, `false` if the version
/// of the dynamic data was the same as the one we already have.
///
/// # Errors
///
/// Returns an error if the data can't be merged with the static data, or if
/// the policy can't be instantiated with the new data.
pub async fn set_dynamic_data(
&self,
dynamic_data: mas_data_model::PolicyData,
) -> Result<bool, LoadError> {
// Check if the version of the dynamic data we have is the same as the one we're
// trying to set
if self.dynamic_data.load().version == Some(dynamic_data.id) {
// Don't do anything if the version is the same
return Ok(false);
}
let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
// Try to instantiate with the new data
self.instantiate_with_data(&merged)
.await
.map_err(LoadError::Instantiate)?;
// If instantiation succeeds, swap the data
self.dynamic_data.store(Arc::new(DynamicData {
version: Some(dynamic_data.id),
merged,
}));
Ok(true)
}
#[tracing::instrument(name = "policy.instantiate", skip_all, err)]
pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
let data = self.dynamic_data.load();
self.instantiate_with_data(&data.merged).await
}
async fn instantiate_with_data(
&self,
data: &serde_json::Value,
) -> Result<Policy, InstantiateError> {
let mut store = Store::new(&self.engine, ());
let runtime = Runtime::new(&mut store, &self.module)
.await
@@ -159,7 +310,7 @@ impl PolicyFactory {
}
let instance = runtime
.with_data(&mut store, &self.data)
.with_data(&mut store, data)
.await
.map_err(InstantiateError::LoadData)?;
@@ -273,6 +424,8 @@ impl Policy {
#[cfg(test)]
mod tests {
use std::time::SystemTime;
use super::*;
#[tokio::test]
@@ -344,4 +497,167 @@ mod tests {
.unwrap();
assert!(!res.valid());
}
#[tokio::test]
async fn test_dynamic_data() {
let data = Data::new("example.com".to_owned());
#[allow(clippy::disallowed_types)]
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..")
.join("policies")
.join("policy.wasm");
let file = tokio::fs::File::open(path).await.unwrap();
let entrypoints = Entrypoints {
register: "register/violation".to_owned(),
client_registration: "client_registration/violation".to_owned(),
authorization_grant: "authorization_grant/violation".to_owned(),
email: "email/violation".to_owned(),
};
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
let mut policy = factory.instantiate().await.unwrap();
let res = policy
.evaluate_register(RegisterInput {
registration_method: RegistrationMethod::Password,
username: "hello",
email: Some("hello@example.com"),
requester: Requester {
ip_address: None,
user_agent: None,
},
})
.await
.unwrap();
assert!(res.valid());
// Update the policy data
factory
.set_dynamic_data(mas_data_model::PolicyData {
id: Ulid::nil(),
created_at: SystemTime::now().into(),
data: serde_json::json!({
"emails": {
"banned_addresses": {
"substrings": ["hello"]
}
}
}),
})
.await
.unwrap();
let mut policy = factory.instantiate().await.unwrap();
let res = policy
.evaluate_register(RegisterInput {
registration_method: RegistrationMethod::Password,
username: "hello",
email: Some("hello@example.com"),
requester: Requester {
ip_address: None,
user_agent: None,
},
})
.await
.unwrap();
assert!(!res.valid());
}
#[tokio::test]
async fn test_big_dynamic_data() {
let data = Data::new("example.com".to_owned());
#[allow(clippy::disallowed_types)]
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..")
.join("policies")
.join("policy.wasm");
let file = tokio::fs::File::open(path).await.unwrap();
let entrypoints = Entrypoints {
register: "register/violation".to_owned(),
client_registration: "client_registration/violation".to_owned(),
authorization_grant: "authorization_grant/violation".to_owned(),
email: "email/violation".to_owned(),
};
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
// That is around 1 MB of JSON data. Each element is a 5-digit string, so 8
// characters including the quotes and a comma.
let data: Vec<String> = (0..(1024 * 1024 / 8))
.map(|i| format!("{:05}", i % 100_000))
.collect();
let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
factory
.set_dynamic_data(mas_data_model::PolicyData {
id: Ulid::nil(),
created_at: SystemTime::now().into(),
data: json,
})
.await
.unwrap();
// Try instantiating the policy, make sure 5-digit numbers are banned from email
// addresses
let mut policy = factory.instantiate().await.unwrap();
let res = policy
.evaluate_register(RegisterInput {
registration_method: RegistrationMethod::Password,
username: "hello",
email: Some("12345@example.com"),
requester: Requester {
ip_address: None,
user_agent: None,
},
})
.await
.unwrap();
assert!(!res.valid());
}
#[test]
fn test_merge() {
use serde_json::json as j;
// Merging objects
let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
// Override a value of the same type
let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
assert_eq!(res, j!({"hello": "john"}));
let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
assert_eq!(res, j!({"hello": false}));
let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
assert_eq!(res, j!({"hello": 42}));
// Override a value of a different type
merge_data(j!({"hello": "world"}), j!({"hello": 123}))
.expect_err("Can't merge different types");
// Merge arrays
let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
assert_eq!(res, j!({"hello": ["world", "john"]}));
// Null overrides a value
let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
assert_eq!(res, j!({"hello": null}));
// Null gets overridden by a value
let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
assert_eq!(res, j!({"hello": "world"}));
// Objects get deeply merged
let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
}
}

View File

@@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "\n DELETE FROM policy_data\n WHERE policy_data_id IN (\n SELECT policy_data_id\n FROM policy_data\n ORDER BY policy_data_id DESC\n OFFSET $1\n )\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": []
},
"hash": "5006c3e60c98c91a0b0fbb3205373e81d9b75e90929af80961f8b5910873a43e"
}

View File

@@ -0,0 +1,32 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT policy_data_id, created_at, data\n FROM policy_data\n ORDER BY policy_data_id DESC\n LIMIT 1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "policy_data_id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 2,
"name": "data",
"type_info": "Jsonb"
}
],
"parameters": {
"Left": []
},
"nullable": [
false,
false,
false
]
},
"hash": "9fe87eeaf4b7d0ba09b59ddad3476eb57ccb6e4053ab8f4450dd4a9d1f6ba108"
}

View File

@@ -0,0 +1,16 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO policy_data (policy_data_id, created_at, data)\n VALUES ($1, $2, $3)\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz",
"Jsonb"
]
},
"nullable": []
},
"hash": "b6c4f4a23968cba2a82c2b7cfffc05a7ed582c9e5c1f65d27b0686f843ccfe42"
}

View File

@@ -0,0 +1,15 @@
-- Copyright 2025 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only
-- Please see LICENSE in the repository root for full details.
-- Add a table which stores the latest policy data
--
-- Every time the policy data is updated, it creates a new row, so that we keep
-- an history of the policy data, trace back which version of the data was used
-- on each evaluation.
CREATE TABLE IF NOT EXISTS policy_data (
policy_data_id UUID PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
data JSONB NOT NULL
);

View File

@@ -173,6 +173,7 @@ mod errors;
pub(crate) mod filter;
pub(crate) mod iden;
pub(crate) mod pagination;
pub(crate) mod policy_data;
pub(crate) mod repository;
pub(crate) mod tracing;

View File

@@ -0,0 +1,204 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
//! A module containing the PostgreSQL implementation of the policy data
//! storage.
use async_trait::async_trait;
use mas_data_model::PolicyData;
use mas_storage::{Clock, policy_data::PolicyDataRepository};
use rand::RngCore;
use serde_json::Value;
use sqlx::{PgConnection, types::Json};
use ulid::Ulid;
use uuid::Uuid;
use crate::{DatabaseError, ExecuteExt};
/// An implementation of [`PolicyDataRepository`] for a PostgreSQL connection.
pub struct PgPolicyDataRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgPolicyDataRepository<'c> {
/// Create a new [`PgPolicyDataRepository`] from an active PostgreSQL
/// connection.
#[must_use]
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct PolicyDataLookup {
policy_data_id: Uuid,
created_at: chrono::DateTime<chrono::Utc>,
data: Json<Value>,
}
impl From<PolicyDataLookup> for PolicyData {
fn from(value: PolicyDataLookup) -> Self {
PolicyData {
id: value.policy_data_id.into(),
created_at: value.created_at,
data: value.data.0,
}
}
}
#[async_trait]
impl PolicyDataRepository for PgPolicyDataRepository<'_> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.policy_data.get",
skip_all,
fields(
db.query.text,
),
err,
)]
async fn get(&mut self) -> Result<Option<PolicyData>, Self::Error> {
let row = sqlx::query_as!(
PolicyDataLookup,
r#"
SELECT policy_data_id, created_at, data
FROM policy_data
ORDER BY policy_data_id DESC
LIMIT 1
"#
)
.traced()
.fetch_optional(&mut *self.conn)
.await?;
let Some(row) = row else {
return Ok(None);
};
Ok(Some(row.into()))
}
#[tracing::instrument(
name = "db.policy_data.set",
skip_all,
fields(
db.query.text,
),
err,
)]
async fn set(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
data: Value,
) -> Result<PolicyData, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
sqlx::query!(
r#"
INSERT INTO policy_data (policy_data_id, created_at, data)
VALUES ($1, $2, $3)
"#,
Uuid::from(id),
created_at,
data,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(PolicyData {
id,
created_at,
data,
})
}
#[tracing::instrument(
name = "db.policy_data.prune",
skip_all,
fields(
db.query.text,
),
err,
)]
async fn prune(&mut self, keep: usize) -> Result<usize, Self::Error> {
let res = sqlx::query!(
r#"
DELETE FROM policy_data
WHERE policy_data_id IN (
SELECT policy_data_id
FROM policy_data
ORDER BY policy_data_id DESC
OFFSET $1
)
"#,
i64::try_from(keep).map_err(DatabaseError::to_invalid_operation)?
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(res
.rows_affected()
.try_into()
.map_err(DatabaseError::to_invalid_operation)?)
}
}
#[cfg(test)]
mod tests {
use mas_storage::{clock::MockClock, policy_data::PolicyDataRepository};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use serde_json::json;
use sqlx::PgPool;
use crate::policy_data::PgPolicyDataRepository;
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_policy_data(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut conn = pool.acquire().await.unwrap();
let mut repo = PgPolicyDataRepository::new(&mut conn);
// Get an empty state at first
let data = repo.get().await.unwrap();
assert_eq!(data, None);
// Set some data
let value1 = json!({"hello": "world"});
let policy_data1 = repo.set(&mut rng, &clock, value1.clone()).await.unwrap();
assert_eq!(policy_data1.data, value1);
let data_fetched1 = repo.get().await.unwrap().unwrap();
assert_eq!(policy_data1, data_fetched1);
// Set some new data
clock.advance(chrono::Duration::seconds(1));
let value2 = json!({"foo": "bar"});
let policy_data2 = repo.set(&mut rng, &clock, value2.clone()).await.unwrap();
assert_eq!(policy_data2.data, value2);
// Check the new data is fetched
let data_fetched2 = repo.get().await.unwrap().unwrap();
assert_eq!(data_fetched2, policy_data2);
// Prune until the first entry
let affected = repo.prune(1).await.unwrap();
let data_fetched3 = repo.get().await.unwrap().unwrap();
assert_eq!(data_fetched3, policy_data2);
assert_eq!(affected, 1);
// Do a raw query to check the other rows were pruned
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM policy_data")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(count, 1);
}
}

View File

@@ -18,6 +18,7 @@ use mas_storage::{
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
},
policy_data::PolicyDataRepository,
queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
upstream_oauth2::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
@@ -40,6 +41,7 @@ use crate::{
PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository,
PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
},
policy_data::PgPolicyDataRepository,
queue::{
job::PgQueueJobRepository, schedule::PgQueueScheduleRepository,
worker::PgQueueWorkerRepository,
@@ -283,4 +285,8 @@ where
) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
Box::new(PgQueueScheduleRepository::new(self.conn.as_mut()))
}
fn policy_data<'c>(&'c mut self) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
Box::new(PgPolicyDataRepository::new(self.conn.as_mut()))
}
}

View File

@@ -119,6 +119,7 @@ mod utils;
pub mod app_session;
pub mod compat;
pub mod oauth2;
pub mod policy_data;
pub mod queue;
pub mod upstream_oauth2;
pub mod user;

View File

@@ -0,0 +1,76 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
//! Repositories to interact with the policy data saved in the storage backend.
use async_trait::async_trait;
use mas_data_model::PolicyData;
use rand_core::RngCore;
use crate::{Clock, repository_impl};
/// A [`PolicyDataRepository`] helps interacting with the policy data saved in
/// the storage backend.
#[async_trait]
pub trait PolicyDataRepository: Send + Sync {
/// The error type returned by the repository
type Error;
/// Get the latest policy data
///
/// Returns the latest policy data, or `None` if no policy data is
/// available.
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn get(&mut self) -> Result<Option<PolicyData>, Self::Error>;
/// Set the latest policy data
///
/// Returns the newly created policy data.
///
/// # Parameters
///
/// * `rng`: The random number generator to use
/// * `clock`: The clock used to generate the timestamps
/// * `data`: The policy data to set
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn set(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
data: serde_json::Value,
) -> Result<PolicyData, Self::Error>;
/// Prune old policy data
///
/// Returns the number of entries pruned.
///
/// # Parameters
///
/// * `keep`: the number of old entries to keep
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn prune(&mut self, keep: usize) -> Result<usize, Self::Error>;
}
repository_impl!(PolicyDataRepository:
async fn get(&mut self) -> Result<Option<PolicyData>, Self::Error>;
async fn set(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
data: serde_json::Value,
) -> Result<PolicyData, Self::Error>;
async fn prune(&mut self, keep: usize) -> Result<usize, Self::Error>;
);

View File

@@ -506,3 +506,11 @@ impl ExpireInactiveUserSessionsJob {
impl InsertableJob for ExpireInactiveUserSessionsJob {
const QUEUE_NAME: &'static str = "expire-inactive-user-sessions";
}
/// Prune stale policy data
#[derive(Debug, Serialize, Deserialize)]
pub struct PruneStalePolicyDataJob;
impl InsertableJob for PruneStalePolicyDataJob {
const QUEUE_NAME: &'static str = "prune-stale-policy-data";
}

View File

@@ -17,6 +17,7 @@ use crate::{
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
},
policy_data::PolicyDataRepository,
queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
upstream_oauth2::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
@@ -204,6 +205,9 @@ pub trait RepositoryAccess: Send {
fn queue_schedule<'c>(
&'c mut self,
) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c>;
/// Get a [`PolicyDataRepository`]
fn policy_data<'c>(&'c mut self) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c>;
}
/// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and
@@ -224,6 +228,7 @@ mod impls {
OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository,
OAuth2SessionRepository,
},
policy_data::PolicyDataRepository,
queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
upstream_oauth2::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
@@ -439,6 +444,12 @@ mod impls {
) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.queue_schedule(), &mut self.mapper))
}
fn policy_data<'c>(
&'c mut self,
) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.policy_data(), &mut self.mapper))
}
}
impl<R: RepositoryAccess + ?Sized> RepositoryAccess for Box<R> {
@@ -579,5 +590,11 @@ mod impls {
) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
(**self).queue_schedule()
}
fn policy_data<'c>(
&'c mut self,
) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
(**self).policy_data()
}
}
}

View File

@@ -7,7 +7,7 @@
//! Database-related tasks
use async_trait::async_trait;
use mas_storage::queue::CleanupExpiredTokensJob;
use mas_storage::queue::{CleanupExpiredTokensJob, PruneStalePolicyDataJob};
use tracing::{debug, info};
use crate::{
@@ -38,3 +38,28 @@ impl RunnableJob for CleanupExpiredTokensJob {
Ok(())
}
}
#[async_trait]
impl RunnableJob for PruneStalePolicyDataJob {
#[tracing::instrument(name = "job.prune_stale_policy_data", skip_all, err)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let mut repo = state.repository().await.map_err(JobError::retry)?;
// Keep the last 10 policy data
let count = repo
.policy_data()
.prune(10)
.await
.map_err(JobError::retry)?;
repo.save().await.map_err(JobError::retry)?;
if count == 0 {
debug!("no stale policy data to prune");
} else {
info!(count, "pruned stale policy data");
}
Ok(())
}
}

View File

@@ -143,6 +143,7 @@ pub async fn init(
.register_handler::<mas_storage::queue::ExpireInactiveCompatSessionsJob>()
.register_handler::<mas_storage::queue::ExpireInactiveOAuthSessionsJob>()
.register_handler::<mas_storage::queue::ExpireInactiveUserSessionsJob>()
.register_handler::<mas_storage::queue::PruneStalePolicyDataJob>()
.add_schedule(
"cleanup-expired-tokens",
"0 0 * * * *".parse()?,
@@ -153,6 +154,12 @@ pub async fn init(
// Run this job every 15 minutes
"30 */15 * * * *".parse()?,
mas_storage::queue::ExpireInactiveSessionsJob,
)
.add_schedule(
"prune-stale-policy-data",
// Run once a day
"0 0 2 * * *".parse()?,
mas_storage::queue::PruneStalePolicyDataJob,
);
task_tracker.spawn(worker.run());

View File

@@ -593,6 +593,208 @@
}
}
},
"/api/admin/v1/policy-data": {
"post": {
"tags": [
"policy-data"
],
"summary": "Set the current policy data",
"operationId": "setPolicyData",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SetPolicyDataRequest"
}
}
},
"required": true
},
"responses": {
"201": {
"description": "Policy data was successfully set",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SingleResponse_for_PolicyData"
},
"example": {
"data": {
"type": "policy-data",
"id": "01040G2081040G2081040G2081",
"attributes": {
"created_at": "1970-01-01T00:00:00Z",
"data": {
"hello": "world",
"foo": 42,
"bar": true
}
},
"links": {
"self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081"
}
},
"links": {
"self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081"
}
}
}
}
},
"400": {
"description": "Invalid policy data",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"errors": [
{
"title": "Failed to instanciate policy with the provided data"
},
{
"title": "invalid policy data"
},
{
"title": "Failed to merge policy data objects"
}
]
}
}
}
}
}
}
},
"/api/admin/v1/policy-data/latest": {
"get": {
"tags": [
"policy-data"
],
"summary": "Get the latest policy data",
"operationId": "getLatestPolicyData",
"responses": {
"200": {
"description": "Latest policy data was found",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SingleResponse_for_PolicyData"
},
"example": {
"data": {
"type": "policy-data",
"id": "01040G2081040G2081040G2081",
"attributes": {
"created_at": "1970-01-01T00:00:00Z",
"data": {
"hello": "world",
"foo": 42,
"bar": true
}
},
"links": {
"self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081"
}
},
"links": {
"self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081"
}
}
}
}
},
"404": {
"description": "No policy data was found",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"errors": [
{
"title": "No policy data found"
}
]
}
}
}
}
}
}
},
"/api/admin/v1/policy-data/{id}": {
"get": {
"tags": [
"policy-data"
],
"summary": "Get policy data by ID",
"operationId": "getPolicyData",
"parameters": [
{
"in": "path",
"name": "id",
"required": true,
"schema": {
"title": "The ID of the resource",
"$ref": "#/components/schemas/ULID"
},
"style": "simple"
}
],
"responses": {
"200": {
"description": "Policy data was found",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SingleResponse_for_PolicyData"
},
"example": {
"data": {
"type": "policy-data",
"id": "01040G2081040G2081040G2081",
"attributes": {
"created_at": "1970-01-01T00:00:00Z",
"data": {
"hello": "world",
"foo": 42,
"bar": true
}
},
"links": {
"self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081"
}
},
"links": {
"self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081"
}
}
}
}
},
"404": {
"description": "Policy data was not found",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"errors": [
{
"title": "Policy data with ID 00000000000000000000000000 not found"
}
]
}
}
}
}
}
}
},
"/api/admin/v1/users": {
"get": {
"tags": [
@@ -2666,6 +2868,86 @@
}
}
},
"SetPolicyDataRequest": {
"title": "JSON payload for the `POST /api/admin/v1/policy-data`",
"type": "object",
"required": [
"data"
],
"properties": {
"data": {
"examples": [
{
"hello": "world",
"foo": 42,
"bar": true
}
]
}
}
},
"SingleResponse_for_PolicyData": {
"description": "A top-level response with a single resource",
"type": "object",
"required": [
"data",
"links"
],
"properties": {
"data": {
"$ref": "#/components/schemas/SingleResource_for_PolicyData"
},
"links": {
"$ref": "#/components/schemas/SelfLinks"
}
}
},
"SingleResource_for_PolicyData": {
"description": "A single resource, with its type, ID, attributes and related links",
"type": "object",
"required": [
"attributes",
"id",
"links",
"type"
],
"properties": {
"type": {
"description": "The type of the resource",
"type": "string"
},
"id": {
"description": "The ID of the resource",
"$ref": "#/components/schemas/ULID"
},
"attributes": {
"description": "The attributes of the resource",
"$ref": "#/components/schemas/PolicyData"
},
"links": {
"description": "Related links",
"$ref": "#/components/schemas/SelfLinks"
}
}
},
"PolicyData": {
"description": "The policy data",
"type": "object",
"required": [
"created_at",
"data"
],
"properties": {
"created_at": {
"description": "The creation date of the policy data",
"type": "string",
"format": "date-time"
},
"data": {
"description": "The policy data content"
}
}
},
"UserFilter": {
"type": "object",
"properties": {
@@ -3252,6 +3534,10 @@
"name": "compat-session",
"description": "Manage compatibility sessions from legacy clients"
},
{
"name": "policy-data",
"description": "Manage the dynamic policy data"
},
{
"name": "oauth2-session",
"description": "Manage OAuth2 sessions"