New logging output (#4424)

This commit is contained in:
Quentin Gliech
2025-04-23 18:44:41 +02:00
committed by GitHub
95 changed files with 1457 additions and 580 deletions

21
Cargo.lock generated
View File

@@ -3159,12 +3159,14 @@ dependencies = [
"dotenvy",
"figment",
"futures-util",
"headers",
"http-body-util",
"hyper",
"ipnetwork",
"itertools 0.14.0",
"listenfd",
"mas-config",
"mas-context",
"mas-data-model",
"mas-email",
"mas-handlers",
@@ -3246,6 +3248,22 @@ dependencies = [
"url",
]
[[package]]
name = "mas-context"
version = "0.15.0"
dependencies = [
"console",
"opentelemetry",
"pin-project-lite",
"quanta",
"tokio",
"tower-layer",
"tower-service",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
]
[[package]]
name = "mas-data-model"
version = "0.15.0"
@@ -3306,6 +3324,7 @@ dependencies = [
"lettre",
"mas-axum-utils",
"mas-config",
"mas-context",
"mas-data-model",
"mas-http",
"mas-i18n",
@@ -3504,6 +3523,7 @@ dependencies = [
"http-body",
"hyper",
"hyper-util",
"mas-context",
"pin-project-lite",
"rustls-pemfile",
"socket2",
@@ -3674,6 +3694,7 @@ dependencies = [
"async-trait",
"chrono",
"cron",
"mas-context",
"mas-data-model",
"mas-email",
"mas-i18n",

View File

@@ -30,6 +30,7 @@ broken_intra_doc_links = "deny"
mas-axum-utils = { path = "./crates/axum-utils/", version = "=0.15.0" }
mas-cli = { path = "./crates/cli/", version = "=0.15.0" }
mas-config = { path = "./crates/config/", version = "=0.15.0" }
mas-context = { path = "./crates/context/", version = "=0.15.0" }
mas-data-model = { path = "./crates/data-model/", version = "=0.15.0" }
mas-email = { path = "./crates/email/", version = "=0.15.0" }
mas-graphql = { path = "./crates/graphql/", version = "=0.15.0" }
@@ -111,6 +112,10 @@ version = "1.1.9"
[workspace.dependencies.compact_str]
version = "0.9.0"
# Terminal formatting
[workspace.dependencies.console]
version = "0.15.11"
# Time utilities
[workspace.dependencies.chrono]
version = "0.4.40"
@@ -248,6 +253,10 @@ features = ["std"]
version = "0.7.0"
features = ["std"]
# Pin projection
[workspace.dependencies.pin-project-lite]
version = "0.2.16"
# PKCS#1 encoding
[workspace.dependencies.pkcs1]
version = "0.7.5"
@@ -258,6 +267,10 @@ features = ["std"]
version = "0.10.2"
features = ["std", "pkcs5", "encryption"]
# High-precision clock
[workspace.dependencies.quanta]
version = "0.12.5"
# Random values
[workspace.dependencies.rand]
version = "0.8.5"
@@ -374,6 +387,14 @@ features = ["rt"]
version = "0.5.2"
features = ["util"]
# Tower service trait
[workspace.dependencies.tower-service]
version = "0.3.3"
# Tower layer trait
[workspace.dependencies.tower-layer]
version = "0.3.3"
# Tower HTTP layers
[workspace.dependencies.tower-http]
version = "0.6.2"

View File

@@ -28,6 +28,8 @@ use serde::{Deserialize, de::DeserializeOwned};
use serde_json::Value;
use thiserror::Error;
use crate::record_error;
static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
#[derive(Deserialize)]
@@ -97,7 +99,7 @@ impl Credentials {
/// # Errors
///
/// Returns an error if the credentials are invalid.
#[tracing::instrument(skip_all, err)]
#[tracing::instrument(skip_all)]
pub async fn verify(
&self,
http_client: &reqwest::Client,
@@ -144,7 +146,7 @@ impl Credentials {
let jwks = fetch_jwks(http_client, jwks)
.await
.map_err(|_| CredentialsVerificationError::JwksFetchFailed)?;
.map_err(CredentialsVerificationError::JwksFetchFailed)?;
jwt.verify_with_jwks(&jwks)
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
@@ -214,7 +216,18 @@ pub enum CredentialsVerificationError {
InvalidAssertionSignature,
#[error("failed to fetch jwks")]
JwksFetchFailed,
JwksFetchFailed(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
}
impl CredentialsVerificationError {
/// Returns true if the error is an internal error, not caused by the client
#[must_use]
pub fn is_internal(&self) -> bool {
matches!(
self,
Self::DecryptionError | Self::InvalidClientConfig | Self::JwksFetchFailed(_)
)
}
}
#[derive(Debug, PartialEq, Eq)]
@@ -231,23 +244,40 @@ impl<F> ClientAuthorization<F> {
}
}
#[derive(Debug)]
#[derive(Debug, Error)]
pub enum ClientAuthorizationError {
#[error("Invalid Authorization header")]
InvalidHeader,
BadForm(FailedToDeserializeForm),
#[error("Could not deserialize request body")]
BadForm(#[source] FailedToDeserializeForm),
#[error("client_id in form ({form:?}) does not match credential ({credential:?})")]
ClientIdMismatch { credential: String, form: String },
#[error("Unsupported client_assertion_type: {client_assertion_type}")]
UnsupportedClientAssertion { client_assertion_type: String },
#[error("No credentials were presented")]
MissingCredentials,
#[error("Invalid request")]
InvalidRequest,
#[error("Invalid client_assertion")]
InvalidAssertion,
#[error(transparent)]
Internal(Box<dyn std::error::Error>),
}
impl IntoResponse for ClientAuthorizationError {
fn into_response(self) -> axum::response::Response {
match self {
let sentry_event_id = record_error!(self, Self::Internal(_));
match &self {
ClientAuthorizationError::InvalidHeader => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(ClientError::new(
ClientErrorCode::InvalidRequest,
"Invalid Authorization header",
@@ -256,39 +286,34 @@ impl IntoResponse for ClientAuthorizationError {
ClientAuthorizationError::BadForm(err) => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(
ClientError::from(ClientErrorCode::InvalidRequest)
.with_description(format!("{err}")),
),
),
ClientAuthorizationError::ClientIdMismatch { form, credential } => {
let description = format!(
"client_id in form ({form:?}) does not match credential ({credential:?})"
);
(
StatusCode::BAD_REQUEST,
Json(
ClientError::from(ClientErrorCode::InvalidGrant)
.with_description(description),
),
)
}
ClientAuthorizationError::UnsupportedClientAssertion {
client_assertion_type,
} => (
ClientAuthorizationError::ClientIdMismatch { .. } => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(
ClientError::from(ClientErrorCode::InvalidRequest).with_description(format!(
"Unsupported client_assertion_type: {client_assertion_type}",
)),
ClientError::from(ClientErrorCode::InvalidGrant)
.with_description(format!("{self}")),
),
),
ClientAuthorizationError::UnsupportedClientAssertion { .. } => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(
ClientError::from(ClientErrorCode::InvalidRequest)
.with_description(format!("{self}")),
),
),
ClientAuthorizationError::MissingCredentials => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(ClientError::new(
ClientErrorCode::InvalidRequest,
"No credentials were presented",
@@ -297,11 +322,13 @@ impl IntoResponse for ClientAuthorizationError {
ClientAuthorizationError::InvalidRequest => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
),
ClientAuthorizationError::InvalidAssertion => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(ClientError::new(
ClientErrorCode::InvalidRequest,
"Invalid client_assertion",
@@ -310,6 +337,7 @@ impl IntoResponse for ClientAuthorizationError {
ClientAuthorizationError::Internal(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
sentry_event_id,
Json(
ClientError::from(ClientErrorCode::ServerError)
.with_description(format!("{e}")),

View File

@@ -7,6 +7,8 @@
use axum::response::{IntoResponse, Response};
use http::StatusCode;
use crate::record_error;
/// A simple wrapper around an error that implements [`IntoResponse`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
@@ -14,10 +16,16 @@ pub struct ErrorWrapper<T>(#[from] pub T);
impl<T> IntoResponse for ErrorWrapper<T>
where
T: std::error::Error,
T: std::error::Error + 'static,
{
fn into_response(self) -> Response {
// TODO: make this a bit more user friendly
(StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
let sentry_event_id = record_error!(self.0);
(
StatusCode::INTERNAL_SERVER_ERROR,
sentry_event_id,
self.0.to_string(),
)
.into_response()
}
}

View File

@@ -54,12 +54,13 @@ impl<E: std::fmt::Debug + std::fmt::Display> From<E> for FancyError {
impl IntoResponse for FancyError {
fn into_response(self) -> Response {
tracing::error!(message = %self.context);
let error = format!("{}", self.context);
let event_id = sentry::capture_message(&error, sentry::Level::Error);
let event_id = SentryEventID::for_last_event();
(
StatusCode::INTERNAL_SERVER_ERROR,
TypedHeader(ContentType::text()),
SentryEventID::from(event_id),
event_id,
Extension(self.context),
error,
)

View File

@@ -13,6 +13,13 @@ use sentry::types::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SentryEventID(Uuid);
impl SentryEventID {
/// Create a new Sentry event ID header for the last event on the hub.
pub fn for_last_event() -> Option<Self> {
sentry::last_event_id().map(Self)
}
}
impl From<Uuid> for SentryEventID {
fn from(uuid: Uuid) -> Self {
Self(uuid)
@@ -28,3 +35,31 @@ impl IntoResponseParts for SentryEventID {
Ok(res)
}
}
/// Record an error. It will emit a tracing event with the error level if
/// matches the pattern, warning otherwise. It also returns the Sentry event ID
/// if the error was recorded.
#[macro_export]
macro_rules! record_error {
($error:expr, !) => {{
tracing::warn!(message = &$error as &dyn std::error::Error);
Option::<$crate::sentry::SentryEventID>::None
}};
($error:expr) => {{
tracing::error!(message = &$error as &dyn std::error::Error);
// With the `sentry-tracing` integration, Sentry should have
// captured an error, so let's extract the last event ID from the
// current hub
$crate::sentry::SentryEventID::for_last_event()
}};
($error:expr, $pattern:pat) => {
if let $pattern = $error {
record_error!($error)
} else {
record_error!($error, !)
}
};
}

View File

@@ -27,6 +27,7 @@ dialoguer = { version = "0.11.0", default-features = false, features = [
dotenvy = "0.15.7"
figment.workspace = true
futures-util.workspace = true
headers.workspace = true
http-body-util.workspace = true
hyper.workspace = true
ipnetwork = "0.20.0"
@@ -66,6 +67,7 @@ sentry-tracing.workspace = true
sentry-tower.workspace = true
mas-config.workspace = true
mas-context.workspace = true
mas-data-model.workspace = true
mas-email.workspace = true
mas-handlers.workspace = true

View File

@@ -8,6 +8,7 @@ use std::{convert::Infallible, net::IpAddr, sync::Arc, time::Instant};
use axum::extract::{FromRef, FromRequestParts};
use ipnetwork::IpNetwork;
use mas_context::LogContext;
use mas_data_model::SiteConfig;
use mas_handlers::{
ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper, GraphQLSchema, Limiter,
@@ -92,35 +93,36 @@ impl AppState {
let http_client = self.http_client.clone();
tokio::spawn(
async move {
let conn = match pool.acquire().await {
Ok(conn) => conn,
Err(e) => {
LogContext::new("metadata-cache-warmup")
.run(async move || {
let conn = match pool.acquire().await {
Ok(conn) => conn,
Err(e) => {
tracing::error!(
error = &e as &dyn std::error::Error,
"Failed to acquire a database connection"
);
return;
}
};
let mut repo = PgRepository::from_conn(conn);
if let Err(e) = metadata_cache
.warm_up_and_run(
&http_client,
std::time::Duration::from_secs(60 * 15),
&mut repo,
)
.await
{
tracing::error!(
error = &e as &dyn std::error::Error,
"Failed to acquire a database connection"
"Failed to warm up the metadata cache"
);
return;
}
};
let mut repo = PgRepository::from_conn(conn);
if let Err(e) = metadata_cache
.warm_up_and_run(
&http_client,
std::time::Duration::from_secs(60 * 15),
&mut repo,
)
.await
{
tracing::error!(
error = &e as &dyn std::error::Error,
"Failed to warm up the metadata cache"
);
}
}
.instrument(tracing::info_span!("metadata_cache.background_warmup")),
})
.instrument(tracing::info_span!("metadata_cache.background_warmup")),
);
}
}

View File

@@ -143,7 +143,8 @@ impl Options {
prune,
dry_run,
)
.await?;
.await
.context("could not sync the configuration with the database")?;
}
}

View File

@@ -255,7 +255,13 @@ impl Options {
};
repo.into_inner().commit().await?;
info!(?email, "Email added");
info!(
%user.id,
%user.username,
%email.id,
%email.email,
"Email added"
);
Ok(ExitCode::SUCCESS)
}

View File

@@ -13,6 +13,7 @@ use itertools::Itertools;
use mas_config::{
AppConfig, ClientsConfig, ConfigurationSection, ConfigurationSectionExt, UpstreamOAuth2Config,
};
use mas_context::LogContext;
use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache};
use mas_listener::server::Server;
use mas_router::UrlBuilder;
@@ -112,7 +113,8 @@ impl Options {
false,
false,
)
.await?;
.await
.context("could not sync the configuration with the database")?;
}
// Initialize the key store
@@ -316,11 +318,13 @@ impl Options {
shutdown
.task_tracker()
.spawn(mas_listener::server::run_servers(
servers,
shutdown.soft_shutdown_token(),
shutdown.hard_shutdown_token(),
));
.spawn(LogContext::new("run-servers").run(|| {
mas_listener::server::run_servers(
servers,
shutdown.soft_shutdown_token(),
shutdown.hard_shutdown_token(),
)
}));
let exit_code = shutdown.run().await;

View File

@@ -146,7 +146,8 @@ impl Options {
// Not a dry run — we do want to create the providers in the database
false,
)
.await?;
.await
.context("could not sync the configuration with the database")?;
}
let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(mas_connection)

View File

@@ -91,8 +91,7 @@ async fn try_main() -> anyhow::Result<ExitCode> {
let (log_writer, _guard) = tracing_appender::non_blocking(output);
let fmt_layer = tracing_subscriber::fmt::layer()
.with_writer(log_writer)
.with_file(true)
.with_line_number(true)
.event_format(mas_context::EventFormatter)
.with_ansi(with_ansi);
let filter_layer = EnvFilter::try_from_default_env()
.or_else(|_| EnvFilter::try_new("info"))
@@ -129,9 +128,11 @@ async fn try_main() -> anyhow::Result<ExitCode> {
let sentry_layer = sentry.is_enabled().then(|| {
sentry_tracing::layer().event_filter(|md| {
// All the spans in the handlers module send their data to Sentry themselves, so
// we only create breadcrumbs for them, instead of full events
if md.target().starts_with("mas_handlers::") {
// By default, Sentry records all events as breadcrumbs, except errors.
//
// Because we're emitting error events for 5xx responses, we need to exclude
// them and also record them as breadcrumbs.
if md.name() == "http.server.response" {
EventFilter::Breadcrumb
} else {
sentry_tracing::default_event_filter(md)

View File

@@ -16,12 +16,14 @@ use axum::{
error_handling::HandleErrorLayer,
extract::{FromRef, MatchedPath},
};
use headers::{HeaderMapExt as _, UserAgent};
use hyper::{
Method, Request, Response, StatusCode, Version,
header::{CACHE_CONTROL, HeaderValue, USER_AGENT},
};
use listenfd::ListenFd;
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
use mas_context::LogContext;
use mas_listener::{ConnectionInfo, unix_or_tcp::UnixOrTcpListener};
use mas_router::Route;
use mas_templates::Templates;
@@ -170,6 +172,45 @@ fn on_http_response_labels<B>(res: &Response<B>) -> Vec<KeyValue> {
)]
}
async fn log_response_middleware(
request: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
let user_agent: Option<UserAgent> = request.headers().typed_get();
let user_agent = user_agent.as_ref().map_or("-", |u| u.as_str());
let method = otel_http_method(&request);
let path = request.uri().path().to_owned();
let version = otel_net_protocol_version(&request);
let response = next.run(request).await;
let Some(log_context) = LogContext::current() else {
tracing::error!("Missing log context for request, this is a bug!");
return response;
};
let stats = log_context.stats();
let status_code = response.status();
match status_code.as_u16() {
100..=399 => tracing::info!(
name: "http.server.response",
"\"{method} {path} HTTP/{version}\" {status_code} {user_agent:?} [{stats}]",
),
400..=499 => tracing::warn!(
name: "http.server.response",
"\"{method} {path} HTTP/{version}\" {status_code} {user_agent:?} [{stats}]",
),
500..=599 => tracing::error!(
name: "http.server.response",
"\"{method} {path} HTTP/{version}\" {status_code} {user_agent:?} [{stats}]",
),
_ => { /* This shouldn't happen */ }
}
response
}
pub fn build_router(
state: AppState,
resources: &[HttpResource],
@@ -277,8 +318,12 @@ pub fn build_router(
span.record("otel.status_code", "OK");
}),
)
.layer(SentryHttpLayer::new())
.layer(axum::middleware::from_fn(log_response_middleware))
.layer(mas_context::LogContextLayer::new(|req| {
otel_http_method(req).into()
}))
.layer(NewSentryLayer::new_from_top())
.layer(SentryHttpLayer::with_transaction())
.with_state(state)
}

View File

@@ -62,7 +62,7 @@ fn map_claims_imports(
}
}
#[tracing::instrument(name = "config.sync", skip_all, err(Debug))]
#[tracing::instrument(name = "config.sync", skip_all)]
pub async fn config_sync(
upstream_oauth2_config: UpstreamOAuth2Config,
clients_config: ClientsConfig,
@@ -175,11 +175,11 @@ pub async fn config_sync(
let _span = info_span!("provider", %provider.id).entered();
if existing_enabled_ids.contains(&provider.id) {
info!("Updating provider");
info!(provider.id = %provider.id, "Updating provider");
} else if existing_disabled.contains_key(&provider.id) {
info!("Enabling and updating provider");
info!(provider.id = %provider.id, "Enabling and updating provider");
} else {
info!("Adding provider");
info!(provider.id = %provider.id, "Adding provider");
}
if dry_run {
@@ -252,15 +252,15 @@ pub async fn config_sync(
if discovery_mode.is_disabled() {
if provider.authorization_endpoint.is_none() {
error!("Provider has discovery disabled but no authorization endpoint set");
error!(provider.id = %provider.id, "Provider has discovery disabled but no authorization endpoint set");
}
if provider.token_endpoint.is_none() {
error!("Provider has discovery disabled but no token endpoint set");
error!(provider.id = %provider.id, "Provider has discovery disabled but no token endpoint set");
}
if provider.jwks_uri.is_none() {
warn!("Provider has discovery disabled but no JWKS URI set");
warn!(provider.id = %provider.id, "Provider has discovery disabled but no JWKS URI set");
}
}
@@ -347,9 +347,9 @@ pub async fn config_sync(
for client in clients_config {
let _span = info_span!("client", client.id = %client.client_id).entered();
if existing_ids.contains(&client.client_id) {
info!("Updating client");
info!(client.id = %client.client_id, "Updating client");
} else {
info!("Adding client");
info!(client.id = %client.client_id, "Adding client");
}
if dry_run {

View File

@@ -12,6 +12,7 @@ use mas_config::{
EmailTransportKind, ExperimentalConfig, HomeserverKind, MatrixConfig, PasswordsConfig,
PolicyConfig, TemplatesConfig,
};
use mas_context::LogContext;
use mas_data_model::{SessionExpirationConfig, SiteConfig};
use mas_email::{MailTransport, Mailer};
use mas_handlers::passwords::PasswordManager;
@@ -21,7 +22,7 @@ use mas_policy::PolicyFactory;
use mas_router::UrlBuilder;
use mas_storage::RepositoryAccess;
use mas_storage_pg::PgRepository;
use mas_templates::{SiteConfigExt, TemplateLoadingError, Templates};
use mas_templates::{SiteConfigExt, Templates};
use sqlx::{
ConnectOptions, Executor, PgConnection, PgPool,
postgres::{PgConnectOptions, PgPoolOptions},
@@ -109,20 +110,23 @@ pub fn test_mailer_in_background(mailer: &Mailer, timeout: Duration) {
let mailer = mailer.clone();
let span = tracing::info_span!("cli.test_mailer");
tokio::spawn(async move {
match tokio::time::timeout(timeout, mailer.test_connection()).await {
Ok(Ok(())) => {}
Ok(Err(err)) => {
tracing::warn!(
error = &err as &dyn std::error::Error,
"Could not connect to the mail backend, tasks sending mails may fail!"
);
tokio::spawn(
LogContext::new("mailer-test").run(async move || {
match tokio::time::timeout(timeout, mailer.test_connection()).await {
Ok(Ok(())) => {}
Ok(Err(err)) => {
tracing::warn!(
error = &err as &dyn std::error::Error,
"Could not connect to the mail backend, tasks sending mails may fail!"
);
}
Err(_) => {
tracing::warn!("Timed out while testing the mail backend connection, tasks sending mails may fail!");
}
}
Err(_) => {
tracing::warn!("Timed out while testing the mail backend connection, tasks sending mails may fail!");
}
}
}.instrument(span));
})
.instrument(span)
);
}
pub async fn policy_factory_from_config(
@@ -222,7 +226,7 @@ pub async fn templates_from_config(
config: &TemplatesConfig,
site_config: &SiteConfig,
url_builder: &UrlBuilder,
) -> Result<Templates, TemplateLoadingError> {
) -> Result<Templates, anyhow::Error> {
Templates::load(
config.path.clone(),
url_builder.clone(),
@@ -232,6 +236,7 @@ pub async fn templates_from_config(
site_config.templates_features(),
)
.await
.with_context(|| format!("Failed to load the templates at {}", config.path))
}
fn database_connect_options_from_config(
@@ -331,7 +336,7 @@ fn database_connect_options_from_config(
}
/// Create a database connection pool from the configuration
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
#[tracing::instrument(name = "db.connect", skip_all)]
pub async fn database_pool_from_config(config: &DatabaseConfig) -> Result<PgPool, anyhow::Error> {
let options = database_connect_options_from_config(config, &DatabaseConnectOptions::default())?;
PgPoolOptions::new()
@@ -367,7 +372,7 @@ impl Default for DatabaseConnectOptions {
}
/// Create a single database connection from the configuration
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
#[tracing::instrument(name = "db.connect", skip_all)]
pub async fn database_connection_from_config(
config: &DatabaseConfig,
) -> Result<PgConnection, anyhow::Error> {
@@ -379,7 +384,7 @@ pub async fn database_connection_from_config(
/// Create a single database connection from the configuration,
/// with specific options.
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
#[tracing::instrument(name = "db.connect", skip_all)]
pub async fn database_connection_from_config_with_options(
config: &DatabaseConfig,
options: &DatabaseConnectOptions,
@@ -430,7 +435,7 @@ pub async fn load_policy_factory_dynamic_data_continuously(
}
/// Update the policy factory dynamic data from the database
#[tracing::instrument(name = "policy.load_dynamic_data", skip_all, err(Debug))]
#[tracing::instrument(name = "policy.load_dynamic_data", skip_all)]
pub async fn load_policy_factory_dynamic_data(
policy_factory: &PolicyFactory,
pool: &PgPool,

View File

@@ -70,7 +70,7 @@ impl SecretsConfig {
/// # Errors
///
/// Returns an error when a key could not be imported
#[tracing::instrument(name = "secrets.load", skip_all, err(Debug))]
#[tracing::instrument(name = "secrets.load", skip_all)]
pub async fn key_store(&self) -> anyhow::Result<Keystore> {
let mut keys = Vec::with_capacity(self.keys.len());
for item in &self.keys {

24
crates/context/Cargo.toml Normal file
View File

@@ -0,0 +1,24 @@
[package]
name = "mas-context"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
publish = false
[lints]
workspace = true
[dependencies]
console.workspace = true
pin-project-lite.workspace = true
quanta.workspace = true
tokio.workspace = true
tower-service.workspace = true
tower-layer.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
tracing-opentelemetry.workspace = true
opentelemetry.workspace = true

148
crates/context/src/fmt.rs Normal file
View File

@@ -0,0 +1,148 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use console::{Color, Style};
use opentelemetry::{TraceId, trace::TraceContextExt};
use tracing::{Level, Subscriber};
use tracing_opentelemetry::OtelData;
use tracing_subscriber::{
fmt::{
FormatEvent, FormatFields,
format::{DefaultFields, Writer},
time::{FormatTime, SystemTime},
},
registry::LookupSpan,
};
use crate::LogContext;
/// An event formatter usable by the [`tracing-subscriber`] crate, which
/// includes the log context and the OTEL trace ID.
#[derive(Debug, Default)]
pub struct EventFormatter;
struct FmtLevel<'a> {
level: &'a Level,
ansi: bool,
}
impl<'a> FmtLevel<'a> {
pub(crate) fn new(level: &'a Level, ansi: bool) -> Self {
Self { level, ansi }
}
}
const TRACE_STR: &str = "TRACE";
const DEBUG_STR: &str = "DEBUG";
const INFO_STR: &str = " INFO";
const WARN_STR: &str = " WARN";
const ERROR_STR: &str = "ERROR";
const TRACE_STYLE: Style = Style::new().fg(Color::Magenta);
const DEBUG_STYLE: Style = Style::new().fg(Color::Blue);
const INFO_STYLE: Style = Style::new().fg(Color::Green);
const WARN_STYLE: Style = Style::new().fg(Color::Yellow);
const ERROR_STYLE: Style = Style::new().fg(Color::Red);
impl std::fmt::Display for FmtLevel<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let msg = match *self.level {
Level::TRACE => TRACE_STYLE.force_styling(self.ansi).apply_to(TRACE_STR),
Level::DEBUG => DEBUG_STYLE.force_styling(self.ansi).apply_to(DEBUG_STR),
Level::INFO => INFO_STYLE.force_styling(self.ansi).apply_to(INFO_STR),
Level::WARN => WARN_STYLE.force_styling(self.ansi).apply_to(WARN_STR),
Level::ERROR => ERROR_STYLE.force_styling(self.ansi).apply_to(ERROR_STR),
};
write!(f, "{msg}")
}
}
struct TargetFmt<'a> {
target: &'a str,
line: Option<u32>,
}
impl<'a> TargetFmt<'a> {
pub(crate) fn new(metadata: &tracing::Metadata<'a>) -> Self {
Self {
target: metadata.target(),
line: metadata.line(),
}
}
}
impl std::fmt::Display for TargetFmt<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.target)?;
if let Some(line) = self.line {
write!(f, ":{line}")?;
}
Ok(())
}
}
impl<S, N> FormatEvent<S, N> for EventFormatter
where
S: Subscriber + for<'a> LookupSpan<'a>,
N: for<'writer> FormatFields<'writer> + 'static,
{
fn format_event(
&self,
ctx: &tracing_subscriber::fmt::FmtContext<'_, S, N>,
mut writer: Writer<'_>,
event: &tracing::Event<'_>,
) -> std::fmt::Result {
let ansi = writer.has_ansi_escapes();
let metadata = event.metadata();
SystemTime.format_time(&mut writer)?;
let level = FmtLevel::new(metadata.level(), ansi);
write!(&mut writer, " {level} ")?;
// If there is no explicit 'name' set in the event macro, it will have the
// 'event {filename}:{line}' value. In this case, we want to display the target:
// the module from where it was emitted. In other cases, we want to
// display the explit name of the event we have set.
let style = Style::new().dim().force_styling(ansi);
if metadata.name().starts_with("event ") {
write!(&mut writer, "{} ", style.apply_to(TargetFmt::new(metadata)))?;
} else {
write!(&mut writer, "{} ", style.apply_to(metadata.name()))?;
}
if let Some(log_context) = LogContext::current() {
let log_context = Style::new()
.bold()
.force_styling(ansi)
.apply_to(log_context);
write!(&mut writer, "{log_context} - ")?;
}
let field_fromatter = DefaultFields::new();
field_fromatter.format_fields(writer.by_ref(), event)?;
// If we have a OTEL span, we can add the trace ID to the end of the log line
if let Some(span) = ctx.lookup_current() {
if let Some(otel) = span.extensions().get::<OtelData>() {
// If it is the root span, the trace ID will be in the span builder. Else, it
// will be in the parent OTEL context
let trace_id = otel
.builder
.trace_id
.unwrap_or_else(|| otel.parent_cx.span().span_context().trace_id());
if trace_id != TraceId::INVALID {
let label = Style::new()
.italic()
.force_styling(ansi)
.apply_to("trace.id");
write!(&mut writer, " {label}={trace_id}")?;
}
}
}
writeln!(&mut writer)
}
}

View File

@@ -0,0 +1,59 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::{
pin::Pin,
sync::atomic::Ordering,
task::{Context, Poll},
};
use quanta::Instant;
use tokio::task::futures::TaskLocalFuture;
use crate::LogContext;
pub type LogContextFuture<F> = TaskLocalFuture<crate::LogContext, PollRecordingFuture<F>>;
impl LogContext {
/// Wrap a future with the given log context
pub(crate) fn wrap_future<F: Future>(&self, future: F) -> LogContextFuture<F> {
let future = PollRecordingFuture::new(future);
crate::CURRENT_LOG_CONTEXT.scope(self.clone(), future)
}
}
pin_project_lite::pin_project! {
/// A future which records the elapsed time and the number of polls in the
/// active log context
pub struct PollRecordingFuture<F> {
#[pin]
inner: F,
}
}
impl<F: Future> PollRecordingFuture<F> {
pub(crate) fn new(inner: F) -> Self {
Self { inner }
}
}
impl<F: Future> Future for PollRecordingFuture<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let start = Instant::now();
let this = self.project();
let result = this.inner.poll(cx);
// Record the number of polls and the time we spent polling the future
let elapsed = start.elapsed().as_nanos().try_into().unwrap_or(u64::MAX);
let _ = crate::CURRENT_LOG_CONTEXT.try_with(|c| {
c.inner.polls.fetch_add(1, Ordering::Relaxed);
c.inner.cpu_time.fetch_add(elapsed, Ordering::Relaxed);
});
result
}
}

View File

@@ -0,0 +1,41 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::borrow::Cow;
use tower_layer::Layer;
use tower_service::Service;
use crate::LogContextService;
/// A layer which creates a log context for each request.
pub struct LogContextLayer<R> {
tagger: fn(&R) -> Cow<'static, str>,
}
impl<R> Clone for LogContextLayer<R> {
fn clone(&self) -> Self {
Self {
tagger: self.tagger,
}
}
}
impl<R> LogContextLayer<R> {
pub fn new(tagger: fn(&R) -> Cow<'static, str>) -> Self {
Self { tagger }
}
}
impl<S, R> Layer<S> for LogContextLayer<R>
where
S: Service<R>,
{
type Service = LogContextService<S, R>;
fn layer(&self, inner: S) -> Self::Service {
LogContextService::new(inner, self.tagger)
}
}

149
crates/context/src/lib.rs Normal file
View File

@@ -0,0 +1,149 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
mod fmt;
mod future;
mod layer;
mod service;
use std::{
borrow::Cow,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
use quanta::Instant;
use tokio::task_local;
pub use self::{
fmt::EventFormatter,
future::{LogContextFuture, PollRecordingFuture},
layer::LogContextLayer,
service::LogContextService,
};
/// A counter which increments each time we create a new log context
/// It will wrap around if we create more than [`u64::MAX`] contexts
static LOG_CONTEXT_INDEX: AtomicU64 = AtomicU64::new(0);
task_local! {
pub static CURRENT_LOG_CONTEXT: LogContext;
}
/// A log context saves informations about the current task, such as the
/// elapsed time, the number of polls, and the poll time.
#[derive(Clone)]
pub struct LogContext {
inner: Arc<LogContextInner>,
}
struct LogContextInner {
/// A user-defined tag for the log context
tag: Cow<'static, str>,
/// A unique index for the log context
index: u64,
/// The time when the context was created
start: Instant,
/// The number of [`Future::poll`] recorded
polls: AtomicU64,
/// An approximation of the total CPU time spent in the context, in
/// nanoseconds
cpu_time: AtomicU64,
}
impl LogContext {
/// Create a new log context with the given tag
pub fn new(tag: impl Into<Cow<'static, str>>) -> Self {
let tag = tag.into();
let inner = LogContextInner {
tag,
index: LOG_CONTEXT_INDEX.fetch_add(1, Ordering::Relaxed),
start: Instant::now(),
polls: AtomicU64::new(0),
cpu_time: AtomicU64::new(0),
};
Self {
inner: Arc::new(inner),
}
}
/// Get a copy of the current log context, if any
pub fn current() -> Option<Self> {
CURRENT_LOG_CONTEXT.try_with(Self::clone).ok()
}
/// Run the async function `f` with the given log context. It will wrap the
/// output future to record poll and CPU statistics.
pub fn run<F: FnOnce() -> Fut, Fut: Future>(&self, f: F) -> LogContextFuture<Fut> {
let future = self.run_sync(f);
self.wrap_future(future)
}
/// Run the sync function `f` with the given log context, recording the CPU
/// time spent.
pub fn run_sync<F: FnOnce() -> R, R>(&self, f: F) -> R {
let start = Instant::now();
let result = CURRENT_LOG_CONTEXT.sync_scope(self.clone(), f);
let elapsed = start.elapsed().as_nanos().try_into().unwrap_or(u64::MAX);
self.inner.cpu_time.fetch_add(elapsed, Ordering::Relaxed);
result
}
/// Create a snapshot of the log context statistics
#[must_use]
pub fn stats(&self) -> LogContextStats {
let polls = self.inner.polls.load(Ordering::Relaxed);
let cpu_time = self.inner.cpu_time.load(Ordering::Relaxed);
let cpu_time = Duration::from_nanos(cpu_time);
let elapsed = self.inner.start.elapsed();
LogContextStats {
polls,
cpu_time,
elapsed,
}
}
}
impl std::fmt::Display for LogContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let tag = &self.inner.tag;
let index = self.inner.index;
write!(f, "{tag}-{index}")
}
}
/// A snapshot of a log context statistics
#[derive(Debug, Clone, Copy)]
pub struct LogContextStats {
/// How many times the context was polled
pub polls: u64,
/// The approximate CPU time spent in the context
pub cpu_time: Duration,
/// How much time elapsed since the context was created
pub elapsed: Duration,
}
impl std::fmt::Display for LogContextStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let polls = self.polls;
#[expect(clippy::cast_precision_loss)]
let cpu_time_ms = self.cpu_time.as_nanos() as f64 / 1_000_000.;
#[expect(clippy::cast_precision_loss)]
let elapsed_ms = self.elapsed.as_nanos() as f64 / 1_000_000.;
write!(
f,
"polls: {polls}, cpu: {cpu_time_ms:.1}ms, elapsed: {elapsed_ms:.1}ms",
)
}
}

View File

@@ -0,0 +1,54 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::{
borrow::Cow,
task::{Context, Poll},
};
use tower_service::Service;
use crate::{LogContext, LogContextFuture};
/// A service which wraps another service and creates a log context for
/// each request.
pub struct LogContextService<S, R> {
inner: S,
tagger: fn(&R) -> Cow<'static, str>,
}
impl<S: Clone, R> Clone for LogContextService<S, R> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
tagger: self.tagger,
}
}
}
impl<S, R> LogContextService<S, R> {
pub fn new(inner: S, tagger: fn(&R) -> Cow<'static, str>) -> Self {
Self { inner, tagger }
}
}
impl<S, R> Service<R> for LogContextService<S, R>
where
S: Service<R>,
{
type Response = S::Response;
type Error = S::Error;
type Future = LogContextFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: R) -> Self::Future {
let tag = (self.tagger)(&req);
let log_context = LogContext::new(tag);
log_context.run(|| self.inner.call(req))
}
}

View File

@@ -111,7 +111,6 @@ impl Mailer {
email.to = %to,
email.language = %context.language(),
),
err,
)]
pub async fn send_verification_email(
&self,
@@ -137,7 +136,6 @@ impl Mailer {
user.id = %context.user().id,
user_recovery_session.id = %context.session().id,
),
err,
)]
pub async fn send_recovery_email(
&self,
@@ -154,7 +152,7 @@ impl Mailer {
/// # Errors
///
/// Returns an error if the connection failed
#[tracing::instrument(name = "email.test_connection", skip_all, err)]
#[tracing::instrument(name = "email.test_connection", skip_all)]
pub async fn test_connection(&self) -> Result<(), crate::transport::Error> {
self.transport.test_connection().await
}

View File

@@ -90,6 +90,7 @@ ulid.workspace = true
mas-axum-utils.workspace = true
mas-config.workspace = true
mas-context.workspace = true
mas-data-model.workspace = true
mas-http.workspace = true
mas-i18n.workspace = true

View File

@@ -15,6 +15,7 @@ use axum::{
use axum_extra::TypedHeader;
use headers::{Authorization, authorization::Bearer};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_data_model::{Session, User};
use mas_storage::{BoxClock, BoxRepository, RepositoryError};
use ulid::Ulid;
@@ -69,27 +70,35 @@ pub enum Rejection {
MissingScope,
}
impl Rejection {
fn status_code(&self) -> StatusCode {
match self {
Self::InvalidAuthorizationHeader | Self::MissingAuthorizationHeader => {
StatusCode::BAD_REQUEST
}
Self::UnknownAccessToken
| Self::TokenExpired
| Self::SessionRevoked
| Self::UserLocked
| Self::MissingScope => StatusCode::UNAUTHORIZED,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
impl IntoResponse for Rejection {
fn into_response(self) -> Response {
let response = ErrorResponse::from_error(&self);
let status = self.status_code();
(status, Json(response)).into_response()
let sentry_event_id = record_error!(
self,
Self::RepositorySetup(_)
| Self::Repository(_)
| Self::LoadSession(_)
| Self::LoadUser(_)
);
let status = match &self {
Rejection::InvalidAuthorizationHeader | Rejection::MissingAuthorizationHeader => {
StatusCode::BAD_REQUEST
}
Rejection::UnknownAccessToken
| Rejection::TokenExpired
| Rejection::SessionRevoked
| Rejection::UserLocked
| Rejection::MissingScope => StatusCode::UNAUTHORIZED,
Rejection::RepositorySetup(_)
| Rejection::Repository(_)
| Rejection::LoadSession(_)
| Rejection::LoadUser(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
(status, sentry_event_id, Json(response)).into_response()
}
}

View File

@@ -6,6 +6,7 @@
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use ulid::Ulid;
use crate::{
@@ -33,11 +34,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let status = match self {
let sentry_event_id = record_error!(self, RouteError::Internal(_));
let status = match &self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -59,7 +62,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.compat_sessions.get", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.compat_sessions.get", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
id: UlidPathParam,

View File

@@ -11,6 +11,7 @@ use axum::{
};
use axum_macros::FromRequestParts;
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::{Page, compat::CompatSessionFilter};
use schemars::JsonSchema;
use serde::Deserialize;
@@ -113,12 +114,14 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let status = match self {
let sentry_event_id = record_error!(self, RouteError::Internal(_));
let status = match &self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::UserNotFound(_) | Self::UserSessionNotFound(_) => StatusCode::NOT_FOUND,
Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -153,7 +156,7 @@ Use the `filter[status]` parameter to filter the sessions by their status and `p
})
}
#[tracing::instrument(name = "handler.admin.v1.compat_sessions.list", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.compat_sessions.list", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
Pagination(pagination): Pagination,

View File

@@ -7,6 +7,7 @@
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use ulid::Ulid;
use crate::{
@@ -34,11 +35,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, RouteError::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -60,7 +62,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.oauth2_session.get", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.oauth2_session.get", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
id: UlidPathParam,

View File

@@ -14,6 +14,7 @@ use axum::{
};
use axum_macros::FromRequestParts;
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::{Page, oauth2::OAuth2SessionFilter};
use oauth2_types::scope::{Scope, ScopeToken};
use schemars::JsonSchema;
@@ -167,6 +168,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, RouteError::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::UserNotFound(_) | Self::ClientNotFound(_) | Self::UserSessionNotFound(_) => {
@@ -174,7 +176,7 @@ impl IntoResponse for RouteError {
}
Self::InvalidScope(_) | Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -213,7 +215,7 @@ Use the `filter[status]` parameter to filter the sessions by their status and `p
})
}
#[tracing::instrument(name = "handler.admin.v1.oauth2_sessions.list", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.oauth2_sessions.list", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
Pagination(pagination): Pagination,

View File

@@ -5,6 +5,7 @@
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use ulid::Ulid;
use crate::{
@@ -32,11 +33,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -56,7 +58,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.policy_data.get", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.policy_data.get", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
id: UlidPathParam,

View File

@@ -5,6 +5,7 @@
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use crate::{
admin::{
@@ -30,11 +31,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -55,7 +57,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.policy_data.get_latest", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.policy_data.get_latest", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
) -> Result<Json<SingleResponse<PolicyData>>, RouteError> {

View File

@@ -7,6 +7,7 @@ use std::sync::Arc;
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, extract::State, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_policy::PolicyFactory;
use mas_storage::BoxRng;
use schemars::JsonSchema;
@@ -36,11 +37,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
RouteError::InvalidPolicyData(_) => StatusCode::BAD_REQUEST,
RouteError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -79,7 +81,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.policy_data.set", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.policy_data.set", skip_all)]
pub async fn handler(
CallContext {
mut repo, clock, ..

View File

@@ -6,6 +6,7 @@
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::BoxRng;
use schemars::JsonSchema;
use serde::Deserialize;
@@ -41,12 +42,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::LinkAlreadyExists(_, _) => StatusCode::CONFLICT,
Self::UserNotFound(_) | Self::ProviderNotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -102,7 +104,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.post", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.post", skip_all)]
pub async fn handler(
CallContext {
mut repo, clock, ..

View File

@@ -6,6 +6,7 @@
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use ulid::Ulid;
use crate::{
@@ -28,11 +29,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -49,7 +51,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.delete", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.delete", skip_all)]
pub async fn handler(
CallContext {
mut repo, clock, ..

View File

@@ -6,6 +6,7 @@
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use ulid::Ulid;
use crate::{
@@ -33,11 +34,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_entry_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_entry_id, Json(error)).into_response()
}
}
@@ -59,7 +61,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.get", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.get", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
id: UlidPathParam,

View File

@@ -11,6 +11,7 @@ use axum::{
};
use axum_macros::FromRequestParts;
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::{Page, upstream_oauth2::UpstreamOAuthLinkFilter};
use schemars::JsonSchema;
use serde::Deserialize;
@@ -91,12 +92,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::UserNotFound(_) | Self::ProviderNotFound(_) => StatusCode::NOT_FOUND,
Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -130,7 +132,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.list", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.list", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
Pagination(pagination): Pagination,

View File

@@ -8,6 +8,7 @@ use std::str::FromStr as _;
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::{
BoxRng,
queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
@@ -52,13 +53,14 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::EmailAlreadyInUse(_) => StatusCode::CONFLICT,
Self::EmailNotValid { .. } => StatusCode::BAD_REQUEST,
Self::UserNotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -106,7 +108,7 @@ Note that this endpoint ignores any policy which would normally prevent the emai
})
}
#[tracing::instrument(name = "handler.admin.v1.user_emails.add", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.user_emails.add", skip_all)]
pub async fn handler(
CallContext {
mut repo, clock, ..

View File

@@ -6,6 +6,7 @@
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::{
BoxRng,
queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
@@ -32,11 +33,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -52,7 +54,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.user_emails.delete", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.user_emails.delete", skip_all)]
pub async fn handler(
CallContext {
mut repo, clock, ..

View File

@@ -6,6 +6,7 @@
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use ulid::Ulid;
use crate::{
@@ -33,11 +34,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -57,7 +59,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.user_emails.get", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.user_emails.get", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
id: UlidPathParam,

View File

@@ -11,6 +11,7 @@ use axum::{
};
use axum_macros::FromRequestParts;
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::{Page, user::UserEmailFilter};
use schemars::JsonSchema;
use serde::Deserialize;
@@ -78,12 +79,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::UserNotFound(_) => StatusCode::NOT_FOUND,
Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -116,7 +118,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.user_emails.list", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.user_emails.list", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
Pagination(pagination): Pagination,

View File

@@ -6,6 +6,7 @@
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use ulid::Ulid;
use crate::{
@@ -33,11 +34,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -58,7 +60,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.user_sessions.get", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.user_sessions.get", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
id: UlidPathParam,

View File

@@ -11,6 +11,7 @@ use axum::{
};
use axum_macros::FromRequestParts;
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::{pagination::Page, user::BrowserSessionFilter};
use schemars::JsonSchema;
use serde::Deserialize;
@@ -100,12 +101,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::UserNotFound(_) => StatusCode::NOT_FOUND,
Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -140,7 +142,7 @@ Use the `filter[status]` parameter to filter the sessions by their status and `p
})
}
#[tracing::instrument(name = "handler.admin.v1.user_sessions.list", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.user_sessions.list", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
Pagination(pagination): Pagination,

View File

@@ -9,6 +9,7 @@ use std::sync::Arc;
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, extract::State, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_matrix::HomeserverConnection;
use mas_storage::{
BoxRng,
@@ -81,12 +82,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_) | Self::Homeserver(_));
let status = match self {
Self::Internal(_) | Self::Homeserver(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::UsernameNotValid => StatusCode::BAD_REQUEST,
Self::UserAlreadyExists | Self::UsernameReserved => StatusCode::CONFLICT,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -131,7 +133,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.users.add", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.users.add", skip_all)]
pub async fn handler(
CallContext {
mut repo, clock, ..

View File

@@ -7,6 +7,7 @@
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, extract::Path, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use schemars::JsonSchema;
use serde::Deserialize;
@@ -34,11 +35,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -65,7 +67,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.users.by_username", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.users.by_username", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
Path(UsernamePathParam { username }): Path<UsernamePathParam>,

View File

@@ -7,6 +7,7 @@
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::{
BoxRng,
queue::{DeactivateUserJob, QueueJobRepositoryExt as _},
@@ -39,11 +40,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -67,7 +69,7 @@ This invalidates any existing session, and will ask the homeserver to make them
})
}
#[tracing::instrument(name = "handler.admin.v1.users.deactivate", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.users.deactivate", skip_all)]
pub async fn handler(
CallContext {
mut repo, clock, ..
@@ -86,7 +88,7 @@ pub async fn handler(
user = repo.user().lock(&clock, user).await?;
}
info!("Scheduling deactivation of user {}", user.id);
info!(%user.id, "Scheduling deactivation of user");
repo.queue_job()
.schedule_job(&mut rng, &clock, DeactivateUserJob::new(&user, true))
.await?;

View File

@@ -7,6 +7,7 @@
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use ulid::Ulid;
use crate::{
@@ -34,11 +35,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -58,7 +60,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.users.get", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.users.get", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
id: UlidPathParam,

View File

@@ -12,6 +12,7 @@ use axum::{
};
use axum_macros::FromRequestParts;
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::{Page, user::UserFilter};
use schemars::JsonSchema;
use serde::Deserialize;
@@ -95,11 +96,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -122,7 +124,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.users.list", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.users.list", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
Pagination(pagination): Pagination,

View File

@@ -7,6 +7,7 @@
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use ulid::Ulid;
use crate::{
@@ -34,11 +35,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -62,7 +64,7 @@ This DOES NOT invalidate any existing session, meaning that all their existing s
})
}
#[tracing::instrument(name = "handler.admin.v1.users.lock", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.users.lock", skip_all)]
pub async fn handler(
CallContext {
mut repo, clock, ..

View File

@@ -7,6 +7,7 @@
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use schemars::JsonSchema;
use serde::Deserialize;
use ulid::Ulid;
@@ -36,11 +37,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -71,7 +73,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.users.set_admin", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.users.set_admin", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
id: UlidPathParam,

View File

@@ -7,6 +7,7 @@
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, extract::State, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::BoxRng;
use schemars::JsonSchema;
use serde::Deserialize;
@@ -43,13 +44,14 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_) | Self::Password(_));
let status = match self {
Self::Internal(_) | Self::Password(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::PasswordAuthDisabled => StatusCode::FORBIDDEN,
Self::PasswordTooWeak => StatusCode::BAD_REQUEST,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -90,7 +92,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.users.set_password", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.users.set_password", skip_all)]
pub async fn handler(
CallContext {
mut repo, clock, ..

View File

@@ -9,6 +9,7 @@ use std::sync::Arc;
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, extract::State, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_matrix::HomeserverConnection;
use ulid::Ulid;
@@ -40,11 +41,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_) | Self::Homeserver(_));
let status = match self {
Self::Internal(_) | Self::Homeserver(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
};
(status, Json(error)).into_response()
(status, sentry_event_id, Json(error)).into_response()
}
}
@@ -66,7 +68,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
})
}
#[tracing::instrument(name = "handler.admin.v1.users.unlock", skip_all, err)]
#[tracing::instrument(name = "handler.admin.v1.users.unlock", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
State(homeserver): State<Arc<dyn HomeserverConnection>>,

View File

@@ -156,7 +156,6 @@ impl Form {
skip_all,
name = "captcha.verify",
fields(captcha.hostname, captcha.challenge_ts, captcha.service),
err
)]
pub async fn verify(
&self,

View File

@@ -10,7 +10,7 @@ use axum::{Json, extract::State, response::IntoResponse};
use axum_extra::typed_header::TypedHeader;
use chrono::Duration;
use hyper::StatusCode;
use mas_axum_utils::sentry::SentryEventID;
use mas_axum_utils::record_error;
use mas_data_model::{
CompatSession, CompatSsoLoginState, Device, SiteConfig, TokenType, User, UserAgent,
};
@@ -210,7 +210,8 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id =
record_error!(self, Self::Internal(_) | Self::ProvisionDeviceFailed(_));
LOGIN_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
let response = match self {
Self::Internal(_) | Self::ProvisionDeviceFailed(_) => MatrixError {
@@ -257,11 +258,11 @@ impl IntoResponse for RouteError {
},
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
#[tracing::instrument(name = "handlers.compat.login.post", skip_all, err)]
#[tracing::instrument(name = "handlers.compat.login.post", skip_all)]
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,

View File

@@ -48,7 +48,6 @@ pub struct Params {
name = "handlers.compat.login_sso_complete.get",
fields(compat_sso_login.id = %id),
skip_all,
err,
)]
pub async fn get(
PreferredLanguage(locale): PreferredLanguage,
@@ -121,7 +120,6 @@ pub async fn get(
name = "handlers.compat.login_sso_complete.post",
fields(compat_sso_login.id = %id),
skip_all,
err,
)]
pub async fn post(
mut rng: BoxRng,

View File

@@ -9,7 +9,7 @@ use axum::{
response::IntoResponse,
};
use hyper::StatusCode;
use mas_axum_utils::sentry::SentryEventID;
use mas_axum_utils::record_error;
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
use mas_storage::{BoxClock, BoxRepository, BoxRng, compat::CompatSsoLoginRepository};
use rand::distributions::{Alphanumeric, DistString};
@@ -43,17 +43,16 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
(
StatusCode::INTERNAL_SERVER_ERROR,
SentryEventID::from(event_id),
format!("{self}"),
)
.into_response()
let sentry_event_id = record_error!(self, Self::Internal(_));
let status_code = match &self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::MissingRedirectUrl | Self::InvalidRedirectUrl => StatusCode::BAD_REQUEST,
};
(status_code, sentry_event_id, format!("{self}")).into_response()
}
}
#[tracing::instrument(name = "handlers.compat.login_sso_redirect.get", skip_all, err)]
#[tracing::instrument(name = "handlers.compat.login_sso_redirect.get", skip_all)]
pub async fn get(
mut rng: BoxRng,
clock: BoxClock,

View File

@@ -10,7 +10,7 @@ use axum::{Json, response::IntoResponse};
use axum_extra::typed_header::TypedHeader;
use headers::{Authorization, authorization::Bearer};
use hyper::StatusCode;
use mas_axum_utils::sentry::SentryEventID;
use mas_axum_utils::record_error;
use mas_data_model::TokenType;
use mas_storage::{
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
@@ -51,7 +51,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
LOGOUT_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
let response = match self {
Self::Internal(_) => MatrixError {
@@ -71,11 +71,11 @@ impl IntoResponse for RouteError {
},
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
#[tracing::instrument(name = "handlers.compat.logout.post", skip_all, err)]
#[tracing::instrument(name = "handlers.compat.logout.post", skip_all)]
pub(crate) async fn post(
clock: BoxClock,
mut rng: BoxRng,

View File

@@ -14,7 +14,7 @@ use axum::{
response::IntoResponse,
};
use hyper::{StatusCode, header};
use mas_axum_utils::sentry::SentryEventID;
use mas_axum_utils::record_error;
use serde::{Serialize, de::DeserializeOwned};
use thiserror::Error;
@@ -59,7 +59,7 @@ pub enum MatrixJsonBodyRejection {
impl IntoResponse for MatrixJsonBodyRejection {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(self, !);
let response = match self {
Self::InvalidContentType | Self::ContentTypeNotJson(_) => MatrixError {
errcode: "M_NOT_JSON",
@@ -102,7 +102,7 @@ impl IntoResponse for MatrixJsonBodyRejection {
},
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}

View File

@@ -7,7 +7,7 @@
use axum::{Json, extract::State, response::IntoResponse};
use chrono::Duration;
use hyper::StatusCode;
use mas_axum_utils::sentry::SentryEventID;
use mas_axum_utils::record_error;
use mas_data_model::{SiteConfig, TokenFormatError, TokenType};
use mas_storage::{
BoxClock, BoxRepository, BoxRng, Clock,
@@ -16,6 +16,7 @@ use mas_storage::{
use serde::{Deserialize, Serialize};
use serde_with::{DurationMilliSeconds, serde_as};
use thiserror::Error;
use ulid::Ulid;
use super::MatrixError;
use crate::{BoundActivityTracker, impl_from_error_for_route};
@@ -31,46 +32,50 @@ pub enum RouteError {
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("invalid token")]
InvalidToken,
InvalidToken(#[from] TokenFormatError),
#[error("refresh token already consumed")]
RefreshTokenConsumed,
#[error("unknown token")]
UnknownToken,
#[error("invalid session")]
InvalidSession,
#[error("invalid token type {0}, expected a compat refresh token")]
InvalidTokenType(TokenType),
#[error("unknown session")]
UnknownSession,
#[error("refresh token already consumed {0}")]
RefreshTokenConsumed(Ulid),
#[error("invalid compat session {0}")]
InvalidSession(Ulid),
#[error("unknown comapt session {0}")]
UnknownSession(Ulid),
}
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_) | Self::UnknownSession(_));
let response = match self {
Self::Internal(_) | Self::UnknownSession => MatrixError {
Self::Internal(_) | Self::UnknownSession(_) => MatrixError {
errcode: "M_UNKNOWN",
error: "Internal error",
status: StatusCode::INTERNAL_SERVER_ERROR,
},
Self::InvalidToken | Self::InvalidSession | Self::RefreshTokenConsumed => MatrixError {
Self::InvalidToken(_)
| Self::UnknownToken
| Self::InvalidTokenType(_)
| Self::InvalidSession(_)
| Self::RefreshTokenConsumed(_) => MatrixError {
errcode: "M_UNKNOWN_TOKEN",
error: "Invalid refresh token",
status: StatusCode::UNAUTHORIZED,
},
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl From<TokenFormatError> for RouteError {
fn from(_e: TokenFormatError) -> Self {
Self::InvalidToken
}
}
#[serde_as]
#[derive(Debug, Serialize)]
pub struct ResponseBody {
@@ -80,7 +85,7 @@ pub struct ResponseBody {
expires_in_ms: Duration,
}
#[tracing::instrument(name = "handlers.compat.refresh.post", skip_all, err)]
#[tracing::instrument(name = "handlers.compat.refresh.post", skip_all)]
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
@@ -92,27 +97,27 @@ pub(crate) async fn post(
let token_type = TokenType::check(&input.refresh_token)?;
if token_type != TokenType::CompatRefreshToken {
return Err(RouteError::InvalidToken);
return Err(RouteError::InvalidTokenType(token_type));
}
let refresh_token = repo
.compat_refresh_token()
.find_by_token(&input.refresh_token)
.await?
.ok_or(RouteError::InvalidToken)?;
.ok_or(RouteError::UnknownToken)?;
if !refresh_token.is_valid() {
return Err(RouteError::RefreshTokenConsumed);
return Err(RouteError::RefreshTokenConsumed(refresh_token.id));
}
let session = repo
.compat_session()
.lookup(refresh_token.session_id)
.await?
.ok_or(RouteError::UnknownSession)?;
.ok_or(RouteError::UnknownSession(refresh_token.session_id))?;
if !session.is_valid() {
return Err(RouteError::InvalidSession);
return Err(RouteError::InvalidSession(refresh_token.session_id));
}
activity_tracker

View File

@@ -552,7 +552,7 @@ impl UserMutations {
let user = repo.user().lock(&state.clock(), user).await?;
if deactivate {
info!("Scheduling deactivation of user {}", user.id);
info!(%user.id, "Scheduling deactivation of user");
repo.queue_job()
.schedule_job(&mut rng, &clock, DeactivateUserJob::new(&user, deactivate))
.await?;

View File

@@ -13,7 +13,7 @@ use hyper::StatusCode;
use mas_axum_utils::{
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
sentry::SentryEventID,
record_error,
};
use mas_data_model::AuthorizationGrantStage;
use mas_keystore::Keystore;
@@ -46,11 +46,11 @@ pub enum RouteError {
#[error("Authorization grant not found")]
GrantNotFound,
#[error("Authorization grant already used")]
GrantNotPending,
#[error("Authorization grant {0} already used")]
GrantNotPending(Ulid),
#[error("Failed to load client")]
NoSuchClient,
#[error("Failed to load client {0}")]
NoSuchClient(Ulid),
}
impl_from_error_for_route!(mas_templates::TemplateError);
@@ -64,10 +64,10 @@ impl_from_error_for_route!(super::callback::CallbackDestinationError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_) | Self::NoSuchClient(_));
(
StatusCode::INTERNAL_SERVER_ERROR,
SentryEventID::from(event_id),
sentry_event_id,
self.to_string(),
)
.into_response()
@@ -78,7 +78,6 @@ impl IntoResponse for RouteError {
name = "handlers.oauth2.authorization.consent.get",
fields(grant.id = %grant_id),
skip_all,
err,
)]
pub(crate) async fn get(
mut rng: BoxRng,
@@ -118,10 +117,10 @@ pub(crate) async fn get(
.oauth2_client()
.lookup(grant.client_id)
.await?
.ok_or(RouteError::NoSuchClient)?;
.ok_or(RouteError::NoSuchClient(grant.client_id))?;
if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
return Err(RouteError::GrantNotPending);
return Err(RouteError::GrantNotPending(grant.id));
}
let Some(session) = maybe_session else {
@@ -172,7 +171,6 @@ pub(crate) async fn get(
name = "handlers.oauth2.authorization.consent.post",
fields(grant.id = %grant_id),
skip_all,
err,
)]
pub(crate) async fn post(
mut rng: BoxRng,
@@ -229,7 +227,11 @@ pub(crate) async fn post(
.oauth2_client()
.lookup(grant.client_id)
.await?
.ok_or(RouteError::NoSuchClient)?;
.ok_or(RouteError::NoSuchClient(grant.client_id))?;
if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
return Err(RouteError::GrantNotPending(grant.id));
}
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {

View File

@@ -9,7 +9,7 @@ use axum::{
response::{IntoResponse, Response},
};
use hyper::StatusCode;
use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, sentry::SentryEventID};
use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, record_error};
use mas_data_model::{AuthorizationCode, Pkce};
use mas_router::{PostAuthAction, UrlBuilder};
use mas_storage::{
@@ -53,7 +53,7 @@ pub enum RouteError {
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
// TODO: better error pages
let response = match self {
RouteError::Internal(e) => {
@@ -75,7 +75,7 @@ impl IntoResponse for RouteError {
.into_response(),
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
@@ -122,7 +122,6 @@ fn resolve_response_mode(
name = "handlers.oauth2.authorization.get",
fields(client.id = %params.auth.client_id),
skip_all,
err,
)]
#[allow(clippy::too_many_lines)]
pub(crate) async fn get(
@@ -319,7 +318,7 @@ pub(crate) async fn get(
let response = match res {
Ok(r) => r,
Err(err) => {
tracing::error!(%err);
tracing::error!(message = &err as &dyn std::error::Error);
callback_destination.go(
&templates,
&locale,

View File

@@ -11,7 +11,7 @@ use headers::{CacheControl, Pragma};
use hyper::StatusCode;
use mas_axum_utils::{
client_authorization::{ClientAuthorization, CredentialsVerificationError},
sentry::SentryEventID,
record_error,
};
use mas_data_model::UserAgent;
use mas_keystore::Encrypter;
@@ -24,6 +24,7 @@ use oauth2_types::{
};
use rand::distributions::{Alphanumeric, DistString};
use thiserror::Error;
use ulid::Ulid;
use crate::{BoundActivityTracker, impl_from_error_for_route};
@@ -35,35 +36,46 @@ pub(crate) enum RouteError {
#[error("client not found")]
ClientNotFound,
#[error("client not allowed")]
ClientNotAllowed,
#[error("client {0} is not allowed to use the device code grant")]
ClientNotAllowed(Ulid),
#[error("could not verify client credentials")]
ClientCredentialsVerification(#[from] CredentialsVerificationError),
#[error("invalid client credentials for client {client_id}")]
InvalidClientCredentials {
client_id: Ulid,
#[source]
source: CredentialsVerificationError,
},
#[error("could not verify client credentials for client {client_id}")]
ClientCredentialsVerification {
client_id: Ulid,
#[source]
source: CredentialsVerificationError,
},
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let response = match self {
Self::Internal(_) => (
Self::Internal(_) | Self::ClientCredentialsVerification { .. } => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ClientError::from(ClientErrorCode::ServerError)),
),
Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
StatusCode::UNAUTHORIZED,
Json(ClientError::from(ClientErrorCode::InvalidClient)),
),
Self::ClientNotAllowed => (
Self::ClientNotAllowed(_) => (
StatusCode::UNAUTHORIZED,
Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
),
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
@@ -71,7 +83,6 @@ impl IntoResponse for RouteError {
name = "handlers.oauth2.device.request.post",
fields(client.id = client_authorization.client_id()),
skip_all,
err,
)]
pub(crate) async fn post(
mut rng: BoxRng,
@@ -94,15 +105,28 @@ pub(crate) async fn post(
let method = client
.token_endpoint_auth_method
.as_ref()
.ok_or(RouteError::ClientNotAllowed)?;
.ok_or(RouteError::ClientNotAllowed(client.id))?;
client_authorization
.credentials
.verify(&http_client, &encrypter, method, &client)
.await?;
.await
.map_err(|err| {
if err.is_internal() {
RouteError::ClientCredentialsVerification {
client_id: client.id,
source: err,
}
} else {
RouteError::InvalidClientCredentials {
client_id: client.id,
source: err,
}
}
})?;
if !client.grant_types.contains(&GrantType::DeviceCode) {
return Err(RouteError::ClientNotAllowed);
return Err(RouteError::ClientNotAllowed(client.id));
}
let scope = client_authorization

View File

@@ -41,6 +41,7 @@ pub(crate) struct ConsentForm {
action: Action,
}
#[tracing::instrument(name = "handlers.oauth2.device.consent.get", skip_all)]
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
@@ -136,6 +137,7 @@ pub(crate) async fn get(
Ok((cookie_jar, Html(rendered)).into_response())
}
#[tracing::instrument(name = "handlers.oauth2.device.consent.post", skip_all)]
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,

View File

@@ -24,7 +24,7 @@ pub struct Params {
code: Option<String>,
}
#[tracing::instrument(name = "handlers.oauth2.device.link.get", skip_all, err)]
#[tracing::instrument(name = "handlers.oauth2.device.link.get", skip_all)]
pub(crate) async fn get(
clock: BoxClock,
mut repo: BoxRepository,

View File

@@ -10,7 +10,7 @@ use axum::{Json, extract::State, http::HeaderValue, response::IntoResponse};
use hyper::{HeaderMap, StatusCode};
use mas_axum_utils::{
client_authorization::{ClientAuthorization, CredentialsVerificationError},
sentry::SentryEventID,
record_error,
};
use mas_data_model::{Device, TokenFormatError, TokenType};
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
@@ -28,6 +28,7 @@ use oauth2_types::{
};
use opentelemetry::{Key, KeyValue, metrics::Counter};
use thiserror::Error;
use ulid::Ulid;
use crate::{ActivityTracker, METER, impl_from_error_for_route};
@@ -53,8 +54,8 @@ pub enum RouteError {
ClientNotFound,
/// The client is not allowed to introspect.
#[error("client is not allowed to introspect")]
NotAllowed,
#[error("client {0} is not allowed to introspect")]
NotAllowed(Ulid),
/// The token type is not the one expected.
#[error("unexpected token type")]
@@ -73,30 +74,30 @@ pub enum RouteError {
InvalidToken(TokenType),
/// The OAuth session is not valid.
#[error("invalid oauth session")]
InvalidOAuthSession,
#[error("invalid oauth session {0}")]
InvalidOAuthSession(Ulid),
/// The OAuth session could not be found in the database.
#[error("unknown oauth session")]
CantLoadOAuthSession,
#[error("unknown oauth session {0}")]
CantLoadOAuthSession(Ulid),
/// The compat session is not valid.
#[error("invalid compat session")]
InvalidCompatSession,
#[error("invalid compat session {0}")]
InvalidCompatSession(Ulid),
/// The compat session could not be found in the database.
#[error("unknown compat session")]
CantLoadCompatSession,
#[error("unknown compat session {0}")]
CantLoadCompatSession(Ulid),
/// The Device ID in the compat session can't be encoded as a scope
#[error("device ID contains characters that are not allowed in a scope")]
CantEncodeDeviceID(#[from] mas_data_model::ToScopeTokenError),
#[error("invalid user")]
InvalidUser,
#[error("invalid user {0}")]
InvalidUser(Ulid),
#[error("unknown user")]
CantLoadUser,
#[error("unknown user {0}")]
CantLoadUser(Ulid),
#[error("bad request")]
BadRequest,
@@ -107,12 +108,19 @@ pub enum RouteError {
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(
self,
Self::Internal(_)
| Self::CantLoadCompatSession(_)
| Self::CantLoadOAuthSession(_)
| Self::CantLoadUser(_)
);
let response = match self {
e @ (Self::Internal(_)
| Self::CantLoadCompatSession
| Self::CantLoadOAuthSession
| Self::CantLoadUser) => (
| Self::CantLoadCompatSession(_)
| Self::CantLoadOAuthSession(_)
| Self::CantLoadUser(_)) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(
ClientError::from(ClientErrorCode::ServerError).with_description(e.to_string()),
@@ -136,9 +144,9 @@ impl IntoResponse for RouteError {
Self::UnknownToken(_)
| Self::UnexpectedTokenType
| Self::InvalidToken(_)
| Self::InvalidUser
| Self::InvalidCompatSession
| Self::InvalidOAuthSession
| Self::InvalidUser(_)
| Self::InvalidCompatSession(_)
| Self::InvalidOAuthSession(_)
| Self::InvalidTokenFormat(_)
| Self::CantEncodeDeviceID(_) => {
INTROSPECTION_COUNTER.add(1, &[KeyValue::new(ACTIVE.clone(), false)]);
@@ -146,11 +154,12 @@ impl IntoResponse for RouteError {
Json(INACTIVE).into_response()
}
Self::NotAllowed => (
Self::NotAllowed(_) => (
StatusCode::UNAUTHORIZED,
Json(ClientError::from(ClientErrorCode::AccessDenied)),
)
.into_response(),
Self::BadRequest => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
@@ -158,7 +167,7 @@ impl IntoResponse for RouteError {
.into_response(),
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
@@ -188,7 +197,6 @@ const SYNAPSE_ADMIN_SCOPE: ScopeToken = ScopeToken::from_static("urn:synapse:adm
name = "handlers.oauth2.introspection.post",
fields(client.id = client_authorization.client_id()),
skip_all,
err,
)]
#[allow(clippy::too_many_lines)]
pub(crate) async fn post(
@@ -208,7 +216,7 @@ pub(crate) async fn post(
let method = match &client.token_endpoint_auth_method {
None | Some(OAuthClientAuthenticationMethod::None) => {
return Err(RouteError::NotAllowed);
return Err(RouteError::NotAllowed(client.id));
}
Some(c) => c,
};
@@ -259,10 +267,10 @@ pub(crate) async fn post(
.oauth2_session()
.lookup(access_token.session_id)
.await?
.ok_or(RouteError::InvalidOAuthSession)?;
.ok_or(RouteError::CantLoadOAuthSession(access_token.session_id))?;
if !session.is_valid() {
return Err(RouteError::InvalidOAuthSession);
return Err(RouteError::InvalidOAuthSession(session.id));
}
// If this is the first time we're using this token, mark it as used
@@ -280,10 +288,10 @@ pub(crate) async fn post(
.user()
.lookup(user_id)
.await?
.ok_or(RouteError::CantLoadUser)?;
.ok_or(RouteError::CantLoadUser(user_id))?;
if !user.is_valid() {
return Err(RouteError::InvalidUser);
return Err(RouteError::InvalidUser(user.id));
}
(Some(user.sub), Some(user.username))
@@ -338,10 +346,10 @@ pub(crate) async fn post(
.oauth2_session()
.lookup(refresh_token.session_id)
.await?
.ok_or(RouteError::CantLoadOAuthSession)?;
.ok_or(RouteError::CantLoadOAuthSession(refresh_token.session_id))?;
if !session.is_valid() {
return Err(RouteError::InvalidOAuthSession);
return Err(RouteError::InvalidOAuthSession(session.id));
}
// The session might not have a user on it (for Client Credentials grants for
@@ -351,10 +359,10 @@ pub(crate) async fn post(
.user()
.lookup(user_id)
.await?
.ok_or(RouteError::CantLoadUser)?;
.ok_or(RouteError::CantLoadUser(user_id))?;
if !user.is_valid() {
return Err(RouteError::InvalidUser);
return Err(RouteError::InvalidUser(user.id));
}
(Some(user.sub), Some(user.username))
@@ -407,20 +415,20 @@ pub(crate) async fn post(
.compat_session()
.lookup(access_token.session_id)
.await?
.ok_or(RouteError::CantLoadCompatSession)?;
.ok_or(RouteError::CantLoadCompatSession(access_token.session_id))?;
if !session.is_valid() {
return Err(RouteError::InvalidCompatSession);
return Err(RouteError::InvalidCompatSession(session.id));
}
let user = repo
.user()
.lookup(session.user_id)
.await?
.ok_or(RouteError::CantLoadUser)?;
.ok_or(RouteError::CantLoadUser(session.user_id))?;
if !user.is_valid() {
return Err(RouteError::InvalidUser)?;
return Err(RouteError::InvalidUser(user.id))?;
}
// Grant the synapse admin scope if the session has the admin flag set.
@@ -491,20 +499,20 @@ pub(crate) async fn post(
.compat_session()
.lookup(refresh_token.session_id)
.await?
.ok_or(RouteError::CantLoadCompatSession)?;
.ok_or(RouteError::CantLoadCompatSession(refresh_token.session_id))?;
if !session.is_valid() {
return Err(RouteError::InvalidCompatSession);
return Err(RouteError::InvalidCompatSession(session.id));
}
let user = repo
.user()
.lookup(session.user_id)
.await?
.ok_or(RouteError::CantLoadUser)?;
.ok_or(RouteError::CantLoadUser(session.user_id))?;
if !user.is_valid() {
return Err(RouteError::InvalidUser)?;
return Err(RouteError::InvalidUser(user.id))?;
}
// Grant the synapse admin scope if the session has the admin flag set.

View File

@@ -9,10 +9,10 @@ use std::sync::LazyLock;
use axum::{Json, extract::State, response::IntoResponse};
use axum_extra::TypedHeader;
use hyper::StatusCode;
use mas_axum_utils::sentry::SentryEventID;
use mas_axum_utils::record_error;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_keystore::Encrypter;
use mas_policy::{Policy, Violation};
use mas_policy::{EvaluationResult, Policy};
use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2ClientRepository};
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
@@ -55,8 +55,8 @@ pub(crate) enum RouteError {
#[error("{0} is a public suffix, not a valid domain")]
UrlIsPublicSuffix(&'static str),
#[error("denied by the policy: {0:?}")]
PolicyDenied(Vec<Violation>),
#[error("client registration denied by the policy: {0}")]
PolicyDenied(EvaluationResult),
}
impl_from_error_for_route!(mas_storage::RepositoryError);
@@ -67,7 +67,7 @@ impl_from_error_for_route!(serde_json::Error);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "denied")]);
@@ -143,15 +143,20 @@ impl IntoResponse for RouteError {
// For policy violations, we return an `invalid_client_metadata` error with the details
// of the violations in most cases. If a violation includes `redirect_uri` in the
// message, we return an `invalid_redirect_uri` error instead.
Self::PolicyDenied(violations) => {
Self::PolicyDenied(evaluation) => {
// TODO: detect them better
let code = if violations.iter().any(|v| v.msg.contains("redirect_uri")) {
let code = if evaluation
.violations
.iter()
.any(|v| v.msg.contains("redirect_uri"))
{
ClientErrorCode::InvalidRedirectUri
} else {
ClientErrorCode::InvalidClientMetadata
};
let collected = &violations
let collected = &evaluation
.violations
.iter()
.map(|v| v.msg.clone())
.collect::<Vec<String>>();
@@ -165,7 +170,7 @@ impl IntoResponse for RouteError {
}
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
@@ -207,7 +212,7 @@ fn localised_url_has_public_suffix(url: &Localized<Url>) -> bool {
url.iter().any(|(_lang, url)| host_is_public_suffix(url))
}
#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all, err)]
#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all)]
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
@@ -282,7 +287,7 @@ pub(crate) async fn post(
})
.await?;
if !res.valid() {
return Err(RouteError::PolicyDenied(res.violations));
return Err(RouteError::PolicyDenied(res));
}
let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {

View File

@@ -8,7 +8,7 @@ use axum::{Json, extract::State, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::{
client_authorization::{ClientAuthorization, CredentialsVerificationError},
sentry::SentryEventID,
record_error,
};
use mas_data_model::TokenType;
use mas_iana::oauth::OAuthTokenTypeHint;
@@ -22,6 +22,7 @@ use oauth2_types::{
requests::RevocationRequest,
};
use thiserror::Error;
use ulid::Ulid;
use crate::{BoundActivityTracker, impl_from_error_for_route};
@@ -39,8 +40,19 @@ pub(crate) enum RouteError {
#[error("client not allowed")]
ClientNotAllowed,
#[error("could not verify client credentials")]
ClientCredentialsVerification(#[from] CredentialsVerificationError),
#[error("invalid client credentials for client {client_id}")]
InvalidClientCredentials {
client_id: Ulid,
#[source]
source: CredentialsVerificationError,
},
#[error("could not verify client credentials for client {client_id}")]
ClientCredentialsVerification {
client_id: Ulid,
#[source]
source: CredentialsVerificationError,
},
#[error("client is unauthorized")]
UnauthorizedClient,
@@ -54,9 +66,9 @@ pub(crate) enum RouteError {
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let response = match self {
Self::Internal(_) => (
Self::Internal(_) | Self::ClientCredentialsVerification { .. } => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ClientError::from(ClientErrorCode::ServerError)),
)
@@ -68,7 +80,7 @@ impl IntoResponse for RouteError {
)
.into_response(),
Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
StatusCode::UNAUTHORIZED,
Json(ClientError::from(ClientErrorCode::InvalidClient)),
)
@@ -90,7 +102,7 @@ impl IntoResponse for RouteError {
Self::UnknownToken => StatusCode::OK.into_response(),
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
@@ -106,7 +118,6 @@ impl From<mas_data_model::TokenFormatError> for RouteError {
name = "handlers.oauth2.revoke.post",
fields(client.id = client_authorization.client_id()),
skip_all,
err,
)]
pub(crate) async fn post(
clock: BoxClock,
@@ -131,7 +142,20 @@ pub(crate) async fn post(
client_authorization
.credentials
.verify(&http_client, &encrypter, method, &client)
.await?;
.await
.map_err(|err| {
if err.is_internal() {
RouteError::ClientCredentialsVerification {
client_id: client.id,
source: err,
}
} else {
RouteError::InvalidClientCredentials {
client_id: client.id,
source: err,
}
}
})?;
let Some(form) = client_authorization.form else {
return Err(RouteError::BadRequest);

View File

@@ -13,7 +13,7 @@ use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
use hyper::StatusCode;
use mas_axum_utils::{
client_authorization::{ClientAuthorization, CredentialsVerificationError},
sentry::SentryEventID,
record_error,
};
use mas_data_model::{
AuthorizationGrantStage, Client, Device, DeviceCodeGrantState, SiteConfig, TokenType, UserAgent,
@@ -42,7 +42,7 @@ use oauth2_types::{
};
use opentelemetry::{Key, KeyValue, metrics::Counter};
use thiserror::Error;
use tracing::{debug, info};
use tracing::{debug, info, warn};
use ulid::Ulid;
use super::{generate_id_token, generate_token_pair};
@@ -72,17 +72,28 @@ pub(crate) enum RouteError {
#[error("client not found")]
ClientNotFound,
#[error("client not allowed")]
ClientNotAllowed,
#[error("client not allowed to use the token endpoint: {0}")]
ClientNotAllowed(Ulid),
#[error("could not verify client credentials")]
ClientCredentialsVerification(#[from] CredentialsVerificationError),
#[error("invalid client credentials for client {client_id}")]
InvalidClientCredentials {
client_id: Ulid,
#[source]
source: CredentialsVerificationError,
},
#[error("could not verify client credentials for client {client_id}")]
ClientCredentialsVerification {
client_id: Ulid,
#[source]
source: CredentialsVerificationError,
},
#[error("grant not found")]
GrantNotFound,
#[error("invalid grant")]
InvalidGrant,
#[error("invalid grant {0}")]
InvalidGrant(Ulid),
#[error("refresh token not found")]
RefreshTokenNotFound,
@@ -96,20 +107,23 @@ pub(crate) enum RouteError {
#[error("client id mismatch: expected {expected}, got {actual}")]
ClientIDMismatch { expected: Ulid, actual: Ulid },
#[error("policy denied the request")]
DeniedByPolicy(Vec<mas_policy::Violation>),
#[error("policy denied the request: {0}")]
DeniedByPolicy(mas_policy::EvaluationResult),
#[error("unsupported grant type")]
UnsupportedGrantType,
#[error("unauthorized client")]
UnauthorizedClient,
#[error("client {0} is not authorized to use this grant type")]
UnauthorizedClient(Ulid),
#[error("failed to load browser session")]
NoSuchBrowserSession,
#[error("unexpected client {was} (expected {expected})")]
UnexptectedClient { was: Ulid, expected: Ulid },
#[error("failed to load oauth session")]
NoSuchOAuthSession,
#[error("failed to load browser session {0}")]
NoSuchBrowserSession(Ulid),
#[error("failed to load oauth session {0}")]
NoSuchOAuthSession(Ulid),
#[error(
"failed to load the next refresh token ({next:?}) from the previous one ({previous:?})"
@@ -145,14 +159,25 @@ pub(crate) enum RouteError {
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(
self,
Self::Internal(_)
| Self::ClientCredentialsVerification { .. }
| Self::NoSuchBrowserSession(_)
| Self::NoSuchOAuthSession(_)
| Self::ProvisionDeviceFailed(_)
| Self::NoSuchNextRefreshToken { .. }
| Self::NoSuchNextAccessToken { .. }
| Self::NoAccessTokenOnRefreshToken { .. }
);
TOKEN_REQUEST_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
let response = match self {
Self::Internal(_)
| Self::NoSuchBrowserSession
| Self::NoSuchOAuthSession
| Self::ClientCredentialsVerification { .. }
| Self::NoSuchBrowserSession(_)
| Self::NoSuchOAuthSession(_)
| Self::ProvisionDeviceFailed(_)
| Self::NoSuchNextRefreshToken { .. }
| Self::NoSuchNextAccessToken { .. }
@@ -160,10 +185,12 @@ impl IntoResponse for RouteError {
StatusCode::INTERNAL_SERVER_ERROR,
Json(ClientError::from(ClientErrorCode::ServerError)),
),
Self::BadRequest => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
),
Self::PkceVerification(err) => (
StatusCode::BAD_REQUEST,
Json(
@@ -171,19 +198,25 @@ impl IntoResponse for RouteError {
.with_description(format!("PKCE verification failed: {err}")),
),
),
Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
StatusCode::UNAUTHORIZED,
Json(ClientError::from(ClientErrorCode::InvalidClient)),
),
Self::ClientNotAllowed | Self::UnauthorizedClient => (
Self::ClientNotAllowed(_)
| Self::UnauthorizedClient(_)
| Self::UnexptectedClient { .. } => (
StatusCode::UNAUTHORIZED,
Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
),
Self::DeniedByPolicy(violations) => (
Self::DeniedByPolicy(evaluation) => (
StatusCode::FORBIDDEN,
Json(
ClientError::from(ClientErrorCode::InvalidScope).with_description(
violations
evaluation
.violations
.into_iter()
.map(|violation| violation.msg)
.collect::<Vec<_>>()
@@ -191,19 +224,23 @@ impl IntoResponse for RouteError {
),
),
),
Self::DeviceCodeRejected => (
StatusCode::FORBIDDEN,
Json(ClientError::from(ClientErrorCode::AccessDenied)),
),
Self::DeviceCodeExpired => (
StatusCode::FORBIDDEN,
Json(ClientError::from(ClientErrorCode::ExpiredToken)),
),
Self::DeviceCodePending => (
StatusCode::FORBIDDEN,
Json(ClientError::from(ClientErrorCode::AuthorizationPending)),
),
Self::InvalidGrant
Self::InvalidGrant(_)
| Self::DeviceCodeExchanged
| Self::RefreshTokenNotFound
| Self::RefreshTokenInvalid(_)
@@ -213,13 +250,14 @@ impl IntoResponse for RouteError {
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::InvalidGrant)),
),
Self::UnsupportedGrantType => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::UnsupportedGrantType)),
),
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
@@ -231,7 +269,6 @@ impl_from_error_for_route!(super::IdTokenSignatureError);
name = "handlers.oauth2.token.post",
fields(client.id = client_authorization.client_id()),
skip_all,
err,
)]
pub(crate) async fn post(
mut rng: BoxRng,
@@ -258,12 +295,27 @@ pub(crate) async fn post(
let method = client
.token_endpoint_auth_method
.as_ref()
.ok_or(RouteError::ClientNotAllowed)?;
.ok_or(RouteError::ClientNotAllowed(client.id))?;
client_authorization
.credentials
.verify(&http_client, &encrypter, method, &client)
.await?;
.await
.map_err(|err| {
// Classify the error differntly, depending on whether it's an 'internal' error,
// or just because the client presented invalid credentials.
if err.is_internal() {
RouteError::ClientCredentialsVerification {
client_id: client.id,
source: err,
}
} else {
RouteError::InvalidClientCredentials {
client_id: client.id,
source: err,
}
}
})?;
let form = client_authorization.form.ok_or(RouteError::BadRequest)?;
@@ -367,7 +419,7 @@ async fn authorization_code_grant(
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
// Check that the client is allowed to use this grant type
if !client.grant_types.contains(&GrantType::AuthorizationCode) {
return Err(RouteError::UnauthorizedClient);
return Err(RouteError::UnauthorizedClient(client.id));
}
let authz_grant = repo
@@ -381,40 +433,43 @@ async fn authorization_code_grant(
let session_id = match authz_grant.stage {
AuthorizationGrantStage::Cancelled { cancelled_at } => {
debug!(%cancelled_at, "Authorization grant was cancelled");
return Err(RouteError::InvalidGrant);
return Err(RouteError::InvalidGrant(authz_grant.id));
}
AuthorizationGrantStage::Exchanged {
exchanged_at,
fulfilled_at,
session_id,
} => {
debug!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged");
warn!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged");
// Ending the session if the token was already exchanged more than 20s ago
if now - exchanged_at > Duration::microseconds(20 * 1000 * 1000) {
debug!("Ending potentially compromised session");
warn!(oauth_session.id = %session_id, "Ending potentially compromised session");
let session = repo
.oauth2_session()
.lookup(session_id)
.await?
.ok_or(RouteError::NoSuchOAuthSession)?;
.ok_or(RouteError::NoSuchOAuthSession(session_id))?;
//if !session.is_finished() {
repo.oauth2_session().finish(clock, session).await?;
repo.save().await?;
//}
}
return Err(RouteError::InvalidGrant);
return Err(RouteError::InvalidGrant(authz_grant.id));
}
AuthorizationGrantStage::Pending => {
debug!("Authorization grant has not been fulfilled yet");
return Err(RouteError::InvalidGrant);
warn!("Authorization grant has not been fulfilled yet");
return Err(RouteError::InvalidGrant(authz_grant.id));
}
AuthorizationGrantStage::Fulfilled {
session_id,
fulfilled_at,
} => {
if now - fulfilled_at > Duration::microseconds(10 * 60 * 1000 * 1000) {
debug!("Code exchange took more than 10 minutes");
return Err(RouteError::InvalidGrant);
warn!("Code exchange took more than 10 minutes");
return Err(RouteError::InvalidGrant(authz_grant.id));
}
session_id
@@ -425,7 +480,7 @@ async fn authorization_code_grant(
.oauth2_session()
.lookup(session_id)
.await?
.ok_or(RouteError::NoSuchOAuthSession)?;
.ok_or(RouteError::NoSuchOAuthSession(session_id))?;
if let Some(user_agent) = user_agent {
session = repo
@@ -435,10 +490,16 @@ async fn authorization_code_grant(
}
// This should never happen, since we looked up in the database using the code
let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?;
let code = authz_grant
.code
.as_ref()
.ok_or(RouteError::InvalidGrant(authz_grant.id))?;
if client.id != session.client_id {
return Err(RouteError::UnauthorizedClient);
return Err(RouteError::UnexptectedClient {
was: client.id,
expected: session.client_id,
});
}
match (code.pkce.as_ref(), grant.code_verifier.as_ref()) {
@@ -453,14 +514,14 @@ async fn authorization_code_grant(
let Some(user_session_id) = session.user_session_id else {
tracing::warn!("No user session associated with this OAuth2 session");
return Err(RouteError::InvalidGrant);
return Err(RouteError::InvalidGrant(authz_grant.id));
};
let browser_session = repo
.browser_session()
.lookup(user_session_id)
.await?
.ok_or(RouteError::NoSuchBrowserSession)?;
.ok_or(RouteError::NoSuchBrowserSession(user_session_id))?;
let last_authentication = repo
.browser_session()
@@ -539,7 +600,7 @@ async fn refresh_token_grant(
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
// Check that the client is allowed to use this grant type
if !client.grant_types.contains(&GrantType::RefreshToken) {
return Err(RouteError::UnauthorizedClient);
return Err(RouteError::UnauthorizedClient(client.id));
}
let refresh_token = repo
@@ -552,7 +613,7 @@ async fn refresh_token_grant(
.oauth2_session()
.lookup(refresh_token.session_id)
.await?
.ok_or(RouteError::NoSuchOAuthSession)?;
.ok_or(RouteError::NoSuchOAuthSession(refresh_token.session_id))?;
// Let's for now record the user agent on each refresh, that should be
// responsive enough and not too much of a burden on the database.
@@ -692,7 +753,7 @@ async fn client_credentials_grant(
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
// Check that the client is allowed to use this grant type
if !client.grant_types.contains(&GrantType::ClientCredentials) {
return Err(RouteError::UnauthorizedClient);
return Err(RouteError::UnauthorizedClient(client.id));
}
// Default to an empty scope if none is provided
@@ -715,7 +776,7 @@ async fn client_credentials_grant(
})
.await?;
if !res.valid() {
return Err(RouteError::DeniedByPolicy(res.violations));
return Err(RouteError::DeniedByPolicy(res));
}
// Start the session
@@ -771,7 +832,7 @@ async fn device_code_grant(
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
// Check that the client is allowed to use this grant type
if !client.grant_types.contains(&GrantType::DeviceCode) {
return Err(RouteError::UnauthorizedClient);
return Err(RouteError::UnauthorizedClient(client.id));
}
let grant = repo
@@ -804,14 +865,14 @@ async fn device_code_grant(
}
DeviceCodeGrantState::Fulfilled {
browser_session_id, ..
} => browser_session_id,
} => *browser_session_id,
};
let browser_session = repo
.browser_session()
.lookup(*browser_session_id)
.lookup(browser_session_id)
.await?
.ok_or(RouteError::NoSuchBrowserSession)?;
.ok_or(RouteError::NoSuchBrowserSession(browser_session_id))?;
// Start the session
let mut session = repo

View File

@@ -12,7 +12,7 @@ use axum::{
use hyper::StatusCode;
use mas_axum_utils::{
jwt::JwtResponse,
sentry::SentryEventID,
record_error,
user_authorization::{AuthorizationVerificationError, UserAuthorization},
};
use mas_jose::{
@@ -25,6 +25,7 @@ use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2ClientRepositor
use serde::Serialize;
use serde_with::skip_serializing_none;
use thiserror::Error;
use ulid::Ulid;
use crate::{BoundActivityTracker, impl_from_error_for_route};
@@ -59,11 +60,11 @@ pub enum RouteError {
#[error("no suitable key found for signing")]
InvalidSigningKey,
#[error("failed to load client")]
NoSuchClient,
#[error("failed to load client {0}")]
NoSuchClient(Ulid),
#[error("failed to load user")]
NoSuchUser,
#[error("failed to load user {0}")]
NoSuchUser(Ulid),
}
impl_from_error_for_route!(mas_storage::RepositoryError);
@@ -72,9 +73,18 @@ impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(
self,
Self::Internal(_)
| Self::InvalidSigningKey
| Self::NoSuchClient(_)
| Self::NoSuchUser(_)
);
let response = match self {
Self::Internal(_) | Self::InvalidSigningKey | Self::NoSuchClient | Self::NoSuchUser => {
Self::Internal(_)
| Self::InvalidSigningKey
| Self::NoSuchClient(_)
| Self::NoSuchUser(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
Self::AuthorizationVerificationError(_) | Self::Unauthorized => {
@@ -82,11 +92,11 @@ impl IntoResponse for RouteError {
}
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
#[tracing::instrument(name = "handlers.oauth2.userinfo.get", skip_all, err)]
#[tracing::instrument(name = "handlers.oauth2.userinfo.get", skip_all)]
pub async fn get(
mut rng: BoxRng,
clock: BoxClock,
@@ -116,7 +126,7 @@ pub async fn get(
.user()
.lookup(user_id)
.await?
.ok_or(RouteError::NoSuchUser)?;
.ok_or(RouteError::NoSuchUser(user_id))?;
let user_info = UserInfo {
sub: user.sub.clone(),
@@ -127,7 +137,7 @@ pub async fn get(
.oauth2_client()
.lookup(session.client_id)
.await?
.ok_or(RouteError::NoSuchClient)?;
.ok_or(RouteError::NoSuchClient(session.client_id))?;
repo.save().await?;

View File

@@ -9,7 +9,7 @@ use axum::{
response::{IntoResponse, Redirect},
};
use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID};
use mas_axum_utils::{cookies::CookieJar, record_error};
use mas_data_model::UpstreamOAuthProvider;
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
use mas_router::UrlBuilder;
@@ -41,13 +41,13 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let response = match self {
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
@@ -55,7 +55,6 @@ impl IntoResponse for RouteError {
name = "handlers.upstream_oauth2.authorize.get",
fields(upstream_oauth_provider.id = %provider_id),
skip_all,
err,
)]
pub(crate) async fn get(
mut rng: BoxRng,

View File

@@ -6,6 +6,7 @@
use std::{collections::HashMap, sync::Arc};
use mas_context::LogContext;
use mas_data_model::{
UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
};
@@ -164,7 +165,7 @@ impl MetadataCache {
///
/// This spawns a background task that will refresh the cache at the given
/// interval.
#[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all, err)]
#[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all)]
pub async fn warm_up_and_run<R: RepositoryAccess>(
&self,
client: &reqwest::Client,
@@ -197,12 +198,14 @@ impl MetadataCache {
loop {
// Re-fetch the known metadata at the given interval
tokio::time::sleep(interval).await;
cache.refresh_all(&client).await;
LogContext::new("metadata-cache-refresh")
.run(|| cache.refresh_all(&client))
.await;
}
}))
}
#[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all, err)]
#[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all)]
async fn fetch(
&self,
client: &reqwest::Client,
@@ -234,7 +237,7 @@ impl MetadataCache {
}
/// Get the metadata for the given issuer.
#[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all, err)]
#[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all)]
pub async fn get(
&self,
client: &reqwest::Client,

View File

@@ -13,7 +13,7 @@ use axum::{
response::{Html, IntoResponse, Response},
};
use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID};
use mas_axum_utils::{cookies::CookieJar, record_error};
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderResponseMode};
use mas_jose::claims::TokenHash;
use mas_keystore::{Encrypter, Keystore};
@@ -153,7 +153,7 @@ impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let response = match self {
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
@@ -161,7 +161,7 @@ impl IntoResponse for RouteError {
e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
@@ -169,7 +169,6 @@ impl IntoResponse for RouteError {
name = "handlers.upstream_oauth2.callback.handler",
fields(upstream_oauth_provider.id = %provider_id),
skip_all,
err,
)]
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
pub(crate) async fn handler(

View File

@@ -17,7 +17,7 @@ use mas_axum_utils::{
FancyError, SessionInfoExt,
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
sentry::SentryEventID,
record_error,
};
use mas_data_model::UserAgent;
use mas_jose::jwt::Jwt;
@@ -77,16 +77,16 @@ pub(crate) enum RouteError {
LinkNotFound,
/// Couldn't find the session on the link
#[error("Session not found")]
SessionNotFound,
#[error("Session {0} not found")]
SessionNotFound(Ulid),
/// Couldn't find the user
#[error("User not found")]
UserNotFound,
#[error("User {0} not found")]
UserNotFound(Ulid),
/// Couldn't find upstream provider
#[error("Upstream provider not found")]
ProviderNotFound,
#[error("Upstream provider {0} not found")]
ProviderNotFound(Ulid),
/// Required attribute rendered to an empty string
#[error("Template {template:?} rendered to an empty string")]
@@ -104,8 +104,8 @@ pub(crate) enum RouteError {
},
/// Session was already consumed
#[error("Session already consumed")]
SessionConsumed,
#[error("Session {0} already consumed")]
SessionConsumed(Ulid),
#[error("Missing session cookie")]
MissingCookie,
@@ -129,14 +129,23 @@ impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let sentry_event_id = record_error!(
self,
Self::Internal(_)
| Self::RequiredAttributeEmpty { .. }
| Self::RequiredAttributeRender { .. }
| Self::SessionNotFound(_)
| Self::ProviderNotFound(_)
| Self::UserNotFound(_)
| Self::HomeserverConnection(_)
);
let response = match self {
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
Self::Internal(e) => FancyError::from(e).into_response(),
e => FancyError::from(e).into_response(),
};
(SentryEventID::from(event_id), response).into_response()
(sentry_event_id, response).into_response()
}
}
@@ -209,7 +218,6 @@ impl ToFormState for FormData {
name = "handlers.upstream_oauth2.link.get",
fields(upstream_oauth_link.id = %link_id),
skip_all,
err,
)]
pub(crate) async fn get(
mut rng: BoxRng,
@@ -245,16 +253,16 @@ pub(crate) async fn get(
.upstream_oauth_session()
.lookup(session_id)
.await?
.ok_or(RouteError::SessionNotFound)?;
.ok_or(RouteError::SessionNotFound(session_id))?;
// This checks that we're in a browser session which is allowed to consume this
// link: the upstream auth session should have been started in this browser.
if upstream_session.link_id() != Some(link.id) {
return Err(RouteError::SessionNotFound);
return Err(RouteError::SessionNotFound(session_id));
}
if upstream_session.is_consumed() {
return Err(RouteError::SessionConsumed);
return Err(RouteError::SessionConsumed(session_id));
}
let (user_session_info, cookie_jar) = cookie_jar.session_info();
@@ -289,7 +297,7 @@ pub(crate) async fn get(
.user()
.lookup(user_id)
.await?
.ok_or(RouteError::UserNotFound)?;
.ok_or(RouteError::UserNotFound(user_id))?;
let ctx = UpstreamExistingLinkContext::new(user)
.with_session(user_session)
@@ -315,7 +323,7 @@ pub(crate) async fn get(
.user()
.lookup(user_id)
.await?
.ok_or(RouteError::UserNotFound)?;
.ok_or(RouteError::UserNotFound(user_id))?;
// Check that the user is not locked or deactivated
if user.deactivated_at.is_some() {
@@ -377,7 +385,7 @@ pub(crate) async fn get(
.upstream_oauth_provider()
.lookup(link.provider_id)
.await?
.ok_or(RouteError::ProviderNotFound)?;
.ok_or(RouteError::ProviderNotFound(link.provider_id))?;
let ctx = UpstreamRegister::new(link.clone(), provider.clone());
@@ -543,7 +551,6 @@ pub(crate) async fn get(
name = "handlers.upstream_oauth2.link.post",
fields(upstream_oauth_link.id = %link_id),
skip_all,
err,
)]
pub(crate) async fn post(
mut rng: BoxRng,
@@ -583,16 +590,16 @@ pub(crate) async fn post(
.upstream_oauth_session()
.lookup(session_id)
.await?
.ok_or(RouteError::SessionNotFound)?;
.ok_or(RouteError::SessionNotFound(session_id))?;
// This checks that we're in a browser session which is allowed to consume this
// link: the upstream auth session should have been started in this browser.
if upstream_session.link_id() != Some(link.id) {
return Err(RouteError::SessionNotFound);
return Err(RouteError::SessionNotFound(session_id));
}
if upstream_session.is_consumed() {
return Err(RouteError::SessionConsumed);
return Err(RouteError::SessionConsumed(session_id));
}
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
@@ -637,7 +644,7 @@ pub(crate) async fn post(
.upstream_oauth_provider()
.lookup(link.provider_id)
.await?
.ok_or(RouteError::ProviderNotFound)?;
.ok_or(RouteError::ProviderNotFound(link.provider_id))?;
// Let's try to import the claims from the ID token
let env = environment();

View File

@@ -25,7 +25,7 @@ pub struct Params {
action: Option<mas_router::AccountAction>,
}
#[tracing::instrument(name = "handlers.views.app.get", skip_all, err)]
#[tracing::instrument(name = "handlers.views.app.get", skip_all)]
pub async fn get(
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
@@ -74,7 +74,7 @@ pub async fn get(
/// Like `get`, but allow anonymous access.
/// Used for a subset of the account management paths.
/// Needed for e.g. account recovery.
#[tracing::instrument(name = "handlers.views.app.get_anonymous", skip_all, err)]
#[tracing::instrument(name = "handlers.views.app.get_anonymous", skip_all)]
pub async fn get_anonymous(
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,

View File

@@ -19,7 +19,7 @@ use crate::{
session::{SessionOrFallback, load_session_or_fallback},
};
#[tracing::instrument(name = "handlers.views.index.get", skip_all, err)]
#[tracing::instrument(name = "handlers.views.index.get", skip_all)]
pub async fn get(
mut rng: BoxRng,
clock: BoxClock,

View File

@@ -61,7 +61,7 @@ impl ToFormState for LoginForm {
type Field = LoginFormField;
}
#[tracing::instrument(name = "handlers.views.login.get", skip_all, err)]
#[tracing::instrument(name = "handlers.views.login.get", skip_all)]
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
@@ -127,7 +127,7 @@ pub(crate) async fn get(
.await
}
#[tracing::instrument(name = "handlers.views.login.post", skip_all, err)]
#[tracing::instrument(name = "handlers.views.login.post", skip_all)]
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,

View File

@@ -18,7 +18,7 @@ use mas_storage::{BoxClock, BoxRepository, user::BrowserSessionRepository};
use crate::BoundActivityTracker;
#[tracing::instrument(name = "handlers.views.logout.post", skip_all, err)]
#[tracing::instrument(name = "handlers.views.logout.post", skip_all)]
pub(crate) async fn post(
clock: BoxClock,
mut repo: BoxRepository,

View File

@@ -20,7 +20,7 @@ mod cookie;
pub(crate) mod password;
pub(crate) mod steps;
#[tracing::instrument(name = "handlers.views.register.get", skip_all, err)]
#[tracing::instrument(name = "handlers.views.register.get", skip_all)]
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,

View File

@@ -66,7 +66,7 @@ pub struct QueryParams {
action: OptionalPostAuthAction,
}
#[tracing::instrument(name = "handlers.views.password_register.get", skip_all, err)]
#[tracing::instrument(name = "handlers.views.password_register.get", skip_all)]
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
@@ -118,7 +118,7 @@ pub(crate) async fn get(
Ok((cookie_jar, Html(content)).into_response())
}
#[tracing::instrument(name = "handlers.views.password_register.post", skip_all, err)]
#[tracing::instrument(name = "handlers.views.password_register.post", skip_all)]
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
pub(crate) async fn post(
mut rng: BoxRng,

View File

@@ -49,7 +49,6 @@ impl ToFormState for DisplayNameForm {
name = "handlers.views.register.steps.display_name.get",
fields(user_registration.id = %id),
skip_all,
err,
)]
pub(crate) async fn get(
mut rng: BoxRng,
@@ -100,7 +99,6 @@ pub(crate) async fn get(
name = "handlers.views.register.steps.display_name.post",
fields(user_registration.id = %id),
skip_all,
err,
)]
pub(crate) async fn post(
mut rng: BoxRng,

View File

@@ -42,7 +42,6 @@ static PASSWORD_REGISTER_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
name = "handlers.views.register.steps.finish.get",
fields(user_registration.id = %id),
skip_all,
err,
)]
pub(crate) async fn get(
mut rng: BoxRng,

View File

@@ -37,7 +37,6 @@ impl ToFormState for CodeForm {
name = "handlers.views.register.steps.verify_email.get",
fields(user_registration.id = %id),
skip_all,
err,
)]
pub(crate) async fn get(
mut rng: BoxRng,
@@ -104,7 +103,6 @@ pub(crate) async fn get(
name = "handlers.views.account_email_verify.post",
fields(user_email.id = %id),
skip_all,
err,
)]
pub(crate) async fn post(
clock: BoxClock,

View File

@@ -17,7 +17,7 @@ futures-util.workspace = true
http-body.workspace = true
hyper = { workspace = true, features = ["server"] }
hyper-util.workspace = true
pin-project-lite = "0.2.16"
pin-project-lite.workspace = true
socket2 = "0.5.9"
thiserror.workspace = true
tokio.workspace = true
@@ -27,6 +27,8 @@ tower.workspace = true
tower-http.workspace = true
tracing.workspace = true
mas-context.workspace = true
[dev-dependencies]
anyhow.workspace = true
rustls-pemfile = "2.2.0"

View File

@@ -18,6 +18,7 @@ use hyper_util::{
server::conn::auto::Connection,
service::TowerToHyperService,
};
use mas_context::LogContext;
use pin_project_lite::pin_project;
use thiserror::Error;
use tokio_rustls::rustls::ServerConfig;
@@ -107,12 +108,6 @@ impl<S> Server<S> {
#[derive(Debug, Error)]
#[non_exhaustive]
enum AcceptError {
#[error("failed to accept connection from the underlying socket")]
Socket {
#[source]
source: std::io::Error,
},
#[error("failed to complete the TLS handshake")]
TlsHandshake {
#[source]
@@ -133,10 +128,6 @@ enum AcceptError {
}
impl AcceptError {
fn socket(source: std::io::Error) -> Self {
Self::Socket { source }
}
fn tls_handshake(source: std::io::Error) -> Self {
Self::TlsHandshake { source }
}
@@ -164,7 +155,6 @@ impl AcceptError {
network.peer.address,
network.peer.port,
),
err,
)]
async fn accept<S, B>(
maybe_proxy_acceptor: &MaybeProxyAcceptor,
@@ -357,12 +347,16 @@ pub async fn run_servers<S, B>(
// Poll on the JoinSet to collect connections to serve
res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
match res {
Some(Ok(Ok(connection))) => {
tracing::trace!("Accepted connection");
let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token());
connection_tasks.spawn(conn);
Some(Ok(Some(connection))) => {
let token = soft_shutdown_token.child_token();
connection_tasks.spawn(LogContext::new("http-serve").run(async move || {
tracing::debug!("Accepted connection");
if let Err(e) = AbortableConnection::new(connection, token).await {
tracing::warn!(error = &*e as &dyn std::error::Error, "Failed to serve connection");
}
}));
},
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
Some(Ok(None)) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
None => tracing::error!("Join set was polled even though it was empty"),
}
@@ -371,8 +365,7 @@ pub async fn run_servers<S, B>(
// Poll on the JoinSet to collect finished connections
res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
match res {
Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
Some(Ok(())) => { /* Connection finished, any errors should be logged in in the spawned task */ },
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
None => tracing::error!("Join set was polled even though it was empty"),
}
@@ -385,11 +378,23 @@ pub async fn run_servers<S, B>(
// Spawn the connection in the set, so we don't have to wait for the handshake to
// accept the next connection. This allows us to keep track of active connections
// and waiting on them for a graceful shutdown
accept_tasks.spawn(async move {
let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = res
.map_err(AcceptError::socket)?;
accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await
});
accept_tasks.spawn(LogContext::new("http-accept").run(async move || {
let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = match res {
Ok(res) => res,
Err(e) => {
tracing::warn!(error = &e as &dyn std::error::Error, "Failed to accept connection from the underlying socket");
return None;
}
};
match accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await {
Ok(connection) => Some(connection),
Err(e) => {
tracing::warn!(error = &e as &dyn std::error::Error, "Failed to accept connection");
None
}
}
}));
},
};
}
@@ -409,12 +414,16 @@ pub async fn run_servers<S, B>(
// Poll on the JoinSet to collect connections to serve
res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
match res {
Some(Ok(Ok(connection))) => {
tracing::trace!("Accepted connection");
let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token());
connection_tasks.spawn(conn);
Some(Ok(Some(connection))) => {
let token = soft_shutdown_token.child_token();
connection_tasks.spawn(LogContext::new("http-serve").run(async || {
tracing::debug!("Accepted connection");
if let Err(e) = AbortableConnection::new(connection, token).await {
tracing::warn!(error = &*e as &dyn std::error::Error, "Failed to serve connection");
}
}));
}
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
Some(Ok(None)) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
None => tracing::error!("Join set was polled even though it was empty"),
}
@@ -423,8 +432,7 @@ pub async fn run_servers<S, B>(
// Poll on the JoinSet to collect finished connections
res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
match res {
Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
Some(Ok(())) => { /* Connection finished, any errors should be logged in in the spawned task */ },
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
None => tracing::error!("Join set was polled even though it was empty"),
}

View File

@@ -197,7 +197,7 @@ pub struct PolicyFactory {
}
impl PolicyFactory {
#[tracing::instrument(name = "policy.load", skip(source), err)]
#[tracing::instrument(name = "policy.load", skip(source))]
pub async fn load(
mut source: impl AsyncRead + std::marker::Unpin,
data: Data,
@@ -283,7 +283,7 @@ impl PolicyFactory {
Ok(true)
}
#[tracing::instrument(name = "policy.instantiate", skip_all, err)]
#[tracing::instrument(name = "policy.instantiate", skip_all)]
pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
let data = self.dynamic_data.load();
self.instantiate_with_data(&data.merged).await
@@ -342,7 +342,6 @@ impl Policy {
fields(
%input.email,
),
err,
)]
pub async fn evaluate_email(
&mut self,
@@ -364,7 +363,6 @@ impl Policy {
input.username = input.username,
input.email = input.email,
),
err,
)]
pub async fn evaluate_register(
&mut self,
@@ -402,7 +400,6 @@ impl Policy {
%input.scope,
%input.client.id,
),
err,
)]
pub async fn evaluate_authorization_grant(
&mut self,

View File

@@ -30,6 +30,7 @@ ulid.workspace = true
serde.workspace = true
serde_json.workspace = true
mas-context.workspace = true
mas-data-model.workspace = true
mas-email.workspace = true
mas-i18n.workspace = true

View File

@@ -17,7 +17,7 @@ use crate::{
#[async_trait]
impl RunnableJob for CleanupExpiredTokensJob {
#[tracing::instrument(name = "job.cleanup_expired_tokens", skip_all, err)]
#[tracing::instrument(name = "job.cleanup_expired_tokens", skip_all)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let clock = state.clock();
let mut repo = state.repository().await.map_err(JobError::retry)?;
@@ -41,7 +41,7 @@ impl RunnableJob for CleanupExpiredTokensJob {
#[async_trait]
impl RunnableJob for PruneStalePolicyDataJob {
#[tracing::instrument(name = "job.prune_stale_policy_data", skip_all, err)]
#[tracing::instrument(name = "job.prune_stale_policy_data", skip_all)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let mut repo = state.repository().await.map_err(JobError::retry)?;

View File

@@ -23,7 +23,6 @@ impl RunnableJob for VerifyEmailJob {
name = "job.verify_email",
fields(user_email.id = %self.user_email_id()),
skip_all,
err,
)]
async fn run(&self, _state: &State, _context: JobContext) -> Result<(), JobError> {
// This job was for the old email verification flow, which has been replaced.
@@ -39,7 +38,6 @@ impl RunnableJob for SendEmailAuthenticationCodeJob {
name = "job.send_email_authentication_code",
fields(user_email_authentication.id = %self.user_email_authentication_id()),
skip_all,
err,
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let clock = state.clock();

View File

@@ -36,7 +36,6 @@ impl RunnableJob for ProvisionUserJob {
name = "job.provision_user"
fields(user.id = %self.user_id()),
skip_all,
err,
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let matrix = state.matrix_connection();
@@ -103,7 +102,6 @@ impl RunnableJob for ProvisionDeviceJob {
device.id = %self.device_id(),
),
skip_all,
err,
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let mut repo = state.repository().await.map_err(JobError::retry)?;
@@ -140,7 +138,6 @@ impl RunnableJob for DeleteDeviceJob {
device.id = %self.device_id(),
),
skip_all,
err,
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let mut rng = state.rng();
@@ -172,7 +169,6 @@ impl RunnableJob for SyncDevicesJob {
name = "job.sync_devices",
fields(user.id = %self.user_id()),
skip_all,
err,
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let matrix = state.matrix_connection();

View File

@@ -8,6 +8,7 @@ use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use cron::Schedule;
use mas_context::LogContext;
use mas_storage::{
Clock, RepositoryAccess, RepositoryError,
queue::{InsertableJob, Job, JobMetadata, Worker},
@@ -183,7 +184,7 @@ fn retry_delay(attempt: usize) -> Duration {
Duration::milliseconds(2_i64.saturating_pow(attempt) * 5_000)
}
type JobResult = Result<(), JobError>;
type JobResult = (std::time::Duration, Result<(), JobError>);
type JobFactory = Arc<dyn Fn(JobPayload) -> Box<dyn RunnableJob> + Send + Sync>;
struct ScheduleDefinition {
@@ -252,7 +253,7 @@ impl QueueWorker {
.await
.map_err(QueueRunnerError::CommitTransaction)?;
tracing::info!("Registered worker");
tracing::info!(worker.id = %registration.id, "Registered worker");
let now = clock.now();
let wakeup_reason = METER
@@ -337,7 +338,9 @@ impl QueueWorker {
self.setup_schedules().await?;
while !self.cancellation_token.is_cancelled() {
self.run_loop().await?;
LogContext::new("worker-run-loop")
.run(|| self.run_loop())
.await?;
}
self.shutdown().await?;
@@ -345,7 +348,7 @@ impl QueueWorker {
Ok(())
}
#[tracing::instrument(name = "worker.setup_schedules", skip_all, err)]
#[tracing::instrument(name = "worker.setup_schedules", skip_all)]
pub async fn setup_schedules(&mut self) -> Result<(), QueueRunnerError> {
let schedules: Vec<_> = self.schedules.iter().map(|s| s.schedule_name).collect();
@@ -369,7 +372,7 @@ impl QueueWorker {
Ok(())
}
#[tracing::instrument(name = "worker.run_loop", skip_all, err)]
#[tracing::instrument(name = "worker.run_loop", skip_all)]
async fn run_loop(&mut self) -> Result<(), QueueRunnerError> {
self.wait_until_wakeup().await?;
@@ -390,7 +393,7 @@ impl QueueWorker {
Ok(())
}
#[tracing::instrument(name = "worker.shutdown", skip_all, err)]
#[tracing::instrument(name = "worker.shutdown", skip_all)]
async fn shutdown(&mut self) -> Result<(), QueueRunnerError> {
tracing::info!("Shutting down worker");
@@ -435,7 +438,7 @@ impl QueueWorker {
Ok(())
}
#[tracing::instrument(name = "worker.wait_until_wakeup", skip_all, err)]
#[tracing::instrument(name = "worker.wait_until_wakeup", skip_all)]
async fn wait_until_wakeup(&mut self) -> Result<(), QueueRunnerError> {
// This is to make sure we wake up every second to do the maintenance tasks
// We add a little bit of random jitter to the duration, so that we don't get
@@ -484,7 +487,6 @@ impl QueueWorker {
name = "worker.tick",
skip_all,
fields(worker.id = %self.registration.id),
err,
)]
async fn tick(&mut self) -> Result<(), QueueRunnerError> {
tracing::debug!("Tick");
@@ -583,7 +585,7 @@ impl QueueWorker {
Ok(())
}
#[tracing::instrument(name = "worker.perform_leader_duties", skip_all, err)]
#[tracing::instrument(name = "worker.perform_leader_duties", skip_all)]
async fn perform_leader_duties(&mut self) -> Result<(), QueueRunnerError> {
// This should have been checked by the caller, but better safe than sorry
if !self.am_i_leader {
@@ -771,16 +773,86 @@ impl JobTracker {
fn spawn_job(&mut self, state: State, context: JobContext, payload: JobPayload) {
let factory = self.factories.get(context.queue_name.as_str()).cloned();
let task = {
let log_context = LogContext::new(format!("job-{}", context.queue_name));
let context = context.clone();
let span = context.span();
async move {
// We should never crash, but in case we do, we do that in the task and
// don't crash the worker
let job = factory.expect("unknown job factory")(payload);
tracing::info!("Running job");
job.run(&state, context).await
}
.instrument(span)
log_context
.run(async move || {
// We should never crash, but in case we do, we do that in the task and
// don't crash the worker
let job = factory.expect("unknown job factory")(payload);
tracing::info!(
job.id = %context.id,
job.queue.name = %context.queue_name,
job.attempt = %context.attempt,
"Running job"
);
let result = job.run(&state, context.clone()).await;
let Some(log_context) = LogContext::current() else {
// This should never happen, but if it does it's fine: we're recovering fine
// from panics in those tasks
panic!("Missing log context, this should never happen");
};
let context_stats = log_context.stats();
// We log the result here so that it's attached to the right span & log context
match &result {
Ok(()) => {
tracing::info!(
job.id = %context.id,
job.queue.name = %context.queue_name,
job.attempt = %context.attempt,
"Job completed [{context_stats}]"
);
}
Err(JobError {
decision: JobErrorDecision::Fail,
error,
}) => {
tracing::error!(
error = &**error as &dyn std::error::Error,
job.id = %context.id,
job.queue.name = %context.queue_name,
job.attempt = %context.attempt,
"Job failed, not retrying [{context_stats}]"
);
}
Err(JobError {
decision: JobErrorDecision::Retry,
error,
}) if context.attempt < MAX_ATTEMPTS => {
let delay = retry_delay(context.attempt);
tracing::warn!(
error = &**error as &dyn std::error::Error,
job.id = %context.id,
job.queue.name = %context.queue_name,
job.attempt = %context.attempt,
"Job failed, will retry in {}s [{context_stats}]",
delay.num_seconds()
);
}
Err(JobError {
decision: JobErrorDecision::Retry,
error,
}) => {
tracing::error!(
error = &**error as &dyn std::error::Error,
job.id = %context.id,
job.queue.name = %context.queue_name,
job.attempt = %context.attempt,
"Job failed too many times, abandonning [{context_stats}]"
);
}
}
(context_stats.elapsed, result)
})
.instrument(span)
};
self.in_flight_jobs.add(
@@ -837,15 +909,10 @@ impl JobTracker {
}
}
// XXX: the time measurement isn't accurate, as it would include the
// time spent between the task finishing, and us processing the result.
// It's fine for now, as it at least gives us an idea of how many tasks
// we run, and what their status is
while let Some(result) = self.last_join_result.take() {
match result {
// The job succeeded
Ok((id, Ok(()))) => {
// The job succeeded. The logging and time measurement is already done in the task
Ok((id, (elapsed, Ok(())))) => {
let context = self
.job_contexts
.remove(&id)
@@ -856,22 +923,9 @@ impl JobTracker {
&[KeyValue::new("job.queue.name", context.queue_name.clone())],
);
let elapsed = context
.start
.elapsed()
.as_millis()
.try_into()
.unwrap_or(u64::MAX);
tracing::info!(
job.id = %context.id,
job.queue.name = %context.queue_name,
job.attempt = %context.attempt,
job.elapsed = format!("{elapsed}ms"),
"Job completed"
);
let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
self.job_processing_time.record(
elapsed,
elapsed_ms,
&[
KeyValue::new("job.queue.name", context.queue_name),
KeyValue::new("job.result", "success"),
@@ -883,8 +937,8 @@ impl JobTracker {
.await?;
}
// The job failed
Ok((id, Err(e))) => {
// The job failed. The logging and time measurement is already done in the task
Ok((id, (elapsed, Err(e)))) => {
let context = self
.job_contexts
.remove(&id)
@@ -900,26 +954,11 @@ impl JobTracker {
.mark_as_failed(clock, context.id, &reason)
.await?;
let elapsed = context
.start
.elapsed()
.as_millis()
.try_into()
.unwrap_or(u64::MAX);
let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
match e.decision {
JobErrorDecision::Fail => {
tracing::error!(
error = &e as &dyn std::error::Error,
job.id = %context.id,
job.queue.name = %context.queue_name,
job.attempt = %context.attempt,
job.elapsed = format!("{elapsed}ms"),
"Job failed, not retrying"
);
self.job_processing_time.record(
elapsed,
elapsed_ms,
&[
KeyValue::new("job.queue.name", context.queue_name),
KeyValue::new("job.result", "failed"),
@@ -928,50 +967,31 @@ impl JobTracker {
);
}
JobErrorDecision::Retry if context.attempt < MAX_ATTEMPTS => {
self.job_processing_time.record(
elapsed_ms,
&[
KeyValue::new("job.queue.name", context.queue_name),
KeyValue::new("job.result", "failed"),
KeyValue::new("job.decision", "retry"),
],
);
let delay = retry_delay(context.attempt);
repo.queue_job()
.retry(&mut *rng, clock, context.id, delay)
.await?;
}
JobErrorDecision::Retry => {
if context.attempt < MAX_ATTEMPTS {
let delay = retry_delay(context.attempt);
tracing::warn!(
error = &e as &dyn std::error::Error,
job.id = %context.id,
job.queue.name = %context.queue_name,
job.attempt = %context.attempt,
job.elapsed = format!("{elapsed}ms"),
"Job failed, will retry in {}s",
delay.num_seconds()
);
self.job_processing_time.record(
elapsed,
&[
KeyValue::new("job.queue.name", context.queue_name),
KeyValue::new("job.result", "failed"),
KeyValue::new("job.decision", "retry"),
],
);
repo.queue_job()
.retry(&mut *rng, clock, context.id, delay)
.await?;
} else {
tracing::error!(
error = &e as &dyn std::error::Error,
job.id = %context.id,
job.queue.name = %context.queue_name,
job.attempt = %context.attempt,
job.elapsed = format!("{elapsed}ms"),
"Job failed too many times, abandonning"
);
self.job_processing_time.record(
elapsed,
&[
KeyValue::new("job.queue.name", context.queue_name),
KeyValue::new("job.result", "failed"),
KeyValue::new("job.decision", "abandon"),
],
);
}
self.job_processing_time.record(
elapsed_ms,
&[
KeyValue::new("job.queue.name", context.queue_name),
KeyValue::new("job.result", "failed"),
KeyValue::new("job.decision", "abandon"),
],
);
}
}
}
@@ -989,6 +1009,8 @@ impl JobTracker {
&[KeyValue::new("job.queue.name", context.queue_name.clone())],
);
// This measurement is not accurate as it includes the time processing the jobs,
// but it's fine, it's only for panicked tasks
let elapsed = context
.start
.elapsed()
@@ -1003,7 +1025,7 @@ impl JobTracker {
if context.attempt < MAX_ATTEMPTS {
let delay = retry_delay(context.attempt);
tracing::warn!(
tracing::error!(
error = &e as &dyn std::error::Error,
job.id = %context.id,
job.queue.name = %context.queue_name,

View File

@@ -32,7 +32,6 @@ impl RunnableJob for SendAccountRecoveryEmailsJob {
user_recovery_session.email,
),
skip_all,
err,
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let clock = state.clock();

View File

@@ -27,7 +27,6 @@ impl RunnableJob for DeactivateUserJob {
name = "job.deactivate_user"
fields(user.id = %self.user_id(), erase = %self.hs_erase()),
skip_all,
err,
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let clock = state.clock();
@@ -118,7 +117,6 @@ impl RunnableJob for ReactivateUserJob {
name = "job.reactivate_user",
fields(user.id = %self.user_id()),
skip_all,
err,
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let matrix = state.matrix_connection();

View File

@@ -138,7 +138,6 @@ impl Templates {
name = "templates.load",
skip_all,
fields(%path),
err,
)]
pub async fn load(
path: Utf8PathBuf,
@@ -258,7 +257,6 @@ impl Templates {
name = "templates.reload",
skip_all,
fields(path = %self.path),
err,
)]
pub async fn reload(&self) -> Result<(), TemplateLoadingError> {
let (translator, environment) = Self::load_(

View File

@@ -19,4 +19,4 @@ tower.workspace = true
opentelemetry.workspace = true
opentelemetry-http.workspace = true
opentelemetry-semantic-conventions.workspace = true
pin-project-lite = "0.2.16"
pin-project-lite.workspace = true