From 50ac1c66f3974fb47520a54a1e9b761e5cbe875b Mon Sep 17 00:00:00 2001 From: wfurt Date: Wed, 7 Feb 2024 14:29:43 -0800 Subject: [PATCH 1/2] fix SendTo with SocketAsyncEventArgs --- .../src/System/Net/Sockets/Socket.Tasks.cs | 12 ++++- .../src/System/Net/Sockets/Socket.cs | 21 ++++++--- .../Net/Sockets/SocketAsyncEventArgs.cs | 7 ++- .../tests/FunctionalTests/SendTo.cs | 29 ++++++++++++ .../SocketAsyncEventArgsTest.cs | 47 +++++++++++++++++++ 5 files changed, 107 insertions(+), 9 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index e1891bef916f4d..3ba24e90cf1019 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -677,7 +677,6 @@ public ValueTask SendToAsync(ReadOnlyMemory buffer, SocketFlags socke Debug.Assert(saea.BufferList == null); saea.SetBuffer(MemoryMarshal.AsMemory(buffer)); saea.SocketFlags = socketFlags; - saea._socketAddress = null; saea.RemoteEndPoint = remoteEP; saea.WrapExceptionsForNetworkStream = false; return saea.SendToAsync(this, cancellationToken); @@ -709,8 +708,17 @@ public ValueTask SendToAsync(ReadOnlyMemory buffer, SocketFlags socke saea.SetBuffer(MemoryMarshal.AsMemory(buffer)); saea.SocketFlags = socketFlags; saea._socketAddress = socketAddress; + saea.RemoteEndPoint = null; saea.WrapExceptionsForNetworkStream = false; - return saea.SendToAsync(this, cancellationToken); + try + { + return saea.SendToAsync(this, cancellationToken); + } + finally + { + // detach user provided SA so we do not accidentally stomp on it later. + saea._socketAddress = null; + } } /// diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 11b8674d681f38..49f8b13acb44d2 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -3095,14 +3095,23 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT ArgumentNullException.ThrowIfNull(e); EndPoint? endPointSnapshot = e.RemoteEndPoint; - if (e._socketAddress == null) + + // RemoteEndPoint should be set unless somebody used SendTo with their own SA. + // In that case RemoteEndPoint will be null and we take provided SA as given. + if (endPointSnapshot == null && e._socketAddress == null) { - if (endPointSnapshot == null) - { - throw new ArgumentException(SR.Format(SR.InvalidNullArgument, "e.RemoteEndPoint"), nameof(e)); - } + throw new ArgumentException(SR.Format(SR.InvalidNullArgument, "e.RemoteEndPoint"), nameof(e)); + } - // Prepare SocketAddress + if (e._socketAddress != null && endPointSnapshot is IPEndPoint && e._socketAddress.Family == endPointSnapshot?.AddressFamily) + { + // we have matching SocketAddress. Since this is only used internally, it should be ok to override it without + // allocating new one. + ((IPEndPoint)endPointSnapshot).Serialize(e._socketAddress.Buffer.Span); + } + else if (endPointSnapshot != null) + { + // Prepare new SocketAddress e._socketAddress = Serialize(ref endPointSnapshot); } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs index e94d862571a0f8..78dd22e5eda7bf 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs @@ -923,7 +923,12 @@ internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags case SocketAsyncOperation.ReceiveFrom: // Deal with incoming address. UpdateReceivedSocketAddress(_socketAddress!); - if (_remoteEndPoint != null && !SocketAddressExtensions.Equals(_socketAddress!, _remoteEndPoint)) + if (_remoteEndPoint == null) + { + // detach user provided SA as it was updated in place. + _socketAddress = null; + } + else if (!SocketAddressExtensions.Equals(_socketAddress!, _remoteEndPoint)) { try { diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs index bf0ad146588699..7a3c33b64bf796 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs @@ -173,6 +173,35 @@ public void SendToAsync_NullAsyncEventArgs_Throws_ArgumentNullException() public sealed class SendTo_Task : SendTo { public SendTo_Task(ITestOutputHelper output) : base(output) { } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task SendTo_DifferentEP_Success(bool ipv4) + { + IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback; + IPEndPoint remoteEp = new IPEndPoint(address, 0); + + using Socket receiver1 = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket receiver2 = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket sender = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + + receiver1.BindToAnonymousPort(address); + receiver2.BindToAnonymousPort(address); + + byte[] sendBuffer = new byte[32]; + var receiveInternalBuffer = new byte[sendBuffer.Length]; + ArraySegment receiveBuffer = new ArraySegment(receiveInternalBuffer, 0, receiveInternalBuffer.Length); + + + await sender.SendToAsync(sendBuffer, SocketFlags.None, receiver1.LocalEndPoint); + SocketReceiveFromResult result = await ReceiveFromAsync(receiver1, receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout); + Assert.Equal(sendBuffer.Length, result.ReceivedBytes); + + await sender.SendToAsync(sendBuffer, SocketFlags.None, receiver2.LocalEndPoint); + result = await ReceiveFromAsync(receiver2, receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout); + Assert.Equal(sendBuffer.Length, result.ReceivedBytes); + } } public sealed class SendTo_CancellableTask : SendTo diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs index ded34276f322fa..3d865cb864570f 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs @@ -895,5 +895,52 @@ void CreateSocketAsyncEventArgs() // separated out so that JIT doesn't extend li return cwt.Count() == 0; // validate that the cwt becomes empty }, 30_000)); } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task SendTo_DifferentEP_Success(bool ipv4) + { + IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback; + IPEndPoint remoteEp = new IPEndPoint(address, 0); + + using Socket receiver1 = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket receiver2 = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket sender = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + + receiver1.BindToAnonymousPort(address); + receiver2.BindToAnonymousPort(address); + + byte[] sendBuffer = new byte[32]; + var receiveInternalBuffer = new byte[sendBuffer.Length]; + ArraySegment receiveBuffer = new ArraySegment(receiveInternalBuffer, 0, receiveInternalBuffer.Length); + + using SocketAsyncEventArgs saea = new SocketAsyncEventArgs(); + ManualResetEventSlim mres = new ManualResetEventSlim(false); + + saea.SetBuffer(sendBuffer); + saea.RemoteEndPoint = receiver1.LocalEndPoint; + saea.Completed += delegate { mres.Set(); }; + if (sender.SendToAsync(saea)) + { + // did not finish synchronously. + mres.Wait(); + } + + SocketReceiveFromResult result = await receiver1.ReceiveFromAsync(receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout); + Assert.Equal(sendBuffer.Length, result.ReceivedBytes); + mres.Reset(); + + + saea.RemoteEndPoint = receiver2.LocalEndPoint; + if (sender.SendToAsync(saea)) + { + // did not finish synchronously. + mres.Wait(); + } + + result = await receiver2.ReceiveFromAsync(receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout); + Assert.Equal(sendBuffer.Length, result.ReceivedBytes); + } } } From 412b66f07efb7e256a986c015aeee23e3822484f Mon Sep 17 00:00:00 2001 From: wfurt Date: Mon, 26 Feb 2024 13:55:07 -0800 Subject: [PATCH 2/2] feedback --- .../System.Net.Sockets/src/System/Net/Sockets/Socket.cs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 49f8b13acb44d2..a8c95005154c93 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -3103,11 +3103,10 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT throw new ArgumentException(SR.Format(SR.InvalidNullArgument, "e.RemoteEndPoint"), nameof(e)); } - if (e._socketAddress != null && endPointSnapshot is IPEndPoint && e._socketAddress.Family == endPointSnapshot?.AddressFamily) + if (e._socketAddress != null && endPointSnapshot is IPEndPoint ipep && e._socketAddress.Family == endPointSnapshot?.AddressFamily) { - // we have matching SocketAddress. Since this is only used internally, it should be ok to override it without - // allocating new one. - ((IPEndPoint)endPointSnapshot).Serialize(e._socketAddress.Buffer.Span); + // we have matching SocketAddress. Since this is only used internally, it is ok to overwrite it without + ipep.Serialize(e._socketAddress.Buffer.Span); } else if (endPointSnapshot != null) {