diff --git a/src/Servers/Kestrel/Core/src/AnyIPListenOptions.cs b/src/Servers/Kestrel/Core/src/AnyIPListenOptions.cs index fe5e2bb98159..8168166fd1c2 100644 --- a/src/Servers/Kestrel/Core/src/AnyIPListenOptions.cs +++ b/src/Servers/Kestrel/Core/src/AnyIPListenOptions.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.IO; using System.Net; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.Extensions.Logging; @@ -18,14 +19,14 @@ internal AnyIPListenOptions(int port) { } - internal override async Task BindAsync(AddressBindContext context) + internal override async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken) { Debug.Assert(IPEndPoint != null); // when address is 'http://hostname:port', 'http://*:port', or 'http://+:port' try { - await base.BindAsync(context).ConfigureAwait(false); + await base.BindAsync(context, cancellationToken).ConfigureAwait(false); } catch (Exception ex) when (!(ex is IOException)) { @@ -33,7 +34,7 @@ internal override async Task BindAsync(AddressBindContext context) // for machines that do not support IPv6 EndPoint = new IPEndPoint(IPAddress.Any, IPEndPoint.Port); - await base.BindAsync(context).ConfigureAwait(false); + await base.BindAsync(context, cancellationToken).ConfigureAwait(false); } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs b/src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs index cb15e3e18502..09a127867b7a 100644 --- a/src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs +++ b/src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; @@ -14,7 +15,7 @@ public AddressBindContext( ServerAddressesFeature serverAddressesFeature, KestrelServerOptions serverOptions, ILogger logger, - Func createBinding) + Func createBinding) { ServerAddressesFeature = serverAddressesFeature; ServerOptions = serverOptions; @@ -28,6 +29,6 @@ public AddressBindContext( public KestrelServerOptions ServerOptions { get; } public ILogger Logger { get; } - public Func CreateBinding { get; } + public Func CreateBinding { get; } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs b/src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs index 7635199bc8e0..752ba3efc494 100644 --- a/src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs +++ b/src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs @@ -7,6 +7,7 @@ using System.IO; using System.Linq; using System.Net; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Connections; @@ -20,7 +21,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal { internal class AddressBinder { - public static async Task BindAsync(IEnumerable listenOptions, AddressBindContext context) + public static async Task BindAsync(IEnumerable listenOptions, AddressBindContext context, CancellationToken cancellationToken) { var strategy = CreateStrategy( listenOptions.ToArray(), @@ -32,7 +33,7 @@ public static async Task BindAsync(IEnumerable listenOptions, Add context.ServerOptions.OptionsInUse.Clear(); context.Addresses.Clear(); - await strategy.BindAsync(context).ConfigureAwait(false); + await strategy.BindAsync(context, cancellationToken).ConfigureAwait(false); } private static IStrategy CreateStrategy(ListenOptions[] listenOptions, string[] addresses, bool preferAddresses) @@ -86,11 +87,11 @@ protected internal static bool TryCreateIPEndPoint(BindingAddress address, [NotN return true; } - internal static async Task BindEndpointAsync(ListenOptions endpoint, AddressBindContext context) + internal static async Task BindEndpointAsync(ListenOptions endpoint, AddressBindContext context, CancellationToken cancellationToken) { try { - await context.CreateBinding(endpoint).ConfigureAwait(false); + await context.CreateBinding(endpoint, cancellationToken).ConfigureAwait(false); } catch (AddressInUseException ex) { @@ -144,16 +145,16 @@ internal static ListenOptions ParseAddress(string address, out bool https) private interface IStrategy { - Task BindAsync(AddressBindContext context); + Task BindAsync(AddressBindContext context, CancellationToken cancellationToken); } private class DefaultAddressStrategy : IStrategy { - public async Task BindAsync(AddressBindContext context) + public async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken) { var httpDefault = ParseAddress(Constants.DefaultServerAddress, out _); context.ServerOptions.ApplyEndpointDefaults(httpDefault); - await httpDefault.BindAsync(context).ConfigureAwait(false); + await httpDefault.BindAsync(context, cancellationToken).ConfigureAwait(false); // Conditional https default, only if a cert is available var httpsDefault = ParseAddress(Constants.DefaultServerHttpsAddress, out _); @@ -161,7 +162,7 @@ public async Task BindAsync(AddressBindContext context) if (httpsDefault.IsTls || httpsDefault.TryUseHttps()) { - await httpsDefault.BindAsync(context).ConfigureAwait(false); + await httpsDefault.BindAsync(context, cancellationToken).ConfigureAwait(false); context.Logger.LogDebug(CoreStrings.BindingToDefaultAddresses, Constants.DefaultServerAddress, Constants.DefaultServerHttpsAddress); } @@ -180,12 +181,12 @@ public OverrideWithAddressesStrategy(IReadOnlyCollection addresses) { } - public override Task BindAsync(AddressBindContext context) + public override Task BindAsync(AddressBindContext context, CancellationToken cancellationToken) { var joined = string.Join(", ", _addresses); context.Logger.LogInformation(CoreStrings.OverridingWithPreferHostingUrls, nameof(IServerAddressesFeature.PreferHostingUrls), joined); - return base.BindAsync(context); + return base.BindAsync(context, cancellationToken); } } @@ -199,12 +200,12 @@ public OverrideWithEndpointsStrategy(IReadOnlyCollection endpoint _originalAddresses = originalAddresses; } - public override Task BindAsync(AddressBindContext context) + public override Task BindAsync(AddressBindContext context, CancellationToken cancellationToken) { var joined = string.Join(", ", _originalAddresses); context.Logger.LogWarning(CoreStrings.OverridingWithKestrelOptions, joined); - return base.BindAsync(context); + return base.BindAsync(context, cancellationToken); } } @@ -217,11 +218,11 @@ public EndpointsStrategy(IReadOnlyCollection endpoints) _endpoints = endpoints; } - public virtual async Task BindAsync(AddressBindContext context) + public virtual async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken) { foreach (var endpoint in _endpoints) { - await endpoint.BindAsync(context).ConfigureAwait(false); + await endpoint.BindAsync(context, cancellationToken).ConfigureAwait(false); } } } @@ -235,7 +236,7 @@ public AddressesStrategy(IReadOnlyCollection addresses) _addresses = addresses; } - public virtual async Task BindAsync(AddressBindContext context) + public virtual async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken) { foreach (var address in _addresses) { @@ -247,7 +248,7 @@ public virtual async Task BindAsync(AddressBindContext context) options.UseHttps(); } - await options.BindAsync(context).ConfigureAwait(false); + await options.BindAsync(context, cancellationToken).ConfigureAwait(false); } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TransportManager.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TransportManager.cs index 4392f6618712..470140fdd9a0 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TransportManager.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TransportManager.cs @@ -37,19 +37,19 @@ public TransportManager( private ConnectionManager ConnectionManager => _serviceContext.ConnectionManager; private IKestrelTrace Trace => _serviceContext.Log; - public async Task BindAsync(EndPoint endPoint, ConnectionDelegate connectionDelegate, EndpointConfig? endpointConfig) + public async Task BindAsync(EndPoint endPoint, ConnectionDelegate connectionDelegate, EndpointConfig? endpointConfig, CancellationToken cancellationToken) { if (_transportFactory is null) { throw new InvalidOperationException($"Cannot bind with {nameof(ConnectionDelegate)} no {nameof(IConnectionListenerFactory)} is registered."); } - var transport = await _transportFactory.BindAsync(endPoint).ConfigureAwait(false); + var transport = await _transportFactory.BindAsync(endPoint, cancellationToken).ConfigureAwait(false); StartAcceptLoop(new GenericConnectionListener(transport), c => connectionDelegate(c), endpointConfig); return transport.EndPoint; } - public async Task BindAsync(EndPoint endPoint, MultiplexedConnectionDelegate multiplexedConnectionDelegate, ListenOptions listenOptions) + public async Task BindAsync(EndPoint endPoint, MultiplexedConnectionDelegate multiplexedConnectionDelegate, ListenOptions listenOptions, CancellationToken cancellationToken) { if (_multiplexedTransportFactory is null) { @@ -69,7 +69,7 @@ public async Task BindAsync(EndPoint endPoint, MultiplexedConnectionDe features.Set(sslServerAuthenticationOptions); } - var transport = await _multiplexedTransportFactory.BindAsync(endPoint, features).ConfigureAwait(false); + var transport = await _multiplexedTransportFactory.BindAsync(endPoint, features, cancellationToken).ConfigureAwait(false); StartAcceptLoop(new GenericMultiplexedConnectionListener(transport), c => multiplexedConnectionDelegate(c), listenOptions.EndpointConfig); return transport.EndPoint; } diff --git a/src/Servers/Kestrel/Core/src/Internal/KestrelServerImpl.cs b/src/Servers/Kestrel/Core/src/Internal/KestrelServerImpl.cs index 9d429b749e7e..37dc5394a873 100644 --- a/src/Servers/Kestrel/Core/src/Internal/KestrelServerImpl.cs +++ b/src/Servers/Kestrel/Core/src/Internal/KestrelServerImpl.cs @@ -160,7 +160,7 @@ public async Task StartAsync(IHttpApplication application, C ServiceContext.Heartbeat?.Start(); - async Task OnBind(ListenOptions options) + async Task OnBind(ListenOptions options, CancellationToken onBindCancellationToken) { // INVESTIGATE: For some reason, MsQuic needs to bind before // sockets for it to successfully listen. It also seems racy. @@ -177,7 +177,7 @@ async Task OnBind(ListenOptions options) // Add the connection limit middleware multiplexedConnectionDelegate = EnforceConnectionLimit(multiplexedConnectionDelegate, Options.Limits.MaxConcurrentConnections, Trace); - options.EndPoint = await _transportManager.BindAsync(options.EndPoint, multiplexedConnectionDelegate, options).ConfigureAwait(false); + options.EndPoint = await _transportManager.BindAsync(options.EndPoint, multiplexedConnectionDelegate, options, onBindCancellationToken).ConfigureAwait(false); } // Add the HTTP middleware as the terminal connection middleware @@ -197,7 +197,7 @@ async Task OnBind(ListenOptions options) // Add the connection limit middleware connectionDelegate = EnforceConnectionLimit(connectionDelegate, Options.Limits.MaxConcurrentConnections, Trace); - options.EndPoint = await _transportManager.BindAsync(options.EndPoint, connectionDelegate, options.EndpointConfig).ConfigureAwait(false); + options.EndPoint = await _transportManager.BindAsync(options.EndPoint, connectionDelegate, options.EndpointConfig, onBindCancellationToken).ConfigureAwait(false); } } @@ -275,7 +275,7 @@ private async Task BindAsync(CancellationToken cancellationToken) Options.ConfigurationLoader?.Load(); - await AddressBinder.BindAsync(Options.ListenOptions, AddressBindContext!).ConfigureAwait(false); + await AddressBinder.BindAsync(Options.ListenOptions, AddressBindContext!, cancellationToken).ConfigureAwait(false); _configChangedRegistration = reloadToken?.RegisterChangeCallback(TriggerRebind, this); } finally @@ -342,8 +342,7 @@ private async Task RebindAsync() { try { - // TODO: This should probably be canceled by the _stopCts too, but we don't currently support bind cancellation even in StartAsync(). - await listenOption.BindAsync(AddressBindContext!).ConfigureAwait(false); + await listenOption.BindAsync(AddressBindContext!, _stopCts.Token).ConfigureAwait(false); } catch (Exception ex) { diff --git a/src/Servers/Kestrel/Core/src/ListenOptions.cs b/src/Servers/Kestrel/Core/src/ListenOptions.cs index 40d66cebfb04..3f434584d123 100644 --- a/src/Servers/Kestrel/Core/src/ListenOptions.cs +++ b/src/Servers/Kestrel/Core/src/ListenOptions.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Net; using System.Net.Sockets; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Experimental; @@ -176,9 +177,9 @@ MultiplexedConnectionDelegate IMultiplexedConnectionBuilder.Build() return app; } - internal virtual async Task BindAsync(AddressBindContext context) + internal virtual async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken) { - await AddressBinder.BindEndpointAsync(this, context).ConfigureAwait(false); + await AddressBinder.BindEndpointAsync(this, context, cancellationToken).ConfigureAwait(false); context.Addresses.Add(GetDisplayName()); } } diff --git a/src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs b/src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs index 8c9978a3aa38..56581f064b70 100644 --- a/src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs +++ b/src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.IO; using System.Net; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.Extensions.Logging; @@ -30,16 +31,16 @@ internal override string GetDisplayName() return $"{Scheme}://localhost:{IPEndPoint!.Port}"; } - internal override async Task BindAsync(AddressBindContext context) + internal override async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken) { var exceptions = new List(); try { var v4Options = Clone(IPAddress.Loopback); - await AddressBinder.BindEndpointAsync(v4Options, context).ConfigureAwait(false); + await AddressBinder.BindEndpointAsync(v4Options, context, cancellationToken).ConfigureAwait(false); } - catch (Exception ex) when (!(ex is IOException)) + catch (Exception ex) when (!(ex is IOException or OperationCanceledException)) { context.Logger.LogInformation(0, CoreStrings.NetworkInterfaceBindingFailed, GetDisplayName(), "IPv4 loopback", ex.Message); exceptions.Add(ex); @@ -48,9 +49,9 @@ internal override async Task BindAsync(AddressBindContext context) try { var v6Options = Clone(IPAddress.IPv6Loopback); - await AddressBinder.BindEndpointAsync(v6Options, context).ConfigureAwait(false); + await AddressBinder.BindEndpointAsync(v6Options, context, cancellationToken).ConfigureAwait(false); } - catch (Exception ex) when (!(ex is IOException)) + catch (Exception ex) when (!(ex is IOException or OperationCanceledException)) { context.Logger.LogInformation(0, CoreStrings.NetworkInterfaceBindingFailed, GetDisplayName(), "IPv6 loopback", ex.Message); exceptions.Add(ex); diff --git a/src/Servers/Kestrel/Core/test/AddressBinderTests.cs b/src/Servers/Kestrel/Core/test/AddressBinderTests.cs index 8372c2215609..c63ec93dcbe5 100644 --- a/src/Servers/Kestrel/Core/test/AddressBinderTests.cs +++ b/src/Servers/Kestrel/Core/test/AddressBinderTests.cs @@ -6,6 +6,7 @@ using System.IO; using System.Net; using System.Net.Sockets; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Hosting; @@ -127,7 +128,7 @@ public async Task WrapsAddressInUseExceptionAsIOException() endpoint => throw new AddressInUseException("already in use")); await Assert.ThrowsAsync(() => - AddressBinder.BindAsync(options.ListenOptions, addressBindContext)); + AddressBinder.BindAsync(options.ListenOptions, addressBindContext, CancellationToken.None)); } [Fact] @@ -148,7 +149,7 @@ public void LogsWarningWhenHostingAddressesAreOverridden() logger, endpoint => Task.CompletedTask); - var bindTask = AddressBinder.BindAsync(options.ListenOptions, addressBindContext); + var bindTask = AddressBinder.BindAsync(options.ListenOptions, addressBindContext, CancellationToken.None); Assert.True(bindTask.IsCompletedSuccessfully); var log = Assert.Single(logger.Messages); @@ -176,7 +177,7 @@ public void LogsInformationWhenKestrelAddressesAreOverridden() addressBindContext.ServerAddressesFeature.PreferHostingUrls = true; - var bindTask = AddressBinder.BindAsync(options.ListenOptions, addressBindContext); + var bindTask = AddressBinder.BindAsync(options.ListenOptions, addressBindContext, CancellationToken.None); Assert.True(bindTask.IsCompletedSuccessfully); var log = Assert.Single(logger.Messages); @@ -184,6 +185,27 @@ public void LogsInformationWhenKestrelAddressesAreOverridden() Assert.Equal(CoreStrings.FormatOverridingWithPreferHostingUrls(nameof(addressBindContext.ServerAddressesFeature.PreferHostingUrls), overriddenAddress), log.Message); } + [Fact] + public async Task FlowsCancellationTokenToCreateBinddingCallback() + { + var addresses = new ServerAddressesFeature(); + addresses.InternalCollection.Add("http://localhost:5000"); + var options = new KestrelServerOptions(); + + var addressBindContext = TestContextFactory.CreateAddressBindContext( + addresses, + options, + NullLogger.Instance, + (endpoint, cancellationToken) => + { + cancellationToken.ThrowIfCancellationRequested(); + return Task.CompletedTask; + }); + + await Assert.ThrowsAsync(() => + AddressBinder.BindAsync(options.ListenOptions, addressBindContext, new CancellationToken(true))); + } + [Theory] [InlineData("http://*:80")] [InlineData("http://+:80")] @@ -218,7 +240,7 @@ public async Task FallbackToIPv4WhenIPv6AnyBindFails(string address) return Task.CompletedTask; }); - await AddressBinder.BindAsync(options.ListenOptions, addressBindContext); + await AddressBinder.BindAsync(options.ListenOptions, addressBindContext, CancellationToken.None); Assert.True(ipV4Attempt, "Should have attempted to bind to IPAddress.Any"); Assert.True(ipV6Attempt, "Should have attempted to bind to IPAddress.IPv6Any"); @@ -260,7 +282,7 @@ public async Task DefaultAddressBinderWithoutDevCertButHttpsConfiguredBindsToHtt return Task.CompletedTask; }); - await AddressBinder.BindAsync(options.ListenOptions, addressBindContext); + await AddressBinder.BindAsync(options.ListenOptions, addressBindContext, CancellationToken.None); Assert.Contains(endpoints, e => e.IPEndPoint.Port == 5000 && !e.IsTls); Assert.Contains(endpoints, e => e.IPEndPoint.Port == 5001 && e.IsTls); diff --git a/src/Servers/Kestrel/shared/test/TestContextFactory.cs b/src/Servers/Kestrel/shared/test/TestContextFactory.cs index 902ea7cebaad..09687a3fd7c0 100644 --- a/src/Servers/Kestrel/shared/test/TestContextFactory.cs +++ b/src/Servers/Kestrel/shared/test/TestContextFactory.cs @@ -100,6 +100,21 @@ public static AddressBindContext CreateAddressBindContext( KestrelServerOptions serverOptions, ILogger logger, Func createBinding) + { + var context = new AddressBindContext( + serverAddressesFeature, + serverOptions, + logger, + (listenOptions, cancellationToken) => createBinding(listenOptions)); + + return context; + } + + public static AddressBindContext CreateAddressBindContext( + ServerAddressesFeature serverAddressesFeature, + KestrelServerOptions serverOptions, + ILogger logger, + Func createBinding) { var context = new AddressBindContext( serverAddressesFeature,