diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.Cache.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.Cache.cs new file mode 100644 index 00000000000000..4fc86adfbd1565 --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.Cache.cs @@ -0,0 +1,235 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Collections.Generic; +using System.Collections.Concurrent; +using System.Collections.ObjectModel; +using System.Security.Authentication; +using System.Net.Security; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using Microsoft.Quic; + +namespace System.Net.Quic; + +internal static partial class MsQuicConfiguration +{ + private const int CheckExpiredModulo = 32; + + private const string DisableCacheEnvironmentVariable = "DOTNET_SYSTEM_NET_QUIC_DISABLE_CONFIGURATION_CACHE"; + private const string DisableCacheCtxSwitch = "System.Net.Quic.DisableConfigurationCache"; + + internal static bool ConfigurationCacheEnabled { get; } = GetConfigurationCacheEnabled(); + private static bool GetConfigurationCacheEnabled() + { + // AppContext switch takes precedence + if (AppContext.TryGetSwitch(DisableCacheCtxSwitch, out bool value)) + { + return !value; + } + else + { + // check environment variable + return + Environment.GetEnvironmentVariable(DisableCacheEnvironmentVariable) is string envVar && + !(envVar == "1" || envVar.Equals("true", StringComparison.OrdinalIgnoreCase)); + } + } + + private static readonly ConcurrentDictionary s_configurationCache = new(); + + private readonly struct CacheKey : IEquatable + { + public readonly List CertificateThumbprints; + public readonly QUIC_CREDENTIAL_FLAGS Flags; + public readonly QUIC_SETTINGS Settings; + public readonly List ApplicationProtocols; + public readonly QUIC_ALLOWED_CIPHER_SUITE_FLAGS AllowedCipherSuites; + + public CacheKey(QUIC_SETTINGS settings, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, ReadOnlyCollection? intermediates, List alpnProtocols, QUIC_ALLOWED_CIPHER_SUITE_FLAGS allowedCipherSuites) + { + CertificateThumbprints = certificate == null ? new List() : new List { certificate.GetCertHash() }; + + if (intermediates != null) + { + foreach (X509Certificate2 intermediate in intermediates) + { + CertificateThumbprints.Add(intermediate.GetCertHash()); + } + } + + Flags = flags; + Settings = settings; + // make defensive copy to prevent modification (the list comes from user code) + ApplicationProtocols = new List(alpnProtocols); + AllowedCipherSuites = allowedCipherSuites; + } + + public override bool Equals(object? obj) => obj is CacheKey key && Equals(key); + + public bool Equals(CacheKey other) + { + if (CertificateThumbprints.Count != other.CertificateThumbprints.Count) + { + return false; + } + + for (int i = 0; i < CertificateThumbprints.Count; i++) + { + if (!CertificateThumbprints[i].AsSpan().SequenceEqual(other.CertificateThumbprints[i])) + { + return false; + } + } + + if (ApplicationProtocols.Count != other.ApplicationProtocols.Count) + { + return false; + } + + for (int i = 0; i < ApplicationProtocols.Count; i++) + { + if (ApplicationProtocols[i] != other.ApplicationProtocols[i]) + { + return false; + } + } + + return + Flags == other.Flags && + Settings.Equals(other.Settings) && + AllowedCipherSuites == other.AllowedCipherSuites; + } + + public override int GetHashCode() + { + HashCode hash = default; + + foreach (var thumbprint in CertificateThumbprints) + { + hash.AddBytes(thumbprint); + } + + hash.Add(Flags); + hash.Add(Settings); + + foreach (var protocol in ApplicationProtocols) + { + hash.AddBytes(protocol.Protocol.Span); + } + + hash.Add(AllowedCipherSuites); + + return hash.ToHashCode(); + } + } + + private static MsQuicConfigurationSafeHandle GetCachedCredentialOrCreate(QUIC_SETTINGS settings, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, ReadOnlyCollection? intermediates, List alpnProtocols, QUIC_ALLOWED_CIPHER_SUITE_FLAGS allowedCipherSuites) + { + CacheKey key = new CacheKey(settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites); + + MsQuicConfigurationSafeHandle? handle; + + if (s_configurationCache.TryGetValue(key, out handle) && handle.TryAddRentCount()) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(null, $"Found cached MsQuicConfiguration: {handle}."); + } + return handle; + } + + // if we get here, the handle is either not in the cache, or we lost the race between + // TryAddRentCount on this thread and MarkForDispose on another thread doing cache cleanup. + // In either case, we need to create a new handle. + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(null, $"MsQuicConfiguration not found in cache, creating new."); + } + + handle = CreateInternal(settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites); + handle.TryAddRentCount(); // we are the first renter + + MsQuicConfigurationSafeHandle cached; + do + { + cached = s_configurationCache.GetOrAdd(key, handle); + } + // If we get the same handle back, we successfully added it to the cache and we are done. + // If we get a different handle back, we need to increase the rent count. + // If we fail to add the rent count, then the existing/cached handle is in process of + // being removed from the cache and we can try again, eventually either succeeding to add our + // new handle or getting a fresh handle inserted by another thread meanwhile. + while (cached != handle && !cached.TryAddRentCount()); + + if (cached != handle) + { + // we lost a race with another thread to insert new handle into the cache + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(null, $"Discarding MsQuicConfiguration {handle} (preferring cached {cached})."); + } + + // First dispose decrements the rent count we added before attempting the cache insertion + // and second closes the handle + handle.Dispose(); + handle.Dispose(); + Debug.Assert(handle.IsClosed); + + return cached; + } + + // we added a new handle, check if we need to cleanup + var count = s_configurationCache.Count; + if (count % CheckExpiredModulo == 0) + { + // let only one thread perform cleanup at a time + lock (s_configurationCache) + { + // check again, if another thread just cleaned up (and cached count went down) we are unlikely + // to clean anything + if (s_configurationCache.Count >= count) + { + CleanupCache(); + } + } + } + + return handle; + } + + private static void CleanupCache() + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(null, $"Cleaning up MsQuicConfiguration cache, current size: {s_configurationCache.Count}."); + } + + foreach ((CacheKey key, MsQuicConfigurationSafeHandle handle) in s_configurationCache) + { + if (!handle.TryMarkForDispose()) + { + // handle in use + continue; + } + + // the handle is not in use and has been marked such that no new rents can be added. + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(null, $"Removing cached MsQuicConfiguration {handle}."); + } + + bool removed = s_configurationCache.TryRemove(key, out _); + Debug.Assert(removed); + handle.Dispose(); + Debug.Assert(handle.IsClosed); + } + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(null, $"Cleaning up MsQuicConfiguration cache, new size: {s_configurationCache.Count}."); + } + } +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs index 7e527e41bf9564..e99f1a68ae9ec5 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs @@ -11,12 +11,12 @@ namespace System.Net.Quic; -internal static class MsQuicConfiguration +internal static partial class MsQuicConfiguration { private static bool HasPrivateKey(this X509Certificate certificate) => certificate is X509Certificate2 certificate2 && certificate2.Handle != IntPtr.Zero && certificate2.HasPrivateKey; - public static MsQuicSafeHandle Create(QuicClientConnectionOptions options) + public static MsQuicConfigurationSafeHandle Create(QuicClientConnectionOptions options) { SslClientAuthenticationOptions authenticationOptions = options.ClientAuthenticationOptions; @@ -79,7 +79,7 @@ public static MsQuicSafeHandle Create(QuicClientConnectionOptions options) return Create(options, flags, certificate, intermediates, authenticationOptions.ApplicationProtocols, authenticationOptions.CipherSuitesPolicy, authenticationOptions.EncryptionPolicy); } - public static MsQuicSafeHandle Create(QuicServerConnectionOptions options, string? targetHost) + public static MsQuicConfigurationSafeHandle Create(QuicServerConnectionOptions options, string? targetHost) { SslServerAuthenticationOptions authenticationOptions = options.ServerAuthenticationOptions; @@ -117,7 +117,7 @@ public static MsQuicSafeHandle Create(QuicServerConnectionOptions options, strin return Create(options, flags, certificate, intermediates, authenticationOptions.ApplicationProtocols, authenticationOptions.CipherSuitesPolicy, authenticationOptions.EncryptionPolicy); } - private static unsafe MsQuicSafeHandle Create(QuicConnectionOptions options, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, ReadOnlyCollection? intermediates, List? alpnProtocols, CipherSuitesPolicy? cipherSuitesPolicy, EncryptionPolicy encryptionPolicy) + private static MsQuicConfigurationSafeHandle Create(QuicConnectionOptions options, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, ReadOnlyCollection? intermediates, List? alpnProtocols, CipherSuitesPolicy? cipherSuitesPolicy, EncryptionPolicy encryptionPolicy) { // Validate options and SSL parameters. if (alpnProtocols is null || alpnProtocols.Count <= 0) @@ -176,6 +176,29 @@ private static unsafe MsQuicSafeHandle Create(QuicConnectionOptions options, QUI : 0; // 0 disables the timeout } + QUIC_ALLOWED_CIPHER_SUITE_FLAGS allowedCipherSuites = QUIC_ALLOWED_CIPHER_SUITE_FLAGS.NONE; + + if (cipherSuitesPolicy != null) + { + flags |= QUIC_CREDENTIAL_FLAGS.SET_ALLOWED_CIPHER_SUITES; + allowedCipherSuites = CipherSuitePolicyToFlags(cipherSuitesPolicy); + } + + if (!MsQuicApi.UsesSChannelBackend) + { + flags |= QUIC_CREDENTIAL_FLAGS.USE_PORTABLE_CERTIFICATES; + } + + if (ConfigurationCacheEnabled) + { + return GetCachedCredentialOrCreate(settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites); + } + + return CreateInternal(settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites); + } + + private static unsafe MsQuicConfigurationSafeHandle CreateInternal(QUIC_SETTINGS settings, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, ReadOnlyCollection? intermediates, List alpnProtocols, QUIC_ALLOWED_CIPHER_SUITE_FLAGS allowedCipherSuites) + { QUIC_HANDLE* handle; using MsQuicBuffers msquicBuffers = new MsQuicBuffers(); @@ -183,24 +206,21 @@ private static unsafe MsQuicSafeHandle Create(QuicConnectionOptions options, QUI ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ConfigurationOpen( MsQuicApi.Api.Registration, msquicBuffers.Buffers, - (uint)alpnProtocols.Count, + (uint)msquicBuffers.Count, &settings, (uint)sizeof(QUIC_SETTINGS), (void*)IntPtr.Zero, &handle), "ConfigurationOpen failed"); - MsQuicSafeHandle configurationHandle = new MsQuicSafeHandle(handle, SafeHandleType.Configuration); + MsQuicConfigurationSafeHandle configurationHandle = new MsQuicConfigurationSafeHandle(handle); try { - QUIC_CREDENTIAL_CONFIG config = new QUIC_CREDENTIAL_CONFIG { Flags = flags }; - config.Flags |= (MsQuicApi.UsesSChannelBackend ? QUIC_CREDENTIAL_FLAGS.NONE : QUIC_CREDENTIAL_FLAGS.USE_PORTABLE_CERTIFICATES); - - if (cipherSuitesPolicy != null) + QUIC_CREDENTIAL_CONFIG config = new QUIC_CREDENTIAL_CONFIG { - config.Flags |= QUIC_CREDENTIAL_FLAGS.SET_ALLOWED_CIPHER_SUITES; - config.AllowedCipherSuites = CipherSuitePolicyToFlags(cipherSuitesPolicy); - } + Flags = flags, + AllowedCipherSuites = allowedCipherSuites + }; int status; if (certificate is null) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs index 38a099ed9e49f3..cf7d70a18e08d1 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Runtime.InteropServices; +using System.Threading; using Microsoft.Quic; namespace System.Net.Quic; @@ -52,7 +53,8 @@ public MsQuicSafeHandle(QUIC_HANDLE* handle, SafeHandleType safeHandleType) SafeHandleType.Stream => MsQuicApi.Api.ApiTable->StreamClose, _ => throw new ArgumentException($"Unexpected value: {safeHandleType}", nameof(safeHandleType)) }, - safeHandleType) { } + safeHandleType) + { } protected override bool ReleaseHandle() { @@ -142,3 +144,46 @@ protected override unsafe bool ReleaseHandle() return true; } } + +internal sealed class MsQuicConfigurationSafeHandle : MsQuicSafeHandle +{ + // MsQuicConfiguration handles are cached, so we need to keep track of the + // number of times a handle is rented. Once we decide to dispose the handle, + // we set the _rentCount to -1. + private volatile int _rentCount; + + public unsafe MsQuicConfigurationSafeHandle(QUIC_HANDLE* handle) + : base(handle, SafeHandleType.Configuration) { } + + public bool TryAddRentCount() + { + int oldCount; + + do + { + oldCount = _rentCount; + if (oldCount < 0) + { + // The handle is already disposed. + return false; + } + } while (Interlocked.CompareExchange(ref _rentCount, oldCount + 1, oldCount) != oldCount); + + return true; + } + + public bool TryMarkForDispose() + { + return Interlocked.CompareExchange(ref _rentCount, -1, 0) == 0; + } + + protected override void Dispose(bool disposing) + { + if (Interlocked.Decrement(ref _rentCount) < 0) + { + // _rentCount is 0 if the handle was never rented (e.g. failure during creation), + // and is -1 when evicted from cache. + base.Dispose(disposing); + } + } +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs index 5a4f626e2f5465..0846543a6aee66 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs @@ -676,7 +676,6 @@ public async ValueTask DisposeAsync() Debug.Assert(_connectedTcs.IsCompleted); _handle.Dispose(); _shutdownTokenSource.Dispose(); - _configuration?.Dispose(); // Dispose remote certificate only if it hasn't been accessed via getter, in which case the accessing code becomes the owner of the certificate lifetime.