Skip to content

Commit 15158af

Browse files
authored
Fix client timeout when error in custom CallCredential (#1388)
1 parent 95ea1dc commit 15158af

File tree

6 files changed

+85
-16
lines changed

6 files changed

+85
-16
lines changed

src/Grpc.Net.Client/Balancer/Internal/SocketConnectivitySubchannelTransport.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
using System.Collections.Generic;
2222
using System.Diagnostics;
2323
using System.IO;
24+
using System.Linq;
2425
using System.Net;
2526
using System.Net.Http;
2627
using System.Net.Sockets;
@@ -51,7 +52,7 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi
5152
private readonly ILogger _logger;
5253
private readonly Subchannel _subchannel;
5354
private readonly TimeSpan _socketPingInterval;
54-
internal readonly List<(DnsEndPoint EndPoint, Socket Socket, Stream? Stream)> _activeStreams;
55+
private readonly List<(DnsEndPoint EndPoint, Socket Socket, Stream? Stream)> _activeStreams;
5556
private readonly Timer _socketConnectedTimer;
5657

5758
private int _lastEndPointIndex;
@@ -73,6 +74,15 @@ public SocketConnectivitySubchannelTransport(Subchannel subchannel, TimeSpan soc
7374
public DnsEndPoint? CurrentEndPoint => _currentEndPoint;
7475
public bool HasStream { get; }
7576

77+
// For testing. Take a copy under lock for thread-safety.
78+
internal IReadOnlyList<(DnsEndPoint EndPoint, Socket Socket, Stream? Stream)> GetActiveStreams()
79+
{
80+
lock (Lock)
81+
{
82+
return _activeStreams.ToList();
83+
}
84+
}
85+
7686
public void Disconnect()
7787
{
7888
lock (Lock)

src/Grpc.Net.Client/Internal/GrpcCall.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -415,18 +415,20 @@ private async Task RunCall(HttpRequestMessage request, TimeSpan? timeout)
415415
{
416416
var (diagnosticSourceEnabled, activity) = InitializeCall(request, timeout);
417417

418-
if (Options.Credentials != null || Channel.CallCredentials?.Count > 0)
419-
{
420-
await ReadCredentials(request).ConfigureAwait(false);
421-
}
422-
423418
// Unset variable to check that FinishCall is called in every code path
424419
bool finished;
425420

426421
Status? status = null;
427422

428423
try
429424
{
425+
// User supplied call credentials could error and so must be run
426+
// inside try/catch to handle errors.
427+
if (Options.Credentials != null || Channel.CallCredentials?.Count > 0)
428+
{
429+
await ReadCredentials(request).ConfigureAwait(false);
430+
}
431+
430432
// Fail early if deadline has already been exceeded
431433
_callCts.Token.ThrowIfCancellationRequested();
432434

test/FunctionalTests/Balancer/ConnectionTests.cs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,16 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
110110
var balancer = BalancerHelpers.GetInnerLoadBalancer<PickFirstBalancer>(channel)!;
111111
var subchannel = balancer._subchannel!;
112112
var transport = (SocketConnectivitySubchannelTransport)subchannel.Transport;
113-
var activeStreams = transport._activeStreams;
113+
var activeStreams = transport.GetActiveStreams();
114114

115115
// Assert
116116
Assert.AreEqual(HttpHandlerType.SocketsHttpHandler, channel.HttpHandlerType);
117117

118-
await TestHelpers.AssertIsTrueRetryAsync(() => activeStreams.Count == 10, "Wait for connections to start.");
118+
await TestHelpers.AssertIsTrueRetryAsync(() =>
119+
{
120+
activeStreams = transport.GetActiveStreams();
121+
return activeStreams.Count == 10;
122+
}, "Wait for connections to start.");
119123
foreach (var t in activeStreams)
120124
{
121125
Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50051), t.EndPoint);
@@ -134,7 +138,11 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
134138
Logger.LogInformation($"Done sending gRPC calls");
135139

136140
// Assert
137-
await TestHelpers.AssertIsTrueRetryAsync(() => activeStreams.Count == 11, "Wait for connections to start.");
141+
await TestHelpers.AssertIsTrueRetryAsync(() =>
142+
{
143+
activeStreams = transport.GetActiveStreams();
144+
return activeStreams.Count == 11;
145+
}, "Wait for connections to start.");
138146
Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50051), activeStreams.Last().EndPoint);
139147

140148
tcs.SetResult(null);
@@ -149,7 +157,11 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
149157

150158
// There are still be 10 HTTP/1.1 connections because they aren't immediately removed
151159
// when the server is shutdown and connectivity is lost.
152-
await TestHelpers.AssertIsTrueRetryAsync(() => activeStreams.Count == 10, "Wait for HTTP/2 connection to end.");
160+
await TestHelpers.AssertIsTrueRetryAsync(() =>
161+
{
162+
activeStreams = transport.GetActiveStreams();
163+
return activeStreams.Count == 10;
164+
}, "Wait for HTTP/2 connection to end.");
153165

154166
grpcWebHandler.HttpVersion = new Version(1, 1);
155167

@@ -160,6 +172,7 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
160172
Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode);
161173

162174
// Removed by failed call.
175+
activeStreams = transport.GetActiveStreams();
163176
Assert.AreEqual(0, activeStreams.Count);
164177
Assert.AreEqual(ConnectivityState.Idle, channel.State);
165178

@@ -168,6 +181,7 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
168181
Assert.AreEqual("Balancer", reply.Message);
169182
Assert.AreEqual("127.0.0.1:50052", host);
170183

184+
activeStreams = transport.GetActiveStreams();
171185
Assert.AreEqual(1, activeStreams.Count);
172186
Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50052), activeStreams[0].EndPoint);
173187
}

test/FunctionalTests/Balancer/PickFirstBalancerTests.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
246246
var balancer = BalancerHelpers.GetInnerLoadBalancer<PickFirstBalancer>(channel)!;
247247
var subchannel = balancer._subchannel!;
248248
var transport = (SocketConnectivitySubchannelTransport)subchannel.Transport;
249-
var activeStreams = transport._activeStreams;
249+
var activeStreams = transport.GetActiveStreams();
250250

251251
// Assert
252252
Assert.AreEqual(2, activeStreams.Count);
@@ -263,6 +263,7 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
263263

264264
await TestHelpers.AssertIsTrueRetryAsync(() =>
265265
{
266+
activeStreams = transport.GetActiveStreams();
266267
Logger.LogInformation($"Current active stream endpoints: {string.Join(", ", activeStreams.Select(s => s.EndPoint))}");
267268
return activeStreams.Count == 0;
268269
}, "Active streams removed.", Logger).DefaultTimeout();
@@ -271,6 +272,7 @@ await TestHelpers.AssertIsTrueRetryAsync(() =>
271272
Assert.AreEqual("Balancer", reply.Message);
272273
Assert.AreEqual("127.0.0.1:50052", host);
273274

275+
activeStreams = transport.GetActiveStreams();
274276
Assert.AreEqual(1, activeStreams.Count);
275277
Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50052), activeStreams[0].EndPoint);
276278
}

test/Grpc.Net.Client.Tests/CallCredentialTests.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
using Grpc.Net.Client.Internal;
2929
using Grpc.Net.Client.Tests.Infrastructure;
3030
using Grpc.Tests.Shared;
31+
using Microsoft.Extensions.DependencyInjection;
32+
using Microsoft.Extensions.Logging;
3133
using Microsoft.Extensions.Logging.Testing;
3234
using NUnit.Framework;
3335

@@ -36,6 +38,37 @@ namespace Grpc.Net.Client.Tests
3638
[TestFixture]
3739
public class CallCredentialTests
3840
{
41+
[Test]
42+
public async Task CallCredentialsWithHttps_WhenAsyncAuthInterceptorThrow_ShouldThrow()
43+
{
44+
// Arrange
45+
var services = new ServiceCollection();
46+
services.AddNUnitLogger();
47+
var loggerFactory = services.BuildServiceProvider().GetRequiredService<ILoggerFactory>();
48+
49+
var httpClient = ClientTestHelpers.CreateTestClient(async request =>
50+
{
51+
var reply = new HelloReply { Message = "Hello world" };
52+
var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout();
53+
return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent);
54+
});
55+
var invoker = HttpClientCallInvokerFactory.Create(httpClient, loggerFactory);
56+
57+
// Act
58+
var expectedException = new Exception("Some AsyncAuthInterceptor Exception");
59+
60+
var callCredentials = CallCredentials.FromInterceptor((context, metadata) =>
61+
{
62+
return Task.FromException(expectedException);
63+
});
64+
65+
var call = invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(credentials: callCredentials), new HelloRequest());
66+
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.ResponseAsync).DefaultTimeout();
67+
68+
// Assert
69+
Assert.AreSame(expectedException, ex.Status.DebugException);
70+
}
71+
3972
[Test]
4073
public async Task CallCredentialsWithHttps_MetadataOnRequest()
4174
{

test/Shared/HttpEventSourceListener.cs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#endregion
1818

19-
using System;
2019
using System.Diagnostics.Tracing;
2120
using System.Text;
2221
using Microsoft.Extensions.Logging;
@@ -26,7 +25,7 @@ namespace Grpc.Tests.Shared
2625
public sealed class HttpEventSourceListener : EventListener
2726
{
2827
private readonly StringBuilder _messageBuilder = new StringBuilder();
29-
private readonly ILogger _logger;
28+
private readonly ILogger? _logger;
3029
private readonly object _lock = new object();
3130
private bool _disposed;
3231

@@ -43,7 +42,13 @@ protected override void OnEventSourceCreated(EventSource eventSource)
4342
if (eventSource.Name.Contains("System.Net.Quic") ||
4443
eventSource.Name.Contains("System.Net.Http"))
4544
{
46-
EnableEvents(eventSource, EventLevel.LogAlways, EventKeywords.All);
45+
lock (_lock)
46+
{
47+
if (!_disposed)
48+
{
49+
EnableEvents(eventSource, EventLevel.LogAlways, EventKeywords.All);
50+
}
51+
}
4752
}
4853
}
4954

@@ -72,7 +77,10 @@ protected override void OnEventWritten(EventWrittenEventArgs eventData)
7277
{
7378
if (!_disposed)
7479
{
75-
_logger.LogDebug(message);
80+
// EventListener base constructor subscribes to events.
81+
// It is possible to start getting events before the
82+
// super constructor is run and logger is assigned.
83+
_logger?.LogDebug(message);
7684
}
7785
}
7886
}
@@ -90,7 +98,7 @@ public override void Dispose()
9098
{
9199
if (!_disposed)
92100
{
93-
_logger.LogDebug($"Stopping {nameof(HttpEventSourceListener)}.");
101+
_logger?.LogDebug($"Stopping {nameof(HttpEventSourceListener)}.");
94102
_disposed = true;
95103
}
96104
}

0 commit comments

Comments
 (0)