Files
letro-authentication-service/crates/http/src/ext.rs
reivilibre f7366feb1f Improve errors when MAS contacts the Synapse homeserver (#2794)
* Add some drive-by docstrings

* Change text rendering of catch_http_codes::HttpError

Using `#[source]` is unnatural here because it makes it look like
two distinct errors (one being a cause of the other),
when in reality it is just one error, with 2 parts.

Using `Display` formatting for that leads to a more natural error.

* Add constraints to `catch_http_code{,s}` methods

Not strictly required, but does two things:

- documents what kind of function is expected
- provides a small extra amount of type enforcement at the call site,
  rather than later on when you find the result doesn't implement Service

* Add a `catch_http_errors` shorthand

Nothing major, just a quality of life improvement so you don't have to
repetitively write out what a HTTP error is

* Unexpected error page: remove leading whitespace from preformatted 'details' section

The extra whitespace was probably unintentional and makes the error harder to read,
particularly when it wraps onto a new line unnecessarily

* Capture and log Matrix errors received from Synapse

* Drive-by clippy fix: use clamp instead of min().max()

* Convert `err(Display)` to `err(Debug)` for `anyhow::Error`s in matrix-synapse support module
2024-06-07 11:14:04 +00:00

129 lines
4.2 KiB
Rust

// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{ops::RangeBounds, sync::OnceLock};
use http::{header::HeaderName, Request, Response, StatusCode};
use tower::Service;
use tower_http::cors::CorsLayer;
use crate::layers::{
body_to_bytes_response::BodyToBytesResponse, bytes_to_body_request::BytesToBodyRequest,
catch_http_codes::CatchHttpCodes, form_urlencoded_request::FormUrlencodedRequest,
json_request::JsonRequest, json_response::JsonResponse,
};
static PROPAGATOR_HEADERS: OnceLock<Vec<HeaderName>> = OnceLock::new();
/// Notify the CORS layer what opentelemetry propagators are being used. This
/// helps whitelisting headers in CORS requests.
///
/// # Panics
///
/// When called twice
pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
let headers = propagator
.fields()
.map(|h| HeaderName::try_from(h).unwrap())
.collect();
tracing::debug!(
?headers,
"Headers allowed in CORS requests for trace propagators set"
);
PROPAGATOR_HEADERS
.set(headers)
.expect(concat!(module_path!(), "::set_propagator was called twice"));
}
pub trait CorsLayerExt {
#[must_use]
fn allow_otel_headers<H>(self, headers: H) -> Self
where
H: IntoIterator<Item = HeaderName>;
}
impl CorsLayerExt for CorsLayer {
fn allow_otel_headers<H>(self, headers: H) -> Self
where
H: IntoIterator<Item = HeaderName>,
{
let base = PROPAGATOR_HEADERS.get().cloned().unwrap_or_default();
let headers: Vec<_> = headers.into_iter().chain(base).collect();
self.allow_headers(headers)
}
}
pub trait ServiceExt<Body>: Sized {
fn request_bytes_to_body(self) -> BytesToBodyRequest<Self> {
BytesToBodyRequest::new(self)
}
/// Adds a layer which collects all the response body into a contiguous
/// byte buffer.
/// This makes the response type `Response<Bytes>`.
fn response_body_to_bytes(self) -> BodyToBytesResponse<Self> {
BodyToBytesResponse::new(self)
}
fn json_response<T>(self) -> JsonResponse<Self, T> {
JsonResponse::new(self)
}
fn json_request<T>(self) -> JsonRequest<Self, T> {
JsonRequest::new(self)
}
fn form_urlencoded_request<T>(self) -> FormUrlencodedRequest<Self, T> {
FormUrlencodedRequest::new(self)
}
/// Catches responses with the given status code and then maps those
/// responses to an error type using the provided `mapper` function.
fn catch_http_code<M, ResBody, E>(
self,
status_code: StatusCode,
mapper: M,
) -> CatchHttpCodes<Self, M>
where
M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
{
self.catch_http_codes(status_code..=status_code, mapper)
}
/// Catches responses with the given status codes and then maps those
/// responses to an error type using the provided `mapper` function.
fn catch_http_codes<B, M, ResBody, E>(self, bounds: B, mapper: M) -> CatchHttpCodes<Self, M>
where
B: RangeBounds<StatusCode>,
M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
{
CatchHttpCodes::new(self, bounds, mapper)
}
/// Shorthand for [`Self::catch_http_codes`] which catches all client errors
/// (4xx) and server errors (5xx).
fn catch_http_errors<M, ResBody, E>(self, mapper: M) -> CatchHttpCodes<Self, M>
where
M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
{
self.catch_http_codes(
StatusCode::from_u16(400).unwrap()..StatusCode::from_u16(600).unwrap(),
mapper,
)
}
}
impl<S, B> ServiceExt<B> for S where S: Service<Request<B>> {}