Skip to content

Commit e11ba83

Browse files
committed
feat: add fp8 source files
1 parent 175ebb2 commit e11ba83

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

CMakeLists.txt

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,26 +174,32 @@ endif ()
174174
# FA3 requires CUDA 12.0 or later
175175
if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
176176
# BF16 source files
177-
file(GLOB FA3_BF16_GEN_SRCS
177+
file(GLOB FA3_BF16_GEN_SRCS
178178
"hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
179-
file(GLOB FA3_BF16_GEN_SRCS_
179+
file(GLOB FA3_BF16_GEN_SRCS_
180180
"hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
181181
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
182-
file(GLOB FA3_BF16_GEN_SRCS_
182+
file(GLOB FA3_BF16_GEN_SRCS_
183183
"hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu")
184184
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
185185
# FP16 source files
186-
file(GLOB FA3_FP16_GEN_SRCS
186+
file(GLOB FA3_FP16_GEN_SRCS
187187
"hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
188-
file(GLOB FA3_FP16_GEN_SRCS_
188+
file(GLOB FA3_FP16_GEN_SRCS_
189189
"hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
190190
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
191-
file(GLOB FA3_FP16_GEN_SRCS_
191+
file(GLOB FA3_FP16_GEN_SRCS_
192192
"hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu")
193193
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
194+
# FP8 source files
195+
file(GLOB FA3_FP8_GEN_SRCS
196+
"hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
197+
file(GLOB FA3_FP8_GEN_SRCS_
198+
"hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
199+
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
194200

195201
# TODO add fp8 source files when FP8 is enabled
196-
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS})
202+
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
197203

198204
# For CUDA we set the architectures on a per file basis
199205
if (VLLM_GPU_LANG STREQUAL "CUDA")
@@ -236,9 +242,8 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
236242
FLASHATTENTION_DISABLE_UNEVEN_K
237243
# FLASHATTENTION_DISABLE_LOCAL
238244
FLASHATTENTION_DISABLE_PYBIND
239-
FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
240245
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
241246
)
242247
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
243248
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")
244-
endif ()
249+
endif ()

0 commit comments

Comments
 (0)