Skip to content

Commit 654289f

Browse files
authoredJun 20, 2024··
feat(transport): Make transport server and channel independent (#1630)
* feat(transport): Make transport server and channel independent * chore(channel): Move BoxFuture to channel module * chore(tls): Move server part of tls service to server module * chore(server): Move io service to server module
1 parent 58b4443 commit 654289f

File tree

20 files changed

+142
-129
lines changed

20 files changed

+142
-129
lines changed
 

‎.github/workflows/CI.yml

+6-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ jobs:
6363
with:
6464
tool: protoc@${{ env.PROTOC_VERSION }}
6565
- uses: Swatinem/rust-cache@v2
66-
- run: cargo hack udeps --workspace --each-feature ${{ matrix.option }}
66+
- run: cargo hack udeps --workspace --exclude-features tls --each-feature ${{ matrix.option }}
67+
- run: cargo udeps --package tonic --features tls,transport
68+
- run: cargo udeps --package tonic --features tls,server
69+
- run: cargo udeps --package tonic --features tls,channel
6770

6871
check:
6972
runs-on: ${{ matrix.os }}
@@ -81,6 +84,8 @@ jobs:
8184
- uses: Swatinem/rust-cache@v2
8285
- name: Check features
8386
run: cargo hack check --workspace --no-private --each-feature --no-dev-deps
87+
- name: Check tonic feature powerset
88+
run: cargo hack check --package tonic --feature-powerset --depth 2
8489
- name: Check all targets
8590
run: cargo check --workspace --all-targets --all-features
8691

‎tonic/Cargo.toml

+10-8
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,31 @@ version = "0.11.0"
2626
codegen = ["dep:async-trait"]
2727
gzip = ["dep:flate2"]
2828
zstd = ["dep:zstd"]
29-
default = ["channel", "codegen", "prost"]
29+
default = ["transport", "codegen", "prost"]
3030
prost = ["dep:prost"]
31-
tls = ["dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"]
31+
tls = ["dep:rustls-pemfile", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"]
3232
tls-roots = ["tls-roots-common", "dep:rustls-native-certs"]
3333
tls-roots-common = ["tls", "channel"]
3434
tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"]
3535
router = ["dep:axum"]
36-
transport = [
36+
server = [
3737
"router",
3838
"dep:async-stream",
3939
"dep:h2",
40-
"dep:hyper", "dep:hyper-util",
40+
"dep:hyper", "hyper?/server",
41+
"dep:hyper-util", "hyper-util?/service", "hyper-util?/server-auto",
4142
"dep:socket2",
4243
"dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time",
4344
"dep:tower", "tower?/util", "tower?/limit",
4445
]
4546
channel = [
46-
"transport",
4747
"dep:hyper", "hyper?/client",
4848
"dep:hyper-util", "hyper-util?/client-legacy",
49-
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make",
49+
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit",
50+
"dep:tokio", "tokio?/time",
5051
"dep:hyper-timeout",
5152
]
53+
transport = ["server", "channel"]
5254

5355
# [[bench]]
5456
# name = "bench_main"
@@ -76,8 +78,8 @@ async-trait = {version = "0.1.13", optional = true}
7678
# transport
7779
async-stream = {version = "0.3", optional = true}
7880
h2 = {version = "0.4", optional = true}
79-
hyper = {version = "1", features = ["http1", "http2", "server"], optional = true}
80-
hyper-util = { version = ">=0.1.4, <0.2", features = ["service", "server-auto", "tokio"], optional = true }
81+
hyper = {version = "1", features = ["http1", "http2"], optional = true}
82+
hyper-util = { version = ">=0.1.4, <0.2", features = ["tokio"], optional = true }
8183
socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] }
8284
tokio = {version = "1", default-features = false, optional = true}
8385
tokio-stream = { version = "0.1", features = ["net"] }

‎tonic/src/lib.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
//!
1717
//! # Feature Flags
1818
//!
19-
//! - `transport`: Enables just the full featured server portion of the `channel` feature.
20-
//! - `channel`: Enables the fully featured, batteries included client and server
19+
//! - `transport`: Enables the fully featured, batteries included client and server
2120
//! implementation based on [`hyper`], [`tower`] and [`tokio`]. Enabled by default.
21+
//! - `server`: Enables just the full featured server portion of the `transport` feature.
22+
//! - `channel`: Enables just the full featured channel portion of the `transport` feature.
2223
//! - `codegen`: Enables all the required exports and optional dependencies required
2324
//! for [`tonic-build`]. Enabled by default.
2425
//! - `tls`: Enables the `rustls` based TLS options for the `transport` feature. Not
@@ -100,8 +101,8 @@ pub mod metadata;
100101
pub mod server;
101102
pub mod service;
102103

103-
#[cfg(feature = "transport")]
104-
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
104+
#[cfg(any(feature = "server", feature = "channel"))]
105+
#[cfg_attr(docsrs, doc(cfg(any(feature = "server", feature = "channel"))))]
105106
pub mod transport;
106107

107108
mod extensions;

‎tonic/src/request.rs

+11-11
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
use crate::metadata::{MetadataMap, MetadataValue};
2-
#[cfg(feature = "transport")]
2+
#[cfg(feature = "server")]
33
use crate::transport::server::TcpConnectInfo;
4-
#[cfg(feature = "tls")]
4+
#[cfg(all(feature = "server", feature = "tls"))]
55
use crate::transport::server::TlsConnectInfo;
66
use http::Extensions;
7-
#[cfg(feature = "transport")]
7+
#[cfg(feature = "server")]
88
use std::net::SocketAddr;
9-
#[cfg(feature = "tls")]
9+
#[cfg(all(feature = "server", feature = "tls"))]
1010
use std::sync::Arc;
1111
use std::time::Duration;
12-
#[cfg(feature = "tls")]
12+
#[cfg(all(feature = "server", feature = "tls"))]
1313
use tokio_rustls::rustls::pki_types::CertificateDer;
1414
use tokio_stream::Stream;
1515

@@ -211,8 +211,8 @@ impl<T> Request<T> {
211211
/// This will return `None` if the `IO` type used
212212
/// does not implement `Connected` or when using a unix domain socket.
213213
/// This currently only works on the server side.
214-
#[cfg(feature = "transport")]
215-
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
214+
#[cfg(feature = "server")]
215+
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
216216
pub fn local_addr(&self) -> Option<SocketAddr> {
217217
let addr = self
218218
.extensions()
@@ -234,8 +234,8 @@ impl<T> Request<T> {
234234
/// This will return `None` if the `IO` type used
235235
/// does not implement `Connected` or when using a unix domain socket.
236236
/// This currently only works on the server side.
237-
#[cfg(feature = "transport")]
238-
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
237+
#[cfg(feature = "server")]
238+
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
239239
pub fn remote_addr(&self) -> Option<SocketAddr> {
240240
let addr = self
241241
.extensions()
@@ -258,8 +258,8 @@ impl<T> Request<T> {
258258
/// and is mostly used for mTLS. This currently only returns
259259
/// `Some` on the server side of the `transport` server with
260260
/// TLS enabled connections.
261-
#[cfg(feature = "tls")]
262-
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
261+
#[cfg(all(feature = "server", feature = "tls"))]
262+
#[cfg_attr(docsrs, doc(all(feature = "server", feature = "tls")))]
263263
pub fn peer_certs(&self) -> Option<Arc<Vec<CertificateDer<'static>>>> {
264264
self.extensions()
265265
.get::<TlsConnectInfo<TcpConnectInfo>>()

‎tonic/src/service/router.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ struct AxumBodyService<S> {
149149
service: S,
150150
}
151151

152-
pub(crate) type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
152+
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
153153

154154
impl<S> Service<Request<axum::body::Body>> for AxumBodyService<S>
155155
where

‎tonic/src/status.rs

+12-13
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ impl Status {
305305
Status::new(Code::Unauthenticated, message)
306306
}
307307

308-
#[cfg_attr(not(feature = "transport"), allow(dead_code))]
309308
pub(crate) fn from_error_generic(
310309
err: impl Into<Box<dyn Error + Send + Sync + 'static>>,
311310
) -> Status {
@@ -316,7 +315,6 @@ impl Status {
316315
///
317316
/// Inspects the error source chain for recognizable errors, including statuses, HTTP2, and
318317
/// hyper, and attempts to maps them to a `Status`, or else returns an Unknown `Status`.
319-
#[cfg_attr(not(feature = "transport"), allow(dead_code))]
320318
pub fn from_error(err: Box<dyn Error + Send + Sync + 'static>) -> Status {
321319
Status::try_from_error(err).unwrap_or_else(|err| {
322320
let mut status = Status::new(Code::Unknown, err.to_string());
@@ -342,7 +340,7 @@ impl Status {
342340
Err(err) => err,
343341
};
344342

345-
#[cfg(feature = "transport")]
343+
#[cfg(feature = "server")]
346344
let err = match err.downcast::<h2::Error>() {
347345
Ok(h2) => {
348346
return Ok(Status::from_h2_error(h2));
@@ -359,7 +357,7 @@ impl Status {
359357
}
360358

361359
// FIXME: bubble this into `transport` and expose generic http2 reasons.
362-
#[cfg(feature = "transport")]
360+
#[cfg(feature = "server")]
363361
fn from_h2_error(err: Box<h2::Error>) -> Status {
364362
let code = Self::code_from_h2(&err);
365363

@@ -368,7 +366,7 @@ impl Status {
368366
status
369367
}
370368

371-
#[cfg(feature = "transport")]
369+
#[cfg(feature = "server")]
372370
fn code_from_h2(err: &h2::Error) -> Code {
373371
// See https://github.com/grpc/grpc/blob/3977c30/doc/PROTOCOL-HTTP2.md#errors
374372
match err.reason() {
@@ -388,7 +386,7 @@ impl Status {
388386
}
389387
}
390388

391-
#[cfg(feature = "transport")]
389+
#[cfg(feature = "server")]
392390
fn to_h2_error(&self) -> h2::Error {
393391
// conservatively transform to h2 error codes...
394392
let reason = match self.code {
@@ -404,7 +402,7 @@ impl Status {
404402
///
405403
/// Returns Some if there's a way to handle the error, or None if the information from this
406404
/// hyper error, but perhaps not its source, should be ignored.
407-
#[cfg(feature = "transport")]
405+
#[cfg(any(feature = "server", feature = "channel"))]
408406
fn from_hyper_error(err: &hyper::Error) -> Option<Status> {
409407
// is_timeout results from hyper's keep-alive logic
410408
// (https://docs.rs/hyper/0.14.11/src/hyper/error.rs.html#192-194). Per the grpc spec
@@ -420,6 +418,7 @@ impl Status {
420418
return Some(Status::cancelled(err.to_string()));
421419
}
422420

421+
#[cfg(feature = "server")]
423422
if let Some(h2_err) = err.source().and_then(|e| e.downcast_ref::<h2::Error>()) {
424423
let code = Status::code_from_h2(h2_err);
425424
let status = Self::new(code, format!("h2 protocol error: {}", err));
@@ -607,7 +606,7 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option<Status> {
607606
});
608607
}
609608

610-
#[cfg(feature = "transport")]
609+
#[cfg(feature = "server")]
611610
if let Some(timeout) = err.downcast_ref::<crate::transport::TimeoutExpired>() {
612611
return Some(Status::cancelled(timeout.to_string()));
613612
}
@@ -624,7 +623,7 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option<Status> {
624623
return Some(Status::unavailable(connect.to_string()));
625624
}
626625

627-
#[cfg(feature = "transport")]
626+
#[cfg(any(feature = "server", feature = "channel"))]
628627
if let Some(hyper) = err
629628
.downcast_ref::<hyper::Error>()
630629
.and_then(Status::from_hyper_error)
@@ -671,14 +670,14 @@ fn invalid_header_value_byte<Error: fmt::Display>(err: Error) -> Status {
671670
)
672671
}
673672

674-
#[cfg(feature = "transport")]
673+
#[cfg(feature = "server")]
675674
impl From<h2::Error> for Status {
676675
fn from(err: h2::Error) -> Self {
677676
Status::from_h2_error(Box::new(err))
678677
}
679678
}
680679

681-
#[cfg(feature = "transport")]
680+
#[cfg(feature = "server")]
682681
impl From<Status> for h2::Error {
683682
fn from(status: Status) -> Self {
684683
status.to_h2_error()
@@ -927,7 +926,7 @@ mod tests {
927926
}
928927

929928
#[test]
930-
#[cfg(feature = "transport")]
929+
#[cfg(feature = "server")]
931930
fn from_error_h2() {
932931
use std::error::Error as _;
933932

@@ -944,7 +943,7 @@ mod tests {
944943
}
945944

946945
#[test]
947-
#[cfg(feature = "transport")]
946+
#[cfg(feature = "server")]
948947
fn to_h2_error() {
949948
let orig = Status::new(Code::Cancelled, "stop eet!");
950949
let err = orig.to_h2_error();

‎tonic/src/transport/channel/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ use tower::{
3636
Service,
3737
};
3838

39+
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
3940
type Svc = Either<Connection, BoxService<Request<BoxBody>, Response<BoxBody>, crate::Error>>;
4041

4142
const DEFAULT_BUFFER_SIZE: usize = 1024;
@@ -186,7 +187,7 @@ impl Channel {
186187
D: Discover<Service = Connection> + Unpin + Send + 'static,
187188
D::Error: Into<crate::Error>,
188189
D::Key: Hash + Send + Clone,
189-
E: Executor<crate::transport::BoxFuture<'static, ()>> + Send + Sync + 'static,
190+
E: Executor<BoxFuture<'static, ()>> + Send + Sync + 'static,
190191
{
191192
let svc = Balance::new(discover);
192193

‎tonic/src/transport/channel/service/add_origin.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::transport::BoxFuture;
1+
use crate::transport::channel::BoxFuture;
22
use http::uri::Authority;
33
use http::uri::Scheme;
44
use http::{Request, Uri};

‎tonic/src/transport/channel/service/connection.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::{AddOrigin, Reconnect, SharedExec, UserAgent};
22
use crate::{
33
body::{boxed, BoxBody},
4-
transport::{service::GrpcTimeout, BoxFuture, Endpoint},
4+
transport::{channel::BoxFuture, service::GrpcTimeout, Endpoint},
55
};
66
use http::Uri;
77
use hyper::rt;

‎tonic/src/transport/channel/service/connector.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::BoxedIo;
22
#[cfg(feature = "tls")]
33
use super::TlsConnector;
4-
use crate::transport::BoxFuture;
4+
use crate::transport::channel::BoxFuture;
55
use http::Uri;
66
use std::fmt;
77
use std::task::{Context, Poll};

‎tonic/src/transport/channel/service/executor.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::transport::BoxFuture;
1+
use crate::transport::channel::BoxFuture;
22
use std::{future::Future, sync::Arc};
33

44
pub(crate) use hyper::rt::Executor;

‎tonic/src/transport/mod.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
9292
#[cfg(feature = "channel")]
9393
pub mod channel;
94+
#[cfg(feature = "server")]
9495
pub mod server;
9596

9697
mod error;
@@ -104,13 +105,16 @@ mod tls;
104105
pub use self::channel::{Channel, Endpoint};
105106
pub use self::error::Error;
106107
#[doc(inline)]
108+
#[cfg(feature = "server")]
107109
pub use self::server::Server;
108110
#[doc(inline)]
109111
pub use self::service::grpc_timeout::TimeoutExpired;
110112

111113
#[cfg(feature = "tls")]
112114
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
113115
pub use self::tls::Certificate;
116+
#[cfg(feature = "server")]
117+
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
114118
pub use axum::{body::Body as AxumBoxBody, Router as AxumRouter};
115119
pub use hyper::{body::Body, Uri};
116120
#[cfg(feature = "tls")]
@@ -119,12 +123,9 @@ pub use tokio_rustls::rustls::pki_types::CertificateDer;
119123
#[cfg(all(feature = "channel", feature = "tls"))]
120124
#[cfg_attr(docsrs, doc(cfg(all(feature = "channel", feature = "tls"))))]
121125
pub use self::channel::ClientTlsConfig;
122-
#[cfg(feature = "tls")]
123-
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
126+
#[cfg(all(feature = "server", feature = "tls"))]
127+
#[cfg_attr(docsrs, doc(all(feature = "server", feature = "tls")))]
124128
pub use self::server::ServerTlsConfig;
125129
#[cfg(feature = "tls")]
126130
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
127131
pub use self::tls::Identity;
128-
129-
#[cfg(feature = "channel")]
130-
use crate::service::router::BoxFuture;

‎tonic/src/transport/server/incoming.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
use super::{Connected, Server};
2-
use crate::transport::service::ServerIo;
1+
use super::{service::ServerIo, Connected, Server};
32
use std::{
43
net::{SocketAddr, TcpListener as StdTcpListener},
54
pin::{pin, Pin},

‎tonic/src/transport/server/mod.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
mod conn;
44
mod incoming;
55
mod recover_error;
6+
mod service;
67
#[cfg(feature = "tls")]
78
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
89
mod tls;
@@ -27,7 +28,7 @@ pub use tls::ServerTlsConfig;
2728
pub use conn::TlsConnectInfo;
2829

2930
#[cfg(feature = "tls")]
30-
use super::service::TlsAcceptor;
31+
use self::service::TlsAcceptor;
3132

3233
#[cfg(unix)]
3334
pub use unix::UdsConnectInfo;
@@ -40,8 +41,8 @@ pub(crate) use tokio_rustls::server::TlsStream;
4041
#[cfg(feature = "tls")]
4142
use crate::transport::Error;
4243

43-
use self::recover_error::RecoverError;
44-
use super::service::{GrpcTimeout, ServerIo};
44+
use self::{recover_error::RecoverError, service::ServerIo};
45+
use super::service::GrpcTimeout;
4546
use crate::body::{boxed, BoxBody};
4647
use crate::server::NamedService;
4748
use bytes::Bytes;
File renamed without changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
mod io;
2+
pub(crate) use self::io::ServerIo;
3+
4+
#[cfg(feature = "tls")]
5+
mod tls;
6+
#[cfg(feature = "tls")]
7+
pub(crate) use self::tls::TlsAcceptor;
+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
use std::{fmt, io::Cursor, sync::Arc};
2+
3+
use tokio::io::{AsyncRead, AsyncWrite};
4+
use tokio_rustls::{
5+
rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig},
6+
TlsAcceptor as RustlsAcceptor,
7+
};
8+
9+
use crate::transport::{
10+
server::{Connected, TlsStream},
11+
service::tls::{add_certs_from_pem, load_identity, ALPN_H2},
12+
Certificate, Identity,
13+
};
14+
15+
#[derive(Clone)]
16+
pub(crate) struct TlsAcceptor {
17+
inner: Arc<ServerConfig>,
18+
}
19+
20+
impl TlsAcceptor {
21+
pub(crate) fn new(
22+
identity: Identity,
23+
client_ca_root: Option<Certificate>,
24+
client_auth_optional: bool,
25+
) -> Result<Self, crate::Error> {
26+
let builder = ServerConfig::builder();
27+
28+
let builder = match client_ca_root {
29+
None => builder.with_no_client_auth(),
30+
Some(cert) => {
31+
let mut roots = RootCertStore::empty();
32+
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
33+
let verifier = if client_auth_optional {
34+
WebPkiClientVerifier::builder(roots.into()).allow_unauthenticated()
35+
} else {
36+
WebPkiClientVerifier::builder(roots.into())
37+
}
38+
.build()?;
39+
builder.with_client_cert_verifier(verifier)
40+
}
41+
};
42+
43+
let (cert, key) = load_identity(identity)?;
44+
let mut config = builder.with_single_cert(cert, key)?;
45+
46+
config.alpn_protocols.push(ALPN_H2.into());
47+
Ok(Self {
48+
inner: Arc::new(config),
49+
})
50+
}
51+
52+
pub(crate) async fn accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error>
53+
where
54+
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
55+
{
56+
let acceptor = RustlsAcceptor::from(self.inner.clone());
57+
acceptor.accept(io).await.map_err(Into::into)
58+
}
59+
}
60+
61+
impl fmt::Debug for TlsAcceptor {
62+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63+
f.debug_struct("TlsAcceptor").finish()
64+
}
65+
}

‎tonic/src/transport/server/tls.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
use crate::transport::{
2-
service::TlsAcceptor,
3-
tls::{Certificate, Identity},
4-
};
51
use std::fmt;
62

3+
use super::service::TlsAcceptor;
4+
use crate::transport::tls::{Certificate, Identity};
5+
76
/// Configures TLS settings for servers.
87
#[derive(Clone, Default)]
98
pub struct ServerTlsConfig {

‎tonic/src/transport/service/mod.rs

-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
pub(crate) mod grpc_timeout;
2-
mod io;
32
#[cfg(feature = "tls")]
43
pub(crate) mod tls;
54

65
pub(crate) use self::grpc_timeout::GrpcTimeout;
7-
pub(crate) use self::io::ServerIo;
8-
#[cfg(feature = "tls")]
9-
pub(crate) use self::tls::TlsAcceptor;

‎tonic/src/transport/service/tls.rs

+5-68
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,11 @@
1-
use std::{
2-
io::Cursor,
3-
{fmt, sync::Arc},
4-
};
1+
use std::{fmt, io::Cursor};
52

6-
use tokio::io::{AsyncRead, AsyncWrite};
7-
use tokio_rustls::{
8-
rustls::{
9-
pki_types::{CertificateDer, PrivateKeyDer},
10-
server::WebPkiClientVerifier,
11-
RootCertStore, ServerConfig,
12-
},
13-
TlsAcceptor as RustlsAcceptor,
3+
use tokio_rustls::rustls::{
4+
pki_types::{CertificateDer, PrivateKeyDer},
5+
RootCertStore,
146
};
157

16-
use crate::transport::{
17-
server::{Connected, TlsStream},
18-
Certificate, Identity,
19-
};
8+
use crate::transport::Identity;
209

2110
/// h2 alpn in plain format for rustls.
2211
pub(crate) const ALPN_H2: &[u8] = b"h2";
@@ -29,58 +18,6 @@ pub(crate) enum TlsError {
2918
PrivateKeyParseError,
3019
}
3120

32-
#[derive(Clone)]
33-
pub(crate) struct TlsAcceptor {
34-
inner: Arc<ServerConfig>,
35-
}
36-
37-
impl TlsAcceptor {
38-
pub(crate) fn new(
39-
identity: Identity,
40-
client_ca_root: Option<Certificate>,
41-
client_auth_optional: bool,
42-
) -> Result<Self, crate::Error> {
43-
let builder = ServerConfig::builder();
44-
45-
let builder = match client_ca_root {
46-
None => builder.with_no_client_auth(),
47-
Some(cert) => {
48-
let mut roots = RootCertStore::empty();
49-
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
50-
let verifier = if client_auth_optional {
51-
WebPkiClientVerifier::builder(roots.into()).allow_unauthenticated()
52-
} else {
53-
WebPkiClientVerifier::builder(roots.into())
54-
}
55-
.build()?;
56-
builder.with_client_cert_verifier(verifier)
57-
}
58-
};
59-
60-
let (cert, key) = load_identity(identity)?;
61-
let mut config = builder.with_single_cert(cert, key)?;
62-
63-
config.alpn_protocols.push(ALPN_H2.into());
64-
Ok(Self {
65-
inner: Arc::new(config),
66-
})
67-
}
68-
69-
pub(crate) async fn accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error>
70-
where
71-
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
72-
{
73-
let acceptor = RustlsAcceptor::from(self.inner.clone());
74-
acceptor.accept(io).await.map_err(Into::into)
75-
}
76-
}
77-
78-
impl fmt::Debug for TlsAcceptor {
79-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80-
f.debug_struct("TlsAcceptor").finish()
81-
}
82-
}
83-
8421
impl fmt::Display for TlsError {
8522
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
8623
match self {

0 commit comments

Comments
 (0)
Please sign in to comment.