Skip to content

Commit c804021

Browse files
authored
Fix gRPC retry calls not unregistering from cancellation token (#1398)
1 parent 15158af commit c804021

File tree

7 files changed

+134
-6
lines changed

7 files changed

+134
-6
lines changed

src/Grpc.Net.Client/Internal/IGrpcCall.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,7 @@ internal interface IGrpcCall<TRequest, TResponse> : IDisposable
4545
void StartDuplexStreaming();
4646

4747
Task WriteClientStreamAsync<TState>(Func<GrpcCall<TRequest, TResponse>, Stream, CallOptions, TState, ValueTask> writeFunc, TState state);
48+
49+
bool Disposed { get; }
4850
}
4951
}

src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,13 @@ private async Task StartCall(Action<GrpcCall<TRequest, TResponse>> startCallFunc
199199
}
200200
}
201201
}
202+
203+
if (CommitedCallTask.IsCompletedSuccessfully() && CommitedCallTask.Result == call)
204+
{
205+
// Wait until the commited call is finished and then clean up hedging call.
206+
await call.CallTask.ConfigureAwait(false);
207+
Cleanup();
208+
}
202209
}
203210

204211
protected override void OnCommitCall(IGrpcCall<TRequest, TResponse> call)

src/Grpc.Net.Client/Internal/Retry/RetryCall.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,16 @@ private async Task StartRetry(Action<GrpcCall<TRequest, TResponse>> startCallFun
245245
}
246246
finally
247247
{
248+
if (CommitedCallTask.IsCompletedSuccessfully())
249+
{
250+
if (CommitedCallTask.Result is GrpcCall<TRequest, TResponse> call)
251+
{
252+
// Wait until the commited call is finished and then clean up retry call.
253+
await call.CallTask.ConfigureAwait(false);
254+
Cleanup();
255+
}
256+
}
257+
248258
Log.StoppingRetryWorker(Logger);
249259
}
250260
}

src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ internal abstract partial class RetryCallBase<TRequest, TResponse> : IGrpcCall<T
4343
private readonly TaskCompletionSource<IGrpcCall<TRequest, TResponse>> _commitedCallTcs;
4444
private RetryCallBaseClientStreamReader<TRequest, TResponse>? _retryBaseClientStreamReader;
4545
private RetryCallBaseClientStreamWriter<TRequest, TResponse>? _retryBaseClientStreamWriter;
46-
private CancellationTokenRegistration? _ctsRegistration;
46+
47+
// Internal for unit testing.
48+
internal CancellationTokenRegistration? _ctsRegistration;
4749

4850
protected object Lock { get; } = new object();
4951
protected ILogger Logger { get; }
@@ -52,14 +54,14 @@ internal abstract partial class RetryCallBase<TRequest, TResponse> : IGrpcCall<T
5254
protected int MaxRetryAttempts { get; }
5355
protected CancellationTokenSource CancellationTokenSource { get; }
5456
protected TaskCompletionSource<IGrpcCall<TRequest, TResponse>?>? NewActiveCallTcs { get; set; }
55-
protected bool Disposed { get; private set; }
5657

5758
public GrpcChannel Channel { get; }
5859
public Task<IGrpcCall<TRequest, TResponse>> CommitedCallTask => _commitedCallTcs.Task;
5960
public IAsyncStreamReader<TResponse>? ClientStreamReader => _retryBaseClientStreamReader ??= new RetryCallBaseClientStreamReader<TRequest, TResponse>(this);
6061
public IClientStreamWriter<TRequest>? ClientStreamWriter => _retryBaseClientStreamWriter ??= new RetryCallBaseClientStreamWriter<TRequest, TResponse>(this);
6162
public WriteOptions? ClientStreamWriteOptions { get; internal set; }
6263
public bool ClientStreamComplete { get; set; }
64+
public bool Disposed { get; private set; }
6365

6466
protected int AttemptCount { get; private set; }
6567
protected List<ReadOnlyMemory<byte>> BufferedMessages { get; }
@@ -345,6 +347,16 @@ protected void CommitCall(IGrpcCall<TRequest, TResponse> call, CommitReason comm
345347

346348
NewActiveCallTcs?.SetResult(null);
347349
_commitedCallTcs.SetResult(call);
350+
351+
// If the commited call has finished and cleaned up then it is safe for
352+
// the wrapping retry call to clean up. This is required to unregister
353+
// from the cancellation token and avoid a memory leak.
354+
//
355+
// A commited call that has already cleaned up is likely a StatusGrpcCall.
356+
if (call.Disposed)
357+
{
358+
Cleanup();
359+
}
348360
}
349361
}
350362
}
@@ -406,18 +418,24 @@ protected virtual void Dispose(bool disposing)
406418

407419
if (disposing)
408420
{
409-
_ctsRegistration?.Dispose();
410-
CancellationTokenSource.Cancel();
411-
412421
if (CommitedCallTask.IsCompletedSuccessfully())
413422
{
414423
CommitedCallTask.Result.Dispose();
415424
}
416425

417-
ClearRetryBuffer();
426+
Cleanup();
418427
}
419428
}
420429

430+
protected void Cleanup()
431+
{
432+
_ctsRegistration?.Dispose();
433+
_ctsRegistration = null;
434+
CancellationTokenSource.Cancel();
435+
436+
ClearRetryBuffer();
437+
}
438+
421439
internal bool TryAddToRetryBuffer(ReadOnlyMemory<byte> message)
422440
{
423441
lock (Lock)

src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ internal sealed class StatusGrpcCall<TRequest, TResponse> : IGrpcCall<TRequest,
3838

3939
public IClientStreamWriter<TRequest>? ClientStreamWriter => _clientStreamWriter ??= new StatusClientStreamWriter(_status);
4040
public IAsyncStreamReader<TResponse>? ClientStreamReader => _clientStreamReader ??= new StatusStreamReader(_status);
41+
public bool Disposed => true;
4142

4243
public StatusGrpcCall(Status status)
4344
{

test/FunctionalTests/Client/RetryTests.cs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
using System;
2020
using System.Collections.Generic;
2121
using System.Linq;
22+
using System.Runtime.CompilerServices;
2223
using System.Threading;
2324
using System.Threading.Tasks;
2425
using Google.Protobuf;
@@ -356,6 +357,63 @@ Task<DataMessage> UnaryFailure(DataMessage request, ServerCallContext context)
356357
tcs.SetResult(new DataMessage());
357358
}
358359

360+
[Test]
361+
public async Task ServerStreaming_CancellatonTokenSpecified_TokenUnregisteredAndResourcesReleased()
362+
{
363+
Task FakeServerStreamCall(DataMessage request, IServerStreamWriter<DataMessage> responseStream, ServerCallContext context)
364+
{
365+
return Task.CompletedTask;
366+
}
367+
368+
// Arrange
369+
var method = Fixture.DynamicGrpc.AddServerStreamingMethod<DataMessage, DataMessage>(FakeServerStreamCall);
370+
371+
var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(retryableStatusCodes: new List<StatusCode> { StatusCode.DeadlineExceeded });
372+
var channel = CreateChannel(serviceConfig: serviceConfig);
373+
374+
var references = new List<WeakReference>();
375+
376+
// Checking that token register calls don't build up on CTS and create a memory leak.
377+
var cts = new CancellationTokenSource();
378+
379+
// Act
380+
// Send calls in a different method so there is no chance that a stack reference
381+
// to a gRPC call is still alive after calls are complete.
382+
await MakeCallsAsync(channel, method, references, cts.Token).DefaultTimeout();
383+
384+
// Assert
385+
// There is a race when cleaning up cancellation token registry.
386+
// Retry a few times to ensure GC is run after unregister.
387+
await TestHelpers.AssertIsTrueRetryAsync(() =>
388+
{
389+
GC.Collect();
390+
GC.WaitForPendingFinalizers();
391+
392+
for (var i = 0; i < references.Count; i++)
393+
{
394+
if (references[i].IsAlive)
395+
{
396+
return false;
397+
}
398+
}
399+
400+
return true;
401+
}, "Assert that retry call resources are released.");
402+
}
403+
404+
[MethodImpl(MethodImplOptions.NoInlining)]
405+
private static async Task MakeCallsAsync(GrpcChannel channel, Method<DataMessage, DataMessage> method, List<WeakReference> references, CancellationToken cancellationToken)
406+
{
407+
var client = TestClientFactory.Create(channel, method);
408+
for (int i = 0; i < 10; i++)
409+
{
410+
var call = client.ServerStreamingCall(new DataMessage(), new CallOptions(cancellationToken: cancellationToken));
411+
references.Add(new WeakReference(call.ResponseStream));
412+
413+
Assert.IsFalse(await call.ResponseStream.MoveNext());
414+
}
415+
}
416+
359417
[TestCase(1)]
360418
[TestCase(20)]
361419
public async Task Unary_AttemptsGreaterThanDefaultClientLimit_LimitedAttemptsMade(int hedgingDelay)

test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ public async Task AsyncUnaryCall_CancellationDuringBackoff_CanceledStatus()
331331

332332
// Act
333333
hedgingCall.StartUnary(new HelloRequest());
334+
Assert.IsNotNull(hedgingCall._ctsRegistration);
334335

335336
// Assert
336337
await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 0, "Wait for all calls to fail.").DefaultTimeout();
@@ -340,6 +341,37 @@ public async Task AsyncUnaryCall_CancellationDuringBackoff_CanceledStatus()
340341
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => hedgingCall.GetResponseAsync()).DefaultTimeout();
341342
Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode);
342343
Assert.AreEqual("Call canceled by the client.", ex.Status.Detail);
344+
Assert.IsNull(hedgingCall._ctsRegistration);
345+
}
346+
347+
[Test]
348+
public async Task AsyncUnaryCall_CancellationTokenSuccess_CleanedUp()
349+
{
350+
// Arrange
351+
var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
352+
var httpClient = ClientTestHelpers.CreateTestClient(async request =>
353+
{
354+
await tcs.Task;
355+
356+
var reply = new HelloReply { Message = "Hello world" };
357+
var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout();
358+
return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent);
359+
});
360+
var cts = new CancellationTokenSource();
361+
var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromSeconds(10));
362+
var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig);
363+
var hedgingCall = new HedgingCall<HelloRequest, HelloReply>(CreateHedgingPolicy(serviceConfig.MethodConfigs[0].HedgingPolicy), invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions(cancellationToken: cts.Token));
364+
365+
// Act
366+
hedgingCall.StartUnary(new HelloRequest());
367+
Assert.IsNotNull(hedgingCall._ctsRegistration);
368+
tcs.SetResult(null);
369+
370+
// Assert
371+
await hedgingCall.GetResponseAsync().DefaultTimeout();
372+
373+
// There is a race between unregistering and GetResponseAsync returning.
374+
await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._ctsRegistration == null, "Hedge call CTS unregistered.");
343375
}
344376

345377
[Test]

0 commit comments

Comments
 (0)