diff --git a/tests/integration_tests/tests/load_shed.rs b/tests/integration_tests/tests/load_shed.rs new file mode 100644 index 000000000..746d6a1be --- /dev/null +++ b/tests/integration_tests/tests/load_shed.rs @@ -0,0 +1,61 @@ +use integration_tests::pb::{test_client, test_server, Input, Output}; +use std::net::SocketAddr; +use tokio::net::TcpListener; +use tonic::{transport::Server, Code, Request, Response, Status}; + +#[tokio::test] +async fn service_resource_exhausted() { + let addr = run_service_in_background(0).await; + + let mut client = test_client::TestClient::connect(format!("http://{}", addr)) + .await + .unwrap(); + + let req = Request::new(Input {}); + let res = client.unary_call(req).await; + + let err = res.unwrap_err(); + assert_eq!(err.code(), Code::ResourceExhausted); +} + +#[tokio::test] +async fn service_resource_not_exhausted() { + let addr = run_service_in_background(1).await; + + let mut client = test_client::TestClient::connect(format!("http://{}", addr)) + .await + .unwrap(); + + let req = Request::new(Input {}); + let res = client.unary_call(req).await; + + assert!(res.is_ok()); +} + +async fn run_service_in_background(concurrency_limit: usize) -> SocketAddr { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, _req: Request) -> Result, Status> { + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc {}); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + Server::builder() + .concurrency_limit_per_connection(concurrency_limit) + .load_shed(true) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + addr +} diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index a9dca6f24..d1aa14bb6 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -39,12 +39,12 @@ server = [ "dep:socket2", "dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time", "tokio-stream/net", - "dep:tower", "tower?/util", "tower?/limit", + "dep:tower", "tower?/util", "tower?/limit", "tower?/load-shed", ] channel = [ "dep:hyper", "hyper?/client", "dep:hyper-util", "hyper-util?/client-legacy", - "dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/util", + "dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/load-shed", "tower?/util", "dep:tokio", "tokio?/time", "dep:hyper-timeout", ] diff --git a/tonic/src/status.rs b/tonic/src/status.rs index 374b566a0..8ea2a5ec6 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -348,6 +348,18 @@ impl Status { Err(err) => err, }; + // If the load shed middleware is enabled, respond to + // service overloaded with an appropriate grpc status. + #[cfg(feature = "server")] + let err = match err.downcast::() { + Ok(_) => { + return Ok(Status::resource_exhausted( + "Too many active requests for the connection", + )); + } + Err(err) => err, + }; + if let Some(mut status) = find_status_in_source_chain(&*err) { status.source = Some(err.into()); return Ok(status); diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index ec4a884b8..c5684e72d 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -66,6 +66,7 @@ use tower::{ layer::util::{Identity, Stack}, layer::Layer, limit::concurrency::ConcurrencyLimitLayer, + load_shed::LoadShedLayer, util::BoxCloneService, Service, ServiceBuilder, ServiceExt, }; @@ -87,6 +88,7 @@ const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(20); pub struct Server { trace_interceptor: Option, concurrency_limit: Option, + load_shed: bool, timeout: Option, #[cfg(feature = "_tls-any")] tls: Option, @@ -111,6 +113,7 @@ impl Default for Server { Self { trace_interceptor: None, concurrency_limit: None, + load_shed: false, timeout: None, #[cfg(feature = "_tls-any")] tls: None, @@ -179,6 +182,27 @@ impl Server { } } + /// Enable or disable load shedding. The default is disabled. + /// + /// When load shedding is enabled, if the service responds with not ready + /// the request will immediately be rejected with a + /// [`resource_exhausted`](https://docs.rs/tonic/latest/tonic/struct.Status.html#method.resource_exhausted) error. + /// The default is to buffer requests. This is especially useful in combination with + /// setting a concurrency limit per connection. + /// + /// # Example + /// + /// ``` + /// # use tonic::transport::Server; + /// # use tower_service::Service; + /// # let builder = Server::builder(); + /// builder.load_shed(true); + /// ``` + #[must_use] + pub fn load_shed(self, load_shed: bool) -> Self { + Server { load_shed, ..self } + } + /// Set a timeout on for all request handlers. /// /// # Example @@ -514,6 +538,7 @@ impl Server { service_builder: self.service_builder.layer(new_layer), trace_interceptor: self.trace_interceptor, concurrency_limit: self.concurrency_limit, + load_shed: self.load_shed, timeout: self.timeout, #[cfg(feature = "_tls-any")] tls: self.tls, @@ -643,6 +668,7 @@ impl Server { { let trace_interceptor = self.trace_interceptor.clone(); let concurrency_limit = self.concurrency_limit; + let load_shed = self.load_shed; let init_connection_window_size = self.init_connection_window_size; let init_stream_window_size = self.init_stream_window_size; let max_concurrent_streams = self.max_concurrent_streams; @@ -667,6 +693,7 @@ impl Server { let mut svc = MakeSvc { inner: svc, concurrency_limit, + load_shed, timeout, trace_interceptor, _io: PhantomData, @@ -1047,6 +1074,7 @@ impl fmt::Debug for Svc { #[derive(Clone)] struct MakeSvc { concurrency_limit: Option, + load_shed: bool, timeout: Option, inner: S, trace_interceptor: Option, @@ -1080,6 +1108,7 @@ where let svc = ServiceBuilder::new() .layer(RecoverErrorLayer::new()) + .option_layer(self.load_shed.then_some(LoadShedLayer::new())) .option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new)) .layer_fn(|s| GrpcTimeout::new(s, timeout)) .service(svc);