Skip to content

Commit fc5ac13

Browse files
authored
Merge pull request #1932 from sayantn/fmaddsub
Use SIMD intrinsics for `vfmaddsubph` and `vfmsubaddph`
2 parents 1d61f54 + b486cc9 commit fc5ac13

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

crates/core_arch/src/x86/avx512fp16.rs

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7184,7 +7184,11 @@ pub fn _mm_maskz_fnmsub_round_sh<const ROUNDING: i32>(
71847184
#[cfg_attr(test, assert_instr(vfmaddsub))]
71857185
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
71867186
pub fn _mm_fmaddsub_ph(a: __m128h, b: __m128h, c: __m128h) -> __m128h {
7187-
unsafe { vfmaddsubph_128(a, b, c) }
7187+
unsafe {
7188+
let add = simd_fma(a, b, c);
7189+
let sub = simd_fma(a, b, simd_neg(c));
7190+
simd_shuffle!(sub, add, [0, 9, 2, 11, 4, 13, 6, 15])
7191+
}
71887192
}
71897193

71907194
/// Multiply packed half-precision (16-bit) floating-point elements in a and b, alternatively add and
@@ -7235,7 +7239,15 @@ pub fn _mm_maskz_fmaddsub_ph(k: __mmask8, a: __m128h, b: __m128h, c: __m128h) ->
72357239
#[cfg_attr(test, assert_instr(vfmaddsub))]
72367240
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
72377241
pub fn _mm256_fmaddsub_ph(a: __m256h, b: __m256h, c: __m256h) -> __m256h {
7238-
unsafe { vfmaddsubph_256(a, b, c) }
7242+
unsafe {
7243+
let add = simd_fma(a, b, c);
7244+
let sub = simd_fma(a, b, simd_neg(c));
7245+
simd_shuffle!(
7246+
sub,
7247+
add,
7248+
[0, 17, 2, 19, 4, 21, 6, 23, 8, 25, 10, 27, 12, 29, 14, 31]
7249+
)
7250+
}
72397251
}
72407252

72417253
/// Multiply packed half-precision (16-bit) floating-point elements in a and b, alternatively add and
@@ -7286,7 +7298,18 @@ pub fn _mm256_maskz_fmaddsub_ph(k: __mmask16, a: __m256h, b: __m256h, c: __m256h
72867298
#[cfg_attr(test, assert_instr(vfmaddsub))]
72877299
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
72887300
pub fn _mm512_fmaddsub_ph(a: __m512h, b: __m512h, c: __m512h) -> __m512h {
7289-
_mm512_fmaddsub_round_ph::<_MM_FROUND_CUR_DIRECTION>(a, b, c)
7301+
unsafe {
7302+
let add = simd_fma(a, b, c);
7303+
let sub = simd_fma(a, b, simd_neg(c));
7304+
simd_shuffle!(
7305+
sub,
7306+
add,
7307+
[
7308+
0, 33, 2, 35, 4, 37, 6, 39, 8, 41, 10, 43, 12, 45, 14, 47, 16, 49, 18, 51, 20, 53,
7309+
22, 55, 24, 57, 26, 59, 28, 61, 30, 63
7310+
]
7311+
)
7312+
}
72907313
}
72917314

72927315
/// Multiply packed half-precision (16-bit) floating-point elements in a and b, alternatively add and
@@ -7459,7 +7482,7 @@ pub fn _mm512_maskz_fmaddsub_round_ph<const ROUNDING: i32>(
74597482
#[cfg_attr(test, assert_instr(vfmsubadd))]
74607483
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
74617484
pub fn _mm_fmsubadd_ph(a: __m128h, b: __m128h, c: __m128h) -> __m128h {
7462-
unsafe { vfmaddsubph_128(a, b, simd_neg(c)) }
7485+
_mm_fmaddsub_ph(a, b, unsafe { simd_neg(c) })
74637486
}
74647487

74657488
/// Multiply packed half-precision (16-bit) floating-point elements in a and b, alternatively subtract
@@ -7510,7 +7533,7 @@ pub fn _mm_maskz_fmsubadd_ph(k: __mmask8, a: __m128h, b: __m128h, c: __m128h) ->
75107533
#[cfg_attr(test, assert_instr(vfmsubadd))]
75117534
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
75127535
pub fn _mm256_fmsubadd_ph(a: __m256h, b: __m256h, c: __m256h) -> __m256h {
7513-
unsafe { vfmaddsubph_256(a, b, simd_neg(c)) }
7536+
_mm256_fmaddsub_ph(a, b, unsafe { simd_neg(c) })
75147537
}
75157538

75167539
/// Multiply packed half-precision (16-bit) floating-point elements in a and b, alternatively subtract
@@ -7561,7 +7584,7 @@ pub fn _mm256_maskz_fmsubadd_ph(k: __mmask16, a: __m256h, b: __m256h, c: __m256h
75617584
#[cfg_attr(test, assert_instr(vfmsubadd))]
75627585
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
75637586
pub fn _mm512_fmsubadd_ph(a: __m512h, b: __m512h, c: __m512h) -> __m512h {
7564-
_mm512_fmsubadd_round_ph::<_MM_FROUND_CUR_DIRECTION>(a, b, c)
7587+
_mm512_fmaddsub_ph(a, b, unsafe { simd_neg(c) })
75657588
}
75667589

75677590
/// Multiply packed half-precision (16-bit) floating-point elements in a and b, alternatively subtract
@@ -16409,10 +16432,6 @@ unsafe extern "C" {
1640916432
#[link_name = "llvm.x86.avx512fp16.vfmadd.f16"]
1641016433
fn vfmaddsh(a: f16, b: f16, c: f16, rounding: i32) -> f16;
1641116434

16412-
#[link_name = "llvm.x86.avx512fp16.vfmaddsub.ph.128"]
16413-
fn vfmaddsubph_128(a: __m128h, b: __m128h, c: __m128h) -> __m128h;
16414-
#[link_name = "llvm.x86.avx512fp16.vfmaddsub.ph.256"]
16415-
fn vfmaddsubph_256(a: __m256h, b: __m256h, c: __m256h) -> __m256h;
1641616435
#[link_name = "llvm.x86.avx512fp16.vfmaddsub.ph.512"]
1641716436
fn vfmaddsubph_512(a: __m512h, b: __m512h, c: __m512h, rounding: i32) -> __m512h;
1641816437

0 commit comments

Comments
 (0)