From 7dfebeda82de8831cfaa7014c29ab645b044e477 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 16 Aug 2022 19:47:44 +0200 Subject: [PATCH] Add a layer to catch HTTP error codes --- Cargo.lock | 1 + crates/http/Cargo.toml | 6 + crates/http/src/ext.rs | 62 ++++++++- crates/http/src/layers/catch_http_codes.rs | 138 +++++++++++++++++++ crates/http/src/layers/mod.rs | 12 +- crates/http/src/lib.rs | 24 +++- crates/http/tests/client_layers.rs | 150 +++++++++++++++++++++ 7 files changed, 375 insertions(+), 18 deletions(-) create mode 100644 crates/http/src/layers/catch_http_codes.rs create mode 100644 crates/http/tests/client_layers.rs diff --git a/Cargo.lock b/Cargo.lock index 80c01bd62..eadebcb82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2432,6 +2432,7 @@ dependencies = [ name = "mas-http" version = "0.1.0" dependencies = [ + "anyhow", "axum", "bytes 1.2.1", "futures-util", diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index b6f515592..ea7bd83cd 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -28,3 +28,9 @@ tower = { version = "0.4.13", features = ["timeout", "limit"] } tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors"] } tracing = "0.1.36" tracing-opentelemetry = "0.17.4" + +[dev-dependencies] +anyhow = "1.0.62" +serde = { version = "1.0.142", features = ["derive"] } +tokio = { version = "1.20.1", features = ["macros"] } +tower = { version = "0.4.13", features = ["util"] } diff --git a/crates/http/src/ext.rs b/crates/http/src/ext.rs index 49d4f5194..cab40b400 100644 --- a/crates/http/src/ext.rs +++ b/crates/http/src/ext.rs @@ -12,13 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use http::header::HeaderName; +use std::ops::RangeBounds; + +use http::{header::HeaderName, Request, StatusCode}; use once_cell::sync::OnceCell; -use tower::{layer::util::Stack, ServiceBuilder}; +use tower::{layer::util::Stack, Service, ServiceBuilder}; use tower_http::cors::CorsLayer; use crate::layers::{ body_to_bytes::{BodyToBytes, BodyToBytesLayer}, + catch_http_codes::{CatchHttpCodes, CatchHttpCodesLayer}, form_urlencoded_request::{FormUrlencodedRequest, FormUrlencodedRequestLayer}, json_request::{JsonRequest, JsonRequestLayer}, json_response::{JsonResponse, JsonResponseLayer}, @@ -65,7 +68,7 @@ impl CorsLayerExt for CorsLayer { } } -pub trait ServiceExt: Sized { +pub trait ServiceExt: Sized { fn response_body_to_bytes(self) -> BodyToBytes { BodyToBytes::new(self) } @@ -81,19 +84,54 @@ pub trait ServiceExt: Sized { fn form_urlencoded_request(self) -> FormUrlencodedRequest { FormUrlencodedRequest::new(self) } + + fn catch_http_code(self, status_code: StatusCode, mapper: M) -> CatchHttpCodes + where + M: Clone, + { + self.catch_http_codes(status_code..=status_code, mapper) + } + + fn catch_http_codes(self, bounds: B, mapper: M) -> CatchHttpCodes + where + B: RangeBounds, + M: Clone, + { + CatchHttpCodes::new(self, bounds, mapper) + } } -impl ServiceExt for S {} +impl ServiceExt for S where S: Service> {} pub trait ServiceBuilderExt: Sized { - fn response_to_bytes(self) -> ServiceBuilder>; + fn response_body_to_bytes(self) -> ServiceBuilder>; fn json_response(self) -> ServiceBuilder, L>>; fn json_request(self) -> ServiceBuilder, L>>; fn form_urlencoded_request(self) -> ServiceBuilder, L>>; + + fn catch_http_code( + self, + status_code: StatusCode, + mapper: M, + ) -> ServiceBuilder, L>> + where + M: Clone, + { + self.catch_http_codes(status_code..=status_code, mapper) + } + + fn catch_http_codes( + self, + bounds: B, + mapper: M, + ) -> ServiceBuilder, L>> + where + B: RangeBounds, + M: Clone; } impl ServiceBuilderExt for ServiceBuilder { - fn response_to_bytes(self) -> ServiceBuilder> { + fn response_body_to_bytes(self) -> ServiceBuilder> { self.layer(BodyToBytesLayer::default()) } @@ -108,4 +146,16 @@ impl ServiceBuilderExt for ServiceBuilder { fn form_urlencoded_request(self) -> ServiceBuilder, L>> { self.layer(FormUrlencodedRequestLayer::default()) } + + fn catch_http_codes( + self, + bounds: B, + mapper: M, + ) -> ServiceBuilder, L>> + where + B: RangeBounds, + M: Clone, + { + self.layer(CatchHttpCodesLayer::new(bounds, mapper)) + } } diff --git a/crates/http/src/layers/catch_http_codes.rs b/crates/http/src/layers/catch_http_codes.rs new file mode 100644 index 000000000..22bed60f7 --- /dev/null +++ b/crates/http/src/layers/catch_http_codes.rs @@ -0,0 +1,138 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::{Bound, RangeBounds}; + +use futures_util::FutureExt; +use http::{Request, Response, StatusCode}; +use thiserror::Error; +use tower::{Layer, Service}; + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Service { inner: S }, + + #[error("request failed with status {status_code}")] + HttpError { + status_code: StatusCode, + #[source] + inner: E, + }, +} + +impl Error { + fn service(inner: S) -> Self { + Self::Service { inner } + } + + pub fn status_code(&self) -> Option { + match self { + Self::Service { .. } => None, + Self::HttpError { status_code, .. } => Some(*status_code), + } + } +} + +pub struct CatchHttpCodes { + inner: S, + bounds: (Bound, Bound), + mapper: M, +} + +impl CatchHttpCodes { + pub fn new(inner: S, bounds: B, mapper: M) -> Self + where + B: RangeBounds, + M: Clone, + { + let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned()); + Self { + inner, + bounds, + mapper, + } + } +} + +impl Service> for CatchHttpCodes +where + S: Service, Response = Response>, + S::Future: Send + 'static, + M: Fn(Response) -> E + Send + Clone + 'static, +{ + type Error = Error; + type Response = Response; + type Future = futures_util::future::Map< + S::Future, + Box< + dyn Fn(Result) -> Result + + Send + + 'static, + >, + >; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx).map_err(Error::service) + } + + fn call(&mut self, request: Request) -> Self::Future { + let fut = self.inner.call(request); + let bounds = self.bounds; + let mapper = self.mapper.clone(); + + fut.map(Box::new(move |res: Result| { + let response = res.map_err(Error::service)?; + let status_code = response.status(); + + if bounds.contains(&status_code) { + let inner = mapper(response); + Err(Error::HttpError { status_code, inner }) + } else { + Ok(response) + } + })) + } +} + +#[derive(Clone)] +pub struct CatchHttpCodesLayer { + bounds: (Bound, Bound), + mapper: M, +} + +impl CatchHttpCodesLayer { + pub fn new(bounds: B, mapper: M) -> Self + where + B: RangeBounds, + M: Clone, + { + let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned()); + Self { bounds, mapper } + } +} + +impl Layer for CatchHttpCodesLayer +where + M: Clone, +{ + type Service = CatchHttpCodes; + + fn layer(&self, inner: S) -> Self::Service { + CatchHttpCodes::new(inner, self.bounds, self.mapper.clone()) + } +} diff --git a/crates/http/src/layers/mod.rs b/crates/http/src/layers/mod.rs index 5f8537430..09dae8471 100644 --- a/crates/http/src/layers/mod.rs +++ b/crates/http/src/layers/mod.rs @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub(crate) mod body_to_bytes; -pub(crate) mod client; -pub(crate) mod form_urlencoded_request; -pub(crate) mod json_request; -pub(crate) mod json_response; +pub mod body_to_bytes; +pub mod catch_http_codes; +pub mod form_urlencoded_request; +pub mod json_request; +pub mod json_response; pub mod otel; + +pub(crate) mod client; pub(crate) mod server; diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index 27b353615..c737ac075 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -35,24 +35,34 @@ use hyper::{ Client, }; use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder}; -use layers::{ - client::ClientResponse, - otel::{TraceDns, TraceLayer}, -}; use thiserror::Error; use tokio::{sync::OnceCell, task::JoinError}; use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt}; +use self::layers::{ + client::ClientResponse, + otel::{TraceDns, TraceLayer}, +}; + mod ext; mod future_service; mod layers; pub use self::{ - ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt}, + ext::{ + set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt, + ServiceExt as HttpServiceExt, + }, future_service::FutureService, layers::{ - body_to_bytes::BodyToBytesLayer, client::ClientLayer, json_request::JsonRequestLayer, - json_response::JsonResponseLayer, otel, server::ServerLayer, + body_to_bytes::{self, BodyToBytes, BodyToBytesLayer}, + catch_http_codes::{self, CatchHttpCodes, CatchHttpCodesLayer}, + client::ClientLayer, + form_urlencoded_request::{self, FormUrlencodedRequest, FormUrlencodedRequestLayer}, + json_request::{self, JsonRequest, JsonRequestLayer}, + json_response::{self, JsonResponse, JsonResponseLayer}, + otel, + server::ServerLayer, }, }; diff --git a/crates/http/tests/client_layers.rs b/crates/http/tests/client_layers.rs new file mode 100644 index 000000000..e68104d1f --- /dev/null +++ b/crates/http/tests/client_layers.rs @@ -0,0 +1,150 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::convert::Infallible; + +use anyhow::{bail, Context}; +use bytes::{Buf, Bytes}; +use headers::{ContentType, HeaderMapExt}; +use http::{header::ACCEPT, HeaderValue, Request, Response, StatusCode}; +use mas_http::HttpServiceBuilderExt; +use serde::Deserialize; +use thiserror::Error; +use tower::{ServiceBuilder, ServiceExt}; + +#[derive(Debug, Error, Deserialize)] +#[error("Error code in response: {error}")] +struct Error { + error: String, +} + +#[tokio::test] +async fn test_http_errors() { + async fn handle(_request: Request) -> Result, Infallible> { + let mut res = Response::new(r#"{"error": "invalid_request"}"#.to_owned()); + *res.status_mut() = StatusCode::BAD_REQUEST; + + Ok(res) + } + + fn mapper(response: Response) -> Error { + serde_json::from_reader(response.into_body().reader()).unwrap() + } + + let svc = ServiceBuilder::new() + .catch_http_code(StatusCode::BAD_REQUEST, mapper) + .response_body_to_bytes() + .service_fn(handle); + + let request = Request::new(hyper::Body::empty()); + + let res = svc.oneshot(request).await; + let err = res.expect_err("the request should fail"); + assert_eq!(err.status_code(), Some(StatusCode::BAD_REQUEST)); +} + +#[tokio::test] +async fn test_json_request_body() { + async fn handle(request: Request) -> Result, anyhow::Error> + where + B: http_body::Body + Send, + B::Error: std::error::Error + Send + Sync + 'static, + { + if request + .headers() + .typed_get::() + .context("Missing Content-Type header")? + != ContentType::json() + { + bail!("Content-Type header is not application/json") + } + + let bytes = hyper::body::to_bytes(request.into_body()).await?; + if bytes.to_vec() != br#"{"hello":"world"}"#.to_vec() { + bail!("Body mismatch") + } + + let res = Response::new(hyper::Body::empty()); + Ok(res) + } + + let svc = ServiceBuilder::new().json_request().service_fn(handle); + + let request = Request::new(serde_json::json!({"hello": "world"})); + + let res = svc.oneshot(request).await; + res.expect("the request should succeed"); +} + +#[tokio::test] +async fn test_json_response_body() { + async fn handle(request: Request) -> Result, anyhow::Error> { + if request + .headers() + .get(ACCEPT) + .context("Missing Accept header")? + != HeaderValue::from_static("application/json") + { + bail!("Accept header is not application/json") + } + + let res = Response::new(r#"{"hello": "world"}"#.to_owned()); + Ok(res) + } + + let svc = ServiceBuilder::new() + .json_response() + .response_body_to_bytes() + .service_fn(handle); + + let request = Request::new(hyper::Body::empty()); + + let res = svc.oneshot(request).await; + let response = res.expect("the request to succeed"); + let body: serde_json::Value = response.into_body(); + assert_eq!(body, serde_json::json!({"hello": "world"})); +} + +#[tokio::test] +async fn test_urlencoded_request_body() { + async fn handle(request: Request) -> Result, anyhow::Error> + where + B: http_body::Body + Send, + B::Error: std::error::Error + Send + Sync + 'static, + { + if request + .headers() + .typed_get::() + .context("Missing Content-Type header")? + != ContentType::form_url_encoded() + { + bail!("Content-Type header is not application/x-form-urlencoded") + } + + let bytes = hyper::body::to_bytes(request.into_body()).await?; + assert_eq!(bytes.to_vec(), br#"hello=world"#.to_vec()); + + let res = Response::new(hyper::Body::empty()); + Ok(res) + } + + let svc = ServiceBuilder::new() + .form_urlencoded_request() + .service_fn(handle); + + let request = Request::new(serde_json::json!({"hello": "world"})); + + let res = svc.oneshot(request).await; + res.expect("the request to succeed"); +}