Skip to content

Commit aff1daf

Browse files
xanatherLucioFranco
andauthoredAug 14, 2023
feat(build): Add optional default unimplemented stubs (#1344)
Co-authored-by: Lucio Franco <luciofranco14@gmail.com>
1 parent 35cf1f2 commit aff1daf

File tree

10 files changed

+306
-9
lines changed

10 files changed

+306
-9
lines changed
 

‎Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ members = [
2424
"tonic-web/tests/integration",
2525
"tests/service_named_result",
2626
"tests/use_arc_self",
27+
"tests/default_stubs",
2728
]
2829
resolver = "2"
2930

‎tests/default_stubs/Cargo.toml

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
[package]
2+
authors = ["Jordan Singh <me@jordansingh.com>"]
3+
edition = "2021"
4+
license = "MIT"
5+
name = "default_stubs"
6+
publish = false
7+
version = "0.1.0"
8+
9+
[dependencies]
10+
futures = "0.3"
11+
tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]}
12+
tokio-stream = {version = "0.1", features = ["net"]}
13+
prost = "0.11"
14+
tonic = {path = "../../tonic"}
15+
16+
[build-dependencies]
17+
tonic-build = {path = "../../tonic-build" }
18+
19+
[package.metadata.cargo-machete]
20+
ignored = ["prost"]

‎tests/default_stubs/build.rs

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
fn main() {
2+
tonic_build::configure()
3+
.compile(&["proto/test.proto"], &["proto"])
4+
.unwrap();
5+
tonic_build::configure()
6+
.generate_default_stubs(true)
7+
.compile(&["proto/test_default.proto"], &["proto"])
8+
.unwrap();
9+
}

‎tests/default_stubs/proto/test.proto

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
syntax = "proto3";
2+
3+
package test;
4+
5+
import "google/protobuf/empty.proto";
6+
7+
service Test {
8+
rpc Unary(google.protobuf.Empty) returns (google.protobuf.Empty);
9+
rpc ServerStream(google.protobuf.Empty) returns (stream google.protobuf.Empty);
10+
rpc ClientStream(stream google.protobuf.Empty) returns (google.protobuf.Empty);
11+
rpc BidirectionalStream(stream google.protobuf.Empty) returns (stream google.protobuf.Empty);
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
syntax = "proto3";
2+
3+
package test_default;
4+
5+
import "google/protobuf/empty.proto";
6+
7+
service TestDefault {
8+
rpc Unary(google.protobuf.Empty) returns (google.protobuf.Empty);
9+
rpc ServerStream(google.protobuf.Empty) returns (stream google.protobuf.Empty);
10+
rpc ClientStream(stream google.protobuf.Empty) returns (google.protobuf.Empty);
11+
rpc BidirectionalStream(stream google.protobuf.Empty) returns (stream google.protobuf.Empty);
12+
}

‎tests/default_stubs/src/lib.rs

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#![allow(unused_imports)]
2+
3+
mod test_defaults;
4+
5+
use futures::{Stream, StreamExt};
6+
use std::pin::Pin;
7+
use tonic::{Request, Response, Status, Streaming};
8+
9+
tonic::include_proto!("test");
10+
tonic::include_proto!("test_default");
11+
12+
#[derive(Debug, Default)]
13+
struct Svc;
14+
15+
#[tonic::async_trait]
16+
impl test_server::Test for Svc {
17+
type ServerStreamStream = Pin<Box<dyn Stream<Item = Result<(), Status>> + Send + 'static>>;
18+
type BidirectionalStreamStream =
19+
Pin<Box<dyn Stream<Item = Result<(), Status>> + Send + 'static>>;
20+
21+
async fn unary(&self, _: Request<()>) -> Result<Response<()>, Status> {
22+
Err(Status::permission_denied(""))
23+
}
24+
25+
async fn server_stream(
26+
&self,
27+
_: Request<()>,
28+
) -> Result<Response<Self::ServerStreamStream>, Status> {
29+
Err(Status::permission_denied(""))
30+
}
31+
32+
async fn client_stream(&self, _: Request<Streaming<()>>) -> Result<Response<()>, Status> {
33+
Err(Status::permission_denied(""))
34+
}
35+
36+
async fn bidirectional_stream(
37+
&self,
38+
_: Request<Streaming<()>>,
39+
) -> Result<Response<Self::BidirectionalStreamStream>, Status> {
40+
Err(Status::permission_denied(""))
41+
}
42+
}
43+
44+
#[tonic::async_trait]
45+
impl test_default_server::TestDefault for Svc {
46+
// Default unimplemented stubs provided here.
47+
}
+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#![allow(unused_imports)]
2+
3+
use crate::*;
4+
use std::net::SocketAddr;
5+
use tokio::net::TcpListener;
6+
use tonic::transport::Server;
7+
8+
#[cfg(test)]
9+
fn echo_requests_iter() -> impl Stream<Item = ()> {
10+
tokio_stream::iter(1..usize::MAX).map(|_| ())
11+
}
12+
13+
#[tokio::test()]
14+
async fn test_default_stubs() {
15+
use tonic::Code;
16+
17+
let addrs = run_services_in_background().await;
18+
19+
// First validate pre-existing functionality (trait has no default implementation, we explicitly return PermissionDenied in lib.rs).
20+
let mut client = test_client::TestClient::connect(format!("http://{}", addrs.0))
21+
.await
22+
.unwrap();
23+
assert_eq!(
24+
client.unary(()).await.unwrap_err().code(),
25+
Code::PermissionDenied
26+
);
27+
assert_eq!(
28+
client.server_stream(()).await.unwrap_err().code(),
29+
Code::PermissionDenied
30+
);
31+
assert_eq!(
32+
client
33+
.client_stream(echo_requests_iter().take(5))
34+
.await
35+
.unwrap_err()
36+
.code(),
37+
Code::PermissionDenied
38+
);
39+
assert_eq!(
40+
client
41+
.bidirectional_stream(echo_requests_iter().take(5))
42+
.await
43+
.unwrap_err()
44+
.code(),
45+
Code::PermissionDenied
46+
);
47+
48+
// Then validate opt-in new functionality (trait has default implementation of returning Unimplemented).
49+
let mut client_default_stubs = test_client::TestClient::connect(format!("http://{}", addrs.1))
50+
.await
51+
.unwrap();
52+
assert_eq!(
53+
client_default_stubs.unary(()).await.unwrap_err().code(),
54+
Code::Unimplemented
55+
);
56+
assert_eq!(
57+
client_default_stubs
58+
.server_stream(())
59+
.await
60+
.unwrap_err()
61+
.code(),
62+
Code::Unimplemented
63+
);
64+
assert_eq!(
65+
client_default_stubs
66+
.client_stream(echo_requests_iter().take(5))
67+
.await
68+
.unwrap_err()
69+
.code(),
70+
Code::Unimplemented
71+
);
72+
assert_eq!(
73+
client_default_stubs
74+
.bidirectional_stream(echo_requests_iter().take(5))
75+
.await
76+
.unwrap_err()
77+
.code(),
78+
Code::Unimplemented
79+
);
80+
}
81+
82+
#[cfg(test)]
83+
async fn run_services_in_background() -> (SocketAddr, SocketAddr) {
84+
let svc = test_server::TestServer::new(Svc {});
85+
let svc_default_stubs = test_default_server::TestDefaultServer::new(Svc {});
86+
87+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
88+
let addr = listener.local_addr().unwrap();
89+
90+
let listener_default_stubs = TcpListener::bind("127.0.0.1:0").await.unwrap();
91+
let addr_default_stubs = listener_default_stubs.local_addr().unwrap();
92+
93+
tokio::spawn(async move {
94+
Server::builder()
95+
.add_service(svc)
96+
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
97+
.await
98+
.unwrap();
99+
});
100+
101+
tokio::spawn(async move {
102+
Server::builder()
103+
.add_service(svc_default_stubs)
104+
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(
105+
listener_default_stubs,
106+
))
107+
.await
108+
.unwrap();
109+
});
110+
111+
(addr, addr_default_stubs)
112+
}

‎tonic-build/src/code_gen.rs

+9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub struct CodeGenBuilder {
1313
build_transport: bool,
1414
disable_comments: HashSet<String>,
1515
use_arc_self: bool,
16+
generate_default_stubs: bool,
1617
}
1718

1819
impl CodeGenBuilder {
@@ -64,6 +65,12 @@ impl CodeGenBuilder {
6465
self
6566
}
6667

68+
/// Enable or disable returning automatic unimplemented gRPC error code for generated traits.
69+
pub fn generate_default_stubs(&mut self, generate_default_stubs: bool) -> &mut Self {
70+
self.generate_default_stubs = generate_default_stubs;
71+
self
72+
}
73+
6774
/// Generate client code based on `Service`.
6875
///
6976
/// This takes some `Service` and will generate a `TokenStream` that contains
@@ -93,6 +100,7 @@ impl CodeGenBuilder {
93100
&self.attributes,
94101
&self.disable_comments,
95102
self.use_arc_self,
103+
self.generate_default_stubs,
96104
)
97105
}
98106
}
@@ -106,6 +114,7 @@ impl Default for CodeGenBuilder {
106114
build_transport: true,
107115
disable_comments: HashSet::default(),
108116
use_arc_self: false,
117+
generate_default_stubs: false,
109118
}
110119
}
111120
}

‎tonic-build/src/prost.rs

+14
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pub fn configure() -> Builder {
4040
emit_rerun_if_changed: std::env::var_os("CARGO").is_some(),
4141
disable_comments: HashSet::default(),
4242
use_arc_self: false,
43+
generate_default_stubs: false,
4344
}
4445
}
4546

@@ -174,6 +175,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
174175
.attributes(self.builder.server_attributes.clone())
175176
.disable_comments(self.builder.disable_comments.clone())
176177
.use_arc_self(self.builder.use_arc_self)
178+
.generate_default_stubs(self.builder.generate_default_stubs)
177179
.generate_server(&service, &self.builder.proto_path);
178180

179181
self.servers.extend(server);
@@ -249,6 +251,7 @@ pub struct Builder {
249251
pub(crate) emit_rerun_if_changed: bool,
250252
pub(crate) disable_comments: HashSet<String>,
251253
pub(crate) use_arc_self: bool,
254+
pub(crate) generate_default_stubs: bool,
252255

253256
out_dir: Option<PathBuf>,
254257
}
@@ -510,6 +513,17 @@ impl Builder {
510513
self
511514
}
512515

516+
/// Enable or disable directing service generation to providing a default implementation for service methods.
517+
/// When this is false all gRPC methods must be explicitly implemented.
518+
/// When this is true any unimplemented service methods will return 'unimplemented' gRPC error code.
519+
/// When this is true all streaming server request RPC types explicitly use tonic::codegen::BoxStream type.
520+
///
521+
/// This defaults to `false`.
522+
pub fn generate_default_stubs(mut self, enable: bool) -> Self {
523+
self.generate_default_stubs = enable;
524+
self
525+
}
526+
513527
/// Compile the .proto files and execute code generation.
514528
pub fn compile(
515529
self,

‎tonic-build/src/server.rs

+70-9
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ pub(crate) fn generate_internal<T: Service>(
1717
attributes: &Attributes,
1818
disable_comments: &HashSet<String>,
1919
use_arc_self: bool,
20+
generate_default_stubs: bool,
2021
) -> TokenStream {
2122
let methods = generate_methods(
2223
service,
2324
emit_package,
2425
proto_path,
2526
compile_well_known_types,
2627
use_arc_self,
28+
generate_default_stubs,
2729
);
2830

2931
let server_service = quote::format_ident!("{}Server", service.name());
@@ -37,6 +39,7 @@ pub(crate) fn generate_internal<T: Service>(
3739
server_trait.clone(),
3840
disable_comments,
3941
use_arc_self,
42+
generate_default_stubs,
4043
);
4144
let package = if emit_package { service.package() } else { "" };
4245
// Transport based implementations
@@ -214,6 +217,7 @@ fn generate_trait<T: Service>(
214217
server_trait: Ident,
215218
disable_comments: &HashSet<String>,
216219
use_arc_self: bool,
220+
generate_default_stubs: bool,
217221
) -> TokenStream {
218222
let methods = generate_trait_methods(
219223
service,
@@ -222,6 +226,7 @@ fn generate_trait<T: Service>(
222226
compile_well_known_types,
223227
disable_comments,
224228
use_arc_self,
229+
generate_default_stubs,
225230
);
226231
let trait_doc = generate_doc_comment(format!(
227232
" Generated trait containing gRPC methods that should be implemented for use with {}Server.",
@@ -244,6 +249,7 @@ fn generate_trait_methods<T: Service>(
244249
compile_well_known_types: bool,
245250
disable_comments: &HashSet<String>,
246251
use_arc_self: bool,
252+
generate_default_stubs: bool,
247253
) -> TokenStream {
248254
let mut stream = TokenStream::new();
249255

@@ -266,22 +272,53 @@ fn generate_trait_methods<T: Service>(
266272
quote!(&self)
267273
};
268274

269-
let method = match (method.client_streaming(), method.server_streaming()) {
270-
(false, false) => {
275+
let method = match (
276+
method.client_streaming(),
277+
method.server_streaming(),
278+
generate_default_stubs,
279+
) {
280+
(false, false, true) => {
281+
quote! {
282+
#method_doc
283+
async fn #name(#self_param, request: tonic::Request<#req_message>)
284+
-> std::result::Result<tonic::Response<#res_message>, tonic::Status> {
285+
Err(tonic::Status::unimplemented("Not yet implemented"))
286+
}
287+
}
288+
}
289+
(false, false, false) => {
271290
quote! {
272291
#method_doc
273292
async fn #name(#self_param, request: tonic::Request<#req_message>)
274293
-> std::result::Result<tonic::Response<#res_message>, tonic::Status>;
275294
}
276295
}
277-
(true, false) => {
296+
(true, false, true) => {
297+
quote! {
298+
#method_doc
299+
async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
300+
-> std::result::Result<tonic::Response<#res_message>, tonic::Status> {
301+
Err(tonic::Status::unimplemented("Not yet implemented"))
302+
}
303+
}
304+
}
305+
(true, false, false) => {
278306
quote! {
279307
#method_doc
280308
async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
281309
-> std::result::Result<tonic::Response<#res_message>, tonic::Status>;
282310
}
283311
}
284-
(false, true) => {
312+
(false, true, true) => {
313+
quote! {
314+
#method_doc
315+
async fn #name(#self_param, request: tonic::Request<#req_message>)
316+
-> std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status> {
317+
Err(tonic::Status::unimplemented("Not yet implemented"))
318+
}
319+
}
320+
}
321+
(false, true, false) => {
285322
let stream = quote::format_ident!("{}Stream", method.identifier());
286323
let stream_doc = generate_doc_comment(format!(
287324
" Server streaming response type for the {} method.",
@@ -297,7 +334,16 @@ fn generate_trait_methods<T: Service>(
297334
-> std::result::Result<tonic::Response<Self::#stream>, tonic::Status>;
298335
}
299336
}
300-
(true, true) => {
337+
(true, true, true) => {
338+
quote! {
339+
#method_doc
340+
async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
341+
-> std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status> {
342+
Err(tonic::Status::unimplemented("Not yet implemented"))
343+
}
344+
}
345+
}
346+
(true, true, false) => {
301347
let stream = quote::format_ident!("{}Stream", method.identifier());
302348
let stream_doc = generate_doc_comment(format!(
303349
" Server streaming response type for the {} method.",
@@ -341,6 +387,7 @@ fn generate_methods<T: Service>(
341387
proto_path: &str,
342388
compile_well_known_types: bool,
343389
use_arc_self: bool,
390+
generate_default_stubs: bool,
344391
) -> TokenStream {
345392
let mut stream = TokenStream::new();
346393

@@ -367,6 +414,7 @@ fn generate_methods<T: Service>(
367414
ident.clone(),
368415
server_trait,
369416
use_arc_self,
417+
generate_default_stubs,
370418
),
371419
(true, false) => generate_client_streaming(
372420
method,
@@ -384,6 +432,7 @@ fn generate_methods<T: Service>(
384432
ident.clone(),
385433
server_trait,
386434
use_arc_self,
435+
generate_default_stubs,
387436
),
388437
};
389438

@@ -464,14 +513,20 @@ fn generate_server_streaming<T: Method>(
464513
method_ident: Ident,
465514
server_trait: Ident,
466515
use_arc_self: bool,
516+
generate_default_stubs: bool,
467517
) -> TokenStream {
468518
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
469519

470520
let service_ident = quote::format_ident!("{}Svc", method.identifier());
471521

472522
let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
473523

474-
let response_stream = quote::format_ident!("{}Stream", method.identifier());
524+
let response_stream = if !generate_default_stubs {
525+
let stream = quote::format_ident!("{}Stream", method.identifier());
526+
quote!(type ResponseStream = T::#stream)
527+
} else {
528+
quote!(type ResponseStream = BoxStream<#response>)
529+
};
475530

476531
let inner_arg = if use_arc_self {
477532
quote!(inner)
@@ -485,7 +540,7 @@ fn generate_server_streaming<T: Method>(
485540

486541
impl<T: #server_trait> tonic::server::ServerStreamingService<#request> for #service_ident<T> {
487542
type Response = #response;
488-
type ResponseStream = T::#response_stream;
543+
#response_stream;
489544
type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
490545

491546
fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
@@ -585,14 +640,20 @@ fn generate_streaming<T: Method>(
585640
method_ident: Ident,
586641
server_trait: Ident,
587642
use_arc_self: bool,
643+
generate_default_stubs: bool,
588644
) -> TokenStream {
589645
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
590646

591647
let service_ident = quote::format_ident!("{}Svc", method.identifier());
592648

593649
let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
594650

595-
let response_stream = quote::format_ident!("{}Stream", method.identifier());
651+
let response_stream = if !generate_default_stubs {
652+
let stream = quote::format_ident!("{}Stream", method.identifier());
653+
quote!(type ResponseStream = T::#stream)
654+
} else {
655+
quote!(type ResponseStream = BoxStream<#response>)
656+
};
596657

597658
let inner_arg = if use_arc_self {
598659
quote!(inner)
@@ -607,7 +668,7 @@ fn generate_streaming<T: Method>(
607668
impl<T: #server_trait> tonic::server::StreamingService<#request> for #service_ident<T>
608669
{
609670
type Response = #response;
610-
type ResponseStream = T::#response_stream;
671+
#response_stream;
611672
type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
612673

613674
fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {

0 commit comments

Comments
 (0)
Please sign in to comment.