diff --git a/src/mscorlib/shared/System/Collections/Generic/Dictionary.cs b/src/mscorlib/shared/System/Collections/Generic/Dictionary.cs index 8792119c0460..aba1ab59f5cc 100644 --- a/src/mscorlib/shared/System/Collections/Generic/Dictionary.cs +++ b/src/mscorlib/shared/System/Collections/Generic/Dictionary.cs @@ -62,6 +62,8 @@ private struct Entry private const string KeyValuePairsName = "KeyValuePairs"; // Do not rename (binary serialization) private const string ComparerName = "Comparer"; // Do not rename (binary serialization) + private static Entry s_nullEntry; + public Dictionary() : this(0, null) { } public Dictionary(int capacity) : this(capacity, null) { } @@ -72,10 +74,11 @@ public Dictionary(int capacity, IEqualityComparer comparer) { if (capacity < 0) ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.capacity); if (capacity > 0) Initialize(capacity); - this.comparer = comparer ?? EqualityComparer.Default; + this.comparer = comparer; - if (this.comparer == EqualityComparer.Default) + if (typeof(TKey) == typeof(string) && (comparer == null || this.comparer == EqualityComparer.Default)) { + // To start, move off default comparer for string which is randomised this.comparer = (IEqualityComparer)NonRandomizedStringEqualityComparer.Default; } } @@ -139,13 +142,7 @@ protected Dictionary(SerializationInfo info, StreamingContext context) HashHelpers.SerializationInfoTable.Add(this, info); } - public IEqualityComparer Comparer - { - get - { - return comparer; - } - } + public IEqualityComparer Comparer => comparer ?? EqualityComparer.Default; public int Count { @@ -210,13 +207,25 @@ public TValue this[TKey key] { get { - int i = FindEntry(key); - if (i >= 0) return entries[i].value; - ThrowHelper.ThrowKeyNotFoundException(key); - return default(TValue); + if (key == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + } + + ref Entry entry = ref FindEntry(key, out bool found); + if (!found) + { + ThrowHelper.ThrowKeyNotFoundException(key); + } + return entry.value; } set { + if (key == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + } + bool modified = TryInsert(key, value, InsertionBehavior.OverwriteExisting); Debug.Assert(modified); } @@ -224,6 +233,11 @@ public TValue this[TKey key] public void Add(TKey key, TValue value) { + if (key == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + } + bool modified = TryInsert(key, value, InsertionBehavior.ThrowOnExisting); Debug.Assert(modified); // If there was an existing key and the Add failed, an exception will already have been thrown. } @@ -235,8 +249,13 @@ void ICollection>.Add(KeyValuePair keyV bool ICollection>.Contains(KeyValuePair keyValuePair) { - int i = FindEntry(keyValuePair.Key); - if (i >= 0 && EqualityComparer.Default.Equals(entries[i].value, keyValuePair.Value)) + if (keyValuePair.Key == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + } + + ref Entry entry = ref FindEntry(keyValuePair.Key, out bool found); + if (found && EqualityComparer.Default.Equals(entry.value, keyValuePair.Value)) { return true; } @@ -245,8 +264,13 @@ bool ICollection>.Contains(KeyValuePair bool ICollection>.Remove(KeyValuePair keyValuePair) { - int i = FindEntry(keyValuePair.Key); - if (i >= 0 && EqualityComparer.Default.Equals(entries[i].value, keyValuePair.Value)) + if (keyValuePair.Key == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + } + + ref Entry entry = ref FindEntry(keyValuePair.Key, out bool found); + if (found && EqualityComparer.Default.Equals(entry.value, keyValuePair.Value)) { Remove(keyValuePair.Key); return true; @@ -269,7 +293,13 @@ public void Clear() public bool ContainsKey(TKey key) { - return FindEntry(key) >= 0; + if (key == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + } + + FindEntry(key, out bool found); + return found; } public bool ContainsValue(TValue value) @@ -283,10 +313,9 @@ public bool ContainsValue(TValue value) } else { - EqualityComparer c = EqualityComparer.Default; for (int i = 0; i < count; i++) { - if (entries[i].hashCode >= 0 && c.Equals(entries[i].value, value)) return true; + if (entries[i].hashCode >= 0 && EqualityComparer.Default.Equals(entries[i].value, value)) return true; } } return false; @@ -338,7 +367,7 @@ public virtual void GetObjectData(SerializationInfo info, StreamingContext conte } info.AddValue(VersionName, version); - info.AddValue(ComparerName, comparer, typeof(IEqualityComparer)); + info.AddValue(ComparerName, comparer ?? EqualityComparer.Default, typeof(IEqualityComparer)); info.AddValue(HashSizeName, buckets == null ? 0 : buckets.Length); // This is the length of the bucket array if (buckets != null) @@ -349,22 +378,36 @@ public virtual void GetObjectData(SerializationInfo info, StreamingContext conte } } - private int FindEntry(TKey key) + private ref Entry FindEntry(TKey key, out bool found) { - if (key == null) - { - ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); - } + Debug.Assert(key != null); + found = true; if (buckets != null) { - int hashCode = comparer.GetHashCode(key) & 0x7FFFFFFF; - for (int i = buckets[hashCode % buckets.Length]; i >= 0; i = entries[i].next) + int hashCode = (comparer == null ? EqualityComparer.Default.GetHashCode(key) : comparer.GetHashCode(key)) & 0x7FFFFFFF; + int targetBucket = hashCode % buckets.Length; + int i = buckets[targetBucket]; + while (i >= 0) { - if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key)) return i; + ref Entry candidateEntry = ref entries[i]; + if (candidateEntry.hashCode == hashCode && ((comparer == null && EqualityComparer.Default.Equals(candidateEntry.key, key)) || (comparer != null && comparer.Equals(candidateEntry.key, key)))) + { + return ref candidateEntry; + } + + i = candidateEntry.next; } } - return -1; + + found = false; + return ref NotFound; + } + + private ref Entry NotFound + { + [MethodImpl(MethodImplOptions.NoInlining)] + get => ref s_nullEntry; } private void Initialize(int capacity) @@ -378,23 +421,22 @@ private void Initialize(int capacity) private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior) { - if (key == null) - { - ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); - } + Debug.Assert(key != null); if (buckets == null) Initialize(0); - int hashCode = comparer.GetHashCode(key) & 0x7FFFFFFF; + int hashCode = (comparer == null ? EqualityComparer.Default.GetHashCode(key) : comparer.GetHashCode(key)) & 0x7FFFFFFF; int targetBucket = hashCode % buckets.Length; int collisionCount = 0; - for (int i = buckets[targetBucket]; i >= 0; i = entries[i].next) + int i = buckets[targetBucket]; + while (i >= 0) { - if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key)) + ref Entry candidateEntry = ref entries[i]; + if (candidateEntry.hashCode == hashCode && ((comparer == null && EqualityComparer.Default.Equals(candidateEntry.key, key)) || (comparer != null && comparer.Equals(candidateEntry.key, key)))) { if (behavior == InsertionBehavior.OverwriteExisting) { - entries[i].value = value; + candidateEntry.value = value; version++; return true; } @@ -406,6 +448,8 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior) return false; } + + i = candidateEntry.next; collisionCount++; } @@ -427,19 +471,21 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior) count++; } - entries[index].hashCode = hashCode; - entries[index].next = buckets[targetBucket]; - entries[index].key = key; - entries[index].value = value; + ref Entry entry = ref entries[index]; + entry.hashCode = hashCode; + entry.next = buckets[targetBucket]; + entry.key = key; + entry.value = value; buckets[targetBucket] = index; version++; // If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing // i.e. EqualityComparer.Default. - if (collisionCount > HashHelpers.HashCollisionThreshold && comparer is NonRandomizedStringEqualityComparer) + if (typeof(TKey) == typeof(string) && collisionCount > HashHelpers.HashCollisionThreshold && comparer is NonRandomizedStringEqualityComparer) { - comparer = (IEqualityComparer)EqualityComparer.Default; + // Clear comparer back to default, for randomised hashing + comparer = null; Resize(entries.Length, true); } @@ -514,7 +560,7 @@ private void Resize(int newSize, bool forceNewHashCodes) { if (newEntries[i].hashCode != -1) { - newEntries[i].hashCode = (comparer.GetHashCode(newEntries[i].key) & 0x7FFFFFFF); + newEntries[i].hashCode = (comparer == null ? EqualityComparer.Default.GetHashCode(newEntries[i].key) : comparer.GetHashCode(newEntries[i].key)) & 0x7FFFFFFF; } } } @@ -545,15 +591,14 @@ public bool Remove(TKey key) if (buckets != null) { - int hashCode = comparer.GetHashCode(key) & 0x7FFFFFFF; + int hashCode = (comparer == null ? EqualityComparer.Default.GetHashCode(key) : comparer.GetHashCode(key)) & 0x7FFFFFFF; int bucket = hashCode % buckets.Length; int last = -1; int i = buckets[bucket]; while (i >= 0) { ref Entry entry = ref entries[i]; - - if (entry.hashCode == hashCode && comparer.Equals(entry.key, key)) + if (entry.hashCode == hashCode && ((comparer == null && EqualityComparer.Default.Equals(entry.key, key)) || (comparer != null && comparer.Equals(entry.key, key)))) { if (last < 0) { @@ -599,15 +644,14 @@ public bool Remove(TKey key, out TValue value) if (buckets != null) { - int hashCode = comparer.GetHashCode(key) & 0x7FFFFFFF; + int hashCode = (comparer == null ? EqualityComparer.Default.GetHashCode(key) : comparer.GetHashCode(key)) & 0x7FFFFFFF; int bucket = hashCode % buckets.Length; int last = -1; int i = buckets[bucket]; while (i >= 0) { ref Entry entry = ref entries[i]; - - if (entry.hashCode == hashCode && comparer.Equals(entry.key, key)) + if (entry.hashCode == hashCode && ((comparer == null && EqualityComparer.Default.Equals(entry.key, key)) || (comparer != null && comparer.Equals(entry.key, key)))) { if (last < 0) { @@ -647,17 +691,24 @@ public bool Remove(TKey key, out TValue value) public bool TryGetValue(TKey key, out TValue value) { - int i = FindEntry(key); - if (i >= 0) + if (key == null) { - value = entries[i].value; - return true; + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); } - value = default(TValue); - return false; + + value = FindEntry(key, out bool found).value; + return found; } - public bool TryAdd(TKey key, TValue value) => TryInsert(key, value, InsertionBehavior.None); + public bool TryAdd(TKey key, TValue value) + { + if (key == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + } + + return TryInsert(key, value, InsertionBehavior.None); + } bool ICollection>.IsReadOnly { @@ -788,11 +839,13 @@ object IDictionary.this[object key] { if (IsCompatibleKey(key)) { - int i = FindEntry((TKey)key); - if (i >= 0) + if (key == null) { - return entries[i].value; + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); } + + ref var entry = ref FindEntry((TKey)key, out bool found); + return found ? (object)entry.value : null; } return null; }