@@ -5,14 +5,17 @@ use std::task::{Context, Poll};
5
5
use base64:: Engine as _;
6
6
use bytes:: { Buf , BufMut , Bytes , BytesMut } ;
7
7
use futures_core:: ready;
8
- use http:: { header, HeaderMap , HeaderValue } ;
8
+ use http:: { header, HeaderMap , HeaderName , HeaderValue } ;
9
9
use http_body:: { Body , SizeHint } ;
10
10
use pin_project:: pin_project;
11
11
use tokio_stream:: Stream ;
12
12
use tonic:: Status ;
13
13
14
14
use self :: content_types:: * ;
15
15
16
+ // A grpc header is u8 (flag) + u32 (msg len)
17
+ const GRPC_HEADER_SIZE : usize = 1 + 4 ;
18
+
16
19
pub ( crate ) mod content_types {
17
20
use http:: { header:: CONTENT_TYPE , HeaderMap } ;
18
21
@@ -43,8 +46,9 @@ const GRPC_WEB_TRAILERS_BIT: u8 = 0b10000000;
43
46
44
47
#[ derive( Copy , Clone , PartialEq , Debug ) ]
45
48
enum Direction {
46
- Request ,
47
- Response ,
49
+ Decode ,
50
+ Encode ,
51
+ Empty ,
48
52
}
49
53
50
54
#[ derive( Copy , Clone , PartialEq , Debug ) ]
@@ -53,35 +57,78 @@ pub(crate) enum Encoding {
53
57
None ,
54
58
}
55
59
60
+ /// HttpBody adapter for the grpc web based services.
61
+ #[ derive( Debug ) ]
56
62
#[ pin_project]
57
- pub ( crate ) struct GrpcWebCall < B > {
63
+ pub struct GrpcWebCall < B > {
58
64
#[ pin]
59
65
inner : B ,
60
66
buf : BytesMut ,
61
67
direction : Direction ,
62
68
encoding : Encoding ,
63
69
poll_trailers : bool ,
70
+ client : bool ,
71
+ trailers : Option < HeaderMap > ,
72
+ }
73
+
74
+ impl < B : Default > Default for GrpcWebCall < B > {
75
+ fn default ( ) -> Self {
76
+ Self {
77
+ inner : Default :: default ( ) ,
78
+ buf : Default :: default ( ) ,
79
+ direction : Direction :: Empty ,
80
+ encoding : Encoding :: None ,
81
+ poll_trailers : Default :: default ( ) ,
82
+ client : Default :: default ( ) ,
83
+ trailers : Default :: default ( ) ,
84
+ }
85
+ }
64
86
}
65
87
66
88
impl < B > GrpcWebCall < B > {
67
89
pub ( crate ) fn request ( inner : B , encoding : Encoding ) -> Self {
68
- Self :: new ( inner, Direction :: Request , encoding)
90
+ Self :: new ( inner, Direction :: Decode , encoding)
69
91
}
70
92
71
93
pub ( crate ) fn response ( inner : B , encoding : Encoding ) -> Self {
72
- Self :: new ( inner, Direction :: Response , encoding)
94
+ Self :: new ( inner, Direction :: Encode , encoding)
95
+ }
96
+
97
+ pub ( crate ) fn client_request ( inner : B ) -> Self {
98
+ Self :: new_client ( inner, Direction :: Encode , Encoding :: None )
99
+ }
100
+
101
+ pub ( crate ) fn client_response ( inner : B ) -> Self {
102
+ Self :: new_client ( inner, Direction :: Decode , Encoding :: None )
103
+ }
104
+
105
+ fn new_client ( inner : B , direction : Direction , encoding : Encoding ) -> Self {
106
+ GrpcWebCall {
107
+ inner,
108
+ buf : BytesMut :: with_capacity ( match ( direction, encoding) {
109
+ ( Direction :: Encode , Encoding :: Base64 ) => BUFFER_SIZE ,
110
+ _ => 0 ,
111
+ } ) ,
112
+ direction,
113
+ encoding,
114
+ poll_trailers : true ,
115
+ client : true ,
116
+ trailers : None ,
117
+ }
73
118
}
74
119
75
120
fn new ( inner : B , direction : Direction , encoding : Encoding ) -> Self {
76
121
GrpcWebCall {
77
122
inner,
78
123
buf : BytesMut :: with_capacity ( match ( direction, encoding) {
79
- ( Direction :: Response , Encoding :: Base64 ) => BUFFER_SIZE ,
124
+ ( Direction :: Encode , Encoding :: Base64 ) => BUFFER_SIZE ,
80
125
_ => 0 ,
81
126
} ) ,
82
127
direction,
83
128
encoding,
84
129
poll_trailers : true ,
130
+ client : false ,
131
+ trailers : None ,
85
132
}
86
133
}
87
134
@@ -192,20 +239,52 @@ where
192
239
type Error = Status ;
193
240
194
241
fn poll_data (
195
- self : Pin < & mut Self > ,
242
+ mut self : Pin < & mut Self > ,
196
243
cx : & mut Context < ' _ > ,
197
244
) -> Poll < Option < Result < Self :: Data , Self :: Error > > > {
245
+ if self . client && self . direction == Direction :: Decode {
246
+ let buf = ready ! ( self . as_mut( ) . poll_decode( cx) ) ;
247
+
248
+ return if let Some ( Ok ( mut buf) ) = buf {
249
+ // We found some trailers so extract them since we
250
+ // want to return them via `poll_trailers`.
251
+ if let Some ( len) = find_trailers ( & buf[ ..] ) {
252
+ // Extract up to len of where the trailers are at
253
+ let msg_buf = buf. copy_to_bytes ( len) ;
254
+ match decode_trailers_frame ( buf) {
255
+ Ok ( Some ( trailers) ) => {
256
+ self . project ( ) . trailers . replace ( trailers) ;
257
+ }
258
+ Err ( e) => return Poll :: Ready ( Some ( Err ( e) ) ) ,
259
+ _ => { }
260
+ }
261
+
262
+ if msg_buf. has_remaining ( ) {
263
+ return Poll :: Ready ( Some ( Ok ( msg_buf) ) ) ;
264
+ } else {
265
+ return Poll :: Ready ( None ) ;
266
+ }
267
+ }
268
+
269
+ Poll :: Ready ( Some ( Ok ( buf) ) )
270
+ } else {
271
+ Poll :: Ready ( buf)
272
+ } ;
273
+ }
274
+
198
275
match self . direction {
199
- Direction :: Request => self . poll_decode ( cx) ,
200
- Direction :: Response => self . poll_encode ( cx) ,
276
+ Direction :: Decode => self . poll_decode ( cx) ,
277
+ Direction :: Encode => self . poll_encode ( cx) ,
278
+ Direction :: Empty => Poll :: Ready ( None ) ,
201
279
}
202
280
}
203
281
204
282
fn poll_trailers (
205
283
self : Pin < & mut Self > ,
206
284
_: & mut Context < ' _ > ,
207
285
) -> Poll < Result < Option < HeaderMap < HeaderValue > > , Self :: Error > > {
208
- Poll :: Ready ( Ok ( None ) )
286
+ let trailers = self . project ( ) . trailers . take ( ) ;
287
+ Poll :: Ready ( Ok ( trailers) )
209
288
}
210
289
211
290
fn is_end_stream ( & self ) -> bool {
@@ -268,6 +347,56 @@ fn encode_trailers(trailers: HeaderMap) -> Vec<u8> {
268
347
} )
269
348
}
270
349
350
+ fn decode_trailers_frame ( mut buf : Bytes ) -> Result < Option < HeaderMap > , Status > {
351
+ if buf. remaining ( ) < GRPC_HEADER_SIZE {
352
+ return Ok ( None ) ;
353
+ }
354
+
355
+ buf. get_u8 ( ) ;
356
+ buf. get_u32 ( ) ;
357
+
358
+ let mut map = HeaderMap :: new ( ) ;
359
+ let mut temp_buf = buf. clone ( ) ;
360
+
361
+ let mut trailers = Vec :: new ( ) ;
362
+ let mut cursor_pos = 0 ;
363
+
364
+ for ( i, b) in buf. iter ( ) . enumerate ( ) {
365
+ if b == & b'\r' && buf. get ( i + 1 ) == Some ( & b'\n' ) {
366
+ let trailer = temp_buf. copy_to_bytes ( i - cursor_pos) ;
367
+ cursor_pos = i;
368
+ trailers. push ( trailer) ;
369
+ if temp_buf. has_remaining ( ) {
370
+ temp_buf. get_u8 ( ) ;
371
+ temp_buf. get_u8 ( ) ;
372
+ }
373
+ }
374
+ }
375
+
376
+ for trailer in trailers {
377
+ let mut s = trailer. split ( |b| b == & b':' ) ;
378
+ let key = s
379
+ . next ( )
380
+ . ok_or_else ( || Status :: internal ( "trailers couldn't parse key" ) ) ?;
381
+ let value = s
382
+ . next ( )
383
+ . ok_or_else ( || Status :: internal ( "trailers couldn't parse value" ) ) ?;
384
+
385
+ let value = value
386
+ . split ( |b| b == & b'\r' )
387
+ . next ( )
388
+ . ok_or_else ( || Status :: internal ( "trailers was not escaped" ) ) ?;
389
+
390
+ let header_key = HeaderName :: try_from ( key)
391
+ . map_err ( |e| Status :: internal ( format ! ( "Unable to parse HeaderName: {}" , e) ) ) ?;
392
+ let header_value = HeaderValue :: try_from ( value)
393
+ . map_err ( |e| Status :: internal ( format ! ( "Unable to parse HeaderValue: {}" , e) ) ) ?;
394
+ map. insert ( header_key, header_value) ;
395
+ }
396
+
397
+ Ok ( Some ( map) )
398
+ }
399
+
271
400
fn make_trailers_frame ( trailers : HeaderMap ) -> Vec < u8 > {
272
401
let trailers = encode_trailers ( trailers) ;
273
402
let len = trailers. len ( ) ;
@@ -281,6 +410,41 @@ fn make_trailers_frame(trailers: HeaderMap) -> Vec<u8> {
281
410
frame
282
411
}
283
412
413
+ /// Search some buffer for grpc-web trailers headers and return
414
+ /// its location in the original buf. If `None` is returned we did
415
+ /// not find a trailers in this buffer either because its incomplete
416
+ /// or the buffer jsut contained grpc message frames.
417
+ fn find_trailers ( buf : & [ u8 ] ) -> Option < usize > {
418
+ let mut len = 0 ;
419
+ let mut temp_buf = & buf[ ..] ;
420
+
421
+ loop {
422
+ // To check each frame, there must be at least GRPC_HEADER_SIZE
423
+ // amount of bytes available otherwise the buffer is incomplete.
424
+ if temp_buf. is_empty ( ) || temp_buf. len ( ) < GRPC_HEADER_SIZE {
425
+ return None ;
426
+ }
427
+
428
+ let header = temp_buf. get_u8 ( ) ;
429
+
430
+ if header == GRPC_WEB_TRAILERS_BIT {
431
+ return Some ( len) ;
432
+ }
433
+
434
+ let msg_len = temp_buf. get_u32 ( ) ;
435
+
436
+ len += msg_len as usize + 4 + 1 ;
437
+
438
+ // If the msg len of a non-grpc-web trailer frame is larger than
439
+ // the overall buffer we know within that buffer there are no trailers.
440
+ if len > buf. len ( ) {
441
+ return None ;
442
+ }
443
+
444
+ temp_buf = & buf[ len as usize ..] ;
445
+ }
446
+ }
447
+
284
448
#[ cfg( test) ]
285
449
mod tests {
286
450
use super :: * ;
@@ -305,4 +469,55 @@ mod tests {
305
469
assert_eq ! ( Encoding :: from_accept( & headers) , case. 1 , "{}" , case. 0 ) ;
306
470
}
307
471
}
472
+
473
+ #[ test]
474
+ fn decode_trailers ( ) {
475
+ let mut headers = HeaderMap :: new ( ) ;
476
+ headers. insert ( "grpc-status" , 0 . try_into ( ) . unwrap ( ) ) ;
477
+ headers. insert ( "grpc-message" , "this is a message" . try_into ( ) . unwrap ( ) ) ;
478
+
479
+ let trailers = make_trailers_frame ( headers. clone ( ) ) ;
480
+
481
+ let buf = Bytes :: from ( trailers) ;
482
+
483
+ let map = decode_trailers_frame ( buf) . unwrap ( ) . unwrap ( ) ;
484
+
485
+ assert_eq ! ( headers, map) ;
486
+ }
487
+
488
+ #[ test]
489
+ fn find_trailers_non_buffered ( ) {
490
+ // Byte version of this:
491
+ // b"\x80\0\0\0\x0fgrpc-status:0\r\n"
492
+ let buf = vec ! [
493
+ 128 , 0 , 0 , 0 , 15 , 103 , 114 , 112 , 99 , 45 , 115 , 116 , 97 , 116 , 117 , 115 , 58 , 48 , 13 , 10 ,
494
+ ] ;
495
+
496
+ let out = find_trailers ( & buf[ ..] ) ;
497
+
498
+ assert_eq ! ( out, Some ( 0 ) ) ;
499
+ }
500
+
501
+ #[ test]
502
+ fn find_trailers_buffered ( ) {
503
+ // Byte version of this:
504
+ // b"\0\0\0\0L\n$975738af-1a17-4aea-b887-ed0bbced6093\x1a$da609e9b-f470-4cc0-a691-3fd6a005a436\x80\0\0\0\x0fgrpc-status:0\r\n"
505
+ let buf = vec ! [
506
+ 0 , 0 , 0 , 0 , 76 , 10 , 36 , 57 , 55 , 53 , 55 , 51 , 56 , 97 , 102 , 45 , 49 , 97 , 49 , 55 , 45 , 52 ,
507
+ 97 , 101 , 97 , 45 , 98 , 56 , 56 , 55 , 45 , 101 , 100 , 48 , 98 , 98 , 99 , 101 , 100 , 54 , 48 , 57 ,
508
+ 51 , 26 , 36 , 100 , 97 , 54 , 48 , 57 , 101 , 57 , 98 , 45 , 102 , 52 , 55 , 48 , 45 , 52 , 99 , 99 , 48 ,
509
+ 45 , 97 , 54 , 57 , 49 , 45 , 51 , 102 , 100 , 54 , 97 , 48 , 48 , 53 , 97 , 52 , 51 , 54 , 128 , 0 , 0 , 0 ,
510
+ 15 , 103 , 114 , 112 , 99 , 45 , 115 , 116 , 97 , 116 , 117 , 115 , 58 , 48 , 13 , 10 ,
511
+ ] ;
512
+
513
+ let out = find_trailers ( & buf[ ..] ) ;
514
+
515
+ assert_eq ! ( out, Some ( 81 ) ) ;
516
+
517
+ let trailers = decode_trailers_frame ( Bytes :: copy_from_slice ( & buf[ 81 ..] ) )
518
+ . unwrap ( )
519
+ . unwrap ( ) ;
520
+ let status = trailers. get ( "grpc-status" ) . unwrap ( ) ;
521
+ assert_eq ! ( status. to_str( ) . unwrap( ) , "0" )
522
+ }
308
523
}
0 commit comments