6060import io .netty .channel .ChannelFutureListener ;
6161import io .netty .channel .ChannelHandlerContext ;
6262import io .netty .channel .ChannelPromise ;
63+ import io .netty .handler .codec .http2 .DecoratingHttp2ConnectionEncoder ;
6364import io .netty .handler .codec .http2 .DecoratingHttp2FrameWriter ;
6465import io .netty .handler .codec .http2 .DefaultHttp2Connection ;
6566import io .netty .handler .codec .http2 .DefaultHttp2ConnectionDecoder ;
8384import io .netty .handler .codec .http2 .Http2Headers ;
8485import io .netty .handler .codec .http2 .Http2HeadersDecoder ;
8586import io .netty .handler .codec .http2 .Http2InboundFrameLogger ;
87+ import io .netty .handler .codec .http2 .Http2LifecycleManager ;
8688import io .netty .handler .codec .http2 .Http2OutboundFrameLogger ;
8789import io .netty .handler .codec .http2 .Http2Settings ;
8890import io .netty .handler .codec .http2 .Http2Stream ;
@@ -125,13 +127,11 @@ class NettyServerHandler extends AbstractNettyHandler {
125127 private final long keepAliveTimeoutInNanos ;
126128 private final long maxConnectionAgeInNanos ;
127129 private final long maxConnectionAgeGraceInNanos ;
128- private final int maxRstCount ;
129- private final long maxRstPeriodNanos ;
130+ private final RstStreamCounter rstStreamCounter ;
130131 private final List <? extends ServerStreamTracer .Factory > streamTracerFactories ;
131132 private final TransportTracer transportTracer ;
132133 private final KeepAliveEnforcer keepAliveEnforcer ;
133134 private final Attributes eagAttributes ;
134- private final Ticker ticker ;
135135 /** Incomplete attributes produced by negotiator. */
136136 private Attributes negotiationAttributes ;
137137 private InternalChannelz .Security securityInfo ;
@@ -149,8 +149,6 @@ class NettyServerHandler extends AbstractNettyHandler {
149149 private ScheduledFuture <?> maxConnectionAgeMonitor ;
150150 @ CheckForNull
151151 private GracefulShutdown gracefulShutdown ;
152- private int rstCount ;
153- private long lastRstNanoTime ;
154152
155153 static NettyServerHandler newHandler (
156154 ServerTransportListener transportListener ,
@@ -251,13 +249,20 @@ static NettyServerHandler newHandler(
251249 final KeepAliveEnforcer keepAliveEnforcer = new KeepAliveEnforcer (
252250 permitKeepAliveWithoutCalls , permitKeepAliveTimeInNanos , TimeUnit .NANOSECONDS );
253251
252+ if (ticker == null ) {
253+ ticker = Ticker .systemTicker ();
254+ }
255+
256+ RstStreamCounter rstStreamCounter
257+ = new RstStreamCounter (maxRstCount , maxRstPeriodNanos , ticker );
254258 // Create the local flow controller configured to auto-refill the connection window.
255259 connection .local ().flowController (
256260 new DefaultHttp2LocalFlowController (connection , DEFAULT_WINDOW_UPDATE_RATIO , true ));
257261 frameWriter = new WriteMonitoringFrameWriter (frameWriter , keepAliveEnforcer );
258262 Http2ConnectionEncoder encoder =
259263 new DefaultHttp2ConnectionEncoder (connection , frameWriter );
260264 encoder = new Http2ControlFrameLimitEncoder (encoder , 10000 );
265+ encoder = new Http2RstCounterEncoder (encoder , rstStreamCounter );
261266 Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder (connection , encoder ,
262267 frameReader );
263268
@@ -266,10 +271,6 @@ static NettyServerHandler newHandler(
266271 settings .maxConcurrentStreams (maxStreams );
267272 settings .maxHeaderListSize (maxHeaderListSize );
268273
269- if (ticker == null ) {
270- ticker = Ticker .systemTicker ();
271- }
272-
273274 return new NettyServerHandler (
274275 channelUnused ,
275276 connection ,
@@ -286,8 +287,7 @@ static NettyServerHandler newHandler(
286287 maxConnectionAgeInNanos , maxConnectionAgeGraceInNanos ,
287288 keepAliveEnforcer ,
288289 autoFlowControl ,
289- maxRstCount ,
290- maxRstPeriodNanos ,
290+ rstStreamCounter ,
291291 eagAttributes , ticker );
292292 }
293293
@@ -310,8 +310,7 @@ private NettyServerHandler(
310310 long maxConnectionAgeGraceInNanos ,
311311 final KeepAliveEnforcer keepAliveEnforcer ,
312312 boolean autoFlowControl ,
313- int maxRstCount ,
314- long maxRstPeriodNanos ,
313+ RstStreamCounter rstStreamCounter ,
315314 Attributes eagAttributes ,
316315 Ticker ticker ) {
317316 super (
@@ -363,12 +362,9 @@ public void onStreamClosed(Http2Stream stream) {
363362 this .maxConnectionAgeInNanos = maxConnectionAgeInNanos ;
364363 this .maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos ;
365364 this .keepAliveEnforcer = checkNotNull (keepAliveEnforcer , "keepAliveEnforcer" );
366- this .maxRstCount = maxRstCount ;
367- this .maxRstPeriodNanos = maxRstPeriodNanos ;
365+ this .rstStreamCounter = rstStreamCounter ;
368366 this .eagAttributes = checkNotNull (eagAttributes , "eagAttributes" );
369- this .ticker = checkNotNull (ticker , "ticker" );
370367
371- this .lastRstNanoTime = ticker .read ();
372368 streamKey = encoder .connection ().newKey ();
373369 this .transportListener = checkNotNull (transportListener , "transportListener" );
374370 this .streamTracerFactories = checkNotNull (streamTracerFactories , "streamTracerFactories" );
@@ -575,24 +571,9 @@ private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfSt
575571 }
576572
577573 private void onRstStreamRead (int streamId , long errorCode ) throws Http2Exception {
578- if (maxRstCount > 0 ) {
579- long now = ticker .read ();
580- if (now - lastRstNanoTime > maxRstPeriodNanos ) {
581- lastRstNanoTime = now ;
582- rstCount = 1 ;
583- } else {
584- rstCount ++;
585- if (rstCount > maxRstCount ) {
586- throw new Http2Exception (Http2Error .ENHANCE_YOUR_CALM , "too_many_rststreams" ) {
587- @ SuppressWarnings ("UnsynchronizedOverridesSynchronized" ) // No memory accesses
588- @ Override
589- public Throwable fillInStackTrace () {
590- // Avoid the CPU cycles, since the resets may be a CPU consumption attack
591- return this ;
592- }
593- };
594- }
595- }
574+ Http2Exception tooManyRstStream = rstStreamCounter .countRstStream ();
575+ if (tooManyRstStream != null ) {
576+ throw tooManyRstStream ;
596577 }
597578
598579 try {
@@ -1180,6 +1161,81 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2
11801161 }
11811162 }
11821163
1164+ private static final class Http2RstCounterEncoder extends DecoratingHttp2ConnectionEncoder {
1165+ private final RstStreamCounter rstStreamCounter ;
1166+ private Http2LifecycleManager lifecycleManager ;
1167+
1168+ Http2RstCounterEncoder (Http2ConnectionEncoder encoder , RstStreamCounter rstStreamCounter ) {
1169+ super (encoder );
1170+ this .rstStreamCounter = rstStreamCounter ;
1171+ }
1172+
1173+ @ Override
1174+ public void lifecycleManager (Http2LifecycleManager lifecycleManager ) {
1175+ this .lifecycleManager = lifecycleManager ;
1176+ super .lifecycleManager (lifecycleManager );
1177+ }
1178+
1179+ @ Override
1180+ public ChannelFuture writeRstStream (
1181+ ChannelHandlerContext ctx , int streamId , long errorCode , ChannelPromise promise ) {
1182+ ChannelFuture future = super .writeRstStream (ctx , streamId , errorCode , promise );
1183+ // We want to count "induced" RST_STREAM, where the server sent a reset because of a malformed
1184+ // frame.
1185+ boolean normalRst
1186+ = errorCode == Http2Error .NO_ERROR .code () || errorCode == Http2Error .CANCEL .code ();
1187+ if (!normalRst ) {
1188+ Http2Exception tooManyRstStream = rstStreamCounter .countRstStream ();
1189+ if (tooManyRstStream != null ) {
1190+ lifecycleManager .onError (ctx , true , tooManyRstStream );
1191+ ctx .close ();
1192+ }
1193+ }
1194+ return future ;
1195+ }
1196+ }
1197+
1198+ private static final class RstStreamCounter {
1199+ private final int maxRstCount ;
1200+ private final long maxRstPeriodNanos ;
1201+ private final Ticker ticker ;
1202+ private int rstCount ;
1203+ private long lastRstNanoTime ;
1204+
1205+ RstStreamCounter (int maxRstCount , long maxRstPeriodNanos , Ticker ticker ) {
1206+ checkArgument (maxRstCount >= 0 , "maxRstCount must be non-negative: %s" , maxRstCount );
1207+ this .maxRstCount = maxRstCount ;
1208+ this .maxRstPeriodNanos = maxRstPeriodNanos ;
1209+ this .ticker = checkNotNull (ticker , "ticker" );
1210+ this .lastRstNanoTime = ticker .read ();
1211+ }
1212+
1213+ /** Returns non-{@code null} when the connection should be killed by the caller. */
1214+ private Http2Exception countRstStream () {
1215+ if (maxRstCount == 0 ) {
1216+ return null ;
1217+ }
1218+ long now = ticker .read ();
1219+ if (now - lastRstNanoTime > maxRstPeriodNanos ) {
1220+ lastRstNanoTime = now ;
1221+ rstCount = 1 ;
1222+ } else {
1223+ rstCount ++;
1224+ if (rstCount > maxRstCount ) {
1225+ return new Http2Exception (Http2Error .ENHANCE_YOUR_CALM , "too_many_rststreams" ) {
1226+ @ SuppressWarnings ("UnsynchronizedOverridesSynchronized" ) // No memory accesses
1227+ @ Override
1228+ public Throwable fillInStackTrace () {
1229+ // Avoid the CPU cycles, since the resets may be a CPU consumption attack
1230+ return this ;
1231+ }
1232+ };
1233+ }
1234+ }
1235+ return null ;
1236+ }
1237+ }
1238+
11831239 private static class ServerChannelLogger extends ChannelLogger {
11841240 private static final Logger log = Logger .getLogger (ChannelLogger .class .getName ());
11851241
0 commit comments