diff --git a/src/SignalR/common/Shared/MessageBuffer.cs b/src/SignalR/common/Shared/MessageBuffer.cs index 17b9ae170fe0..f08fff86aa40 100644 --- a/src/SignalR/common/Shared/MessageBuffer.cs +++ b/src/SignalR/common/Shared/MessageBuffer.cs @@ -121,15 +121,16 @@ private async Task RunTimer() public ValueTask WriteAsync(SerializedHubMessage hubMessage, CancellationToken cancellationToken) { - return WriteAsyncCore(hubMessage.Message!, hubMessage.GetSerializedMessage(_protocol), cancellationToken); + // Default to HubInvocationMessage as that's the only type we use SerializedHubMessage for currently when Message is null. Should harden this in the future. + return WriteAsyncCore(hubMessage.Message?.GetType() ?? typeof(HubInvocationMessage), hubMessage.GetSerializedMessage(_protocol), cancellationToken); } public ValueTask WriteAsync(HubMessage hubMessage, CancellationToken cancellationToken) { - return WriteAsyncCore(hubMessage, _protocol.GetMessageBytes(hubMessage), cancellationToken); + return WriteAsyncCore(hubMessage.GetType(), _protocol.GetMessageBytes(hubMessage), cancellationToken); } - private async ValueTask WriteAsyncCore(HubMessage hubMessage, ReadOnlyMemory messageBytes, CancellationToken cancellationToken) + private async ValueTask WriteAsyncCore(Type hubMessageType, ReadOnlyMemory messageBytes, CancellationToken cancellationToken) { // TODO: Add backpressure based on message count if (_bufferedByteCount > _bufferLimit) @@ -158,7 +159,7 @@ private async ValueTask WriteAsyncCore(HubMessage hubMessage, ReadO await _writeLock.WaitAsync(cancellationToken: default).ConfigureAwait(false); try { - if (hubMessage is HubInvocationMessage invocationMessage) + if (typeof(HubInvocationMessage).IsAssignableFrom(hubMessageType)) { _totalMessageCount++; _bufferedByteCount += messageBytes.Length; diff --git a/src/SignalR/server/Core/src/SerializedHubMessage.cs b/src/SignalR/server/Core/src/SerializedHubMessage.cs index e355b0329128..9f4327a4cc58 100644 --- a/src/SignalR/server/Core/src/SerializedHubMessage.cs +++ b/src/SignalR/server/Core/src/SerializedHubMessage.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using Microsoft.AspNetCore.SignalR.Protocol; namespace Microsoft.AspNetCore.SignalR; @@ -40,6 +41,8 @@ public SerializedHubMessage(IReadOnlyList messages) /// The hub message for the cache. This will be serialized with an in to get the message's serialized representation. public SerializedHubMessage(HubMessage message) { + // Type currently only used for invocation messages, we should probably refactor it to be explicit about that e.g. new property for message type? + Debug.Assert(message.GetType().IsAssignableTo(typeof(HubInvocationMessage))); Message = message; } diff --git a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/Internal/MessageBufferTests.cs b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/Internal/MessageBufferTests.cs index a7e87d2b8bda..540ea462e199 100644 --- a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/Internal/MessageBufferTests.cs +++ b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/Internal/MessageBufferTests.cs @@ -2,11 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO.Pipelines; +using System.Text.Json; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.InternalTesting; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; -using Microsoft.AspNetCore.InternalTesting; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Time.Testing; @@ -169,6 +170,62 @@ public async Task UnAckedMessageResentOnReconnect() Assert.False(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null))); } + // Regression test for https://github.com/dotnet/aspnetcore/issues/55575 + [Fact] + public async Task UnAckedSerializedMessageResentOnReconnect() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol, bufferLimit: 1000, NullLogger.Instance); + + var invocationMessage = new SerializedHubMessage([new SerializedMessage(protocol.Name, + protocol.GetMessageBytes(new InvocationMessage("method1", [1])))]); + await messageBuffer.WriteAsync(invocationMessage, default); + + var res = await pipes.Application.Input.ReadAsync(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + var parsedMessage = Assert.IsType(message); + Assert.Equal("method1", parsedMessage.Target); + Assert.Equal(1, ((JsonElement)Assert.Single(parsedMessage.Arguments)).GetInt32()); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + DuplexPipe.UpdateConnectionPair(ref pipes, connection); + await messageBuffer.ResendAsync(pipes.Transport.Output); + + Assert.True(messageBuffer.ShouldProcessMessage(PingMessage.Instance)); + Assert.True(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null))); + Assert.True(messageBuffer.ShouldProcessMessage(new SequenceMessage(1))); + + res = await pipes.Application.Input.ReadAsync(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + var seqMessage = Assert.IsType(message); + Assert.Equal(1, seqMessage.SequenceId); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + res = await pipes.Application.Input.ReadAsync(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + parsedMessage = Assert.IsType(message); + Assert.Equal("method1", parsedMessage.Target); + Assert.Equal(1, ((JsonElement)Assert.Single(parsedMessage.Arguments)).GetInt32()); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + messageBuffer.ShouldProcessMessage(new SequenceMessage(1)); + + Assert.True(messageBuffer.ShouldProcessMessage(PingMessage.Instance)); + Assert.False(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null))); + } + [Fact] public async Task AckedMessageNotResentOnReconnect() { diff --git a/src/SignalR/server/StackExchangeRedis/test/RedisEndToEnd.cs b/src/SignalR/server/StackExchangeRedis/test/RedisEndToEnd.cs index f191334235f1..2b48f26457d2 100644 --- a/src/SignalR/server/StackExchangeRedis/test/RedisEndToEnd.cs +++ b/src/SignalR/server/StackExchangeRedis/test/RedisEndToEnd.cs @@ -1,17 +1,15 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections.Generic; -using System.Threading.Tasks; +using System.Net.WebSockets; using Microsoft.AspNetCore.Http.Connections; +using Microsoft.AspNetCore.Http.Connections.Client; +using Microsoft.AspNetCore.InternalTesting; using Microsoft.AspNetCore.SignalR.Client; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.Tests; -using Microsoft.AspNetCore.InternalTesting; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using Xunit; namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests; @@ -213,7 +211,105 @@ public async Task CanSendAndReceiveUserMessagesUserNameWithPatternIsTreatedAsLit } } - private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null) + [ConditionalTheory] + [SkipIfDockerNotPresent] + [InlineData("messagepack")] + [InlineData("json")] + public async Task StatefulReconnectPreservesMessageFromOtherServer(string protocolName) + { + using (StartVerifiableLog()) + { + var protocol = HubProtocolHelpers.GetHubProtocol(protocolName); + + ClientWebSocket innerWs = null; + WebSocketWrapper ws = null; + TaskCompletionSource reconnectTcs = null; + TaskCompletionSource startedReconnectTcs = null; + + var connection = CreateConnection(_serverFixture.FirstServer.Url + "/stateful", HttpTransportType.WebSockets, protocol, LoggerFactory, + customizeConnection: builder => + { + builder.WithStatefulReconnect(); + builder.Services.Configure(o => + { + // Replace the websocket creation for the first connection so we can make the client think there was an ungraceful closure + // Which will trigger the stateful reconnect flow + o.WebSocketFactory = async (context, token) => + { + if (reconnectTcs is null) + { + reconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + startedReconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + else + { + startedReconnectTcs.SetResult(); + // We only want to wait on the reconnect, not the initial connection attempt + await reconnectTcs.Task.DefaultTimeout(); + } + + innerWs = new ClientWebSocket(); + ws = new WebSocketWrapper(innerWs); + await innerWs.ConnectAsync(context.Uri, token); + + _ = Task.Run(async () => + { + try + { + while (innerWs.State == WebSocketState.Open) + { + var buffer = new byte[1024]; + var res = await innerWs.ReceiveAsync(buffer, default); + ws.SetReceiveResult((res, buffer.AsMemory(0, res.Count))); + } + } + // Log but ignore receive errors, that likely just means the connection closed + catch (Exception ex) + { + Logger.LogInformation(ex, "Error while reading from inner websocket"); + } + }); + + return ws; + }; + }); + }); + var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/stateful", HttpTransportType.WebSockets, protocol, LoggerFactory); + + var tcs = new TaskCompletionSource(); + connection.On("SendToAll", message => tcs.TrySetResult(message)); + + var tcs2 = new TaskCompletionSource(); + secondConnection.On("SendToAll", message => tcs2.TrySetResult(message)); + + await connection.StartAsync().DefaultTimeout(); + await secondConnection.StartAsync().DefaultTimeout(); + + // Close first connection before the second connection sends a message to all clients + await ws.CloseOutputAsync(WebSocketCloseStatus.InternalServerError, statusDescription: null, default); + await startedReconnectTcs.Task.DefaultTimeout(); + + // Send to all clients, since both clients are on different servers this means the backplane will be used + // And we want to test that messages are still preserved for stateful reconnect purposes when a client disconnects + // But is on a different server from the original message sender. + await secondConnection.SendAsync("SendToAll", "test message").DefaultTimeout(); + + // Check that second connection still receives the message + Assert.Equal("test message", await tcs2.Task.DefaultTimeout()); + Assert.False(tcs.Task.IsCompleted); + + // allow first connection to reconnect + reconnectTcs.SetResult(); + + // Check that first connection received the message once it reconnected + Assert.Equal("test message", await tcs.Task.DefaultTimeout()); + + await connection.DisposeAsync().DefaultTimeout(); + } + } + + private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null, + Action customizeConnection = null) { var hubConnectionBuilder = new HubConnectionBuilder() .WithLoggerFactory(loggerFactory) @@ -227,6 +323,8 @@ private static HubConnection CreateConnection(string url, HttpTransportType tran hubConnectionBuilder.Services.AddSingleton(protocol); + customizeConnection?.Invoke(hubConnectionBuilder); + return hubConnectionBuilder.Build(); } @@ -255,4 +353,67 @@ public static IEnumerable TransportTypesAndProtocolTypes } } } + + internal sealed class WebSocketWrapper : WebSocket + { + private readonly WebSocket _inner; + private TaskCompletionSource<(WebSocketReceiveResult, ReadOnlyMemory)> _receiveTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public WebSocketWrapper(WebSocket inner) + { + _inner = inner; + } + + public override WebSocketCloseStatus? CloseStatus => _inner.CloseStatus; + + public override string CloseStatusDescription => _inner.CloseStatusDescription; + + public override WebSocketState State => _inner.State; + + public override string SubProtocol => _inner.SubProtocol; + + public override void Abort() + { + _inner.Abort(); + } + + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + { + return _inner.CloseAsync(closeStatus, statusDescription, cancellationToken); + } + + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + { + _receiveTcs.TrySetException(new IOException("force reconnect")); + return Task.CompletedTask; + } + + public override void Dispose() + { + _inner.Dispose(); + } + + public void SetReceiveResult((WebSocketReceiveResult, ReadOnlyMemory) result) + { + _receiveTcs.SetResult(result); + } + + public override async Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + { + var res = await _receiveTcs.Task; + // Handle zero-byte reads + if (buffer.Count == 0) + { + return res.Item1; + } + _receiveTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + res.Item2.CopyTo(buffer); + return res.Item1; + } + + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + { + return _inner.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + } + } } diff --git a/src/SignalR/server/StackExchangeRedis/test/Startup.cs b/src/SignalR/server/StackExchangeRedis/test/Startup.cs index 3fd461aed98e..1b55bd1cff53 100644 --- a/src/SignalR/server/StackExchangeRedis/test/Startup.cs +++ b/src/SignalR/server/StackExchangeRedis/test/Startup.cs @@ -33,6 +33,7 @@ public void Configure(IApplicationBuilder app) app.UseEndpoints(endpoints => { endpoints.MapHub("/echo"); + endpoints.MapHub("/stateful", o => o.AllowStatefulReconnects = true); }); } diff --git a/src/SignalR/server/StackExchangeRedis/test/StatefulHub.cs b/src/SignalR/server/StackExchangeRedis/test/StatefulHub.cs new file mode 100644 index 000000000000..1efa1d84fcd0 --- /dev/null +++ b/src/SignalR/server/StackExchangeRedis/test/StatefulHub.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests; + +public class StatefulHub : Hub +{ + public Task SendToAll(string message) + { + return Clients.All.SendAsync("SendToAll", message); + } +}