Skip to content

Commit dc29c17

Browse files
authoredAug 25, 2023
feat(web): Add GrpcWebClientService (#1472)
This adds `grpc-web` support for clients in `tonic-web`. This is done by reusing the server side encoding/decoding but wrapping it in different directions.
1 parent 388b177 commit dc29c17

File tree

5 files changed

+357
-71
lines changed

5 files changed

+357
-71
lines changed
 

‎examples/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ routeguide = ["dep:async-stream", "tokio-stream", "dep:rand", "dep:serde", "dep:
270270
reflection = ["dep:tonic-reflection"]
271271
autoreload = ["tokio-stream/net", "dep:listenfd"]
272272
health = ["dep:tonic-health"]
273-
grpc-web = ["dep:tonic-web", "dep:bytes", "dep:http", "dep:hyper", "dep:tracing-subscriber"]
273+
grpc-web = ["dep:tonic-web", "dep:bytes", "dep:http", "dep:hyper", "dep:tracing-subscriber", "dep:tower"]
274274
tracing = ["dep:tracing", "dep:tracing-subscriber"]
275275
hyper-warp = ["dep:either", "dep:tower", "dep:hyper", "dep:http", "dep:http-body", "dep:warp"]
276276
hyper-warp-multiplex = ["hyper-warp"]

‎examples/src/grpc-web/client.rs

+13-59
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,28 @@
1-
use bytes::{Buf, BufMut, Bytes, BytesMut};
2-
use hello_world::{HelloReply, HelloRequest};
3-
use http::header::{ACCEPT, CONTENT_TYPE};
1+
use hello_world::{greeter_client::GreeterClient, HelloRequest};
2+
use tonic_web::GrpcWebClientLayer;
43

54
pub mod hello_world {
65
tonic::include_proto!("helloworld");
76
}
87

98
#[tokio::main]
109
async fn main() -> Result<(), Box<dyn std::error::Error>> {
11-
let msg = HelloRequest {
12-
name: "Bob".to_string(),
13-
};
10+
// Must use hyper directly...
11+
let client = hyper::Client::builder().build_http();
1412

15-
// a good old http/1.1 request
16-
let request = http::Request::builder()
17-
.version(http::Version::HTTP_11)
18-
.method(http::Method::POST)
19-
.uri("http://127.0.0.1:3000/helloworld.Greeter/SayHello")
20-
.header(CONTENT_TYPE, "application/grpc-web")
21-
.header(ACCEPT, "application/grpc-web")
22-
.body(hyper::Body::from(encode_body(msg)))
23-
.unwrap();
13+
let svc = tower::ServiceBuilder::new()
14+
.layer(GrpcWebClientLayer::new())
15+
.service(client);
2416

25-
let client = hyper::Client::new();
17+
let mut client = GreeterClient::with_origin(svc, "http://127.0.0.1:3000".try_into()?);
2618

27-
let response = client.request(request).await.unwrap();
19+
let request = tonic::Request::new(HelloRequest {
20+
name: "Tonic".into(),
21+
});
2822

29-
assert_eq!(
30-
response.headers().get(CONTENT_TYPE).unwrap(),
31-
"application/grpc-web+proto"
32-
);
23+
let response = client.say_hello(request).await?;
3324

34-
let body = response.into_body();
35-
let reply = decode_body::<HelloReply>(body).await;
36-
37-
println!("REPLY={:?}", reply);
25+
println!("RESPONSE={:?}", response);
3826

3927
Ok(())
4028
}
41-
42-
// one byte for the compression flag plus four bytes for the length
43-
const GRPC_HEADER_SIZE: usize = 5;
44-
45-
fn encode_body<T>(msg: T) -> Bytes
46-
where
47-
T: prost::Message,
48-
{
49-
let msg_len = msg.encoded_len();
50-
let mut buf = BytesMut::with_capacity(GRPC_HEADER_SIZE + msg_len);
51-
52-
// compression flag, 0 means "no compression"
53-
buf.put_u8(0);
54-
buf.put_u32(msg_len as u32);
55-
56-
msg.encode(&mut buf).unwrap();
57-
buf.freeze()
58-
}
59-
60-
async fn decode_body<T>(body: hyper::Body) -> T
61-
where
62-
T: Default + prost::Message,
63-
{
64-
let mut body = hyper::body::to_bytes(body).await.unwrap();
65-
66-
// ignore the compression flag
67-
body.advance(1);
68-
69-
let len = body.get_u32();
70-
#[allow(clippy::let_and_return)]
71-
let msg = T::decode(&mut body.split_to(len as usize)).unwrap();
72-
73-
msg
74-
}

‎tonic-web/src/call.rs

+226-11
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@ use std::task::{Context, Poll};
55
use base64::Engine as _;
66
use bytes::{Buf, BufMut, Bytes, BytesMut};
77
use futures_core::ready;
8-
use http::{header, HeaderMap, HeaderValue};
8+
use http::{header, HeaderMap, HeaderName, HeaderValue};
99
use http_body::{Body, SizeHint};
1010
use pin_project::pin_project;
1111
use tokio_stream::Stream;
1212
use tonic::Status;
1313

1414
use self::content_types::*;
1515

16+
// A grpc header is u8 (flag) + u32 (msg len)
17+
const GRPC_HEADER_SIZE: usize = 1 + 4;
18+
1619
pub(crate) mod content_types {
1720
use http::{header::CONTENT_TYPE, HeaderMap};
1821

@@ -43,8 +46,9 @@ const GRPC_WEB_TRAILERS_BIT: u8 = 0b10000000;
4346

4447
#[derive(Copy, Clone, PartialEq, Debug)]
4548
enum Direction {
46-
Request,
47-
Response,
49+
Decode,
50+
Encode,
51+
Empty,
4852
}
4953

5054
#[derive(Copy, Clone, PartialEq, Debug)]
@@ -53,35 +57,78 @@ pub(crate) enum Encoding {
5357
None,
5458
}
5559

60+
/// HttpBody adapter for the grpc web based services.
61+
#[derive(Debug)]
5662
#[pin_project]
57-
pub(crate) struct GrpcWebCall<B> {
63+
pub struct GrpcWebCall<B> {
5864
#[pin]
5965
inner: B,
6066
buf: BytesMut,
6167
direction: Direction,
6268
encoding: Encoding,
6369
poll_trailers: bool,
70+
client: bool,
71+
trailers: Option<HeaderMap>,
72+
}
73+
74+
impl<B: Default> Default for GrpcWebCall<B> {
75+
fn default() -> Self {
76+
Self {
77+
inner: Default::default(),
78+
buf: Default::default(),
79+
direction: Direction::Empty,
80+
encoding: Encoding::None,
81+
poll_trailers: Default::default(),
82+
client: Default::default(),
83+
trailers: Default::default(),
84+
}
85+
}
6486
}
6587

6688
impl<B> GrpcWebCall<B> {
6789
pub(crate) fn request(inner: B, encoding: Encoding) -> Self {
68-
Self::new(inner, Direction::Request, encoding)
90+
Self::new(inner, Direction::Decode, encoding)
6991
}
7092

7193
pub(crate) fn response(inner: B, encoding: Encoding) -> Self {
72-
Self::new(inner, Direction::Response, encoding)
94+
Self::new(inner, Direction::Encode, encoding)
95+
}
96+
97+
pub(crate) fn client_request(inner: B) -> Self {
98+
Self::new_client(inner, Direction::Encode, Encoding::None)
99+
}
100+
101+
pub(crate) fn client_response(inner: B) -> Self {
102+
Self::new_client(inner, Direction::Decode, Encoding::None)
103+
}
104+
105+
fn new_client(inner: B, direction: Direction, encoding: Encoding) -> Self {
106+
GrpcWebCall {
107+
inner,
108+
buf: BytesMut::with_capacity(match (direction, encoding) {
109+
(Direction::Encode, Encoding::Base64) => BUFFER_SIZE,
110+
_ => 0,
111+
}),
112+
direction,
113+
encoding,
114+
poll_trailers: true,
115+
client: true,
116+
trailers: None,
117+
}
73118
}
74119

75120
fn new(inner: B, direction: Direction, encoding: Encoding) -> Self {
76121
GrpcWebCall {
77122
inner,
78123
buf: BytesMut::with_capacity(match (direction, encoding) {
79-
(Direction::Response, Encoding::Base64) => BUFFER_SIZE,
124+
(Direction::Encode, Encoding::Base64) => BUFFER_SIZE,
80125
_ => 0,
81126
}),
82127
direction,
83128
encoding,
84129
poll_trailers: true,
130+
client: false,
131+
trailers: None,
85132
}
86133
}
87134

@@ -192,20 +239,52 @@ where
192239
type Error = Status;
193240

194241
fn poll_data(
195-
self: Pin<&mut Self>,
242+
mut self: Pin<&mut Self>,
196243
cx: &mut Context<'_>,
197244
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
245+
if self.client && self.direction == Direction::Decode {
246+
let buf = ready!(self.as_mut().poll_decode(cx));
247+
248+
return if let Some(Ok(mut buf)) = buf {
249+
// We found some trailers so extract them since we
250+
// want to return them via `poll_trailers`.
251+
if let Some(len) = find_trailers(&buf[..]) {
252+
// Extract up to len of where the trailers are at
253+
let msg_buf = buf.copy_to_bytes(len);
254+
match decode_trailers_frame(buf) {
255+
Ok(Some(trailers)) => {
256+
self.project().trailers.replace(trailers);
257+
}
258+
Err(e) => return Poll::Ready(Some(Err(e))),
259+
_ => {}
260+
}
261+
262+
if msg_buf.has_remaining() {
263+
return Poll::Ready(Some(Ok(msg_buf)));
264+
} else {
265+
return Poll::Ready(None);
266+
}
267+
}
268+
269+
Poll::Ready(Some(Ok(buf)))
270+
} else {
271+
Poll::Ready(buf)
272+
};
273+
}
274+
198275
match self.direction {
199-
Direction::Request => self.poll_decode(cx),
200-
Direction::Response => self.poll_encode(cx),
276+
Direction::Decode => self.poll_decode(cx),
277+
Direction::Encode => self.poll_encode(cx),
278+
Direction::Empty => Poll::Ready(None),
201279
}
202280
}
203281

204282
fn poll_trailers(
205283
self: Pin<&mut Self>,
206284
_: &mut Context<'_>,
207285
) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
208-
Poll::Ready(Ok(None))
286+
let trailers = self.project().trailers.take();
287+
Poll::Ready(Ok(trailers))
209288
}
210289

211290
fn is_end_stream(&self) -> bool {
@@ -268,6 +347,56 @@ fn encode_trailers(trailers: HeaderMap) -> Vec<u8> {
268347
})
269348
}
270349

350+
fn decode_trailers_frame(mut buf: Bytes) -> Result<Option<HeaderMap>, Status> {
351+
if buf.remaining() < GRPC_HEADER_SIZE {
352+
return Ok(None);
353+
}
354+
355+
buf.get_u8();
356+
buf.get_u32();
357+
358+
let mut map = HeaderMap::new();
359+
let mut temp_buf = buf.clone();
360+
361+
let mut trailers = Vec::new();
362+
let mut cursor_pos = 0;
363+
364+
for (i, b) in buf.iter().enumerate() {
365+
if b == &b'\r' && buf.get(i + 1) == Some(&b'\n') {
366+
let trailer = temp_buf.copy_to_bytes(i - cursor_pos);
367+
cursor_pos = i;
368+
trailers.push(trailer);
369+
if temp_buf.has_remaining() {
370+
temp_buf.get_u8();
371+
temp_buf.get_u8();
372+
}
373+
}
374+
}
375+
376+
for trailer in trailers {
377+
let mut s = trailer.split(|b| b == &b':');
378+
let key = s
379+
.next()
380+
.ok_or_else(|| Status::internal("trailers couldn't parse key"))?;
381+
let value = s
382+
.next()
383+
.ok_or_else(|| Status::internal("trailers couldn't parse value"))?;
384+
385+
let value = value
386+
.split(|b| b == &b'\r')
387+
.next()
388+
.ok_or_else(|| Status::internal("trailers was not escaped"))?;
389+
390+
let header_key = HeaderName::try_from(key)
391+
.map_err(|e| Status::internal(format!("Unable to parse HeaderName: {}", e)))?;
392+
let header_value = HeaderValue::try_from(value)
393+
.map_err(|e| Status::internal(format!("Unable to parse HeaderValue: {}", e)))?;
394+
map.insert(header_key, header_value);
395+
}
396+
397+
Ok(Some(map))
398+
}
399+
271400
fn make_trailers_frame(trailers: HeaderMap) -> Vec<u8> {
272401
let trailers = encode_trailers(trailers);
273402
let len = trailers.len();
@@ -281,6 +410,41 @@ fn make_trailers_frame(trailers: HeaderMap) -> Vec<u8> {
281410
frame
282411
}
283412

413+
/// Search some buffer for grpc-web trailers headers and return
414+
/// its location in the original buf. If `None` is returned we did
415+
/// not find a trailers in this buffer either because its incomplete
416+
/// or the buffer jsut contained grpc message frames.
417+
fn find_trailers(buf: &[u8]) -> Option<usize> {
418+
let mut len = 0;
419+
let mut temp_buf = &buf[..];
420+
421+
loop {
422+
// To check each frame, there must be at least GRPC_HEADER_SIZE
423+
// amount of bytes available otherwise the buffer is incomplete.
424+
if temp_buf.is_empty() || temp_buf.len() < GRPC_HEADER_SIZE {
425+
return None;
426+
}
427+
428+
let header = temp_buf.get_u8();
429+
430+
if header == GRPC_WEB_TRAILERS_BIT {
431+
return Some(len);
432+
}
433+
434+
let msg_len = temp_buf.get_u32();
435+
436+
len += msg_len as usize + 4 + 1;
437+
438+
// If the msg len of a non-grpc-web trailer frame is larger than
439+
// the overall buffer we know within that buffer there are no trailers.
440+
if len > buf.len() {
441+
return None;
442+
}
443+
444+
temp_buf = &buf[len as usize..];
445+
}
446+
}
447+
284448
#[cfg(test)]
285449
mod tests {
286450
use super::*;
@@ -305,4 +469,55 @@ mod tests {
305469
assert_eq!(Encoding::from_accept(&headers), case.1, "{}", case.0);
306470
}
307471
}
472+
473+
#[test]
474+
fn decode_trailers() {
475+
let mut headers = HeaderMap::new();
476+
headers.insert("grpc-status", 0.try_into().unwrap());
477+
headers.insert("grpc-message", "this is a message".try_into().unwrap());
478+
479+
let trailers = make_trailers_frame(headers.clone());
480+
481+
let buf = Bytes::from(trailers);
482+
483+
let map = decode_trailers_frame(buf).unwrap().unwrap();
484+
485+
assert_eq!(headers, map);
486+
}
487+
488+
#[test]
489+
fn find_trailers_non_buffered() {
490+
// Byte version of this:
491+
// b"\x80\0\0\0\x0fgrpc-status:0\r\n"
492+
let buf = vec![
493+
128, 0, 0, 0, 15, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 58, 48, 13, 10,
494+
];
495+
496+
let out = find_trailers(&buf[..]);
497+
498+
assert_eq!(out, Some(0));
499+
}
500+
501+
#[test]
502+
fn find_trailers_buffered() {
503+
// Byte version of this:
504+
// b"\0\0\0\0L\n$975738af-1a17-4aea-b887-ed0bbced6093\x1a$da609e9b-f470-4cc0-a691-3fd6a005a436\x80\0\0\0\x0fgrpc-status:0\r\n"
505+
let buf = vec![
506+
0, 0, 0, 0, 76, 10, 36, 57, 55, 53, 55, 51, 56, 97, 102, 45, 49, 97, 49, 55, 45, 52,
507+
97, 101, 97, 45, 98, 56, 56, 55, 45, 101, 100, 48, 98, 98, 99, 101, 100, 54, 48, 57,
508+
51, 26, 36, 100, 97, 54, 48, 57, 101, 57, 98, 45, 102, 52, 55, 48, 45, 52, 99, 99, 48,
509+
45, 97, 54, 57, 49, 45, 51, 102, 100, 54, 97, 48, 48, 53, 97, 52, 51, 54, 128, 0, 0, 0,
510+
15, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 58, 48, 13, 10,
511+
];
512+
513+
let out = find_trailers(&buf[..]);
514+
515+
assert_eq!(out, Some(81));
516+
517+
let trailers = decode_trailers_frame(Bytes::copy_from_slice(&buf[81..]))
518+
.unwrap()
519+
.unwrap();
520+
let status = trailers.get("grpc-status").unwrap();
521+
assert_eq!(status.to_str().unwrap(), "0")
522+
}
308523
}

‎tonic-web/src/client.rs

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
use bytes::Bytes;
2+
use futures_core::ready;
3+
use http::header::CONTENT_TYPE;
4+
use http::{Request, Response, Version};
5+
use http_body::Body;
6+
use pin_project::pin_project;
7+
use std::error::Error;
8+
use std::future::Future;
9+
use std::pin::Pin;
10+
use std::task::{Context, Poll};
11+
use tower_layer::Layer;
12+
use tower_service::Service;
13+
use tracing::debug;
14+
15+
use crate::call::content_types::GRPC_WEB;
16+
use crate::call::GrpcWebCall;
17+
18+
/// Layer implementing the grpc-web protocol for clients.
19+
#[derive(Debug, Clone)]
20+
pub struct GrpcWebClientLayer {
21+
_priv: (),
22+
}
23+
24+
impl GrpcWebClientLayer {
25+
/// Create a new grpc-web for clients layer.
26+
pub fn new() -> GrpcWebClientLayer {
27+
Self { _priv: () }
28+
}
29+
}
30+
31+
impl Default for GrpcWebClientLayer {
32+
fn default() -> Self {
33+
Self::new()
34+
}
35+
}
36+
37+
impl<S> Layer<S> for GrpcWebClientLayer {
38+
type Service = GrpcWebClientService<S>;
39+
40+
fn layer(&self, inner: S) -> Self::Service {
41+
GrpcWebClientService::new(inner)
42+
}
43+
}
44+
45+
/// A [`Service`] that wraps some inner http service that will
46+
/// coerce requests coming from [`tonic::client::Grpc`] into proper
47+
/// `grpc-web` requests.
48+
#[derive(Debug, Clone)]
49+
pub struct GrpcWebClientService<S> {
50+
inner: S,
51+
}
52+
53+
impl<S> GrpcWebClientService<S> {
54+
/// Create a new grpc-web for clients service.
55+
pub fn new(inner: S) -> Self {
56+
Self { inner }
57+
}
58+
}
59+
60+
impl<S, B1, B2> Service<Request<B1>> for GrpcWebClientService<S>
61+
where
62+
S: Service<Request<GrpcWebCall<B1>>, Response = Response<B2>>,
63+
B1: Body,
64+
B2: Body<Data = Bytes>,
65+
B2::Error: Error,
66+
{
67+
type Response = Response<GrpcWebCall<B2>>;
68+
type Error = S::Error;
69+
type Future = ResponseFuture<S::Future>;
70+
71+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
72+
self.inner.poll_ready(cx)
73+
}
74+
75+
fn call(&mut self, mut req: Request<B1>) -> Self::Future {
76+
if req.version() == Version::HTTP_2 {
77+
debug!("coercing HTTP2 request to HTTP1.1");
78+
79+
*req.version_mut() = Version::HTTP_11;
80+
}
81+
82+
req.headers_mut()
83+
.insert(CONTENT_TYPE, GRPC_WEB.try_into().unwrap());
84+
85+
let req = req.map(GrpcWebCall::client_request);
86+
87+
let fut = self.inner.call(req);
88+
89+
ResponseFuture { inner: fut }
90+
}
91+
}
92+
93+
/// Response future for the [`GrpcWebService`].
94+
#[allow(missing_debug_implementations)]
95+
#[pin_project]
96+
#[must_use = "futures do nothing unless polled"]
97+
pub struct ResponseFuture<F> {
98+
#[pin]
99+
inner: F,
100+
}
101+
102+
impl<F, B, E> Future for ResponseFuture<F>
103+
where
104+
B: Body<Data = Bytes>,
105+
F: Future<Output = Result<Response<B>, E>>,
106+
{
107+
type Output = Result<Response<GrpcWebCall<B>>, E>;
108+
109+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
110+
let res = ready!(self.project().inner.poll(cx));
111+
112+
Poll::Ready(res.map(|r| r.map(GrpcWebCall::client_response)))
113+
}
114+
}

‎tonic-web/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,13 @@
9797
#![doc(html_root_url = "https://docs.rs/tonic-web/0.9.2")]
9898
#![doc(issue_tracker_base_url = "https://github.com/hyperium/tonic/issues/")]
9999

100+
pub use call::GrpcWebCall;
101+
pub use client::{GrpcWebClientLayer, GrpcWebClientService};
100102
pub use layer::GrpcWebLayer;
101103
pub use service::{GrpcWebService, ResponseFuture};
102104

103105
mod call;
106+
mod client;
104107
mod layer;
105108
mod service;
106109

0 commit comments

Comments
 (0)
Please sign in to comment.