diff --git a/Cargo.lock b/Cargo.lock index e4a179e53..86a87f359 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3246,6 +3246,17 @@ dependencies = [ "url", ] +[[package]] +name = "mas-context" +version = "0.15.0-rc.0" +dependencies = [ + "pin-project-lite", + "quanta", + "tokio", + "tower-layer", + "tower-service", +] + [[package]] name = "mas-data-model" version = "0.15.0-rc.0" diff --git a/Cargo.toml b/Cargo.toml index f2bd90196..b6fe86947 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ broken_intra_doc_links = "deny" mas-axum-utils = { path = "./crates/axum-utils/", version = "=0.15.0-rc.0" } mas-cli = { path = "./crates/cli/", version = "=0.15.0-rc.0" } mas-config = { path = "./crates/config/", version = "=0.15.0-rc.0" } +mas-context = { path = "./crates/context/", version = "=0.15.0-rc.0" } mas-data-model = { path = "./crates/data-model/", version = "=0.15.0-rc.0" } mas-email = { path = "./crates/email/", version = "=0.15.0-rc.0" } mas-graphql = { path = "./crates/graphql/", version = "=0.15.0-rc.0" } @@ -248,6 +249,10 @@ features = ["std"] version = "0.7.0" features = ["std"] +# Pin projection +[workspace.dependencies.pin-project-lite] +version = "0.2.16" + # PKCS#1 encoding [workspace.dependencies.pkcs1] version = "0.7.5" @@ -258,6 +263,10 @@ features = ["std"] version = "0.10.2" features = ["std", "pkcs5", "encryption"] +# High-precision clock +[workspace.dependencies.quanta] +version = "0.12.5" + # Random values [workspace.dependencies.rand] version = "0.8.5" @@ -374,6 +383,14 @@ features = ["rt"] version = "0.5.2" features = ["util"] +# Tower service trait +[workspace.dependencies.tower-service] +version = "0.3.3" + +# Tower layer trait +[workspace.dependencies.tower-layer] +version = "0.3.3" + # Tower HTTP layers [workspace.dependencies.tower-http] version = "0.6.2" diff --git a/crates/context/Cargo.toml b/crates/context/Cargo.toml new file mode 100644 index 000000000..b0f422b51 --- /dev/null +++ b/crates/context/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "mas-context" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true +publish = false + +[lints] +workspace = true + +[dependencies] +pin-project-lite.workspace = true +quanta.workspace = true +tokio.workspace = true +tower-service.workspace = true +tower-layer.workspace = true diff --git a/crates/context/src/future.rs b/crates/context/src/future.rs new file mode 100644 index 000000000..9e93af4fa --- /dev/null +++ b/crates/context/src/future.rs @@ -0,0 +1,59 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use std::{ + pin::Pin, + sync::atomic::Ordering, + task::{Context, Poll}, +}; + +use quanta::Instant; +use tokio::task::futures::TaskLocalFuture; + +use crate::LogContext; + +pub type LogContextFuture = TaskLocalFuture>; + +impl LogContext { + /// Wrap a future with the given log context + pub(crate) fn wrap_future(&self, future: F) -> LogContextFuture { + let future = PollRecordingFuture::new(future); + crate::CURRENT_LOG_CONTEXT.scope(self.clone(), future) + } +} + +pin_project_lite::pin_project! { + /// A future which records the elapsed time and the number of polls in the + /// active log context + pub struct PollRecordingFuture { + #[pin] + inner: F, + } +} + +impl PollRecordingFuture { + pub(crate) fn new(inner: F) -> Self { + Self { inner } + } +} + +impl Future for PollRecordingFuture { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let start = Instant::now(); + let this = self.project(); + let result = this.inner.poll(cx); + + // Record the number of polls and the time we spent polling the future + let elapsed = start.elapsed().as_nanos().try_into().unwrap_or(u64::MAX); + let _ = crate::CURRENT_LOG_CONTEXT.try_with(|c| { + c.inner.polls.fetch_add(1, Ordering::Relaxed); + c.inner.cpu_time.fetch_add(elapsed, Ordering::Relaxed); + }); + + result + } +} diff --git a/crates/context/src/layer.rs b/crates/context/src/layer.rs new file mode 100644 index 000000000..0ce6e3497 --- /dev/null +++ b/crates/context/src/layer.rs @@ -0,0 +1,41 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use std::borrow::Cow; + +use tower_layer::Layer; +use tower_service::Service; + +use crate::LogContextService; + +/// A layer which creates a log context for each request. +pub struct LogContextLayer { + tagger: fn(&R) -> Cow<'static, str>, +} + +impl Clone for LogContextLayer { + fn clone(&self) -> Self { + Self { + tagger: self.tagger, + } + } +} + +impl LogContextLayer { + pub fn new(tagger: fn(&R) -> Cow<'static, str>) -> Self { + Self { tagger } + } +} + +impl Layer for LogContextLayer +where + S: Service, +{ + type Service = LogContextService; + + fn layer(&self, inner: S) -> Self::Service { + LogContextService::new(inner, self.tagger) + } +} diff --git a/crates/context/src/lib.rs b/crates/context/src/lib.rs new file mode 100644 index 000000000..54cdff095 --- /dev/null +++ b/crates/context/src/lib.rs @@ -0,0 +1,126 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +mod future; +mod layer; +mod service; + +use std::{ + borrow::Cow, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, +}; + +use quanta::Instant; +use tokio::task_local; + +pub use self::{ + future::{LogContextFuture, PollRecordingFuture}, + layer::LogContextLayer, + service::LogContextService, +}; + +/// A counter which increments each time we create a new log context +/// It will wrap around if we create more than [`u64::MAX`] contexts +static LOG_CONTEXT_INDEX: AtomicU64 = AtomicU64::new(0); +task_local! { + pub static CURRENT_LOG_CONTEXT: LogContext; +} + +/// A log context saves informations about the current task, such as the +/// elapsed time, the number of polls, and the poll time. +#[derive(Clone)] +pub struct LogContext { + inner: Arc, +} + +struct LogContextInner { + /// A user-defined tag for the log context + tag: Cow<'static, str>, + + /// A unique index for the log context + index: u64, + + /// The time when the context was created + start: Instant, + + /// The number of [`Future::poll`] recorded + polls: AtomicU64, + + /// An approximation of the total CPU time spent in the context + cpu_time: AtomicU64, +} + +impl LogContext { + /// Create a new log context with the given tag + pub fn new(tag: impl Into>) -> Self { + let tag = tag.into(); + let inner = LogContextInner { + tag, + index: LOG_CONTEXT_INDEX.fetch_add(1, Ordering::Relaxed), + start: Instant::now(), + polls: AtomicU64::new(0), + cpu_time: AtomicU64::new(0), + }; + + Self { + inner: Arc::new(inner), + } + } + + /// Get a copy of the current log context, if any + pub fn current() -> Option { + CURRENT_LOG_CONTEXT.try_with(Self::clone).ok() + } + + /// Run the async function `f` with the given log context. It will wrap the + /// output future to record poll and CPU statistics. + pub fn run Fut, Fut: Future>(&self, f: F) -> LogContextFuture { + let future = self.run_sync(f); + self.wrap_future(future) + } + + /// Run the sync function `f` with the given log context, recording the CPU + /// time spent. + pub fn run_sync R, R>(&self, f: F) -> R { + let start = Instant::now(); + let result = CURRENT_LOG_CONTEXT.sync_scope(self.clone(), f); + let elapsed = start.elapsed().as_nanos().try_into().unwrap_or(u64::MAX); + self.inner.cpu_time.fetch_add(elapsed, Ordering::Relaxed); + result + } +} + +impl std::fmt::Display for LogContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + #[expect(clippy::cast_precision_loss)] + let elapsed = self.inner.start.elapsed().as_nanos() as f64 / 1_000_000.; + + #[expect(clippy::cast_precision_loss)] + let cpu_time_ms = self.inner.cpu_time.load(Ordering::Relaxed) as f64 / 1_000_000.; + + let polls = self.inner.polls.load(Ordering::Relaxed); + let tag = &self.inner.tag; + let index = self.inner.index; + write!( + f, + "{tag}-{index} ({polls} polls, CPU: {cpu_time_ms:.3} ms, total: {elapsed:.3} ms)" + ) + } +} + +/// A helper which implements `Display` for printing the current log context +#[derive(Debug, Clone, Copy)] +pub struct CurrentLogContext; + +impl std::fmt::Display for CurrentLogContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + CURRENT_LOG_CONTEXT + .try_with(|c| c.fmt(f)) + .unwrap_or_else(|_| "".fmt(f)) + } +} diff --git a/crates/context/src/service.rs b/crates/context/src/service.rs new file mode 100644 index 000000000..98a1d1184 --- /dev/null +++ b/crates/context/src/service.rs @@ -0,0 +1,54 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use std::{ + borrow::Cow, + task::{Context, Poll}, +}; + +use tower_service::Service; + +use crate::{LogContext, LogContextFuture}; + +/// A service which wraps another service and creates a log context for +/// each request. +pub struct LogContextService { + inner: S, + tagger: fn(&R) -> Cow<'static, str>, +} + +impl Clone for LogContextService { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + tagger: self.tagger, + } + } +} + +impl LogContextService { + pub fn new(inner: S, tagger: fn(&R) -> Cow<'static, str>) -> Self { + Self { inner, tagger } + } +} + +impl Service for LogContextService +where + S: Service, +{ + type Response = S::Response; + type Error = S::Error; + type Future = LogContextFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: R) -> Self::Future { + let tag = (self.tagger)(&req); + let log_context = LogContext::new(tag); + log_context.run(|| self.inner.call(req)) + } +} diff --git a/crates/listener/Cargo.toml b/crates/listener/Cargo.toml index 4f9056c0f..d5178c049 100644 --- a/crates/listener/Cargo.toml +++ b/crates/listener/Cargo.toml @@ -17,7 +17,7 @@ futures-util.workspace = true http-body.workspace = true hyper = { workspace = true, features = ["server"] } hyper-util.workspace = true -pin-project-lite = "0.2.16" +pin-project-lite.workspace = true socket2 = "0.5.9" thiserror.workspace = true tokio.workspace = true diff --git a/crates/tower/Cargo.toml b/crates/tower/Cargo.toml index 978eaa3c1..52ef9da13 100644 --- a/crates/tower/Cargo.toml +++ b/crates/tower/Cargo.toml @@ -19,4 +19,4 @@ tower.workspace = true opentelemetry.workspace = true opentelemetry-http.workspace = true opentelemetry-semantic-conventions.workspace = true -pin-project-lite = "0.2.16" +pin-project-lite.workspace = true