Skip to content

Commit 6acecf4

Browse files
authored
Fix client with custom ConnectCallback (#1542)
1 parent 3ebcc93 commit 6acecf4

File tree

6 files changed

+148
-6
lines changed

6 files changed

+148
-6
lines changed

src/Grpc.Net.Client/Balancer/Internal/BalancerHttpHandler.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,17 @@ internal class BalancerHttpHandler : DelegatingHandler
3535

3636
private readonly ConnectionManager _manager;
3737

38-
public BalancerHttpHandler(HttpMessageHandler innerHandler, ConnectionManager manager)
38+
public BalancerHttpHandler(HttpMessageHandler innerHandler, HttpHandlerType httpHandlerType, ConnectionManager manager)
3939
: base(innerHandler)
4040
{
4141
_manager = manager;
4242

4343
#if NET5_0_OR_GREATER
44-
var socketsHttpHandler = HttpRequestHelpers.GetHttpHandlerType<SocketsHttpHandler>(innerHandler);
45-
if (socketsHttpHandler != null)
44+
if (httpHandlerType == HttpHandlerType.SocketsHttpHandler)
4645
{
46+
var socketsHttpHandler = HttpRequestHelpers.GetHttpHandlerType<SocketsHttpHandler>(innerHandler);
47+
CompatibilityHelpers.Assert(socketsHttpHandler != null, "Should have handler with this handler type.");
48+
4749
socketsHttpHandler.ConnectCallback = OnConnect;
4850
}
4951
#endif

src/Grpc.Net.Client/GrpcChannel.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ private HttpMessageInvoker CreateInternalHttpInvoker(HttpMessageHandler? handler
339339
#endif
340340

341341
#if SUPPORT_LOAD_BALANCING
342-
handler = new BalancerHttpHandler(handler, ConnectionManager);
342+
handler = new BalancerHttpHandler(handler, HttpHandlerType, ConnectionManager);
343343
#endif
344344

345345
// Use HttpMessageInvoker instead of HttpClient because it is faster

test/FunctionalTests/Client/ConnectionTests.cs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616

1717
#endregion
1818

19+
using System.Net;
20+
using System.Net.Sockets;
1921
using Greet;
2022
using Grpc.AspNetCore.FunctionalTests.Infrastructure;
2123
using Grpc.Core;
2224
using Grpc.Net.Client;
2325
using Grpc.Tests.Shared;
26+
using Microsoft.AspNetCore.Connections.Features;
2427
using NUnit.Framework;
2528

2629
namespace Grpc.AspNetCore.FunctionalTests.Client
@@ -60,7 +63,42 @@ public async Task ALPN_ProtocolDowngradedToHttp1_ThrowErrorFromServer()
6063
#endif
6164
}
6265

63-
#if NET6_0
66+
#if NET5_0_OR_GREATER
67+
[Test]
68+
public async Task UnixDomainSockets()
69+
{
70+
Task<HelloReply> UnaryUds(HelloRequest request, ServerCallContext context)
71+
{
72+
#if NET6_0_OR_GREATER
73+
var endPoint = (UnixDomainSocketEndPoint)context.GetHttpContext().Features.Get<IConnectionSocketFeature>()!.Socket.LocalEndPoint!;
74+
Assert.NotNull(endPoint);
75+
#endif
76+
77+
return Task.FromResult(new HelloReply { Message = "Hello " + request.Name });
78+
}
79+
80+
// Arrange
81+
var method = Fixture.DynamicGrpc.AddUnaryMethod<HelloRequest, HelloReply>(UnaryUds);
82+
83+
var http = Fixture.CreateHandler(TestServerEndpointName.UnixDomainSocket);
84+
85+
var channel = GrpcChannel.ForAddress(http.address, new GrpcChannelOptions
86+
{
87+
LoggerFactory = LoggerFactory,
88+
HttpHandler = http.handler
89+
});
90+
91+
var client = TestClientFactory.Create(channel, method);
92+
93+
// Act
94+
var response = await client.UnaryCall(new HelloRequest { Name = "John" }).ResponseAsync.DefaultTimeout();
95+
96+
// Assert
97+
Assert.AreEqual("Hello John", response.Message);
98+
}
99+
#endif
100+
101+
#if NET6_0_OR_GREATER
64102
[Test]
65103
[RequireHttp3]
66104
public async Task Http3()
@@ -79,6 +117,7 @@ public async Task Http3()
79117
// Act
80118
var response = await client.SayHelloAsync(new HelloRequest { Name = "John" }).ResponseAsync.DefaultTimeout();
81119

120+
// Assert
82121
Assert.AreEqual("Hello John", response.Message);
83122
}
84123
#endif

test/FunctionalTests/Infrastructure/GrpcTestFixture.cs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#endregion
1818

19+
using System.Net.Sockets;
1920
using Microsoft.AspNetCore.Hosting;
2021
using Microsoft.AspNetCore.Server.Kestrel.Core;
2122
using Microsoft.Extensions.DependencyInjection;
@@ -25,6 +26,7 @@ namespace Grpc.AspNetCore.FunctionalTests.Infrastructure
2526
{
2627
public class GrpcTestFixture<TStartup> : IDisposable where TStartup : class
2728
{
29+
private readonly string _socketPath = Path.Combine(Path.GetTempPath(), "grpc-transporter.tmp");
2830
private readonly InProcessTestServer _server;
2931

3032
public GrpcTestFixture(
@@ -86,6 +88,19 @@ public GrpcTestFixture(
8688
listenOptions.UseHttps(certPath, "1111");
8789
});
8890

91+
#if NET5_0_OR_GREATER
92+
if (File.Exists(_socketPath))
93+
{
94+
File.Delete(_socketPath);
95+
}
96+
97+
urls[TestServerEndpointName.UnixDomainSocket] = _socketPath;
98+
options.ListenUnixSocket(_socketPath, listenOptions =>
99+
{
100+
listenOptions.Protocols = HttpProtocols.Http2;
101+
});
102+
#endif
103+
89104
#if NET6_0_OR_GREATER
90105
if (RequireHttp3Attribute.IsSupported(out _))
91106
{
@@ -145,6 +160,16 @@ public HttpClient CreateClient(TestServerEndpointName? endpointName = null, Dele
145160
RemoteCertificateValidationCallback = (_, __, ___, ____) => true
146161
};
147162

163+
#if NET5_0_OR_GREATER
164+
if (endpointName == TestServerEndpointName.UnixDomainSocket)
165+
{
166+
var udsEndPoint = new UnixDomainSocketEndPoint(_server.GetUrl(endpointName.Value));
167+
var connectionFactory = new UnixDomainSocketConnectionFactory(udsEndPoint);
168+
169+
socketsHttpHandler.ConnectCallback = connectionFactory.ConnectAsync;
170+
}
171+
#endif
172+
148173
HttpClient client;
149174
HttpMessageHandler handler;
150175
if (messageHandler != null)
@@ -175,11 +200,24 @@ public HttpClient CreateClient(TestServerEndpointName? endpointName = null, Dele
175200
client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher;
176201
#endif
177202
}
178-
client.BaseAddress = new Uri(_server.GetUrl(endpointName.Value));
203+
204+
client.BaseAddress = CalculateBaseAddress(endpointName.Value);
179205

180206
return (client, handler);
181207
}
182208

209+
private Uri CalculateBaseAddress(TestServerEndpointName endpointName)
210+
{
211+
#if NET5_0_OR_GREATER
212+
if (endpointName == TestServerEndpointName.UnixDomainSocket)
213+
{
214+
return new Uri("http://localhost");
215+
}
216+
#endif
217+
218+
return new Uri(_server.GetUrl(endpointName));
219+
}
220+
183221
public Uri GetUrl(TestServerEndpointName endpointName)
184222
{
185223
switch (endpointName)
@@ -192,6 +230,10 @@ public Uri GetUrl(TestServerEndpointName endpointName)
192230
case TestServerEndpointName.Http3WithTls:
193231
#endif
194232
return new Uri(_server.GetUrl(endpointName));
233+
#if NET5_0_OR_GREATER
234+
case TestServerEndpointName.UnixDomainSocket:
235+
return new Uri("http://localhost");
236+
#endif
195237
default:
196238
throw new ArgumentException("Unexpected value: " + endpointName, nameof(endpointName));
197239
}

test/FunctionalTests/Infrastructure/TestServerEndpointName.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ public enum TestServerEndpointName
2424
Http1,
2525
Http2WithTls,
2626
Http1WithTls,
27+
#if NET5_0_OR_GREATER
28+
UnixDomainSocket,
29+
#endif
2730
#if NET6_0_OR_GREATER
2831
Http3WithTls,
2932
#endif
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#region Copyright notice and license
2+
3+
// Copyright 2019 The gRPC Authors
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
#endregion
18+
19+
#if NET5_0_OR_GREATER
20+
using System;
21+
using System.IO;
22+
using System.Net;
23+
using System.Net.Http;
24+
using System.Net.Sockets;
25+
using System.Threading;
26+
using System.Threading.Tasks;
27+
28+
namespace Grpc.AspNetCore.FunctionalTests.Infrastructure
29+
{
30+
public class UnixDomainSocketConnectionFactory
31+
{
32+
private readonly EndPoint _endPoint;
33+
34+
public UnixDomainSocketConnectionFactory(EndPoint endPoint)
35+
{
36+
_endPoint = endPoint;
37+
}
38+
39+
public async ValueTask<Stream> ConnectAsync(SocketsHttpConnectionContext _, CancellationToken cancellationToken = default)
40+
{
41+
var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified);
42+
43+
try
44+
{
45+
await socket.ConnectAsync(_endPoint, cancellationToken).ConfigureAwait(false);
46+
return new NetworkStream(socket, true);
47+
}
48+
catch (Exception ex)
49+
{
50+
socket.Dispose();
51+
throw new HttpRequestException($"Error connecting to '{_endPoint}'.", ex);
52+
}
53+
}
54+
}
55+
}
56+
#endif

0 commit comments

Comments
 (0)