-
Notifications
You must be signed in to change notification settings - Fork 33
Description
Refactor the compute operations in our distance primitives.
The goal is that all compute()
methods will dispatch to simd::generic_simd_op()
, as it is done for example in
return simd::generic_simd_op(L2FloatOp<16>{}, a, b, length); |
To achieve this, the actual compute operation must be wrapped into a SIMD struct that provides the necessary protocol. One example is L2FloatOp
ScalableVectorSearch/include/svs/core/distance/euclidean.h
Lines 240 to 259 in 18ba515
template <> struct L2FloatOp<16> : public svs::simd::ConvertToFloat<16> { | |
using parent = svs::simd::ConvertToFloat<16>; | |
using mask_t = typename parent::mask_t; | |
// Here, we can fill-in the shared init, accumulate, combine, and reduce methods. | |
static __m512 init() { return _mm512_setzero_ps(); } | |
static __m512 accumulate(__m512 accumulator, __m512 a, __m512 b) { | |
auto c = _mm512_sub_ps(a, b); | |
return _mm512_fmadd_ps(c, c, accumulator); | |
} | |
static __m512 accumulate(mask_t m, __m512 accumulator, __m512 a, __m512 b) { | |
auto c = _mm512_maskz_sub_ps(m, a, b); | |
return _mm512_mask3_fmadd_ps(c, c, accumulator, m); | |
} | |
static __m512 combine(__m512 x, __m512 y) { return _mm512_add_ps(x, y); } | |
static float reduce(__m512 x) { return _mm512_reduce_add_ps(x); } | |
}; |
That is, it is necessary to provide init()
, accumulate()
, and combine()
operations.
The chain compute() -> simd::generic_simd_op() -> <actual compute>
is happening in many places already, but not consistently. One example where avx intrinsics are still used directly in compute()
is
template <size_t N> struct L2Impl<N, float, float, AVX_AVAILABILITY::AVX2> { |
But many other examples exist in our distance primitives.