Make the admin API update the local policy data

This commit is contained in:
Quentin Gliech
2025-02-25 16:22:42 +01:00
parent fe789884ab
commit 518a366ee2
7 changed files with 66 additions and 3 deletions

View File

@@ -204,6 +204,12 @@ impl FromRef<AppState> for Limiter {
}
}
impl FromRef<AppState> for Arc<PolicyFactory> {
fn from_ref(input: &AppState) -> Self {
input.policy_factory.clone()
}
}
impl FromRef<AppState> for BoxHomeserverConnection {
fn from_ref(input: &AppState) -> Self {
Box::new(input.homeserver_connection.clone())

View File

@@ -4,6 +4,8 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::sync::Arc;
use aide::{
axum::ApiRouter,
openapi::{OAuth2Flow, OAuth2Flows, OpenApi, SecurityScheme, Server, Tag},
@@ -20,6 +22,7 @@ use indexmap::IndexMap;
use mas_axum_utils::FancyError;
use mas_http::CorsLayerExt;
use mas_matrix::BoxHomeserverConnection;
use mas_policy::PolicyFactory;
use mas_router::{
ApiDoc, ApiDocCallback, OAuth2AuthorizationEndpoint, OAuth2TokenEndpoint, Route, SimpleRoute,
UrlBuilder,
@@ -118,6 +121,7 @@ where
CallContext: FromRequestParts<S>,
Templates: FromRef<S>,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
{
// We *always* want to explicitly set the possible responses, beacuse the
// infered ones are not necessarily correct

View File

@@ -4,12 +4,15 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::sync::Arc;
use aide::axum::{
ApiRouter,
routing::{get_with, post_with},
};
use axum::extract::{FromRef, FromRequestParts};
use mas_matrix::BoxHomeserverConnection;
use mas_policy::PolicyFactory;
use mas_storage::BoxRng;
use super::call_context::CallContext;
@@ -28,6 +31,7 @@ where
S: Clone + Send + Sync + 'static,
BoxHomeserverConnection: FromRef<S>,
PasswordManager: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
BoxRng: FromRequestParts<S>,
CallContext: FromRequestParts<S>,
{

View File

@@ -2,9 +2,12 @@
//
// SPDX-License-Identifier: AGPL-3.0-only
use std::sync::Arc;
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use axum::{Json, extract::State, response::IntoResponse};
use hyper::StatusCode;
use mas_policy::PolicyFactory;
use mas_storage::BoxRng;
use schemars::JsonSchema;
use serde::Deserialize;
@@ -21,6 +24,9 @@ use crate::{
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error("Failed to instanciate policy with the provided data")]
InvalidPolicyData(#[from] mas_policy::LoadError),
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
}
@@ -30,7 +36,10 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let status = StatusCode::INTERNAL_SERVER_ERROR;
let status = match self {
RouteError::InvalidPolicyData(_) => StatusCode::BAD_REQUEST,
RouteError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
(status, Json(error)).into_response()
}
}
@@ -62,6 +71,12 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
t.description("Policy data was successfully set")
.example(response)
})
.response_with::<400, Json<ErrorResponse>, _>(|t| {
let error = ErrorResponse::from_error(&RouteError::InvalidPolicyData(
mas_policy::LoadError::invalid_data_example(),
));
t.description("Invalid policy data").example(error)
})
}
#[tracing::instrument(name = "handler.admin.v1.policy_data.set", skip_all, err)]
@@ -70,6 +85,7 @@ pub async fn handler(
mut repo, clock, ..
}: CallContext,
NoApi(mut rng): NoApi<BoxRng>,
State(policy_factory): State<Arc<PolicyFactory>>,
Json(request): Json<SetPolicyDataRequest>,
) -> Result<(StatusCode, Json<SingleResponse<PolicyData>>), RouteError> {
let policy_data = repo
@@ -77,6 +93,9 @@ pub async fn handler(
.set(&mut rng, &clock, request.data)
.await?;
// Swap the policy data. This will fail if the policy data is invalid
policy_factory.set_dynamic_data(policy_data.clone()).await?;
repo.save().await?;
Ok((

View File

@@ -13,7 +13,7 @@
)]
#![warn(clippy::pedantic)]
use std::io::Write;
use std::{io::Write, sync::Arc};
use aide::openapi::{Server, ServerVariable};
use indexmap::IndexMap;
@@ -58,6 +58,7 @@ impl_from_ref!(mas_templates::Templates);
impl_from_ref!(mas_matrix::BoxHomeserverConnection);
impl_from_ref!(mas_keystore::Keystore);
impl_from_ref!(mas_handlers::passwords::PasswordManager);
impl_from_ref!(Arc<mas_policy::PolicyFactory>);
fn main() -> Result<(), Box<dyn std::error::Error>> {
let (mut api, _) = mas_handlers::admin_api_router::<DummyState>();

View File

@@ -513,6 +513,12 @@ impl FromRef<TestState> for SiteConfig {
}
}
impl FromRef<TestState> for Arc<PolicyFactory> {
fn from_ref(input: &TestState) -> Self {
input.policy_factory.clone()
}
}
impl FromRef<TestState> for BoxHomeserverConnection {
fn from_ref(input: &TestState) -> Self {
Box::new(input.homeserver_connection.clone())

View File

@@ -640,6 +640,29 @@
}
}
}
},
"400": {
"description": "Invalid policy data",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"errors": [
{
"title": "Failed to instanciate policy with the provided data"
},
{
"title": "invalid policy data"
},
{
"title": "Failed to merge policy data objects"
}
]
}
}
}
}
}
}