File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
src/target/source/literal Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -338,15 +338,15 @@ __pack_nv_bfloat162(const nv_bfloat16 x, const nv_bfloat16 y) {
338338// so we define them here to make sure the generated CUDA code
339339// is valid.
340340#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \
341- static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \
341+ static inline __device__ __host__ nv_bfloat16 HALF_MATH_NAME(nv_bfloat16 x, nv_bfloat16 y) { \
342342 float tmp_x = __bfloat162float(x); \
343343 float tmp_y = __bfloat162float(y); \
344344 float result = FP32_MATH_NAME(tmp_x, tmp_y); \
345345 return __float2bfloat16(result); \
346346}
347347
348348#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \
349- static inline __device__ __host__ half HALF_MATH_NAME(half x) { \
349+ static inline __device__ __host__ nv_bfloat16 HALF_MATH_NAME(nv_bfloat16 x) { \
350350 float tmp_x = __bfloat162float(x); \
351351 float result = FP32_MATH_NAME(tmp_x); \
352352 return __float2bfloat16(result); \
You can’t perform that action at this time.
0 commit comments