diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs b/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs index f684e9b871d0d8..b4ca39ab4536bd 100644 --- a/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs +++ b/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs @@ -628,9 +628,6 @@ static async Task CompleteAsync(Task task, string hostName, ValueStopwatch st } } - private static Task RunAsync(Func func, object arg, CancellationToken cancellationToken) => - Task.Factory.StartNew(func!, arg, cancellationToken, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); - private static IPHostEntry CreateHostEntryForAddress(IPAddress address) => new IPHostEntry { @@ -656,5 +653,71 @@ private static bool LogFailure(ValueStopwatch stopwatch) NameResolutionTelemetry.Log.AfterResolution(stopwatch, successful: false); return false; } + + /// Mapping from key to current task in flight for that key. + private static readonly Dictionary s_tasks = new Dictionary(); + + /// Queue the function to be invoked asynchronously. + /// + /// Since this is doing synchronous work on a thread pool thread, we want to limit how many threads end up being + /// blocked. We could employ a semaphore to limit overall usage, but a common case is that DNS requests are made + /// for only a handful of endpoints, and a reasonable compromise is to ensure that requests for a given host are + /// serialized. Once the data for that host is cached locally by the OS, the subsequent requests should all complete + /// very quickly, and if the head-of-line request is taking a long time due to the connection to the server, we won't + /// block lots of threads all getting data for that one host. We also still want to issue the request to the OS, rather + /// than having all concurrent requests for the same host share the exact same task, so that any shuffling of the results + /// by the OS to enable round robin is still perceived. + /// + private static Task RunAsync(Func func, object key, CancellationToken cancellationToken) + { + Task? task = null; + + lock (s_tasks) + { + // Get the previous task for this key, if there is one. + s_tasks.TryGetValue(key, out Task? prevTask); + prevTask ??= Task.CompletedTask; + + // Invoke the function in a queued work item when the previous task completes. Note that some callers expect the + // returned task to have the key as the task's AsyncState. + task = prevTask.ContinueWith(delegate + { + Debug.Assert(!Monitor.IsEntered(s_tasks)); + try + { + return func(key); + } + finally + { + // When the work is done, remove this key/task pair from the dictionary if this is still the current task. + // Because the work item is created and stored into both the local and the dictionary while the lock is + // held, and since we take the same lock here, inside this lock it's guaranteed to see the changes + // made by the call site. + lock (s_tasks) + { + ((ICollection>)s_tasks).Remove(new KeyValuePair(key!, task!)); + } + } + }, key, cancellationToken, TaskContinuationOptions.DenyChildAttach, TaskScheduler.Default); + + // If it's possible the task may end up getting canceled, it won't have a chance to remove itself from + // the dictionary if it is canceled, so use a separate continuation to do so. + if (cancellationToken.CanBeCanceled) + { + task.ContinueWith((task, key) => + { + lock (s_tasks) + { + ((ICollection>)s_tasks).Remove(new KeyValuePair(key!, task)); + } + }, key, CancellationToken.None, TaskContinuationOptions.OnlyOnCanceled | TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + } + + // Finally, store the task into the dictionary as the current task for this key. + s_tasks[key] = task; + } + + return task; + } } }