Skip to content

Commit e5faab0

Browse files
authored
Improve perf of Enumerable.Sum/Average/Max/Min for arrays and lists (#64624)
* Improve perf of Enumerable.Sum/Average/Max/Min for arrays and lists It's very common to use these terminal functions for quick stats on arrays and lists of values. Just the overhead of enumerating as an enumerable (involving multiple interface dispatch) per iteration is significant, and it's much faster to directly enumerate the contents of the array or the list. In some cases, we can further use vectorization to speed up the processing. This change: - Adds a helper that does a fast check to see if it can extract a span from an enumerable that's actually an array or a list. It could be augmented to detect other interesting types, but `T[]` and `List<T>` are the most relevant from the data I've seen, and we can fairly quickly do type checks to get the most benefit for a small amount of cost. - Uses that helper in the int/long/float/double/decimal overloads of Sum/Average/Min/Max to add a span-based path. - Vectorizes Sum for float and double - Vectorizes Average for int, float, and double (the latter two via use of Sum) * Address PR feedback
1 parent 6b14c1e commit e5faab0

File tree

10 files changed

+775
-316
lines changed

10 files changed

+775
-316
lines changed

src/libraries/System.Linq/src/System.Linq.csproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@
100100
<Reference Include="System.Memory" />
101101
<Reference Include="System.Numerics.Vectors" />
102102
<Reference Include="System.Runtime" />
103+
<Reference Include="System.Runtime.CompilerServices.Unsafe" />
103104
<Reference Include="System.Runtime.Extensions" />
105+
<Reference Include="System.Runtime.InteropServices" />
104106
</ItemGroup>
105107
</Project>

src/libraries/System.Linq/src/System/Linq/Average.cs

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Collections.Generic;
5+
using System.Numerics;
56

67
namespace System.Linq
78
{
89
public static partial class Enumerable
910
{
1011
public static double Average(this IEnumerable<int> source)
1112
{
12-
if (source == null)
13+
if (source.TryGetSpan(out ReadOnlySpan<int> span))
1314
{
14-
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
15+
return Average(span);
1516
}
1617

1718
using (IEnumerator<int> e = source.GetEnumerator())
@@ -36,6 +37,38 @@ public static double Average(this IEnumerable<int> source)
3637
}
3738
}
3839

40+
private static double Average(ReadOnlySpan<int> span)
41+
{
42+
if (span.IsEmpty)
43+
{
44+
ThrowHelper.ThrowNoElementsException();
45+
}
46+
47+
long sum = 0;
48+
int i = 0;
49+
50+
if (Vector.IsHardwareAccelerated && span.Length >= Vector<int>.Count)
51+
{
52+
Vector<long> sums = default;
53+
do
54+
{
55+
Vector.Widen(new Vector<int>(span.Slice(i)), out Vector<long> low, out Vector<long> high);
56+
sums += low;
57+
sums += high;
58+
i += Vector<int>.Count;
59+
}
60+
while (i <= span.Length - Vector<int>.Count);
61+
sum += Vector.Sum(sums);
62+
}
63+
64+
for (; (uint)i < (uint)span.Length; i++)
65+
{
66+
sum += span[i];
67+
}
68+
69+
return (double)sum / span.Length;
70+
}
71+
3972
public static double? Average(this IEnumerable<int?> source)
4073
{
4174
if (source == null)
@@ -75,9 +108,9 @@ public static double Average(this IEnumerable<int> source)
75108

76109
public static double Average(this IEnumerable<long> source)
77110
{
78-
if (source == null)
111+
if (source.TryGetSpan(out ReadOnlySpan<long> span))
79112
{
80-
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
113+
return Average(span);
81114
}
82115

83116
using (IEnumerator<long> e = source.GetEnumerator())
@@ -102,6 +135,22 @@ public static double Average(this IEnumerable<long> source)
102135
}
103136
}
104137

138+
private static double Average(ReadOnlySpan<long> span)
139+
{
140+
if (span.IsEmpty)
141+
{
142+
ThrowHelper.ThrowNoElementsException();
143+
}
144+
145+
long sum = span[0];
146+
for (int i = 1; i < span.Length; i++)
147+
{
148+
checked { sum += span[i]; }
149+
}
150+
151+
return (double)sum / span.Length;
152+
}
153+
105154
public static double? Average(this IEnumerable<long?> source)
106155
{
107156
if (source == null)
@@ -141,9 +190,14 @@ public static double Average(this IEnumerable<long> source)
141190

142191
public static float Average(this IEnumerable<float> source)
143192
{
144-
if (source == null)
193+
if (source.TryGetSpan(out ReadOnlySpan<float> span))
145194
{
146-
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
195+
if (span.IsEmpty)
196+
{
197+
ThrowHelper.ThrowNoElementsException();
198+
}
199+
200+
return (float)(Sum(span) / span.Length);
147201
}
148202

149203
using (IEnumerator<float> e = source.GetEnumerator())
@@ -204,9 +258,14 @@ public static float Average(this IEnumerable<float> source)
204258

205259
public static double Average(this IEnumerable<double> source)
206260
{
207-
if (source == null)
261+
if (source.TryGetSpan(out ReadOnlySpan<double> span))
208262
{
209-
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
263+
if (span.IsEmpty)
264+
{
265+
ThrowHelper.ThrowNoElementsException();
266+
}
267+
268+
return Sum(span) / span.Length;
210269
}
211270

212271
using (IEnumerator<double> e = source.GetEnumerator())
@@ -270,9 +329,14 @@ public static double Average(this IEnumerable<double> source)
270329

271330
public static decimal Average(this IEnumerable<decimal> source)
272331
{
273-
if (source == null)
332+
if (source.TryGetSpan(out ReadOnlySpan<decimal> span))
274333
{
275-
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
334+
if (span.IsEmpty)
335+
{
336+
ThrowHelper.ThrowNoElementsException();
337+
}
338+
339+
return Sum(span) / span.Length;
276340
}
277341

278342
using (IEnumerator<decimal> e = source.GetEnumerator())

src/libraries/System.Linq/src/System/Linq/Enumerable.cs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,52 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Collections.Generic;
5+
using System.Runtime.CompilerServices;
6+
using System.Runtime.InteropServices;
57

68
namespace System.Linq
79
{
810
public static partial class Enumerable
911
{
1012
public static IEnumerable<TSource> AsEnumerable<TSource>(this IEnumerable<TSource> source) => source;
13+
14+
/// <summary>Validates that source is not null and then tries to extract a span from the source.</summary>
15+
[MethodImpl(MethodImplOptions.AggressiveInlining)] // fast type checks that don't add a lot of overhead
16+
private static bool TryGetSpan<TSource>(this IEnumerable<TSource> source, out ReadOnlySpan<TSource> span)
17+
// This constraint isn't required, but the overheads involved here can be more substantial when TSource
18+
// is a reference type and generic implementations are shared. So for now we're protecting ourselves
19+
// and forcing a conscious choice to remove this in the future, at which point it should be paired with
20+
// sufficient performance testing.
21+
where TSource : struct
22+
{
23+
if (source is null)
24+
{
25+
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
26+
}
27+
28+
// Use `GetType() == typeof(...)` rather than `is` to avoid cast helpers. This is measurably cheaper
29+
// but does mean we could end up missing some rare cases where we could get a span but don't (e.g. a uint[]
30+
// masquerading as an int[]). That's an acceptable tradeoff. The Unsafe usage is only after we've
31+
// validated the exact type; this could be changed to a cast in the future if the JIT starts to recognize it.
32+
// We only pay the comparison/branching costs here for super common types we expect to be used frequently
33+
// with LINQ methods.
34+
35+
bool result = true;
36+
if (source.GetType() == typeof(TSource[]))
37+
{
38+
span = Unsafe.As<TSource[]>(source);
39+
}
40+
else if (source.GetType() == typeof(List<TSource>))
41+
{
42+
span = CollectionsMarshal.AsSpan(Unsafe.As<List<TSource>>(source));
43+
}
44+
else
45+
{
46+
span = default;
47+
result = false;
48+
}
49+
50+
return result;
51+
}
1152
}
1253
}

0 commit comments

Comments
 (0)