diff --git a/src/coreclr/jit/gentree.cpp b/src/coreclr/jit/gentree.cpp index f6089645f8c416..a0ca6ff8aeac46 100644 --- a/src/coreclr/jit/gentree.cpp +++ b/src/coreclr/jit/gentree.cpp @@ -23864,7 +23864,6 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, CorInfoType si #if defined(TARGET_XARCH) assert(!varTypeIsByte(simdBaseType) && !varTypeIsLong(simdBaseType)); - assert(simdSize != 64); // HorizontalAdd combines pairs so we need log2(vectorLength) passes to sum all elements together. unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType); @@ -23897,6 +23896,35 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, CorInfoType si intrinsic = NI_SSSE3_HorizontalAdd; } + if (simdSize == 64) + { + assert(IsBaselineVector512IsaSupportedDebugOnly()); + // This is roughly the following managed code: + // ... + // simd64 tmp2 = tmp1; + // tmp3 = tmp2.GetUpper(); + // simd32 tmp4 = Isa.Add(tmp1.GetLower(), tmp2); + // tmp5 = tmp4; + // simd16 tmp6 = tmp4.GetUpper(); + // tmp1 = Isa.Add(tmp1.GetLower(), tmp2); + // ... + // From here on we can treat this as a simd16 reduction + GenTree* op1Dup = fgMakeMultiUse(&op1); + GenTree* op1Lower32 = gtNewSimdGetUpperNode(TYP_SIMD32, op1Dup, simdBaseJitType, simdSize); + GenTree* op1Upper32 = gtNewSimdGetLowerNode(TYP_SIMD32, op1, simdBaseJitType, simdSize); + + simdSize = simdSize / 2; + op1Lower32 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD32, op1Lower32, op1Upper32, simdBaseJitType, simdSize); + haddCount--; + + GenTree* op1Dup32 = fgMakeMultiUse(&op1Lower32); + GenTree* op1Lower16 = gtNewSimdGetUpperNode(TYP_SIMD16, op1Lower32, simdBaseJitType, simdSize); + GenTree* op1Upper16 = gtNewSimdGetLowerNode(TYP_SIMD16, op1Dup32, simdBaseJitType, simdSize); + simdSize = simdSize / 2; + op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1Lower16, op1Upper16, simdBaseJitType, simdSize); + haddCount--; + } + for (int i = 0; i < haddCount; i++) { tmp = fgMakeMultiUse(&op1); diff --git a/src/coreclr/jit/hwintrinsiclistxarch.h b/src/coreclr/jit/hwintrinsiclistxarch.h index e1649b2159c55f..522d6df38cabae 100644 --- a/src/coreclr/jit/hwintrinsiclistxarch.h +++ b/src/coreclr/jit/hwintrinsiclistxarch.h @@ -330,6 +330,7 @@ HARDWARE_INTRINSIC(Vector512, StoreAligned, HARDWARE_INTRINSIC(Vector512, StoreAlignedNonTemporal, 64, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoCodeGen) HARDWARE_INTRINSIC(Vector512, StoreUnsafe, 64, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoCodeGen) HARDWARE_INTRINSIC(Vector512, Subtract, 64, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen) +HARDWARE_INTRINSIC(Vector512, Sum, 64, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoCodeGen) HARDWARE_INTRINSIC(Vector512, ToScalar, 64, 1, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_movss, INS_movsd_simd}, HW_Category_SIMDScalar, HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_BaseTypeFromFirstArg) HARDWARE_INTRINSIC(Vector512, WidenLower, 64, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen|HW_Flag_BaseTypeFromFirstArg) HARDWARE_INTRINSIC(Vector512, WidenUpper, 64, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen|HW_Flag_BaseTypeFromFirstArg) diff --git a/src/coreclr/jit/hwintrinsicxarch.cpp b/src/coreclr/jit/hwintrinsicxarch.cpp index 162bb4942b54ce..d86963c1ffd125 100644 --- a/src/coreclr/jit/hwintrinsicxarch.cpp +++ b/src/coreclr/jit/hwintrinsicxarch.cpp @@ -2846,6 +2846,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic, case NI_Vector128_Sum: case NI_Vector256_Sum: + case NI_Vector512_Sum: { assert(sig->numArgs == 1); var_types simdType = getSIMDTypeForSize(simdSize);