diff --git a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx index 401be8dc707616..e3d5079aafc0bc 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx @@ -129,4 +129,10 @@ The WebSocket failed to negotiate max client window bits. The client requested {0} but the server responded with {1}. + + UseDefaultCredentials, Credentials, Proxy, ClientCertificates, RemoteCertificateValidationCallback and Cookies must not be set on ClientWebSocketOptions when an HttpMessageInvoker instance is also specified. These options should be set on the HttpMessageInvoker's underlying HttpMessageHandler instead. + + + An HttpMessageInvoker instance must be passed to ConnectAsync when using HTTP/2. + diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index 6a8de0a712f581..c4069223db053f 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -30,6 +30,14 @@ public sealed class ClientWebSocketOptions private HttpVersionPolicy _versionPolicy = HttpVersionPolicy.RequestVersionOrLower; private bool _collectHttpResponseDetails; + internal bool AreCompatibleWithCustomInvoker() => + !UseDefaultCredentials && + Credentials is null && + (_clientCertificates?.Count ?? 0) == 0 && + RemoteCertificateValidationCallback is null && + Cookies is null && + (Proxy is null || Proxy == WebSocketHandle.DefaultWebProxy.Instance); + internal ClientWebSocketOptions() { } // prevent external instantiation #region HTTP Settings diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index 6bb365012ef779..e27661ea29efff 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -48,9 +48,22 @@ public void Abort() public async Task ConnectAsync(Uri uri, HttpMessageInvoker? invoker, CancellationToken cancellationToken, ClientWebSocketOptions options) { bool disposeHandler = false; - invoker ??= new HttpMessageInvoker(SetupHandler(options, out disposeHandler)); - HttpResponseMessage? response = null; + if (invoker is null) + { + if (options.HttpVersion.Major >= 2 || options.HttpVersionPolicy == HttpVersionPolicy.RequestVersionOrHigher) + { + throw new ArgumentException(SR.net_WebSockets_CustomInvokerRequiredForHttp2, nameof(options)); + } + invoker = new HttpMessageInvoker(SetupHandler(options, out disposeHandler)); + } + else if (!options.AreCompatibleWithCustomInvoker()) + { + // This will not throw if the Proxy is a DefaultWebProxy. + throw new ArgumentException(SR.net_WebSockets_OptionsIncompatibleWithCustomInvoker, nameof(options)); + } + + HttpResponseMessage? response = null; bool disposeResponse = false; // force non-secure request to 1.1 whenever it is possible as HttpClient does @@ -237,12 +250,7 @@ private static SocketsHttpHandler SetupHandler(ClientWebSocketOptions options, o // Create the handler for this request and populate it with all of the options. // Try to use a shared handler rather than creating a new one just for this request, if // the options are compatible. - if (options.Credentials == null && - !options.UseDefaultCredentials && - options.Proxy == null && - options.Cookies == null && - options.RemoteCertificateValidationCallback == null && - (options._clientCertificates?.Count ?? 0) == 0) + if (options.AreCompatibleWithCustomInvoker() && options.Proxy is null) { disposeHandler = false; handler = s_defaultHandler; @@ -518,7 +526,7 @@ private static void ValidateHeader(HttpHeaders headers, string name, string expe } /// Used as a sentinel to indicate that ClientWebSocket should use the system's default proxy. - private sealed class DefaultWebProxy : IWebProxy + internal sealed class DefaultWebProxy : IWebProxy { public static DefaultWebProxy Instance { get; } = new DefaultWebProxy(); public ICredentials? Credentials { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.cs index 70cb6317f36635..d81baa53e4639b 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.cs @@ -16,14 +16,14 @@ public sealed class InvokerAbortTest : AbortTest { public InvokerAbortTest(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpMessageInvoker(new SocketsHttpHandler()); + protected override bool UseCustomInvoker => true; } public sealed class HttpClientAbortTest : AbortTest { public HttpClientAbortTest(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpClient(new HttpClientHandler()); + protected override bool UseHttpClient => true; } public class AbortTest : ClientWebSocketTestBase diff --git a/src/libraries/System.Net.WebSockets.Client/tests/CancelTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/CancelTest.cs index 5919f63157d705..534f62cca17d51 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/CancelTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/CancelTest.cs @@ -14,14 +14,14 @@ public sealed class InvokerCancelTest : CancelTest { public InvokerCancelTest(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpMessageInvoker(new SocketsHttpHandler()); + protected override bool UseCustomInvoker => true; } public sealed class HttpClientCancelTest : CancelTest { public HttpClientCancelTest(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpClient(new HttpClientHandler()); + protected override bool UseHttpClient => true; } public class CancelTest : ClientWebSocketTestBase diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs index a32587bc862fc9..0c8b94778a58ea 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.Net.Test.Common; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -10,13 +9,10 @@ using Xunit; using Xunit.Abstractions; using System.Net.Http; -using System.Net.WebSockets.Client.Tests; +using System.Diagnostics; namespace System.Net.WebSockets.Client.Tests { - /// - /// ClientWebSocket tests that do require a remote server. - /// public class ClientWebSocketTestBase { public static readonly object[][] EchoServers = System.Net.Test.Common.Configuration.WebSockets.EchoServers; @@ -112,7 +108,38 @@ protected static async Task ReceiveEntireMessageAsync(We } } - protected virtual HttpMessageInvoker? GetInvoker() => null; + protected virtual bool UseCustomInvoker => false; + + protected virtual bool UseHttpClient => false; + + protected bool UseSharedHandler => !UseCustomInvoker && !UseHttpClient; + + protected Action? ConfigureCustomHandler; + + internal HttpMessageInvoker? GetInvoker() + { + var handler = new HttpClientHandler(); + + if (PlatformDetection.IsNotBrowser) + { + handler.ServerCertificateCustomValidationCallback = HttpClientHandler.DangerousAcceptAnyServerCertificateValidator; + } + + ConfigureCustomHandler?.Invoke(handler); + + if (UseCustomInvoker) + { + Debug.Assert(!UseHttpClient); + return new HttpMessageInvoker(handler); + } + + if (UseHttpClient) + { + return new HttpClient(handler); + } + + return null; + } protected Task GetConnectedWebSocket(Uri uri, int TimeOutMilliseconds, ITestOutputHelper output) => WebSocketHelper.GetConnectedWebSocket(uri, TimeOutMilliseconds, output, invoker: GetInvoker()); diff --git a/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs index c09df107c8538c..9034b38bd43f46 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs @@ -18,14 +18,14 @@ public sealed class InvokerCloseTest : CloseTest { public InvokerCloseTest(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpMessageInvoker(new SocketsHttpHandler()); + protected override bool UseCustomInvoker => true; } public sealed class HttpClientCloseTest : CloseTest { public HttpClientCloseTest(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpClient(new HttpClientHandler()); + protected override bool UseHttpClient => true; } public class CloseTest : ClientWebSocketTestBase diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.Http2.cs b/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.Http2.cs index 1ec2f1ee39fa45..bdbbcb75222fa5 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.Http2.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.Http2.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.IO; using System.Net.Http; using System.Net.Test.Common; using System.Threading; @@ -11,34 +10,59 @@ using Xunit; using Xunit.Abstractions; -using static System.Net.Http.Functional.Tests.TestHelper; - namespace System.Net.WebSockets.Client.Tests { public sealed class InvokerConnectTest_Http2 : ConnectTest_Http2 { public InvokerConnectTest_Http2(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpClient(new HttpClientHandler()); + protected override bool UseCustomInvoker => true; } public sealed class HttpClientConnectTest_Http2 : ConnectTest_Http2 { public HttpClientConnectTest_Http2(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpClient(new HttpClientHandler()); + protected override bool UseHttpClient => true; } - public class ConnectTest_Http2 : ClientWebSocketTestBase + public sealed class HttpClientConnectTest_Http2_NoInvoker : ClientWebSocketTestBase { - public ConnectTest_Http2(ITestOutputHelper output) : base(output) { } + public HttpClientConnectTest_Http2_NoInvoker(ITestOutputHelper output) : base(output) { } + + public static IEnumerable ConnectAsync_Http2WithNoInvoker_ThrowsArgumentException_MemberData() + { + yield return Options(options => options.HttpVersion = HttpVersion.Version20); + yield return Options(options => options.HttpVersion = HttpVersion.Version30); + yield return Options(options => options.HttpVersion = new Version(2, 1)); + yield return Options(options => options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher); + static object[] Options(Action configureOptions) => + new object[] { configureOptions }; + } [Theory] - [InlineData(false)] - [InlineData(true)] + [MemberData(nameof(ConnectAsync_Http2WithNoInvoker_ThrowsArgumentException_MemberData))] + [SkipOnPlatform(TestPlatforms.Browser, "HTTP/2 WebSockets aren't supported on Browser")] + public async Task ConnectAsync_Http2WithNoInvoker_ThrowsArgumentException(Action configureOptions) + { + using var ws = new ClientWebSocket(); + configureOptions(ws.Options); + + Task connectTask = ws.ConnectAsync(new Uri("wss://dummy"), CancellationToken.None); + + Assert.Equal(TaskStatus.Faulted, connectTask.Status); + await Assert.ThrowsAsync("options", () => connectTask); + } + } + + public abstract class ConnectTest_Http2 : ClientWebSocketTestBase + { + public ConnectTest_Http2(ITestOutputHelper output) : base(output) { } + + [Fact] [SkipOnPlatform(TestPlatforms.Browser, "System.Net.Sockets is not supported on this platform")] - public async Task ConnectAsync_VersionNotSupported_NoSsl_Throws(bool useHandler) + public async Task ConnectAsync_VersionNotSupported_NoSsl_Throws() { await Http2LoopbackServer.CreateClientAndServerAsync(async uri => { @@ -46,17 +70,10 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async uri => using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) { cws.Options.HttpVersion = HttpVersion.Version20; - cws.Options.HttpVersionPolicy = Http.HttpVersionPolicy.RequestVersionExact; - Task t; - if (useHandler) - { - var handler = new SocketsHttpHandler(); - t = cws.ConnectAsync(uri, new HttpMessageInvoker(handler), cts.Token); - } - else - { - t = cws.ConnectAsync(uri, cts.Token); - } + cws.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact; + + Task t = cws.ConnectAsync(uri, GetInvoker(), cts.Token); + var ex = await Assert.ThrowsAnyAsync(() => t); Assert.IsType(ex.InnerException); Assert.True(ex.InnerException.Data.Contains("SETTINGS_ENABLE_CONNECT_PROTOCOL")); @@ -65,8 +82,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async uri => async server => { Http2LoopbackConnection connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 0 }); - }, new Http2Options() { WebSocketEndpoint = true, UseSsl = false } - ); + }, new Http2Options() { WebSocketEndpoint = true, UseSsl = false }); } [Fact] @@ -79,10 +95,9 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async uri => using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) { cws.Options.HttpVersion = HttpVersion.Version20; - cws.Options.HttpVersionPolicy = Http.HttpVersionPolicy.RequestVersionExact; - Task t; - var handler = CreateSocketsHttpHandler(allowAllCertificates: true); - t = cws.ConnectAsync(uri, new HttpMessageInvoker(handler), cts.Token); + cws.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact; + + Task t = cws.ConnectAsync(uri, GetInvoker(), cts.Token); var ex = await Assert.ThrowsAnyAsync(() => t); Assert.IsType(ex.InnerException); @@ -92,31 +107,22 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async uri => async server => { Http2LoopbackConnection connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 0 }); - }, new Http2Options() { WebSocketEndpoint = true } - ); + }, new Http2Options() { WebSocketEndpoint = true }); } [OuterLoop("Uses external servers", typeof(PlatformDetection), nameof(PlatformDetection.LocalEchoServerIsNotAvailable))] - [Theory] - [MemberData(nameof(SecureEchoServersAndBoolean))] + [Fact] [SkipOnPlatform(TestPlatforms.Browser, "System.Net.Sockets is not supported on this platform")] - public async Task ConnectAsync_Http11Server_DowngradeFail(Uri server, bool useHandler) + public async Task ConnectAsync_Http11Server_DowngradeFail() { using (var cws = new ClientWebSocket()) using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) { cws.Options.HttpVersion = HttpVersion.Version20; - cws.Options.HttpVersionPolicy = Http.HttpVersionPolicy.RequestVersionExact; - Task t; - if (useHandler) - { - var handler = new SocketsHttpHandler(); - t = cws.ConnectAsync(server, new HttpMessageInvoker(handler), cts.Token); - } - else - { - t = cws.ConnectAsync(server, cts.Token); - } + cws.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact; + + Task t = cws.ConnectAsync(Test.Common.Configuration.WebSockets.SecureRemoteEchoServer, GetInvoker(), cts.Token); + var ex = await Assert.ThrowsAnyAsync(() => t); Assert.IsType(ex.InnerException); Assert.True(ex.InnerException.Data.Contains("HTTP2_ENABLED")); @@ -126,34 +132,23 @@ public async Task ConnectAsync_Http11Server_DowngradeFail(Uri server, bool useHa [OuterLoop("Uses external servers", typeof(PlatformDetection), nameof(PlatformDetection.LocalEchoServerIsNotAvailable))] [Theory] - [MemberData(nameof(EchoServersAndBoolean))] + [MemberData(nameof(EchoServers))] [SkipOnPlatform(TestPlatforms.Browser, "System.Net.Sockets is not supported on this platform")] - public async Task ConnectAsync_Http11Server_DowngradeSuccess(Uri server, bool useHandler) + public async Task ConnectAsync_Http11Server_DowngradeSuccess(Uri server) { using (var cws = new ClientWebSocket()) using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) { cws.Options.HttpVersion = HttpVersion.Version20; - cws.Options.HttpVersionPolicy = Http.HttpVersionPolicy.RequestVersionOrLower; - if (useHandler) - { - var handler = new SocketsHttpHandler(); - await cws.ConnectAsync(server, new HttpMessageInvoker(handler), cts.Token); - } - else - { - await cws.ConnectAsync(server, cts.Token); - } + cws.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionOrLower; + await cws.ConnectAsync(server, GetInvoker(), cts.Token); Assert.Equal(WebSocketState.Open, cws.State); } } - - [Theory] - [InlineData(false)] - [InlineData(true)] + [Fact] [SkipOnPlatform(TestPlatforms.Browser, "System.Net.Sockets is not supported on this platform")] - public async Task ConnectAsync_VersionSupported_NoSsl_Success(bool useHandler) + public async Task ConnectAsync_VersionSupported_NoSsl_Success() { await Http2LoopbackServer.CreateClientAndServerAsync(async uri => { @@ -161,16 +156,8 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async uri => using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) { cws.Options.HttpVersion = HttpVersion.Version20; - cws.Options.HttpVersionPolicy = Http.HttpVersionPolicy.RequestVersionExact; - if (useHandler) - { - var handler = new SocketsHttpHandler(); - await cws.ConnectAsync(uri, new HttpMessageInvoker(handler), cts.Token); - } - else - { - await cws.ConnectAsync(uri, cts.Token); - } + cws.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact; + await cws.ConnectAsync(uri, GetInvoker(), cts.Token); } }, async server => @@ -178,8 +165,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async uri => Http2LoopbackConnection connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }); (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); await connection.SendResponseHeadersAsync(streamId, endStream: false, HttpStatusCode.OK); - }, new Http2Options() { WebSocketEndpoint = true, UseSsl = false } - ); + }, new Http2Options() { WebSocketEndpoint = true, UseSsl = false }); } [Fact] @@ -192,10 +178,8 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async uri => using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) { cws.Options.HttpVersion = HttpVersion.Version20; - cws.Options.HttpVersionPolicy = Http.HttpVersionPolicy.RequestVersionExact; - - var handler = CreateSocketsHttpHandler(allowAllCertificates: true); - await cws.ConnectAsync(uri, new HttpMessageInvoker(handler), cts.Token); + cws.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact; + await cws.ConnectAsync(uri, GetInvoker(), cts.Token); } }, async server => @@ -203,8 +187,39 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async uri => Http2LoopbackConnection connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }); (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); await connection.SendResponseHeadersAsync(streamId, endStream: false, HttpStatusCode.OK); - }, new Http2Options() { WebSocketEndpoint = true } - ); + }, new Http2Options() { WebSocketEndpoint = true }); + } + + [Fact] + [SkipOnPlatform(TestPlatforms.Browser, "HTTP/2 WebSockets aren't supported on Browser")] + public async Task ConnectAsync_SameHttp2ConnectionUsedForMultipleWebSocketConnection() + { + await Http2LoopbackServer.CreateClientAndServerAsync(async uri => + { + using var cws1 = new ClientWebSocket(); + cws1.Options.HttpVersion = HttpVersion.Version20; + cws1.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact; + + using var cws2 = new ClientWebSocket(); + cws2.Options.HttpVersion = HttpVersion.Version20; + cws2.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact; + + using var cts = new CancellationTokenSource(TimeOutMilliseconds); + HttpMessageInvoker? invoker = GetInvoker(); + + await cws1.ConnectAsync(uri, invoker, cts.Token); + await cws2.ConnectAsync(uri, invoker, cts.Token); + }, + async server => + { + await using Http2LoopbackConnection connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }); + + (int streamId1, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); + await connection.SendResponseHeadersAsync(streamId1, endStream: false, HttpStatusCode.OK); + + (int streamId2, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); + await connection.SendResponseHeadersAsync(streamId2, endStream: false, HttpStatusCode.OK); + }, new Http2Options() { WebSocketEndpoint = true, UseSsl = false }); } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.cs index fd0eee1265cfa1..3fd63fbbb7c61a 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.cs @@ -3,12 +3,11 @@ using System.Collections.Generic; using System.IO; -using System.Linq; using System.Net.Http; using System.Net.Test.Common; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; - using Xunit; using Xunit.Abstractions; @@ -17,14 +16,78 @@ namespace System.Net.WebSockets.Client.Tests public sealed class InvokerConnectTest : ConnectTest { public InvokerConnectTest(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpMessageInvoker(new SocketsHttpHandler()); + + protected override bool UseCustomInvoker => true; + + public static IEnumerable ConnectAsync_CustomInvokerWithIncompatibleWebSocketOptions_ThrowsArgumentException_MemberData() + { + yield return Throw(options => options.UseDefaultCredentials = true); + yield return NoThrow(options => options.UseDefaultCredentials = false); + yield return Throw(options => options.Credentials = new NetworkCredential()); + yield return Throw(options => options.Proxy = new WebProxy()); + yield return Throw(options => options.ClientCertificates.Add(Test.Common.Configuration.Certificates.GetClientCertificate())); + yield return NoThrow(options => options.ClientCertificates = new X509CertificateCollection()); + yield return Throw(options => options.RemoteCertificateValidationCallback = delegate { return true; }); + yield return Throw(options => options.Cookies = new CookieContainer()); + + // We allow no proxy or the default proxy to be used + yield return NoThrow(options => { }); + yield return NoThrow(options => options.Proxy = null); + + // These options don't conflict with the custom invoker + yield return NoThrow(options => options.HttpVersion = new Version(2, 0)); + yield return NoThrow(options => options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher); + yield return NoThrow(options => options.SetRequestHeader("foo", "bar")); + yield return NoThrow(options => options.AddSubProtocol("foo")); + yield return NoThrow(options => options.KeepAliveInterval = TimeSpan.FromSeconds(42)); + yield return NoThrow(options => options.DangerousDeflateOptions = new WebSocketDeflateOptions()); + yield return NoThrow(options => options.CollectHttpResponseDetails = true); + + static object[] Throw(Action configureOptions) => + new object[] { configureOptions, true }; + + static object[] NoThrow(Action configureOptions) => + new object[] { configureOptions, false }; + } + + [Theory] + [MemberData(nameof(ConnectAsync_CustomInvokerWithIncompatibleWebSocketOptions_ThrowsArgumentException_MemberData))] + [SkipOnPlatform(TestPlatforms.Browser, "Custom invoker is ignored on Browser")] + public async Task ConnectAsync_CustomInvokerWithIncompatibleWebSocketOptions_ThrowsArgumentException(Action configureOptions, bool shouldThrow) + { + using var invoker = new HttpMessageInvoker(new SocketsHttpHandler + { + ConnectCallback = (_, _) => ValueTask.FromException(new Exception("ConnectCallback")) + }); + + using var ws = new ClientWebSocket(); + configureOptions(ws.Options); + + Task connectTask = ws.ConnectAsync(new Uri("wss://dummy"), invoker, CancellationToken.None); + if (shouldThrow) + { + Assert.Equal(TaskStatus.Faulted, connectTask.Status); + await Assert.ThrowsAsync("options", () => connectTask); + } + else + { + WebSocketException ex = await Assert.ThrowsAsync(() => connectTask); + Assert.NotNull(ex.InnerException); + Assert.Contains("ConnectCallback", ex.InnerException.Message); + } + + foreach (X509Certificate cert in ws.Options.ClientCertificates) + { + cert.Dispose(); + } + } } public sealed class HttpClientConnectTest : ConnectTest { public HttpClientConnectTest(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpClient(new HttpClientHandler()); + protected override bool UseHttpClient => true; } public class ConnectTest : ClientWebSocketTestBase @@ -258,7 +321,13 @@ public async Task ConnectAndCloseAsync_UseProxyServer_ExpectedClosedState(Uri se using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) using (LoopbackProxyServer proxyServer = LoopbackProxyServer.Create()) { - cws.Options.Proxy = new WebProxy(proxyServer.Uri); + ConfigureCustomHandler = handler => handler.Proxy = new WebProxy(proxyServer.Uri); + + if (UseSharedHandler) + { + cws.Options.Proxy = new WebProxy(proxyServer.Uri); + } + await ConnectAsync(cws, server, cts.Token); string expectedCloseStatusDescription = "Client close status"; @@ -267,6 +336,7 @@ public async Task ConnectAndCloseAsync_UseProxyServer_ExpectedClosedState(Uri se Assert.Equal(WebSocketState.Closed, cws.State); Assert.Equal(WebSocketCloseStatus.NormalClosure, cws.CloseStatus); Assert.Equal(expectedCloseStatusDescription, cws.CloseStatusDescription); + Assert.Equal(1, proxyServer.Connections); } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs index 262e45ae414db8..9836f31df16cb7 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -18,14 +18,14 @@ public sealed class InvokerDeflateTests : DeflateTests { public InvokerDeflateTests(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpMessageInvoker(new SocketsHttpHandler()); + protected override bool UseCustomInvoker => true; } public sealed class HttpClientDeflateTests : DeflateTests { public HttpClientDeflateTests(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpClient(new HttpClientHandler()); + protected override bool UseHttpClient => true; } [PlatformSpecific(~TestPlatforms.Browser)] diff --git a/src/libraries/System.Net.WebSockets.Client/tests/SendReceiveTest.Http2.cs b/src/libraries/System.Net.WebSockets.Client/tests/SendReceiveTest.Http2.cs index 5f3be83d5bfb79..ef21a36e44fa8f 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/SendReceiveTest.Http2.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/SendReceiveTest.Http2.cs @@ -3,7 +3,6 @@ using System.Linq; using System.Net.Http; -using System.Net.Sockets; using System.Net.Test.Common; using System.Threading; using System.Threading.Tasks; @@ -11,19 +10,29 @@ using Xunit; using Xunit.Abstractions; -using static System.Net.Http.Functional.Tests.TestHelper; - namespace System.Net.WebSockets.Client.Tests { - public class SendReceiveTest_Http2 : ClientWebSocketTestBase + public sealed class HttpClientSendReceiveTest_Http2 : SendReceiveTest_Http2 + { + public HttpClientSendReceiveTest_Http2(ITestOutputHelper output) : base(output) { } + + protected override bool UseHttpClient => true; + } + + public sealed class InvokerSendReceiveTest_Http2 : SendReceiveTest_Http2 + { + public InvokerSendReceiveTest_Http2(ITestOutputHelper output) : base(output) { } + + protected override bool UseCustomInvoker => true; + } + + public abstract class SendReceiveTest_Http2 : ClientWebSocketTestBase { public SendReceiveTest_Http2(ITestOutputHelper output) : base(output) { } - [Theory] - [InlineData(false)] - [InlineData(true)] + [Fact] [SkipOnPlatform(TestPlatforms.Browser, "System.Net.Sockets is not supported on this platform")] - public async Task ReceiveNoThrowAfterSend_NoSsl(bool useHandler) + public async Task ReceiveNoThrowAfterSend_NoSsl() { var serverMessage = new byte[] { 4, 5, 6 }; await Http2LoopbackServer.CreateClientAndServerAsync(async uri => @@ -32,16 +41,9 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async uri => using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) { cws.Options.HttpVersion = HttpVersion.Version20; - cws.Options.HttpVersionPolicy = Http.HttpVersionPolicy.RequestVersionExact; - if (useHandler) - { - var handler = new SocketsHttpHandler(); - await cws.ConnectAsync(uri, new HttpMessageInvoker(handler), cts.Token); - } - else - { - await cws.ConnectAsync(uri, cts.Token); - } + cws.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact; + + await cws.ConnectAsync(uri, GetInvoker(), cts.Token); await cws.SendAsync(new byte[] { 2, 3, 4 }, WebSocketMessageType.Binary, true, cts.Token); @@ -63,8 +65,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async uri => byte[] constructMessage = prefix.Concat(serverMessage).ToArray(); await connection.SendResponseDataAsync(streamId, constructMessage, endStream: false); - }, new Http2Options() { WebSocketEndpoint = true, UseSsl = false } - ); + }, new Http2Options() { WebSocketEndpoint = true, UseSsl = false }); } [Fact] @@ -78,10 +79,9 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async uri => using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) { cws.Options.HttpVersion = HttpVersion.Version20; - cws.Options.HttpVersionPolicy = Http.HttpVersionPolicy.RequestVersionExact; + cws.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact; - var handler = CreateSocketsHttpHandler(allowAllCertificates: true); - await cws.ConnectAsync(uri, new HttpMessageInvoker(handler), cts.Token); + await cws.ConnectAsync(uri, GetInvoker(), cts.Token); await cws.SendAsync(new byte[] { 2, 3, 4 }, WebSocketMessageType.Binary, true, cts.Token); @@ -103,8 +103,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async uri => byte[] constructMessage = prefix.Concat(serverMessage).ToArray(); await connection.SendResponseDataAsync(streamId, constructMessage, endStream: false); - }, new Http2Options() { WebSocketEndpoint = true } - ); + }, new Http2Options() { WebSocketEndpoint = true }); } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/SendReceiveTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/SendReceiveTest.cs index 6597b6f9ec6315..ec3913c02c16c9 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/SendReceiveTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/SendReceiveTest.cs @@ -17,28 +17,28 @@ public sealed class InvokerMemorySendReceiveTest : MemorySendReceiveTest { public InvokerMemorySendReceiveTest(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpMessageInvoker(new SocketsHttpHandler()); + protected override bool UseCustomInvoker => true; } public sealed class HttpClientMemorySendReceiveTest : MemorySendReceiveTest { public HttpClientMemorySendReceiveTest(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpClient(new HttpClientHandler()); + protected override bool UseHttpClient => true; } public sealed class InvokerArraySegmentSendReceiveTest : ArraySegmentSendReceiveTest { public InvokerArraySegmentSendReceiveTest(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpMessageInvoker(new SocketsHttpHandler()); + protected override bool UseCustomInvoker => true; } public sealed class HttpClientArraySegmentSendReceiveTest : ArraySegmentSendReceiveTest { public HttpClientArraySegmentSendReceiveTest(ITestOutputHelper output) : base(output) { } - protected override HttpMessageInvoker? GetInvoker() => new HttpClient(new HttpClientHandler()); + protected override bool UseHttpClient => true; } public class MemorySendReceiveTest : SendReceiveTest