diff --git a/examples/Cargo.toml b/examples/Cargo.toml index e9c966275..a91704248 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -54,10 +54,18 @@ path = "src/dynamic_load_balance/server.rs" name = "tls-client" path = "src/tls/client.rs" +[[bin]] +name = "tls-client-rustls" +path = "src/tls/client_rustls.rs" + [[bin]] name = "tls-server" path = "src/tls/server.rs" +[[bin]] +name = "tls-server-rustls" +path = "src/tls/server_rustls.rs" + [[bin]] name = "tls-client-auth-server" path = "src/tls_client_auth/server.rs" @@ -214,6 +222,11 @@ tonic-web = { path = "../tonic-web" } # streaming example h2 = "0.3" +tokio-rustls = "*" +hyper-rustls = { version = "0.23", features = ["http2"] } +rustls-pemfile = "*" +tower-http = { version = "0.2", features = ["add-extension"] } + [build-dependencies] tonic-build = { path = "../tonic-build", features = ["prost", "compression"] } diff --git a/examples/src/tls/client_rustls.rs b/examples/src/tls/client_rustls.rs new file mode 100644 index 000000000..f5a7e4d09 --- /dev/null +++ b/examples/src/tls/client_rustls.rs @@ -0,0 +1,85 @@ +//! This examples shows how you can combine `hyper-rustls` and `tonic` to +//! provide a custom `ClientConfig` for the tls configuration. + +pub mod pb { + tonic::include_proto!("/grpc.examples.echo"); +} + +use hyper::{client::HttpConnector, Uri}; +use pb::{echo_client::EchoClient, EchoRequest}; +use tokio_rustls::rustls::{ClientConfig, RootCertStore}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let fd = std::fs::File::open("examples/data/tls/ca.pem")?; + + let mut roots = RootCertStore::empty(); + + let mut buf = std::io::BufReader::new(&fd); + let certs = rustls_pemfile::certs(&mut buf)?; + roots.add_parsable_certificates(&certs); + + let tls = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(roots) + .with_no_client_auth(); + + let mut http = HttpConnector::new(); + http.enforce_http(false); + + // We have to do some wrapping here to map the request type from + // `https://example.com` -> `https://[::1]:50051` because `rustls` + // doesn't accept ip's as `ServerName`. + let connector = tower::ServiceBuilder::new() + .layer_fn(move |s| { + let tls = tls.clone(); + + hyper_rustls::HttpsConnectorBuilder::new() + .with_tls_config(tls) + .https_or_http() + .enable_http2() + .wrap_connector(s) + }) + // Since our cert is signed with `example.com` but we actually want to connect + // to a local server we will override the Uri passed from the `HttpsConnector` + // and map it to the correct `Uri` that will connect us directly to the local server. + .map_request(|_| Uri::from_static("https://[::1]:50051")) + .service(http); + + let client = hyper::Client::builder().build(connector); + + // Hyper expects an absolute `Uri` to allow it to know which server to connect too. + // Currently, tonic's generated code only sets the `path_and_query` section so we + // are going to write a custom tower layer in front of the hyper client to add the + // scheme and authority. + // + // Again, this Uri is `example.com` because our tls certs is signed with this SNI but above + // we actually map this back to `[::1]:50051` before the `Uri` is passed to hyper's `HttpConnector` + // to allow it to correctly establish the tcp connection to the local `tls-server`. + let uri = Uri::from_static("https://example.com"); + let svc = tower::ServiceBuilder::new() + .map_request(move |mut req: http::Request| { + let uri = Uri::builder() + .scheme(uri.scheme().unwrap().clone()) + .authority(uri.authority().unwrap().clone()) + .path_and_query(req.uri().path_and_query().unwrap().clone()) + .build() + .unwrap(); + + *req.uri_mut() = uri; + req + }) + .service(client); + + let mut client = EchoClient::new(svc); + + let request = tonic::Request::new(EchoRequest { + message: "hello".into(), + }); + + let response = client.unary_echo(request).await?; + + println!("RESPONSE={:?}", response); + + Ok(()) +} diff --git a/examples/src/tls/server_rustls.rs b/examples/src/tls/server_rustls.rs new file mode 100644 index 000000000..d8cd2b903 --- /dev/null +++ b/examples/src/tls/server_rustls.rs @@ -0,0 +1,143 @@ +pub mod pb { + tonic::include_proto!("/grpc.examples.echo"); +} + +use futures::Stream; +use hyper::server::conn::Http; +use pb::{EchoRequest, EchoResponse}; +use std::{pin::Pin, sync::Arc}; +use tokio::net::TcpListener; +use tokio_rustls::{ + rustls::{Certificate, PrivateKey, ServerConfig}, + TlsAcceptor, +}; +use tonic::{transport::Server, Request, Response, Status, Streaming}; +use tower_http::ServiceBuilderExt; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let certs = { + let fd = std::fs::File::open("examples/data/tls/server.pem")?; + let mut buf = std::io::BufReader::new(&fd); + rustls_pemfile::certs(&mut buf)? + .into_iter() + .map(Certificate) + .collect() + }; + let key = { + let fd = std::fs::File::open("examples/data/tls/server.key")?; + let mut buf = std::io::BufReader::new(&fd); + rustls_pemfile::pkcs8_private_keys(&mut buf)? + .into_iter() + .map(PrivateKey) + .next() + .unwrap() + + // let key = std::fs::read("examples/data/tls/server.key")?; + // PrivateKey(key) + }; + + let mut tls = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, key)?; + tls.alpn_protocols = vec![b"h2".to_vec()]; + + let server = EchoServer::default(); + + let svc = Server::builder() + .add_service(pb::echo_server::EchoServer::new(server)) + .into_service(); + + let mut http = Http::new(); + http.http2_only(true); + + let listener = TcpListener::bind("[::1]:50051").await?; + let tls_acceptor = TlsAcceptor::from(Arc::new(tls)); + + loop { + let (conn, addr) = match listener.accept().await { + Ok(incoming) => incoming, + Err(e) => { + eprintln!("Error accepting connection: {}", e); + continue; + } + }; + + let http = http.clone(); + let tls_acceptor = tls_acceptor.clone(); + let svc = svc.clone(); + + tokio::spawn(async move { + let mut certificates = Vec::new(); + + let conn = tls_acceptor + .accept_with(conn, |info| { + if let Some(certs) = info.peer_certificates() { + for cert in certs { + certificates.push(cert.clone()); + } + } + }) + .await + .unwrap(); + + let svc = tower::ServiceBuilder::new() + .add_extension(Arc::new(ConnInfo { addr, certificates })) + .service(svc); + + http.serve_connection(conn, svc).await.unwrap(); + }); + } +} + +#[derive(Debug)] +struct ConnInfo { + addr: std::net::SocketAddr, + certificates: Vec, +} + +type EchoResult = Result, Status>; +type ResponseStream = Pin> + Send>>; + +#[derive(Default)] +pub struct EchoServer; + +#[tonic::async_trait] +impl pb::echo_server::Echo for EchoServer { + async fn unary_echo(&self, request: Request) -> EchoResult { + let conn_info = request.extensions().get::>().unwrap(); + println!( + "Got a request from: {:?} with certs: {:?}", + conn_info.addr, conn_info.certificates + ); + + let message = request.into_inner().message; + Ok(Response::new(EchoResponse { message })) + } + + type ServerStreamingEchoStream = ResponseStream; + + async fn server_streaming_echo( + &self, + _: Request, + ) -> EchoResult { + Err(Status::unimplemented("not implemented")) + } + + async fn client_streaming_echo( + &self, + _: Request>, + ) -> EchoResult { + Err(Status::unimplemented("not implemented")) + } + + type BidirectionalStreamingEchoStream = ResponseStream; + + async fn bidirectional_streaming_echo( + &self, + _: Request>, + ) -> EchoResult { + Err(Status::unimplemented("not implemented")) + } +}