diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/ConditionalWeakTable.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/ConditionalWeakTable.cs index 2318211cb7e9f3..eb6af90b5a5043 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/ConditionalWeakTable.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/ConditionalWeakTable.cs @@ -48,10 +48,7 @@ public ConditionalWeakTable() /// If the key is not found, contains default(TValue). /// /// Returns "true" if key was found, "false" otherwise. - /// - /// The key may get garbage collected during the TryGetValue operation. If so, TryGetValue - /// may at its discretion, return "false" and set "value" to the default (as if the key was not present.) - /// + /// is . public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value) { if (key is null) @@ -65,12 +62,8 @@ public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value) /// Adds a key to the table. /// key to add. May not be null. /// value to associate with key. - /// - /// If the key is already entered into the dictionary, this method throws an exception. - /// The key may get garbage collected during the Add() operation. If so, Add() - /// has the right to consider any prior entries successfully removed and add a new entry without - /// throwing an exception. - /// + /// is . + /// is already entered into the dictionary. public void Add(TKey key, TValue value) { if (key is null) @@ -94,6 +87,7 @@ public void Add(TKey key, TValue value) /// The key to add. /// The key's property value. /// true if the key/value pair was added; false if the table already contained the key. + /// is . public bool TryAdd(TKey key, TValue value) { if (key is null) @@ -117,6 +111,7 @@ public bool TryAdd(TKey key, TValue value) /// Adds the key and value if the key doesn't exist, or updates the existing key's value if it does exist. /// key to add or update. May not be null. /// value to associate with key. + /// is . public void AddOrUpdate(TKey key, TValue value) { if (key is null) @@ -141,13 +136,9 @@ public void AddOrUpdate(TKey key, TValue value) } /// Removes a key and its value from the table. - /// key to remove. May not be null. - /// true if the key is found and removed. Returns false if the key was not in the dictionary. - /// - /// The key may get garbage collected during the Remove() operation. If so, - /// Remove() will not fail or throw, however, the return value can be either true or false - /// depending on who wins the race. - /// + /// The key to remove. + /// if the key is found and removed; otherwise, . + /// is . public bool Remove(TKey key) { if (key is null) @@ -157,7 +148,25 @@ public bool Remove(TKey key) lock (_lock) { - return _container.Remove(key); + return _container.Remove(key, out _); + } + } + + /// Removes a key and its value from the table, and returns the removed value if it was present. + /// The key to remove. + /// value removed from the table, if it was present. + /// if the key is found and removed; otherwise, . + /// is . + public bool Remove(TKey key, [MaybeNullWhen(false)] out TValue value) + { + if (key is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + } + + lock (_lock) + { + return _container.Remove(key, out value); } } @@ -692,17 +701,19 @@ internal void RemoveAllKeys() } /// Removes the specified key from the table, if it exists. - internal bool Remove(TKey key) + internal bool Remove(TKey key, [MaybeNullWhen(false)] out TValue value) { VerifyIntegrity(); - int entryIndex = FindEntry(key, out _); + int entryIndex = FindEntry(key, out object? valueObject); if (entryIndex != -1) { RemoveIndex(entryIndex); + value = Unsafe.As(valueObject); return true; } + value = null; return false; } diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index b3ead528d620cb..88b39b898bf838 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -13284,6 +13284,7 @@ public void Clear() { } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] public TValue GetValue(TKey key, System.Runtime.CompilerServices.ConditionalWeakTable.CreateValueCallback createValueCallback) { throw null; } public bool Remove(TKey key) { throw null; } + public bool Remove(TKey key, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TValue value) { throw null; } System.Collections.Generic.IEnumerator> System.Collections.Generic.IEnumerable>.GetEnumerator() { throw null; } System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; } public bool TryAdd(TKey key, TValue value) { throw null; } diff --git a/src/libraries/System.Runtime/tests/System.Runtime.Tests/System/Runtime/CompilerServices/ConditionalWeakTableTests.cs b/src/libraries/System.Runtime/tests/System.Runtime.Tests/System/Runtime/CompilerServices/ConditionalWeakTableTests.cs index c4f70e4ab37687..06487e90aec3e8 100644 --- a/src/libraries/System.Runtime/tests/System.Runtime.Tests/System/Runtime/CompilerServices/ConditionalWeakTableTests.cs +++ b/src/libraries/System.Runtime/tests/System.Runtime.Tests/System/Runtime/CompilerServices/ConditionalWeakTableTests.cs @@ -22,6 +22,7 @@ public static void InvalidArgs_Throws() AssertExtensions.Throws("key", () => cwt.Add(null, new object())); // null key AssertExtensions.Throws("key", () => cwt.TryGetValue(null, out ignored)); // null key AssertExtensions.Throws("key", () => cwt.Remove(null)); // null key + AssertExtensions.Throws("key", () => cwt.Remove(null, out _)); // null key AssertExtensions.Throws("key", () => cwt.GetOrAdd(null, new object())); // null key AssertExtensions.Throws("key", () => cwt.GetOrAdd(null, k => new object())); // null key AssertExtensions.Throws("key", () => cwt.GetOrAdd(null, (k, a) => new object(), 42)); // null key @@ -171,6 +172,39 @@ public static void AddMany_ThenRemoveAll(int numObjects) } } + [Theory] + [InlineData(1)] + [InlineData(100)] + public static void AddMany_ThenRemoveAll_ValidateRemovedValue(int numObjects) + { + object[] keys = Enumerable.Range(0, numObjects).Select(_ => new object()).ToArray(); + object[] values = Enumerable.Range(0, numObjects).Select(_ => new object()).ToArray(); + var cwt = new ConditionalWeakTable(); + + for (int i = 0; i < numObjects; i++) + { + cwt.Add(keys[i], values[i]); + } + + for (int i = 0; i < numObjects; i++) + { + Assert.Same(values[i], cwt.GetValue(keys[i], _ => new object())); + } + + for (int i = 0; i < numObjects; i++) + { + Assert.True(cwt.Remove(keys[i], out var value)); + Assert.False(cwt.Remove(keys[i], out _)); + Assert.Same(values[i], value); + } + + for (int i = 0; i < numObjects; i++) + { + object ignored; + Assert.False(cwt.TryGetValue(keys[i], out ignored)); + } + } + [Theory] [InlineData(100)] public static void AddRemoveIteratively(int numObjects) @@ -188,6 +222,24 @@ public static void AddRemoveIteratively(int numObjects) } } + [Theory] + [InlineData(100)] + public static void AddRemoveIteratively_ValidateRemovedValue(int numObjects) + { + object[] keys = Enumerable.Range(0, numObjects).Select(_ => new object()).ToArray(); + object[] values = Enumerable.Range(0, numObjects).Select(_ => new object()).ToArray(); + var cwt = new ConditionalWeakTable(); + + for (int i = 0; i < numObjects; i++) + { + cwt.Add(keys[i], values[i]); + Assert.Same(values[i], cwt.GetValue(keys[i], _ => new object())); + Assert.True(cwt.Remove(keys[i], out var value)); + Assert.False(cwt.Remove(keys[i], out _)); + Assert.Same(values[i], value); + } + } + [Fact] public static void Concurrent_AddMany_DropReferences() // no asserts, just making nothing throws { @@ -218,6 +270,26 @@ public static void Concurrent_Add_Read_Remove_DifferentObjects() }); } + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public static void Concurrent_Add_Read_Remove_ValidateRemovedValue_DifferentObjects() + { + var cwt = new ConditionalWeakTable(); + DateTime end = DateTime.UtcNow + TimeSpan.FromSeconds(0.25); + Parallel.For(0, Environment.ProcessorCount, i => + { + while (DateTime.UtcNow < end) + { + object key = new object(); + object value = new object(); + cwt.Add(key, value); + Assert.Same(value, cwt.GetValue(key, _ => new object())); + Assert.True(cwt.Remove(key, out var removedValue)); + Assert.False(cwt.Remove(key, out _)); + Assert.Same(value, removedValue); + } + }); + } + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] public static void Concurrent_GetOrAdd_Add_Read_Remove_DifferentObjects() {