Skip to content

Commit

Permalink
Request decompression (#282)
Browse files Browse the repository at this point in the history
* Add layer to decompress request bodies

Still needs some refactor to remove duplicate code and needs documentation

* Refactor decompression modules

* Fix incorrect rename

* Add ResponseFuture for RequestDecompression

Which either polls its inner future or returns 415 Unsupported Media Type

* Refactor DecompressionService

* Rollback rename and move of `Decompression`

Refactoring of `decompression` module will be done in a later PR

* Re-add `request` module to the `decompression` module

* Send "identity" encoding when no encodings are accepted

* Add documentation

* Add example

* Fix styling

* Fix some styling of imports and documentation

* Add enable parameter to RequestDecompressionLayer::pass_through_unaccepted

* Cleanup redundant code

* fix imports

* actually fix import

* check for zstd

* zstd

---------

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
  • Loading branch information
wanderinglethe and davidpdrsn committed Feb 24, 2023
1 parent f3d8528 commit 67130ab
Show file tree
Hide file tree
Showing 6 changed files with 548 additions and 3 deletions.
3 changes: 2 additions & 1 deletion tower-http/src/decompression/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ use std::{io, marker::PhantomData, pin::Pin, task::Poll};
use tokio_util::io::StreamReader;

pin_project! {
/// Response body of [`Decompression`].
/// Response body of [`RequestDecompression`] and [`Decompression`].
///
/// [`RequestDecompression`]: super::RequestDecompression
/// [`Decompression`]: super::Decompression
pub struct DecompressionBody<B>
where
Expand Down
54 changes: 52 additions & 2 deletions tower-http/src/decompression/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,51 @@
//! Middleware that decompresses response bodies.
//! Middleware that decompresses request and response bodies.
//!
//! # Example
//! # Examples
//!
//! #### Request
//! ```rust
//! use bytes::BytesMut;
//! use flate2::{write::GzEncoder, Compression};
//! use http::{header, HeaderValue, Request, Response};
//! use http_body::Body as _; // for Body::data
//! use hyper::Body;
//! use std::{error::Error, io::Write};
//! use tower::{Service, ServiceBuilder, service_fn, ServiceExt};
//! use tower_http::{BoxError, decompression::{DecompressionBody, RequestDecompressionLayer}};
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), BoxError> {
//! // A request encoded with gzip coming from some HTTP client.
//! let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
//! encoder.write_all(b"Hello?")?;
//! let request = Request::builder()
//! .header(header::CONTENT_ENCODING, "gzip")
//! .body(Body::from(encoder.finish()?))?;
//!
//! // Our HTTP server
//! let mut server = ServiceBuilder::new()
//! // Automatically decompress request bodies.
//! .layer(RequestDecompressionLayer::new())
//! .service(service_fn(handler));
//!
//! // Send the request, with the gzip encoded body, to our server.
//! let _response = server.ready().await?.call(request).await?;
//!
//! // Handler receives request whose body is decoded when read
//! async fn handler(mut req: Request<DecompressionBody<Body>>) -> Result<Response<Body>, BoxError>{
//! let mut data = BytesMut::new();
//! while let Some(chunk) = req.body_mut().data().await {
//! let chunk = chunk?;
//! data.extend_from_slice(&chunk[..]);
//! }
//! assert_eq!(data.freeze().to_vec(), b"Hello?");
//! Ok(Response::new(Body::from("Hello, World!")))
//! }
//! # Ok(())
//! # }
//! ```
//!
//! #### Response
//! ```rust
//! use bytes::BytesMut;
//! use http::{Request, Response};
Expand Down Expand Up @@ -53,6 +97,8 @@
//! # }
//! ```

mod request;

mod body;
mod future;
mod layer;
Expand All @@ -63,6 +109,10 @@ pub use self::{
service::Decompression,
};

pub use self::request::future::RequestDecompressionFuture;
pub use self::request::layer::RequestDecompressionLayer;
pub use self::request::service::RequestDecompression;

#[cfg(test)]
mod tests {
use super::*;
Expand Down
95 changes: 95 additions & 0 deletions tower-http/src/decompression/request/future.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use crate::compression_utils::AcceptEncoding;
use crate::BoxError;
use bytes::Buf;
use http::{header, HeaderValue, Response, StatusCode};
use http_body::{combinators::UnsyncBoxBody, Body, Empty};
use pin_project_lite::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;

pin_project! {
#[derive(Debug)]
/// Response future of [`RequestDecompression`]
pub struct RequestDecompressionFuture<F, B, E>
where
F: Future<Output = Result<Response<B>, E>>,
B: Body
{
#[pin]
kind: Kind<F, B, E>,
}
}

pin_project! {
#[derive(Debug)]
#[project = StateProj]
enum Kind<F, B, E>
where
F: Future<Output = Result<Response<B>, E>>,
B: Body
{
Inner {
#[pin]
fut: F
},
Unsupported {
#[pin]
accept: AcceptEncoding
},
}
}

impl<F, B, E> RequestDecompressionFuture<F, B, E>
where
F: Future<Output = Result<Response<B>, E>>,
B: Body,
{
#[must_use]
pub(super) fn unsupported_encoding(accept: AcceptEncoding) -> Self {
Self {
kind: Kind::Unsupported { accept },
}
}

#[must_use]
pub(super) fn inner(fut: F) -> Self {
Self {
kind: Kind::Inner { fut },
}
}
}

impl<F, B, E> Future for RequestDecompressionFuture<F, B, E>
where
F: Future<Output = Result<Response<B>, E>>,
B: Body + Send + 'static,
B::Data: Buf + 'static,
B::Error: Into<BoxError> + 'static,
E: Into<BoxError>,
{
type Output = Result<Response<UnsyncBoxBody<B::Data, BoxError>>, BoxError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().kind.project() {
StateProj::Inner { fut } => fut
.poll(cx)
.map_ok(|res| res.map(|body| body.map_err(Into::into).boxed_unsync()))
.map_err(Into::into),
StateProj::Unsupported { accept } => {
let res = Response::builder()
.header(
header::ACCEPT_ENCODING,
accept
.to_header_value()
.unwrap_or(HeaderValue::from_static("identity")),
)
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
.body(Empty::new().map_err(Into::into).boxed_unsync())
.unwrap();
Poll::Ready(Ok(res))
}
}
}
}
105 changes: 105 additions & 0 deletions tower-http/src/decompression/request/layer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use super::service::RequestDecompression;
use crate::compression_utils::AcceptEncoding;
use tower_layer::Layer;

/// Decompresses request bodies and calls its underlying service.
///
/// Transparently decompresses request bodies based on the `Content-Encoding` header.
/// When the encoding in the `Content-Encoding` header is not accepted an `Unsupported Media Type`
/// status code will be returned with the accepted encodings in the `Accept-Encoding` header.
///
/// Enabling pass-through of unaccepted encodings will not return an `Unsupported Media Type`. But
/// will call the underlying service with the unmodified request if the encoding is not supported.
/// This is disabled by default.
///
/// See the [module docs](crate::decompression) for more details.
#[derive(Debug, Default, Clone)]
pub struct RequestDecompressionLayer {
accept: AcceptEncoding,
pass_through_unaccepted: bool,
}

impl<S> Layer<S> for RequestDecompressionLayer {
type Service = RequestDecompression<S>;

fn layer(&self, service: S) -> Self::Service {
RequestDecompression {
inner: service,
accept: self.accept,
pass_through_unaccepted: self.pass_through_unaccepted,
}
}
}

impl RequestDecompressionLayer {
/// Creates a new `RequestDecompressionLayer`.
pub fn new() -> Self {
Default::default()
}

/// Sets whether to support gzip encoding.
#[cfg(feature = "decompression-gzip")]
pub fn gzip(mut self, enable: bool) -> Self {
self.accept.set_gzip(enable);
self
}

/// Sets whether to support Deflate encoding.
#[cfg(feature = "decompression-deflate")]
pub fn deflate(mut self, enable: bool) -> Self {
self.accept.set_deflate(enable);
self
}

/// Sets whether to support Brotli encoding.
#[cfg(feature = "decompression-br")]
pub fn br(mut self, enable: bool) -> Self {
self.accept.set_br(enable);
self
}

/// Sets whether to support Zstd encoding.
#[cfg(feature = "decompression-zstd")]
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}

/// Disables support for gzip encoding.
///
/// This method is available even if the `gzip` crate feature is disabled.
pub fn no_gzip(mut self) -> Self {
self.accept.set_gzip(false);
self
}

/// Disables support for Deflate encoding.
///
/// This method is available even if the `deflate` crate feature is disabled.
pub fn no_deflate(mut self) -> Self {
self.accept.set_deflate(false);
self
}

/// Disables support for Brotli encoding.
///
/// This method is available even if the `br` crate feature is disabled.
pub fn no_br(mut self) -> Self {
self.accept.set_br(false);
self
}

/// Disables support for Zstd encoding.
///
/// This method is available even if the `zstd` crate feature is disabled.
pub fn no_zstd(mut self) -> Self {
self.accept.set_zstd(false);
self
}

/// Sets whether to pass through the request even when the encoding is not supported.
pub fn pass_through_unaccepted(mut self, enable: bool) -> Self {
self.pass_through_unaccepted = enable;
self
}
}
109 changes: 109 additions & 0 deletions tower-http/src/decompression/request/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
pub(super) mod future;
pub(super) mod layer;
pub(super) mod service;

#[cfg(test)]
mod tests {
use super::service::RequestDecompression;
use crate::decompression::DecompressionBody;
use bytes::BytesMut;
use flate2::{write::GzEncoder, Compression};
use http::{header, Response, StatusCode};
use http_body::Body as _;
use hyper::{Body, Error, Request, Server};
use std::io::Write;
use std::net::SocketAddr;
use tower::{make::Shared, service_fn, Service, ServiceExt};

#[tokio::test]
async fn decompress_accepted_encoding() {
let req = request_gzip();
let mut svc = RequestDecompression::new(service_fn(assert_request_is_decompressed));
let _ = svc.ready().await.unwrap().call(req).await.unwrap();
}

#[tokio::test]
async fn support_unencoded_body() {
let req = Request::builder().body(Body::from("Hello?")).unwrap();
let mut svc = RequestDecompression::new(service_fn(assert_request_is_decompressed));
let _ = svc.ready().await.unwrap().call(req).await.unwrap();
}

#[tokio::test]
async fn unaccepted_content_encoding_returns_unsupported_media_type() {
let req = request_gzip();
let mut svc = RequestDecompression::new(service_fn(should_not_be_called)).gzip(false);
let res = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(StatusCode::UNSUPPORTED_MEDIA_TYPE, res.status());
}

#[tokio::test]
async fn pass_through_unsupported_encoding_when_enabled() {
let req = request_gzip();
let mut svc = RequestDecompression::new(service_fn(assert_request_is_passed_through))
.pass_through_unaccepted(true)
.gzip(false);
let _ = svc.ready().await.unwrap().call(req).await.unwrap();
}

async fn assert_request_is_decompressed(
req: Request<DecompressionBody<Body>>,
) -> Result<Response<Body>, Error> {
let (parts, mut body) = req.into_parts();
let body = read_body(&mut body).await;

assert_eq!(body, b"Hello?");
assert!(!parts.headers.contains_key(header::CONTENT_ENCODING));

Ok(Response::new(Body::from("Hello, World!")))
}

async fn assert_request_is_passed_through(
req: Request<DecompressionBody<Body>>,
) -> Result<Response<Body>, Error> {
let (parts, mut body) = req.into_parts();
let body = read_body(&mut body).await;

assert_ne!(body, b"Hello?");
assert!(parts.headers.contains_key(header::CONTENT_ENCODING));

Ok(Response::new(Body::empty()))
}

async fn should_not_be_called(
_: Request<DecompressionBody<Body>>,
) -> Result<Response<Body>, Error> {
panic!("Inner service should not be called");
}

fn request_gzip() -> Request<Body> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(b"Hello?").unwrap();
let body = encoder.finish().unwrap();
Request::builder()
.header(header::CONTENT_ENCODING, "gzip")
.body(Body::from(body))
.unwrap()
}

async fn read_body(body: &mut DecompressionBody<Body>) -> Vec<u8> {
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
data.freeze().to_vec()
}

#[allow(dead_code)]
async fn is_compatible_with_hyper() {
let svc = service_fn(assert_request_is_decompressed);
let svc = RequestDecompression::new(svc);

let make_service = Shared::new(svc);

let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let server = Server::bind(&addr).serve(make_service);
server.await.unwrap();
}
}

0 comments on commit 67130ab

Please sign in to comment.