Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cpp/kernels/fmha_v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Introduction

FMHA_v2 is just a bunch of Multi-head Attention kernels that weve enabled for known cases. Its not built as a library (cuBLAS, cuDNN, HazyResearch's MHA, etc) that is supposed to deliver good perf for all cases. End users will get access to FMHA through products or libraries, not directly through FMHA_v2.
FMHA_v2 is just a bunch of Multi-head Attention kernels that we've enabled for known cases. It's not built as a library (cuBLAS, cuDNN, HazyResearch's MHA, etc) that is supposed to deliver good perf for all cases. End users will get access to FMHA through products or libraries, not directly through FMHA_v2.

## Launch a container to build the code

Expand Down Expand Up @@ -80,3 +80,11 @@ Why is the FMHA_v2 slower than public implementation in several cases?
```
Usually, adding new launch configurations suffices. The heuristics of FMHA_v2 are designed to work optimally for known cases. If you encounter an unknown case, first check if FMHA_v2 has a suitable kernel. If there isn't one, feel free to approach us and we'll enable a new configuration
```

What's the difference between cubins and cu files?

'''
Cubins are precompiled (from the internal fmha_v2 repo) binary files and take a lot of space, cu files are generated directly from this repo. Now we replace most of the kernels with cu files and delete unused cubins.
You can modify code in this repo to change or create your own kernels and run.
Now there are some kernels still running in cubins. See use_cubin_header(setup.py#L3055) and modify_cubin_header(setup.py#L3413) for details.
'''
17 changes: 13 additions & 4 deletions cpp/kernels/fmha_v2/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3049,14 +3049,20 @@ def get_kernel_traits_code(specs_names):
return code


# For now, only hopper head_size 128 kernel uses cubins, and other kernels use cu files.
# You should set the condition `use_cubin_header` to false if you have modified the source code of the FMHA kernels on Hopper (sm90) with head_size 128.
# This ensures that the kernels will be recompiled using the updated source code rather than relying on precompiled cubins.
def use_cubin_header(kspec):
return kspec.sm == 90 and kspec.head_size == 128


def get_cubin_header(kernel_traits, specs_names):
cubins = []
cubin_lens = []
cubins_dict = {}
cubin_lens_dict = {}
for kspec, fname, lname, kname in specs_names:
# only generate hopper cubin header
if generate_cu_trtllm and not 'sm90' in kname:
if generate_cu_trtllm and not use_cubin_header(kspec):
continue
name = fname.replace('.', '_')
data = 'extern unsigned char cubin_{name}_cubin[];'.format(name=name)
Expand Down Expand Up @@ -3209,7 +3215,7 @@ def get_cubin_header(kernel_traits, specs_names):
if generate_cu_trtllm:

def get_lname_from_kname(kname: str) -> str:
if 'sm90' in kname:
if use_cubin_header(kspec):
return 'nullptr'
lname = kname.replace('_kernel', '')
mask_types = [
Expand All @@ -3228,7 +3234,7 @@ def get_lname_from_kname(kname: str) -> str:
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
'''.format(**locals()) if 'sm90' in kname else '''\
'''.format(**locals()) if use_cubin_header(kspec) else '''\
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
Expand Down Expand Up @@ -3404,6 +3410,9 @@ def get_lname_from_kname(kname: str) -> str:
return code


# This is used to add some kernels running in cubins.
# The source code of paged context fmha kernels are not in this repo, but we have cubins for them.
# Other kernels are for passing CI cases.
def modify_cubin_header(cubin_header):
# for paged context fmha cases
target = "#ifndef EXCLUDE_SM_90"
Expand Down
2,912 changes: 1,338 additions & 1,574 deletions cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h

Large diffs are not rendered by default.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Loading