Skip to content

Commit b75e55b

Browse files
authored
[HTTP/3] Abort response stream on dispose if content not finished (#57156)
* Sends abort read/write if H/3 stream is disposed before respective contents are finsihed * Minor tweaks in abort conditions * Prevent reverting SendState from Aborted/ConnectionClosed back to sending state within Send* methods.
1 parent 9a55354 commit b75e55b

File tree

4 files changed

+175
-49
lines changed

4 files changed

+175
-49
lines changed

src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ internal sealed class Http3LoopbackStream : IDisposable
2020
private const int MaximumVarIntBytes = 8;
2121
private const long VarIntMax = (1L << 62) - 1;
2222

23-
private const long DataFrame = 0x0;
24-
private const long HeadersFrame = 0x1;
25-
private const long SettingsFrame = 0x4;
26-
private const long GoAwayFrame = 0x7;
23+
public const long DataFrame = 0x0;
24+
public const long HeadersFrame = 0x1;
25+
public const long SettingsFrame = 0x4;
26+
public const long GoAwayFrame = 0x7;
2727

2828
public const long ControlStream = 0x0;
2929
public const long PushStream = 0x1;

src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ public void Dispose()
8484
if (!_disposed)
8585
{
8686
_disposed = true;
87+
AbortStream();
8788
_stream.Dispose();
8889
DisposeSyncHelper();
8990
}
@@ -94,6 +95,7 @@ public async ValueTask DisposeAsync()
9495
if (!_disposed)
9596
{
9697
_disposed = true;
98+
AbortStream();
9799
await _stream.DisposeAsync().ConfigureAwait(false);
98100
DisposeSyncHelper();
99101
}
@@ -358,6 +360,9 @@ private async Task SendContentAsync(HttpContent content, CancellationToken cance
358360
await content.CopyToAsync(writeStream, null, cancellationToken).ConfigureAwait(false);
359361
}
360362

363+
// Set to 0 to recognize that the whole request body has been sent and therefore there's no need to abort write side in case of a premature disposal.
364+
_requestContentLengthRemaining = 0;
365+
361366
if (_sendBuffer.ActiveLength != 0)
362367
{
363368
// Our initial send buffer, which has our headers, is normally sent out on the first write to the Http3WriteStream.
@@ -1210,6 +1215,20 @@ private async ValueTask<bool> ReadNextDataFrameAsync(HttpResponseMessage respons
12101215
public void Trace(string message, [CallerMemberName] string? memberName = null) =>
12111216
_connection.Trace(StreamId, message, memberName);
12121217

1218+
private void AbortStream()
1219+
{
1220+
// If the request body isn't completed, cancel it now.
1221+
if (_requestContentLengthRemaining != 0) // 0 is used for the end of content writing, -1 is used for unknown Content-Length
1222+
{
1223+
_stream.AbortWrite((long)Http3ErrorCode.RequestCancelled);
1224+
}
1225+
// If the response body isn't completed, cancel it now.
1226+
if (_responseDataPayloadRemaining != -1) // -1 is used for EOF, 0 for consumed DATA frame payload before the next read
1227+
{
1228+
_stream.AbortRead((long)Http3ErrorCode.RequestCancelled);
1229+
}
1230+
}
1231+
12131232
// TODO: it may be possible for Http3RequestStream to implement Stream directly and avoid this allocation.
12141233
private sealed class Http3ReadStream : HttpBaseStream
12151234
{
@@ -1233,36 +1252,42 @@ public Http3ReadStream(Http3RequestStream stream)
12331252

12341253
protected override void Dispose(bool disposing)
12351254
{
1236-
if (_stream != null)
1255+
Http3RequestStream? stream = Interlocked.Exchange(ref _stream, null);
1256+
if (stream is null)
12371257
{
1238-
if (disposing)
1239-
{
1240-
// This will remove the stream from the connection properly.
1241-
_stream.Dispose();
1242-
}
1243-
else
1244-
{
1245-
// We shouldn't be using a managed instance here, but don't have much choice -- we
1246-
// need to remove the stream from the connection's GOAWAY collection.
1247-
_stream._connection.RemoveStream(_stream._stream);
1248-
_stream._connection = null!;
1249-
}
1258+
return;
1259+
}
12501260

1251-
_stream = null;
1252-
_response = null;
1261+
if (disposing)
1262+
{
1263+
// This will remove the stream from the connection properly.
1264+
stream.Dispose();
1265+
}
1266+
else
1267+
{
1268+
// We shouldn't be using a managed instance here, but don't have much choice -- we
1269+
// need to remove the stream from the connection's GOAWAY collection.
1270+
stream._connection.RemoveStream(stream._stream);
1271+
stream._connection = null!;
12531272
}
12541273

1274+
_response = null;
1275+
12551276
base.Dispose(disposing);
12561277
}
12571278

12581279
public override async ValueTask DisposeAsync()
12591280
{
1260-
if (_stream != null)
1281+
Http3RequestStream? stream = Interlocked.Exchange(ref _stream, null);
1282+
if (stream is null)
12611283
{
1262-
await _stream.DisposeAsync().ConfigureAwait(false);
1263-
_stream = null!;
1284+
return;
12641285
}
12651286

1287+
await stream.DisposeAsync().ConfigureAwait(false);
1288+
1289+
_response = null;
1290+
12661291
await base.DisposeAsync().ConfigureAwait(false);
12671292
}
12681293

src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,119 @@ public async Task ReservedFrameType_Throws()
319319
await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
320320
}
321321

322+
[Fact]
323+
public async Task RequestSentResponseDisposed_ThrowsOnServer()
324+
{
325+
byte[] data = Encoding.UTF8.GetBytes(new string('a', 1024));
326+
327+
using Http3LoopbackServer server = CreateHttp3LoopbackServer();
328+
329+
Task serverTask = Task.Run(async () =>
330+
{
331+
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
332+
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
333+
HttpRequestData request = await stream.ReadRequestDataAsync();
334+
await stream.SendResponseHeadersAsync();
335+
336+
Stopwatch sw = Stopwatch.StartNew();
337+
bool hasFailed = false;
338+
while (sw.Elapsed < TimeSpan.FromSeconds(15))
339+
{
340+
try
341+
{
342+
await stream.SendResponseBodyAsync(data, isFinal: false);
343+
}
344+
catch (QuicStreamAbortedException)
345+
{
346+
hasFailed = true;
347+
break;
348+
}
349+
}
350+
Assert.True(hasFailed, $"Expected {nameof(QuicStreamAbortedException)}, instead ran successfully for {sw.Elapsed}");
351+
});
352+
353+
Task clientTask = Task.Run(async () =>
354+
{
355+
using HttpClient client = CreateHttpClient();
356+
using HttpRequestMessage request = new()
357+
{
358+
Method = HttpMethod.Get,
359+
RequestUri = server.Address,
360+
Version = HttpVersion30,
361+
VersionPolicy = HttpVersionPolicy.RequestVersionExact
362+
};
363+
364+
var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
365+
var stream = await response.Content.ReadAsStreamAsync();
366+
byte[] buffer = new byte[512];
367+
for (int i = 0; i < 5; ++i)
368+
{
369+
var count = await stream.ReadAsync(buffer);
370+
}
371+
372+
// We haven't finished reading the whole respose, but we're disposing it, which should turn into an exception on the server-side.
373+
response.Dispose();
374+
await serverTask;
375+
});
376+
377+
await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
378+
}
379+
380+
[Fact]
381+
public async Task RequestSendingResponseDisposed_ThrowsOnServer()
382+
{
383+
byte[] data = Encoding.UTF8.GetBytes(new string('a', 1024));
384+
385+
using Http3LoopbackServer server = CreateHttp3LoopbackServer();
386+
387+
Task serverTask = Task.Run(async () =>
388+
{
389+
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
390+
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
391+
HttpRequestData request = await stream.ReadRequestDataAsync(false);
392+
await stream.SendResponseHeadersAsync();
393+
394+
Stopwatch sw = Stopwatch.StartNew();
395+
bool hasFailed = false;
396+
while (sw.Elapsed < TimeSpan.FromSeconds(15))
397+
{
398+
try
399+
{
400+
var (frameType, payload) = await stream.ReadFrameAsync();
401+
Assert.Equal(Http3LoopbackStream.DataFrame, frameType);
402+
}
403+
catch (QuicStreamAbortedException)
404+
{
405+
hasFailed = true;
406+
break;
407+
}
408+
}
409+
Assert.True(hasFailed, $"Expected {nameof(QuicStreamAbortedException)}, instead ran successfully for {sw.Elapsed}");
410+
});
411+
412+
Task clientTask = Task.Run(async () =>
413+
{
414+
using HttpClient client = CreateHttpClient();
415+
using HttpRequestMessage request = new()
416+
{
417+
Method = HttpMethod.Get,
418+
RequestUri = server.Address,
419+
Version = HttpVersion30,
420+
VersionPolicy = HttpVersionPolicy.RequestVersionExact,
421+
Content = new ByteAtATimeContent(60*4, Task.CompletedTask, new TaskCompletionSource<bool>(), 250)
422+
};
423+
424+
var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
425+
var stream = await response.Content.ReadAsStreamAsync();
426+
427+
// We haven't finished sending the whole request, but we're disposing the response, which should turn into an exception on the server-side.
428+
response.Dispose();
429+
await serverTask;
430+
});
431+
432+
await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
433+
}
434+
322435
[Fact]
323436
public async Task ServerCertificateCustomValidationCallback_Succeeds()
324437
{
@@ -885,7 +998,7 @@ public async Task StatusCodes_ReceiveSuccess(HttpStatusCode statusCode, bool qpa
885998
VersionPolicy = HttpVersionPolicy.RequestVersionExact
886999
};
8871000
HttpResponseMessage response = await client.SendAsync(request).WaitAsync(TimeSpan.FromSeconds(10));
888-
1001+
8891002
Assert.Equal(statusCode, response.StatusCode);
8901003

8911004
await serverTask;

src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ internal override async ValueTask WriteAsync(ReadOnlySequence<byte> buffers, boo
265265
{
266266
ThrowIfDisposed();
267267

268-
using CancellationTokenRegistration registration = HandleWriteStartState(cancellationToken);
268+
using CancellationTokenRegistration registration = HandleWriteStartState(buffers.IsEmpty, cancellationToken);
269269

270270
await SendReadOnlySequenceAsync(buffers, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false);
271271

@@ -281,7 +281,7 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory<ReadOnlyMemory<byte>
281281
{
282282
ThrowIfDisposed();
283283

284-
using CancellationTokenRegistration registration = HandleWriteStartState(cancellationToken);
284+
using CancellationTokenRegistration registration = HandleWriteStartState(buffers.IsEmpty, cancellationToken);
285285

286286
await SendReadOnlyMemoryListAsync(buffers, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false);
287287

@@ -292,20 +292,20 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, bool e
292292
{
293293
ThrowIfDisposed();
294294

295-
using CancellationTokenRegistration registration = HandleWriteStartState(cancellationToken);
295+
using CancellationTokenRegistration registration = HandleWriteStartState(buffer.IsEmpty, cancellationToken);
296296

297297
await SendReadOnlyMemoryAsync(buffer, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false);
298298

299299
HandleWriteCompletedState();
300300
}
301301

302-
private CancellationTokenRegistration HandleWriteStartState(CancellationToken cancellationToken)
302+
private CancellationTokenRegistration HandleWriteStartState(bool emptyBuffer, CancellationToken cancellationToken)
303303
{
304304
if (_state.SendState == SendState.Closed)
305305
{
306306
throw new InvalidOperationException(SR.net_quic_writing_notallowed);
307307
}
308-
else if ( _state.SendState == SendState.Aborted)
308+
if (_state.SendState == SendState.Aborted)
309309
{
310310
if (_state.SendErrorCode != -1)
311311
{
@@ -363,10 +363,14 @@ private CancellationTokenRegistration HandleWriteStartState(CancellationToken ca
363363

364364
throw new OperationCanceledException(SR.net_quic_sending_aborted);
365365
}
366-
else if (_state.SendState == SendState.ConnectionClosed)
366+
if (_state.SendState == SendState.ConnectionClosed)
367367
{
368368
throw GetConnectionAbortedException(_state);
369369
}
370+
371+
// Change the state in the same lock where we check for final states to prevent coming back from Aborted/ConnectionClosed.
372+
Debug.Assert(_state.SendState != SendState.Pending);
373+
_state.SendState = emptyBuffer ? SendState.Finished : SendState.Pending;
370374
}
371375

372376
return registration;
@@ -632,7 +636,10 @@ internal override void Shutdown()
632636

633637
lock (_state)
634638
{
635-
_state.SendState = SendState.Finished;
639+
if (_state.SendState < SendState.Finished)
640+
{
641+
_state.SendState = SendState.Finished;
642+
}
636643
}
637644

638645
// it is ok to send shutdown several times, MsQuic will ignore it
@@ -1157,12 +1164,6 @@ private unsafe ValueTask SendReadOnlyMemoryAsync(
11571164
ReadOnlyMemory<byte> buffer,
11581165
QUIC_SEND_FLAGS flags)
11591166
{
1160-
lock (_state)
1161-
{
1162-
Debug.Assert(_state.SendState != SendState.Pending);
1163-
_state.SendState = buffer.IsEmpty ? SendState.Finished : SendState.Pending;
1164-
}
1165-
11661167
if (buffer.IsEmpty)
11671168
{
11681169
if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN)
@@ -1211,13 +1212,6 @@ private unsafe ValueTask SendReadOnlySequenceAsync(
12111212
ReadOnlySequence<byte> buffers,
12121213
QUIC_SEND_FLAGS flags)
12131214
{
1214-
1215-
lock (_state)
1216-
{
1217-
Debug.Assert(_state.SendState != SendState.Pending);
1218-
_state.SendState = buffers.IsEmpty ? SendState.Finished : SendState.Pending;
1219-
}
1220-
12211215
if (buffers.IsEmpty)
12221216
{
12231217
if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN)
@@ -1281,12 +1275,6 @@ private unsafe ValueTask SendReadOnlyMemoryListAsync(
12811275
ReadOnlyMemory<ReadOnlyMemory<byte>> buffers,
12821276
QUIC_SEND_FLAGS flags)
12831277
{
1284-
lock (_state)
1285-
{
1286-
Debug.Assert(_state.SendState != SendState.Pending);
1287-
_state.SendState = buffers.IsEmpty ? SendState.Finished : SendState.Pending;
1288-
}
1289-
12901278
if (buffers.IsEmpty)
12911279
{
12921280
if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN)

0 commit comments

Comments
 (0)