@@ -11,7 +11,10 @@ use std::{
11
11
} ;
12
12
use tokio_stream:: { Stream , StreamExt } ;
13
13
14
+ use fuse:: Fuse ;
15
+
14
16
pub ( super ) const BUFFER_SIZE : usize = 8 * 1024 ;
17
+ const YIELD_THRESHOLD : usize = 32 * 1024 ;
15
18
16
19
pub ( crate ) fn encode_server < T , U > (
17
20
encoder : T ,
24
27
T : Encoder < Error = Status > ,
25
28
U : Stream < Item = Result < T :: Item , Status > > ,
26
29
{
27
- let stream = encode (
30
+ let stream = EncodedBytes :: new (
28
31
encoder,
29
32
source,
30
33
compression_encoding,
45
48
T : Encoder < Error = Status > ,
46
49
U : Stream < Item = T :: Item > ,
47
50
{
48
- let stream = encode (
51
+ let stream = EncodedBytes :: new (
49
52
encoder,
50
53
source. map ( Ok ) ,
51
54
compression_encoding,
@@ -55,44 +58,115 @@ where
55
58
EncodeBody :: new_client ( stream)
56
59
}
57
60
58
- fn encode < T , U > (
59
- mut encoder : T ,
60
- source : U ,
61
+ /// Combinator for efficient encoding of messages into reasonably sized buffers.
62
+ /// EncodedBytes encodes ready messages from its delegate stream into a BytesMut,
63
+ /// splitting off and yielding a buffer when either:
64
+ /// * The delegate stream polls as not ready, or
65
+ /// * The encoded buffer surpasses YIELD_THRESHOLD.
66
+ #[ pin_project( project = EncodedBytesProj ) ]
67
+ #[ derive( Debug ) ]
68
+ pub ( crate ) struct EncodedBytes < T , U >
69
+ where
70
+ T : Encoder < Error = Status > ,
71
+ U : Stream < Item = Result < T :: Item , Status > > ,
72
+ {
73
+ #[ pin]
74
+ source : Fuse < U > ,
75
+ encoder : T ,
61
76
compression_encoding : Option < CompressionEncoding > ,
62
- compression_override : SingleMessageCompressionOverride ,
63
77
max_message_size : Option < usize > ,
64
- ) -> impl Stream < Item = Result < Bytes , Status > >
78
+ buf : BytesMut ,
79
+ uncompression_buf : BytesMut ,
80
+ }
81
+
82
+ impl < T , U > EncodedBytes < T , U >
65
83
where
66
84
T : Encoder < Error = Status > ,
67
85
U : Stream < Item = Result < T :: Item , Status > > ,
68
86
{
69
- let mut buf = BytesMut :: with_capacity ( BUFFER_SIZE ) ;
87
+ fn new (
88
+ encoder : T ,
89
+ source : U ,
90
+ compression_encoding : Option < CompressionEncoding > ,
91
+ compression_override : SingleMessageCompressionOverride ,
92
+ max_message_size : Option < usize > ,
93
+ ) -> Self {
94
+ let buf = BytesMut :: with_capacity ( BUFFER_SIZE ) ;
70
95
71
- let compression_encoding = if compression_override == SingleMessageCompressionOverride :: Disable
72
- {
73
- None
74
- } else {
75
- compression_encoding
76
- } ;
96
+ let compression_encoding =
97
+ if compression_override == SingleMessageCompressionOverride :: Disable {
98
+ None
99
+ } else {
100
+ compression_encoding
101
+ } ;
77
102
78
- let mut uncompression_buf = if compression_encoding. is_some ( ) {
79
- BytesMut :: with_capacity ( BUFFER_SIZE )
80
- } else {
81
- BytesMut :: new ( )
82
- } ;
103
+ let uncompression_buf = if compression_encoding. is_some ( ) {
104
+ BytesMut :: with_capacity ( BUFFER_SIZE )
105
+ } else {
106
+ BytesMut :: new ( )
107
+ } ;
83
108
84
- source. map ( move |result| {
85
- let item = result?;
109
+ return EncodedBytes {
110
+ source : Fuse :: new ( source) ,
111
+ encoder,
112
+ compression_encoding,
113
+ max_message_size,
114
+ buf,
115
+ uncompression_buf,
116
+ } ;
117
+ }
118
+ }
86
119
87
- encode_item (
88
- & mut encoder,
89
- & mut buf,
90
- & mut uncompression_buf,
120
+ impl < T , U > Stream for EncodedBytes < T , U >
121
+ where
122
+ T : Encoder < Error = Status > ,
123
+ U : Stream < Item = Result < T :: Item , Status > > ,
124
+ {
125
+ type Item = Result < Bytes , Status > ;
126
+
127
+ fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
128
+ let EncodedBytesProj {
129
+ mut source,
130
+ encoder,
91
131
compression_encoding,
92
132
max_message_size,
93
- item,
94
- )
95
- } )
133
+ buf,
134
+ uncompression_buf,
135
+ } = self . project ( ) ;
136
+
137
+ loop {
138
+ match source. as_mut ( ) . poll_next ( cx) {
139
+ Poll :: Pending if buf. is_empty ( ) => {
140
+ return Poll :: Pending ;
141
+ }
142
+ Poll :: Ready ( None ) if buf. is_empty ( ) => {
143
+ return Poll :: Ready ( None ) ;
144
+ }
145
+ Poll :: Pending | Poll :: Ready ( None ) => {
146
+ return Poll :: Ready ( Some ( Ok ( buf. split_to ( buf. len ( ) ) . freeze ( ) ) ) ) ;
147
+ }
148
+ Poll :: Ready ( Some ( Ok ( item) ) ) => {
149
+ if let Err ( status) = encode_item (
150
+ encoder,
151
+ buf,
152
+ uncompression_buf,
153
+ * compression_encoding,
154
+ * max_message_size,
155
+ item,
156
+ ) {
157
+ return Poll :: Ready ( Some ( Err ( status) ) ) ;
158
+ }
159
+
160
+ if buf. len ( ) >= YIELD_THRESHOLD {
161
+ return Poll :: Ready ( Some ( Ok ( buf. split_to ( buf. len ( ) ) . freeze ( ) ) ) ) ;
162
+ }
163
+ }
164
+ Poll :: Ready ( Some ( Err ( status) ) ) => {
165
+ return Poll :: Ready ( Some ( Err ( status) ) ) ;
166
+ }
167
+ }
168
+ }
169
+ }
96
170
}
97
171
98
172
fn encode_item < T > (
@@ -102,10 +176,12 @@ fn encode_item<T>(
102
176
compression_encoding : Option < CompressionEncoding > ,
103
177
max_message_size : Option < usize > ,
104
178
item : T :: Item ,
105
- ) -> Result < Bytes , Status >
179
+ ) -> Result < ( ) , Status >
106
180
where
107
181
T : Encoder < Error = Status > ,
108
182
{
183
+ let offset = buf. len ( ) ;
184
+
109
185
buf. reserve ( HEADER_SIZE ) ;
110
186
unsafe {
111
187
buf. advance_mut ( HEADER_SIZE ) ;
@@ -129,14 +205,14 @@ where
129
205
}
130
206
131
207
// now that we know length, we can write the header
132
- finish_encoding ( compression_encoding, max_message_size, buf)
208
+ finish_encoding ( compression_encoding, max_message_size, & mut buf[ offset.. ] )
133
209
}
134
210
135
211
fn finish_encoding (
136
212
compression_encoding : Option < CompressionEncoding > ,
137
213
max_message_size : Option < usize > ,
138
- buf : & mut BytesMut ,
139
- ) -> Result < Bytes , Status > {
214
+ buf : & mut [ u8 ] ,
215
+ ) -> Result < ( ) , Status > {
140
216
let len = buf. len ( ) - HEADER_SIZE ;
141
217
let limit = max_message_size. unwrap_or ( DEFAULT_MAX_SEND_MESSAGE_SIZE ) ;
142
218
if len > limit {
@@ -160,7 +236,7 @@ fn finish_encoding(
160
236
buf. put_u32 ( len as u32 ) ;
161
237
}
162
238
163
- Ok ( buf . split_to ( len + HEADER_SIZE ) . freeze ( ) )
239
+ Ok ( ( ) )
164
240
}
165
241
166
242
#[ derive( Debug ) ]
@@ -269,3 +345,57 @@ where
269
345
Poll :: Ready ( self . project ( ) . state . trailers ( ) )
270
346
}
271
347
}
348
+
349
+ mod fuse {
350
+ use std:: {
351
+ pin:: Pin ,
352
+ task:: { ready, Context , Poll } ,
353
+ } ;
354
+
355
+ use tokio_stream:: Stream ;
356
+
357
+ /// Stream for the [`fuse`](super::StreamExt::fuse) method.
358
+ #[ derive( Debug ) ]
359
+ #[ pin_project:: pin_project]
360
+ #[ must_use = "streams do nothing unless polled" ]
361
+ pub ( crate ) struct Fuse < St > {
362
+ #[ pin]
363
+ stream : St ,
364
+ done : bool ,
365
+ }
366
+
367
+ impl < St > Fuse < St > {
368
+ pub ( crate ) fn new ( stream : St ) -> Self {
369
+ Self {
370
+ stream,
371
+ done : false ,
372
+ }
373
+ }
374
+ }
375
+
376
+ impl < S : Stream > Stream for Fuse < S > {
377
+ type Item = S :: Item ;
378
+
379
+ fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < S :: Item > > {
380
+ let this = self . project ( ) ;
381
+
382
+ if * this. done {
383
+ return Poll :: Ready ( None ) ;
384
+ }
385
+
386
+ let item = ready ! ( this. stream. poll_next( cx) ) ;
387
+ if item. is_none ( ) {
388
+ * this. done = true ;
389
+ }
390
+ Poll :: Ready ( item)
391
+ }
392
+
393
+ fn size_hint ( & self ) -> ( usize , Option < usize > ) {
394
+ if self . done {
395
+ ( 0 , Some ( 0 ) )
396
+ } else {
397
+ self . stream . size_hint ( )
398
+ }
399
+ }
400
+ }
401
+ }
0 commit comments