Skip to content

Commit 126681d

Browse files
authored
Fix method discovery when there are duplicate names (#591)
1 parent db7e27d commit 126681d

File tree

7 files changed

+348
-134
lines changed

7 files changed

+348
-134
lines changed

src/Grpc.AspNetCore.Server/Model/Internal/BinderServiceModelProvider.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ public void OnServiceMethodDiscovery(ServiceMethodProviderContext<TService> cont
3737
// Invoke BindService(ServiceBinderBase, BaseType)
3838
if (bindMethodInfo != null)
3939
{
40-
var binder = new ProviderServiceBinder<TService>(context);
40+
// The second parameter is always the service base type
41+
var serviceParameter = bindMethodInfo.GetParameters()[1];
42+
43+
var binder = new ProviderServiceBinder<TService>(context, serviceParameter.ParameterType);
4144

4245
try
4346
{

src/Grpc.AspNetCore.Server/Model/Internal/ProviderServiceBinder.cs

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,46 +18,62 @@
1818

1919
using System;
2020
using System.Collections.Generic;
21+
using System.Reflection;
2122
using Grpc.Core;
2223

2324
namespace Grpc.AspNetCore.Server.Model.Internal
2425
{
2526
internal class ProviderServiceBinder<TService> : ServiceBinderBase where TService : class
2627
{
2728
private readonly ServiceMethodProviderContext<TService> _context;
29+
private readonly Type _declaringType;
2830

29-
internal ProviderServiceBinder(ServiceMethodProviderContext<TService> context)
31+
internal ProviderServiceBinder(ServiceMethodProviderContext<TService> context, Type declaringType)
3032
{
3133
_context = context;
34+
_declaringType = declaringType;
3235
}
3336

3437
public override void AddMethod<TRequest, TResponse>(Method<TRequest, TResponse> method, ClientStreamingServerMethod<TRequest, TResponse> handler)
3538
{
36-
var (invoker, metadata) = CreateModelCore<ClientStreamingServerMethod<TService, TRequest, TResponse>>(method.Name);
39+
var (invoker, metadata) = CreateModelCore<ClientStreamingServerMethod<TService, TRequest, TResponse>>(
40+
method.Name,
41+
new[] { typeof(IAsyncStreamReader<TRequest>), typeof(ServerCallContext) });
42+
3743
_context.AddClientStreamingMethod<TRequest, TResponse>(method, metadata, invoker);
3844
}
3945

4046
public override void AddMethod<TRequest, TResponse>(Method<TRequest, TResponse> method, DuplexStreamingServerMethod<TRequest, TResponse> handler)
4147
{
42-
var (invoker, metadata) = CreateModelCore<DuplexStreamingServerMethod<TService, TRequest, TResponse>>(method.Name);
48+
var (invoker, metadata) = CreateModelCore<DuplexStreamingServerMethod<TService, TRequest, TResponse>>(
49+
method.Name,
50+
new[] { typeof(IAsyncStreamReader<TRequest>), typeof(IServerStreamWriter<TResponse>), typeof(ServerCallContext) });
51+
4352
_context.AddDuplexStreamingMethod<TRequest, TResponse>(method, metadata, invoker);
4453
}
4554

4655
public override void AddMethod<TRequest, TResponse>(Method<TRequest, TResponse> method, ServerStreamingServerMethod<TRequest, TResponse> handler)
4756
{
48-
var (invoker, metadata) = CreateModelCore<ServerStreamingServerMethod<TService, TRequest, TResponse>>(method.Name);
57+
var (invoker, metadata) = CreateModelCore<ServerStreamingServerMethod<TService, TRequest, TResponse>>(
58+
method.Name,
59+
new[] { typeof(TRequest), typeof(IServerStreamWriter<TResponse>), typeof(ServerCallContext) });
60+
4961
_context.AddServerStreamingMethod<TRequest, TResponse>(method, metadata, invoker);
5062
}
5163

5264
public override void AddMethod<TRequest, TResponse>(Method<TRequest, TResponse> method, UnaryServerMethod<TRequest, TResponse> handler)
5365
{
54-
var (invoker, metadata) = CreateModelCore<UnaryServerMethod<TService, TRequest, TResponse>>(method.Name);
66+
var (invoker, metadata) = CreateModelCore<UnaryServerMethod<TService, TRequest, TResponse>>(
67+
method.Name,
68+
new[] { typeof(TRequest), typeof(ServerCallContext) });
69+
5570
_context.AddUnaryMethod<TRequest, TResponse>(method, metadata, invoker);
5671
}
5772

58-
private (TDelegate invoker, List<object> metadata) CreateModelCore<TDelegate>(string methodName) where TDelegate : Delegate
73+
private (TDelegate invoker, List<object> metadata) CreateModelCore<TDelegate>(string methodName, Type[] methodParameters) where TDelegate : Delegate
5974
{
60-
var handlerMethod = typeof(TService).GetMethod(methodName);
75+
var handlerMethod = GetMethod(methodName, methodParameters);
76+
6177
if (handlerMethod == null)
6278
{
6379
throw new InvalidOperationException($"Could not find '{methodName}' on {typeof(TService)}.");
@@ -73,5 +89,40 @@ public override void AddMethod<TRequest, TResponse>(Method<TRequest, TResponse>
7389

7490
return (invoker, metadata);
7591
}
92+
93+
private MethodInfo? GetMethod(string methodName, Type[] methodParameters)
94+
{
95+
Type? currentType = typeof(TService);
96+
while (currentType != null)
97+
{
98+
var matchingMethod = currentType.GetMethod(
99+
methodName,
100+
BindingFlags.Public | BindingFlags.Instance,
101+
binder: null,
102+
types: methodParameters,
103+
modifiers: null);
104+
105+
if (matchingMethod == null)
106+
{
107+
return null;
108+
}
109+
110+
// Validate that the method overrides the virtual method on the base service type.
111+
// If there is a method with the same name it will hide the base method. Ignore it,
112+
// and continue searching on the base type.
113+
if (matchingMethod.IsVirtual)
114+
{
115+
var baseDefinitionMethod = matchingMethod.GetBaseDefinition();
116+
if (baseDefinitionMethod != null && baseDefinitionMethod.DeclaringType == _declaringType)
117+
{
118+
return matchingMethod;
119+
}
120+
}
121+
122+
currentType = currentType.BaseType;
123+
}
124+
125+
return null;
126+
}
76127
}
77128
}

test/Grpc.AspNetCore.Server.Tests/CallHandlerTests.cs

Lines changed: 12 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
using Grpc.AspNetCore.Server.Internal;
2626
using Grpc.AspNetCore.Server.Internal.CallHandlers;
2727
using Grpc.AspNetCore.Server.Tests.Infrastructure;
28+
using Grpc.AspNetCore.Server.Tests.TestObjects;
2829
using Grpc.Core;
2930
using Grpc.Tests.Shared;
3031
using Microsoft.AspNetCore.Http;
@@ -50,7 +51,7 @@ public class CallHandlerTests
5051
public async Task MinRequestBodyDataRateFeature_MethodType_HasRequestBodyDataRate(MethodType methodType, bool hasRequestBodyDataRate)
5152
{
5253
// Arrange
53-
var httpContext = CreateContext();
54+
var httpContext = HttpContextHelpers.CreateContext();
5455
var call = CreateHandler(methodType);
5556

5657
// Act
@@ -67,7 +68,7 @@ public async Task MinRequestBodyDataRateFeature_MethodType_HasRequestBodyDataRat
6768
public async Task MaxRequestBodySizeFeature_MethodType_HasMaxRequestBodySize(MethodType methodType, bool hasMaxRequestBodySize)
6869
{
6970
// Arrange
70-
var httpContext = CreateContext();
71+
var httpContext = HttpContextHelpers.CreateContext();
7172
var call = CreateHandler(methodType);
7273

7374
// Act
@@ -84,7 +85,7 @@ public async Task MaxRequestBodySizeFeature_FeatureIsReadOnly_FailureLogged()
8485
var testSink = new TestSink();
8586
var testLoggerFactory = new TestLoggerFactory(testSink, true);
8687

87-
var httpContext = CreateContext(isMaxRequestBodySizeFeatureReadOnly: true);
88+
var httpContext = HttpContextHelpers.CreateContext(isMaxRequestBodySizeFeatureReadOnly: true);
8889
var call = CreateHandler(MethodType.ClientStreaming, testLoggerFactory);
8990

9091
// Act
@@ -102,7 +103,7 @@ public async Task ContentTypeValidation_InvalidContentType_FailureLogged()
102103
var testSink = new TestSink();
103104
var testLoggerFactory = new TestLoggerFactory(testSink, true);
104105

105-
var httpContext = CreateContext(contentType: "text/plain");
106+
var httpContext = HttpContextHelpers.CreateContext(contentType: "text/plain");
106107
var call = CreateHandler(MethodType.ClientStreaming, testLoggerFactory);
107108

108109
// Act
@@ -121,7 +122,7 @@ public async Task SetResponseTrailers_FeatureMissing_ThrowError()
121122
var testSink = new TestSink();
122123
var testLoggerFactory = new TestLoggerFactory(testSink, true);
123124

124-
var httpContext = CreateContext(skipTrailerFeatureSet: true);
125+
var httpContext = HttpContextHelpers.CreateContext(skipTrailerFeatureSet: true);
125126
var call = CreateHandler(MethodType.ClientStreaming, testLoggerFactory);
126127

127128
// Act
@@ -138,7 +139,7 @@ public async Task ProtocolValidation_InvalidProtocol_FailureLogged()
138139
var testSink = new TestSink();
139140
var testLoggerFactory = new TestLoggerFactory(testSink, true);
140141

141-
var httpContext = CreateContext(protocol: "HTTP/1.1");
142+
var httpContext = HttpContextHelpers.CreateContext(protocol: "HTTP/1.1");
142143
var call = CreateHandler(MethodType.ClientStreaming, testLoggerFactory);
143144

144145
// Act
@@ -157,7 +158,7 @@ public async Task ProtocolValidation_IISHttp2Protocol_Success()
157158
var testSink = new TestSink();
158159
var testLoggerFactory = new TestLoggerFactory(testSink, true);
159160

160-
var httpContext = CreateContext(protocol: GrpcProtocolConstants.Http20Protocol);
161+
var httpContext = HttpContextHelpers.CreateContext(protocol: GrpcProtocolConstants.Http20Protocol);
161162
var call = CreateHandler(MethodType.ClientStreaming, testLoggerFactory);
162163

163164
// Act
@@ -210,125 +211,6 @@ private static ServerCallHandlerBase<TestService, TestMessage, TestMessage> Crea
210211
throw new ArgumentException();
211212
}
212213
}
213-
214-
private static HttpContext CreateContext(
215-
bool isMaxRequestBodySizeFeatureReadOnly = false,
216-
bool skipTrailerFeatureSet = false,
217-
string? protocol = null,
218-
string? contentType = null)
219-
{
220-
var httpContext = new DefaultHttpContext();
221-
var responseFeature = new TestHttpResponseFeature();
222-
var responseBodyFeature = new TestHttpResponseBodyFeature(httpContext.Features.Get<IHttpResponseBodyFeature>(), responseFeature);
223-
224-
httpContext.Request.Protocol = protocol ?? GrpcProtocolConstants.Http2Protocol;
225-
httpContext.Request.ContentType = contentType ?? GrpcProtocolConstants.GrpcContentType;
226-
httpContext.Features.Set<IHttpMinRequestBodyDataRateFeature>(new TestMinRequestBodyDataRateFeature());
227-
httpContext.Features.Set<IHttpMaxRequestBodySizeFeature>(new TestMaxRequestBodySizeFeature(isMaxRequestBodySizeFeatureReadOnly, 100));
228-
httpContext.Features.Set<IHttpResponseFeature>(responseFeature);
229-
httpContext.Features.Set<IHttpResponseBodyFeature>(responseBodyFeature);
230-
if (!skipTrailerFeatureSet)
231-
{
232-
httpContext.Features.Set<IHttpResponseTrailersFeature>(new TestHttpResponseTrailersFeature());
233-
}
234-
235-
return httpContext;
236-
}
237-
}
238-
239-
public class TestService { }
240-
241-
public class TestMessage { }
242-
243-
public class TestHttpResponseBodyFeature : IHttpResponseBodyFeature
244-
{
245-
private readonly IHttpResponseBodyFeature _innerResponseBodyFeature;
246-
private readonly TestHttpResponseFeature _responseFeature;
247-
248-
public Stream Stream => _innerResponseBodyFeature.Stream;
249-
public PipeWriter Writer => _innerResponseBodyFeature.Writer;
250-
251-
public TestHttpResponseBodyFeature(IHttpResponseBodyFeature innerResponseBodyFeature, TestHttpResponseFeature responseFeature)
252-
{
253-
_innerResponseBodyFeature = innerResponseBodyFeature ?? throw new ArgumentNullException(nameof(innerResponseBodyFeature));
254-
_responseFeature = responseFeature ?? throw new ArgumentNullException(nameof(responseFeature));
255-
}
256-
257-
public Task CompleteAsync()
258-
{
259-
return _innerResponseBodyFeature.CompleteAsync();
260-
}
261-
262-
public void DisableBuffering()
263-
{
264-
_innerResponseBodyFeature.DisableBuffering();
265-
}
266-
267-
public Task SendFileAsync(string path, long offset, long? count, CancellationToken cancellationToken = default)
268-
{
269-
return _innerResponseBodyFeature.SendFileAsync(path, offset, count, cancellationToken);
270-
}
271-
272-
public Task StartAsync(CancellationToken cancellationToken = default)
273-
{
274-
_responseFeature.HasStarted = true;
275-
return _innerResponseBodyFeature.StartAsync(cancellationToken);
276-
}
277-
}
278-
279-
public class TestHttpResponseFeature : IHttpResponseFeature
280-
{
281-
public Stream Body { get; set; }
282-
public bool HasStarted { get; internal set; }
283-
public IHeaderDictionary Headers { get; set; }
284-
public string? ReasonPhrase { get; set; }
285-
public int StatusCode { get; set; }
286-
287-
public TestHttpResponseFeature()
288-
{
289-
StatusCode = 200;
290-
Headers = new HeaderDictionary();
291-
Body = Stream.Null;
292-
}
293-
294-
public void OnCompleted(Func<object, Task> callback, object state)
295-
{
296-
}
297-
298-
public void OnStarting(Func<object, Task> callback, object state)
299-
{
300-
HasStarted = true;
301-
}
302-
}
303-
304-
public class TestMinRequestBodyDataRateFeature : IHttpMinRequestBodyDataRateFeature
305-
{
306-
public MinDataRate MinDataRate { get; set; } = new MinDataRate(1, TimeSpan.FromSeconds(5));
307-
}
308-
309-
public class TestMaxRequestBodySizeFeature : IHttpMaxRequestBodySizeFeature
310-
{
311-
public TestMaxRequestBodySizeFeature(bool isReadOnly, long? maxRequestBodySize)
312-
{
313-
IsReadOnly = isReadOnly;
314-
MaxRequestBodySize = maxRequestBodySize;
315-
}
316-
317-
public bool IsReadOnly { get; }
318-
public long? MaxRequestBodySize { get; set; }
319-
}
320-
321-
internal class TestGrpcServiceActivator<TGrpcService> : IGrpcServiceActivator<TGrpcService> where TGrpcService : class, new()
322-
{
323-
public GrpcActivatorHandle<TGrpcService> Create(IServiceProvider serviceProvider)
324-
{
325-
return new GrpcActivatorHandle<TGrpcService>(new TGrpcService(), false, null);
326-
}
327-
328-
public ValueTask ReleaseAsync(GrpcActivatorHandle<TGrpcService> service)
329-
{
330-
return default;
331-
}
332214
}
333215

334216
public class TestServiceProvider : IServiceProvider
@@ -340,4 +222,8 @@ public object GetService(Type serviceType)
340222
throw new NotImplementedException();
341223
}
342224
}
225+
226+
public class TestService { }
227+
228+
public class TestMessage { }
343229
}

0 commit comments

Comments
 (0)