Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/flashinfer
Submodule flashinfer updated 55 files
+12 −5 .github/workflows/release_wheel.yml
+4 −1 .gitignore
+1 −1 .release-please-manifest.json
+16 −0 CHANGELOG.md
+80 −39 CMakeLists.txt
+10 −0 cmake/config.cmake
+3 −0 docs/api/python/decode.rst
+13 −0 docs/api/python/norm.rst
+15 −0 docs/api/python/sampling.rst
+6 −6 docs/conf.py
+2 −1 docs/index.rst
+17 −2 docs/installation.rst
+0 −67 include/flashinfer/attention/decode.cuh
+183 −81 include/flashinfer/attention/handler.cuh
+46 −89 include/flashinfer/attention/prefill.cuh
+0 −146 include/flashinfer/decode_attention_decl.cuh
+0 −95 include/flashinfer/prefill_attention_decl.cuh
+109 −99 include/flashinfer/sampling.cuh
+12 −19 include/flashinfer/utils.cuh
+1 −0 python/MANIFEST.in
+63 −28 python/csrc/batch_decode.cu
+18 −8 python/csrc/batch_prefill.cu
+23 −3 python/csrc/flashinfer_ops.cu
+45 −19 python/csrc/flashinfer_ops.h
+43 −0 python/csrc/norm.cu
+4 −4 python/csrc/pytorch_extension_utils.h
+98 −0 python/csrc/sampling.cu
+7 −0 python/flashinfer/__init__.py
+284 −1 python/flashinfer/decode.py
+49 −0 python/flashinfer/norm.py
+4 −2 python/flashinfer/prefill.py
+190 −0 python/flashinfer/sampling.py
+3 −6 python/generate_batch_paged_prefill_inst.py
+40 −16 python/generate_dispatch_inc.py
+12 −3 python/setup.py
+143 −4 python/tests/test_batch_decode_kernels.py
+47 −0 python/tests/test_norm.py
+101 −0 python/tests/test_sampling.py
+3 −4 src/bench_batch_decode.cu
+5 −6 src/bench_cascade.cu
+4 −4 src/bench_sampling.cu
+2 −2 src/bench_single_decode.cu
+2 −1 src/bench_single_prefill.cu
+62 −63 src/cpu_reference.h
+314 −0 src/flashinfer_ops.cuh
+5 −5 src/test_batch_decode.cu
+9 −8 src/test_batch_prefill.cu
+9 −10 src/test_cascade.cu
+2 −2 src/test_page.cu
+1,707 −9 src/test_sampling.cu
+1 −1 src/test_single_decode.cu
+1 −2 src/test_single_prefill.cu
+23 −45 src/tvm_wrapper.cu
+43 −0 src/utils.h
+1 −1 version.txt
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -960,13 +960,13 @@ option(USE_FLASHINFER "Build TVM with FlashInfer" OFF)
if (USE_FLASHINFER STREQUAL "ON")
message(STATUS "Build with FlashInfer")
set(FLASHINFER_TVM_BINDING ON)
set(FLASHINFER_TVM_HOME ${PROJECT_SOURCE_DIR})
set(FLASHINFER_ENABLE_FP8 OFF)
set(FLASHINFER_ENABLE_BF16 OFF)
set(FLASHINFER_TVM_SOURCE_DIR ${PROJECT_SOURCE_DIR})
set(FLASHINFER_PREFILL OFF)
set(FLASHINFER_DECODE OFF)
set(FLASHINFER_PAGE OFF)
set(FLASHINFER_CASCADE OFF)
set(FLASHINFER_SAMPLING OFF)
set(FLASHINFER_NORM OFF)
add_subdirectory(3rdparty/flashinfer)
else ()
message(STATUS "Build without FlashInfer")
Expand Down
13 changes: 13 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,19 @@ set(USE_GTEST AUTO)
# Need to have USE_CUDA=ON
set(USE_CUTLASS OFF)

# Whether to enable FlashInfer or not.
set(USE_FLASHINFER OFF)
# Options for FlashInfer kernel compilation.
set(FLASHINFER_ENABLE_FP8 OFF)
set(FLASHINFER_ENABLE_BF16 OFF)
set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)
set(FLASHINFER_GEN_PAGE_SIZES 16)
set(FLASHINFER_GEN_HEAD_DIMS 128)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1)
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false")
set(FLASHINFER_GEN_CASUALS "false" "true")

# Enable to show a summary of TVM options
set(SUMMARIZE OFF)

Expand Down
17 changes: 17 additions & 0 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,23 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
return true;
}

/*!
* \brief Get the preferred host device from the input device.
* - For CUDA and ROCm, CUDAHost and ROCMHost will be returned for pinned memory,
* since pinned memory reduces copy overhead.
* - For other devices, CPU is returned as a fallback.
*/
inline Device GetPreferredHostDevice(Device device) {
if (device.device_type == DLDeviceType::kDLCUDA) {
return Device{DLDeviceType::kDLCUDAHost, 0};
} else if (device.device_type == DLDeviceType::kDLROCM) {
return Device{DLDeviceType::kDLROCMHost, 0};
} else {
// Fallback to CPU.
return Device{DLDeviceType::kDLCPU, 0};
}
}

} // namespace runtime
} // namespace tvm

Expand Down
Loading