Admin API to dynamically set policy data (#4115)
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -3588,6 +3588,7 @@ name = "mas-policy"
|
||||
version = "0.14.1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"arc-swap",
|
||||
"mas-data-model",
|
||||
"oauth2-types",
|
||||
"opa-wasm",
|
||||
|
||||
@@ -61,6 +61,7 @@ syn2mas = { path = "./crates/syn2mas", version = "=0.14.1" }
|
||||
version = "0.14.1"
|
||||
features = ["axum", "axum-extra", "axum-json", "axum-query", "macros"]
|
||||
|
||||
# An `Arc` that can be atomically updated
|
||||
[workspace.dependencies.arc-swap]
|
||||
version = "1.7.1"
|
||||
|
||||
|
||||
@@ -203,6 +203,12 @@ impl FromRef<AppState> for Limiter {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<AppState> for Arc<PolicyFactory> {
|
||||
fn from_ref(input: &AppState) -> Self {
|
||||
input.policy_factory.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<AppState> for Arc<dyn HomeserverConnection> {
|
||||
fn from_ref(input: &AppState) -> Self {
|
||||
Arc::clone(&input.homeserver_connection)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2024, 2025 New Vector Ltd.
|
||||
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
@@ -8,10 +8,14 @@ use std::process::ExitCode;
|
||||
|
||||
use clap::Parser;
|
||||
use figment::Figment;
|
||||
use mas_config::{ConfigurationSection, ConfigurationSectionExt, MatrixConfig, PolicyConfig};
|
||||
use mas_config::{
|
||||
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig,
|
||||
};
|
||||
use tracing::{info, info_span};
|
||||
|
||||
use crate::util::policy_factory_from_config;
|
||||
use crate::util::{
|
||||
database_pool_from_config, load_policy_factory_dynamic_data, policy_factory_from_config,
|
||||
};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
pub(super) struct Options {
|
||||
@@ -22,7 +26,11 @@ pub(super) struct Options {
|
||||
#[derive(Parser, Debug)]
|
||||
enum Subcommand {
|
||||
/// Check that the policies compile
|
||||
Policy,
|
||||
Policy {
|
||||
/// With dynamic data loaded
|
||||
#[arg(long)]
|
||||
with_dynamic_data: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl Options {
|
||||
@@ -30,13 +38,19 @@ impl Options {
|
||||
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
|
||||
use Subcommand as SC;
|
||||
match self.subcommand {
|
||||
SC::Policy => {
|
||||
SC::Policy { with_dynamic_data } => {
|
||||
let _span = info_span!("cli.debug.policy").entered();
|
||||
let config = PolicyConfig::extract_or_default(figment)?;
|
||||
let matrix_config = MatrixConfig::extract(figment)?;
|
||||
info!("Loading and compiling the policy module");
|
||||
let policy_factory = policy_factory_from_config(&config, &matrix_config).await?;
|
||||
|
||||
if with_dynamic_data {
|
||||
let database_config = DatabaseConfig::extract(figment)?;
|
||||
let pool = database_pool_from_config(&database_config).await?;
|
||||
load_policy_factory_dynamic_data(&policy_factory, &pool).await?;
|
||||
}
|
||||
|
||||
let _instance = policy_factory.instantiate().await?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,8 @@ use crate::{
|
||||
app_state::AppState,
|
||||
lifecycle::LifecycleManager,
|
||||
util::{
|
||||
database_pool_from_config, homeserver_connection_from_config, mailer_from_config,
|
||||
database_pool_from_config, homeserver_connection_from_config,
|
||||
load_policy_factory_dynamic_data_continuously, mailer_from_config,
|
||||
password_manager_from_config, policy_factory_from_config, site_config_from_config,
|
||||
templates_from_config, test_mailer_in_background,
|
||||
},
|
||||
@@ -129,6 +130,14 @@ impl Options {
|
||||
let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?;
|
||||
let policy_factory = Arc::new(policy_factory);
|
||||
|
||||
load_policy_factory_dynamic_data_continuously(
|
||||
&policy_factory,
|
||||
&pool,
|
||||
shutdown.soft_shutdown_token(),
|
||||
shutdown.task_tracker(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let url_builder = UrlBuilder::new(
|
||||
config.http.public_base.clone(),
|
||||
config.http.issuer.clone(),
|
||||
|
||||
@@ -19,11 +19,14 @@ use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection};
|
||||
use mas_matrix_synapse::SynapseConnection;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::RepositoryAccess;
|
||||
use mas_storage_pg::PgRepository;
|
||||
use mas_templates::{SiteConfigExt, TemplateLoadingError, Templates};
|
||||
use sqlx::{
|
||||
ConnectOptions, PgConnection, PgPool,
|
||||
postgres::{PgConnectOptions, PgPoolOptions},
|
||||
};
|
||||
use tokio_util::{sync::CancellationToken, task::TaskTracker};
|
||||
use tracing::{Instrument, log::LevelFilter};
|
||||
|
||||
pub async fn password_manager_from_config(
|
||||
@@ -377,6 +380,66 @@ pub async fn database_connection_from_config_with_options(
|
||||
.context("could not connect to the database")
|
||||
}
|
||||
|
||||
/// Update the policy factory dynamic data from the database and spawn a task to
|
||||
/// periodically update it
|
||||
// XXX: this could be put somewhere else?
|
||||
pub async fn load_policy_factory_dynamic_data_continuously(
|
||||
policy_factory: &Arc<PolicyFactory>,
|
||||
pool: &PgPool,
|
||||
cancellation_token: CancellationToken,
|
||||
task_tracker: &TaskTracker,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let policy_factory = policy_factory.clone();
|
||||
let pool = pool.clone();
|
||||
|
||||
load_policy_factory_dynamic_data(&policy_factory, &pool).await?;
|
||||
|
||||
task_tracker.spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(60));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
() = cancellation_token.cancelled() => {
|
||||
return;
|
||||
}
|
||||
_ = interval.tick() => {}
|
||||
}
|
||||
|
||||
if let Err(err) = load_policy_factory_dynamic_data(&policy_factory, &pool).await {
|
||||
tracing::error!(
|
||||
error = ?err,
|
||||
"Failed to load policy factory dynamic data"
|
||||
);
|
||||
cancellation_token.cancel();
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the policy factory dynamic data from the database
|
||||
#[tracing::instrument(name = "policy.load_dynamic_data", skip_all, err(Debug))]
|
||||
pub async fn load_policy_factory_dynamic_data(
|
||||
policy_factory: &PolicyFactory,
|
||||
pool: &PgPool,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let mut repo = PgRepository::from_pool(pool)
|
||||
.await
|
||||
.context("Failed to acquire database connection")?;
|
||||
|
||||
if let Some(data) = repo.policy_data().get().await? {
|
||||
let id = data.id;
|
||||
let updated = policy_factory.set_dynamic_data(data).await?;
|
||||
if updated {
|
||||
tracing::info!(policy_data.id = %id, "Loaded dynamic policy data from the database");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a clonable, type-erased [`HomeserverConnection`] from the
|
||||
/// configuration
|
||||
pub fn homeserver_connection_from_config(
|
||||
|
||||
@@ -10,6 +10,7 @@ use thiserror::Error;
|
||||
|
||||
pub(crate) mod compat;
|
||||
pub mod oauth2;
|
||||
pub(crate) mod policy_data;
|
||||
mod site_config;
|
||||
pub(crate) mod tokens;
|
||||
pub(crate) mod upstream_oauth2;
|
||||
@@ -32,6 +33,7 @@ pub use self::{
|
||||
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, DeviceCodeGrant,
|
||||
DeviceCodeGrantState, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState,
|
||||
},
|
||||
policy_data::PolicyData,
|
||||
site_config::{CaptchaConfig, CaptchaService, SessionExpirationConfig, SiteConfig},
|
||||
tokens::{
|
||||
AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType,
|
||||
|
||||
15
crates/data-model/src/policy_data.rs
Normal file
15
crates/data-model/src/policy_data.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::Serialize;
|
||||
use ulid::Ulid;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct PolicyData {
|
||||
pub id: Ulid,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
@@ -22,6 +22,7 @@ use indexmap::IndexMap;
|
||||
use mas_axum_utils::FancyError;
|
||||
use mas_http::CorsLayerExt;
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::{
|
||||
ApiDoc, ApiDocCallback, OAuth2AuthorizationEndpoint, OAuth2TokenEndpoint, Route, SimpleRoute,
|
||||
UrlBuilder,
|
||||
@@ -47,6 +48,11 @@ fn finish(t: TransformOpenApi) -> TransformOpenApi {
|
||||
description: Some("Manage compatibility sessions from legacy clients".to_owned()),
|
||||
..Tag::default()
|
||||
})
|
||||
.tag(Tag {
|
||||
name: "policy-data".to_owned(),
|
||||
description: Some("Manage the dynamic policy data".to_owned()),
|
||||
..Tag::default()
|
||||
})
|
||||
.tag(Tag {
|
||||
name: "oauth2-session".to_owned(),
|
||||
description: Some("Manage OAuth2 sessions".to_owned()),
|
||||
@@ -115,6 +121,7 @@ where
|
||||
CallContext: FromRequestParts<S>,
|
||||
Templates: FromRef<S>,
|
||||
UrlBuilder: FromRef<S>,
|
||||
Arc<PolicyFactory>: FromRef<S>,
|
||||
{
|
||||
// We *always* want to explicitly set the possible responses, beacuse the
|
||||
// infered ones are not necessarily correct
|
||||
|
||||
@@ -534,3 +534,50 @@ impl UpstreamOAuthLink {
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// The policy data
|
||||
#[derive(Serialize, JsonSchema)]
|
||||
pub struct PolicyData {
|
||||
#[serde(skip)]
|
||||
id: Ulid,
|
||||
|
||||
/// The creation date of the policy data
|
||||
created_at: DateTime<Utc>,
|
||||
|
||||
/// The policy data content
|
||||
data: serde_json::Value,
|
||||
}
|
||||
|
||||
impl From<mas_data_model::PolicyData> for PolicyData {
|
||||
fn from(policy_data: mas_data_model::PolicyData) -> Self {
|
||||
Self {
|
||||
id: policy_data.id,
|
||||
created_at: policy_data.created_at,
|
||||
data: policy_data.data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Resource for PolicyData {
|
||||
const KIND: &'static str = "policy-data";
|
||||
const PATH: &'static str = "/api/admin/v1/policy-data";
|
||||
|
||||
fn id(&self) -> Ulid {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl PolicyData {
|
||||
/// Samples of policy data
|
||||
pub fn samples() -> [Self; 1] {
|
||||
[Self {
|
||||
id: Ulid::from_bytes([0x01; 16]),
|
||||
created_at: DateTime::default(),
|
||||
data: serde_json::json!({
|
||||
"hello": "world",
|
||||
"foo": 42,
|
||||
"bar": true
|
||||
}),
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ use aide::axum::{
|
||||
};
|
||||
use axum::extract::{FromRef, FromRequestParts};
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_storage::BoxRng;
|
||||
|
||||
use super::call_context::CallContext;
|
||||
@@ -19,6 +20,7 @@ use crate::passwords::PasswordManager;
|
||||
|
||||
mod compat_sessions;
|
||||
mod oauth2_sessions;
|
||||
mod policy_data;
|
||||
mod upstream_oauth_links;
|
||||
mod user_emails;
|
||||
mod user_sessions;
|
||||
@@ -29,6 +31,7 @@ where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
Arc<dyn HomeserverConnection>: FromRef<S>,
|
||||
PasswordManager: FromRef<S>,
|
||||
Arc<PolicyFactory>: FromRef<S>,
|
||||
BoxRng: FromRequestParts<S>,
|
||||
CallContext: FromRequestParts<S>,
|
||||
{
|
||||
@@ -49,6 +52,21 @@ where
|
||||
"/oauth2-sessions/{id}",
|
||||
get_with(self::oauth2_sessions::get, self::oauth2_sessions::get_doc),
|
||||
)
|
||||
.api_route(
|
||||
"/policy-data",
|
||||
post_with(self::policy_data::set, self::policy_data::set_doc),
|
||||
)
|
||||
.api_route(
|
||||
"/policy-data/latest",
|
||||
get_with(
|
||||
self::policy_data::get_latest,
|
||||
self::policy_data::get_latest_doc,
|
||||
),
|
||||
)
|
||||
.api_route(
|
||||
"/policy-data/{id}",
|
||||
get_with(self::policy_data::get, self::policy_data::get_doc),
|
||||
)
|
||||
.api_route(
|
||||
"/users",
|
||||
get_with(self::users::list, self::users::list_doc)
|
||||
|
||||
153
crates/handlers/src/admin/v1/policy_data/get.rs
Normal file
153
crates/handlers/src/admin/v1/policy_data/get.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
admin::{
|
||||
call_context::CallContext,
|
||||
model::PolicyData,
|
||||
params::UlidPathParam,
|
||||
response::{ErrorResponse, SingleResponse},
|
||||
},
|
||||
impl_from_error_for_route,
|
||||
};
|
||||
|
||||
#[derive(Debug, thiserror::Error, OperationIo)]
|
||||
#[aide(output_with = "Json<ErrorResponse>")]
|
||||
pub enum RouteError {
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
|
||||
#[error("Policy data with ID {0} not found")]
|
||||
NotFound(Ulid),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
operation
|
||||
.id("getPolicyData")
|
||||
.summary("Get policy data by ID")
|
||||
.tag("policy-data")
|
||||
.response_with::<200, Json<SingleResponse<PolicyData>>, _>(|t| {
|
||||
let [sample, ..] = PolicyData::samples();
|
||||
let response = SingleResponse::new_canonical(sample);
|
||||
t.description("Policy data was found").example(response)
|
||||
})
|
||||
.response_with::<404, RouteError, _>(|t| {
|
||||
let response = ErrorResponse::from_error(&RouteError::NotFound(Ulid::nil()));
|
||||
t.description("Policy data was not found").example(response)
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.policy_data.get", skip_all, err)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
id: UlidPathParam,
|
||||
) -> Result<Json<SingleResponse<PolicyData>>, RouteError> {
|
||||
let policy_data = repo
|
||||
.policy_data()
|
||||
.get()
|
||||
.await?
|
||||
.ok_or(RouteError::NotFound(*id))?;
|
||||
|
||||
Ok(Json(SingleResponse::new_canonical(policy_data.into())))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use hyper::{Request, StatusCode};
|
||||
use insta::assert_json_snapshot;
|
||||
use sqlx::PgPool;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_get(pool: PgPool) {
|
||||
setup();
|
||||
let mut state = TestState::from_pool(pool).await.unwrap();
|
||||
let token = state.token_with_scope("urn:mas:admin").await;
|
||||
|
||||
let mut rng = state.rng();
|
||||
let mut repo = state.repository().await.unwrap();
|
||||
|
||||
let policy_data = repo
|
||||
.policy_data()
|
||||
.set(
|
||||
&mut rng,
|
||||
&state.clock,
|
||||
serde_json::json!({"hello": "world"}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
repo.save().await.unwrap();
|
||||
|
||||
let request = Request::get(format!("/api/admin/v1/policy-data/{}", policy_data.id))
|
||||
.bearer(&token)
|
||||
.empty();
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::OK);
|
||||
let body: serde_json::Value = response.json();
|
||||
assert_json_snapshot!(body, @r###"
|
||||
{
|
||||
"data": {
|
||||
"type": "policy-data",
|
||||
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
|
||||
"attributes": {
|
||||
"created_at": "2022-01-16T14:40:00Z",
|
||||
"data": {
|
||||
"hello": "world"
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E"
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E"
|
||||
}
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_get_not_found(pool: PgPool) {
|
||||
setup();
|
||||
let mut state = TestState::from_pool(pool).await.unwrap();
|
||||
let token = state.token_with_scope("urn:mas:admin").await;
|
||||
|
||||
let request = Request::get(format!("/api/admin/v1/policy-data/{}", Ulid::nil()))
|
||||
.bearer(&token)
|
||||
.empty();
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::NOT_FOUND);
|
||||
let body: serde_json::Value = response.json();
|
||||
assert_json_snapshot!(body, @r###"
|
||||
{
|
||||
"errors": [
|
||||
{
|
||||
"title": "Policy data with ID 00000000000000000000000000 not found"
|
||||
}
|
||||
]
|
||||
}
|
||||
"###);
|
||||
}
|
||||
}
|
||||
149
crates/handlers/src/admin/v1/policy_data/get_latest.rs
Normal file
149
crates/handlers/src/admin/v1/policy_data/get_latest.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
|
||||
use crate::{
|
||||
admin::{
|
||||
call_context::CallContext,
|
||||
model::PolicyData,
|
||||
response::{ErrorResponse, SingleResponse},
|
||||
},
|
||||
impl_from_error_for_route,
|
||||
};
|
||||
|
||||
#[derive(Debug, thiserror::Error, OperationIo)]
|
||||
#[aide(output_with = "Json<ErrorResponse>")]
|
||||
pub enum RouteError {
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
|
||||
#[error("No policy data found")]
|
||||
NotFound,
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
operation
|
||||
.id("getLatestPolicyData")
|
||||
.summary("Get the latest policy data")
|
||||
.tag("policy-data")
|
||||
.response_with::<200, Json<SingleResponse<PolicyData>>, _>(|t| {
|
||||
let [sample, ..] = PolicyData::samples();
|
||||
let response = SingleResponse::new_canonical(sample);
|
||||
t.description("Latest policy data was found")
|
||||
.example(response)
|
||||
})
|
||||
.response_with::<404, RouteError, _>(|t| {
|
||||
let response = ErrorResponse::from_error(&RouteError::NotFound);
|
||||
t.description("No policy data was found").example(response)
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.policy_data.get_latest", skip_all, err)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
) -> Result<Json<SingleResponse<PolicyData>>, RouteError> {
|
||||
let policy_data = repo
|
||||
.policy_data()
|
||||
.get()
|
||||
.await?
|
||||
.ok_or(RouteError::NotFound)?;
|
||||
|
||||
Ok(Json(SingleResponse::new_canonical(policy_data.into())))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use hyper::{Request, StatusCode};
|
||||
use insta::assert_json_snapshot;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_get_latest(pool: PgPool) {
|
||||
setup();
|
||||
let mut state = TestState::from_pool(pool).await.unwrap();
|
||||
let token = state.token_with_scope("urn:mas:admin").await;
|
||||
|
||||
let mut rng = state.rng();
|
||||
let mut repo = state.repository().await.unwrap();
|
||||
|
||||
repo.policy_data()
|
||||
.set(
|
||||
&mut rng,
|
||||
&state.clock,
|
||||
serde_json::json!({"hello": "world"}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
repo.save().await.unwrap();
|
||||
|
||||
let request = Request::get("/api/admin/v1/policy-data/latest")
|
||||
.bearer(&token)
|
||||
.empty();
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::OK);
|
||||
let body: serde_json::Value = response.json();
|
||||
assert_json_snapshot!(body, @r###"
|
||||
{
|
||||
"data": {
|
||||
"type": "policy-data",
|
||||
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
|
||||
"attributes": {
|
||||
"created_at": "2022-01-16T14:40:00Z",
|
||||
"data": {
|
||||
"hello": "world"
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E"
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E"
|
||||
}
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_get_no_latest(pool: PgPool) {
|
||||
setup();
|
||||
let mut state = TestState::from_pool(pool).await.unwrap();
|
||||
let token = state.token_with_scope("urn:mas:admin").await;
|
||||
|
||||
let request = Request::get("/api/admin/v1/policy-data/latest")
|
||||
.bearer(&token)
|
||||
.empty();
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::NOT_FOUND);
|
||||
let body: serde_json::Value = response.json();
|
||||
assert_json_snapshot!(body, @r###"
|
||||
{
|
||||
"errors": [
|
||||
{
|
||||
"title": "No policy data found"
|
||||
}
|
||||
]
|
||||
}
|
||||
"###);
|
||||
}
|
||||
}
|
||||
14
crates/handlers/src/admin/v1/policy_data/mod.rs
Normal file
14
crates/handlers/src/admin/v1/policy_data/mod.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
mod get;
|
||||
mod get_latest;
|
||||
mod set;
|
||||
|
||||
pub use self::{
|
||||
get::{doc as get_doc, handler as get},
|
||||
get_latest::{doc as get_latest_doc, handler as get_latest},
|
||||
set::{doc as set_doc, handler as set},
|
||||
};
|
||||
152
crates/handlers/src/admin/v1/policy_data/set.rs
Normal file
152
crates/handlers/src/admin/v1/policy_data/set.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use aide::{NoApi, OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, extract::State, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_storage::BoxRng;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
admin::{
|
||||
call_context::CallContext,
|
||||
model::PolicyData,
|
||||
response::{ErrorResponse, SingleResponse},
|
||||
},
|
||||
impl_from_error_for_route,
|
||||
};
|
||||
|
||||
#[derive(Debug, thiserror::Error, OperationIo)]
|
||||
#[aide(output_with = "Json<ErrorResponse>")]
|
||||
pub enum RouteError {
|
||||
#[error("Failed to instanciate policy with the provided data")]
|
||||
InvalidPolicyData(#[from] mas_policy::LoadError),
|
||||
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let status = match self {
|
||||
RouteError::InvalidPolicyData(_) => StatusCode::BAD_REQUEST,
|
||||
RouteError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
fn data_example() -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"hello": "world",
|
||||
"foo": 42,
|
||||
"bar": true
|
||||
})
|
||||
}
|
||||
|
||||
/// # JSON payload for the `POST /api/admin/v1/policy-data`
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
#[serde(rename = "SetPolicyDataRequest")]
|
||||
pub struct SetPolicyDataRequest {
|
||||
#[schemars(example = "data_example")]
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
operation
|
||||
.id("setPolicyData")
|
||||
.summary("Set the current policy data")
|
||||
.tag("policy-data")
|
||||
.response_with::<201, Json<SingleResponse<PolicyData>>, _>(|t| {
|
||||
let [sample, ..] = PolicyData::samples();
|
||||
let response = SingleResponse::new_canonical(sample);
|
||||
t.description("Policy data was successfully set")
|
||||
.example(response)
|
||||
})
|
||||
.response_with::<400, Json<ErrorResponse>, _>(|t| {
|
||||
let error = ErrorResponse::from_error(&RouteError::InvalidPolicyData(
|
||||
mas_policy::LoadError::invalid_data_example(),
|
||||
));
|
||||
t.description("Invalid policy data").example(error)
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.policy_data.set", skip_all, err)]
|
||||
pub async fn handler(
|
||||
CallContext {
|
||||
mut repo, clock, ..
|
||||
}: CallContext,
|
||||
NoApi(mut rng): NoApi<BoxRng>,
|
||||
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||
Json(request): Json<SetPolicyDataRequest>,
|
||||
) -> Result<(StatusCode, Json<SingleResponse<PolicyData>>), RouteError> {
|
||||
let policy_data = repo
|
||||
.policy_data()
|
||||
.set(&mut rng, &clock, request.data)
|
||||
.await?;
|
||||
|
||||
// Swap the policy data. This will fail if the policy data is invalid
|
||||
policy_factory.set_dynamic_data(policy_data.clone()).await?;
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::CREATED,
|
||||
Json(SingleResponse::new_canonical(policy_data.into())),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use hyper::{Request, StatusCode};
|
||||
use insta::assert_json_snapshot;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_create(pool: PgPool) {
|
||||
setup();
|
||||
let mut state = TestState::from_pool(pool).await.unwrap();
|
||||
let token = state.token_with_scope("urn:mas:admin").await;
|
||||
|
||||
let request = Request::post("/api/admin/v1/policy-data")
|
||||
.bearer(&token)
|
||||
.json(serde_json::json!({
|
||||
"data": {
|
||||
"hello": "world"
|
||||
}
|
||||
}));
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::CREATED);
|
||||
let body: serde_json::Value = response.json();
|
||||
assert_json_snapshot!(body, @r###"
|
||||
{
|
||||
"data": {
|
||||
"type": "policy-data",
|
||||
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
|
||||
"attributes": {
|
||||
"created_at": "2022-01-16T14:40:00Z",
|
||||
"data": {
|
||||
"hello": "world"
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E"
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"self": "/api/admin/v1/policy-data/01FSHN9AG0MZAA6S4AF7CTV32E"
|
||||
}
|
||||
}
|
||||
"###);
|
||||
}
|
||||
}
|
||||
@@ -58,6 +58,7 @@ impl_from_ref!(mas_templates::Templates);
|
||||
impl_from_ref!(Arc<dyn mas_matrix::HomeserverConnection>);
|
||||
impl_from_ref!(mas_keystore::Keystore);
|
||||
impl_from_ref!(mas_handlers::passwords::PasswordManager);
|
||||
impl_from_ref!(Arc<mas_policy::PolicyFactory>);
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (mut api, _) = mas_handlers::admin_api_router::<DummyState>();
|
||||
|
||||
@@ -514,6 +514,12 @@ impl FromRef<TestState> for SiteConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<TestState> for Arc<PolicyFactory> {
|
||||
fn from_ref(input: &TestState) -> Self {
|
||||
input.policy_factory.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<TestState> for Arc<dyn HomeserverConnection> {
|
||||
fn from_ref(input: &TestState) -> Self {
|
||||
input.homeserver_connection.clone()
|
||||
|
||||
@@ -13,6 +13,7 @@ workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
arc-swap.workspace = true
|
||||
opa-wasm = "0.1.4"
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -6,11 +6,14 @@
|
||||
|
||||
pub mod model;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use mas_data_model::Ulid;
|
||||
use opa_wasm::{
|
||||
Runtime,
|
||||
wasmtime::{Config, Engine, Module, OptLevel, Store},
|
||||
};
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
@@ -33,10 +36,23 @@ pub enum LoadError {
|
||||
#[error("failed to compile WASM module")]
|
||||
Compilation(#[source] anyhow::Error),
|
||||
|
||||
#[error("invalid policy data")]
|
||||
InvalidData(#[source] anyhow::Error),
|
||||
|
||||
#[error("failed to instantiate a test instance")]
|
||||
Instantiate(#[source] InstantiateError),
|
||||
}
|
||||
|
||||
impl LoadError {
|
||||
/// Creates an example of an invalid data error, used for API response
|
||||
/// documentation
|
||||
#[doc(hidden)]
|
||||
#[must_use]
|
||||
pub fn invalid_data_example() -> Self {
|
||||
Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InstantiateError {
|
||||
#[error("failed to create WASM runtime")]
|
||||
@@ -69,11 +85,10 @@ impl Entrypoints {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[derive(Debug)]
|
||||
pub struct Data {
|
||||
server_name: String,
|
||||
|
||||
#[serde(flatten)]
|
||||
rest: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
@@ -91,12 +106,93 @@ impl Data {
|
||||
self.rest = Some(rest);
|
||||
self
|
||||
}
|
||||
|
||||
fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
|
||||
let base = serde_json::json!({
|
||||
"server_name": self.server_name,
|
||||
});
|
||||
|
||||
if let Some(rest) = &self.rest {
|
||||
merge_data(base, rest.clone())
|
||||
} else {
|
||||
Ok(base)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn value_kind(value: &serde_json::Value) -> &'static str {
|
||||
match value {
|
||||
serde_json::Value::Object(_) => "object",
|
||||
serde_json::Value::Array(_) => "array",
|
||||
serde_json::Value::String(_) => "string",
|
||||
serde_json::Value::Number(_) => "number",
|
||||
serde_json::Value::Bool(_) => "boolean",
|
||||
serde_json::Value::Null => "null",
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_data(
|
||||
mut left: serde_json::Value,
|
||||
right: serde_json::Value,
|
||||
) -> Result<serde_json::Value, anyhow::Error> {
|
||||
merge_data_rec(&mut left, right)?;
|
||||
Ok(left)
|
||||
}
|
||||
|
||||
fn merge_data_rec(
|
||||
left: &mut serde_json::Value,
|
||||
right: serde_json::Value,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
match (left, right) {
|
||||
(serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
|
||||
for (key, value) in right {
|
||||
if let Some(left_value) = left.get_mut(&key) {
|
||||
merge_data_rec(left_value, value)?;
|
||||
} else {
|
||||
left.insert(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
(serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
|
||||
left.extend(right);
|
||||
}
|
||||
// Other values override
|
||||
(serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
|
||||
*left = right;
|
||||
}
|
||||
(serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
|
||||
*left = right;
|
||||
}
|
||||
(serde_json::Value::String(left), serde_json::Value::String(right)) => {
|
||||
*left = right;
|
||||
}
|
||||
|
||||
// Null gets overridden by anything
|
||||
(left, right) if left.is_null() => *left = right,
|
||||
|
||||
// Null on the right makes the left value null
|
||||
(left, right) if right.is_null() => *left = right,
|
||||
|
||||
(left, right) => anyhow::bail!(
|
||||
"Cannot merge a {} into a {}",
|
||||
value_kind(&right),
|
||||
value_kind(left),
|
||||
),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct DynamicData {
|
||||
version: Option<Ulid>,
|
||||
merged: serde_json::Value,
|
||||
}
|
||||
|
||||
pub struct PolicyFactory {
|
||||
engine: Engine,
|
||||
module: Module,
|
||||
data: Data,
|
||||
dynamic_data: ArcSwap<DynamicData>,
|
||||
entrypoints: Entrypoints,
|
||||
}
|
||||
|
||||
@@ -124,10 +220,17 @@ impl PolicyFactory {
|
||||
.await?
|
||||
.map_err(LoadError::Compilation)?;
|
||||
|
||||
let merged = data.to_value().map_err(LoadError::InvalidData)?;
|
||||
let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
|
||||
version: None,
|
||||
merged,
|
||||
}));
|
||||
|
||||
let factory = Self {
|
||||
engine,
|
||||
module,
|
||||
data,
|
||||
dynamic_data,
|
||||
entrypoints,
|
||||
};
|
||||
|
||||
@@ -140,8 +243,56 @@ impl PolicyFactory {
|
||||
Ok(factory)
|
||||
}
|
||||
|
||||
/// Set the dynamic data for the policy.
|
||||
///
|
||||
/// The `dynamic_data` object is merged with the static data given when the
|
||||
/// policy was loaded.
|
||||
///
|
||||
/// Returns `true` if the data was updated, `false` if the version
|
||||
/// of the dynamic data was the same as the one we already have.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the data can't be merged with the static data, or if
|
||||
/// the policy can't be instantiated with the new data.
|
||||
pub async fn set_dynamic_data(
|
||||
&self,
|
||||
dynamic_data: mas_data_model::PolicyData,
|
||||
) -> Result<bool, LoadError> {
|
||||
// Check if the version of the dynamic data we have is the same as the one we're
|
||||
// trying to set
|
||||
if self.dynamic_data.load().version == Some(dynamic_data.id) {
|
||||
// Don't do anything if the version is the same
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
|
||||
let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
|
||||
|
||||
// Try to instantiate with the new data
|
||||
self.instantiate_with_data(&merged)
|
||||
.await
|
||||
.map_err(LoadError::Instantiate)?;
|
||||
|
||||
// If instantiation succeeds, swap the data
|
||||
self.dynamic_data.store(Arc::new(DynamicData {
|
||||
version: Some(dynamic_data.id),
|
||||
merged,
|
||||
}));
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "policy.instantiate", skip_all, err)]
|
||||
pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
|
||||
let data = self.dynamic_data.load();
|
||||
self.instantiate_with_data(&data.merged).await
|
||||
}
|
||||
|
||||
async fn instantiate_with_data(
|
||||
&self,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<Policy, InstantiateError> {
|
||||
let mut store = Store::new(&self.engine, ());
|
||||
let runtime = Runtime::new(&mut store, &self.module)
|
||||
.await
|
||||
@@ -159,7 +310,7 @@ impl PolicyFactory {
|
||||
}
|
||||
|
||||
let instance = runtime
|
||||
.with_data(&mut store, &self.data)
|
||||
.with_data(&mut store, data)
|
||||
.await
|
||||
.map_err(InstantiateError::LoadData)?;
|
||||
|
||||
@@ -273,6 +424,8 @@ impl Policy {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use std::time::SystemTime;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -344,4 +497,167 @@ mod tests {
|
||||
.unwrap();
|
||||
assert!(!res.valid());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dynamic_data() {
|
||||
let data = Data::new("example.com".to_owned());
|
||||
|
||||
#[allow(clippy::disallowed_types)]
|
||||
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("..")
|
||||
.join("..")
|
||||
.join("policies")
|
||||
.join("policy.wasm");
|
||||
|
||||
let file = tokio::fs::File::open(path).await.unwrap();
|
||||
|
||||
let entrypoints = Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
||||
|
||||
let mut policy = factory.instantiate().await.unwrap();
|
||||
|
||||
let res = policy
|
||||
.evaluate_register(RegisterInput {
|
||||
registration_method: RegistrationMethod::Password,
|
||||
username: "hello",
|
||||
email: Some("hello@example.com"),
|
||||
requester: Requester {
|
||||
ip_address: None,
|
||||
user_agent: None,
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(res.valid());
|
||||
|
||||
// Update the policy data
|
||||
factory
|
||||
.set_dynamic_data(mas_data_model::PolicyData {
|
||||
id: Ulid::nil(),
|
||||
created_at: SystemTime::now().into(),
|
||||
data: serde_json::json!({
|
||||
"emails": {
|
||||
"banned_addresses": {
|
||||
"substrings": ["hello"]
|
||||
}
|
||||
}
|
||||
}),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let mut policy = factory.instantiate().await.unwrap();
|
||||
let res = policy
|
||||
.evaluate_register(RegisterInput {
|
||||
registration_method: RegistrationMethod::Password,
|
||||
username: "hello",
|
||||
email: Some("hello@example.com"),
|
||||
requester: Requester {
|
||||
ip_address: None,
|
||||
user_agent: None,
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!res.valid());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_big_dynamic_data() {
|
||||
let data = Data::new("example.com".to_owned());
|
||||
|
||||
#[allow(clippy::disallowed_types)]
|
||||
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("..")
|
||||
.join("..")
|
||||
.join("policies")
|
||||
.join("policy.wasm");
|
||||
|
||||
let file = tokio::fs::File::open(path).await.unwrap();
|
||||
|
||||
let entrypoints = Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
};
|
||||
|
||||
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
||||
|
||||
// That is around 1 MB of JSON data. Each element is a 5-digit string, so 8
|
||||
// characters including the quotes and a comma.
|
||||
let data: Vec<String> = (0..(1024 * 1024 / 8))
|
||||
.map(|i| format!("{:05}", i % 100_000))
|
||||
.collect();
|
||||
let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
|
||||
factory
|
||||
.set_dynamic_data(mas_data_model::PolicyData {
|
||||
id: Ulid::nil(),
|
||||
created_at: SystemTime::now().into(),
|
||||
data: json,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Try instantiating the policy, make sure 5-digit numbers are banned from email
|
||||
// addresses
|
||||
let mut policy = factory.instantiate().await.unwrap();
|
||||
let res = policy
|
||||
.evaluate_register(RegisterInput {
|
||||
registration_method: RegistrationMethod::Password,
|
||||
username: "hello",
|
||||
email: Some("12345@example.com"),
|
||||
requester: Requester {
|
||||
ip_address: None,
|
||||
user_agent: None,
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!res.valid());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge() {
|
||||
use serde_json::json as j;
|
||||
|
||||
// Merging objects
|
||||
let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
|
||||
assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
|
||||
|
||||
// Override a value of the same type
|
||||
let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
|
||||
assert_eq!(res, j!({"hello": "john"}));
|
||||
|
||||
let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
|
||||
assert_eq!(res, j!({"hello": false}));
|
||||
|
||||
let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
|
||||
assert_eq!(res, j!({"hello": 42}));
|
||||
|
||||
// Override a value of a different type
|
||||
merge_data(j!({"hello": "world"}), j!({"hello": 123}))
|
||||
.expect_err("Can't merge different types");
|
||||
|
||||
// Merge arrays
|
||||
let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
|
||||
assert_eq!(res, j!({"hello": ["world", "john"]}));
|
||||
|
||||
// Null overrides a value
|
||||
let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
|
||||
assert_eq!(res, j!({"hello": null}));
|
||||
|
||||
// Null gets overridden by a value
|
||||
let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
|
||||
assert_eq!(res, j!({"hello": "world"}));
|
||||
|
||||
// Objects get deeply merged
|
||||
let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
|
||||
assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
|
||||
}
|
||||
}
|
||||
|
||||
14
crates/storage-pg/.sqlx/query-5006c3e60c98c91a0b0fbb3205373e81d9b75e90929af80961f8b5910873a43e.json
generated
Normal file
14
crates/storage-pg/.sqlx/query-5006c3e60c98c91a0b0fbb3205373e81d9b75e90929af80961f8b5910873a43e.json
generated
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n DELETE FROM policy_data\n WHERE policy_data_id IN (\n SELECT policy_data_id\n FROM policy_data\n ORDER BY policy_data_id DESC\n OFFSET $1\n )\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "5006c3e60c98c91a0b0fbb3205373e81d9b75e90929af80961f8b5910873a43e"
|
||||
}
|
||||
32
crates/storage-pg/.sqlx/query-9fe87eeaf4b7d0ba09b59ddad3476eb57ccb6e4053ab8f4450dd4a9d1f6ba108.json
generated
Normal file
32
crates/storage-pg/.sqlx/query-9fe87eeaf4b7d0ba09b59ddad3476eb57ccb6e4053ab8f4450dd4a9d1f6ba108.json
generated
Normal file
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT policy_data_id, created_at, data\n FROM policy_data\n ORDER BY policy_data_id DESC\n LIMIT 1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "policy_data_id",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "data",
|
||||
"type_info": "Jsonb"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": []
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "9fe87eeaf4b7d0ba09b59ddad3476eb57ccb6e4053ab8f4450dd4a9d1f6ba108"
|
||||
}
|
||||
16
crates/storage-pg/.sqlx/query-b6c4f4a23968cba2a82c2b7cfffc05a7ed582c9e5c1f65d27b0686f843ccfe42.json
generated
Normal file
16
crates/storage-pg/.sqlx/query-b6c4f4a23968cba2a82c2b7cfffc05a7ed582c9e5c1f65d27b0686f843ccfe42.json
generated
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO policy_data (policy_data_id, created_at, data)\n VALUES ($1, $2, $3)\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Timestamptz",
|
||||
"Jsonb"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "b6c4f4a23968cba2a82c2b7cfffc05a7ed582c9e5c1f65d27b0686f843ccfe42"
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
-- Copyright 2025 New Vector Ltd.
|
||||
--
|
||||
-- SPDX-License-Identifier: AGPL-3.0-only
|
||||
-- Please see LICENSE in the repository root for full details.
|
||||
|
||||
-- Add a table which stores the latest policy data
|
||||
--
|
||||
-- Every time the policy data is updated, it creates a new row, so that we keep
|
||||
-- an history of the policy data, trace back which version of the data was used
|
||||
-- on each evaluation.
|
||||
CREATE TABLE IF NOT EXISTS policy_data (
|
||||
policy_data_id UUID PRIMARY KEY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
data JSONB NOT NULL
|
||||
);
|
||||
@@ -173,6 +173,7 @@ mod errors;
|
||||
pub(crate) mod filter;
|
||||
pub(crate) mod iden;
|
||||
pub(crate) mod pagination;
|
||||
pub(crate) mod policy_data;
|
||||
pub(crate) mod repository;
|
||||
pub(crate) mod tracing;
|
||||
|
||||
|
||||
204
crates/storage-pg/src/policy_data.rs
Normal file
204
crates/storage-pg/src/policy_data.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
//! A module containing the PostgreSQL implementation of the policy data
|
||||
//! storage.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mas_data_model::PolicyData;
|
||||
use mas_storage::{Clock, policy_data::PolicyDataRepository};
|
||||
use rand::RngCore;
|
||||
use serde_json::Value;
|
||||
use sqlx::{PgConnection, types::Json};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{DatabaseError, ExecuteExt};
|
||||
|
||||
/// An implementation of [`PolicyDataRepository`] for a PostgreSQL connection.
|
||||
pub struct PgPolicyDataRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgPolicyDataRepository<'c> {
|
||||
/// Create a new [`PgPolicyDataRepository`] from an active PostgreSQL
|
||||
/// connection.
|
||||
#[must_use]
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct PolicyDataLookup {
|
||||
policy_data_id: Uuid,
|
||||
created_at: chrono::DateTime<chrono::Utc>,
|
||||
data: Json<Value>,
|
||||
}
|
||||
|
||||
impl From<PolicyDataLookup> for PolicyData {
|
||||
fn from(value: PolicyDataLookup) -> Self {
|
||||
PolicyData {
|
||||
id: value.policy_data_id.into(),
|
||||
created_at: value.created_at,
|
||||
data: value.data.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl PolicyDataRepository for PgPolicyDataRepository<'_> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.policy_data.get",
|
||||
skip_all,
|
||||
fields(
|
||||
db.query.text,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn get(&mut self) -> Result<Option<PolicyData>, Self::Error> {
|
||||
let row = sqlx::query_as!(
|
||||
PolicyDataLookup,
|
||||
r#"
|
||||
SELECT policy_data_id, created_at, data
|
||||
FROM policy_data
|
||||
ORDER BY policy_data_id DESC
|
||||
LIMIT 1
|
||||
"#
|
||||
)
|
||||
.traced()
|
||||
.fetch_optional(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let Some(row) = row else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(Some(row.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.policy_data.set",
|
||||
skip_all,
|
||||
fields(
|
||||
db.query.text,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn set(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
data: Value,
|
||||
) -> Result<PolicyData, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO policy_data (policy_data_id, created_at, data)
|
||||
VALUES ($1, $2, $3)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
created_at,
|
||||
data,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(PolicyData {
|
||||
id,
|
||||
created_at,
|
||||
data,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.policy_data.prune",
|
||||
skip_all,
|
||||
fields(
|
||||
db.query.text,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn prune(&mut self, keep: usize) -> Result<usize, Self::Error> {
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM policy_data
|
||||
WHERE policy_data_id IN (
|
||||
SELECT policy_data_id
|
||||
FROM policy_data
|
||||
ORDER BY policy_data_id DESC
|
||||
OFFSET $1
|
||||
)
|
||||
"#,
|
||||
i64::try_from(keep).map_err(DatabaseError::to_invalid_operation)?
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(res
|
||||
.rows_affected()
|
||||
.try_into()
|
||||
.map_err(DatabaseError::to_invalid_operation)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use mas_storage::{clock::MockClock, policy_data::PolicyDataRepository};
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaChaRng;
|
||||
use serde_json::json;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::policy_data::PgPolicyDataRepository;
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_policy_data(pool: PgPool) {
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = MockClock::default();
|
||||
let mut conn = pool.acquire().await.unwrap();
|
||||
let mut repo = PgPolicyDataRepository::new(&mut conn);
|
||||
|
||||
// Get an empty state at first
|
||||
let data = repo.get().await.unwrap();
|
||||
assert_eq!(data, None);
|
||||
|
||||
// Set some data
|
||||
let value1 = json!({"hello": "world"});
|
||||
let policy_data1 = repo.set(&mut rng, &clock, value1.clone()).await.unwrap();
|
||||
assert_eq!(policy_data1.data, value1);
|
||||
|
||||
let data_fetched1 = repo.get().await.unwrap().unwrap();
|
||||
assert_eq!(policy_data1, data_fetched1);
|
||||
|
||||
// Set some new data
|
||||
clock.advance(chrono::Duration::seconds(1));
|
||||
let value2 = json!({"foo": "bar"});
|
||||
let policy_data2 = repo.set(&mut rng, &clock, value2.clone()).await.unwrap();
|
||||
assert_eq!(policy_data2.data, value2);
|
||||
|
||||
// Check the new data is fetched
|
||||
let data_fetched2 = repo.get().await.unwrap().unwrap();
|
||||
assert_eq!(data_fetched2, policy_data2);
|
||||
|
||||
// Prune until the first entry
|
||||
let affected = repo.prune(1).await.unwrap();
|
||||
let data_fetched3 = repo.get().await.unwrap().unwrap();
|
||||
assert_eq!(data_fetched3, policy_data2);
|
||||
assert_eq!(affected, 1);
|
||||
|
||||
// Do a raw query to check the other rows were pruned
|
||||
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM policy_data")
|
||||
.fetch_one(&mut *conn)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(count, 1);
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ use mas_storage::{
|
||||
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
|
||||
OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
|
||||
},
|
||||
policy_data::PolicyDataRepository,
|
||||
queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
|
||||
upstream_oauth2::{
|
||||
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
|
||||
@@ -40,6 +41,7 @@ use crate::{
|
||||
PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository,
|
||||
PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
|
||||
},
|
||||
policy_data::PgPolicyDataRepository,
|
||||
queue::{
|
||||
job::PgQueueJobRepository, schedule::PgQueueScheduleRepository,
|
||||
worker::PgQueueWorkerRepository,
|
||||
@@ -283,4 +285,8 @@ where
|
||||
) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
|
||||
Box::new(PgQueueScheduleRepository::new(self.conn.as_mut()))
|
||||
}
|
||||
|
||||
fn policy_data<'c>(&'c mut self) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
|
||||
Box::new(PgPolicyDataRepository::new(self.conn.as_mut()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,6 +119,7 @@ mod utils;
|
||||
pub mod app_session;
|
||||
pub mod compat;
|
||||
pub mod oauth2;
|
||||
pub mod policy_data;
|
||||
pub mod queue;
|
||||
pub mod upstream_oauth2;
|
||||
pub mod user;
|
||||
|
||||
76
crates/storage/src/policy_data.rs
Normal file
76
crates/storage/src/policy_data.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
//! Repositories to interact with the policy data saved in the storage backend.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mas_data_model::PolicyData;
|
||||
use rand_core::RngCore;
|
||||
|
||||
use crate::{Clock, repository_impl};
|
||||
|
||||
/// A [`PolicyDataRepository`] helps interacting with the policy data saved in
|
||||
/// the storage backend.
|
||||
#[async_trait]
|
||||
pub trait PolicyDataRepository: Send + Sync {
|
||||
/// The error type returned by the repository
|
||||
type Error;
|
||||
|
||||
/// Get the latest policy data
|
||||
///
|
||||
/// Returns the latest policy data, or `None` if no policy data is
|
||||
/// available.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn get(&mut self) -> Result<Option<PolicyData>, Self::Error>;
|
||||
|
||||
/// Set the latest policy data
|
||||
///
|
||||
/// Returns the newly created policy data.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `rng`: The random number generator to use
|
||||
/// * `clock`: The clock used to generate the timestamps
|
||||
/// * `data`: The policy data to set
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn set(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
data: serde_json::Value,
|
||||
) -> Result<PolicyData, Self::Error>;
|
||||
|
||||
/// Prune old policy data
|
||||
///
|
||||
/// Returns the number of entries pruned.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `keep`: the number of old entries to keep
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn prune(&mut self, keep: usize) -> Result<usize, Self::Error>;
|
||||
}
|
||||
|
||||
repository_impl!(PolicyDataRepository:
|
||||
async fn get(&mut self) -> Result<Option<PolicyData>, Self::Error>;
|
||||
|
||||
async fn set(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
data: serde_json::Value,
|
||||
) -> Result<PolicyData, Self::Error>;
|
||||
|
||||
async fn prune(&mut self, keep: usize) -> Result<usize, Self::Error>;
|
||||
);
|
||||
@@ -506,3 +506,11 @@ impl ExpireInactiveUserSessionsJob {
|
||||
impl InsertableJob for ExpireInactiveUserSessionsJob {
|
||||
const QUEUE_NAME: &'static str = "expire-inactive-user-sessions";
|
||||
}
|
||||
|
||||
/// Prune stale policy data
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PruneStalePolicyDataJob;
|
||||
|
||||
impl InsertableJob for PruneStalePolicyDataJob {
|
||||
const QUEUE_NAME: &'static str = "prune-stale-policy-data";
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ use crate::{
|
||||
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
|
||||
OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
|
||||
},
|
||||
policy_data::PolicyDataRepository,
|
||||
queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
|
||||
upstream_oauth2::{
|
||||
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
|
||||
@@ -204,6 +205,9 @@ pub trait RepositoryAccess: Send {
|
||||
fn queue_schedule<'c>(
|
||||
&'c mut self,
|
||||
) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c>;
|
||||
|
||||
/// Get a [`PolicyDataRepository`]
|
||||
fn policy_data<'c>(&'c mut self) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c>;
|
||||
}
|
||||
|
||||
/// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and
|
||||
@@ -224,6 +228,7 @@ mod impls {
|
||||
OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository,
|
||||
OAuth2SessionRepository,
|
||||
},
|
||||
policy_data::PolicyDataRepository,
|
||||
queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
|
||||
upstream_oauth2::{
|
||||
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
|
||||
@@ -439,6 +444,12 @@ mod impls {
|
||||
) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
|
||||
Box::new(MapErr::new(self.inner.queue_schedule(), &mut self.mapper))
|
||||
}
|
||||
|
||||
fn policy_data<'c>(
|
||||
&'c mut self,
|
||||
) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
|
||||
Box::new(MapErr::new(self.inner.policy_data(), &mut self.mapper))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RepositoryAccess + ?Sized> RepositoryAccess for Box<R> {
|
||||
@@ -579,5 +590,11 @@ mod impls {
|
||||
) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
|
||||
(**self).queue_schedule()
|
||||
}
|
||||
|
||||
fn policy_data<'c>(
|
||||
&'c mut self,
|
||||
) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
|
||||
(**self).policy_data()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
//! Database-related tasks
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mas_storage::queue::CleanupExpiredTokensJob;
|
||||
use mas_storage::queue::{CleanupExpiredTokensJob, PruneStalePolicyDataJob};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{
|
||||
@@ -38,3 +38,28 @@ impl RunnableJob for CleanupExpiredTokensJob {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RunnableJob for PruneStalePolicyDataJob {
|
||||
#[tracing::instrument(name = "job.prune_stale_policy_data", skip_all, err)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
|
||||
let mut repo = state.repository().await.map_err(JobError::retry)?;
|
||||
|
||||
// Keep the last 10 policy data
|
||||
let count = repo
|
||||
.policy_data()
|
||||
.prune(10)
|
||||
.await
|
||||
.map_err(JobError::retry)?;
|
||||
|
||||
repo.save().await.map_err(JobError::retry)?;
|
||||
|
||||
if count == 0 {
|
||||
debug!("no stale policy data to prune");
|
||||
} else {
|
||||
info!(count, "pruned stale policy data");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,6 +143,7 @@ pub async fn init(
|
||||
.register_handler::<mas_storage::queue::ExpireInactiveCompatSessionsJob>()
|
||||
.register_handler::<mas_storage::queue::ExpireInactiveOAuthSessionsJob>()
|
||||
.register_handler::<mas_storage::queue::ExpireInactiveUserSessionsJob>()
|
||||
.register_handler::<mas_storage::queue::PruneStalePolicyDataJob>()
|
||||
.add_schedule(
|
||||
"cleanup-expired-tokens",
|
||||
"0 0 * * * *".parse()?,
|
||||
@@ -153,6 +154,12 @@ pub async fn init(
|
||||
// Run this job every 15 minutes
|
||||
"30 */15 * * * *".parse()?,
|
||||
mas_storage::queue::ExpireInactiveSessionsJob,
|
||||
)
|
||||
.add_schedule(
|
||||
"prune-stale-policy-data",
|
||||
// Run once a day
|
||||
"0 0 2 * * *".parse()?,
|
||||
mas_storage::queue::PruneStalePolicyDataJob,
|
||||
);
|
||||
|
||||
task_tracker.spawn(worker.run());
|
||||
|
||||
@@ -593,6 +593,208 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/admin/v1/policy-data": {
|
||||
"post": {
|
||||
"tags": [
|
||||
"policy-data"
|
||||
],
|
||||
"summary": "Set the current policy data",
|
||||
"operationId": "setPolicyData",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SetPolicyDataRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Policy data was successfully set",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SingleResponse_for_PolicyData"
|
||||
},
|
||||
"example": {
|
||||
"data": {
|
||||
"type": "policy-data",
|
||||
"id": "01040G2081040G2081040G2081",
|
||||
"attributes": {
|
||||
"created_at": "1970-01-01T00:00:00Z",
|
||||
"data": {
|
||||
"hello": "world",
|
||||
"foo": 42,
|
||||
"bar": true
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081"
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"description": "Invalid policy data",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ErrorResponse"
|
||||
},
|
||||
"example": {
|
||||
"errors": [
|
||||
{
|
||||
"title": "Failed to instanciate policy with the provided data"
|
||||
},
|
||||
{
|
||||
"title": "invalid policy data"
|
||||
},
|
||||
{
|
||||
"title": "Failed to merge policy data objects"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/admin/v1/policy-data/latest": {
|
||||
"get": {
|
||||
"tags": [
|
||||
"policy-data"
|
||||
],
|
||||
"summary": "Get the latest policy data",
|
||||
"operationId": "getLatestPolicyData",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Latest policy data was found",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SingleResponse_for_PolicyData"
|
||||
},
|
||||
"example": {
|
||||
"data": {
|
||||
"type": "policy-data",
|
||||
"id": "01040G2081040G2081040G2081",
|
||||
"attributes": {
|
||||
"created_at": "1970-01-01T00:00:00Z",
|
||||
"data": {
|
||||
"hello": "world",
|
||||
"foo": 42,
|
||||
"bar": true
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081"
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "No policy data was found",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ErrorResponse"
|
||||
},
|
||||
"example": {
|
||||
"errors": [
|
||||
{
|
||||
"title": "No policy data found"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/admin/v1/policy-data/{id}": {
|
||||
"get": {
|
||||
"tags": [
|
||||
"policy-data"
|
||||
],
|
||||
"summary": "Get policy data by ID",
|
||||
"operationId": "getPolicyData",
|
||||
"parameters": [
|
||||
{
|
||||
"in": "path",
|
||||
"name": "id",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"title": "The ID of the resource",
|
||||
"$ref": "#/components/schemas/ULID"
|
||||
},
|
||||
"style": "simple"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Policy data was found",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SingleResponse_for_PolicyData"
|
||||
},
|
||||
"example": {
|
||||
"data": {
|
||||
"type": "policy-data",
|
||||
"id": "01040G2081040G2081040G2081",
|
||||
"attributes": {
|
||||
"created_at": "1970-01-01T00:00:00Z",
|
||||
"data": {
|
||||
"hello": "world",
|
||||
"foo": 42,
|
||||
"bar": true
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081"
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"self": "/api/admin/v1/policy-data/01040G2081040G2081040G2081"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Policy data was not found",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ErrorResponse"
|
||||
},
|
||||
"example": {
|
||||
"errors": [
|
||||
{
|
||||
"title": "Policy data with ID 00000000000000000000000000 not found"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/admin/v1/users": {
|
||||
"get": {
|
||||
"tags": [
|
||||
@@ -2666,6 +2868,86 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"SetPolicyDataRequest": {
|
||||
"title": "JSON payload for the `POST /api/admin/v1/policy-data`",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"data"
|
||||
],
|
||||
"properties": {
|
||||
"data": {
|
||||
"examples": [
|
||||
{
|
||||
"hello": "world",
|
||||
"foo": 42,
|
||||
"bar": true
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"SingleResponse_for_PolicyData": {
|
||||
"description": "A top-level response with a single resource",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"data",
|
||||
"links"
|
||||
],
|
||||
"properties": {
|
||||
"data": {
|
||||
"$ref": "#/components/schemas/SingleResource_for_PolicyData"
|
||||
},
|
||||
"links": {
|
||||
"$ref": "#/components/schemas/SelfLinks"
|
||||
}
|
||||
}
|
||||
},
|
||||
"SingleResource_for_PolicyData": {
|
||||
"description": "A single resource, with its type, ID, attributes and related links",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"attributes",
|
||||
"id",
|
||||
"links",
|
||||
"type"
|
||||
],
|
||||
"properties": {
|
||||
"type": {
|
||||
"description": "The type of the resource",
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"description": "The ID of the resource",
|
||||
"$ref": "#/components/schemas/ULID"
|
||||
},
|
||||
"attributes": {
|
||||
"description": "The attributes of the resource",
|
||||
"$ref": "#/components/schemas/PolicyData"
|
||||
},
|
||||
"links": {
|
||||
"description": "Related links",
|
||||
"$ref": "#/components/schemas/SelfLinks"
|
||||
}
|
||||
}
|
||||
},
|
||||
"PolicyData": {
|
||||
"description": "The policy data",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"created_at",
|
||||
"data"
|
||||
],
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"description": "The creation date of the policy data",
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"data": {
|
||||
"description": "The policy data content"
|
||||
}
|
||||
}
|
||||
},
|
||||
"UserFilter": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -3252,6 +3534,10 @@
|
||||
"name": "compat-session",
|
||||
"description": "Manage compatibility sessions from legacy clients"
|
||||
},
|
||||
{
|
||||
"name": "policy-data",
|
||||
"description": "Manage the dynamic policy data"
|
||||
},
|
||||
{
|
||||
"name": "oauth2-session",
|
||||
"description": "Manage OAuth2 sessions"
|
||||
|
||||
Reference in New Issue
Block a user