diff --git a/tests/integration_tests/tests/load_shed.rs b/tests/integration_tests/tests/load_shed.rs
new file mode 100644
index 000000000..746d6a1be
--- /dev/null
+++ b/tests/integration_tests/tests/load_shed.rs
@@ -0,0 +1,61 @@
+use integration_tests::pb::{test_client, test_server, Input, Output};
+use std::net::SocketAddr;
+use tokio::net::TcpListener;
+use tonic::{transport::Server, Code, Request, Response, Status};
+
+#[tokio::test]
+async fn service_resource_exhausted() {
+ let addr = run_service_in_background(0).await;
+
+ let mut client = test_client::TestClient::connect(format!("http://{}", addr))
+ .await
+ .unwrap();
+
+ let req = Request::new(Input {});
+ let res = client.unary_call(req).await;
+
+ let err = res.unwrap_err();
+ assert_eq!(err.code(), Code::ResourceExhausted);
+}
+
+#[tokio::test]
+async fn service_resource_not_exhausted() {
+ let addr = run_service_in_background(1).await;
+
+ let mut client = test_client::TestClient::connect(format!("http://{}", addr))
+ .await
+ .unwrap();
+
+ let req = Request::new(Input {});
+ let res = client.unary_call(req).await;
+
+ assert!(res.is_ok());
+}
+
+async fn run_service_in_background(concurrency_limit: usize) -> SocketAddr {
+ struct Svc;
+
+ #[tonic::async_trait]
+ impl test_server::Test for Svc {
+ async fn unary_call(&self, _req: Request) -> Result, Status> {
+ Ok(Response::new(Output {}))
+ }
+ }
+
+ let svc = test_server::TestServer::new(Svc {});
+
+ let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+ let addr = listener.local_addr().unwrap();
+
+ tokio::spawn(async move {
+ Server::builder()
+ .concurrency_limit_per_connection(concurrency_limit)
+ .load_shed(true)
+ .add_service(svc)
+ .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
+ .await
+ .unwrap();
+ });
+
+ addr
+}
diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml
index a9dca6f24..d1aa14bb6 100644
--- a/tonic/Cargo.toml
+++ b/tonic/Cargo.toml
@@ -39,12 +39,12 @@ server = [
"dep:socket2",
"dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time",
"tokio-stream/net",
- "dep:tower", "tower?/util", "tower?/limit",
+ "dep:tower", "tower?/util", "tower?/limit", "tower?/load-shed",
]
channel = [
"dep:hyper", "hyper?/client",
"dep:hyper-util", "hyper-util?/client-legacy",
- "dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/util",
+ "dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/load-shed", "tower?/util",
"dep:tokio", "tokio?/time",
"dep:hyper-timeout",
]
diff --git a/tonic/src/status.rs b/tonic/src/status.rs
index 374b566a0..8ea2a5ec6 100644
--- a/tonic/src/status.rs
+++ b/tonic/src/status.rs
@@ -348,6 +348,18 @@ impl Status {
Err(err) => err,
};
+ // If the load shed middleware is enabled, respond to
+ // service overloaded with an appropriate grpc status.
+ #[cfg(feature = "server")]
+ let err = match err.downcast::() {
+ Ok(_) => {
+ return Ok(Status::resource_exhausted(
+ "Too many active requests for the connection",
+ ));
+ }
+ Err(err) => err,
+ };
+
if let Some(mut status) = find_status_in_source_chain(&*err) {
status.source = Some(err.into());
return Ok(status);
diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs
index ec4a884b8..c5684e72d 100644
--- a/tonic/src/transport/server/mod.rs
+++ b/tonic/src/transport/server/mod.rs
@@ -66,6 +66,7 @@ use tower::{
layer::util::{Identity, Stack},
layer::Layer,
limit::concurrency::ConcurrencyLimitLayer,
+ load_shed::LoadShedLayer,
util::BoxCloneService,
Service, ServiceBuilder, ServiceExt,
};
@@ -87,6 +88,7 @@ const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(20);
pub struct Server {
trace_interceptor: Option,
concurrency_limit: Option,
+ load_shed: bool,
timeout: Option,
#[cfg(feature = "_tls-any")]
tls: Option,
@@ -111,6 +113,7 @@ impl Default for Server {
Self {
trace_interceptor: None,
concurrency_limit: None,
+ load_shed: false,
timeout: None,
#[cfg(feature = "_tls-any")]
tls: None,
@@ -179,6 +182,27 @@ impl Server {
}
}
+ /// Enable or disable load shedding. The default is disabled.
+ ///
+ /// When load shedding is enabled, if the service responds with not ready
+ /// the request will immediately be rejected with a
+ /// [`resource_exhausted`](https://docs.rs/tonic/latest/tonic/struct.Status.html#method.resource_exhausted) error.
+ /// The default is to buffer requests. This is especially useful in combination with
+ /// setting a concurrency limit per connection.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// # use tonic::transport::Server;
+ /// # use tower_service::Service;
+ /// # let builder = Server::builder();
+ /// builder.load_shed(true);
+ /// ```
+ #[must_use]
+ pub fn load_shed(self, load_shed: bool) -> Self {
+ Server { load_shed, ..self }
+ }
+
/// Set a timeout on for all request handlers.
///
/// # Example
@@ -514,6 +538,7 @@ impl Server {
service_builder: self.service_builder.layer(new_layer),
trace_interceptor: self.trace_interceptor,
concurrency_limit: self.concurrency_limit,
+ load_shed: self.load_shed,
timeout: self.timeout,
#[cfg(feature = "_tls-any")]
tls: self.tls,
@@ -643,6 +668,7 @@ impl Server {
{
let trace_interceptor = self.trace_interceptor.clone();
let concurrency_limit = self.concurrency_limit;
+ let load_shed = self.load_shed;
let init_connection_window_size = self.init_connection_window_size;
let init_stream_window_size = self.init_stream_window_size;
let max_concurrent_streams = self.max_concurrent_streams;
@@ -667,6 +693,7 @@ impl Server {
let mut svc = MakeSvc {
inner: svc,
concurrency_limit,
+ load_shed,
timeout,
trace_interceptor,
_io: PhantomData,
@@ -1047,6 +1074,7 @@ impl fmt::Debug for Svc {
#[derive(Clone)]
struct MakeSvc {
concurrency_limit: Option,
+ load_shed: bool,
timeout: Option,
inner: S,
trace_interceptor: Option,
@@ -1080,6 +1108,7 @@ where
let svc = ServiceBuilder::new()
.layer(RecoverErrorLayer::new())
+ .option_layer(self.load_shed.then_some(LoadShedLayer::new()))
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
.layer_fn(|s| GrpcTimeout::new(s, timeout))
.service(svc);