Skip to content

Commit 77e226c

Browse files
authored
Update utils_parallel_dequant.cuh
1 parent 66d86a6 commit 77e226c

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525

2626
#include <cuda.h>
2727
#include <cuda_fp16.h>
28-
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
2928
#include <cuda_bf16.h>
30-
#endif
3129
#include <cuda_runtime.h>
3230

3331
/*
@@ -70,9 +68,9 @@ constexpr float power_of_two(int n) {
7068
return (n == 0) ? 1.0f : 2.0f * power_of_two(n - 1);
7169
}
7270

73-
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
7471
template<int EXPONENT, int MANTISSA>
7572
__device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bfloat16 Scale) {
73+
#if __CUDA_ARCH__ >= 800
7674
constexpr int BIAS_OFFSET = (int(1) << (8-1)) - (int(1) << (EXPONENT-1));
7775
constexpr float BIAS = power_of_two(BIAS_OFFSET);
7876
__nv_bfloat16* BF16_1 = reinterpret_cast<__nv_bfloat16*>(&PackedBF16Pair);
@@ -82,8 +80,8 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bflo
8280
output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(BIAS)), Scale);
8381
output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(BIAS)), Scale);
8482
return output;
85-
}
8683
#endif
84+
}
8785

8886
// MODIFICATION NOTE: to support MSVC
8987
// - u_int32_t __restrict__ Reg[][4] is changed to below.

0 commit comments

Comments
 (0)