Use CancellationToken and a TaskTracker to handle graceful shutdowns
This commit is contained in:
6
Cargo.lock
generated
6
Cargo.lock
generated
@@ -3261,6 +3261,7 @@ dependencies = [
|
||||
"serde_yaml",
|
||||
"sqlx",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tower 0.5.1",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
@@ -3393,6 +3394,7 @@ dependencies = [
|
||||
"thiserror",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tower 0.5.1",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
@@ -3563,7 +3565,6 @@ version = "0.12.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"event-listener 5.3.1",
|
||||
"futures-util",
|
||||
"http-body",
|
||||
"hyper",
|
||||
@@ -3576,6 +3577,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-test",
|
||||
"tokio-util",
|
||||
"tower 0.5.1",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
@@ -6318,6 +6320,8 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"hashbrown 0.14.5",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
@@ -259,6 +259,11 @@ version = "1.0.64"
|
||||
version = "1.40.0"
|
||||
features = ["full"]
|
||||
|
||||
# Useful async utilities
|
||||
[workspace.dependencies.tokio-util]
|
||||
version = "0.7.12"
|
||||
features = ["rt"]
|
||||
|
||||
# Tower services
|
||||
[workspace.dependencies.tower]
|
||||
version = "0.5.1"
|
||||
|
||||
@@ -36,6 +36,7 @@ serde_json.workspace = true
|
||||
serde_yaml = "0.9.34"
|
||||
sqlx.workspace = true
|
||||
tokio.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tower.workspace = true
|
||||
tower-http.workspace = true
|
||||
url.workspace = true
|
||||
@@ -48,12 +49,20 @@ tracing-opentelemetry.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
opentelemetry-http.workspace = true
|
||||
opentelemetry-jaeger-propagator = "0.3.0"
|
||||
opentelemetry-otlp = { version = "0.17.0", default-features = false, features = ["trace", "metrics", "http-proto"] }
|
||||
opentelemetry-otlp = { version = "0.17.0", default-features = false, features = [
|
||||
"trace",
|
||||
"metrics",
|
||||
"http-proto",
|
||||
] }
|
||||
opentelemetry-prometheus = "0.17.0"
|
||||
opentelemetry-resource-detectors = "0.3.0"
|
||||
opentelemetry-semantic-conventions.workspace = true
|
||||
opentelemetry-stdout = { version = "0.5.0", features = ["trace", "metrics"] }
|
||||
opentelemetry_sdk = { version = "0.24.1", features = ["trace", "metrics", "rt-tokio"] }
|
||||
opentelemetry_sdk = { version = "0.24.1", features = [
|
||||
"trace",
|
||||
"metrics",
|
||||
"rt-tokio",
|
||||
] }
|
||||
prometheus = "0.13.4"
|
||||
sentry.workspace = true
|
||||
sentry-tracing.workspace = true
|
||||
|
||||
@@ -14,7 +14,7 @@ use mas_config::{
|
||||
AppConfig, ClientsConfig, ConfigurationSection, ConfigurationSectionExt, UpstreamOAuth2Config,
|
||||
};
|
||||
use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, Limiter, MetadataCache};
|
||||
use mas_listener::{server::Server, shutdown::ShutdownStream};
|
||||
use mas_listener::server::Server;
|
||||
use mas_matrix_synapse::SynapseConnection;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::SystemClock;
|
||||
@@ -24,11 +24,11 @@ use rand::{
|
||||
thread_rng,
|
||||
};
|
||||
use sqlx::migrate::Migrate;
|
||||
use tokio::signal::unix::SignalKind;
|
||||
use tracing::{info, info_span, warn, Instrument};
|
||||
|
||||
use crate::{
|
||||
app_state::AppState,
|
||||
shutdown::ShutdownManager,
|
||||
util::{
|
||||
database_pool_from_config, mailer_from_config, password_manager_from_config,
|
||||
policy_factory_from_config, register_sighup, site_config_from_config,
|
||||
@@ -61,6 +61,7 @@ impl Options {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
|
||||
let span = info_span!("cli.run.init").entered();
|
||||
let shutdown = ShutdownManager::new()?;
|
||||
let config = AppConfig::extract(figment)?;
|
||||
|
||||
if self.migrate {
|
||||
@@ -173,8 +174,21 @@ impl Options {
|
||||
url_builder.clone(),
|
||||
)
|
||||
.await?;
|
||||
// TODO: grab the handle
|
||||
tokio::spawn(monitor.run());
|
||||
|
||||
// XXX: The monitor from apalis is a bit annoying to use for graceful shutdowns,
|
||||
// ideally we'd just give it a cancellation token
|
||||
let shutdown_future = shutdown.soft_shutdown_token().cancelled_owned();
|
||||
shutdown.task_tracker().spawn(async move {
|
||||
if let Err(e) = monitor
|
||||
.run_with_signal(async move {
|
||||
shutdown_future.await;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
{
|
||||
tracing::error!(error = &e as &dyn std::error::Error, "Task worker failed");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let listeners_config = config.http.listeners.clone();
|
||||
@@ -186,7 +200,12 @@ impl Options {
|
||||
|
||||
// Initialize the activity tracker
|
||||
// Activity is flushed every minute
|
||||
let activity_tracker = ActivityTracker::new(pool.clone(), Duration::from_secs(60));
|
||||
let activity_tracker = ActivityTracker::new(
|
||||
pool.clone(),
|
||||
Duration::from_secs(60),
|
||||
shutdown.task_tracker(),
|
||||
shutdown.soft_shutdown_token(),
|
||||
);
|
||||
let trusted_proxies = config.http.trusted_proxies.clone();
|
||||
|
||||
// Build a rate limiter.
|
||||
@@ -302,16 +321,17 @@ impl Options {
|
||||
.flatten_ok()
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let shutdown = ShutdownStream::default()
|
||||
.with_timeout(Duration::from_secs(60))
|
||||
.with_signal(SignalKind::terminate())?
|
||||
.with_signal(SignalKind::interrupt())?;
|
||||
|
||||
span.exit();
|
||||
|
||||
mas_listener::server::run_servers(servers, shutdown).await;
|
||||
shutdown
|
||||
.task_tracker()
|
||||
.spawn(mas_listener::server::run_servers(
|
||||
servers,
|
||||
shutdown.soft_shutdown_token(),
|
||||
shutdown.hard_shutdown_token(),
|
||||
));
|
||||
|
||||
state.activity_tracker.shutdown().await;
|
||||
shutdown.run().await;
|
||||
|
||||
Ok(ExitCode::SUCCESS)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ mod app_state;
|
||||
mod commands;
|
||||
mod sentry_transport;
|
||||
mod server;
|
||||
mod shutdown;
|
||||
mod sync;
|
||||
mod telemetry;
|
||||
mod util;
|
||||
|
||||
116
crates/cli/src/shutdown.rs
Normal file
116
crates/cli/src/shutdown.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::signal::unix::{Signal, SignalKind};
|
||||
use tokio_util::{sync::CancellationToken, task::TaskTracker};
|
||||
|
||||
/// A helper to manage graceful shutdowns and track tasks that gracefully
|
||||
/// shutdown.
|
||||
///
|
||||
/// It will listen for SIGTERM and SIGINT signals, and will trigger a soft
|
||||
/// shutdown on the first signal, and a hard shutdown on the second signal or
|
||||
/// after a timeout.
|
||||
///
|
||||
/// Users of this manager should use the `soft_shutdown_token` to react to a
|
||||
/// soft shutdown, which should gracefully finish requests and close
|
||||
/// connections, and the `hard_shutdown_token` to react to a hard shutdown,
|
||||
/// which should drop all connections and finish all requests.
|
||||
///
|
||||
/// They should also use the `task_tracker` to make it track things running, so
|
||||
/// that it knows when the soft shutdown is over and worked.
|
||||
pub struct ShutdownManager {
|
||||
hard_shutdown_token: CancellationToken,
|
||||
soft_shutdown_token: CancellationToken,
|
||||
task_tracker: TaskTracker,
|
||||
sigterm: Signal,
|
||||
sigint: Signal,
|
||||
timeout: Duration,
|
||||
}
|
||||
|
||||
impl ShutdownManager {
|
||||
/// Create a new shutdown manager, installing the signal handlers
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the signal handler could not be installed
|
||||
pub fn new() -> Result<Self, std::io::Error> {
|
||||
let hard_shutdown_token = CancellationToken::new();
|
||||
let soft_shutdown_token = hard_shutdown_token.child_token();
|
||||
let sigterm = tokio::signal::unix::signal(SignalKind::terminate())?;
|
||||
let sigint = tokio::signal::unix::signal(SignalKind::interrupt())?;
|
||||
let timeout = Duration::from_secs(60);
|
||||
let task_tracker = TaskTracker::new();
|
||||
|
||||
Ok(Self {
|
||||
hard_shutdown_token,
|
||||
soft_shutdown_token,
|
||||
task_tracker,
|
||||
sigterm,
|
||||
sigint,
|
||||
timeout,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a reference to the task tracker
|
||||
#[must_use]
|
||||
pub fn task_tracker(&self) -> &TaskTracker {
|
||||
&self.task_tracker
|
||||
}
|
||||
|
||||
/// Get a cancellation token that can be used to react to a hard shutdown
|
||||
#[must_use]
|
||||
pub fn hard_shutdown_token(&self) -> CancellationToken {
|
||||
self.hard_shutdown_token.clone()
|
||||
}
|
||||
|
||||
/// Get a cancellation token that can be used to react to a soft shutdown
|
||||
#[must_use]
|
||||
pub fn soft_shutdown_token(&self) -> CancellationToken {
|
||||
self.soft_shutdown_token.clone()
|
||||
}
|
||||
|
||||
/// Run until we finish completely shutting down.
|
||||
pub async fn run(mut self) {
|
||||
// Wait for a first signal and trigger the soft shutdown
|
||||
tokio::select! {
|
||||
_ = self.sigterm.recv() => {
|
||||
tracing::info!("Shutdown signal received (SIGTERM), shutting down");
|
||||
},
|
||||
_ = self.sigint.recv() => {
|
||||
tracing::info!("Shutdown signal received (SIGINT), shutting down");
|
||||
},
|
||||
};
|
||||
|
||||
self.soft_shutdown_token.cancel();
|
||||
self.task_tracker.close();
|
||||
|
||||
// Start the timeout
|
||||
let timeout = tokio::time::sleep(self.timeout);
|
||||
tokio::select! {
|
||||
_ = self.sigterm.recv() => {
|
||||
tracing::warn!("Second shutdown signal received (SIGTERM), abort");
|
||||
},
|
||||
_ = self.sigint.recv() => {
|
||||
tracing::warn!("Second shutdown signal received (SIGINT), abort");
|
||||
},
|
||||
() = timeout => {
|
||||
tracing::warn!("Shutdown timeout reached, abort");
|
||||
},
|
||||
() = self.task_tracker.wait() => {
|
||||
// This is the "happy path", we have gracefully shutdown
|
||||
},
|
||||
}
|
||||
|
||||
self.hard_shutdown_token().cancel();
|
||||
|
||||
// TODO: we may want to have a time out on the task tracker, in case we have
|
||||
// really stuck tasks on it
|
||||
self.task_tracker().wait().await;
|
||||
|
||||
tracing::info!("All tasks are done, exitting");
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ workspace = true
|
||||
[dependencies]
|
||||
# Async runtime
|
||||
tokio.workspace = true
|
||||
tokio-util.workspace = true
|
||||
futures-util = "0.3.31"
|
||||
async-trait.workspace = true
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{BrowserSession, CompatSession, Session};
|
||||
use mas_storage::Clock;
|
||||
use sqlx::PgPool;
|
||||
use tokio_util::{sync::CancellationToken, task::TaskTracker};
|
||||
use ulid::Ulid;
|
||||
|
||||
pub use self::bound::Bound;
|
||||
@@ -45,7 +46,6 @@ enum Message {
|
||||
ip: Option<IpAddr>,
|
||||
},
|
||||
Flush(tokio::sync::oneshot::Sender<()>),
|
||||
Shutdown(tokio::sync::oneshot::Sender<()>),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -54,16 +54,29 @@ pub struct ActivityTracker {
|
||||
}
|
||||
|
||||
impl ActivityTracker {
|
||||
/// Create a new activity tracker, spawning the worker.
|
||||
/// Create a new activity tracker
|
||||
///
|
||||
/// It will spawn the background worker and a loop to flush the tracker on
|
||||
/// the task tracker, and both will shut themselves down, flushing one last
|
||||
/// time, when the cancellation token is cancelled.
|
||||
#[must_use]
|
||||
pub fn new(pool: PgPool, flush_interval: std::time::Duration) -> Self {
|
||||
pub fn new(
|
||||
pool: PgPool,
|
||||
flush_interval: std::time::Duration,
|
||||
task_tracker: &TaskTracker,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Self {
|
||||
let worker = Worker::new(pool);
|
||||
let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE);
|
||||
let tracker = ActivityTracker { channel: sender };
|
||||
|
||||
// Spawn the flush loop and the worker
|
||||
tokio::spawn(tracker.clone().flush_loop(flush_interval));
|
||||
tokio::spawn(worker.run(receiver));
|
||||
task_tracker.spawn(
|
||||
tracker
|
||||
.clone()
|
||||
.flush_loop(flush_interval, cancellation_token.clone()),
|
||||
);
|
||||
task_tracker.spawn(worker.run(receiver, cancellation_token));
|
||||
|
||||
tracker
|
||||
}
|
||||
@@ -148,50 +161,47 @@ impl ActivityTracker {
|
||||
match res {
|
||||
Ok(()) => {
|
||||
if let Err(e) = rx.await {
|
||||
tracing::error!("Failed to flush activity tracker: {}", e);
|
||||
tracing::error!(
|
||||
error = &e as &dyn std::error::Error,
|
||||
"Failed to flush activity tracker"
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to flush activity tracker: {}", e);
|
||||
tracing::error!(
|
||||
error = &e as &dyn std::error::Error,
|
||||
"Failed to flush activity tracker"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Regularly flush the activity tracker.
|
||||
async fn flush_loop(self, interval: std::time::Duration) {
|
||||
async fn flush_loop(
|
||||
self,
|
||||
interval: std::time::Duration,
|
||||
cancellation_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
() = cancellation_token.cancelled() => {
|
||||
// The cancellation token was cancelled, so we should exit
|
||||
return;
|
||||
}
|
||||
|
||||
// First check if the channel is closed, then check if the timer expired
|
||||
() = self.channel.closed() => {
|
||||
// The channel was closed, so we should exit
|
||||
break;
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
() = tokio::time::sleep(interval) => {
|
||||
self.flush().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Shutdown the activity tracker.
|
||||
///
|
||||
/// This will wait for all pending messages to be processed.
|
||||
pub async fn shutdown(&self) {
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
let res = self.channel.send(Message::Shutdown(tx)).await;
|
||||
|
||||
match res {
|
||||
Ok(()) => {
|
||||
if let Err(e) = rx.await {
|
||||
tracing::error!("Failed to shutdown activity tracker: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to shutdown activity tracker: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ use opentelemetry::{
|
||||
Key,
|
||||
};
|
||||
use sqlx::PgPool;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::activity_tracker::{Message, SessionKind};
|
||||
@@ -88,9 +89,30 @@ impl Worker {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn run(mut self, mut receiver: tokio::sync::mpsc::Receiver<Message>) {
|
||||
let mut shutdown_notifier = None;
|
||||
while let Some(message) = receiver.recv().await {
|
||||
pub(super) async fn run(
|
||||
mut self,
|
||||
mut receiver: tokio::sync::mpsc::Receiver<Message>,
|
||||
cancellation_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
let message = tokio::select! {
|
||||
// Because we want the cancellation token to trigger only once,
|
||||
// we looked whether we closed the channel or not
|
||||
() = cancellation_token.cancelled(), if !receiver.is_closed() => {
|
||||
// We only close the channel, which will make it flush all
|
||||
// the pending messages
|
||||
receiver.close();
|
||||
tracing::debug!("Shutting down activity tracker");
|
||||
continue;
|
||||
},
|
||||
|
||||
message = receiver.recv() => {
|
||||
// We consumed all the messages, break out of the loop
|
||||
let Some(message) = message else { break };
|
||||
message
|
||||
}
|
||||
};
|
||||
|
||||
match message {
|
||||
Message::Record {
|
||||
kind,
|
||||
@@ -129,37 +151,18 @@ impl Worker {
|
||||
|
||||
record.end_time = date_time.max(record.end_time);
|
||||
}
|
||||
|
||||
Message::Flush(tx) => {
|
||||
self.message_counter.add(1, &[TYPE.string("flush")]);
|
||||
|
||||
self.flush().await;
|
||||
let _ = tx.send(());
|
||||
}
|
||||
Message::Shutdown(tx) => {
|
||||
self.message_counter.add(1, &[TYPE.string("shutdown")]);
|
||||
|
||||
let old_tx = shutdown_notifier.replace(tx);
|
||||
if let Some(old_tx) = old_tx {
|
||||
tracing::warn!("Activity tracker shutdown requested while another shutdown was already in progress");
|
||||
// Still send the shutdown signal to the previous notifier. This means we
|
||||
// send the shutdown signal before we flush the activity tracker, but that
|
||||
// should be fine, since there should not be multiple shutdown requests.
|
||||
let _ = old_tx.send(());
|
||||
}
|
||||
receiver.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush one last time
|
||||
self.flush().await;
|
||||
|
||||
if let Some(shutdown_notifier) = shutdown_notifier {
|
||||
let _ = shutdown_notifier.send(());
|
||||
} else {
|
||||
// This should never happen, since we set the shutdown notifier when we receive
|
||||
// the first shutdown message
|
||||
tracing::warn!("Activity tracker shutdown requested but no shutdown notifier was set");
|
||||
}
|
||||
}
|
||||
|
||||
/// Flush the activity tracker.
|
||||
|
||||
@@ -44,6 +44,10 @@ use rand::SeedableRng;
|
||||
use rand_chacha::ChaChaRng;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use tokio_util::{
|
||||
sync::{CancellationToken, DropGuard},
|
||||
task::TaskTracker,
|
||||
};
|
||||
use tower::{Layer, Service, ServiceExt};
|
||||
use url::Url;
|
||||
|
||||
@@ -105,6 +109,9 @@ pub(crate) struct TestState {
|
||||
pub limiter: Limiter,
|
||||
pub clock: Arc<MockClock>,
|
||||
pub rng: Arc<Mutex<ChaChaRng>>,
|
||||
|
||||
#[allow(dead_code)] // It is used, as it will cancel the CancellationToken when dropped
|
||||
cancellation_drop_guard: Arc<DropGuard>,
|
||||
}
|
||||
|
||||
fn workspace_root() -> camino::Utf8PathBuf {
|
||||
@@ -147,6 +154,9 @@ impl TestState {
|
||||
) -> Result<Self, anyhow::Error> {
|
||||
let workspace_root = workspace_root();
|
||||
|
||||
let task_tracker = TaskTracker::new();
|
||||
let shutdown_token = CancellationToken::new();
|
||||
|
||||
let url_builder = UrlBuilder::new("https://example.com/".parse()?, None, None);
|
||||
|
||||
let templates = Templates::load(
|
||||
@@ -204,8 +214,12 @@ impl TestState {
|
||||
|
||||
let graphql_schema = graphql::schema_builder().data(state).finish();
|
||||
|
||||
let activity_tracker =
|
||||
ActivityTracker::new(pool.clone(), std::time::Duration::from_secs(1));
|
||||
let activity_tracker = ActivityTracker::new(
|
||||
pool.clone(),
|
||||
std::time::Duration::from_secs(60),
|
||||
&task_tracker,
|
||||
shutdown_token.child_token(),
|
||||
);
|
||||
|
||||
let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();
|
||||
|
||||
@@ -227,6 +241,7 @@ impl TestState {
|
||||
limiter,
|
||||
clock,
|
||||
rng,
|
||||
cancellation_drop_guard: Arc::new(shutdown_token.drop_guard()),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ pub async fn get(
|
||||
cookie_jar: CookieJar,
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
) -> Result<impl IntoResponse, FancyError> {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
let session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
@@ -13,7 +13,6 @@ workspace = true
|
||||
|
||||
[dependencies]
|
||||
bytes.workspace = true
|
||||
event-listener = "5.3.1"
|
||||
futures-util = "0.3.31"
|
||||
http-body.workspace = true
|
||||
hyper = { workspace = true, features = ["server"] }
|
||||
@@ -24,6 +23,7 @@ socket2 = "0.5.7"
|
||||
thiserror.workspace = true
|
||||
tokio.workspace = true
|
||||
tokio-rustls = "0.26.0"
|
||||
tokio-util.workspace = true
|
||||
tower.workspace = true
|
||||
tower-http.workspace = true
|
||||
tracing.workspace = true
|
||||
|
||||
@@ -14,9 +14,9 @@ use std::{
|
||||
|
||||
use anyhow::Context;
|
||||
use hyper::{Request, Response};
|
||||
use mas_listener::{server::Server, shutdown::ShutdownStream, ConnectionInfo};
|
||||
use tokio::signal::unix::SignalKind;
|
||||
use mas_listener::{server::Server, ConnectionInfo};
|
||||
use tokio_rustls::rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tower::service_fn;
|
||||
|
||||
static CA_CERT_PEM: &[u8] = include_bytes!("./certs/ca.pem");
|
||||
@@ -53,12 +53,23 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||
|
||||
tracing::info!("Listening on http://127.0.0.1:3000, http(proxy)://127.0.0.1:3001, https://127.0.0.1:3002 and https(proxy)://127.0.0.1:3003");
|
||||
|
||||
let shutdown = ShutdownStream::default()
|
||||
.with_timeout(Duration::from_secs(1))
|
||||
.with_signal(SignalKind::interrupt())?
|
||||
.with_signal(SignalKind::terminate())?;
|
||||
let hard_shutdown = CancellationToken::new();
|
||||
let soft_shutdown = hard_shutdown.child_token();
|
||||
|
||||
mas_listener::server::run_servers(servers, shutdown).await;
|
||||
{
|
||||
let hard_shutdown = hard_shutdown.clone();
|
||||
let soft_shutdown = soft_shutdown.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::signal::ctrl_c().await.unwrap();
|
||||
tracing::info!("Ctrl-C received, performing soft-shutdown");
|
||||
soft_shutdown.cancel();
|
||||
tokio::signal::ctrl_c().await.unwrap();
|
||||
tracing::info!("Ctrl-C received again, shutting down");
|
||||
hard_shutdown.cancel();
|
||||
});
|
||||
}
|
||||
|
||||
mas_listener::server::run_servers(servers, hard_shutdown, soft_shutdown).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ pub mod maybe_tls;
|
||||
pub mod proxy_protocol;
|
||||
pub mod rewind;
|
||||
pub mod server;
|
||||
pub mod shutdown;
|
||||
pub mod unix_or_tcp;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
||||
@@ -7,13 +7,12 @@
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use event_listener::{Event, EventListener};
|
||||
use futures_util::{stream::SelectAll, Stream, StreamExt};
|
||||
use futures_util::{stream::SelectAll, StreamExt};
|
||||
use hyper::{Request, Response};
|
||||
use hyper_util::{
|
||||
rt::{TokioExecutor, TokioIo},
|
||||
@@ -23,6 +22,7 @@ use hyper_util::{
|
||||
use pin_project_lite::pin_project;
|
||||
use thiserror::Error;
|
||||
use tokio_rustls::rustls::ServerConfig;
|
||||
use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
|
||||
use tower::Service;
|
||||
use tower_http::add_extension::AddExtension;
|
||||
use tracing::Instrument;
|
||||
@@ -84,18 +84,24 @@ impl<S> Server<S> {
|
||||
}
|
||||
|
||||
/// Run a single server
|
||||
pub async fn run<B, SD>(self, shutdown: SD)
|
||||
where
|
||||
pub async fn run<B>(
|
||||
self,
|
||||
soft_shutdown_token: CancellationToken,
|
||||
hard_shutdown_token: CancellationToken,
|
||||
) where
|
||||
S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
S::Error: std::error::Error + Send + Sync + 'static,
|
||||
B: http_body::Body + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: std::error::Error + Send + Sync + 'static,
|
||||
SD: Stream + Unpin,
|
||||
SD::Item: std::fmt::Display,
|
||||
{
|
||||
run_servers(std::iter::once(self), shutdown).await;
|
||||
run_servers(
|
||||
std::iter::once(self),
|
||||
soft_shutdown_token,
|
||||
hard_shutdown_token,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,18 +258,16 @@ pin_project! {
|
||||
#[pin]
|
||||
connection: C,
|
||||
#[pin]
|
||||
shutdown_listener: EventListener,
|
||||
shutdown_in_progress: Arc<AtomicBool>,
|
||||
cancellation_future: WaitForCancellationFutureOwned,
|
||||
did_start_shutdown: bool,
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> AbortableConnection<C> {
|
||||
fn new(connection: C, shutdown_in_progress: &Arc<AtomicBool>, event: &Arc<Event>) -> Self {
|
||||
fn new(connection: C, cancellation_token: CancellationToken) -> Self {
|
||||
Self {
|
||||
connection,
|
||||
shutdown_listener: event.listen(),
|
||||
shutdown_in_progress: Arc::clone(shutdown_in_progress),
|
||||
cancellation_future: cancellation_token.cancelled_owned(),
|
||||
did_start_shutdown: false,
|
||||
}
|
||||
}
|
||||
@@ -286,19 +290,11 @@ where
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.project();
|
||||
|
||||
// Poll the shutdown signal, so that wakers get registered.
|
||||
// XXX: I don't think we care about the result of this poll, since it's only
|
||||
// really to register wakers. But I'm not sure if it's safe to
|
||||
// ignore the result.
|
||||
let _ = this.shutdown_listener.poll(cx);
|
||||
|
||||
if !*this.did_start_shutdown
|
||||
&& this
|
||||
.shutdown_in_progress
|
||||
.load(std::sync::atomic::Ordering::Relaxed)
|
||||
{
|
||||
*this.did_start_shutdown = true;
|
||||
this.connection.as_mut().graceful_shutdown();
|
||||
if let Poll::Ready(()) = this.cancellation_future.poll(cx) {
|
||||
if !*this.did_start_shutdown {
|
||||
*this.did_start_shutdown = true;
|
||||
this.connection.as_mut().graceful_shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
this.connection.poll(cx)
|
||||
@@ -306,16 +302,17 @@ where
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn run_servers<S, B, SD>(listeners: impl IntoIterator<Item = Server<S>>, mut shutdown: SD)
|
||||
where
|
||||
pub async fn run_servers<S, B>(
|
||||
listeners: impl IntoIterator<Item = Server<S>>,
|
||||
soft_shutdown_token: CancellationToken,
|
||||
hard_shutdown_token: CancellationToken,
|
||||
) where
|
||||
S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
S::Error: std::error::Error + Send + Sync + 'static,
|
||||
B: http_body::Body + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: std::error::Error + Send + Sync + 'static,
|
||||
SD: Stream + Unpin,
|
||||
SD::Item: std::fmt::Display,
|
||||
{
|
||||
// Create a stream of accepted connections out of the listeners
|
||||
let mut accept_stream: SelectAll<_> = listeners
|
||||
@@ -344,19 +341,13 @@ where
|
||||
// A JoinSet which collects connections that are being served
|
||||
let mut connection_tasks = tokio::task::JoinSet::new();
|
||||
|
||||
// A shared atomic boolean to tell all connections to shutdown
|
||||
let shutdown_in_progress = Arc::new(AtomicBool::new(false));
|
||||
let shutdown_event = Arc::new(Event::new());
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
// First look for the shutdown signal
|
||||
res = shutdown.next() => {
|
||||
let why = res.map_or_else(|| String::from("???"), |why| format!("{why}"));
|
||||
tracing::info!("Received shutdown signal ({why})");
|
||||
|
||||
() = soft_shutdown_token.cancelled() => {
|
||||
tracing::debug!("Shutting down listeners");
|
||||
break;
|
||||
},
|
||||
|
||||
@@ -365,7 +356,7 @@ where
|
||||
match res {
|
||||
Some(Ok(Ok(connection))) => {
|
||||
tracing::trace!("Accepted connection");
|
||||
let conn = AbortableConnection::new(connection, &shutdown_in_progress, &shutdown_event);
|
||||
let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token());
|
||||
connection_tasks.spawn(conn);
|
||||
},
|
||||
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
|
||||
@@ -385,9 +376,8 @@ where
|
||||
},
|
||||
|
||||
// Look for connections to accept
|
||||
res = accept_stream.next(), if !accept_stream.is_empty() => {
|
||||
// SAFETY: We shouldn't reach this branch if the stream set is empty
|
||||
let Some(res) = res else { unreachable!() };
|
||||
res = accept_stream.next() => {
|
||||
let Some(res) = res else { continue };
|
||||
|
||||
// Spawn the connection in the set, so we don't have to wait for the handshake to
|
||||
// accept the next connection. This allows us to keep track of active connections
|
||||
@@ -401,10 +391,6 @@ where
|
||||
};
|
||||
}
|
||||
|
||||
// Tell the active connections to shutdown
|
||||
shutdown_in_progress.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
shutdown_event.notify(usize::MAX);
|
||||
|
||||
// Wait for connections to cleanup
|
||||
if !accept_tasks.is_empty() || !connection_tasks.is_empty() {
|
||||
tracing::info!(
|
||||
@@ -422,7 +408,7 @@ where
|
||||
match res {
|
||||
Some(Ok(Ok(connection))) => {
|
||||
tracing::trace!("Accepted connection");
|
||||
let conn = AbortableConnection::new(connection, &shutdown_in_progress, &shutdown_event);
|
||||
let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token());
|
||||
connection_tasks.spawn(conn);
|
||||
}
|
||||
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
|
||||
@@ -441,11 +427,10 @@ where
|
||||
}
|
||||
},
|
||||
|
||||
// Handle when we receive the shutdown signal again
|
||||
res = shutdown.next() => {
|
||||
let why = res.map_or_else(|| String::from("???"), |why| format!("{why}"));
|
||||
// Handle when we are asked to hard shutdown
|
||||
() = hard_shutdown_token.cancelled() => {
|
||||
tracing::warn!(
|
||||
"Received shutdown signal again ({why}), forcing shutdown ({active} active connections, {pending} pending connections)",
|
||||
"Forcing shutdown ({active} active connections, {pending} pending connections)",
|
||||
active = connection_tasks.len(),
|
||||
pending = accept_tasks.len(),
|
||||
);
|
||||
@@ -457,5 +442,4 @@ where
|
||||
|
||||
accept_tasks.shutdown().await;
|
||||
connection_tasks.shutdown().await;
|
||||
tracing::info!("Shutdown complete");
|
||||
}
|
||||
|
||||
@@ -1,172 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::{fmt::Display, pin::Pin, task::Poll, time::Duration};
|
||||
|
||||
use futures_util::{ready, Future, Stream};
|
||||
use tokio::{
|
||||
signal::unix::{signal, Signal, SignalKind},
|
||||
time::Sleep,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum ShutdownReason {
|
||||
Signal(SignalKind),
|
||||
Timeout,
|
||||
}
|
||||
|
||||
fn signal_to_str(kind: SignalKind) -> &'static str {
|
||||
match kind.as_raw_value() {
|
||||
libc::SIGALRM => "SIGALRM",
|
||||
libc::SIGCHLD => "SIGCHLD",
|
||||
libc::SIGHUP => "SIGHUP",
|
||||
libc::SIGINT => "SIGINT",
|
||||
libc::SIGIO => "SIGIO",
|
||||
libc::SIGPIPE => "SIGPIPE",
|
||||
libc::SIGQUIT => "SIGQUIT",
|
||||
libc::SIGTERM => "SIGTERM",
|
||||
libc::SIGUSR1 => "SIGUSR1",
|
||||
libc::SIGUSR2 => "SIGUSR2",
|
||||
_ => "SIG???",
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ShutdownReason {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Signal(s) => signal_to_str(*s).fmt(f),
|
||||
Self::Timeout => "timeout".fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ShutdownStreamState {
|
||||
Waiting,
|
||||
|
||||
Graceful { sleep: Option<Pin<Box<Sleep>>> },
|
||||
|
||||
Done,
|
||||
}
|
||||
|
||||
impl Default for ShutdownStreamState {
|
||||
fn default() -> Self {
|
||||
Self::Waiting
|
||||
}
|
||||
}
|
||||
|
||||
impl ShutdownStreamState {
|
||||
fn is_graceful(&self) -> bool {
|
||||
matches!(self, Self::Graceful { .. })
|
||||
}
|
||||
|
||||
fn is_done(&self) -> bool {
|
||||
matches!(self, Self::Done)
|
||||
}
|
||||
|
||||
fn get_sleep_mut(&mut self) -> Option<&mut Pin<Box<Sleep>>> {
|
||||
match self {
|
||||
Self::Graceful { sleep } => sleep.as_mut(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream which is used to drive a graceful shutdown.
|
||||
///
|
||||
/// It will emit 2 items: one when a first signal is caught, the other when
|
||||
/// either another signal is caught, or after a timeout.
|
||||
#[derive(Default)]
|
||||
pub struct ShutdownStream {
|
||||
state: ShutdownStreamState,
|
||||
signals: Vec<(SignalKind, Signal)>,
|
||||
timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl ShutdownStream {
|
||||
/// Create a default shutdown stream, which listens on SIGINT and SIGTERM,
|
||||
/// with a 60s timeout
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if signal handlers could not be installed
|
||||
pub fn new() -> Result<Self, std::io::Error> {
|
||||
let ret = Self::default()
|
||||
.with_timeout(Duration::from_secs(60))
|
||||
.with_signal(SignalKind::interrupt())?
|
||||
.with_signal(SignalKind::terminate())?;
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
/// Add a signal to register
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the signal handler could not be installed
|
||||
pub fn with_signal(mut self, kind: SignalKind) -> Result<Self, std::io::Error> {
|
||||
let signal = signal(kind)?;
|
||||
self.signals.push((kind, signal));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.timeout = Some(timeout);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for ShutdownStream {
|
||||
type Item = ShutdownReason;
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
match self.state {
|
||||
ShutdownStreamState::Waiting => (2, Some(2)),
|
||||
ShutdownStreamState::Graceful { .. } => (1, Some(1)),
|
||||
ShutdownStreamState::Done => (0, Some(0)),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_next(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if this.state.is_done() {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
for (kind, signal) in &mut this.signals {
|
||||
match signal.poll_recv(cx) {
|
||||
Poll::Ready(_) => {
|
||||
// We got a signal
|
||||
if this.state.is_graceful() {
|
||||
// If we was gracefully shutting down, mark it as done
|
||||
this.state = ShutdownStreamState::Done;
|
||||
} else {
|
||||
// Else start the timeout
|
||||
let sleep = this
|
||||
.timeout
|
||||
.map(|duration| Box::pin(tokio::time::sleep(duration)));
|
||||
this.state = ShutdownStreamState::Graceful { sleep };
|
||||
}
|
||||
|
||||
return Poll::Ready(Some(ShutdownReason::Signal(*kind)));
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(timeout) = this.state.get_sleep_mut() {
|
||||
ready!(timeout.as_mut().poll(cx));
|
||||
this.state = ShutdownStreamState::Done;
|
||||
Poll::Ready(Some(ShutdownReason::Timeout))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user