Skip to content

Commit b947e1a

Browse files
authoredJun 19, 2024··
feat(channel): Make channel feature additive (#1574)
* feat(channel): Make channel feature additive * chore(channel): Move add_origin service module to channel module * chore(channel): Move user_agent service to channel module * chore(channel): Move reconnect service to channel module * chore(channel): Move connection service to channel module * chore(channel): Move discover service to channel module * feat(channel): Remove unused Connected implement for BoxedIo * chore(channel): Move channel part of io service to channel module * chore(channel): Move channel part of tls service to channel module * chore(channel): Move connector service to channel module * chore(channel): Move executor service to channel module * chore(channel): Clean up scope of channel service module * chore(channel): Clean up importing in channel service * chore(doc): Update feature flag document about transport and channel
1 parent e2c506a commit b947e1a

22 files changed

+242
-210
lines changed
 

‎examples/Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,13 @@ tracing = ["dep:tracing", "dep:tracing-subscriber"]
276276
uds = ["tokio-stream/net", "dep:tower", "dep:hyper", "dep:hyper-util"]
277277
streaming = ["tokio-stream", "dep:h2"]
278278
mock = ["tokio-stream", "dep:tower", "dep:hyper-util"]
279-
tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "dep:http"]
279+
tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "tower?/timeout", "dep:http"]
280280
json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"]
281281
compression = ["tonic/gzip"]
282282
tls = ["tonic/tls"]
283283
tls-rustls = ["dep:hyper", "dep:hyper-util", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls", "dep:pin-project", "dep:http-body-util"]
284284
dynamic-load-balance = ["dep:tower"]
285-
timeout = ["tokio/time", "dep:tower"]
285+
timeout = ["tokio/time", "dep:tower", "tower?/timeout"]
286286
tls-client-auth = ["tonic/tls"]
287287
types = ["dep:tonic-types"]
288288
h2c = ["dep:hyper", "dep:tower", "dep:http", "dep:hyper-util"]

‎tonic/Cargo.toml

+17-10
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,29 @@ version = "0.11.0"
2626
codegen = ["dep:async-trait"]
2727
gzip = ["dep:flate2"]
2828
zstd = ["dep:zstd"]
29-
default = ["transport", "codegen", "prost"]
29+
default = ["channel", "codegen", "prost"]
3030
prost = ["dep:prost"]
3131
tls = ["dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"]
3232
tls-roots = ["tls-roots-common", "dep:rustls-native-certs"]
33-
tls-roots-common = ["tls"]
33+
tls-roots-common = ["tls", "channel"]
3434
tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"]
3535
router = ["dep:axum"]
3636
transport = [
3737
"router",
3838
"dep:async-stream",
39-
"channel",
4039
"dep:h2",
41-
"dep:hyper", "dep:hyper-util", "dep:hyper-timeout",
40+
"dep:hyper", "dep:hyper-util",
4241
"dep:socket2",
4342
"dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time",
44-
"dep:tower",
43+
"dep:tower", "tower?/util", "tower?/limit",
44+
]
45+
channel = [
46+
"transport",
47+
"dep:hyper", "hyper?/client",
48+
"dep:hyper-util", "hyper-util?/client-legacy",
49+
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make",
50+
"dep:hyper-timeout",
4551
]
46-
channel = []
4752

4853
# [[bench]]
4954
# name = "bench_main"
@@ -71,13 +76,12 @@ async-trait = {version = "0.1.13", optional = true}
7176
# transport
7277
async-stream = {version = "0.3", optional = true}
7378
h2 = {version = "0.4", optional = true}
74-
hyper = {version = "1", features = ["full"], optional = true}
75-
hyper-util = { version = ">=0.1.4, <0.2", features = ["full"], optional = true }
76-
hyper-timeout = {version = "0.5", optional = true}
79+
hyper = {version = "1", features = ["http1", "http2", "server"], optional = true}
80+
hyper-util = { version = ">=0.1.4, <0.2", features = ["service", "server-auto", "tokio"], optional = true }
7781
socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] }
7882
tokio = {version = "1", default-features = false, optional = true}
7983
tokio-stream = { version = "0.1", features = ["net"] }
80-
tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true}
84+
tower = {version = "0.4.7", default-features = false, optional = true}
8185
axum = {version = "0.7", default-features = false, optional = true}
8286

8387
# rustls
@@ -90,6 +94,9 @@ webpki-roots = { version = "0.26", optional = true }
9094
flate2 = {version = "1.0", optional = true}
9195
zstd = { version = "0.13.0", optional = true }
9296

97+
# channel
98+
hyper-timeout = {version = "0.5", optional = true}
99+
93100
[dev-dependencies]
94101
bencher = "0.1.5"
95102
quickcheck = "1.0"

‎tonic/src/lib.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
//!
1717
//! # Feature Flags
1818
//!
19-
//! - `transport`: Enables the fully featured, batteries included client and server
20-
//! implementation based on [`hyper`], [`tower`] and [`tokio`]. Enabled by default.
21-
//! - `channel`: Enables just the full featured channel/client portion of the `transport`
22-
//! feature.
19+
//! - `transport`: Enables just the full featured server portion of the `channel` feature.
20+
//! - `channel`: Enables the fully featured, batteries included client and server
21+
//! implementation based on [`hyper`], [`tower`] and [`tokio`]. Enabled by default.
2322
//! - `codegen`: Enables all the required exports and optional dependencies required
2423
//! for [`tonic-build`]. Enabled by default.
2524
//! - `tls`: Enables the `rustls` based TLS options for the `transport` feature. Not

‎tonic/src/status.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -617,8 +617,10 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option<Status> {
617617
// matches the spec of:
618618
// > The service is currently unavailable. This is most likely a transient condition that
619619
// > can be corrected if retried with a backoff.
620-
#[cfg(feature = "transport")]
621-
if let Some(connect) = err.downcast_ref::<crate::transport::ConnectError>() {
620+
#[cfg(feature = "channel")]
621+
if let Some(connect) =
622+
err.downcast_ref::<crate::transport::channel::service::ConnectError>()
623+
{
622624
return Some(Status::unavailable(connect.to_string()));
623625
}
624626

‎tonic/src/transport/channel/endpoint.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
use super::super::service;
1+
#[cfg(feature = "tls")]
2+
use super::service::TlsConnector;
3+
use super::service::{self, Executor, SharedExec};
24
use super::Channel;
35
#[cfg(feature = "tls")]
46
use super::ClientTlsConfig;
5-
#[cfg(feature = "tls")]
6-
use crate::transport::service::TlsConnector;
7-
use crate::transport::{service::SharedExec, Error, Executor};
7+
use crate::transport::Error;
88
use bytes::Bytes;
99
use http::{uri::Uri, HeaderValue};
1010
use hyper::rt;

‎tonic/src/transport/channel/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Client implementation and builder.
22
33
mod endpoint;
4+
pub(crate) mod service;
45
#[cfg(feature = "tls")]
56
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
67
mod tls;
@@ -9,9 +10,8 @@ pub use endpoint::Endpoint;
910
#[cfg(feature = "tls")]
1011
pub use tls::ClientTlsConfig;
1112

12-
use super::service::{Connection, DynamicServiceStream, SharedExec};
13+
use self::service::{Connection, DynamicServiceStream, Executor, SharedExec};
1314
use crate::body::BoxBody;
14-
use crate::transport::Executor;
1515
use bytes::Bytes;
1616
use http::{
1717
uri::{InvalidUri, Uri},

‎tonic/src/transport/service/connection.rs ‎tonic/src/transport/channel/service/connection.rs

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
use super::SharedExec;
2-
use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent};
1+
use super::{AddOrigin, Reconnect, SharedExec, UserAgent};
32
use crate::{
43
body::{boxed, BoxBody},
5-
transport::{BoxFuture, Endpoint},
4+
transport::{service::GrpcTimeout, BoxFuture, Endpoint},
65
};
76
use http::Uri;
87
use hyper::rt;
@@ -36,7 +35,7 @@ impl Connection {
3635
C::Future: Unpin + Send,
3736
C::Response: rt::Read + rt::Write + Unpin + Send + 'static,
3837
{
39-
let mut settings: Builder<super::SharedExec> = Builder::new(endpoint.executor.clone())
38+
let mut settings: Builder<SharedExec> = Builder::new(endpoint.executor.clone())
4039
.initial_stream_window_size(endpoint.init_stream_window_size)
4140
.initial_connection_window_size(endpoint.init_connection_window_size)
4241
.keep_alive_interval(endpoint.http2_keep_alive_interval)
@@ -158,12 +157,12 @@ impl tower::Service<http::Request<BoxBody>> for SendRequest {
158157

159158
struct MakeSendRequestService<C> {
160159
connector: C,
161-
executor: super::SharedExec,
162-
settings: Builder<super::SharedExec>,
160+
executor: SharedExec,
161+
settings: Builder<SharedExec>,
163162
}
164163

165164
impl<C> MakeSendRequestService<C> {
166-
fn new(connector: C, executor: SharedExec, settings: Builder<super::SharedExec>) -> Self {
165+
fn new(connector: C, executor: SharedExec, settings: Builder<SharedExec>) -> Self {
167166
Self {
168167
connector,
169168
executor,

‎tonic/src/transport/service/connector.rs ‎tonic/src/transport/channel/service/connector.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use super::super::BoxFuture;
2-
use super::io::BoxedIo;
1+
use super::BoxedIo;
32
#[cfg(feature = "tls")]
4-
use super::tls::TlsConnector;
3+
use super::TlsConnector;
4+
use crate::transport::BoxFuture;
55
use http::Uri;
66
use std::fmt;
77
use std::task::{Context, Poll};

‎tonic/src/transport/service/discover.rs ‎tonic/src/transport/channel/service/discover.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
use super::connection::Connection;
2-
use crate::transport::Endpoint;
1+
use super::super::{Connection, Endpoint};
32

43
use hyper_util::client::legacy::connect::HttpConnector;
54
use std::{
+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
use std::io::{self, IoSlice};
2+
use std::pin::Pin;
3+
use std::task::{Context, Poll};
4+
5+
use hyper::rt;
6+
use hyper_util::client::legacy::connect::{Connected as HyperConnected, Connection};
7+
8+
pub(in crate::transport) trait Io:
9+
rt::Read + rt::Write + Send + 'static
10+
{
11+
}
12+
13+
impl<T> Io for T where T: rt::Read + rt::Write + Send + 'static {}
14+
15+
pub(crate) struct BoxedIo(Pin<Box<dyn Io>>);
16+
17+
impl BoxedIo {
18+
pub(in crate::transport) fn new<I: Io>(io: I) -> Self {
19+
BoxedIo(Box::pin(io))
20+
}
21+
}
22+
23+
impl Connection for BoxedIo {
24+
fn connected(&self) -> HyperConnected {
25+
HyperConnected::new()
26+
}
27+
}
28+
29+
impl rt::Read for BoxedIo {
30+
fn poll_read(
31+
mut self: Pin<&mut Self>,
32+
cx: &mut Context<'_>,
33+
buf: rt::ReadBufCursor<'_>,
34+
) -> Poll<io::Result<()>> {
35+
Pin::new(&mut self.0).poll_read(cx, buf)
36+
}
37+
}
38+
39+
impl rt::Write for BoxedIo {
40+
fn poll_write(
41+
mut self: Pin<&mut Self>,
42+
cx: &mut Context<'_>,
43+
buf: &[u8],
44+
) -> Poll<io::Result<usize>> {
45+
Pin::new(&mut self.0).poll_write(cx, buf)
46+
}
47+
48+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
49+
Pin::new(&mut self.0).poll_flush(cx)
50+
}
51+
52+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
53+
Pin::new(&mut self.0).poll_shutdown(cx)
54+
}
55+
56+
fn poll_write_vectored(
57+
mut self: Pin<&mut Self>,
58+
cx: &mut Context<'_>,
59+
bufs: &[IoSlice<'_>],
60+
) -> Poll<Result<usize, io::Error>> {
61+
Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
62+
}
63+
64+
fn is_write_vectored(&self) -> bool {
65+
self.0.is_write_vectored()
66+
}
67+
}
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
mod add_origin;
2+
use self::add_origin::AddOrigin;
3+
4+
mod user_agent;
5+
use self::user_agent::UserAgent;
6+
7+
mod reconnect;
8+
use self::reconnect::Reconnect;
9+
10+
mod connection;
11+
pub(super) use self::connection::Connection;
12+
13+
mod discover;
14+
pub(super) use self::discover::DynamicServiceStream;
15+
16+
mod io;
17+
use self::io::BoxedIo;
18+
19+
mod connector;
20+
pub(crate) use self::connector::{ConnectError, Connector};
21+
22+
mod executor;
23+
pub(super) use self::executor::{Executor, SharedExec};
24+
25+
#[cfg(feature = "tls")]
26+
mod tls;
27+
#[cfg(feature = "tls")]
28+
pub(super) use self::tls::TlsConnector;
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
use std::fmt;
2+
use std::io::Cursor;
3+
use std::sync::Arc;
4+
5+
use hyper_util::rt::TokioIo;
6+
use tokio::io::{AsyncRead, AsyncWrite};
7+
use tokio_rustls::{
8+
rustls::{pki_types::ServerName, ClientConfig, RootCertStore},
9+
TlsConnector as RustlsConnector,
10+
};
11+
12+
use super::io::BoxedIo;
13+
use crate::transport::service::tls::{add_certs_from_pem, load_identity, TlsError, ALPN_H2};
14+
use crate::transport::tls::{Certificate, Identity};
15+
16+
#[derive(Clone)]
17+
pub(crate) struct TlsConnector {
18+
config: Arc<ClientConfig>,
19+
domain: Arc<ServerName<'static>>,
20+
assume_http2: bool,
21+
}
22+
23+
impl TlsConnector {
24+
pub(crate) fn new(
25+
ca_certs: Vec<Certificate>,
26+
identity: Option<Identity>,
27+
domain: &str,
28+
assume_http2: bool,
29+
) -> Result<Self, crate::Error> {
30+
let builder = ClientConfig::builder();
31+
let mut roots = RootCertStore::empty();
32+
33+
#[cfg(feature = "tls-roots")]
34+
roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
35+
36+
#[cfg(feature = "tls-webpki-roots")]
37+
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
38+
39+
for cert in ca_certs {
40+
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
41+
}
42+
43+
let builder = builder.with_root_certificates(roots);
44+
let mut config = match identity {
45+
Some(identity) => {
46+
let (client_cert, client_key) = load_identity(identity)?;
47+
builder.with_client_auth_cert(client_cert, client_key)?
48+
}
49+
None => builder.with_no_client_auth(),
50+
};
51+
52+
config.alpn_protocols.push(ALPN_H2.into());
53+
Ok(Self {
54+
config: Arc::new(config),
55+
domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
56+
assume_http2,
57+
})
58+
}
59+
60+
pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
61+
where
62+
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
63+
{
64+
let io = RustlsConnector::from(self.config.clone())
65+
.connect(self.domain.as_ref().to_owned(), io)
66+
.await?;
67+
68+
// Generally we require ALPN to be negotiated, but if the user has
69+
// explicitly set `assume_http2` to true, we'll allow it to be missing.
70+
let (_, session) = io.get_ref();
71+
let alpn_protocol = session.alpn_protocol();
72+
if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) {
73+
return Err(TlsError::H2NotNegotiated.into());
74+
}
75+
Ok(BoxedIo::new(TokioIo::new(io)))
76+
}
77+
}
78+
79+
impl fmt::Debug for TlsConnector {
80+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81+
f.debug_struct("TlsConnector").finish()
82+
}
83+
}

‎tonic/src/transport/channel/tls.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
use super::service::TlsConnector;
12
use crate::transport::{
2-
service::TlsConnector,
33
tls::{Certificate, Identity},
44
Error,
55
};

‎tonic/src/transport/error.rs

+6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ struct ErrorImpl {
1515
#[derive(Debug)]
1616
pub(crate) enum Kind {
1717
Transport,
18+
#[cfg(feature = "channel")]
1819
InvalidUri,
20+
#[cfg(feature = "channel")]
1921
InvalidUserAgent,
2022
}
2123

@@ -35,18 +37,22 @@ impl Error {
3537
Error::new(Kind::Transport).with(source)
3638
}
3739

40+
#[cfg(feature = "channel")]
3841
pub(crate) fn new_invalid_uri() -> Self {
3942
Error::new(Kind::InvalidUri)
4043
}
4144

45+
#[cfg(feature = "channel")]
4246
pub(crate) fn new_invalid_user_agent() -> Self {
4347
Error::new(Kind::InvalidUserAgent)
4448
}
4549

4650
fn description(&self) -> &str {
4751
match &self.inner.kind {
4852
Kind::Transport => "transport error",
53+
#[cfg(feature = "channel")]
4954
Kind::InvalidUri => "invalid URI",
55+
#[cfg(feature = "channel")]
5056
Kind::InvalidUserAgent => "user agent is not a valid header value",
5157
}
5258
}

‎tonic/src/transport/mod.rs

+4-5
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
//!
9090
//! [rustls]: https://docs.rs/rustls/0.16.0/rustls/
9191
92+
#[cfg(feature = "channel")]
9293
pub mod channel;
9394
pub mod server;
9495

@@ -106,7 +107,6 @@ pub use self::error::Error;
106107
pub use self::server::Server;
107108
#[doc(inline)]
108109
pub use self::service::grpc_timeout::TimeoutExpired;
109-
pub(crate) use self::service::ConnectError;
110110

111111
#[cfg(feature = "tls")]
112112
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
@@ -116,10 +116,8 @@ pub use hyper::{body::Body, Uri};
116116
#[cfg(feature = "tls")]
117117
pub use tokio_rustls::rustls::pki_types::CertificateDer;
118118

119-
pub(crate) use self::service::executor::Executor;
120-
121-
#[cfg(feature = "tls")]
122-
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
119+
#[cfg(all(feature = "channel", feature = "tls"))]
120+
#[cfg_attr(docsrs, doc(cfg(all(feature = "channel", feature = "tls"))))]
123121
pub use self::channel::ClientTlsConfig;
124122
#[cfg(feature = "tls")]
125123
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
@@ -128,4 +126,5 @@ pub use self::server::ServerTlsConfig;
128126
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
129127
pub use self::tls::Identity;
130128

129+
#[cfg(feature = "channel")]
131130
use crate::service::router::BoxFuture;

‎tonic/src/transport/service/io.rs

-74
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
use crate::transport::server::Connected;
2-
use hyper::rt;
3-
use hyper_util::client::legacy::connect::{Connected as HyperConnected, Connection};
42
use std::io;
53
use std::io::IoSlice;
64
use std::pin::Pin;
@@ -9,78 +7,6 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
97
#[cfg(feature = "tls")]
108
use tokio_rustls::server::TlsStream;
119

12-
pub(in crate::transport) trait Io:
13-
rt::Read + rt::Write + Send + 'static
14-
{
15-
}
16-
17-
impl<T> Io for T where T: rt::Read + rt::Write + Send + 'static {}
18-
19-
pub(crate) struct BoxedIo(Pin<Box<dyn Io>>);
20-
21-
impl BoxedIo {
22-
pub(in crate::transport) fn new<I: Io>(io: I) -> Self {
23-
BoxedIo(Box::pin(io))
24-
}
25-
}
26-
27-
impl Connection for BoxedIo {
28-
fn connected(&self) -> HyperConnected {
29-
HyperConnected::new()
30-
}
31-
}
32-
33-
impl Connected for BoxedIo {
34-
type ConnectInfo = NoneConnectInfo;
35-
36-
fn connect_info(&self) -> Self::ConnectInfo {
37-
NoneConnectInfo
38-
}
39-
}
40-
41-
#[derive(Copy, Clone)]
42-
pub(crate) struct NoneConnectInfo;
43-
44-
impl rt::Read for BoxedIo {
45-
fn poll_read(
46-
mut self: Pin<&mut Self>,
47-
cx: &mut Context<'_>,
48-
buf: rt::ReadBufCursor<'_>,
49-
) -> Poll<io::Result<()>> {
50-
Pin::new(&mut self.0).poll_read(cx, buf)
51-
}
52-
}
53-
54-
impl rt::Write for BoxedIo {
55-
fn poll_write(
56-
mut self: Pin<&mut Self>,
57-
cx: &mut Context<'_>,
58-
buf: &[u8],
59-
) -> Poll<io::Result<usize>> {
60-
Pin::new(&mut self.0).poll_write(cx, buf)
61-
}
62-
63-
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
64-
Pin::new(&mut self.0).poll_flush(cx)
65-
}
66-
67-
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
68-
Pin::new(&mut self.0).poll_shutdown(cx)
69-
}
70-
71-
fn poll_write_vectored(
72-
mut self: Pin<&mut Self>,
73-
cx: &mut Context<'_>,
74-
bufs: &[IoSlice<'_>],
75-
) -> Poll<Result<usize, io::Error>> {
76-
Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
77-
}
78-
79-
fn is_write_vectored(&self) -> bool {
80-
self.0.is_write_vectored()
81-
}
82-
}
83-
8410
pub(crate) enum ServerIo<IO> {
8511
Io(IO),
8612
#[cfg(feature = "tls")]

‎tonic/src/transport/service/mod.rs

+2-16
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,9 @@
1-
mod add_origin;
2-
mod connection;
3-
mod connector;
4-
mod discover;
5-
pub(crate) mod executor;
61
pub(crate) mod grpc_timeout;
72
mod io;
8-
mod reconnect;
93
#[cfg(feature = "tls")]
10-
mod tls;
11-
mod user_agent;
4+
pub(crate) mod tls;
125

13-
pub(crate) use self::add_origin::AddOrigin;
14-
pub(crate) use self::connection::Connection;
15-
pub(crate) use self::connector::ConnectError;
16-
pub(crate) use self::connector::Connector;
17-
pub(crate) use self::discover::DynamicServiceStream;
18-
pub(crate) use self::executor::SharedExec;
196
pub(crate) use self::grpc_timeout::GrpcTimeout;
207
pub(crate) use self::io::ServerIo;
218
#[cfg(feature = "tls")]
22-
pub(crate) use self::tls::{TlsAcceptor, TlsConnector};
23-
pub(crate) use self::user_agent::UserAgent;
9+
pub(crate) use self::tls::TlsAcceptor;

‎tonic/src/transport/service/tls.rs

+9-78
Original file line numberDiff line numberDiff line change
@@ -6,99 +6,29 @@ use std::{
66
use tokio::io::{AsyncRead, AsyncWrite};
77
use tokio_rustls::{
88
rustls::{
9-
pki_types::{CertificateDer, PrivateKeyDer, ServerName},
9+
pki_types::{CertificateDer, PrivateKeyDer},
1010
server::WebPkiClientVerifier,
11-
ClientConfig, RootCertStore, ServerConfig,
11+
RootCertStore, ServerConfig,
1212
},
13-
TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector,
13+
TlsAcceptor as RustlsAcceptor,
1414
};
1515

16-
use super::io::BoxedIo;
1716
use crate::transport::{
1817
server::{Connected, TlsStream},
1918
Certificate, Identity,
2019
};
21-
use hyper_util::rt::TokioIo;
2220

2321
/// h2 alpn in plain format for rustls.
24-
const ALPN_H2: &[u8] = b"h2";
22+
pub(crate) const ALPN_H2: &[u8] = b"h2";
2523

2624
#[derive(Debug)]
27-
enum TlsError {
25+
pub(crate) enum TlsError {
26+
#[cfg(feature = "channel")]
2827
H2NotNegotiated,
2928
CertificateParseError,
3029
PrivateKeyParseError,
3130
}
3231

33-
#[derive(Clone)]
34-
pub(crate) struct TlsConnector {
35-
config: Arc<ClientConfig>,
36-
domain: Arc<ServerName<'static>>,
37-
assume_http2: bool,
38-
}
39-
40-
impl TlsConnector {
41-
pub(crate) fn new(
42-
ca_certs: Vec<Certificate>,
43-
identity: Option<Identity>,
44-
domain: &str,
45-
assume_http2: bool,
46-
) -> Result<Self, crate::Error> {
47-
let builder = ClientConfig::builder();
48-
let mut roots = RootCertStore::empty();
49-
50-
#[cfg(feature = "tls-roots")]
51-
roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
52-
53-
#[cfg(feature = "tls-webpki-roots")]
54-
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
55-
56-
for cert in ca_certs {
57-
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
58-
}
59-
60-
let builder = builder.with_root_certificates(roots);
61-
let mut config = match identity {
62-
Some(identity) => {
63-
let (client_cert, client_key) = load_identity(identity)?;
64-
builder.with_client_auth_cert(client_cert, client_key)?
65-
}
66-
None => builder.with_no_client_auth(),
67-
};
68-
69-
config.alpn_protocols.push(ALPN_H2.into());
70-
Ok(Self {
71-
config: Arc::new(config),
72-
domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
73-
assume_http2,
74-
})
75-
}
76-
77-
pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
78-
where
79-
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
80-
{
81-
let io = RustlsConnector::from(self.config.clone())
82-
.connect(self.domain.as_ref().to_owned(), io)
83-
.await?;
84-
85-
// Generally we require ALPN to be negotiated, but if the user has
86-
// explicitly set `assume_http2` to true, we'll allow it to be missing.
87-
let (_, session) = io.get_ref();
88-
let alpn_protocol = session.alpn_protocol();
89-
if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) {
90-
return Err(TlsError::H2NotNegotiated.into());
91-
}
92-
Ok(BoxedIo::new(TokioIo::new(io)))
93-
}
94-
}
95-
96-
impl fmt::Debug for TlsConnector {
97-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98-
f.debug_struct("TlsConnector").finish()
99-
}
100-
}
101-
10232
#[derive(Clone)]
10333
pub(crate) struct TlsAcceptor {
10434
inner: Arc<ServerConfig>,
@@ -154,6 +84,7 @@ impl fmt::Debug for TlsAcceptor {
15484
impl fmt::Display for TlsError {
15585
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15686
match self {
87+
#[cfg(feature = "channel")]
15788
TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."),
15889
TlsError::CertificateParseError => write!(f, "Error parsing TLS certificate."),
15990
TlsError::PrivateKeyParseError => write!(
@@ -166,7 +97,7 @@ impl fmt::Display for TlsError {
16697

16798
impl std::error::Error for TlsError {}
16899

169-
fn load_identity(
100+
pub(crate) fn load_identity(
170101
identity: Identity,
171102
) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), TlsError> {
172103
let cert = rustls_pemfile::certs(&mut Cursor::new(identity.cert))
@@ -180,7 +111,7 @@ fn load_identity(
180111
Ok((cert, key))
181112
}
182113

183-
fn add_certs_from_pem(
114+
pub(crate) fn add_certs_from_pem(
184115
mut certs: &mut dyn std::io::BufRead,
185116
roots: &mut RootCertStore,
186117
) -> Result<(), crate::Error> {

0 commit comments

Comments
 (0)
Please sign in to comment.