@@ -174,26 +174,32 @@ endif ()
174174# FA3 requires CUDA 12.0 or later
175175if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
176176 # BF16 source files
177- file (GLOB FA3_BF16_GEN_SRCS
177+ file (GLOB FA3_BF16_GEN_SRCS
178178 "hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu" )
179- file (GLOB FA3_BF16_GEN_SRCS_
179+ file (GLOB FA3_BF16_GEN_SRCS_
180180 "hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu" )
181181 list (APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_} )
182- file (GLOB FA3_BF16_GEN_SRCS_
182+ file (GLOB FA3_BF16_GEN_SRCS_
183183 "hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu" )
184184 list (APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_} )
185185 # FP16 source files
186- file (GLOB FA3_FP16_GEN_SRCS
186+ file (GLOB FA3_FP16_GEN_SRCS
187187 "hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu" )
188- file (GLOB FA3_FP16_GEN_SRCS_
188+ file (GLOB FA3_FP16_GEN_SRCS_
189189 "hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu" )
190190 list (APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_} )
191- file (GLOB FA3_FP16_GEN_SRCS_
191+ file (GLOB FA3_FP16_GEN_SRCS_
192192 "hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu" )
193193 list (APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_} )
194+ # FP8 source files
195+ file (GLOB FA3_FP8_GEN_SRCS
196+ "hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu" )
197+ file (GLOB FA3_FP8_GEN_SRCS_
198+ "hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu" )
199+ list (APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_} )
194200
195201 # TODO add fp8 source files when FP8 is enabled
196- set (FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} )
202+ set (FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS} )
197203
198204 # For CUDA we set the architectures on a per file basis
199205 if (VLLM_GPU_LANG STREQUAL "CUDA" )
@@ -236,9 +242,8 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
236242 FLASHATTENTION_DISABLE_UNEVEN_K
237243 # FLASHATTENTION_DISABLE_LOCAL
238244 FLASHATTENTION_DISABLE_PYBIND
239- FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
240245 FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
241246 )
242247elseif (${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
243248 message (STATUS "FA3 is disabled because CUDA version is not 12.0 or later." )
244- endif ()
249+ endif ()
0 commit comments