New logging output (#4424)
This commit is contained in:
21
Cargo.lock
generated
21
Cargo.lock
generated
@@ -3159,12 +3159,14 @@ dependencies = [
|
||||
"dotenvy",
|
||||
"figment",
|
||||
"futures-util",
|
||||
"headers",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"ipnetwork",
|
||||
"itertools 0.14.0",
|
||||
"listenfd",
|
||||
"mas-config",
|
||||
"mas-context",
|
||||
"mas-data-model",
|
||||
"mas-email",
|
||||
"mas-handlers",
|
||||
@@ -3246,6 +3248,22 @@ dependencies = [
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mas-context"
|
||||
version = "0.15.0"
|
||||
dependencies = [
|
||||
"console",
|
||||
"opentelemetry",
|
||||
"pin-project-lite",
|
||||
"quanta",
|
||||
"tokio",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mas-data-model"
|
||||
version = "0.15.0"
|
||||
@@ -3306,6 +3324,7 @@ dependencies = [
|
||||
"lettre",
|
||||
"mas-axum-utils",
|
||||
"mas-config",
|
||||
"mas-context",
|
||||
"mas-data-model",
|
||||
"mas-http",
|
||||
"mas-i18n",
|
||||
@@ -3504,6 +3523,7 @@ dependencies = [
|
||||
"http-body",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"mas-context",
|
||||
"pin-project-lite",
|
||||
"rustls-pemfile",
|
||||
"socket2",
|
||||
@@ -3674,6 +3694,7 @@ dependencies = [
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"cron",
|
||||
"mas-context",
|
||||
"mas-data-model",
|
||||
"mas-email",
|
||||
"mas-i18n",
|
||||
|
||||
21
Cargo.toml
21
Cargo.toml
@@ -30,6 +30,7 @@ broken_intra_doc_links = "deny"
|
||||
mas-axum-utils = { path = "./crates/axum-utils/", version = "=0.15.0" }
|
||||
mas-cli = { path = "./crates/cli/", version = "=0.15.0" }
|
||||
mas-config = { path = "./crates/config/", version = "=0.15.0" }
|
||||
mas-context = { path = "./crates/context/", version = "=0.15.0" }
|
||||
mas-data-model = { path = "./crates/data-model/", version = "=0.15.0" }
|
||||
mas-email = { path = "./crates/email/", version = "=0.15.0" }
|
||||
mas-graphql = { path = "./crates/graphql/", version = "=0.15.0" }
|
||||
@@ -111,6 +112,10 @@ version = "1.1.9"
|
||||
[workspace.dependencies.compact_str]
|
||||
version = "0.9.0"
|
||||
|
||||
# Terminal formatting
|
||||
[workspace.dependencies.console]
|
||||
version = "0.15.11"
|
||||
|
||||
# Time utilities
|
||||
[workspace.dependencies.chrono]
|
||||
version = "0.4.40"
|
||||
@@ -248,6 +253,10 @@ features = ["std"]
|
||||
version = "0.7.0"
|
||||
features = ["std"]
|
||||
|
||||
# Pin projection
|
||||
[workspace.dependencies.pin-project-lite]
|
||||
version = "0.2.16"
|
||||
|
||||
# PKCS#1 encoding
|
||||
[workspace.dependencies.pkcs1]
|
||||
version = "0.7.5"
|
||||
@@ -258,6 +267,10 @@ features = ["std"]
|
||||
version = "0.10.2"
|
||||
features = ["std", "pkcs5", "encryption"]
|
||||
|
||||
# High-precision clock
|
||||
[workspace.dependencies.quanta]
|
||||
version = "0.12.5"
|
||||
|
||||
# Random values
|
||||
[workspace.dependencies.rand]
|
||||
version = "0.8.5"
|
||||
@@ -374,6 +387,14 @@ features = ["rt"]
|
||||
version = "0.5.2"
|
||||
features = ["util"]
|
||||
|
||||
# Tower service trait
|
||||
[workspace.dependencies.tower-service]
|
||||
version = "0.3.3"
|
||||
|
||||
# Tower layer trait
|
||||
[workspace.dependencies.tower-layer]
|
||||
version = "0.3.3"
|
||||
|
||||
# Tower HTTP layers
|
||||
[workspace.dependencies.tower-http]
|
||||
version = "0.6.2"
|
||||
|
||||
@@ -28,6 +28,8 @@ use serde::{Deserialize, de::DeserializeOwned};
|
||||
use serde_json::Value;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::record_error;
|
||||
|
||||
static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -97,7 +99,7 @@ impl Credentials {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the credentials are invalid.
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn verify(
|
||||
&self,
|
||||
http_client: &reqwest::Client,
|
||||
@@ -144,7 +146,7 @@ impl Credentials {
|
||||
|
||||
let jwks = fetch_jwks(http_client, jwks)
|
||||
.await
|
||||
.map_err(|_| CredentialsVerificationError::JwksFetchFailed)?;
|
||||
.map_err(CredentialsVerificationError::JwksFetchFailed)?;
|
||||
|
||||
jwt.verify_with_jwks(&jwks)
|
||||
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
|
||||
@@ -214,7 +216,18 @@ pub enum CredentialsVerificationError {
|
||||
InvalidAssertionSignature,
|
||||
|
||||
#[error("failed to fetch jwks")]
|
||||
JwksFetchFailed,
|
||||
JwksFetchFailed(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
}
|
||||
|
||||
impl CredentialsVerificationError {
|
||||
/// Returns true if the error is an internal error, not caused by the client
|
||||
#[must_use]
|
||||
pub fn is_internal(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::DecryptionError | Self::InvalidClientConfig | Self::JwksFetchFailed(_)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
@@ -231,23 +244,40 @@ impl<F> ClientAuthorization<F> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ClientAuthorizationError {
|
||||
#[error("Invalid Authorization header")]
|
||||
InvalidHeader,
|
||||
BadForm(FailedToDeserializeForm),
|
||||
|
||||
#[error("Could not deserialize request body")]
|
||||
BadForm(#[source] FailedToDeserializeForm),
|
||||
|
||||
#[error("client_id in form ({form:?}) does not match credential ({credential:?})")]
|
||||
ClientIdMismatch { credential: String, form: String },
|
||||
|
||||
#[error("Unsupported client_assertion_type: {client_assertion_type}")]
|
||||
UnsupportedClientAssertion { client_assertion_type: String },
|
||||
|
||||
#[error("No credentials were presented")]
|
||||
MissingCredentials,
|
||||
|
||||
#[error("Invalid request")]
|
||||
InvalidRequest,
|
||||
|
||||
#[error("Invalid client_assertion")]
|
||||
InvalidAssertion,
|
||||
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error>),
|
||||
}
|
||||
|
||||
impl IntoResponse for ClientAuthorizationError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
match &self {
|
||||
ClientAuthorizationError::InvalidHeader => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
sentry_event_id,
|
||||
Json(ClientError::new(
|
||||
ClientErrorCode::InvalidRequest,
|
||||
"Invalid Authorization header",
|
||||
@@ -256,39 +286,34 @@ impl IntoResponse for ClientAuthorizationError {
|
||||
|
||||
ClientAuthorizationError::BadForm(err) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
sentry_event_id,
|
||||
Json(
|
||||
ClientError::from(ClientErrorCode::InvalidRequest)
|
||||
.with_description(format!("{err}")),
|
||||
),
|
||||
),
|
||||
|
||||
ClientAuthorizationError::ClientIdMismatch { form, credential } => {
|
||||
let description = format!(
|
||||
"client_id in form ({form:?}) does not match credential ({credential:?})"
|
||||
);
|
||||
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
ClientError::from(ClientErrorCode::InvalidGrant)
|
||||
.with_description(description),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
ClientAuthorizationError::UnsupportedClientAssertion {
|
||||
client_assertion_type,
|
||||
} => (
|
||||
ClientAuthorizationError::ClientIdMismatch { .. } => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
sentry_event_id,
|
||||
Json(
|
||||
ClientError::from(ClientErrorCode::InvalidRequest).with_description(format!(
|
||||
"Unsupported client_assertion_type: {client_assertion_type}",
|
||||
)),
|
||||
ClientError::from(ClientErrorCode::InvalidGrant)
|
||||
.with_description(format!("{self}")),
|
||||
),
|
||||
),
|
||||
|
||||
ClientAuthorizationError::UnsupportedClientAssertion { .. } => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
sentry_event_id,
|
||||
Json(
|
||||
ClientError::from(ClientErrorCode::InvalidRequest)
|
||||
.with_description(format!("{self}")),
|
||||
),
|
||||
),
|
||||
|
||||
ClientAuthorizationError::MissingCredentials => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
sentry_event_id,
|
||||
Json(ClientError::new(
|
||||
ClientErrorCode::InvalidRequest,
|
||||
"No credentials were presented",
|
||||
@@ -297,11 +322,13 @@ impl IntoResponse for ClientAuthorizationError {
|
||||
|
||||
ClientAuthorizationError::InvalidRequest => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
sentry_event_id,
|
||||
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
|
||||
),
|
||||
|
||||
ClientAuthorizationError::InvalidAssertion => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
sentry_event_id,
|
||||
Json(ClientError::new(
|
||||
ClientErrorCode::InvalidRequest,
|
||||
"Invalid client_assertion",
|
||||
@@ -310,6 +337,7 @@ impl IntoResponse for ClientAuthorizationError {
|
||||
|
||||
ClientAuthorizationError::Internal(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
sentry_event_id,
|
||||
Json(
|
||||
ClientError::from(ClientErrorCode::ServerError)
|
||||
.with_description(format!("{e}")),
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use http::StatusCode;
|
||||
|
||||
use crate::record_error;
|
||||
|
||||
/// A simple wrapper around an error that implements [`IntoResponse`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
@@ -14,10 +16,16 @@ pub struct ErrorWrapper<T>(#[from] pub T);
|
||||
|
||||
impl<T> IntoResponse for ErrorWrapper<T>
|
||||
where
|
||||
T: std::error::Error,
|
||||
T: std::error::Error + 'static,
|
||||
{
|
||||
fn into_response(self) -> Response {
|
||||
// TODO: make this a bit more user friendly
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
|
||||
let sentry_event_id = record_error!(self.0);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
sentry_event_id,
|
||||
self.0.to_string(),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,12 +54,13 @@ impl<E: std::fmt::Debug + std::fmt::Display> From<E> for FancyError {
|
||||
|
||||
impl IntoResponse for FancyError {
|
||||
fn into_response(self) -> Response {
|
||||
tracing::error!(message = %self.context);
|
||||
let error = format!("{}", self.context);
|
||||
let event_id = sentry::capture_message(&error, sentry::Level::Error);
|
||||
let event_id = SentryEventID::for_last_event();
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
TypedHeader(ContentType::text()),
|
||||
SentryEventID::from(event_id),
|
||||
event_id,
|
||||
Extension(self.context),
|
||||
error,
|
||||
)
|
||||
|
||||
@@ -13,6 +13,13 @@ use sentry::types::Uuid;
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct SentryEventID(Uuid);
|
||||
|
||||
impl SentryEventID {
|
||||
/// Create a new Sentry event ID header for the last event on the hub.
|
||||
pub fn for_last_event() -> Option<Self> {
|
||||
sentry::last_event_id().map(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Uuid> for SentryEventID {
|
||||
fn from(uuid: Uuid) -> Self {
|
||||
Self(uuid)
|
||||
@@ -28,3 +35,31 @@ impl IntoResponseParts for SentryEventID {
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
/// Record an error. It will emit a tracing event with the error level if
|
||||
/// matches the pattern, warning otherwise. It also returns the Sentry event ID
|
||||
/// if the error was recorded.
|
||||
#[macro_export]
|
||||
macro_rules! record_error {
|
||||
($error:expr, !) => {{
|
||||
tracing::warn!(message = &$error as &dyn std::error::Error);
|
||||
Option::<$crate::sentry::SentryEventID>::None
|
||||
}};
|
||||
|
||||
($error:expr) => {{
|
||||
tracing::error!(message = &$error as &dyn std::error::Error);
|
||||
|
||||
// With the `sentry-tracing` integration, Sentry should have
|
||||
// captured an error, so let's extract the last event ID from the
|
||||
// current hub
|
||||
$crate::sentry::SentryEventID::for_last_event()
|
||||
}};
|
||||
|
||||
($error:expr, $pattern:pat) => {
|
||||
if let $pattern = $error {
|
||||
record_error!($error)
|
||||
} else {
|
||||
record_error!($error, !)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ dialoguer = { version = "0.11.0", default-features = false, features = [
|
||||
dotenvy = "0.15.7"
|
||||
figment.workspace = true
|
||||
futures-util.workspace = true
|
||||
headers.workspace = true
|
||||
http-body-util.workspace = true
|
||||
hyper.workspace = true
|
||||
ipnetwork = "0.20.0"
|
||||
@@ -66,6 +67,7 @@ sentry-tracing.workspace = true
|
||||
sentry-tower.workspace = true
|
||||
|
||||
mas-config.workspace = true
|
||||
mas-context.workspace = true
|
||||
mas-data-model.workspace = true
|
||||
mas-email.workspace = true
|
||||
mas-handlers.workspace = true
|
||||
|
||||
@@ -8,6 +8,7 @@ use std::{convert::Infallible, net::IpAddr, sync::Arc, time::Instant};
|
||||
|
||||
use axum::extract::{FromRef, FromRequestParts};
|
||||
use ipnetwork::IpNetwork;
|
||||
use mas_context::LogContext;
|
||||
use mas_data_model::SiteConfig;
|
||||
use mas_handlers::{
|
||||
ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper, GraphQLSchema, Limiter,
|
||||
@@ -92,35 +93,36 @@ impl AppState {
|
||||
let http_client = self.http_client.clone();
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let conn = match pool.acquire().await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
LogContext::new("metadata-cache-warmup")
|
||||
.run(async move || {
|
||||
let conn = match pool.acquire().await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = &e as &dyn std::error::Error,
|
||||
"Failed to acquire a database connection"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut repo = PgRepository::from_conn(conn);
|
||||
|
||||
if let Err(e) = metadata_cache
|
||||
.warm_up_and_run(
|
||||
&http_client,
|
||||
std::time::Duration::from_secs(60 * 15),
|
||||
&mut repo,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
error = &e as &dyn std::error::Error,
|
||||
"Failed to acquire a database connection"
|
||||
"Failed to warm up the metadata cache"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut repo = PgRepository::from_conn(conn);
|
||||
|
||||
if let Err(e) = metadata_cache
|
||||
.warm_up_and_run(
|
||||
&http_client,
|
||||
std::time::Duration::from_secs(60 * 15),
|
||||
&mut repo,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
error = &e as &dyn std::error::Error,
|
||||
"Failed to warm up the metadata cache"
|
||||
);
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("metadata_cache.background_warmup")),
|
||||
})
|
||||
.instrument(tracing::info_span!("metadata_cache.background_warmup")),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,7 +143,8 @@ impl Options {
|
||||
prune,
|
||||
dry_run,
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
.context("could not sync the configuration with the database")?;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -255,7 +255,13 @@ impl Options {
|
||||
};
|
||||
|
||||
repo.into_inner().commit().await?;
|
||||
info!(?email, "Email added");
|
||||
info!(
|
||||
%user.id,
|
||||
%user.username,
|
||||
%email.id,
|
||||
%email.email,
|
||||
"Email added"
|
||||
);
|
||||
|
||||
Ok(ExitCode::SUCCESS)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ use itertools::Itertools;
|
||||
use mas_config::{
|
||||
AppConfig, ClientsConfig, ConfigurationSection, ConfigurationSectionExt, UpstreamOAuth2Config,
|
||||
};
|
||||
use mas_context::LogContext;
|
||||
use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache};
|
||||
use mas_listener::server::Server;
|
||||
use mas_router::UrlBuilder;
|
||||
@@ -112,7 +113,8 @@ impl Options {
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
.context("could not sync the configuration with the database")?;
|
||||
}
|
||||
|
||||
// Initialize the key store
|
||||
@@ -316,11 +318,13 @@ impl Options {
|
||||
|
||||
shutdown
|
||||
.task_tracker()
|
||||
.spawn(mas_listener::server::run_servers(
|
||||
servers,
|
||||
shutdown.soft_shutdown_token(),
|
||||
shutdown.hard_shutdown_token(),
|
||||
));
|
||||
.spawn(LogContext::new("run-servers").run(|| {
|
||||
mas_listener::server::run_servers(
|
||||
servers,
|
||||
shutdown.soft_shutdown_token(),
|
||||
shutdown.hard_shutdown_token(),
|
||||
)
|
||||
}));
|
||||
|
||||
let exit_code = shutdown.run().await;
|
||||
|
||||
|
||||
@@ -146,7 +146,8 @@ impl Options {
|
||||
// Not a dry run — we do want to create the providers in the database
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
.context("could not sync the configuration with the database")?;
|
||||
}
|
||||
|
||||
let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(mas_connection)
|
||||
|
||||
@@ -91,8 +91,7 @@ async fn try_main() -> anyhow::Result<ExitCode> {
|
||||
let (log_writer, _guard) = tracing_appender::non_blocking(output);
|
||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||
.with_writer(log_writer)
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.event_format(mas_context::EventFormatter)
|
||||
.with_ansi(with_ansi);
|
||||
let filter_layer = EnvFilter::try_from_default_env()
|
||||
.or_else(|_| EnvFilter::try_new("info"))
|
||||
@@ -129,9 +128,11 @@ async fn try_main() -> anyhow::Result<ExitCode> {
|
||||
|
||||
let sentry_layer = sentry.is_enabled().then(|| {
|
||||
sentry_tracing::layer().event_filter(|md| {
|
||||
// All the spans in the handlers module send their data to Sentry themselves, so
|
||||
// we only create breadcrumbs for them, instead of full events
|
||||
if md.target().starts_with("mas_handlers::") {
|
||||
// By default, Sentry records all events as breadcrumbs, except errors.
|
||||
//
|
||||
// Because we're emitting error events for 5xx responses, we need to exclude
|
||||
// them and also record them as breadcrumbs.
|
||||
if md.name() == "http.server.response" {
|
||||
EventFilter::Breadcrumb
|
||||
} else {
|
||||
sentry_tracing::default_event_filter(md)
|
||||
|
||||
@@ -16,12 +16,14 @@ use axum::{
|
||||
error_handling::HandleErrorLayer,
|
||||
extract::{FromRef, MatchedPath},
|
||||
};
|
||||
use headers::{HeaderMapExt as _, UserAgent};
|
||||
use hyper::{
|
||||
Method, Request, Response, StatusCode, Version,
|
||||
header::{CACHE_CONTROL, HeaderValue, USER_AGENT},
|
||||
};
|
||||
use listenfd::ListenFd;
|
||||
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
|
||||
use mas_context::LogContext;
|
||||
use mas_listener::{ConnectionInfo, unix_or_tcp::UnixOrTcpListener};
|
||||
use mas_router::Route;
|
||||
use mas_templates::Templates;
|
||||
@@ -170,6 +172,45 @@ fn on_http_response_labels<B>(res: &Response<B>) -> Vec<KeyValue> {
|
||||
)]
|
||||
}
|
||||
|
||||
async fn log_response_middleware(
|
||||
request: axum::extract::Request,
|
||||
next: axum::middleware::Next,
|
||||
) -> axum::response::Response {
|
||||
let user_agent: Option<UserAgent> = request.headers().typed_get();
|
||||
let user_agent = user_agent.as_ref().map_or("-", |u| u.as_str());
|
||||
let method = otel_http_method(&request);
|
||||
let path = request.uri().path().to_owned();
|
||||
let version = otel_net_protocol_version(&request);
|
||||
|
||||
let response = next.run(request).await;
|
||||
|
||||
let Some(log_context) = LogContext::current() else {
|
||||
tracing::error!("Missing log context for request, this is a bug!");
|
||||
return response;
|
||||
};
|
||||
|
||||
let stats = log_context.stats();
|
||||
|
||||
let status_code = response.status();
|
||||
match status_code.as_u16() {
|
||||
100..=399 => tracing::info!(
|
||||
name: "http.server.response",
|
||||
"\"{method} {path} HTTP/{version}\" {status_code} {user_agent:?} [{stats}]",
|
||||
),
|
||||
400..=499 => tracing::warn!(
|
||||
name: "http.server.response",
|
||||
"\"{method} {path} HTTP/{version}\" {status_code} {user_agent:?} [{stats}]",
|
||||
),
|
||||
500..=599 => tracing::error!(
|
||||
name: "http.server.response",
|
||||
"\"{method} {path} HTTP/{version}\" {status_code} {user_agent:?} [{stats}]",
|
||||
),
|
||||
_ => { /* This shouldn't happen */ }
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
pub fn build_router(
|
||||
state: AppState,
|
||||
resources: &[HttpResource],
|
||||
@@ -277,8 +318,12 @@ pub fn build_router(
|
||||
span.record("otel.status_code", "OK");
|
||||
}),
|
||||
)
|
||||
.layer(SentryHttpLayer::new())
|
||||
.layer(axum::middleware::from_fn(log_response_middleware))
|
||||
.layer(mas_context::LogContextLayer::new(|req| {
|
||||
otel_http_method(req).into()
|
||||
}))
|
||||
.layer(NewSentryLayer::new_from_top())
|
||||
.layer(SentryHttpLayer::with_transaction())
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ fn map_claims_imports(
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "config.sync", skip_all, err(Debug))]
|
||||
#[tracing::instrument(name = "config.sync", skip_all)]
|
||||
pub async fn config_sync(
|
||||
upstream_oauth2_config: UpstreamOAuth2Config,
|
||||
clients_config: ClientsConfig,
|
||||
@@ -175,11 +175,11 @@ pub async fn config_sync(
|
||||
|
||||
let _span = info_span!("provider", %provider.id).entered();
|
||||
if existing_enabled_ids.contains(&provider.id) {
|
||||
info!("Updating provider");
|
||||
info!(provider.id = %provider.id, "Updating provider");
|
||||
} else if existing_disabled.contains_key(&provider.id) {
|
||||
info!("Enabling and updating provider");
|
||||
info!(provider.id = %provider.id, "Enabling and updating provider");
|
||||
} else {
|
||||
info!("Adding provider");
|
||||
info!(provider.id = %provider.id, "Adding provider");
|
||||
}
|
||||
|
||||
if dry_run {
|
||||
@@ -252,15 +252,15 @@ pub async fn config_sync(
|
||||
|
||||
if discovery_mode.is_disabled() {
|
||||
if provider.authorization_endpoint.is_none() {
|
||||
error!("Provider has discovery disabled but no authorization endpoint set");
|
||||
error!(provider.id = %provider.id, "Provider has discovery disabled but no authorization endpoint set");
|
||||
}
|
||||
|
||||
if provider.token_endpoint.is_none() {
|
||||
error!("Provider has discovery disabled but no token endpoint set");
|
||||
error!(provider.id = %provider.id, "Provider has discovery disabled but no token endpoint set");
|
||||
}
|
||||
|
||||
if provider.jwks_uri.is_none() {
|
||||
warn!("Provider has discovery disabled but no JWKS URI set");
|
||||
warn!(provider.id = %provider.id, "Provider has discovery disabled but no JWKS URI set");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,9 +347,9 @@ pub async fn config_sync(
|
||||
for client in clients_config {
|
||||
let _span = info_span!("client", client.id = %client.client_id).entered();
|
||||
if existing_ids.contains(&client.client_id) {
|
||||
info!("Updating client");
|
||||
info!(client.id = %client.client_id, "Updating client");
|
||||
} else {
|
||||
info!("Adding client");
|
||||
info!(client.id = %client.client_id, "Adding client");
|
||||
}
|
||||
|
||||
if dry_run {
|
||||
|
||||
@@ -12,6 +12,7 @@ use mas_config::{
|
||||
EmailTransportKind, ExperimentalConfig, HomeserverKind, MatrixConfig, PasswordsConfig,
|
||||
PolicyConfig, TemplatesConfig,
|
||||
};
|
||||
use mas_context::LogContext;
|
||||
use mas_data_model::{SessionExpirationConfig, SiteConfig};
|
||||
use mas_email::{MailTransport, Mailer};
|
||||
use mas_handlers::passwords::PasswordManager;
|
||||
@@ -21,7 +22,7 @@ use mas_policy::PolicyFactory;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::RepositoryAccess;
|
||||
use mas_storage_pg::PgRepository;
|
||||
use mas_templates::{SiteConfigExt, TemplateLoadingError, Templates};
|
||||
use mas_templates::{SiteConfigExt, Templates};
|
||||
use sqlx::{
|
||||
ConnectOptions, Executor, PgConnection, PgPool,
|
||||
postgres::{PgConnectOptions, PgPoolOptions},
|
||||
@@ -109,20 +110,23 @@ pub fn test_mailer_in_background(mailer: &Mailer, timeout: Duration) {
|
||||
let mailer = mailer.clone();
|
||||
|
||||
let span = tracing::info_span!("cli.test_mailer");
|
||||
tokio::spawn(async move {
|
||||
match tokio::time::timeout(timeout, mailer.test_connection()).await {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) => {
|
||||
tracing::warn!(
|
||||
error = &err as &dyn std::error::Error,
|
||||
"Could not connect to the mail backend, tasks sending mails may fail!"
|
||||
);
|
||||
tokio::spawn(
|
||||
LogContext::new("mailer-test").run(async move || {
|
||||
match tokio::time::timeout(timeout, mailer.test_connection()).await {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) => {
|
||||
tracing::warn!(
|
||||
error = &err as &dyn std::error::Error,
|
||||
"Could not connect to the mail backend, tasks sending mails may fail!"
|
||||
);
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!("Timed out while testing the mail backend connection, tasks sending mails may fail!");
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!("Timed out while testing the mail backend connection, tasks sending mails may fail!");
|
||||
}
|
||||
}
|
||||
}.instrument(span));
|
||||
})
|
||||
.instrument(span)
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn policy_factory_from_config(
|
||||
@@ -222,7 +226,7 @@ pub async fn templates_from_config(
|
||||
config: &TemplatesConfig,
|
||||
site_config: &SiteConfig,
|
||||
url_builder: &UrlBuilder,
|
||||
) -> Result<Templates, TemplateLoadingError> {
|
||||
) -> Result<Templates, anyhow::Error> {
|
||||
Templates::load(
|
||||
config.path.clone(),
|
||||
url_builder.clone(),
|
||||
@@ -232,6 +236,7 @@ pub async fn templates_from_config(
|
||||
site_config.templates_features(),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to load the templates at {}", config.path))
|
||||
}
|
||||
|
||||
fn database_connect_options_from_config(
|
||||
@@ -331,7 +336,7 @@ fn database_connect_options_from_config(
|
||||
}
|
||||
|
||||
/// Create a database connection pool from the configuration
|
||||
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
|
||||
#[tracing::instrument(name = "db.connect", skip_all)]
|
||||
pub async fn database_pool_from_config(config: &DatabaseConfig) -> Result<PgPool, anyhow::Error> {
|
||||
let options = database_connect_options_from_config(config, &DatabaseConnectOptions::default())?;
|
||||
PgPoolOptions::new()
|
||||
@@ -367,7 +372,7 @@ impl Default for DatabaseConnectOptions {
|
||||
}
|
||||
|
||||
/// Create a single database connection from the configuration
|
||||
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
|
||||
#[tracing::instrument(name = "db.connect", skip_all)]
|
||||
pub async fn database_connection_from_config(
|
||||
config: &DatabaseConfig,
|
||||
) -> Result<PgConnection, anyhow::Error> {
|
||||
@@ -379,7 +384,7 @@ pub async fn database_connection_from_config(
|
||||
|
||||
/// Create a single database connection from the configuration,
|
||||
/// with specific options.
|
||||
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
|
||||
#[tracing::instrument(name = "db.connect", skip_all)]
|
||||
pub async fn database_connection_from_config_with_options(
|
||||
config: &DatabaseConfig,
|
||||
options: &DatabaseConnectOptions,
|
||||
@@ -430,7 +435,7 @@ pub async fn load_policy_factory_dynamic_data_continuously(
|
||||
}
|
||||
|
||||
/// Update the policy factory dynamic data from the database
|
||||
#[tracing::instrument(name = "policy.load_dynamic_data", skip_all, err(Debug))]
|
||||
#[tracing::instrument(name = "policy.load_dynamic_data", skip_all)]
|
||||
pub async fn load_policy_factory_dynamic_data(
|
||||
policy_factory: &PolicyFactory,
|
||||
pool: &PgPool,
|
||||
|
||||
@@ -70,7 +70,7 @@ impl SecretsConfig {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error when a key could not be imported
|
||||
#[tracing::instrument(name = "secrets.load", skip_all, err(Debug))]
|
||||
#[tracing::instrument(name = "secrets.load", skip_all)]
|
||||
pub async fn key_store(&self) -> anyhow::Result<Keystore> {
|
||||
let mut keys = Vec::with_capacity(self.keys.len());
|
||||
for item in &self.keys {
|
||||
|
||||
24
crates/context/Cargo.toml
Normal file
24
crates/context/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "mas-context"
|
||||
version.workspace = true
|
||||
authors.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
homepage.workspace = true
|
||||
repository.workspace = true
|
||||
publish = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
console.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
quanta.workspace = true
|
||||
tokio.workspace = true
|
||||
tower-service.workspace = true
|
||||
tower-layer.workspace = true
|
||||
tracing.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing-opentelemetry.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
148
crates/context/src/fmt.rs
Normal file
148
crates/context/src/fmt.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use console::{Color, Style};
|
||||
use opentelemetry::{TraceId, trace::TraceContextExt};
|
||||
use tracing::{Level, Subscriber};
|
||||
use tracing_opentelemetry::OtelData;
|
||||
use tracing_subscriber::{
|
||||
fmt::{
|
||||
FormatEvent, FormatFields,
|
||||
format::{DefaultFields, Writer},
|
||||
time::{FormatTime, SystemTime},
|
||||
},
|
||||
registry::LookupSpan,
|
||||
};
|
||||
|
||||
use crate::LogContext;
|
||||
|
||||
/// An event formatter usable by the [`tracing-subscriber`] crate, which
|
||||
/// includes the log context and the OTEL trace ID.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct EventFormatter;
|
||||
|
||||
struct FmtLevel<'a> {
|
||||
level: &'a Level,
|
||||
ansi: bool,
|
||||
}
|
||||
|
||||
impl<'a> FmtLevel<'a> {
|
||||
pub(crate) fn new(level: &'a Level, ansi: bool) -> Self {
|
||||
Self { level, ansi }
|
||||
}
|
||||
}
|
||||
|
||||
const TRACE_STR: &str = "TRACE";
|
||||
const DEBUG_STR: &str = "DEBUG";
|
||||
const INFO_STR: &str = " INFO";
|
||||
const WARN_STR: &str = " WARN";
|
||||
const ERROR_STR: &str = "ERROR";
|
||||
|
||||
const TRACE_STYLE: Style = Style::new().fg(Color::Magenta);
|
||||
const DEBUG_STYLE: Style = Style::new().fg(Color::Blue);
|
||||
const INFO_STYLE: Style = Style::new().fg(Color::Green);
|
||||
const WARN_STYLE: Style = Style::new().fg(Color::Yellow);
|
||||
const ERROR_STYLE: Style = Style::new().fg(Color::Red);
|
||||
|
||||
impl std::fmt::Display for FmtLevel<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let msg = match *self.level {
|
||||
Level::TRACE => TRACE_STYLE.force_styling(self.ansi).apply_to(TRACE_STR),
|
||||
Level::DEBUG => DEBUG_STYLE.force_styling(self.ansi).apply_to(DEBUG_STR),
|
||||
Level::INFO => INFO_STYLE.force_styling(self.ansi).apply_to(INFO_STR),
|
||||
Level::WARN => WARN_STYLE.force_styling(self.ansi).apply_to(WARN_STR),
|
||||
Level::ERROR => ERROR_STYLE.force_styling(self.ansi).apply_to(ERROR_STR),
|
||||
};
|
||||
write!(f, "{msg}")
|
||||
}
|
||||
}
|
||||
|
||||
struct TargetFmt<'a> {
|
||||
target: &'a str,
|
||||
line: Option<u32>,
|
||||
}
|
||||
|
||||
impl<'a> TargetFmt<'a> {
|
||||
pub(crate) fn new(metadata: &tracing::Metadata<'a>) -> Self {
|
||||
Self {
|
||||
target: metadata.target(),
|
||||
line: metadata.line(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TargetFmt<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.target)?;
|
||||
if let Some(line) = self.line {
|
||||
write!(f, ":{line}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, N> FormatEvent<S, N> for EventFormatter
|
||||
where
|
||||
S: Subscriber + for<'a> LookupSpan<'a>,
|
||||
N: for<'writer> FormatFields<'writer> + 'static,
|
||||
{
|
||||
fn format_event(
|
||||
&self,
|
||||
ctx: &tracing_subscriber::fmt::FmtContext<'_, S, N>,
|
||||
mut writer: Writer<'_>,
|
||||
event: &tracing::Event<'_>,
|
||||
) -> std::fmt::Result {
|
||||
let ansi = writer.has_ansi_escapes();
|
||||
let metadata = event.metadata();
|
||||
|
||||
SystemTime.format_time(&mut writer)?;
|
||||
|
||||
let level = FmtLevel::new(metadata.level(), ansi);
|
||||
write!(&mut writer, " {level} ")?;
|
||||
|
||||
// If there is no explicit 'name' set in the event macro, it will have the
|
||||
// 'event {filename}:{line}' value. In this case, we want to display the target:
|
||||
// the module from where it was emitted. In other cases, we want to
|
||||
// display the explit name of the event we have set.
|
||||
let style = Style::new().dim().force_styling(ansi);
|
||||
if metadata.name().starts_with("event ") {
|
||||
write!(&mut writer, "{} ", style.apply_to(TargetFmt::new(metadata)))?;
|
||||
} else {
|
||||
write!(&mut writer, "{} ", style.apply_to(metadata.name()))?;
|
||||
}
|
||||
|
||||
if let Some(log_context) = LogContext::current() {
|
||||
let log_context = Style::new()
|
||||
.bold()
|
||||
.force_styling(ansi)
|
||||
.apply_to(log_context);
|
||||
write!(&mut writer, "{log_context} - ")?;
|
||||
}
|
||||
|
||||
let field_fromatter = DefaultFields::new();
|
||||
field_fromatter.format_fields(writer.by_ref(), event)?;
|
||||
|
||||
// If we have a OTEL span, we can add the trace ID to the end of the log line
|
||||
if let Some(span) = ctx.lookup_current() {
|
||||
if let Some(otel) = span.extensions().get::<OtelData>() {
|
||||
// If it is the root span, the trace ID will be in the span builder. Else, it
|
||||
// will be in the parent OTEL context
|
||||
let trace_id = otel
|
||||
.builder
|
||||
.trace_id
|
||||
.unwrap_or_else(|| otel.parent_cx.span().span_context().trace_id());
|
||||
if trace_id != TraceId::INVALID {
|
||||
let label = Style::new()
|
||||
.italic()
|
||||
.force_styling(ansi)
|
||||
.apply_to("trace.id");
|
||||
write!(&mut writer, " {label}={trace_id}")?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
writeln!(&mut writer)
|
||||
}
|
||||
}
|
||||
59
crates/context/src/future.rs
Normal file
59
crates/context/src/future.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::atomic::Ordering,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use quanta::Instant;
|
||||
use tokio::task::futures::TaskLocalFuture;
|
||||
|
||||
use crate::LogContext;
|
||||
|
||||
pub type LogContextFuture<F> = TaskLocalFuture<crate::LogContext, PollRecordingFuture<F>>;
|
||||
|
||||
impl LogContext {
|
||||
/// Wrap a future with the given log context
|
||||
pub(crate) fn wrap_future<F: Future>(&self, future: F) -> LogContextFuture<F> {
|
||||
let future = PollRecordingFuture::new(future);
|
||||
crate::CURRENT_LOG_CONTEXT.scope(self.clone(), future)
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
/// A future which records the elapsed time and the number of polls in the
|
||||
/// active log context
|
||||
pub struct PollRecordingFuture<F> {
|
||||
#[pin]
|
||||
inner: F,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future> PollRecordingFuture<F> {
|
||||
pub(crate) fn new(inner: F) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future> Future for PollRecordingFuture<F> {
|
||||
type Output = F::Output;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let start = Instant::now();
|
||||
let this = self.project();
|
||||
let result = this.inner.poll(cx);
|
||||
|
||||
// Record the number of polls and the time we spent polling the future
|
||||
let elapsed = start.elapsed().as_nanos().try_into().unwrap_or(u64::MAX);
|
||||
let _ = crate::CURRENT_LOG_CONTEXT.try_with(|c| {
|
||||
c.inner.polls.fetch_add(1, Ordering::Relaxed);
|
||||
c.inner.cpu_time.fetch_add(elapsed, Ordering::Relaxed);
|
||||
});
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
41
crates/context/src/layer.rs
Normal file
41
crates/context/src/layer.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::borrow::Cow;
|
||||
|
||||
use tower_layer::Layer;
|
||||
use tower_service::Service;
|
||||
|
||||
use crate::LogContextService;
|
||||
|
||||
/// A layer which creates a log context for each request.
|
||||
pub struct LogContextLayer<R> {
|
||||
tagger: fn(&R) -> Cow<'static, str>,
|
||||
}
|
||||
|
||||
impl<R> Clone for LogContextLayer<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
tagger: self.tagger,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> LogContextLayer<R> {
|
||||
pub fn new(tagger: fn(&R) -> Cow<'static, str>) -> Self {
|
||||
Self { tagger }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, R> Layer<S> for LogContextLayer<R>
|
||||
where
|
||||
S: Service<R>,
|
||||
{
|
||||
type Service = LogContextService<S, R>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
LogContextService::new(inner, self.tagger)
|
||||
}
|
||||
}
|
||||
149
crates/context/src/lib.rs
Normal file
149
crates/context/src/lib.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
mod fmt;
|
||||
mod future;
|
||||
mod layer;
|
||||
mod service;
|
||||
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU64, Ordering},
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use quanta::Instant;
|
||||
use tokio::task_local;
|
||||
|
||||
pub use self::{
|
||||
fmt::EventFormatter,
|
||||
future::{LogContextFuture, PollRecordingFuture},
|
||||
layer::LogContextLayer,
|
||||
service::LogContextService,
|
||||
};
|
||||
|
||||
/// A counter which increments each time we create a new log context
|
||||
/// It will wrap around if we create more than [`u64::MAX`] contexts
|
||||
static LOG_CONTEXT_INDEX: AtomicU64 = AtomicU64::new(0);
|
||||
task_local! {
|
||||
pub static CURRENT_LOG_CONTEXT: LogContext;
|
||||
}
|
||||
|
||||
/// A log context saves informations about the current task, such as the
|
||||
/// elapsed time, the number of polls, and the poll time.
|
||||
#[derive(Clone)]
|
||||
pub struct LogContext {
|
||||
inner: Arc<LogContextInner>,
|
||||
}
|
||||
|
||||
struct LogContextInner {
|
||||
/// A user-defined tag for the log context
|
||||
tag: Cow<'static, str>,
|
||||
|
||||
/// A unique index for the log context
|
||||
index: u64,
|
||||
|
||||
/// The time when the context was created
|
||||
start: Instant,
|
||||
|
||||
/// The number of [`Future::poll`] recorded
|
||||
polls: AtomicU64,
|
||||
|
||||
/// An approximation of the total CPU time spent in the context, in
|
||||
/// nanoseconds
|
||||
cpu_time: AtomicU64,
|
||||
}
|
||||
|
||||
impl LogContext {
|
||||
/// Create a new log context with the given tag
|
||||
pub fn new(tag: impl Into<Cow<'static, str>>) -> Self {
|
||||
let tag = tag.into();
|
||||
let inner = LogContextInner {
|
||||
tag,
|
||||
index: LOG_CONTEXT_INDEX.fetch_add(1, Ordering::Relaxed),
|
||||
start: Instant::now(),
|
||||
polls: AtomicU64::new(0),
|
||||
cpu_time: AtomicU64::new(0),
|
||||
};
|
||||
|
||||
Self {
|
||||
inner: Arc::new(inner),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a copy of the current log context, if any
|
||||
pub fn current() -> Option<Self> {
|
||||
CURRENT_LOG_CONTEXT.try_with(Self::clone).ok()
|
||||
}
|
||||
|
||||
/// Run the async function `f` with the given log context. It will wrap the
|
||||
/// output future to record poll and CPU statistics.
|
||||
pub fn run<F: FnOnce() -> Fut, Fut: Future>(&self, f: F) -> LogContextFuture<Fut> {
|
||||
let future = self.run_sync(f);
|
||||
self.wrap_future(future)
|
||||
}
|
||||
|
||||
/// Run the sync function `f` with the given log context, recording the CPU
|
||||
/// time spent.
|
||||
pub fn run_sync<F: FnOnce() -> R, R>(&self, f: F) -> R {
|
||||
let start = Instant::now();
|
||||
let result = CURRENT_LOG_CONTEXT.sync_scope(self.clone(), f);
|
||||
let elapsed = start.elapsed().as_nanos().try_into().unwrap_or(u64::MAX);
|
||||
self.inner.cpu_time.fetch_add(elapsed, Ordering::Relaxed);
|
||||
result
|
||||
}
|
||||
|
||||
/// Create a snapshot of the log context statistics
|
||||
#[must_use]
|
||||
pub fn stats(&self) -> LogContextStats {
|
||||
let polls = self.inner.polls.load(Ordering::Relaxed);
|
||||
let cpu_time = self.inner.cpu_time.load(Ordering::Relaxed);
|
||||
let cpu_time = Duration::from_nanos(cpu_time);
|
||||
let elapsed = self.inner.start.elapsed();
|
||||
LogContextStats {
|
||||
polls,
|
||||
cpu_time,
|
||||
elapsed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LogContext {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let tag = &self.inner.tag;
|
||||
let index = self.inner.index;
|
||||
write!(f, "{tag}-{index}")
|
||||
}
|
||||
}
|
||||
|
||||
/// A snapshot of a log context statistics
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct LogContextStats {
|
||||
/// How many times the context was polled
|
||||
pub polls: u64,
|
||||
|
||||
/// The approximate CPU time spent in the context
|
||||
pub cpu_time: Duration,
|
||||
|
||||
/// How much time elapsed since the context was created
|
||||
pub elapsed: Duration,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LogContextStats {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let polls = self.polls;
|
||||
#[expect(clippy::cast_precision_loss)]
|
||||
let cpu_time_ms = self.cpu_time.as_nanos() as f64 / 1_000_000.;
|
||||
#[expect(clippy::cast_precision_loss)]
|
||||
let elapsed_ms = self.elapsed.as_nanos() as f64 / 1_000_000.;
|
||||
write!(
|
||||
f,
|
||||
"polls: {polls}, cpu: {cpu_time_ms:.1}ms, elapsed: {elapsed_ms:.1}ms",
|
||||
)
|
||||
}
|
||||
}
|
||||
54
crates/context/src/service.rs
Normal file
54
crates/context/src/service.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use tower_service::Service;
|
||||
|
||||
use crate::{LogContext, LogContextFuture};
|
||||
|
||||
/// A service which wraps another service and creates a log context for
|
||||
/// each request.
|
||||
pub struct LogContextService<S, R> {
|
||||
inner: S,
|
||||
tagger: fn(&R) -> Cow<'static, str>,
|
||||
}
|
||||
|
||||
impl<S: Clone, R> Clone for LogContextService<S, R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
inner: self.inner.clone(),
|
||||
tagger: self.tagger,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, R> LogContextService<S, R> {
|
||||
pub fn new(inner: S, tagger: fn(&R) -> Cow<'static, str>) -> Self {
|
||||
Self { inner, tagger }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, R> Service<R> for LogContextService<S, R>
|
||||
where
|
||||
S: Service<R>,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = LogContextFuture<S::Future>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: R) -> Self::Future {
|
||||
let tag = (self.tagger)(&req);
|
||||
let log_context = LogContext::new(tag);
|
||||
log_context.run(|| self.inner.call(req))
|
||||
}
|
||||
}
|
||||
@@ -111,7 +111,6 @@ impl Mailer {
|
||||
email.to = %to,
|
||||
email.language = %context.language(),
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn send_verification_email(
|
||||
&self,
|
||||
@@ -137,7 +136,6 @@ impl Mailer {
|
||||
user.id = %context.user().id,
|
||||
user_recovery_session.id = %context.session().id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn send_recovery_email(
|
||||
&self,
|
||||
@@ -154,7 +152,7 @@ impl Mailer {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the connection failed
|
||||
#[tracing::instrument(name = "email.test_connection", skip_all, err)]
|
||||
#[tracing::instrument(name = "email.test_connection", skip_all)]
|
||||
pub async fn test_connection(&self) -> Result<(), crate::transport::Error> {
|
||||
self.transport.test_connection().await
|
||||
}
|
||||
|
||||
@@ -90,6 +90,7 @@ ulid.workspace = true
|
||||
|
||||
mas-axum-utils.workspace = true
|
||||
mas-config.workspace = true
|
||||
mas-context.workspace = true
|
||||
mas-data-model.workspace = true
|
||||
mas-http.workspace = true
|
||||
mas-i18n.workspace = true
|
||||
|
||||
@@ -15,6 +15,7 @@ use axum::{
|
||||
use axum_extra::TypedHeader;
|
||||
use headers::{Authorization, authorization::Bearer};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_data_model::{Session, User};
|
||||
use mas_storage::{BoxClock, BoxRepository, RepositoryError};
|
||||
use ulid::Ulid;
|
||||
@@ -69,27 +70,35 @@ pub enum Rejection {
|
||||
MissingScope,
|
||||
}
|
||||
|
||||
impl Rejection {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Self::InvalidAuthorizationHeader | Self::MissingAuthorizationHeader => {
|
||||
StatusCode::BAD_REQUEST
|
||||
}
|
||||
Self::UnknownAccessToken
|
||||
| Self::TokenExpired
|
||||
| Self::SessionRevoked
|
||||
| Self::UserLocked
|
||||
| Self::MissingScope => StatusCode::UNAUTHORIZED,
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for Rejection {
|
||||
fn into_response(self) -> Response {
|
||||
let response = ErrorResponse::from_error(&self);
|
||||
let status = self.status_code();
|
||||
(status, Json(response)).into_response()
|
||||
let sentry_event_id = record_error!(
|
||||
self,
|
||||
Self::RepositorySetup(_)
|
||||
| Self::Repository(_)
|
||||
| Self::LoadSession(_)
|
||||
| Self::LoadUser(_)
|
||||
);
|
||||
|
||||
let status = match &self {
|
||||
Rejection::InvalidAuthorizationHeader | Rejection::MissingAuthorizationHeader => {
|
||||
StatusCode::BAD_REQUEST
|
||||
}
|
||||
|
||||
Rejection::UnknownAccessToken
|
||||
| Rejection::TokenExpired
|
||||
| Rejection::SessionRevoked
|
||||
| Rejection::UserLocked
|
||||
| Rejection::MissingScope => StatusCode::UNAUTHORIZED,
|
||||
|
||||
Rejection::RepositorySetup(_)
|
||||
| Rejection::Repository(_)
|
||||
| Rejection::LoadSession(_)
|
||||
| Rejection::LoadUser(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
|
||||
(status, sentry_event_id, Json(response)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
@@ -33,11 +34,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let status = match self {
|
||||
let sentry_event_id = record_error!(self, RouteError::Internal(_));
|
||||
let status = match &self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +62,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.compat_sessions.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.compat_sessions.get", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
id: UlidPathParam,
|
||||
|
||||
@@ -11,6 +11,7 @@ use axum::{
|
||||
};
|
||||
use axum_macros::FromRequestParts;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_storage::{Page, compat::CompatSessionFilter};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
@@ -113,12 +114,14 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let status = match self {
|
||||
let sentry_event_id = record_error!(self, RouteError::Internal(_));
|
||||
let status = match &self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::UserNotFound(_) | Self::UserSessionNotFound(_) => StatusCode::NOT_FOUND,
|
||||
Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,7 +156,7 @@ Use the `filter[status]` parameter to filter the sessions by their status and `p
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.compat_sessions.list", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.compat_sessions.list", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
Pagination(pagination): Pagination,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
@@ -34,11 +35,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, RouteError::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,7 +62,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.oauth2_session.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.oauth2_session.get", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
id: UlidPathParam,
|
||||
|
||||
@@ -14,6 +14,7 @@ use axum::{
|
||||
};
|
||||
use axum_macros::FromRequestParts;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_storage::{Page, oauth2::OAuth2SessionFilter};
|
||||
use oauth2_types::scope::{Scope, ScopeToken};
|
||||
use schemars::JsonSchema;
|
||||
@@ -167,6 +168,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, RouteError::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::UserNotFound(_) | Self::ClientNotFound(_) | Self::UserSessionNotFound(_) => {
|
||||
@@ -174,7 +176,7 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
Self::InvalidScope(_) | Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,7 +215,7 @@ Use the `filter[status]` parameter to filter the sessions by their status and `p
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.oauth2_sessions.list", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.oauth2_sessions.list", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
Pagination(pagination): Pagination,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
@@ -32,11 +33,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +58,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.policy_data.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.policy_data.get", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
id: UlidPathParam,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
|
||||
use crate::{
|
||||
admin::{
|
||||
@@ -30,11 +31,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +57,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.policy_data.get_latest", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.policy_data.get_latest", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
) -> Result<Json<SingleResponse<PolicyData>>, RouteError> {
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::sync::Arc;
|
||||
use aide::{NoApi, OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, extract::State, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_storage::BoxRng;
|
||||
use schemars::JsonSchema;
|
||||
@@ -36,11 +37,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
RouteError::InvalidPolicyData(_) => StatusCode::BAD_REQUEST,
|
||||
RouteError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,7 +81,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.policy_data.set", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.policy_data.set", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext {
|
||||
mut repo, clock, ..
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
use aide::{NoApi, OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_storage::BoxRng;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
@@ -41,12 +42,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::LinkAlreadyExists(_, _) => StatusCode::CONFLICT,
|
||||
Self::UserNotFound(_) | Self::ProviderNotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,7 +104,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.post", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.post", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext {
|
||||
mut repo, clock, ..
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
@@ -28,11 +29,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +51,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.delete", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.delete", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext {
|
||||
mut repo, clock, ..
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
@@ -33,11 +34,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_entry_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_entry_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +61,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.get", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
id: UlidPathParam,
|
||||
|
||||
@@ -11,6 +11,7 @@ use axum::{
|
||||
};
|
||||
use axum_macros::FromRequestParts;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_storage::{Page, upstream_oauth2::UpstreamOAuthLinkFilter};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
@@ -91,12 +92,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::UserNotFound(_) | Self::ProviderNotFound(_) => StatusCode::NOT_FOUND,
|
||||
Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,7 +132,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.list", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.list", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
Pagination(pagination): Pagination,
|
||||
|
||||
@@ -8,6 +8,7 @@ use std::str::FromStr as _;
|
||||
use aide::{NoApi, OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_storage::{
|
||||
BoxRng,
|
||||
queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
|
||||
@@ -52,13 +53,14 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::EmailAlreadyInUse(_) => StatusCode::CONFLICT,
|
||||
Self::EmailNotValid { .. } => StatusCode::BAD_REQUEST,
|
||||
Self::UserNotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,7 +108,7 @@ Note that this endpoint ignores any policy which would normally prevent the emai
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.user_emails.add", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.user_emails.add", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext {
|
||||
mut repo, clock, ..
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
use aide::{NoApi, OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_storage::{
|
||||
BoxRng,
|
||||
queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
|
||||
@@ -32,11 +33,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +54,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.user_emails.delete", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.user_emails.delete", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext {
|
||||
mut repo, clock, ..
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
@@ -33,11 +34,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,7 +59,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.user_emails.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.user_emails.get", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
id: UlidPathParam,
|
||||
|
||||
@@ -11,6 +11,7 @@ use axum::{
|
||||
};
|
||||
use axum_macros::FromRequestParts;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_storage::{Page, user::UserEmailFilter};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
@@ -78,12 +79,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::UserNotFound(_) => StatusCode::NOT_FOUND,
|
||||
Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,7 +118,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.user_emails.list", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.user_emails.list", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
Pagination(pagination): Pagination,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
@@ -33,11 +34,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +60,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.user_sessions.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.user_sessions.get", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
id: UlidPathParam,
|
||||
|
||||
@@ -11,6 +11,7 @@ use axum::{
|
||||
};
|
||||
use axum_macros::FromRequestParts;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_storage::{pagination::Page, user::BrowserSessionFilter};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
@@ -100,12 +101,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::UserNotFound(_) => StatusCode::NOT_FOUND,
|
||||
Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,7 +142,7 @@ Use the `filter[status]` parameter to filter the sessions by their status and `p
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.user_sessions.list", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.user_sessions.list", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
Pagination(pagination): Pagination,
|
||||
|
||||
@@ -9,6 +9,7 @@ use std::sync::Arc;
|
||||
use aide::{NoApi, OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, extract::State, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_storage::{
|
||||
BoxRng,
|
||||
@@ -81,12 +82,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_) | Self::Homeserver(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) | Self::Homeserver(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::UsernameNotValid => StatusCode::BAD_REQUEST,
|
||||
Self::UserAlreadyExists | Self::UsernameReserved => StatusCode::CONFLICT,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +133,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.add", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.add", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext {
|
||||
mut repo, clock, ..
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, extract::Path, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
|
||||
@@ -34,11 +35,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,7 +67,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.by_username", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.by_username", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
Path(UsernamePathParam { username }): Path<UsernamePathParam>,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
use aide::{NoApi, OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_storage::{
|
||||
BoxRng,
|
||||
queue::{DeactivateUserJob, QueueJobRepositoryExt as _},
|
||||
@@ -39,11 +40,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +69,7 @@ This invalidates any existing session, and will ask the homeserver to make them
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.deactivate", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.deactivate", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext {
|
||||
mut repo, clock, ..
|
||||
@@ -86,7 +88,7 @@ pub async fn handler(
|
||||
user = repo.user().lock(&clock, user).await?;
|
||||
}
|
||||
|
||||
info!("Scheduling deactivation of user {}", user.id);
|
||||
info!(%user.id, "Scheduling deactivation of user");
|
||||
repo.queue_job()
|
||||
.schedule_job(&mut rng, &clock, DeactivateUserJob::new(&user, true))
|
||||
.await?;
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
@@ -34,11 +35,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +60,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.get", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
id: UlidPathParam,
|
||||
|
||||
@@ -12,6 +12,7 @@ use axum::{
|
||||
};
|
||||
use axum_macros::FromRequestParts;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_storage::{Page, user::UserFilter};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
@@ -95,11 +96,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +124,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.list", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.list", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
Pagination(pagination): Pagination,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{
|
||||
@@ -34,11 +35,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,7 +64,7 @@ This DOES NOT invalidate any existing session, meaning that all their existing s
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.lock", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.lock", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext {
|
||||
mut repo, clock, ..
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use ulid::Ulid;
|
||||
@@ -36,11 +37,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +73,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.set_admin", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.set_admin", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
id: UlidPathParam,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
use aide::{NoApi, OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, extract::State, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_storage::BoxRng;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
@@ -43,13 +44,14 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_) | Self::Password(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) | Self::Password(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::PasswordAuthDisabled => StatusCode::FORBIDDEN,
|
||||
Self::PasswordTooWeak => StatusCode::BAD_REQUEST,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,7 +92,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.set_password", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.set_password", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext {
|
||||
mut repo, clock, ..
|
||||
|
||||
@@ -9,6 +9,7 @@ use std::sync::Arc;
|
||||
use aide::{OperationIo, transform::TransformOperation};
|
||||
use axum::{Json, extract::State, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use ulid::Ulid;
|
||||
|
||||
@@ -40,11 +41,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let error = ErrorResponse::from_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_) | Self::Homeserver(_));
|
||||
let status = match self {
|
||||
Self::Internal(_) | Self::Homeserver(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
};
|
||||
(status, Json(error)).into_response()
|
||||
(status, sentry_event_id, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +68,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.unlock", skip_all, err)]
|
||||
#[tracing::instrument(name = "handler.admin.v1.users.unlock", skip_all)]
|
||||
pub async fn handler(
|
||||
CallContext { mut repo, .. }: CallContext,
|
||||
State(homeserver): State<Arc<dyn HomeserverConnection>>,
|
||||
|
||||
@@ -156,7 +156,6 @@ impl Form {
|
||||
skip_all,
|
||||
name = "captcha.verify",
|
||||
fields(captcha.hostname, captcha.challenge_ts, captcha.service),
|
||||
err
|
||||
)]
|
||||
pub async fn verify(
|
||||
&self,
|
||||
|
||||
@@ -10,7 +10,7 @@ use axum::{Json, extract::State, response::IntoResponse};
|
||||
use axum_extra::typed_header::TypedHeader;
|
||||
use chrono::Duration;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::sentry::SentryEventID;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_data_model::{
|
||||
CompatSession, CompatSsoLoginState, Device, SiteConfig, TokenType, User, UserAgent,
|
||||
};
|
||||
@@ -210,7 +210,8 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id =
|
||||
record_error!(self, Self::Internal(_) | Self::ProvisionDeviceFailed(_));
|
||||
LOGIN_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
|
||||
let response = match self {
|
||||
Self::Internal(_) | Self::ProvisionDeviceFailed(_) => MatrixError {
|
||||
@@ -257,11 +258,11 @@ impl IntoResponse for RouteError {
|
||||
},
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.compat.login.post", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.compat.login.post", skip_all)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
|
||||
@@ -48,7 +48,6 @@ pub struct Params {
|
||||
name = "handlers.compat.login_sso_complete.get",
|
||||
fields(compat_sso_login.id = %id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub async fn get(
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
@@ -121,7 +120,6 @@ pub async fn get(
|
||||
name = "handlers.compat.login_sso_complete.post",
|
||||
fields(compat_sso_login.id = %id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub async fn post(
|
||||
mut rng: BoxRng,
|
||||
|
||||
@@ -9,7 +9,7 @@ use axum::{
|
||||
response::IntoResponse,
|
||||
};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::sentry::SentryEventID;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
|
||||
use mas_storage::{BoxClock, BoxRepository, BoxRng, compat::CompatSsoLoginRepository};
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
@@ -43,17 +43,16 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
SentryEventID::from(event_id),
|
||||
format!("{self}"),
|
||||
)
|
||||
.into_response()
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let status_code = match &self {
|
||||
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::MissingRedirectUrl | Self::InvalidRedirectUrl => StatusCode::BAD_REQUEST,
|
||||
};
|
||||
(status_code, sentry_event_id, format!("{self}")).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.compat.login_sso_redirect.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.compat.login_sso_redirect.get", skip_all)]
|
||||
pub async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
|
||||
@@ -10,7 +10,7 @@ use axum::{Json, response::IntoResponse};
|
||||
use axum_extra::typed_header::TypedHeader;
|
||||
use headers::{Authorization, authorization::Bearer};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::sentry::SentryEventID;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_data_model::TokenType;
|
||||
use mas_storage::{
|
||||
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
|
||||
@@ -51,7 +51,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
LOGOUT_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
|
||||
let response = match self {
|
||||
Self::Internal(_) => MatrixError {
|
||||
@@ -71,11 +71,11 @@ impl IntoResponse for RouteError {
|
||||
},
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.compat.logout.post", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.compat.logout.post", skip_all)]
|
||||
pub(crate) async fn post(
|
||||
clock: BoxClock,
|
||||
mut rng: BoxRng,
|
||||
|
||||
@@ -14,7 +14,7 @@ use axum::{
|
||||
response::IntoResponse,
|
||||
};
|
||||
use hyper::{StatusCode, header};
|
||||
use mas_axum_utils::sentry::SentryEventID;
|
||||
use mas_axum_utils::record_error;
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -59,7 +59,7 @@ pub enum MatrixJsonBodyRejection {
|
||||
|
||||
impl IntoResponse for MatrixJsonBodyRejection {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(self, !);
|
||||
let response = match self {
|
||||
Self::InvalidContentType | Self::ContentTypeNotJson(_) => MatrixError {
|
||||
errcode: "M_NOT_JSON",
|
||||
@@ -102,7 +102,7 @@ impl IntoResponse for MatrixJsonBodyRejection {
|
||||
},
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
use axum::{Json, extract::State, response::IntoResponse};
|
||||
use chrono::Duration;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::sentry::SentryEventID;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_data_model::{SiteConfig, TokenFormatError, TokenType};
|
||||
use mas_storage::{
|
||||
BoxClock, BoxRepository, BoxRng, Clock,
|
||||
@@ -16,6 +16,7 @@ use mas_storage::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{DurationMilliSeconds, serde_as};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::MatrixError;
|
||||
use crate::{BoundActivityTracker, impl_from_error_for_route};
|
||||
@@ -31,46 +32,50 @@ pub enum RouteError {
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
|
||||
#[error("invalid token")]
|
||||
InvalidToken,
|
||||
InvalidToken(#[from] TokenFormatError),
|
||||
|
||||
#[error("refresh token already consumed")]
|
||||
RefreshTokenConsumed,
|
||||
#[error("unknown token")]
|
||||
UnknownToken,
|
||||
|
||||
#[error("invalid session")]
|
||||
InvalidSession,
|
||||
#[error("invalid token type {0}, expected a compat refresh token")]
|
||||
InvalidTokenType(TokenType),
|
||||
|
||||
#[error("unknown session")]
|
||||
UnknownSession,
|
||||
#[error("refresh token already consumed {0}")]
|
||||
RefreshTokenConsumed(Ulid),
|
||||
|
||||
#[error("invalid compat session {0}")]
|
||||
InvalidSession(Ulid),
|
||||
|
||||
#[error("unknown comapt session {0}")]
|
||||
UnknownSession(Ulid),
|
||||
}
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_) | Self::UnknownSession(_));
|
||||
let response = match self {
|
||||
Self::Internal(_) | Self::UnknownSession => MatrixError {
|
||||
Self::Internal(_) | Self::UnknownSession(_) => MatrixError {
|
||||
errcode: "M_UNKNOWN",
|
||||
error: "Internal error",
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
},
|
||||
Self::InvalidToken | Self::InvalidSession | Self::RefreshTokenConsumed => MatrixError {
|
||||
Self::InvalidToken(_)
|
||||
| Self::UnknownToken
|
||||
| Self::InvalidTokenType(_)
|
||||
| Self::InvalidSession(_)
|
||||
| Self::RefreshTokenConsumed(_) => MatrixError {
|
||||
errcode: "M_UNKNOWN_TOKEN",
|
||||
error: "Invalid refresh token",
|
||||
status: StatusCode::UNAUTHORIZED,
|
||||
},
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl From<TokenFormatError> for RouteError {
|
||||
fn from(_e: TokenFormatError) -> Self {
|
||||
Self::InvalidToken
|
||||
}
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ResponseBody {
|
||||
@@ -80,7 +85,7 @@ pub struct ResponseBody {
|
||||
expires_in_ms: Duration,
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.compat.refresh.post", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.compat.refresh.post", skip_all)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
@@ -92,27 +97,27 @@ pub(crate) async fn post(
|
||||
let token_type = TokenType::check(&input.refresh_token)?;
|
||||
|
||||
if token_type != TokenType::CompatRefreshToken {
|
||||
return Err(RouteError::InvalidToken);
|
||||
return Err(RouteError::InvalidTokenType(token_type));
|
||||
}
|
||||
|
||||
let refresh_token = repo
|
||||
.compat_refresh_token()
|
||||
.find_by_token(&input.refresh_token)
|
||||
.await?
|
||||
.ok_or(RouteError::InvalidToken)?;
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
if !refresh_token.is_valid() {
|
||||
return Err(RouteError::RefreshTokenConsumed);
|
||||
return Err(RouteError::RefreshTokenConsumed(refresh_token.id));
|
||||
}
|
||||
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.lookup(refresh_token.session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::UnknownSession)?;
|
||||
.ok_or(RouteError::UnknownSession(refresh_token.session_id))?;
|
||||
|
||||
if !session.is_valid() {
|
||||
return Err(RouteError::InvalidSession);
|
||||
return Err(RouteError::InvalidSession(refresh_token.session_id));
|
||||
}
|
||||
|
||||
activity_tracker
|
||||
|
||||
@@ -552,7 +552,7 @@ impl UserMutations {
|
||||
let user = repo.user().lock(&state.clock(), user).await?;
|
||||
|
||||
if deactivate {
|
||||
info!("Scheduling deactivation of user {}", user.id);
|
||||
info!(%user.id, "Scheduling deactivation of user");
|
||||
repo.queue_job()
|
||||
.schedule_job(&mut rng, &clock, DeactivateUserJob::new(&user, deactivate))
|
||||
.await?;
|
||||
|
||||
@@ -13,7 +13,7 @@ use hyper::StatusCode;
|
||||
use mas_axum_utils::{
|
||||
cookies::CookieJar,
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
sentry::SentryEventID,
|
||||
record_error,
|
||||
};
|
||||
use mas_data_model::AuthorizationGrantStage;
|
||||
use mas_keystore::Keystore;
|
||||
@@ -46,11 +46,11 @@ pub enum RouteError {
|
||||
#[error("Authorization grant not found")]
|
||||
GrantNotFound,
|
||||
|
||||
#[error("Authorization grant already used")]
|
||||
GrantNotPending,
|
||||
#[error("Authorization grant {0} already used")]
|
||||
GrantNotPending(Ulid),
|
||||
|
||||
#[error("Failed to load client")]
|
||||
NoSuchClient,
|
||||
#[error("Failed to load client {0}")]
|
||||
NoSuchClient(Ulid),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
@@ -64,10 +64,10 @@ impl_from_error_for_route!(super::callback::CallbackDestinationError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_) | Self::NoSuchClient(_));
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
SentryEventID::from(event_id),
|
||||
sentry_event_id,
|
||||
self.to_string(),
|
||||
)
|
||||
.into_response()
|
||||
@@ -78,7 +78,6 @@ impl IntoResponse for RouteError {
|
||||
name = "handlers.oauth2.authorization.consent.get",
|
||||
fields(grant.id = %grant_id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
@@ -118,10 +117,10 @@ pub(crate) async fn get(
|
||||
.oauth2_client()
|
||||
.lookup(grant.client_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchClient)?;
|
||||
.ok_or(RouteError::NoSuchClient(grant.client_id))?;
|
||||
|
||||
if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
|
||||
return Err(RouteError::GrantNotPending);
|
||||
return Err(RouteError::GrantNotPending(grant.id));
|
||||
}
|
||||
|
||||
let Some(session) = maybe_session else {
|
||||
@@ -172,7 +171,6 @@ pub(crate) async fn get(
|
||||
name = "handlers.oauth2.authorization.consent.post",
|
||||
fields(grant.id = %grant_id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
@@ -229,7 +227,11 @@ pub(crate) async fn post(
|
||||
.oauth2_client()
|
||||
.lookup(grant.client_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchClient)?;
|
||||
.ok_or(RouteError::NoSuchClient(grant.client_id))?;
|
||||
|
||||
if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
|
||||
return Err(RouteError::GrantNotPending(grant.id));
|
||||
}
|
||||
|
||||
let res = policy
|
||||
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
|
||||
|
||||
@@ -9,7 +9,7 @@ use axum::{
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, sentry::SentryEventID};
|
||||
use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, record_error};
|
||||
use mas_data_model::{AuthorizationCode, Pkce};
|
||||
use mas_router::{PostAuthAction, UrlBuilder};
|
||||
use mas_storage::{
|
||||
@@ -53,7 +53,7 @@ pub enum RouteError {
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
// TODO: better error pages
|
||||
let response = match self {
|
||||
RouteError::Internal(e) => {
|
||||
@@ -75,7 +75,7 @@ impl IntoResponse for RouteError {
|
||||
.into_response(),
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +122,6 @@ fn resolve_response_mode(
|
||||
name = "handlers.oauth2.authorization.get",
|
||||
fields(client.id = %params.auth.client_id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(crate) async fn get(
|
||||
@@ -319,7 +318,7 @@ pub(crate) async fn get(
|
||||
let response = match res {
|
||||
Ok(r) => r,
|
||||
Err(err) => {
|
||||
tracing::error!(%err);
|
||||
tracing::error!(message = &err as &dyn std::error::Error);
|
||||
callback_destination.go(
|
||||
&templates,
|
||||
&locale,
|
||||
|
||||
@@ -11,7 +11,7 @@ use headers::{CacheControl, Pragma};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::{
|
||||
client_authorization::{ClientAuthorization, CredentialsVerificationError},
|
||||
sentry::SentryEventID,
|
||||
record_error,
|
||||
};
|
||||
use mas_data_model::UserAgent;
|
||||
use mas_keystore::Encrypter;
|
||||
@@ -24,6 +24,7 @@ use oauth2_types::{
|
||||
};
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{BoundActivityTracker, impl_from_error_for_route};
|
||||
|
||||
@@ -35,35 +36,46 @@ pub(crate) enum RouteError {
|
||||
#[error("client not found")]
|
||||
ClientNotFound,
|
||||
|
||||
#[error("client not allowed")]
|
||||
ClientNotAllowed,
|
||||
#[error("client {0} is not allowed to use the device code grant")]
|
||||
ClientNotAllowed(Ulid),
|
||||
|
||||
#[error("could not verify client credentials")]
|
||||
ClientCredentialsVerification(#[from] CredentialsVerificationError),
|
||||
#[error("invalid client credentials for client {client_id}")]
|
||||
InvalidClientCredentials {
|
||||
client_id: Ulid,
|
||||
#[source]
|
||||
source: CredentialsVerificationError,
|
||||
},
|
||||
|
||||
#[error("could not verify client credentials for client {client_id}")]
|
||||
ClientCredentialsVerification {
|
||||
client_id: Ulid,
|
||||
#[source]
|
||||
source: CredentialsVerificationError,
|
||||
},
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
|
||||
let response = match self {
|
||||
Self::Internal(_) => (
|
||||
Self::Internal(_) | Self::ClientCredentialsVerification { .. } => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ClientError::from(ClientErrorCode::ServerError)),
|
||||
),
|
||||
Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
|
||||
Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ClientError::from(ClientErrorCode::InvalidClient)),
|
||||
),
|
||||
Self::ClientNotAllowed => (
|
||||
Self::ClientNotAllowed(_) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
|
||||
),
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +83,6 @@ impl IntoResponse for RouteError {
|
||||
name = "handlers.oauth2.device.request.post",
|
||||
fields(client.id = client_authorization.client_id()),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
@@ -94,15 +105,28 @@ pub(crate) async fn post(
|
||||
let method = client
|
||||
.token_endpoint_auth_method
|
||||
.as_ref()
|
||||
.ok_or(RouteError::ClientNotAllowed)?;
|
||||
.ok_or(RouteError::ClientNotAllowed(client.id))?;
|
||||
|
||||
client_authorization
|
||||
.credentials
|
||||
.verify(&http_client, &encrypter, method, &client)
|
||||
.await?;
|
||||
.await
|
||||
.map_err(|err| {
|
||||
if err.is_internal() {
|
||||
RouteError::ClientCredentialsVerification {
|
||||
client_id: client.id,
|
||||
source: err,
|
||||
}
|
||||
} else {
|
||||
RouteError::InvalidClientCredentials {
|
||||
client_id: client.id,
|
||||
source: err,
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
if !client.grant_types.contains(&GrantType::DeviceCode) {
|
||||
return Err(RouteError::ClientNotAllowed);
|
||||
return Err(RouteError::ClientNotAllowed(client.id));
|
||||
}
|
||||
|
||||
let scope = client_authorization
|
||||
|
||||
@@ -41,6 +41,7 @@ pub(crate) struct ConsentForm {
|
||||
action: Action,
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.oauth2.device.consent.get", skip_all)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
@@ -136,6 +137,7 @@ pub(crate) async fn get(
|
||||
Ok((cookie_jar, Html(rendered)).into_response())
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.oauth2.device.consent.post", skip_all)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
|
||||
@@ -24,7 +24,7 @@ pub struct Params {
|
||||
code: Option<String>,
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.oauth2.device.link.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.oauth2.device.link.get", skip_all)]
|
||||
pub(crate) async fn get(
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
|
||||
@@ -10,7 +10,7 @@ use axum::{Json, extract::State, http::HeaderValue, response::IntoResponse};
|
||||
use hyper::{HeaderMap, StatusCode};
|
||||
use mas_axum_utils::{
|
||||
client_authorization::{ClientAuthorization, CredentialsVerificationError},
|
||||
sentry::SentryEventID,
|
||||
record_error,
|
||||
};
|
||||
use mas_data_model::{Device, TokenFormatError, TokenType};
|
||||
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
|
||||
@@ -28,6 +28,7 @@ use oauth2_types::{
|
||||
};
|
||||
use opentelemetry::{Key, KeyValue, metrics::Counter};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{ActivityTracker, METER, impl_from_error_for_route};
|
||||
|
||||
@@ -53,8 +54,8 @@ pub enum RouteError {
|
||||
ClientNotFound,
|
||||
|
||||
/// The client is not allowed to introspect.
|
||||
#[error("client is not allowed to introspect")]
|
||||
NotAllowed,
|
||||
#[error("client {0} is not allowed to introspect")]
|
||||
NotAllowed(Ulid),
|
||||
|
||||
/// The token type is not the one expected.
|
||||
#[error("unexpected token type")]
|
||||
@@ -73,30 +74,30 @@ pub enum RouteError {
|
||||
InvalidToken(TokenType),
|
||||
|
||||
/// The OAuth session is not valid.
|
||||
#[error("invalid oauth session")]
|
||||
InvalidOAuthSession,
|
||||
#[error("invalid oauth session {0}")]
|
||||
InvalidOAuthSession(Ulid),
|
||||
|
||||
/// The OAuth session could not be found in the database.
|
||||
#[error("unknown oauth session")]
|
||||
CantLoadOAuthSession,
|
||||
#[error("unknown oauth session {0}")]
|
||||
CantLoadOAuthSession(Ulid),
|
||||
|
||||
/// The compat session is not valid.
|
||||
#[error("invalid compat session")]
|
||||
InvalidCompatSession,
|
||||
#[error("invalid compat session {0}")]
|
||||
InvalidCompatSession(Ulid),
|
||||
|
||||
/// The compat session could not be found in the database.
|
||||
#[error("unknown compat session")]
|
||||
CantLoadCompatSession,
|
||||
#[error("unknown compat session {0}")]
|
||||
CantLoadCompatSession(Ulid),
|
||||
|
||||
/// The Device ID in the compat session can't be encoded as a scope
|
||||
#[error("device ID contains characters that are not allowed in a scope")]
|
||||
CantEncodeDeviceID(#[from] mas_data_model::ToScopeTokenError),
|
||||
|
||||
#[error("invalid user")]
|
||||
InvalidUser,
|
||||
#[error("invalid user {0}")]
|
||||
InvalidUser(Ulid),
|
||||
|
||||
#[error("unknown user")]
|
||||
CantLoadUser,
|
||||
#[error("unknown user {0}")]
|
||||
CantLoadUser(Ulid),
|
||||
|
||||
#[error("bad request")]
|
||||
BadRequest,
|
||||
@@ -107,12 +108,19 @@ pub enum RouteError {
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(
|
||||
self,
|
||||
Self::Internal(_)
|
||||
| Self::CantLoadCompatSession(_)
|
||||
| Self::CantLoadOAuthSession(_)
|
||||
| Self::CantLoadUser(_)
|
||||
);
|
||||
|
||||
let response = match self {
|
||||
e @ (Self::Internal(_)
|
||||
| Self::CantLoadCompatSession
|
||||
| Self::CantLoadOAuthSession
|
||||
| Self::CantLoadUser) => (
|
||||
| Self::CantLoadCompatSession(_)
|
||||
| Self::CantLoadOAuthSession(_)
|
||||
| Self::CantLoadUser(_)) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(
|
||||
ClientError::from(ClientErrorCode::ServerError).with_description(e.to_string()),
|
||||
@@ -136,9 +144,9 @@ impl IntoResponse for RouteError {
|
||||
Self::UnknownToken(_)
|
||||
| Self::UnexpectedTokenType
|
||||
| Self::InvalidToken(_)
|
||||
| Self::InvalidUser
|
||||
| Self::InvalidCompatSession
|
||||
| Self::InvalidOAuthSession
|
||||
| Self::InvalidUser(_)
|
||||
| Self::InvalidCompatSession(_)
|
||||
| Self::InvalidOAuthSession(_)
|
||||
| Self::InvalidTokenFormat(_)
|
||||
| Self::CantEncodeDeviceID(_) => {
|
||||
INTROSPECTION_COUNTER.add(1, &[KeyValue::new(ACTIVE.clone(), false)]);
|
||||
@@ -146,11 +154,12 @@ impl IntoResponse for RouteError {
|
||||
Json(INACTIVE).into_response()
|
||||
}
|
||||
|
||||
Self::NotAllowed => (
|
||||
Self::NotAllowed(_) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ClientError::from(ClientErrorCode::AccessDenied)),
|
||||
)
|
||||
.into_response(),
|
||||
|
||||
Self::BadRequest => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
|
||||
@@ -158,7 +167,7 @@ impl IntoResponse for RouteError {
|
||||
.into_response(),
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,7 +197,6 @@ const SYNAPSE_ADMIN_SCOPE: ScopeToken = ScopeToken::from_static("urn:synapse:adm
|
||||
name = "handlers.oauth2.introspection.post",
|
||||
fields(client.id = client_authorization.client_id()),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(crate) async fn post(
|
||||
@@ -208,7 +216,7 @@ pub(crate) async fn post(
|
||||
|
||||
let method = match &client.token_endpoint_auth_method {
|
||||
None | Some(OAuthClientAuthenticationMethod::None) => {
|
||||
return Err(RouteError::NotAllowed);
|
||||
return Err(RouteError::NotAllowed(client.id));
|
||||
}
|
||||
Some(c) => c,
|
||||
};
|
||||
@@ -259,10 +267,10 @@ pub(crate) async fn post(
|
||||
.oauth2_session()
|
||||
.lookup(access_token.session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::InvalidOAuthSession)?;
|
||||
.ok_or(RouteError::CantLoadOAuthSession(access_token.session_id))?;
|
||||
|
||||
if !session.is_valid() {
|
||||
return Err(RouteError::InvalidOAuthSession);
|
||||
return Err(RouteError::InvalidOAuthSession(session.id));
|
||||
}
|
||||
|
||||
// If this is the first time we're using this token, mark it as used
|
||||
@@ -280,10 +288,10 @@ pub(crate) async fn post(
|
||||
.user()
|
||||
.lookup(user_id)
|
||||
.await?
|
||||
.ok_or(RouteError::CantLoadUser)?;
|
||||
.ok_or(RouteError::CantLoadUser(user_id))?;
|
||||
|
||||
if !user.is_valid() {
|
||||
return Err(RouteError::InvalidUser);
|
||||
return Err(RouteError::InvalidUser(user.id));
|
||||
}
|
||||
|
||||
(Some(user.sub), Some(user.username))
|
||||
@@ -338,10 +346,10 @@ pub(crate) async fn post(
|
||||
.oauth2_session()
|
||||
.lookup(refresh_token.session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::CantLoadOAuthSession)?;
|
||||
.ok_or(RouteError::CantLoadOAuthSession(refresh_token.session_id))?;
|
||||
|
||||
if !session.is_valid() {
|
||||
return Err(RouteError::InvalidOAuthSession);
|
||||
return Err(RouteError::InvalidOAuthSession(session.id));
|
||||
}
|
||||
|
||||
// The session might not have a user on it (for Client Credentials grants for
|
||||
@@ -351,10 +359,10 @@ pub(crate) async fn post(
|
||||
.user()
|
||||
.lookup(user_id)
|
||||
.await?
|
||||
.ok_or(RouteError::CantLoadUser)?;
|
||||
.ok_or(RouteError::CantLoadUser(user_id))?;
|
||||
|
||||
if !user.is_valid() {
|
||||
return Err(RouteError::InvalidUser);
|
||||
return Err(RouteError::InvalidUser(user.id));
|
||||
}
|
||||
|
||||
(Some(user.sub), Some(user.username))
|
||||
@@ -407,20 +415,20 @@ pub(crate) async fn post(
|
||||
.compat_session()
|
||||
.lookup(access_token.session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::CantLoadCompatSession)?;
|
||||
.ok_or(RouteError::CantLoadCompatSession(access_token.session_id))?;
|
||||
|
||||
if !session.is_valid() {
|
||||
return Err(RouteError::InvalidCompatSession);
|
||||
return Err(RouteError::InvalidCompatSession(session.id));
|
||||
}
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(session.user_id)
|
||||
.await?
|
||||
.ok_or(RouteError::CantLoadUser)?;
|
||||
.ok_or(RouteError::CantLoadUser(session.user_id))?;
|
||||
|
||||
if !user.is_valid() {
|
||||
return Err(RouteError::InvalidUser)?;
|
||||
return Err(RouteError::InvalidUser(user.id))?;
|
||||
}
|
||||
|
||||
// Grant the synapse admin scope if the session has the admin flag set.
|
||||
@@ -491,20 +499,20 @@ pub(crate) async fn post(
|
||||
.compat_session()
|
||||
.lookup(refresh_token.session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::CantLoadCompatSession)?;
|
||||
.ok_or(RouteError::CantLoadCompatSession(refresh_token.session_id))?;
|
||||
|
||||
if !session.is_valid() {
|
||||
return Err(RouteError::InvalidCompatSession);
|
||||
return Err(RouteError::InvalidCompatSession(session.id));
|
||||
}
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(session.user_id)
|
||||
.await?
|
||||
.ok_or(RouteError::CantLoadUser)?;
|
||||
.ok_or(RouteError::CantLoadUser(session.user_id))?;
|
||||
|
||||
if !user.is_valid() {
|
||||
return Err(RouteError::InvalidUser)?;
|
||||
return Err(RouteError::InvalidUser(user.id))?;
|
||||
}
|
||||
|
||||
// Grant the synapse admin scope if the session has the admin flag set.
|
||||
|
||||
@@ -9,10 +9,10 @@ use std::sync::LazyLock;
|
||||
use axum::{Json, extract::State, response::IntoResponse};
|
||||
use axum_extra::TypedHeader;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::sentry::SentryEventID;
|
||||
use mas_axum_utils::record_error;
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_policy::{Policy, Violation};
|
||||
use mas_policy::{EvaluationResult, Policy};
|
||||
use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2ClientRepository};
|
||||
use oauth2_types::{
|
||||
errors::{ClientError, ClientErrorCode},
|
||||
@@ -55,8 +55,8 @@ pub(crate) enum RouteError {
|
||||
#[error("{0} is a public suffix, not a valid domain")]
|
||||
UrlIsPublicSuffix(&'static str),
|
||||
|
||||
#[error("denied by the policy: {0:?}")]
|
||||
PolicyDenied(Vec<Violation>),
|
||||
#[error("client registration denied by the policy: {0}")]
|
||||
PolicyDenied(EvaluationResult),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
@@ -67,7 +67,7 @@ impl_from_error_for_route!(serde_json::Error);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
|
||||
REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "denied")]);
|
||||
|
||||
@@ -143,15 +143,20 @@ impl IntoResponse for RouteError {
|
||||
// For policy violations, we return an `invalid_client_metadata` error with the details
|
||||
// of the violations in most cases. If a violation includes `redirect_uri` in the
|
||||
// message, we return an `invalid_redirect_uri` error instead.
|
||||
Self::PolicyDenied(violations) => {
|
||||
Self::PolicyDenied(evaluation) => {
|
||||
// TODO: detect them better
|
||||
let code = if violations.iter().any(|v| v.msg.contains("redirect_uri")) {
|
||||
let code = if evaluation
|
||||
.violations
|
||||
.iter()
|
||||
.any(|v| v.msg.contains("redirect_uri"))
|
||||
{
|
||||
ClientErrorCode::InvalidRedirectUri
|
||||
} else {
|
||||
ClientErrorCode::InvalidClientMetadata
|
||||
};
|
||||
|
||||
let collected = &violations
|
||||
let collected = &evaluation
|
||||
.violations
|
||||
.iter()
|
||||
.map(|v| v.msg.clone())
|
||||
.collect::<Vec<String>>();
|
||||
@@ -165,7 +170,7 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,7 +212,7 @@ fn localised_url_has_public_suffix(url: &Localized<Url>) -> bool {
|
||||
url.iter().any(|(_lang, url)| host_is_public_suffix(url))
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
@@ -282,7 +287,7 @@ pub(crate) async fn post(
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
return Err(RouteError::PolicyDenied(res.violations));
|
||||
return Err(RouteError::PolicyDenied(res));
|
||||
}
|
||||
|
||||
let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
|
||||
|
||||
@@ -8,7 +8,7 @@ use axum::{Json, extract::State, response::IntoResponse};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::{
|
||||
client_authorization::{ClientAuthorization, CredentialsVerificationError},
|
||||
sentry::SentryEventID,
|
||||
record_error,
|
||||
};
|
||||
use mas_data_model::TokenType;
|
||||
use mas_iana::oauth::OAuthTokenTypeHint;
|
||||
@@ -22,6 +22,7 @@ use oauth2_types::{
|
||||
requests::RevocationRequest,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{BoundActivityTracker, impl_from_error_for_route};
|
||||
|
||||
@@ -39,8 +40,19 @@ pub(crate) enum RouteError {
|
||||
#[error("client not allowed")]
|
||||
ClientNotAllowed,
|
||||
|
||||
#[error("could not verify client credentials")]
|
||||
ClientCredentialsVerification(#[from] CredentialsVerificationError),
|
||||
#[error("invalid client credentials for client {client_id}")]
|
||||
InvalidClientCredentials {
|
||||
client_id: Ulid,
|
||||
#[source]
|
||||
source: CredentialsVerificationError,
|
||||
},
|
||||
|
||||
#[error("could not verify client credentials for client {client_id}")]
|
||||
ClientCredentialsVerification {
|
||||
client_id: Ulid,
|
||||
#[source]
|
||||
source: CredentialsVerificationError,
|
||||
},
|
||||
|
||||
#[error("client is unauthorized")]
|
||||
UnauthorizedClient,
|
||||
@@ -54,9 +66,9 @@ pub(crate) enum RouteError {
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let response = match self {
|
||||
Self::Internal(_) => (
|
||||
Self::Internal(_) | Self::ClientCredentialsVerification { .. } => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ClientError::from(ClientErrorCode::ServerError)),
|
||||
)
|
||||
@@ -68,7 +80,7 @@ impl IntoResponse for RouteError {
|
||||
)
|
||||
.into_response(),
|
||||
|
||||
Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
|
||||
Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ClientError::from(ClientErrorCode::InvalidClient)),
|
||||
)
|
||||
@@ -90,7 +102,7 @@ impl IntoResponse for RouteError {
|
||||
Self::UnknownToken => StatusCode::OK.into_response(),
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,7 +118,6 @@ impl From<mas_data_model::TokenFormatError> for RouteError {
|
||||
name = "handlers.oauth2.revoke.post",
|
||||
fields(client.id = client_authorization.client_id()),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn post(
|
||||
clock: BoxClock,
|
||||
@@ -131,7 +142,20 @@ pub(crate) async fn post(
|
||||
client_authorization
|
||||
.credentials
|
||||
.verify(&http_client, &encrypter, method, &client)
|
||||
.await?;
|
||||
.await
|
||||
.map_err(|err| {
|
||||
if err.is_internal() {
|
||||
RouteError::ClientCredentialsVerification {
|
||||
client_id: client.id,
|
||||
source: err,
|
||||
}
|
||||
} else {
|
||||
RouteError::InvalidClientCredentials {
|
||||
client_id: client.id,
|
||||
source: err,
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
let Some(form) = client_authorization.form else {
|
||||
return Err(RouteError::BadRequest);
|
||||
|
||||
@@ -13,7 +13,7 @@ use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::{
|
||||
client_authorization::{ClientAuthorization, CredentialsVerificationError},
|
||||
sentry::SentryEventID,
|
||||
record_error,
|
||||
};
|
||||
use mas_data_model::{
|
||||
AuthorizationGrantStage, Client, Device, DeviceCodeGrantState, SiteConfig, TokenType, UserAgent,
|
||||
@@ -42,7 +42,7 @@ use oauth2_types::{
|
||||
};
|
||||
use opentelemetry::{Key, KeyValue, metrics::Counter};
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info};
|
||||
use tracing::{debug, info, warn};
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::{generate_id_token, generate_token_pair};
|
||||
@@ -72,17 +72,28 @@ pub(crate) enum RouteError {
|
||||
#[error("client not found")]
|
||||
ClientNotFound,
|
||||
|
||||
#[error("client not allowed")]
|
||||
ClientNotAllowed,
|
||||
#[error("client not allowed to use the token endpoint: {0}")]
|
||||
ClientNotAllowed(Ulid),
|
||||
|
||||
#[error("could not verify client credentials")]
|
||||
ClientCredentialsVerification(#[from] CredentialsVerificationError),
|
||||
#[error("invalid client credentials for client {client_id}")]
|
||||
InvalidClientCredentials {
|
||||
client_id: Ulid,
|
||||
#[source]
|
||||
source: CredentialsVerificationError,
|
||||
},
|
||||
|
||||
#[error("could not verify client credentials for client {client_id}")]
|
||||
ClientCredentialsVerification {
|
||||
client_id: Ulid,
|
||||
#[source]
|
||||
source: CredentialsVerificationError,
|
||||
},
|
||||
|
||||
#[error("grant not found")]
|
||||
GrantNotFound,
|
||||
|
||||
#[error("invalid grant")]
|
||||
InvalidGrant,
|
||||
#[error("invalid grant {0}")]
|
||||
InvalidGrant(Ulid),
|
||||
|
||||
#[error("refresh token not found")]
|
||||
RefreshTokenNotFound,
|
||||
@@ -96,20 +107,23 @@ pub(crate) enum RouteError {
|
||||
#[error("client id mismatch: expected {expected}, got {actual}")]
|
||||
ClientIDMismatch { expected: Ulid, actual: Ulid },
|
||||
|
||||
#[error("policy denied the request")]
|
||||
DeniedByPolicy(Vec<mas_policy::Violation>),
|
||||
#[error("policy denied the request: {0}")]
|
||||
DeniedByPolicy(mas_policy::EvaluationResult),
|
||||
|
||||
#[error("unsupported grant type")]
|
||||
UnsupportedGrantType,
|
||||
|
||||
#[error("unauthorized client")]
|
||||
UnauthorizedClient,
|
||||
#[error("client {0} is not authorized to use this grant type")]
|
||||
UnauthorizedClient(Ulid),
|
||||
|
||||
#[error("failed to load browser session")]
|
||||
NoSuchBrowserSession,
|
||||
#[error("unexpected client {was} (expected {expected})")]
|
||||
UnexptectedClient { was: Ulid, expected: Ulid },
|
||||
|
||||
#[error("failed to load oauth session")]
|
||||
NoSuchOAuthSession,
|
||||
#[error("failed to load browser session {0}")]
|
||||
NoSuchBrowserSession(Ulid),
|
||||
|
||||
#[error("failed to load oauth session {0}")]
|
||||
NoSuchOAuthSession(Ulid),
|
||||
|
||||
#[error(
|
||||
"failed to load the next refresh token ({next:?}) from the previous one ({previous:?})"
|
||||
@@ -145,14 +159,25 @@ pub(crate) enum RouteError {
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(
|
||||
self,
|
||||
Self::Internal(_)
|
||||
| Self::ClientCredentialsVerification { .. }
|
||||
| Self::NoSuchBrowserSession(_)
|
||||
| Self::NoSuchOAuthSession(_)
|
||||
| Self::ProvisionDeviceFailed(_)
|
||||
| Self::NoSuchNextRefreshToken { .. }
|
||||
| Self::NoSuchNextAccessToken { .. }
|
||||
| Self::NoAccessTokenOnRefreshToken { .. }
|
||||
);
|
||||
|
||||
TOKEN_REQUEST_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
|
||||
|
||||
let response = match self {
|
||||
Self::Internal(_)
|
||||
| Self::NoSuchBrowserSession
|
||||
| Self::NoSuchOAuthSession
|
||||
| Self::ClientCredentialsVerification { .. }
|
||||
| Self::NoSuchBrowserSession(_)
|
||||
| Self::NoSuchOAuthSession(_)
|
||||
| Self::ProvisionDeviceFailed(_)
|
||||
| Self::NoSuchNextRefreshToken { .. }
|
||||
| Self::NoSuchNextAccessToken { .. }
|
||||
@@ -160,10 +185,12 @@ impl IntoResponse for RouteError {
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ClientError::from(ClientErrorCode::ServerError)),
|
||||
),
|
||||
|
||||
Self::BadRequest => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
|
||||
),
|
||||
|
||||
Self::PkceVerification(err) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
@@ -171,19 +198,25 @@ impl IntoResponse for RouteError {
|
||||
.with_description(format!("PKCE verification failed: {err}")),
|
||||
),
|
||||
),
|
||||
Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
|
||||
|
||||
Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ClientError::from(ClientErrorCode::InvalidClient)),
|
||||
),
|
||||
Self::ClientNotAllowed | Self::UnauthorizedClient => (
|
||||
|
||||
Self::ClientNotAllowed(_)
|
||||
| Self::UnauthorizedClient(_)
|
||||
| Self::UnexptectedClient { .. } => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
|
||||
),
|
||||
Self::DeniedByPolicy(violations) => (
|
||||
|
||||
Self::DeniedByPolicy(evaluation) => (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
ClientError::from(ClientErrorCode::InvalidScope).with_description(
|
||||
violations
|
||||
evaluation
|
||||
.violations
|
||||
.into_iter()
|
||||
.map(|violation| violation.msg)
|
||||
.collect::<Vec<_>>()
|
||||
@@ -191,19 +224,23 @@ impl IntoResponse for RouteError {
|
||||
),
|
||||
),
|
||||
),
|
||||
|
||||
Self::DeviceCodeRejected => (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ClientError::from(ClientErrorCode::AccessDenied)),
|
||||
),
|
||||
|
||||
Self::DeviceCodeExpired => (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ClientError::from(ClientErrorCode::ExpiredToken)),
|
||||
),
|
||||
|
||||
Self::DeviceCodePending => (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ClientError::from(ClientErrorCode::AuthorizationPending)),
|
||||
),
|
||||
Self::InvalidGrant
|
||||
|
||||
Self::InvalidGrant(_)
|
||||
| Self::DeviceCodeExchanged
|
||||
| Self::RefreshTokenNotFound
|
||||
| Self::RefreshTokenInvalid(_)
|
||||
@@ -213,13 +250,14 @@ impl IntoResponse for RouteError {
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ClientError::from(ClientErrorCode::InvalidGrant)),
|
||||
),
|
||||
|
||||
Self::UnsupportedGrantType => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ClientError::from(ClientErrorCode::UnsupportedGrantType)),
|
||||
),
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,7 +269,6 @@ impl_from_error_for_route!(super::IdTokenSignatureError);
|
||||
name = "handlers.oauth2.token.post",
|
||||
fields(client.id = client_authorization.client_id()),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
@@ -258,12 +295,27 @@ pub(crate) async fn post(
|
||||
let method = client
|
||||
.token_endpoint_auth_method
|
||||
.as_ref()
|
||||
.ok_or(RouteError::ClientNotAllowed)?;
|
||||
.ok_or(RouteError::ClientNotAllowed(client.id))?;
|
||||
|
||||
client_authorization
|
||||
.credentials
|
||||
.verify(&http_client, &encrypter, method, &client)
|
||||
.await?;
|
||||
.await
|
||||
.map_err(|err| {
|
||||
// Classify the error differntly, depending on whether it's an 'internal' error,
|
||||
// or just because the client presented invalid credentials.
|
||||
if err.is_internal() {
|
||||
RouteError::ClientCredentialsVerification {
|
||||
client_id: client.id,
|
||||
source: err,
|
||||
}
|
||||
} else {
|
||||
RouteError::InvalidClientCredentials {
|
||||
client_id: client.id,
|
||||
source: err,
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
let form = client_authorization.form.ok_or(RouteError::BadRequest)?;
|
||||
|
||||
@@ -367,7 +419,7 @@ async fn authorization_code_grant(
|
||||
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
|
||||
// Check that the client is allowed to use this grant type
|
||||
if !client.grant_types.contains(&GrantType::AuthorizationCode) {
|
||||
return Err(RouteError::UnauthorizedClient);
|
||||
return Err(RouteError::UnauthorizedClient(client.id));
|
||||
}
|
||||
|
||||
let authz_grant = repo
|
||||
@@ -381,40 +433,43 @@ async fn authorization_code_grant(
|
||||
let session_id = match authz_grant.stage {
|
||||
AuthorizationGrantStage::Cancelled { cancelled_at } => {
|
||||
debug!(%cancelled_at, "Authorization grant was cancelled");
|
||||
return Err(RouteError::InvalidGrant);
|
||||
return Err(RouteError::InvalidGrant(authz_grant.id));
|
||||
}
|
||||
AuthorizationGrantStage::Exchanged {
|
||||
exchanged_at,
|
||||
fulfilled_at,
|
||||
session_id,
|
||||
} => {
|
||||
debug!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged");
|
||||
warn!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged");
|
||||
|
||||
// Ending the session if the token was already exchanged more than 20s ago
|
||||
if now - exchanged_at > Duration::microseconds(20 * 1000 * 1000) {
|
||||
debug!("Ending potentially compromised session");
|
||||
warn!(oauth_session.id = %session_id, "Ending potentially compromised session");
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchOAuthSession)?;
|
||||
.ok_or(RouteError::NoSuchOAuthSession(session_id))?;
|
||||
|
||||
//if !session.is_finished() {
|
||||
repo.oauth2_session().finish(clock, session).await?;
|
||||
repo.save().await?;
|
||||
//}
|
||||
}
|
||||
|
||||
return Err(RouteError::InvalidGrant);
|
||||
return Err(RouteError::InvalidGrant(authz_grant.id));
|
||||
}
|
||||
AuthorizationGrantStage::Pending => {
|
||||
debug!("Authorization grant has not been fulfilled yet");
|
||||
return Err(RouteError::InvalidGrant);
|
||||
warn!("Authorization grant has not been fulfilled yet");
|
||||
return Err(RouteError::InvalidGrant(authz_grant.id));
|
||||
}
|
||||
AuthorizationGrantStage::Fulfilled {
|
||||
session_id,
|
||||
fulfilled_at,
|
||||
} => {
|
||||
if now - fulfilled_at > Duration::microseconds(10 * 60 * 1000 * 1000) {
|
||||
debug!("Code exchange took more than 10 minutes");
|
||||
return Err(RouteError::InvalidGrant);
|
||||
warn!("Code exchange took more than 10 minutes");
|
||||
return Err(RouteError::InvalidGrant(authz_grant.id));
|
||||
}
|
||||
|
||||
session_id
|
||||
@@ -425,7 +480,7 @@ async fn authorization_code_grant(
|
||||
.oauth2_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchOAuthSession)?;
|
||||
.ok_or(RouteError::NoSuchOAuthSession(session_id))?;
|
||||
|
||||
if let Some(user_agent) = user_agent {
|
||||
session = repo
|
||||
@@ -435,10 +490,16 @@ async fn authorization_code_grant(
|
||||
}
|
||||
|
||||
// This should never happen, since we looked up in the database using the code
|
||||
let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?;
|
||||
let code = authz_grant
|
||||
.code
|
||||
.as_ref()
|
||||
.ok_or(RouteError::InvalidGrant(authz_grant.id))?;
|
||||
|
||||
if client.id != session.client_id {
|
||||
return Err(RouteError::UnauthorizedClient);
|
||||
return Err(RouteError::UnexptectedClient {
|
||||
was: client.id,
|
||||
expected: session.client_id,
|
||||
});
|
||||
}
|
||||
|
||||
match (code.pkce.as_ref(), grant.code_verifier.as_ref()) {
|
||||
@@ -453,14 +514,14 @@ async fn authorization_code_grant(
|
||||
|
||||
let Some(user_session_id) = session.user_session_id else {
|
||||
tracing::warn!("No user session associated with this OAuth2 session");
|
||||
return Err(RouteError::InvalidGrant);
|
||||
return Err(RouteError::InvalidGrant(authz_grant.id));
|
||||
};
|
||||
|
||||
let browser_session = repo
|
||||
.browser_session()
|
||||
.lookup(user_session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchBrowserSession)?;
|
||||
.ok_or(RouteError::NoSuchBrowserSession(user_session_id))?;
|
||||
|
||||
let last_authentication = repo
|
||||
.browser_session()
|
||||
@@ -539,7 +600,7 @@ async fn refresh_token_grant(
|
||||
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
|
||||
// Check that the client is allowed to use this grant type
|
||||
if !client.grant_types.contains(&GrantType::RefreshToken) {
|
||||
return Err(RouteError::UnauthorizedClient);
|
||||
return Err(RouteError::UnauthorizedClient(client.id));
|
||||
}
|
||||
|
||||
let refresh_token = repo
|
||||
@@ -552,7 +613,7 @@ async fn refresh_token_grant(
|
||||
.oauth2_session()
|
||||
.lookup(refresh_token.session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchOAuthSession)?;
|
||||
.ok_or(RouteError::NoSuchOAuthSession(refresh_token.session_id))?;
|
||||
|
||||
// Let's for now record the user agent on each refresh, that should be
|
||||
// responsive enough and not too much of a burden on the database.
|
||||
@@ -692,7 +753,7 @@ async fn client_credentials_grant(
|
||||
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
|
||||
// Check that the client is allowed to use this grant type
|
||||
if !client.grant_types.contains(&GrantType::ClientCredentials) {
|
||||
return Err(RouteError::UnauthorizedClient);
|
||||
return Err(RouteError::UnauthorizedClient(client.id));
|
||||
}
|
||||
|
||||
// Default to an empty scope if none is provided
|
||||
@@ -715,7 +776,7 @@ async fn client_credentials_grant(
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
return Err(RouteError::DeniedByPolicy(res.violations));
|
||||
return Err(RouteError::DeniedByPolicy(res));
|
||||
}
|
||||
|
||||
// Start the session
|
||||
@@ -771,7 +832,7 @@ async fn device_code_grant(
|
||||
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
|
||||
// Check that the client is allowed to use this grant type
|
||||
if !client.grant_types.contains(&GrantType::DeviceCode) {
|
||||
return Err(RouteError::UnauthorizedClient);
|
||||
return Err(RouteError::UnauthorizedClient(client.id));
|
||||
}
|
||||
|
||||
let grant = repo
|
||||
@@ -804,14 +865,14 @@ async fn device_code_grant(
|
||||
}
|
||||
DeviceCodeGrantState::Fulfilled {
|
||||
browser_session_id, ..
|
||||
} => browser_session_id,
|
||||
} => *browser_session_id,
|
||||
};
|
||||
|
||||
let browser_session = repo
|
||||
.browser_session()
|
||||
.lookup(*browser_session_id)
|
||||
.lookup(browser_session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchBrowserSession)?;
|
||||
.ok_or(RouteError::NoSuchBrowserSession(browser_session_id))?;
|
||||
|
||||
// Start the session
|
||||
let mut session = repo
|
||||
|
||||
@@ -12,7 +12,7 @@ use axum::{
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::{
|
||||
jwt::JwtResponse,
|
||||
sentry::SentryEventID,
|
||||
record_error,
|
||||
user_authorization::{AuthorizationVerificationError, UserAuthorization},
|
||||
};
|
||||
use mas_jose::{
|
||||
@@ -25,6 +25,7 @@ use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2ClientRepositor
|
||||
use serde::Serialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{BoundActivityTracker, impl_from_error_for_route};
|
||||
|
||||
@@ -59,11 +60,11 @@ pub enum RouteError {
|
||||
#[error("no suitable key found for signing")]
|
||||
InvalidSigningKey,
|
||||
|
||||
#[error("failed to load client")]
|
||||
NoSuchClient,
|
||||
#[error("failed to load client {0}")]
|
||||
NoSuchClient(Ulid),
|
||||
|
||||
#[error("failed to load user")]
|
||||
NoSuchUser,
|
||||
#[error("failed to load user {0}")]
|
||||
NoSuchUser(Ulid),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
@@ -72,9 +73,18 @@ impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(
|
||||
self,
|
||||
Self::Internal(_)
|
||||
| Self::InvalidSigningKey
|
||||
| Self::NoSuchClient(_)
|
||||
| Self::NoSuchUser(_)
|
||||
);
|
||||
let response = match self {
|
||||
Self::Internal(_) | Self::InvalidSigningKey | Self::NoSuchClient | Self::NoSuchUser => {
|
||||
Self::Internal(_)
|
||||
| Self::InvalidSigningKey
|
||||
| Self::NoSuchClient(_)
|
||||
| Self::NoSuchUser(_) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
|
||||
}
|
||||
Self::AuthorizationVerificationError(_) | Self::Unauthorized => {
|
||||
@@ -82,11 +92,11 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.oauth2.userinfo.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.oauth2.userinfo.get", skip_all)]
|
||||
pub async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
@@ -116,7 +126,7 @@ pub async fn get(
|
||||
.user()
|
||||
.lookup(user_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchUser)?;
|
||||
.ok_or(RouteError::NoSuchUser(user_id))?;
|
||||
|
||||
let user_info = UserInfo {
|
||||
sub: user.sub.clone(),
|
||||
@@ -127,7 +137,7 @@ pub async fn get(
|
||||
.oauth2_client()
|
||||
.lookup(session.client_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchClient)?;
|
||||
.ok_or(RouteError::NoSuchClient(session.client_id))?;
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ use axum::{
|
||||
response::{IntoResponse, Redirect},
|
||||
};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID};
|
||||
use mas_axum_utils::{cookies::CookieJar, record_error};
|
||||
use mas_data_model::UpstreamOAuthProvider;
|
||||
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
|
||||
use mas_router::UrlBuilder;
|
||||
@@ -41,13 +41,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let response = match self {
|
||||
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
|
||||
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +55,6 @@ impl IntoResponse for RouteError {
|
||||
name = "handlers.upstream_oauth2.authorize.get",
|
||||
fields(upstream_oauth_provider.id = %provider_id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use mas_context::LogContext;
|
||||
use mas_data_model::{
|
||||
UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
|
||||
};
|
||||
@@ -164,7 +165,7 @@ impl MetadataCache {
|
||||
///
|
||||
/// This spawns a background task that will refresh the cache at the given
|
||||
/// interval.
|
||||
#[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all, err)]
|
||||
#[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all)]
|
||||
pub async fn warm_up_and_run<R: RepositoryAccess>(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
@@ -197,12 +198,14 @@ impl MetadataCache {
|
||||
loop {
|
||||
// Re-fetch the known metadata at the given interval
|
||||
tokio::time::sleep(interval).await;
|
||||
cache.refresh_all(&client).await;
|
||||
LogContext::new("metadata-cache-refresh")
|
||||
.run(|| cache.refresh_all(&client))
|
||||
.await;
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all, err)]
|
||||
#[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all)]
|
||||
async fn fetch(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
@@ -234,7 +237,7 @@ impl MetadataCache {
|
||||
}
|
||||
|
||||
/// Get the metadata for the given issuer.
|
||||
#[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all, err)]
|
||||
#[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all)]
|
||||
pub async fn get(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
|
||||
@@ -13,7 +13,7 @@ use axum::{
|
||||
response::{Html, IntoResponse, Response},
|
||||
};
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID};
|
||||
use mas_axum_utils::{cookies::CookieJar, record_error};
|
||||
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderResponseMode};
|
||||
use mas_jose::claims::TokenHash;
|
||||
use mas_keystore::{Encrypter, Keystore};
|
||||
@@ -153,7 +153,7 @@ impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(self, Self::Internal(_));
|
||||
let response = match self {
|
||||
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
|
||||
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
|
||||
@@ -161,7 +161,7 @@ impl IntoResponse for RouteError {
|
||||
e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,7 +169,6 @@ impl IntoResponse for RouteError {
|
||||
name = "handlers.upstream_oauth2.callback.handler",
|
||||
fields(upstream_oauth_provider.id = %provider_id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
|
||||
pub(crate) async fn handler(
|
||||
|
||||
@@ -17,7 +17,7 @@ use mas_axum_utils::{
|
||||
FancyError, SessionInfoExt,
|
||||
cookies::CookieJar,
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
sentry::SentryEventID,
|
||||
record_error,
|
||||
};
|
||||
use mas_data_model::UserAgent;
|
||||
use mas_jose::jwt::Jwt;
|
||||
@@ -77,16 +77,16 @@ pub(crate) enum RouteError {
|
||||
LinkNotFound,
|
||||
|
||||
/// Couldn't find the session on the link
|
||||
#[error("Session not found")]
|
||||
SessionNotFound,
|
||||
#[error("Session {0} not found")]
|
||||
SessionNotFound(Ulid),
|
||||
|
||||
/// Couldn't find the user
|
||||
#[error("User not found")]
|
||||
UserNotFound,
|
||||
#[error("User {0} not found")]
|
||||
UserNotFound(Ulid),
|
||||
|
||||
/// Couldn't find upstream provider
|
||||
#[error("Upstream provider not found")]
|
||||
ProviderNotFound,
|
||||
#[error("Upstream provider {0} not found")]
|
||||
ProviderNotFound(Ulid),
|
||||
|
||||
/// Required attribute rendered to an empty string
|
||||
#[error("Template {template:?} rendered to an empty string")]
|
||||
@@ -104,8 +104,8 @@ pub(crate) enum RouteError {
|
||||
},
|
||||
|
||||
/// Session was already consumed
|
||||
#[error("Session already consumed")]
|
||||
SessionConsumed,
|
||||
#[error("Session {0} already consumed")]
|
||||
SessionConsumed(Ulid),
|
||||
|
||||
#[error("Missing session cookie")]
|
||||
MissingCookie,
|
||||
@@ -129,14 +129,23 @@ impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let sentry_event_id = record_error!(
|
||||
self,
|
||||
Self::Internal(_)
|
||||
| Self::RequiredAttributeEmpty { .. }
|
||||
| Self::RequiredAttributeRender { .. }
|
||||
| Self::SessionNotFound(_)
|
||||
| Self::ProviderNotFound(_)
|
||||
| Self::UserNotFound(_)
|
||||
| Self::HomeserverConnection(_)
|
||||
);
|
||||
let response = match self {
|
||||
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
|
||||
Self::Internal(e) => FancyError::from(e).into_response(),
|
||||
e => FancyError::from(e).into_response(),
|
||||
};
|
||||
|
||||
(SentryEventID::from(event_id), response).into_response()
|
||||
(sentry_event_id, response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,7 +218,6 @@ impl ToFormState for FormData {
|
||||
name = "handlers.upstream_oauth2.link.get",
|
||||
fields(upstream_oauth_link.id = %link_id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
@@ -245,16 +253,16 @@ pub(crate) async fn get(
|
||||
.upstream_oauth_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::SessionNotFound)?;
|
||||
.ok_or(RouteError::SessionNotFound(session_id))?;
|
||||
|
||||
// This checks that we're in a browser session which is allowed to consume this
|
||||
// link: the upstream auth session should have been started in this browser.
|
||||
if upstream_session.link_id() != Some(link.id) {
|
||||
return Err(RouteError::SessionNotFound);
|
||||
return Err(RouteError::SessionNotFound(session_id));
|
||||
}
|
||||
|
||||
if upstream_session.is_consumed() {
|
||||
return Err(RouteError::SessionConsumed);
|
||||
return Err(RouteError::SessionConsumed(session_id));
|
||||
}
|
||||
|
||||
let (user_session_info, cookie_jar) = cookie_jar.session_info();
|
||||
@@ -289,7 +297,7 @@ pub(crate) async fn get(
|
||||
.user()
|
||||
.lookup(user_id)
|
||||
.await?
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
.ok_or(RouteError::UserNotFound(user_id))?;
|
||||
|
||||
let ctx = UpstreamExistingLinkContext::new(user)
|
||||
.with_session(user_session)
|
||||
@@ -315,7 +323,7 @@ pub(crate) async fn get(
|
||||
.user()
|
||||
.lookup(user_id)
|
||||
.await?
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
.ok_or(RouteError::UserNotFound(user_id))?;
|
||||
|
||||
// Check that the user is not locked or deactivated
|
||||
if user.deactivated_at.is_some() {
|
||||
@@ -377,7 +385,7 @@ pub(crate) async fn get(
|
||||
.upstream_oauth_provider()
|
||||
.lookup(link.provider_id)
|
||||
.await?
|
||||
.ok_or(RouteError::ProviderNotFound)?;
|
||||
.ok_or(RouteError::ProviderNotFound(link.provider_id))?;
|
||||
|
||||
let ctx = UpstreamRegister::new(link.clone(), provider.clone());
|
||||
|
||||
@@ -543,7 +551,6 @@ pub(crate) async fn get(
|
||||
name = "handlers.upstream_oauth2.link.post",
|
||||
fields(upstream_oauth_link.id = %link_id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
@@ -583,16 +590,16 @@ pub(crate) async fn post(
|
||||
.upstream_oauth_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::SessionNotFound)?;
|
||||
.ok_or(RouteError::SessionNotFound(session_id))?;
|
||||
|
||||
// This checks that we're in a browser session which is allowed to consume this
|
||||
// link: the upstream auth session should have been started in this browser.
|
||||
if upstream_session.link_id() != Some(link.id) {
|
||||
return Err(RouteError::SessionNotFound);
|
||||
return Err(RouteError::SessionNotFound(session_id));
|
||||
}
|
||||
|
||||
if upstream_session.is_consumed() {
|
||||
return Err(RouteError::SessionConsumed);
|
||||
return Err(RouteError::SessionConsumed(session_id));
|
||||
}
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
@@ -637,7 +644,7 @@ pub(crate) async fn post(
|
||||
.upstream_oauth_provider()
|
||||
.lookup(link.provider_id)
|
||||
.await?
|
||||
.ok_or(RouteError::ProviderNotFound)?;
|
||||
.ok_or(RouteError::ProviderNotFound(link.provider_id))?;
|
||||
|
||||
// Let's try to import the claims from the ID token
|
||||
let env = environment();
|
||||
|
||||
@@ -25,7 +25,7 @@ pub struct Params {
|
||||
action: Option<mas_router::AccountAction>,
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.views.app.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.views.app.get", skip_all)]
|
||||
pub async fn get(
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
State(templates): State<Templates>,
|
||||
@@ -74,7 +74,7 @@ pub async fn get(
|
||||
/// Like `get`, but allow anonymous access.
|
||||
/// Used for a subset of the account management paths.
|
||||
/// Needed for e.g. account recovery.
|
||||
#[tracing::instrument(name = "handlers.views.app.get_anonymous", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.views.app.get_anonymous", skip_all)]
|
||||
pub async fn get_anonymous(
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
State(templates): State<Templates>,
|
||||
|
||||
@@ -19,7 +19,7 @@ use crate::{
|
||||
session::{SessionOrFallback, load_session_or_fallback},
|
||||
};
|
||||
|
||||
#[tracing::instrument(name = "handlers.views.index.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.views.index.get", skip_all)]
|
||||
pub async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
|
||||
@@ -61,7 +61,7 @@ impl ToFormState for LoginForm {
|
||||
type Field = LoginFormField;
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.views.login.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.views.login.get", skip_all)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
@@ -127,7 +127,7 @@ pub(crate) async fn get(
|
||||
.await
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.views.login.post", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.views.login.post", skip_all)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
|
||||
@@ -18,7 +18,7 @@ use mas_storage::{BoxClock, BoxRepository, user::BrowserSessionRepository};
|
||||
|
||||
use crate::BoundActivityTracker;
|
||||
|
||||
#[tracing::instrument(name = "handlers.views.logout.post", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.views.logout.post", skip_all)]
|
||||
pub(crate) async fn post(
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
|
||||
@@ -20,7 +20,7 @@ mod cookie;
|
||||
pub(crate) mod password;
|
||||
pub(crate) mod steps;
|
||||
|
||||
#[tracing::instrument(name = "handlers.views.register.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.views.register.get", skip_all)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
|
||||
@@ -66,7 +66,7 @@ pub struct QueryParams {
|
||||
action: OptionalPostAuthAction,
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.views.password_register.get", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.views.password_register.get", skip_all)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
@@ -118,7 +118,7 @@ pub(crate) async fn get(
|
||||
Ok((cookie_jar, Html(content)).into_response())
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "handlers.views.password_register.post", skip_all, err)]
|
||||
#[tracing::instrument(name = "handlers.views.password_register.post", skip_all)]
|
||||
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
|
||||
@@ -49,7 +49,6 @@ impl ToFormState for DisplayNameForm {
|
||||
name = "handlers.views.register.steps.display_name.get",
|
||||
fields(user_registration.id = %id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
@@ -100,7 +99,6 @@ pub(crate) async fn get(
|
||||
name = "handlers.views.register.steps.display_name.post",
|
||||
fields(user_registration.id = %id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn post(
|
||||
mut rng: BoxRng,
|
||||
|
||||
@@ -42,7 +42,6 @@ static PASSWORD_REGISTER_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
|
||||
name = "handlers.views.register.steps.finish.get",
|
||||
fields(user_registration.id = %id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
|
||||
@@ -37,7 +37,6 @@ impl ToFormState for CodeForm {
|
||||
name = "handlers.views.register.steps.verify_email.get",
|
||||
fields(user_registration.id = %id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
@@ -104,7 +103,6 @@ pub(crate) async fn get(
|
||||
name = "handlers.views.account_email_verify.post",
|
||||
fields(user_email.id = %id),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
pub(crate) async fn post(
|
||||
clock: BoxClock,
|
||||
|
||||
@@ -17,7 +17,7 @@ futures-util.workspace = true
|
||||
http-body.workspace = true
|
||||
hyper = { workspace = true, features = ["server"] }
|
||||
hyper-util.workspace = true
|
||||
pin-project-lite = "0.2.16"
|
||||
pin-project-lite.workspace = true
|
||||
socket2 = "0.5.9"
|
||||
thiserror.workspace = true
|
||||
tokio.workspace = true
|
||||
@@ -27,6 +27,8 @@ tower.workspace = true
|
||||
tower-http.workspace = true
|
||||
tracing.workspace = true
|
||||
|
||||
mas-context.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow.workspace = true
|
||||
rustls-pemfile = "2.2.0"
|
||||
|
||||
@@ -18,6 +18,7 @@ use hyper_util::{
|
||||
server::conn::auto::Connection,
|
||||
service::TowerToHyperService,
|
||||
};
|
||||
use mas_context::LogContext;
|
||||
use pin_project_lite::pin_project;
|
||||
use thiserror::Error;
|
||||
use tokio_rustls::rustls::ServerConfig;
|
||||
@@ -107,12 +108,6 @@ impl<S> Server<S> {
|
||||
#[derive(Debug, Error)]
|
||||
#[non_exhaustive]
|
||||
enum AcceptError {
|
||||
#[error("failed to accept connection from the underlying socket")]
|
||||
Socket {
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
|
||||
#[error("failed to complete the TLS handshake")]
|
||||
TlsHandshake {
|
||||
#[source]
|
||||
@@ -133,10 +128,6 @@ enum AcceptError {
|
||||
}
|
||||
|
||||
impl AcceptError {
|
||||
fn socket(source: std::io::Error) -> Self {
|
||||
Self::Socket { source }
|
||||
}
|
||||
|
||||
fn tls_handshake(source: std::io::Error) -> Self {
|
||||
Self::TlsHandshake { source }
|
||||
}
|
||||
@@ -164,7 +155,6 @@ impl AcceptError {
|
||||
network.peer.address,
|
||||
network.peer.port,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn accept<S, B>(
|
||||
maybe_proxy_acceptor: &MaybeProxyAcceptor,
|
||||
@@ -357,12 +347,16 @@ pub async fn run_servers<S, B>(
|
||||
// Poll on the JoinSet to collect connections to serve
|
||||
res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
|
||||
match res {
|
||||
Some(Ok(Ok(connection))) => {
|
||||
tracing::trace!("Accepted connection");
|
||||
let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token());
|
||||
connection_tasks.spawn(conn);
|
||||
Some(Ok(Some(connection))) => {
|
||||
let token = soft_shutdown_token.child_token();
|
||||
connection_tasks.spawn(LogContext::new("http-serve").run(async move || {
|
||||
tracing::debug!("Accepted connection");
|
||||
if let Err(e) = AbortableConnection::new(connection, token).await {
|
||||
tracing::warn!(error = &*e as &dyn std::error::Error, "Failed to serve connection");
|
||||
}
|
||||
}));
|
||||
},
|
||||
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
|
||||
Some(Ok(None)) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
|
||||
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
|
||||
None => tracing::error!("Join set was polled even though it was empty"),
|
||||
}
|
||||
@@ -371,8 +365,7 @@ pub async fn run_servers<S, B>(
|
||||
// Poll on the JoinSet to collect finished connections
|
||||
res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
|
||||
match res {
|
||||
Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
|
||||
Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
|
||||
Some(Ok(())) => { /* Connection finished, any errors should be logged in in the spawned task */ },
|
||||
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
|
||||
None => tracing::error!("Join set was polled even though it was empty"),
|
||||
}
|
||||
@@ -385,11 +378,23 @@ pub async fn run_servers<S, B>(
|
||||
// Spawn the connection in the set, so we don't have to wait for the handshake to
|
||||
// accept the next connection. This allows us to keep track of active connections
|
||||
// and waiting on them for a graceful shutdown
|
||||
accept_tasks.spawn(async move {
|
||||
let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = res
|
||||
.map_err(AcceptError::socket)?;
|
||||
accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await
|
||||
});
|
||||
accept_tasks.spawn(LogContext::new("http-accept").run(async move || {
|
||||
let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = match res {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = &e as &dyn std::error::Error, "Failed to accept connection from the underlying socket");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
match accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await {
|
||||
Ok(connection) => Some(connection),
|
||||
Err(e) => {
|
||||
tracing::warn!(error = &e as &dyn std::error::Error, "Failed to accept connection");
|
||||
None
|
||||
}
|
||||
}
|
||||
}));
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -409,12 +414,16 @@ pub async fn run_servers<S, B>(
|
||||
// Poll on the JoinSet to collect connections to serve
|
||||
res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
|
||||
match res {
|
||||
Some(Ok(Ok(connection))) => {
|
||||
tracing::trace!("Accepted connection");
|
||||
let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token());
|
||||
connection_tasks.spawn(conn);
|
||||
Some(Ok(Some(connection))) => {
|
||||
let token = soft_shutdown_token.child_token();
|
||||
connection_tasks.spawn(LogContext::new("http-serve").run(async || {
|
||||
tracing::debug!("Accepted connection");
|
||||
if let Err(e) = AbortableConnection::new(connection, token).await {
|
||||
tracing::warn!(error = &*e as &dyn std::error::Error, "Failed to serve connection");
|
||||
}
|
||||
}));
|
||||
}
|
||||
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
|
||||
Some(Ok(None)) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
|
||||
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
|
||||
None => tracing::error!("Join set was polled even though it was empty"),
|
||||
}
|
||||
@@ -423,8 +432,7 @@ pub async fn run_servers<S, B>(
|
||||
// Poll on the JoinSet to collect finished connections
|
||||
res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
|
||||
match res {
|
||||
Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
|
||||
Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
|
||||
Some(Ok(())) => { /* Connection finished, any errors should be logged in in the spawned task */ },
|
||||
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
|
||||
None => tracing::error!("Join set was polled even though it was empty"),
|
||||
}
|
||||
|
||||
@@ -197,7 +197,7 @@ pub struct PolicyFactory {
|
||||
}
|
||||
|
||||
impl PolicyFactory {
|
||||
#[tracing::instrument(name = "policy.load", skip(source), err)]
|
||||
#[tracing::instrument(name = "policy.load", skip(source))]
|
||||
pub async fn load(
|
||||
mut source: impl AsyncRead + std::marker::Unpin,
|
||||
data: Data,
|
||||
@@ -283,7 +283,7 @@ impl PolicyFactory {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "policy.instantiate", skip_all, err)]
|
||||
#[tracing::instrument(name = "policy.instantiate", skip_all)]
|
||||
pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
|
||||
let data = self.dynamic_data.load();
|
||||
self.instantiate_with_data(&data.merged).await
|
||||
@@ -342,7 +342,6 @@ impl Policy {
|
||||
fields(
|
||||
%input.email,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_email(
|
||||
&mut self,
|
||||
@@ -364,7 +363,6 @@ impl Policy {
|
||||
input.username = input.username,
|
||||
input.email = input.email,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_register(
|
||||
&mut self,
|
||||
@@ -402,7 +400,6 @@ impl Policy {
|
||||
%input.scope,
|
||||
%input.client.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_authorization_grant(
|
||||
&mut self,
|
||||
|
||||
@@ -30,6 +30,7 @@ ulid.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
mas-context.workspace = true
|
||||
mas-data-model.workspace = true
|
||||
mas-email.workspace = true
|
||||
mas-i18n.workspace = true
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::{
|
||||
|
||||
#[async_trait]
|
||||
impl RunnableJob for CleanupExpiredTokensJob {
|
||||
#[tracing::instrument(name = "job.cleanup_expired_tokens", skip_all, err)]
|
||||
#[tracing::instrument(name = "job.cleanup_expired_tokens", skip_all)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
|
||||
let clock = state.clock();
|
||||
let mut repo = state.repository().await.map_err(JobError::retry)?;
|
||||
@@ -41,7 +41,7 @@ impl RunnableJob for CleanupExpiredTokensJob {
|
||||
|
||||
#[async_trait]
|
||||
impl RunnableJob for PruneStalePolicyDataJob {
|
||||
#[tracing::instrument(name = "job.prune_stale_policy_data", skip_all, err)]
|
||||
#[tracing::instrument(name = "job.prune_stale_policy_data", skip_all)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
|
||||
let mut repo = state.repository().await.map_err(JobError::retry)?;
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ impl RunnableJob for VerifyEmailJob {
|
||||
name = "job.verify_email",
|
||||
fields(user_email.id = %self.user_email_id()),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
async fn run(&self, _state: &State, _context: JobContext) -> Result<(), JobError> {
|
||||
// This job was for the old email verification flow, which has been replaced.
|
||||
@@ -39,7 +38,6 @@ impl RunnableJob for SendEmailAuthenticationCodeJob {
|
||||
name = "job.send_email_authentication_code",
|
||||
fields(user_email_authentication.id = %self.user_email_authentication_id()),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
|
||||
let clock = state.clock();
|
||||
|
||||
@@ -36,7 +36,6 @@ impl RunnableJob for ProvisionUserJob {
|
||||
name = "job.provision_user"
|
||||
fields(user.id = %self.user_id()),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
|
||||
let matrix = state.matrix_connection();
|
||||
@@ -103,7 +102,6 @@ impl RunnableJob for ProvisionDeviceJob {
|
||||
device.id = %self.device_id(),
|
||||
),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
|
||||
let mut repo = state.repository().await.map_err(JobError::retry)?;
|
||||
@@ -140,7 +138,6 @@ impl RunnableJob for DeleteDeviceJob {
|
||||
device.id = %self.device_id(),
|
||||
),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
|
||||
let mut rng = state.rng();
|
||||
@@ -172,7 +169,6 @@ impl RunnableJob for SyncDevicesJob {
|
||||
name = "job.sync_devices",
|
||||
fields(user.id = %self.user_id()),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
|
||||
let matrix = state.matrix_connection();
|
||||
|
||||
@@ -8,6 +8,7 @@ use std::{collections::HashMap, sync::Arc};
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use cron::Schedule;
|
||||
use mas_context::LogContext;
|
||||
use mas_storage::{
|
||||
Clock, RepositoryAccess, RepositoryError,
|
||||
queue::{InsertableJob, Job, JobMetadata, Worker},
|
||||
@@ -183,7 +184,7 @@ fn retry_delay(attempt: usize) -> Duration {
|
||||
Duration::milliseconds(2_i64.saturating_pow(attempt) * 5_000)
|
||||
}
|
||||
|
||||
type JobResult = Result<(), JobError>;
|
||||
type JobResult = (std::time::Duration, Result<(), JobError>);
|
||||
type JobFactory = Arc<dyn Fn(JobPayload) -> Box<dyn RunnableJob> + Send + Sync>;
|
||||
|
||||
struct ScheduleDefinition {
|
||||
@@ -252,7 +253,7 @@ impl QueueWorker {
|
||||
.await
|
||||
.map_err(QueueRunnerError::CommitTransaction)?;
|
||||
|
||||
tracing::info!("Registered worker");
|
||||
tracing::info!(worker.id = %registration.id, "Registered worker");
|
||||
let now = clock.now();
|
||||
|
||||
let wakeup_reason = METER
|
||||
@@ -337,7 +338,9 @@ impl QueueWorker {
|
||||
self.setup_schedules().await?;
|
||||
|
||||
while !self.cancellation_token.is_cancelled() {
|
||||
self.run_loop().await?;
|
||||
LogContext::new("worker-run-loop")
|
||||
.run(|| self.run_loop())
|
||||
.await?;
|
||||
}
|
||||
|
||||
self.shutdown().await?;
|
||||
@@ -345,7 +348,7 @@ impl QueueWorker {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "worker.setup_schedules", skip_all, err)]
|
||||
#[tracing::instrument(name = "worker.setup_schedules", skip_all)]
|
||||
pub async fn setup_schedules(&mut self) -> Result<(), QueueRunnerError> {
|
||||
let schedules: Vec<_> = self.schedules.iter().map(|s| s.schedule_name).collect();
|
||||
|
||||
@@ -369,7 +372,7 @@ impl QueueWorker {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "worker.run_loop", skip_all, err)]
|
||||
#[tracing::instrument(name = "worker.run_loop", skip_all)]
|
||||
async fn run_loop(&mut self) -> Result<(), QueueRunnerError> {
|
||||
self.wait_until_wakeup().await?;
|
||||
|
||||
@@ -390,7 +393,7 @@ impl QueueWorker {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "worker.shutdown", skip_all, err)]
|
||||
#[tracing::instrument(name = "worker.shutdown", skip_all)]
|
||||
async fn shutdown(&mut self) -> Result<(), QueueRunnerError> {
|
||||
tracing::info!("Shutting down worker");
|
||||
|
||||
@@ -435,7 +438,7 @@ impl QueueWorker {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "worker.wait_until_wakeup", skip_all, err)]
|
||||
#[tracing::instrument(name = "worker.wait_until_wakeup", skip_all)]
|
||||
async fn wait_until_wakeup(&mut self) -> Result<(), QueueRunnerError> {
|
||||
// This is to make sure we wake up every second to do the maintenance tasks
|
||||
// We add a little bit of random jitter to the duration, so that we don't get
|
||||
@@ -484,7 +487,6 @@ impl QueueWorker {
|
||||
name = "worker.tick",
|
||||
skip_all,
|
||||
fields(worker.id = %self.registration.id),
|
||||
err,
|
||||
)]
|
||||
async fn tick(&mut self) -> Result<(), QueueRunnerError> {
|
||||
tracing::debug!("Tick");
|
||||
@@ -583,7 +585,7 @@ impl QueueWorker {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "worker.perform_leader_duties", skip_all, err)]
|
||||
#[tracing::instrument(name = "worker.perform_leader_duties", skip_all)]
|
||||
async fn perform_leader_duties(&mut self) -> Result<(), QueueRunnerError> {
|
||||
// This should have been checked by the caller, but better safe than sorry
|
||||
if !self.am_i_leader {
|
||||
@@ -771,16 +773,86 @@ impl JobTracker {
|
||||
fn spawn_job(&mut self, state: State, context: JobContext, payload: JobPayload) {
|
||||
let factory = self.factories.get(context.queue_name.as_str()).cloned();
|
||||
let task = {
|
||||
let log_context = LogContext::new(format!("job-{}", context.queue_name));
|
||||
let context = context.clone();
|
||||
let span = context.span();
|
||||
async move {
|
||||
// We should never crash, but in case we do, we do that in the task and
|
||||
// don't crash the worker
|
||||
let job = factory.expect("unknown job factory")(payload);
|
||||
tracing::info!("Running job");
|
||||
job.run(&state, context).await
|
||||
}
|
||||
.instrument(span)
|
||||
log_context
|
||||
.run(async move || {
|
||||
// We should never crash, but in case we do, we do that in the task and
|
||||
// don't crash the worker
|
||||
let job = factory.expect("unknown job factory")(payload);
|
||||
tracing::info!(
|
||||
job.id = %context.id,
|
||||
job.queue.name = %context.queue_name,
|
||||
job.attempt = %context.attempt,
|
||||
"Running job"
|
||||
);
|
||||
let result = job.run(&state, context.clone()).await;
|
||||
|
||||
let Some(log_context) = LogContext::current() else {
|
||||
// This should never happen, but if it does it's fine: we're recovering fine
|
||||
// from panics in those tasks
|
||||
panic!("Missing log context, this should never happen");
|
||||
};
|
||||
|
||||
let context_stats = log_context.stats();
|
||||
|
||||
// We log the result here so that it's attached to the right span & log context
|
||||
match &result {
|
||||
Ok(()) => {
|
||||
tracing::info!(
|
||||
job.id = %context.id,
|
||||
job.queue.name = %context.queue_name,
|
||||
job.attempt = %context.attempt,
|
||||
"Job completed [{context_stats}]"
|
||||
);
|
||||
}
|
||||
|
||||
Err(JobError {
|
||||
decision: JobErrorDecision::Fail,
|
||||
error,
|
||||
}) => {
|
||||
tracing::error!(
|
||||
error = &**error as &dyn std::error::Error,
|
||||
job.id = %context.id,
|
||||
job.queue.name = %context.queue_name,
|
||||
job.attempt = %context.attempt,
|
||||
"Job failed, not retrying [{context_stats}]"
|
||||
);
|
||||
}
|
||||
|
||||
Err(JobError {
|
||||
decision: JobErrorDecision::Retry,
|
||||
error,
|
||||
}) if context.attempt < MAX_ATTEMPTS => {
|
||||
let delay = retry_delay(context.attempt);
|
||||
tracing::warn!(
|
||||
error = &**error as &dyn std::error::Error,
|
||||
job.id = %context.id,
|
||||
job.queue.name = %context.queue_name,
|
||||
job.attempt = %context.attempt,
|
||||
"Job failed, will retry in {}s [{context_stats}]",
|
||||
delay.num_seconds()
|
||||
);
|
||||
}
|
||||
|
||||
Err(JobError {
|
||||
decision: JobErrorDecision::Retry,
|
||||
error,
|
||||
}) => {
|
||||
tracing::error!(
|
||||
error = &**error as &dyn std::error::Error,
|
||||
job.id = %context.id,
|
||||
job.queue.name = %context.queue_name,
|
||||
job.attempt = %context.attempt,
|
||||
"Job failed too many times, abandonning [{context_stats}]"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
(context_stats.elapsed, result)
|
||||
})
|
||||
.instrument(span)
|
||||
};
|
||||
|
||||
self.in_flight_jobs.add(
|
||||
@@ -837,15 +909,10 @@ impl JobTracker {
|
||||
}
|
||||
}
|
||||
|
||||
// XXX: the time measurement isn't accurate, as it would include the
|
||||
// time spent between the task finishing, and us processing the result.
|
||||
// It's fine for now, as it at least gives us an idea of how many tasks
|
||||
// we run, and what their status is
|
||||
|
||||
while let Some(result) = self.last_join_result.take() {
|
||||
match result {
|
||||
// The job succeeded
|
||||
Ok((id, Ok(()))) => {
|
||||
// The job succeeded. The logging and time measurement is already done in the task
|
||||
Ok((id, (elapsed, Ok(())))) => {
|
||||
let context = self
|
||||
.job_contexts
|
||||
.remove(&id)
|
||||
@@ -856,22 +923,9 @@ impl JobTracker {
|
||||
&[KeyValue::new("job.queue.name", context.queue_name.clone())],
|
||||
);
|
||||
|
||||
let elapsed = context
|
||||
.start
|
||||
.elapsed()
|
||||
.as_millis()
|
||||
.try_into()
|
||||
.unwrap_or(u64::MAX);
|
||||
tracing::info!(
|
||||
job.id = %context.id,
|
||||
job.queue.name = %context.queue_name,
|
||||
job.attempt = %context.attempt,
|
||||
job.elapsed = format!("{elapsed}ms"),
|
||||
"Job completed"
|
||||
);
|
||||
|
||||
let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
|
||||
self.job_processing_time.record(
|
||||
elapsed,
|
||||
elapsed_ms,
|
||||
&[
|
||||
KeyValue::new("job.queue.name", context.queue_name),
|
||||
KeyValue::new("job.result", "success"),
|
||||
@@ -883,8 +937,8 @@ impl JobTracker {
|
||||
.await?;
|
||||
}
|
||||
|
||||
// The job failed
|
||||
Ok((id, Err(e))) => {
|
||||
// The job failed. The logging and time measurement is already done in the task
|
||||
Ok((id, (elapsed, Err(e)))) => {
|
||||
let context = self
|
||||
.job_contexts
|
||||
.remove(&id)
|
||||
@@ -900,26 +954,11 @@ impl JobTracker {
|
||||
.mark_as_failed(clock, context.id, &reason)
|
||||
.await?;
|
||||
|
||||
let elapsed = context
|
||||
.start
|
||||
.elapsed()
|
||||
.as_millis()
|
||||
.try_into()
|
||||
.unwrap_or(u64::MAX);
|
||||
|
||||
let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
|
||||
match e.decision {
|
||||
JobErrorDecision::Fail => {
|
||||
tracing::error!(
|
||||
error = &e as &dyn std::error::Error,
|
||||
job.id = %context.id,
|
||||
job.queue.name = %context.queue_name,
|
||||
job.attempt = %context.attempt,
|
||||
job.elapsed = format!("{elapsed}ms"),
|
||||
"Job failed, not retrying"
|
||||
);
|
||||
|
||||
self.job_processing_time.record(
|
||||
elapsed,
|
||||
elapsed_ms,
|
||||
&[
|
||||
KeyValue::new("job.queue.name", context.queue_name),
|
||||
KeyValue::new("job.result", "failed"),
|
||||
@@ -928,50 +967,31 @@ impl JobTracker {
|
||||
);
|
||||
}
|
||||
|
||||
JobErrorDecision::Retry if context.attempt < MAX_ATTEMPTS => {
|
||||
self.job_processing_time.record(
|
||||
elapsed_ms,
|
||||
&[
|
||||
KeyValue::new("job.queue.name", context.queue_name),
|
||||
KeyValue::new("job.result", "failed"),
|
||||
KeyValue::new("job.decision", "retry"),
|
||||
],
|
||||
);
|
||||
|
||||
let delay = retry_delay(context.attempt);
|
||||
repo.queue_job()
|
||||
.retry(&mut *rng, clock, context.id, delay)
|
||||
.await?;
|
||||
}
|
||||
|
||||
JobErrorDecision::Retry => {
|
||||
if context.attempt < MAX_ATTEMPTS {
|
||||
let delay = retry_delay(context.attempt);
|
||||
tracing::warn!(
|
||||
error = &e as &dyn std::error::Error,
|
||||
job.id = %context.id,
|
||||
job.queue.name = %context.queue_name,
|
||||
job.attempt = %context.attempt,
|
||||
job.elapsed = format!("{elapsed}ms"),
|
||||
"Job failed, will retry in {}s",
|
||||
delay.num_seconds()
|
||||
);
|
||||
|
||||
self.job_processing_time.record(
|
||||
elapsed,
|
||||
&[
|
||||
KeyValue::new("job.queue.name", context.queue_name),
|
||||
KeyValue::new("job.result", "failed"),
|
||||
KeyValue::new("job.decision", "retry"),
|
||||
],
|
||||
);
|
||||
|
||||
repo.queue_job()
|
||||
.retry(&mut *rng, clock, context.id, delay)
|
||||
.await?;
|
||||
} else {
|
||||
tracing::error!(
|
||||
error = &e as &dyn std::error::Error,
|
||||
job.id = %context.id,
|
||||
job.queue.name = %context.queue_name,
|
||||
job.attempt = %context.attempt,
|
||||
job.elapsed = format!("{elapsed}ms"),
|
||||
"Job failed too many times, abandonning"
|
||||
);
|
||||
|
||||
self.job_processing_time.record(
|
||||
elapsed,
|
||||
&[
|
||||
KeyValue::new("job.queue.name", context.queue_name),
|
||||
KeyValue::new("job.result", "failed"),
|
||||
KeyValue::new("job.decision", "abandon"),
|
||||
],
|
||||
);
|
||||
}
|
||||
self.job_processing_time.record(
|
||||
elapsed_ms,
|
||||
&[
|
||||
KeyValue::new("job.queue.name", context.queue_name),
|
||||
KeyValue::new("job.result", "failed"),
|
||||
KeyValue::new("job.decision", "abandon"),
|
||||
],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -989,6 +1009,8 @@ impl JobTracker {
|
||||
&[KeyValue::new("job.queue.name", context.queue_name.clone())],
|
||||
);
|
||||
|
||||
// This measurement is not accurate as it includes the time processing the jobs,
|
||||
// but it's fine, it's only for panicked tasks
|
||||
let elapsed = context
|
||||
.start
|
||||
.elapsed()
|
||||
@@ -1003,7 +1025,7 @@ impl JobTracker {
|
||||
|
||||
if context.attempt < MAX_ATTEMPTS {
|
||||
let delay = retry_delay(context.attempt);
|
||||
tracing::warn!(
|
||||
tracing::error!(
|
||||
error = &e as &dyn std::error::Error,
|
||||
job.id = %context.id,
|
||||
job.queue.name = %context.queue_name,
|
||||
|
||||
@@ -32,7 +32,6 @@ impl RunnableJob for SendAccountRecoveryEmailsJob {
|
||||
user_recovery_session.email,
|
||||
),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
|
||||
let clock = state.clock();
|
||||
|
||||
@@ -27,7 +27,6 @@ impl RunnableJob for DeactivateUserJob {
|
||||
name = "job.deactivate_user"
|
||||
fields(user.id = %self.user_id(), erase = %self.hs_erase()),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
|
||||
let clock = state.clock();
|
||||
@@ -118,7 +117,6 @@ impl RunnableJob for ReactivateUserJob {
|
||||
name = "job.reactivate_user",
|
||||
fields(user.id = %self.user_id()),
|
||||
skip_all,
|
||||
err,
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
|
||||
let matrix = state.matrix_connection();
|
||||
|
||||
@@ -138,7 +138,6 @@ impl Templates {
|
||||
name = "templates.load",
|
||||
skip_all,
|
||||
fields(%path),
|
||||
err,
|
||||
)]
|
||||
pub async fn load(
|
||||
path: Utf8PathBuf,
|
||||
@@ -258,7 +257,6 @@ impl Templates {
|
||||
name = "templates.reload",
|
||||
skip_all,
|
||||
fields(path = %self.path),
|
||||
err,
|
||||
)]
|
||||
pub async fn reload(&self) -> Result<(), TemplateLoadingError> {
|
||||
let (translator, environment) = Self::load_(
|
||||
|
||||
@@ -19,4 +19,4 @@ tower.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
opentelemetry-http.workspace = true
|
||||
opentelemetry-semantic-conventions.workspace = true
|
||||
pin-project-lite = "0.2.16"
|
||||
pin-project-lite.workspace = true
|
||||
|
||||
Reference in New Issue
Block a user