Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3320,37 +3320,61 @@ public static int Count<T>(ref T current, T value, int length) where T : IEquata
public static int CountValueType<T>(ref T current, T value, int length) where T : struct, IEquatable<T>?
{
int count = 0;

ref T end = ref Unsafe.Add(ref current, length);

if (Vector128.IsHardwareAccelerated && length >= Vector128<T>.Count)
{
if (Vector256.IsHardwareAccelerated && length >= Vector256<T>.Count)
{
Vector256<T> targetVector = Vector256.Create(value);
ref T oneVectorAwayFromEndMinus1 = ref Unsafe.Subtract(ref end, Vector256<T>.Count - 1);
ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector256<T>.Count);
do
{
count += BitOperations.PopCount(Vector256.Equals(Vector256.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits());
current = ref Unsafe.Add(ref current, Vector256<T>.Count);
}
while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEndMinus1));
while (!Unsafe.IsAddressGreaterThan(ref current, ref oneVectorAwayFromEnd));

if (Unsafe.IsAddressLessThan(ref current, ref Unsafe.Subtract(ref end, Vector128<T>.Count - 1)))
// If there are just a few elements remaining, then processing these elements by the scalar loop
// is cheaper than doing bitmask + popcount on the full last vector. To avoid complicated type
// based checks, other remainder-count based logic to determine the correct cut-off, for simplicity
// a half-vector size is chosen (based on benchmarks).
uint remaining = (uint)Unsafe.ByteOffset(ref current, ref end) / (uint)Unsafe.SizeOf<T>();
if (remaining > Vector256<T>.Count / 2)
{
count += BitOperations.PopCount(Vector128.Equals(Vector128.LoadUnsafe(ref current), Vector128.Create(value)).ExtractMostSignificantBits());
current = ref Unsafe.Add(ref current, Vector128<T>.Count);
uint mask = Vector256.Equals(Vector256.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits();

// The mask contains some elements that may be double-checked, so shift them away in order to get the correct pop-count.
uint overlaps = (uint)Vector256<T>.Count - remaining;
mask >>= (int)overlaps;
count += BitOperations.PopCount(mask);

return count;
}
}
else
{
Vector128<T> targetVector = Vector128.Create(value);
ref T oneVectorAwayFromEndMinus1 = ref Unsafe.Subtract(ref end, Vector128<T>.Count - 1);
ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector128<T>.Count);
do
{
count += BitOperations.PopCount(Vector128.Equals(Vector128.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits());
current = ref Unsafe.Add(ref current, Vector128<T>.Count);
}
while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEndMinus1));
while (!Unsafe.IsAddressGreaterThan(ref current, ref oneVectorAwayFromEnd));

uint remaining = (uint)Unsafe.ByteOffset(ref current, ref end) / (uint)Unsafe.SizeOf<T>();
if (remaining > Vector128<T>.Count / 2)
{
uint mask = Vector128.Equals(Vector128.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits();

// The mask contains some elements that may be double-checked, so shift them away in order to get the correct pop-count.
uint overlaps = (uint)Vector128<T>.Count - remaining;
mask >>= (int)overlaps;
count += BitOperations.PopCount(mask);

return count;
}
}
}

Expand Down