diff --git a/Cargo.lock b/Cargo.lock
index 80c01bd62..eadebcb82 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2432,6 +2432,7 @@ dependencies = [
name = "mas-http"
version = "0.1.0"
dependencies = [
+ "anyhow",
"axum",
"bytes 1.2.1",
"futures-util",
diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml
index b6f515592..ea7bd83cd 100644
--- a/crates/http/Cargo.toml
+++ b/crates/http/Cargo.toml
@@ -28,3 +28,9 @@ tower = { version = "0.4.13", features = ["timeout", "limit"] }
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors"] }
tracing = "0.1.36"
tracing-opentelemetry = "0.17.4"
+
+[dev-dependencies]
+anyhow = "1.0.62"
+serde = { version = "1.0.142", features = ["derive"] }
+tokio = { version = "1.20.1", features = ["macros"] }
+tower = { version = "0.4.13", features = ["util"] }
diff --git a/crates/http/src/ext.rs b/crates/http/src/ext.rs
index 49d4f5194..cab40b400 100644
--- a/crates/http/src/ext.rs
+++ b/crates/http/src/ext.rs
@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-use http::header::HeaderName;
+use std::ops::RangeBounds;
+
+use http::{header::HeaderName, Request, StatusCode};
use once_cell::sync::OnceCell;
-use tower::{layer::util::Stack, ServiceBuilder};
+use tower::{layer::util::Stack, Service, ServiceBuilder};
use tower_http::cors::CorsLayer;
use crate::layers::{
body_to_bytes::{BodyToBytes, BodyToBytesLayer},
+ catch_http_codes::{CatchHttpCodes, CatchHttpCodesLayer},
form_urlencoded_request::{FormUrlencodedRequest, FormUrlencodedRequestLayer},
json_request::{JsonRequest, JsonRequestLayer},
json_response::{JsonResponse, JsonResponseLayer},
@@ -65,7 +68,7 @@ impl CorsLayerExt for CorsLayer {
}
}
-pub trait ServiceExt: Sized {
+pub trait ServiceExt
: Sized {
fn response_body_to_bytes(self) -> BodyToBytes {
BodyToBytes::new(self)
}
@@ -81,19 +84,54 @@ pub trait ServiceExt: Sized {
fn form_urlencoded_request(self) -> FormUrlencodedRequest {
FormUrlencodedRequest::new(self)
}
+
+ fn catch_http_code(self, status_code: StatusCode, mapper: M) -> CatchHttpCodes
+ where
+ M: Clone,
+ {
+ self.catch_http_codes(status_code..=status_code, mapper)
+ }
+
+ fn catch_http_codes(self, bounds: B, mapper: M) -> CatchHttpCodes
+ where
+ B: RangeBounds,
+ M: Clone,
+ {
+ CatchHttpCodes::new(self, bounds, mapper)
+ }
}
-impl ServiceExt for S {}
+impl ServiceExt for S where S: Service> {}
pub trait ServiceBuilderExt: Sized {
- fn response_to_bytes(self) -> ServiceBuilder>;
+ fn response_body_to_bytes(self) -> ServiceBuilder>;
fn json_response(self) -> ServiceBuilder, L>>;
fn json_request(self) -> ServiceBuilder, L>>;
fn form_urlencoded_request(self) -> ServiceBuilder, L>>;
+
+ fn catch_http_code(
+ self,
+ status_code: StatusCode,
+ mapper: M,
+ ) -> ServiceBuilder, L>>
+ where
+ M: Clone,
+ {
+ self.catch_http_codes(status_code..=status_code, mapper)
+ }
+
+ fn catch_http_codes(
+ self,
+ bounds: B,
+ mapper: M,
+ ) -> ServiceBuilder, L>>
+ where
+ B: RangeBounds,
+ M: Clone;
}
impl ServiceBuilderExt for ServiceBuilder {
- fn response_to_bytes(self) -> ServiceBuilder> {
+ fn response_body_to_bytes(self) -> ServiceBuilder> {
self.layer(BodyToBytesLayer::default())
}
@@ -108,4 +146,16 @@ impl ServiceBuilderExt for ServiceBuilder {
fn form_urlencoded_request(self) -> ServiceBuilder, L>> {
self.layer(FormUrlencodedRequestLayer::default())
}
+
+ fn catch_http_codes(
+ self,
+ bounds: B,
+ mapper: M,
+ ) -> ServiceBuilder, L>>
+ where
+ B: RangeBounds,
+ M: Clone,
+ {
+ self.layer(CatchHttpCodesLayer::new(bounds, mapper))
+ }
}
diff --git a/crates/http/src/layers/catch_http_codes.rs b/crates/http/src/layers/catch_http_codes.rs
new file mode 100644
index 000000000..22bed60f7
--- /dev/null
+++ b/crates/http/src/layers/catch_http_codes.rs
@@ -0,0 +1,138 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::ops::{Bound, RangeBounds};
+
+use futures_util::FutureExt;
+use http::{Request, Response, StatusCode};
+use thiserror::Error;
+use tower::{Layer, Service};
+
+#[derive(Debug, Error)]
+pub enum Error {
+ #[error(transparent)]
+ Service { inner: S },
+
+ #[error("request failed with status {status_code}")]
+ HttpError {
+ status_code: StatusCode,
+ #[source]
+ inner: E,
+ },
+}
+
+impl Error {
+ fn service(inner: S) -> Self {
+ Self::Service { inner }
+ }
+
+ pub fn status_code(&self) -> Option {
+ match self {
+ Self::Service { .. } => None,
+ Self::HttpError { status_code, .. } => Some(*status_code),
+ }
+ }
+}
+
+pub struct CatchHttpCodes {
+ inner: S,
+ bounds: (Bound, Bound),
+ mapper: M,
+}
+
+impl CatchHttpCodes {
+ pub fn new(inner: S, bounds: B, mapper: M) -> Self
+ where
+ B: RangeBounds,
+ M: Clone,
+ {
+ let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
+ Self {
+ inner,
+ bounds,
+ mapper,
+ }
+ }
+}
+
+impl Service> for CatchHttpCodes
+where
+ S: Service, Response = Response>,
+ S::Future: Send + 'static,
+ M: Fn(Response) -> E + Send + Clone + 'static,
+{
+ type Error = Error;
+ type Response = Response;
+ type Future = futures_util::future::Map<
+ S::Future,
+ Box<
+ dyn Fn(Result) -> Result
+ + Send
+ + 'static,
+ >,
+ >;
+
+ fn poll_ready(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll> {
+ self.inner.poll_ready(cx).map_err(Error::service)
+ }
+
+ fn call(&mut self, request: Request) -> Self::Future {
+ let fut = self.inner.call(request);
+ let bounds = self.bounds;
+ let mapper = self.mapper.clone();
+
+ fut.map(Box::new(move |res: Result| {
+ let response = res.map_err(Error::service)?;
+ let status_code = response.status();
+
+ if bounds.contains(&status_code) {
+ let inner = mapper(response);
+ Err(Error::HttpError { status_code, inner })
+ } else {
+ Ok(response)
+ }
+ }))
+ }
+}
+
+#[derive(Clone)]
+pub struct CatchHttpCodesLayer {
+ bounds: (Bound, Bound),
+ mapper: M,
+}
+
+impl CatchHttpCodesLayer {
+ pub fn new(bounds: B, mapper: M) -> Self
+ where
+ B: RangeBounds,
+ M: Clone,
+ {
+ let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
+ Self { bounds, mapper }
+ }
+}
+
+impl Layer for CatchHttpCodesLayer
+where
+ M: Clone,
+{
+ type Service = CatchHttpCodes;
+
+ fn layer(&self, inner: S) -> Self::Service {
+ CatchHttpCodes::new(inner, self.bounds, self.mapper.clone())
+ }
+}
diff --git a/crates/http/src/layers/mod.rs b/crates/http/src/layers/mod.rs
index 5f8537430..09dae8471 100644
--- a/crates/http/src/layers/mod.rs
+++ b/crates/http/src/layers/mod.rs
@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-pub(crate) mod body_to_bytes;
-pub(crate) mod client;
-pub(crate) mod form_urlencoded_request;
-pub(crate) mod json_request;
-pub(crate) mod json_response;
+pub mod body_to_bytes;
+pub mod catch_http_codes;
+pub mod form_urlencoded_request;
+pub mod json_request;
+pub mod json_response;
pub mod otel;
+
+pub(crate) mod client;
pub(crate) mod server;
diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs
index 27b353615..c737ac075 100644
--- a/crates/http/src/lib.rs
+++ b/crates/http/src/lib.rs
@@ -35,24 +35,34 @@ use hyper::{
Client,
};
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
-use layers::{
- client::ClientResponse,
- otel::{TraceDns, TraceLayer},
-};
use thiserror::Error;
use tokio::{sync::OnceCell, task::JoinError};
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
+use self::layers::{
+ client::ClientResponse,
+ otel::{TraceDns, TraceLayer},
+};
+
mod ext;
mod future_service;
mod layers;
pub use self::{
- ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt},
+ ext::{
+ set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
+ ServiceExt as HttpServiceExt,
+ },
future_service::FutureService,
layers::{
- body_to_bytes::BodyToBytesLayer, client::ClientLayer, json_request::JsonRequestLayer,
- json_response::JsonResponseLayer, otel, server::ServerLayer,
+ body_to_bytes::{self, BodyToBytes, BodyToBytesLayer},
+ catch_http_codes::{self, CatchHttpCodes, CatchHttpCodesLayer},
+ client::ClientLayer,
+ form_urlencoded_request::{self, FormUrlencodedRequest, FormUrlencodedRequestLayer},
+ json_request::{self, JsonRequest, JsonRequestLayer},
+ json_response::{self, JsonResponse, JsonResponseLayer},
+ otel,
+ server::ServerLayer,
},
};
diff --git a/crates/http/tests/client_layers.rs b/crates/http/tests/client_layers.rs
new file mode 100644
index 000000000..e68104d1f
--- /dev/null
+++ b/crates/http/tests/client_layers.rs
@@ -0,0 +1,150 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::convert::Infallible;
+
+use anyhow::{bail, Context};
+use bytes::{Buf, Bytes};
+use headers::{ContentType, HeaderMapExt};
+use http::{header::ACCEPT, HeaderValue, Request, Response, StatusCode};
+use mas_http::HttpServiceBuilderExt;
+use serde::Deserialize;
+use thiserror::Error;
+use tower::{ServiceBuilder, ServiceExt};
+
+#[derive(Debug, Error, Deserialize)]
+#[error("Error code in response: {error}")]
+struct Error {
+ error: String,
+}
+
+#[tokio::test]
+async fn test_http_errors() {
+ async fn handle(_request: Request) -> Result, Infallible> {
+ let mut res = Response::new(r#"{"error": "invalid_request"}"#.to_owned());
+ *res.status_mut() = StatusCode::BAD_REQUEST;
+
+ Ok(res)
+ }
+
+ fn mapper(response: Response) -> Error {
+ serde_json::from_reader(response.into_body().reader()).unwrap()
+ }
+
+ let svc = ServiceBuilder::new()
+ .catch_http_code(StatusCode::BAD_REQUEST, mapper)
+ .response_body_to_bytes()
+ .service_fn(handle);
+
+ let request = Request::new(hyper::Body::empty());
+
+ let res = svc.oneshot(request).await;
+ let err = res.expect_err("the request should fail");
+ assert_eq!(err.status_code(), Some(StatusCode::BAD_REQUEST));
+}
+
+#[tokio::test]
+async fn test_json_request_body() {
+ async fn handle(request: Request) -> Result, anyhow::Error>
+ where
+ B: http_body::Body + Send,
+ B::Error: std::error::Error + Send + Sync + 'static,
+ {
+ if request
+ .headers()
+ .typed_get::()
+ .context("Missing Content-Type header")?
+ != ContentType::json()
+ {
+ bail!("Content-Type header is not application/json")
+ }
+
+ let bytes = hyper::body::to_bytes(request.into_body()).await?;
+ if bytes.to_vec() != br#"{"hello":"world"}"#.to_vec() {
+ bail!("Body mismatch")
+ }
+
+ let res = Response::new(hyper::Body::empty());
+ Ok(res)
+ }
+
+ let svc = ServiceBuilder::new().json_request().service_fn(handle);
+
+ let request = Request::new(serde_json::json!({"hello": "world"}));
+
+ let res = svc.oneshot(request).await;
+ res.expect("the request should succeed");
+}
+
+#[tokio::test]
+async fn test_json_response_body() {
+ async fn handle(request: Request) -> Result, anyhow::Error> {
+ if request
+ .headers()
+ .get(ACCEPT)
+ .context("Missing Accept header")?
+ != HeaderValue::from_static("application/json")
+ {
+ bail!("Accept header is not application/json")
+ }
+
+ let res = Response::new(r#"{"hello": "world"}"#.to_owned());
+ Ok(res)
+ }
+
+ let svc = ServiceBuilder::new()
+ .json_response()
+ .response_body_to_bytes()
+ .service_fn(handle);
+
+ let request = Request::new(hyper::Body::empty());
+
+ let res = svc.oneshot(request).await;
+ let response = res.expect("the request to succeed");
+ let body: serde_json::Value = response.into_body();
+ assert_eq!(body, serde_json::json!({"hello": "world"}));
+}
+
+#[tokio::test]
+async fn test_urlencoded_request_body() {
+ async fn handle(request: Request) -> Result, anyhow::Error>
+ where
+ B: http_body::Body + Send,
+ B::Error: std::error::Error + Send + Sync + 'static,
+ {
+ if request
+ .headers()
+ .typed_get::()
+ .context("Missing Content-Type header")?
+ != ContentType::form_url_encoded()
+ {
+ bail!("Content-Type header is not application/x-form-urlencoded")
+ }
+
+ let bytes = hyper::body::to_bytes(request.into_body()).await?;
+ assert_eq!(bytes.to_vec(), br#"hello=world"#.to_vec());
+
+ let res = Response::new(hyper::Body::empty());
+ Ok(res)
+ }
+
+ let svc = ServiceBuilder::new()
+ .form_urlencoded_request()
+ .service_fn(handle);
+
+ let request = Request::new(serde_json::json!({"hello": "world"}));
+
+ let res = svc.oneshot(request).await;
+ res.expect("the request to succeed");
+}