Skip to content

Commit d7a4494

Browse files
committed
[Cuda] Updated bfloat16 math defs.
Required to pass `test_cuda_bf16_vectorize_add` in `tests/python/unittest/test_target_codegen_cuda.py`.
1 parent 2a840a3 commit d7a4494

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/target/source/literal/cuda_half_t.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,15 +338,15 @@ __pack_nv_bfloat162(const nv_bfloat16 x, const nv_bfloat16 y) {
338338
// so we define them here to make sure the generated CUDA code
339339
// is valid.
340340
#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \
341-
static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \
341+
static inline __device__ __host__ nv_bfloat16 HALF_MATH_NAME(nv_bfloat16 x, nv_bfloat16 y) { \
342342
float tmp_x = __bfloat162float(x); \
343343
float tmp_y = __bfloat162float(y); \
344344
float result = FP32_MATH_NAME(tmp_x, tmp_y); \
345345
return __float2bfloat16(result); \
346346
}
347347
348348
#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \
349-
static inline __device__ __host__ half HALF_MATH_NAME(half x) { \
349+
static inline __device__ __host__ nv_bfloat16 HALF_MATH_NAME(nv_bfloat16 x) { \
350350
float tmp_x = __bfloat162float(x); \
351351
float result = FP32_MATH_NAME(tmp_x); \
352352
return __float2bfloat16(result); \

0 commit comments

Comments
 (0)