diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs index e481dfda5..7f54895f1 100644 --- a/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs @@ -529,7 +529,7 @@ protected void HandleUnexpectedError(Exception ex) CommitReason commitReason; // Cancellation token triggered by dispose could throw here. - if (ex is OperationCanceledException && CancellationTokenSource.IsCancellationRequested) + if (ex is OperationCanceledException operationCanceledException && CancellationTokenSource.IsCancellationRequested) { // Cancellation could have been caused by an exceeded deadline. if (IsDeadlineExceeded()) @@ -542,7 +542,21 @@ protected void HandleUnexpectedError(Exception ex) else { commitReason = CommitReason.Canceled; - resolvedCall = CreateStatusCall(Disposed ? GrpcProtocolConstants.CreateDisposeCanceledStatus(ex) : GrpcProtocolConstants.CreateClientCanceledStatus(ex)); + Status status; + if (Disposed) + { + status = GrpcProtocolConstants.CreateDisposeCanceledStatus(exception: null); + } + else + { + // Replace the OCE from CancellationTokenSource with an OCE that has the passed in cancellation token if it is canceled. + if (Options.CancellationToken.IsCancellationRequested && Options.CancellationToken != operationCanceledException.CancellationToken) + { + ex = new OperationCanceledException(Options.CancellationToken); + } + status = GrpcProtocolConstants.CreateClientCanceledStatus(ex); + } + resolvedCall = CreateStatusCall(status); } } else diff --git a/src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs b/src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs index 0b9426aaf..00885b9fc 100644 --- a/src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs +++ b/src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs @@ -56,12 +56,21 @@ public void Dispose() public Task GetResponseAsync() { - return Task.FromException(new RpcException(_status)); + return CreateErrorTask(); } public Task GetResponseHeadersAsync() { - return Task.FromException(new RpcException(_status)); + return CreateErrorTask(); + } + + private Task CreateErrorTask() + { + if (_channel.ThrowOperationCanceledOnCancellation && _status.DebugException is OperationCanceledException ex) + { + return Task.FromException(ex); + } + return Task.FromException(new RpcException(_status)); } public Status GetStatus() diff --git a/test/FunctionalTests/Client/CancellationTests.cs b/test/FunctionalTests/Client/CancellationTests.cs index 0b85d093f..20af72dbd 100644 --- a/test/FunctionalTests/Client/CancellationTests.cs +++ b/test/FunctionalTests/Client/CancellationTests.cs @@ -509,6 +509,39 @@ async Task UnaryMethod(DataMessage request, ServerCallContext conte Assert.AreEqual(StatusCode.Cancelled, call.GetStatus().StatusCode); } + [Test] + public async Task Unary_Retry_CancellationImmediately_TokenMatchesSource() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + async Task UnaryMethod(DataMessage request, ServerCallContext context) + { + await tcs.Task; + return new DataMessage(); + } + + SetExpectedErrorsFilter(writeContext => + { + return true; + }); + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryMethod); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var channel = CreateChannel(throwOperationCanceledOnCancellation: true, serviceConfig: serviceConfig); + var client = TestClientFactory.Create(channel, method); + + // Act + var cts = new CancellationTokenSource(); + cts.Cancel(); + + var call = client.UnaryCall(new DataMessage(), new CallOptions(cancellationToken: cts.Token)); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(cts.Token, ex.CancellationToken); + Assert.AreEqual(StatusCode.Cancelled, call.GetStatus().StatusCode); + } + [Test] public async Task ServerStreaming_CancellationDuringCall_TokenMatchesSource() {