Skip to content
Closed
241 changes: 182 additions & 59 deletions src/libraries/System.Private.CoreLib/src/System/String.Manipulation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,21 @@ public string Replace(char oldChar, char newChar)
// process the remaining elements vectorized too.
// Thus we adjust the pointers so that at least one full vector from the end can be processed.
nuint length = (uint)Length;
if (Vector128.IsHardwareAccelerated && length >= (uint)Vector128<ushort>.Count)
if (Vector512.IsHardwareAccelerated && length >= (uint)Vector512<ushort>.Count)
{
nuint adjust = (length - remainingLength) & ((uint)Vector512<ushort>.Count - 1);
pSrc = ref Unsafe.Subtract(ref pSrc, adjust);
pDst = ref Unsafe.Subtract(ref pDst, adjust);
remainingLength += adjust;
}
else if (Vector256.IsHardwareAccelerated && length >= (uint)Vector256<ushort>.Count)
{
nuint adjust = (length - remainingLength) & ((uint)Vector256<ushort>.Count - 1);
pSrc = ref Unsafe.Subtract(ref pSrc, adjust);
pDst = ref Unsafe.Subtract(ref pDst, adjust);
remainingLength += adjust;
}
else if (Vector128.IsHardwareAccelerated && length >= (uint)Vector128<ushort>.Count)
{
nuint adjust = (length - remainingLength) & ((uint)Vector128<ushort>.Count - 1);
pSrc = ref Unsafe.Subtract(ref pSrc, adjust);
Expand Down Expand Up @@ -1224,35 +1238,7 @@ public string Replace(string oldValue, string? newValue)
}

// Find all occurrences of the oldValue character.
char c = oldValue[0];
int i = 0;

if (PackedSpanHelpers.PackedIndexOfIsSupported && PackedSpanHelpers.CanUsePackedIndexOf(c))
{
while (true)
{
int pos = PackedSpanHelpers.IndexOf(ref Unsafe.Add(ref _firstChar, i), c, Length - i);
if (pos < 0)
{
break;
}
replacementIndices.Append(i + pos);
i += pos + 1;
}
}
else
{
while (true)
{
int pos = SpanHelpers.NonPackedIndexOfChar(ref Unsafe.Add(ref _firstChar, i), c, Length - i);
if (pos < 0)
{
break;
}
replacementIndices.Append(i + pos);
i += pos + 1;
}
}
MakeReplacementSearchVectorized(this, ref replacementIndices, oldValue[0]);
}
else
{
Expand Down Expand Up @@ -1285,6 +1271,91 @@ public string Replace(string oldValue, string? newValue)
return dst;
}

private static void MakeReplacementSearchVectorized(ReadOnlySpan<char> sourceSpan, ref ValueListBuilder<int> replacementIndices, char c)
{
nuint offset = 0;
nuint lengthToExamine = (uint)sourceSpan.Length;
ref char source = ref MemoryMarshal.GetReference(sourceSpan);

if (Vector512.IsHardwareAccelerated && sourceSpan.Length >= (uint)Vector512<ushort>.Count*2)
{
Vector512<ushort> v1 = Vector512.Create((ushort)c);
do
{
Vector512<ushort> vector = Vector512.LoadUnsafe(ref source, offset);

if (Vector512.EqualsAny(vector, v1))
{
// Skip every other bit
ulong mask = (Vector512.Equals(vector, v1)).AsByte().ExtractMostSignificantBits() & 0x5555555555555555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
replacementIndices.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

offset += (nuint)Vector512<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector512<ushort>.Count);
}
else if (Vector256.IsHardwareAccelerated && sourceSpan.Length >= (uint)Vector256<ushort>.Count*2)
{
Vector256<ushort> v1 = Vector256.Create((ushort)c);
do
{
Vector256<ushort> vector = Vector256.LoadUnsafe(ref source, offset);
Vector256<byte> cmp = (Vector256.Equals(vector, v1)).AsByte();

if (cmp != Vector256<byte>.Zero)
{
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x55555555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
replacementIndices.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

offset += (nuint)Vector256<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector256<ushort>.Count);
}
else if (Vector128.IsHardwareAccelerated && sourceSpan.Length >= (uint)Vector128<ushort>.Count*2)
{
Vector128<ushort> v1 = Vector128.Create((ushort)c);
do
{
Vector128<ushort> vector = Vector128.LoadUnsafe(ref source, offset);
Vector128<byte> cmp = (Vector128.Equals(vector, v1)).AsByte();

if (cmp != Vector128<byte>.Zero)
{
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x5555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
replacementIndices.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

offset += (nuint)Vector128<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector128<ushort>.Count);
}
while (offset < lengthToExamine)
{
char curr = Unsafe.Add(ref source, offset);
if (curr == c)
{
replacementIndices.Append((int)offset);
}
offset++;
}
}

private string ReplaceHelper(int oldValueLength, string newValue, ReadOnlySpan<int> indices)
{
Debug.Assert(indices.Length > 0);
Expand Down Expand Up @@ -1899,46 +1970,98 @@ internal static void MakeSeparatorListAny(ReadOnlySpan<char> source, ReadOnlySpa

private static void MakeSeparatorListVectorized(ReadOnlySpan<char> sourceSpan, ref ValueListBuilder<int> sepListBuilder, char c, char c2, char c3)
{
// Redundant test so we won't prejit remainder of this method
// on platforms where it is not supported
if (!Vector128.IsHardwareAccelerated)
Debug.Assert(sourceSpan.Length >= Vector128<ushort>.Count);
nuint lengthToExamine = (uint)sourceSpan.Length;
nuint offset = 0;
ref char source = ref MemoryMarshal.GetReference(sourceSpan);

if (Vector512.IsHardwareAccelerated && lengthToExamine >= (uint)Vector512<ushort>.Count*2)
{
throw new PlatformNotSupportedException();
}
Vector512<ushort> v1 = Vector512.Create((ushort)c);
Vector512<ushort> v2 = Vector512.Create((ushort)c2);
Vector512<ushort> v3 = Vector512.Create((ushort)c3);

Debug.Assert(sourceSpan.Length >= Vector128<ushort>.Count);
do
{
Vector512<ushort> vector = Vector512.LoadUnsafe(ref source, offset);
Vector512<ushort> v1Eq = Vector512.Equals(vector, v1);
Vector512<ushort> v2Eq = Vector512.Equals(vector, v2);
Vector512<ushort> v3Eq = Vector512.Equals(vector, v3);
Vector512<byte> cmp = (v1Eq | v2Eq | v3Eq).AsByte();

nuint offset = 0;
nuint lengthToExamine = (uint)sourceSpan.Length;
if (cmp != Vector512<byte>.Zero)
{
// Skip every other bit
ulong mask = cmp.ExtractMostSignificantBits() & 0x5555555555555555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
sepListBuilder.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

ref char source = ref MemoryMarshal.GetReference(sourceSpan);
offset += (nuint)Vector512<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector512<ushort>.Count);
}
else if (Vector256.IsHardwareAccelerated && lengthToExamine >= (uint)Vector256<ushort>.Count*2)
{
Vector256<ushort> v1 = Vector256.Create((ushort)c);
Vector256<ushort> v2 = Vector256.Create((ushort)c2);
Vector256<ushort> v3 = Vector256.Create((ushort)c3);

do
{
Vector256<ushort> vector = Vector256.LoadUnsafe(ref source, offset);
Vector256<ushort> v1Eq = Vector256.Equals(vector, v1);
Vector256<ushort> v2Eq = Vector256.Equals(vector, v2);
Vector256<ushort> v3Eq = Vector256.Equals(vector, v3);
Vector256<byte> cmp = (v1Eq | v2Eq | v3Eq).AsByte();

Vector128<ushort> v1 = Vector128.Create((ushort)c);
Vector128<ushort> v2 = Vector128.Create((ushort)c2);
Vector128<ushort> v3 = Vector128.Create((ushort)c3);
if (cmp != Vector256<byte>.Zero)
{
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x55555555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
sepListBuilder.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

do
offset += (nuint)Vector256<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector256<ushort>.Count);
}
else if (Vector128.IsHardwareAccelerated)
{
Vector128<ushort> vector = Vector128.LoadUnsafe(ref source, offset);
Vector128<ushort> v1Eq = Vector128.Equals(vector, v1);
Vector128<ushort> v2Eq = Vector128.Equals(vector, v2);
Vector128<ushort> v3Eq = Vector128.Equals(vector, v3);
Vector128<byte> cmp = (v1Eq | v2Eq | v3Eq).AsByte();
Vector128<ushort> v1 = Vector128.Create((ushort)c);
Vector128<ushort> v2 = Vector128.Create((ushort)c2);
Vector128<ushort> v3 = Vector128.Create((ushort)c3);

if (cmp != Vector128<byte>.Zero)
do
{
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x5555;
do
Vector128<ushort> vector = Vector128.LoadUnsafe(ref source, offset);
Vector128<ushort> v1Eq = Vector128.Equals(vector, v1);
Vector128<ushort> v2Eq = Vector128.Equals(vector, v2);
Vector128<ushort> v3Eq = Vector128.Equals(vector, v3);
Vector128<byte> cmp = (v1Eq | v2Eq | v3Eq).AsByte();

if (cmp != Vector128<byte>.Zero)
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
sepListBuilder.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x5555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
sepListBuilder.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

offset += (nuint)Vector128<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector128<ushort>.Count);
offset += (nuint)Vector128<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector128<ushort>.Count);
}

while (offset < lengthToExamine)
{
Expand Down