Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request decompression #282

Merged
merged 20 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion tower-http/src/decompression/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ use std::{io, marker::PhantomData, pin::Pin, task::Poll};
use tokio_util::io::StreamReader;

pin_project! {
/// Response body of [`Decompression`].
/// Response body of [`RequestDecompression`] and [`Decompression`].
///
/// [`RequestDecompression`]: super::RequestDecompression
/// [`Decompression`]: super::Decompression
pub struct DecompressionBody<B>
where
Expand Down
54 changes: 52 additions & 2 deletions tower-http/src/decompression/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,51 @@
//! Middleware that decompresses response bodies.
//! Middleware that decompresses request and response bodies.
//!
//! # Example
//! # Examples
//!
//! #### Request
//! ```rust
//! use bytes::BytesMut;
//! use flate2::{write::GzEncoder, Compression};
//! use http::{header, HeaderValue, Request, Response};
//! use http_body::Body as _; // for Body::data
//! use hyper::Body;
//! use std::{error::Error, io::Write};
//! use tower::{Service, ServiceBuilder, service_fn, ServiceExt};
//! use tower_http::{BoxError, decompression::{DecompressionBody, RequestDecompressionLayer}};
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), BoxError> {
//! // A request encoded with gzip coming from some HTTP client.
//! let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
//! encoder.write_all(b"Hello?")?;
//! let request = Request::builder()
//! .header(header::CONTENT_ENCODING, "gzip")
//! .body(Body::from(encoder.finish()?))?;
//!
//! // Our HTTP server
//! let mut server = ServiceBuilder::new()
//! // Automatically decompress request bodies.
//! .layer(RequestDecompressionLayer::new())
//! .service(service_fn(handler));
//!
//! // Send the request, with the gzip encoded body, to our server.
//! let _response = server.ready().await?.call(request).await?;
//!
//! // Handler receives request whose body is decoded when read
//! async fn handler(mut req: Request<DecompressionBody<Body>>) -> Result<Response<Body>, BoxError>{
//! let mut data = BytesMut::new();
//! while let Some(chunk) = req.body_mut().data().await {
//! let chunk = chunk?;
//! data.extend_from_slice(&chunk[..]);
//! }
//! assert_eq!(data.freeze().to_vec(), b"Hello?");
//! Ok(Response::new(Body::from("Hello, World!")))
//! }
//! # Ok(())
//! # }
//! ```
//!
//! #### Response
//! ```rust
//! use bytes::BytesMut;
//! use http::{Request, Response};
Expand Down Expand Up @@ -53,6 +97,8 @@
//! # }
//! ```

mod request;

mod body;
mod future;
mod layer;
Expand All @@ -63,6 +109,10 @@ pub use self::{
service::Decompression,
};

pub use self::request::future::RequestDecompressionFuture;
pub use self::request::layer::RequestDecompressionLayer;
pub use self::request::service::RequestDecompression;
wanderinglethe marked this conversation as resolved.
Show resolved Hide resolved

#[cfg(test)]
mod tests {
use super::*;
Expand Down
95 changes: 95 additions & 0 deletions tower-http/src/decompression/request/future.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use crate::compression_utils::AcceptEncoding;
use crate::BoxError;
use bytes::Buf;
use http::{header, HeaderValue, Response, StatusCode};
use http_body::{combinators::UnsyncBoxBody, Body, Empty};
use pin_project_lite::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;

pin_project! {
#[derive(Debug)]
/// Response future of [`RequestDecompression`]
pub struct RequestDecompressionFuture<F, B, E>
where
F: Future<Output = Result<Response<B>, E>>,
B: Body
{
#[pin]
kind: Kind<F, B, E>,
}
}

pin_project! {
#[derive(Debug)]
#[project = StateProj]
enum Kind<F, B, E>
where
F: Future<Output = Result<Response<B>, E>>,
B: Body
{
Inner {
#[pin]
fut: F
},
Unsupported {
#[pin]
accept: AcceptEncoding
},
}
}

impl<F, B, E> RequestDecompressionFuture<F, B, E>
where
F: Future<Output = Result<Response<B>, E>>,
B: Body,
{
#[must_use]
pub(super) fn unsupported_encoding(accept: AcceptEncoding) -> Self {
Self {
kind: Kind::Unsupported { accept },
}
}

#[must_use]
pub(super) fn inner(fut: F) -> Self {
Self {
kind: Kind::Inner { fut },
}
}
}

impl<F, B, E> Future for RequestDecompressionFuture<F, B, E>
where
F: Future<Output = Result<Response<B>, E>>,
B: Body + Send + 'static,
B::Data: Buf + 'static,
B::Error: Into<BoxError> + 'static,
E: Into<BoxError>,
{
type Output = Result<Response<UnsyncBoxBody<B::Data, BoxError>>, BoxError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().kind.project() {
StateProj::Inner { fut } => fut
.poll(cx)
.map_ok(|res| res.map(|body| body.map_err(Into::into).boxed_unsync()))
.map_err(Into::into),
StateProj::Unsupported { accept } => {
let res = Response::builder()
.header(
header::ACCEPT_ENCODING,
accept
.to_header_value()
.unwrap_or(HeaderValue::from_static("identity")),
)
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
.body(Empty::new().map_err(Into::into).boxed_unsync())
.unwrap();
Poll::Ready(Ok(res))
}
}
}
}
105 changes: 105 additions & 0 deletions tower-http/src/decompression/request/layer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use super::service::RequestDecompression;
use crate::compression_utils::AcceptEncoding;
use tower_layer::Layer;

/// Decompresses request bodies and calls its underlying service.
///
/// Transparently decompresses request bodies based on the `Content-Encoding` header.
/// When the encoding in the `Content-Encoding` header is not accepted an `Unsupported Media Type`
/// status code will be returned with the accepted encodings in the `Accept-Encoding` header.
///
/// Enabling pass-through of unaccepted encodings will not return an `Unsupported Media Type`. But
/// will call the underlying service with the unmodified request if the encoding is not supported.
/// This is disabled by default.
///
/// See the [module docs](crate::decompression) for more details.
#[derive(Debug, Default, Clone)]
pub struct RequestDecompressionLayer {
accept: AcceptEncoding,
pass_through_unaccepted: bool,
}

impl<S> Layer<S> for RequestDecompressionLayer {
type Service = RequestDecompression<S>;

fn layer(&self, service: S) -> Self::Service {
RequestDecompression {
inner: service,
accept: self.accept,
pass_through_unaccepted: self.pass_through_unaccepted,
}
}
}

impl RequestDecompressionLayer {
/// Creates a new `RequestDecompressionLayer`.
pub fn new() -> Self {
Default::default()
}

/// Sets whether to support gzip encoding.
#[cfg(feature = "decompression-gzip")]
pub fn gzip(mut self, enable: bool) -> Self {
self.accept.set_gzip(enable);
self
}

/// Sets whether to support Deflate encoding.
#[cfg(feature = "decompression-deflate")]
pub fn deflate(mut self, enable: bool) -> Self {
self.accept.set_deflate(enable);
self
}

/// Sets whether to support Brotli encoding.
#[cfg(feature = "decompression-br")]
pub fn br(mut self, enable: bool) -> Self {
self.accept.set_br(enable);
self
}

/// Sets whether to support Zstd encoding.
#[cfg(feature = "decompression-zstd")]
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}

/// Disables support for gzip encoding.
///
/// This method is available even if the `gzip` crate feature is disabled.
pub fn no_gzip(mut self) -> Self {
self.accept.set_gzip(false);
self
}

/// Disables support for Deflate encoding.
///
/// This method is available even if the `deflate` crate feature is disabled.
pub fn no_deflate(mut self) -> Self {
self.accept.set_deflate(false);
self
}

/// Disables support for Brotli encoding.
///
/// This method is available even if the `br` crate feature is disabled.
pub fn no_br(mut self) -> Self {
self.accept.set_br(false);
self
}

/// Disables support for Zstd encoding.
///
/// This method is available even if the `zstd` crate feature is disabled.
pub fn no_zstd(mut self) -> Self {
self.accept.set_zstd(false);
self
}

/// Sets whether to pass through the request even when the encoding is not supported.
pub fn pass_through_unaccepted(mut self, enable: bool) -> Self {
self.pass_through_unaccepted = enable;
self
}
}
109 changes: 109 additions & 0 deletions tower-http/src/decompression/request/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
pub(super) mod future;
pub(super) mod layer;
pub(super) mod service;

#[cfg(test)]
mod tests {
use super::service::RequestDecompression;
use crate::decompression::DecompressionBody;
use bytes::BytesMut;
use flate2::{write::GzEncoder, Compression};
use http::{header, Response, StatusCode};
use http_body::Body as _;
use hyper::{Body, Error, Request, Server};
use std::io::Write;
use std::net::SocketAddr;
use tower::{make::Shared, service_fn, Service, ServiceExt};

#[tokio::test]
async fn decompress_accepted_encoding() {
let req = request_gzip();
let mut svc = RequestDecompression::new(service_fn(assert_request_is_decompressed));
let _ = svc.ready().await.unwrap().call(req).await.unwrap();
}

#[tokio::test]
async fn support_unencoded_body() {
let req = Request::builder().body(Body::from("Hello?")).unwrap();
let mut svc = RequestDecompression::new(service_fn(assert_request_is_decompressed));
let _ = svc.ready().await.unwrap().call(req).await.unwrap();
}

#[tokio::test]
async fn unaccepted_content_encoding_returns_unsupported_media_type() {
let req = request_gzip();
let mut svc = RequestDecompression::new(service_fn(should_not_be_called)).gzip(false);
let res = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(StatusCode::UNSUPPORTED_MEDIA_TYPE, res.status());
}

#[tokio::test]
async fn pass_through_unsupported_encoding_when_enabled() {
let req = request_gzip();
let mut svc = RequestDecompression::new(service_fn(assert_request_is_passed_through))
.pass_through_unaccepted(true)
.gzip(false);
let _ = svc.ready().await.unwrap().call(req).await.unwrap();
}

async fn assert_request_is_decompressed(
req: Request<DecompressionBody<Body>>,
) -> Result<Response<Body>, Error> {
let (parts, mut body) = req.into_parts();
let body = read_body(&mut body).await;

assert_eq!(body, b"Hello?");
assert!(!parts.headers.contains_key(header::CONTENT_ENCODING));

Ok(Response::new(Body::from("Hello, World!")))
}

async fn assert_request_is_passed_through(
req: Request<DecompressionBody<Body>>,
) -> Result<Response<Body>, Error> {
let (parts, mut body) = req.into_parts();
let body = read_body(&mut body).await;

assert_ne!(body, b"Hello?");
assert!(parts.headers.contains_key(header::CONTENT_ENCODING));

Ok(Response::new(Body::empty()))
}

async fn should_not_be_called(
_: Request<DecompressionBody<Body>>,
) -> Result<Response<Body>, Error> {
panic!("Inner service should not be called");
}

fn request_gzip() -> Request<Body> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(b"Hello?").unwrap();
let body = encoder.finish().unwrap();
Request::builder()
.header(header::CONTENT_ENCODING, "gzip")
.body(Body::from(body))
.unwrap()
}

async fn read_body(body: &mut DecompressionBody<Body>) -> Vec<u8> {
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
data.freeze().to_vec()
}

#[allow(dead_code)]
async fn is_compatible_with_hyper() {
let svc = service_fn(assert_request_is_decompressed);
let svc = RequestDecompression::new(svc);

let make_service = Shared::new(svc);

let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let server = Server::bind(&addr).serve(make_service);
server.await.unwrap();
}
}