Use CancellationToken and a TaskTracker to handle graceful shutdowns

This commit is contained in:
Quentin Gliech
2024-10-09 17:48:59 +02:00
parent e88963d172
commit f0e2f6a2f0
16 changed files with 309 additions and 302 deletions

6
Cargo.lock generated
View File

@@ -3261,6 +3261,7 @@ dependencies = [
"serde_yaml",
"sqlx",
"tokio",
"tokio-util",
"tower 0.5.1",
"tower-http",
"tracing",
@@ -3393,6 +3394,7 @@ dependencies = [
"thiserror",
"time",
"tokio",
"tokio-util",
"tower 0.5.1",
"tower-http",
"tracing",
@@ -3563,7 +3565,6 @@ version = "0.12.0"
dependencies = [
"anyhow",
"bytes",
"event-listener 5.3.1",
"futures-util",
"http-body",
"hyper",
@@ -3576,6 +3577,7 @@ dependencies = [
"tokio",
"tokio-rustls",
"tokio-test",
"tokio-util",
"tower 0.5.1",
"tower-http",
"tracing",
@@ -6318,6 +6320,8 @@ dependencies = [
"bytes",
"futures-core",
"futures-sink",
"futures-util",
"hashbrown 0.14.5",
"pin-project-lite",
"tokio",
]

View File

@@ -259,6 +259,11 @@ version = "1.0.64"
version = "1.40.0"
features = ["full"]
# Useful async utilities
[workspace.dependencies.tokio-util]
version = "0.7.12"
features = ["rt"]
# Tower services
[workspace.dependencies.tower]
version = "0.5.1"

View File

@@ -36,6 +36,7 @@ serde_json.workspace = true
serde_yaml = "0.9.34"
sqlx.workspace = true
tokio.workspace = true
tokio-util.workspace = true
tower.workspace = true
tower-http.workspace = true
url.workspace = true
@@ -48,12 +49,20 @@ tracing-opentelemetry.workspace = true
opentelemetry.workspace = true
opentelemetry-http.workspace = true
opentelemetry-jaeger-propagator = "0.3.0"
opentelemetry-otlp = { version = "0.17.0", default-features = false, features = ["trace", "metrics", "http-proto"] }
opentelemetry-otlp = { version = "0.17.0", default-features = false, features = [
"trace",
"metrics",
"http-proto",
] }
opentelemetry-prometheus = "0.17.0"
opentelemetry-resource-detectors = "0.3.0"
opentelemetry-semantic-conventions.workspace = true
opentelemetry-stdout = { version = "0.5.0", features = ["trace", "metrics"] }
opentelemetry_sdk = { version = "0.24.1", features = ["trace", "metrics", "rt-tokio"] }
opentelemetry_sdk = { version = "0.24.1", features = [
"trace",
"metrics",
"rt-tokio",
] }
prometheus = "0.13.4"
sentry.workspace = true
sentry-tracing.workspace = true

View File

@@ -14,7 +14,7 @@ use mas_config::{
AppConfig, ClientsConfig, ConfigurationSection, ConfigurationSectionExt, UpstreamOAuth2Config,
};
use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, Limiter, MetadataCache};
use mas_listener::{server::Server, shutdown::ShutdownStream};
use mas_listener::server::Server;
use mas_matrix_synapse::SynapseConnection;
use mas_router::UrlBuilder;
use mas_storage::SystemClock;
@@ -24,11 +24,11 @@ use rand::{
thread_rng,
};
use sqlx::migrate::Migrate;
use tokio::signal::unix::SignalKind;
use tracing::{info, info_span, warn, Instrument};
use crate::{
app_state::AppState,
shutdown::ShutdownManager,
util::{
database_pool_from_config, mailer_from_config, password_manager_from_config,
policy_factory_from_config, register_sighup, site_config_from_config,
@@ -61,6 +61,7 @@ impl Options {
#[allow(clippy::too_many_lines)]
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
let span = info_span!("cli.run.init").entered();
let shutdown = ShutdownManager::new()?;
let config = AppConfig::extract(figment)?;
if self.migrate {
@@ -173,8 +174,21 @@ impl Options {
url_builder.clone(),
)
.await?;
// TODO: grab the handle
tokio::spawn(monitor.run());
// XXX: The monitor from apalis is a bit annoying to use for graceful shutdowns,
// ideally we'd just give it a cancellation token
let shutdown_future = shutdown.soft_shutdown_token().cancelled_owned();
shutdown.task_tracker().spawn(async move {
if let Err(e) = monitor
.run_with_signal(async move {
shutdown_future.await;
Ok(())
})
.await
{
tracing::error!(error = &e as &dyn std::error::Error, "Task worker failed");
}
});
}
let listeners_config = config.http.listeners.clone();
@@ -186,7 +200,12 @@ impl Options {
// Initialize the activity tracker
// Activity is flushed every minute
let activity_tracker = ActivityTracker::new(pool.clone(), Duration::from_secs(60));
let activity_tracker = ActivityTracker::new(
pool.clone(),
Duration::from_secs(60),
shutdown.task_tracker(),
shutdown.soft_shutdown_token(),
);
let trusted_proxies = config.http.trusted_proxies.clone();
// Build a rate limiter.
@@ -302,16 +321,17 @@ impl Options {
.flatten_ok()
.collect::<Result<Vec<_>, _>>()?;
let shutdown = ShutdownStream::default()
.with_timeout(Duration::from_secs(60))
.with_signal(SignalKind::terminate())?
.with_signal(SignalKind::interrupt())?;
span.exit();
mas_listener::server::run_servers(servers, shutdown).await;
shutdown
.task_tracker()
.spawn(mas_listener::server::run_servers(
servers,
shutdown.soft_shutdown_token(),
shutdown.hard_shutdown_token(),
));
state.activity_tracker.shutdown().await;
shutdown.run().await;
Ok(ExitCode::SUCCESS)
}

View File

@@ -22,6 +22,7 @@ mod app_state;
mod commands;
mod sentry_transport;
mod server;
mod shutdown;
mod sync;
mod telemetry;
mod util;

116
crates/cli/src/shutdown.rs Normal file
View File

@@ -0,0 +1,116 @@
// Copyright 2024 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::time::Duration;
use tokio::signal::unix::{Signal, SignalKind};
use tokio_util::{sync::CancellationToken, task::TaskTracker};
/// A helper to manage graceful shutdowns and track tasks that gracefully
/// shutdown.
///
/// It will listen for SIGTERM and SIGINT signals, and will trigger a soft
/// shutdown on the first signal, and a hard shutdown on the second signal or
/// after a timeout.
///
/// Users of this manager should use the `soft_shutdown_token` to react to a
/// soft shutdown, which should gracefully finish requests and close
/// connections, and the `hard_shutdown_token` to react to a hard shutdown,
/// which should drop all connections and finish all requests.
///
/// They should also use the `task_tracker` to make it track things running, so
/// that it knows when the soft shutdown is over and worked.
pub struct ShutdownManager {
hard_shutdown_token: CancellationToken,
soft_shutdown_token: CancellationToken,
task_tracker: TaskTracker,
sigterm: Signal,
sigint: Signal,
timeout: Duration,
}
impl ShutdownManager {
/// Create a new shutdown manager, installing the signal handlers
///
/// # Errors
///
/// Returns an error if the signal handler could not be installed
pub fn new() -> Result<Self, std::io::Error> {
let hard_shutdown_token = CancellationToken::new();
let soft_shutdown_token = hard_shutdown_token.child_token();
let sigterm = tokio::signal::unix::signal(SignalKind::terminate())?;
let sigint = tokio::signal::unix::signal(SignalKind::interrupt())?;
let timeout = Duration::from_secs(60);
let task_tracker = TaskTracker::new();
Ok(Self {
hard_shutdown_token,
soft_shutdown_token,
task_tracker,
sigterm,
sigint,
timeout,
})
}
/// Get a reference to the task tracker
#[must_use]
pub fn task_tracker(&self) -> &TaskTracker {
&self.task_tracker
}
/// Get a cancellation token that can be used to react to a hard shutdown
#[must_use]
pub fn hard_shutdown_token(&self) -> CancellationToken {
self.hard_shutdown_token.clone()
}
/// Get a cancellation token that can be used to react to a soft shutdown
#[must_use]
pub fn soft_shutdown_token(&self) -> CancellationToken {
self.soft_shutdown_token.clone()
}
/// Run until we finish completely shutting down.
pub async fn run(mut self) {
// Wait for a first signal and trigger the soft shutdown
tokio::select! {
_ = self.sigterm.recv() => {
tracing::info!("Shutdown signal received (SIGTERM), shutting down");
},
_ = self.sigint.recv() => {
tracing::info!("Shutdown signal received (SIGINT), shutting down");
},
};
self.soft_shutdown_token.cancel();
self.task_tracker.close();
// Start the timeout
let timeout = tokio::time::sleep(self.timeout);
tokio::select! {
_ = self.sigterm.recv() => {
tracing::warn!("Second shutdown signal received (SIGTERM), abort");
},
_ = self.sigint.recv() => {
tracing::warn!("Second shutdown signal received (SIGINT), abort");
},
() = timeout => {
tracing::warn!("Shutdown timeout reached, abort");
},
() = self.task_tracker.wait() => {
// This is the "happy path", we have gracefully shutdown
},
}
self.hard_shutdown_token().cancel();
// TODO: we may want to have a time out on the task tracker, in case we have
// really stuck tasks on it
self.task_tracker().wait().await;
tracing::info!("All tasks are done, exitting");
}
}

View File

@@ -14,6 +14,7 @@ workspace = true
[dependencies]
# Async runtime
tokio.workspace = true
tokio-util.workspace = true
futures-util = "0.3.31"
async-trait.workspace = true

View File

@@ -13,6 +13,7 @@ use chrono::{DateTime, Utc};
use mas_data_model::{BrowserSession, CompatSession, Session};
use mas_storage::Clock;
use sqlx::PgPool;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use ulid::Ulid;
pub use self::bound::Bound;
@@ -45,7 +46,6 @@ enum Message {
ip: Option<IpAddr>,
},
Flush(tokio::sync::oneshot::Sender<()>),
Shutdown(tokio::sync::oneshot::Sender<()>),
}
#[derive(Clone)]
@@ -54,16 +54,29 @@ pub struct ActivityTracker {
}
impl ActivityTracker {
/// Create a new activity tracker, spawning the worker.
/// Create a new activity tracker
///
/// It will spawn the background worker and a loop to flush the tracker on
/// the task tracker, and both will shut themselves down, flushing one last
/// time, when the cancellation token is cancelled.
#[must_use]
pub fn new(pool: PgPool, flush_interval: std::time::Duration) -> Self {
pub fn new(
pool: PgPool,
flush_interval: std::time::Duration,
task_tracker: &TaskTracker,
cancellation_token: CancellationToken,
) -> Self {
let worker = Worker::new(pool);
let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE);
let tracker = ActivityTracker { channel: sender };
// Spawn the flush loop and the worker
tokio::spawn(tracker.clone().flush_loop(flush_interval));
tokio::spawn(worker.run(receiver));
task_tracker.spawn(
tracker
.clone()
.flush_loop(flush_interval, cancellation_token.clone()),
);
task_tracker.spawn(worker.run(receiver, cancellation_token));
tracker
}
@@ -148,50 +161,47 @@ impl ActivityTracker {
match res {
Ok(()) => {
if let Err(e) = rx.await {
tracing::error!("Failed to flush activity tracker: {}", e);
tracing::error!(
error = &e as &dyn std::error::Error,
"Failed to flush activity tracker"
);
}
}
Err(e) => {
tracing::error!("Failed to flush activity tracker: {}", e);
tracing::error!(
error = &e as &dyn std::error::Error,
"Failed to flush activity tracker"
);
}
}
}
/// Regularly flush the activity tracker.
async fn flush_loop(self, interval: std::time::Duration) {
async fn flush_loop(
self,
interval: std::time::Duration,
cancellation_token: CancellationToken,
) {
loop {
tokio::select! {
biased;
() = cancellation_token.cancelled() => {
// The cancellation token was cancelled, so we should exit
return;
}
// First check if the channel is closed, then check if the timer expired
() = self.channel.closed() => {
// The channel was closed, so we should exit
break;
return;
}
() = tokio::time::sleep(interval) => {
self.flush().await;
}
}
}
}
/// Shutdown the activity tracker.
///
/// This will wait for all pending messages to be processed.
pub async fn shutdown(&self) {
let (tx, rx) = tokio::sync::oneshot::channel();
let res = self.channel.send(Message::Shutdown(tx)).await;
match res {
Ok(()) => {
if let Err(e) = rx.await {
tracing::error!("Failed to shutdown activity tracker: {}", e);
}
}
Err(e) => {
tracing::error!("Failed to shutdown activity tracker: {}", e);
}
}
}
}

View File

@@ -13,6 +13,7 @@ use opentelemetry::{
Key,
};
use sqlx::PgPool;
use tokio_util::sync::CancellationToken;
use ulid::Ulid;
use crate::activity_tracker::{Message, SessionKind};
@@ -88,9 +89,30 @@ impl Worker {
}
}
pub(super) async fn run(mut self, mut receiver: tokio::sync::mpsc::Receiver<Message>) {
let mut shutdown_notifier = None;
while let Some(message) = receiver.recv().await {
pub(super) async fn run(
mut self,
mut receiver: tokio::sync::mpsc::Receiver<Message>,
cancellation_token: CancellationToken,
) {
loop {
let message = tokio::select! {
// Because we want the cancellation token to trigger only once,
// we looked whether we closed the channel or not
() = cancellation_token.cancelled(), if !receiver.is_closed() => {
// We only close the channel, which will make it flush all
// the pending messages
receiver.close();
tracing::debug!("Shutting down activity tracker");
continue;
},
message = receiver.recv() => {
// We consumed all the messages, break out of the loop
let Some(message) = message else { break };
message
}
};
match message {
Message::Record {
kind,
@@ -129,37 +151,18 @@ impl Worker {
record.end_time = date_time.max(record.end_time);
}
Message::Flush(tx) => {
self.message_counter.add(1, &[TYPE.string("flush")]);
self.flush().await;
let _ = tx.send(());
}
Message::Shutdown(tx) => {
self.message_counter.add(1, &[TYPE.string("shutdown")]);
let old_tx = shutdown_notifier.replace(tx);
if let Some(old_tx) = old_tx {
tracing::warn!("Activity tracker shutdown requested while another shutdown was already in progress");
// Still send the shutdown signal to the previous notifier. This means we
// send the shutdown signal before we flush the activity tracker, but that
// should be fine, since there should not be multiple shutdown requests.
let _ = old_tx.send(());
}
receiver.close();
}
}
}
// Flush one last time
self.flush().await;
if let Some(shutdown_notifier) = shutdown_notifier {
let _ = shutdown_notifier.send(());
} else {
// This should never happen, since we set the shutdown notifier when we receive
// the first shutdown message
tracing::warn!("Activity tracker shutdown requested but no shutdown notifier was set");
}
}
/// Flush the activity tracker.

View File

@@ -44,6 +44,10 @@ use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use serde::{de::DeserializeOwned, Serialize};
use sqlx::PgPool;
use tokio_util::{
sync::{CancellationToken, DropGuard},
task::TaskTracker,
};
use tower::{Layer, Service, ServiceExt};
use url::Url;
@@ -105,6 +109,9 @@ pub(crate) struct TestState {
pub limiter: Limiter,
pub clock: Arc<MockClock>,
pub rng: Arc<Mutex<ChaChaRng>>,
#[allow(dead_code)] // It is used, as it will cancel the CancellationToken when dropped
cancellation_drop_guard: Arc<DropGuard>,
}
fn workspace_root() -> camino::Utf8PathBuf {
@@ -147,6 +154,9 @@ impl TestState {
) -> Result<Self, anyhow::Error> {
let workspace_root = workspace_root();
let task_tracker = TaskTracker::new();
let shutdown_token = CancellationToken::new();
let url_builder = UrlBuilder::new("https://example.com/".parse()?, None, None);
let templates = Templates::load(
@@ -204,8 +214,12 @@ impl TestState {
let graphql_schema = graphql::schema_builder().data(state).finish();
let activity_tracker =
ActivityTracker::new(pool.clone(), std::time::Duration::from_secs(1));
let activity_tracker = ActivityTracker::new(
pool.clone(),
std::time::Duration::from_secs(60),
&task_tracker,
shutdown_token.child_token(),
);
let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();
@@ -227,6 +241,7 @@ impl TestState {
limiter,
clock,
rng,
cancellation_drop_guard: Arc::new(shutdown_token.drop_guard()),
})
}

View File

@@ -26,6 +26,7 @@ pub async fn get(
cookie_jar: CookieJar,
PreferredLanguage(locale): PreferredLanguage,
) -> Result<impl IntoResponse, FancyError> {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let session = session_info.load_session(&mut repo).await?;

View File

@@ -13,7 +13,6 @@ workspace = true
[dependencies]
bytes.workspace = true
event-listener = "5.3.1"
futures-util = "0.3.31"
http-body.workspace = true
hyper = { workspace = true, features = ["server"] }
@@ -24,6 +23,7 @@ socket2 = "0.5.7"
thiserror.workspace = true
tokio.workspace = true
tokio-rustls = "0.26.0"
tokio-util.workspace = true
tower.workspace = true
tower-http.workspace = true
tracing.workspace = true

View File

@@ -14,9 +14,9 @@ use std::{
use anyhow::Context;
use hyper::{Request, Response};
use mas_listener::{server::Server, shutdown::ShutdownStream, ConnectionInfo};
use tokio::signal::unix::SignalKind;
use mas_listener::{server::Server, ConnectionInfo};
use tokio_rustls::rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig};
use tokio_util::sync::CancellationToken;
use tower::service_fn;
static CA_CERT_PEM: &[u8] = include_bytes!("./certs/ca.pem");
@@ -53,12 +53,23 @@ async fn main() -> Result<(), anyhow::Error> {
tracing::info!("Listening on http://127.0.0.1:3000, http(proxy)://127.0.0.1:3001, https://127.0.0.1:3002 and https(proxy)://127.0.0.1:3003");
let shutdown = ShutdownStream::default()
.with_timeout(Duration::from_secs(1))
.with_signal(SignalKind::interrupt())?
.with_signal(SignalKind::terminate())?;
let hard_shutdown = CancellationToken::new();
let soft_shutdown = hard_shutdown.child_token();
mas_listener::server::run_servers(servers, shutdown).await;
{
let hard_shutdown = hard_shutdown.clone();
let soft_shutdown = soft_shutdown.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.unwrap();
tracing::info!("Ctrl-C received, performing soft-shutdown");
soft_shutdown.cancel();
tokio::signal::ctrl_c().await.unwrap();
tracing::info!("Ctrl-C received again, shutting down");
hard_shutdown.cancel();
});
}
mas_listener::server::run_servers(servers, hard_shutdown, soft_shutdown).await;
Ok(())
}

View File

@@ -16,7 +16,6 @@ pub mod maybe_tls;
pub mod proxy_protocol;
pub mod rewind;
pub mod server;
pub mod shutdown;
pub mod unix_or_tcp;
#[derive(Debug, Clone)]

View File

@@ -7,13 +7,12 @@
use std::{
future::Future,
pin::Pin,
sync::{atomic::AtomicBool, Arc},
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use event_listener::{Event, EventListener};
use futures_util::{stream::SelectAll, Stream, StreamExt};
use futures_util::{stream::SelectAll, StreamExt};
use hyper::{Request, Response};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
@@ -23,6 +22,7 @@ use hyper_util::{
use pin_project_lite::pin_project;
use thiserror::Error;
use tokio_rustls::rustls::ServerConfig;
use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
use tower::Service;
use tower_http::add_extension::AddExtension;
use tracing::Instrument;
@@ -84,18 +84,24 @@ impl<S> Server<S> {
}
/// Run a single server
pub async fn run<B, SD>(self, shutdown: SD)
where
pub async fn run<B>(
self,
soft_shutdown_token: CancellationToken,
hard_shutdown_token: CancellationToken,
) where
S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: std::error::Error + Send + Sync + 'static,
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: std::error::Error + Send + Sync + 'static,
SD: Stream + Unpin,
SD::Item: std::fmt::Display,
{
run_servers(std::iter::once(self), shutdown).await;
run_servers(
std::iter::once(self),
soft_shutdown_token,
hard_shutdown_token,
)
.await;
}
}
@@ -252,18 +258,16 @@ pin_project! {
#[pin]
connection: C,
#[pin]
shutdown_listener: EventListener,
shutdown_in_progress: Arc<AtomicBool>,
cancellation_future: WaitForCancellationFutureOwned,
did_start_shutdown: bool,
}
}
impl<C> AbortableConnection<C> {
fn new(connection: C, shutdown_in_progress: &Arc<AtomicBool>, event: &Arc<Event>) -> Self {
fn new(connection: C, cancellation_token: CancellationToken) -> Self {
Self {
connection,
shutdown_listener: event.listen(),
shutdown_in_progress: Arc::clone(shutdown_in_progress),
cancellation_future: cancellation_token.cancelled_owned(),
did_start_shutdown: false,
}
}
@@ -286,19 +290,11 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
// Poll the shutdown signal, so that wakers get registered.
// XXX: I don't think we care about the result of this poll, since it's only
// really to register wakers. But I'm not sure if it's safe to
// ignore the result.
let _ = this.shutdown_listener.poll(cx);
if !*this.did_start_shutdown
&& this
.shutdown_in_progress
.load(std::sync::atomic::Ordering::Relaxed)
{
*this.did_start_shutdown = true;
this.connection.as_mut().graceful_shutdown();
if let Poll::Ready(()) = this.cancellation_future.poll(cx) {
if !*this.did_start_shutdown {
*this.did_start_shutdown = true;
this.connection.as_mut().graceful_shutdown();
}
}
this.connection.poll(cx)
@@ -306,16 +302,17 @@ where
}
#[allow(clippy::too_many_lines)]
pub async fn run_servers<S, B, SD>(listeners: impl IntoIterator<Item = Server<S>>, mut shutdown: SD)
where
pub async fn run_servers<S, B>(
listeners: impl IntoIterator<Item = Server<S>>,
soft_shutdown_token: CancellationToken,
hard_shutdown_token: CancellationToken,
) where
S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: std::error::Error + Send + Sync + 'static,
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: std::error::Error + Send + Sync + 'static,
SD: Stream + Unpin,
SD::Item: std::fmt::Display,
{
// Create a stream of accepted connections out of the listeners
let mut accept_stream: SelectAll<_> = listeners
@@ -344,19 +341,13 @@ where
// A JoinSet which collects connections that are being served
let mut connection_tasks = tokio::task::JoinSet::new();
// A shared atomic boolean to tell all connections to shutdown
let shutdown_in_progress = Arc::new(AtomicBool::new(false));
let shutdown_event = Arc::new(Event::new());
loop {
tokio::select! {
biased;
// First look for the shutdown signal
res = shutdown.next() => {
let why = res.map_or_else(|| String::from("???"), |why| format!("{why}"));
tracing::info!("Received shutdown signal ({why})");
() = soft_shutdown_token.cancelled() => {
tracing::debug!("Shutting down listeners");
break;
},
@@ -365,7 +356,7 @@ where
match res {
Some(Ok(Ok(connection))) => {
tracing::trace!("Accepted connection");
let conn = AbortableConnection::new(connection, &shutdown_in_progress, &shutdown_event);
let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token());
connection_tasks.spawn(conn);
},
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
@@ -385,9 +376,8 @@ where
},
// Look for connections to accept
res = accept_stream.next(), if !accept_stream.is_empty() => {
// SAFETY: We shouldn't reach this branch if the stream set is empty
let Some(res) = res else { unreachable!() };
res = accept_stream.next() => {
let Some(res) = res else { continue };
// Spawn the connection in the set, so we don't have to wait for the handshake to
// accept the next connection. This allows us to keep track of active connections
@@ -401,10 +391,6 @@ where
};
}
// Tell the active connections to shutdown
shutdown_in_progress.store(true, std::sync::atomic::Ordering::Relaxed);
shutdown_event.notify(usize::MAX);
// Wait for connections to cleanup
if !accept_tasks.is_empty() || !connection_tasks.is_empty() {
tracing::info!(
@@ -422,7 +408,7 @@ where
match res {
Some(Ok(Ok(connection))) => {
tracing::trace!("Accepted connection");
let conn = AbortableConnection::new(connection, &shutdown_in_progress, &shutdown_event);
let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token());
connection_tasks.spawn(conn);
}
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
@@ -441,11 +427,10 @@ where
}
},
// Handle when we receive the shutdown signal again
res = shutdown.next() => {
let why = res.map_or_else(|| String::from("???"), |why| format!("{why}"));
// Handle when we are asked to hard shutdown
() = hard_shutdown_token.cancelled() => {
tracing::warn!(
"Received shutdown signal again ({why}), forcing shutdown ({active} active connections, {pending} pending connections)",
"Forcing shutdown ({active} active connections, {pending} pending connections)",
active = connection_tasks.len(),
pending = accept_tasks.len(),
);
@@ -457,5 +442,4 @@ where
accept_tasks.shutdown().await;
connection_tasks.shutdown().await;
tracing::info!("Shutdown complete");
}

View File

@@ -1,172 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::{fmt::Display, pin::Pin, task::Poll, time::Duration};
use futures_util::{ready, Future, Stream};
use tokio::{
signal::unix::{signal, Signal, SignalKind},
time::Sleep,
};
#[derive(Debug, Clone, Copy)]
pub enum ShutdownReason {
Signal(SignalKind),
Timeout,
}
fn signal_to_str(kind: SignalKind) -> &'static str {
match kind.as_raw_value() {
libc::SIGALRM => "SIGALRM",
libc::SIGCHLD => "SIGCHLD",
libc::SIGHUP => "SIGHUP",
libc::SIGINT => "SIGINT",
libc::SIGIO => "SIGIO",
libc::SIGPIPE => "SIGPIPE",
libc::SIGQUIT => "SIGQUIT",
libc::SIGTERM => "SIGTERM",
libc::SIGUSR1 => "SIGUSR1",
libc::SIGUSR2 => "SIGUSR2",
_ => "SIG???",
}
}
impl Display for ShutdownReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Signal(s) => signal_to_str(*s).fmt(f),
Self::Timeout => "timeout".fmt(f),
}
}
}
pub enum ShutdownStreamState {
Waiting,
Graceful { sleep: Option<Pin<Box<Sleep>>> },
Done,
}
impl Default for ShutdownStreamState {
fn default() -> Self {
Self::Waiting
}
}
impl ShutdownStreamState {
fn is_graceful(&self) -> bool {
matches!(self, Self::Graceful { .. })
}
fn is_done(&self) -> bool {
matches!(self, Self::Done)
}
fn get_sleep_mut(&mut self) -> Option<&mut Pin<Box<Sleep>>> {
match self {
Self::Graceful { sleep } => sleep.as_mut(),
_ => None,
}
}
}
/// A stream which is used to drive a graceful shutdown.
///
/// It will emit 2 items: one when a first signal is caught, the other when
/// either another signal is caught, or after a timeout.
#[derive(Default)]
pub struct ShutdownStream {
state: ShutdownStreamState,
signals: Vec<(SignalKind, Signal)>,
timeout: Option<Duration>,
}
impl ShutdownStream {
/// Create a default shutdown stream, which listens on SIGINT and SIGTERM,
/// with a 60s timeout
///
/// # Errors
///
/// Returns an error if signal handlers could not be installed
pub fn new() -> Result<Self, std::io::Error> {
let ret = Self::default()
.with_timeout(Duration::from_secs(60))
.with_signal(SignalKind::interrupt())?
.with_signal(SignalKind::terminate())?;
Ok(ret)
}
/// Add a signal to register
///
/// # Errors
///
/// Returns an error if the signal handler could not be installed
pub fn with_signal(mut self, kind: SignalKind) -> Result<Self, std::io::Error> {
let signal = signal(kind)?;
self.signals.push((kind, signal));
Ok(self)
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
impl Stream for ShutdownStream {
type Item = ShutdownReason;
fn size_hint(&self) -> (usize, Option<usize>) {
match self.state {
ShutdownStreamState::Waiting => (2, Some(2)),
ShutdownStreamState::Graceful { .. } => (1, Some(1)),
ShutdownStreamState::Done => (0, Some(0)),
}
}
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.state.is_done() {
return Poll::Ready(None);
}
for (kind, signal) in &mut this.signals {
match signal.poll_recv(cx) {
Poll::Ready(_) => {
// We got a signal
if this.state.is_graceful() {
// If we was gracefully shutting down, mark it as done
this.state = ShutdownStreamState::Done;
} else {
// Else start the timeout
let sleep = this
.timeout
.map(|duration| Box::pin(tokio::time::sleep(duration)));
this.state = ShutdownStreamState::Graceful { sleep };
}
return Poll::Ready(Some(ShutdownReason::Signal(*kind)));
}
Poll::Pending => {}
}
}
if let Some(timeout) = this.state.get_sleep_mut() {
ready!(timeout.as_mut().poll(cx));
this.state = ShutdownStreamState::Done;
Poll::Ready(Some(ShutdownReason::Timeout))
} else {
Poll::Pending
}
}
}