diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6fc19aa2f1c..ae3956b4430 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -6,13 +6,39 @@ # Without approval from a member of this team, PRs cannot be merged to release branches. # * @NVIDIA/trt-llm-release-branch-approval +## TensorRT-LLM Infra +### CI +/jenkins @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs +### Setup +/docker @NVIDIA/trt-llm-setup-infra-devs @NVIDIA/trt-llm-infra-devs +### Github workflows +/.github @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs +/.coderabbit.yaml @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs + +## TensorRT-LLM - Docs +/docs @NVIDIA/trt-llm-doc-owners + +## Examples +/examples @NVIDIA/trt-llm-doc-owners + +## TensorRT-LLM - Triton backend +/triton_backend @NVIDIA/trt-llm-triton-backend-devs + # TensorRT-LLM Pytorch backend /tensorrt_llm/_torch @NVIDIA/trt-llm-torch-devs + +## TensorRT-LLM Pytorch - Modules +/tensorrt_llm/_torch/modules @NVIDIA/trt-llm-torch-modules + +## TensorRT-LLM Pytorch Models +/tensorrt_llm/_torch/models @NVIDIA/trt-llm-torch-models-devs +/examples/models @NVIDIA/trt-llm-torch-models-devs @NVIDIA/trt-llm-doc-owners + ## TensorRT-LLM Pytorch backend - runtime /tensorrt_llm/_torch/pyexecutor @NVIDIA/trt-llm-torch-runtime-devs ## TensorRT-LLM Pytorch backend - AutoDeploy flow /tensorrt_llm/_torch/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs -/tensorrt_llm/examples/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs +/examples/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs @NVIDIA/trt-llm-doc-owners ## TensorRT-LLM Pytorch - Speculative Decoding /tensorrt_llm/_torch/speculative @NVIDIA/trt-llm-torch-spec-decoding @@ -31,12 +57,6 @@ /tensorrt_llm/_torch/attention_backend @NVIDIA/trt-llm-torch-attention-devs /tensorrt_llm/_torch/modules/attention.py @NVIDIA/trt-llm-torch-attention-devs -## TensorRT-LLM Pytorch - Modules -/tensorrt_llm/_torch/modules @NVIDIA/trt-llm-torch-modules - - -## TensorRT-LLM Pytorch Models -/tensorrt_llm/_torch/models @NVIDIA/trt-llm-torch-models-devs ### TensorRT-LLM Pytorch - Models - Gemma /tensorrt_llm/_torch/models/modeling_gemma3.py @NVIDIA/trt-llm-torch-models-gemma-devs @NVIDIA/trt-llm-torch-models-devs @@ -108,8 +128,6 @@ /cpp/tensorrt_llm/runtime/loraUtils.cpp @NVIDIA/trt-llm-torch-peft /cpp/tensorrt_llm/runtime/loraUtils.h @NVIDIA/trt-llm-torch-peft -## TensorRT-LLM - Triton backend -/triton_backend @NVIDIA/trt-llm-triton-backend-devs ## TensorRT-LLM trtllm-bench Reviewers /tensorrt_llm/bench @NVIDIA/trtllm-bench-reviewers @@ -121,10 +139,9 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers /tensorrt_llm/executor @NVIDIA/trt-llm-llmapi-devs ## TensorRT-LLM LLM Disaggregated -/examples/disaggregated @NVIDIA/trt-llm-disagg-devs +/examples/disaggregated @NVIDIA/trt-llm-disagg-devs @NVIDIA/trt-llm-doc-owners /tensorrt_llm/disaggregated_params.py @NVIDIA/trt-llm-disagg-devs /tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @NVIDIA/trt-llm-disagg-devs -/tensorrt_llm/_torch/pyexecutor/py_executor.py @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/cacheFormatter.h @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp @NVIDIA/trt-llm-disagg-devs @@ -135,19 +152,6 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers /cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h @NVIDIA/trt-llm-disagg-devs -## TensorRT-LLM Infra - -### CI -/jenkins @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs -### Setup -/docker @NVIDIA/trt-llm-setup-infra-devs @NVIDIA/trt-llm-infra-devs -### Github workflows -/tensorrt_llm/.github @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs -/tensorrt_llm/.coderabbit.yaml @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs - -## TensorRT-LLM - Docs -/docs @NVIDIA/trt-llm-doc-owners -/examples @NVIDIA/trt-llm-doc-owners # The rule below requires that any PR modifying public APIs must be approved by at least one member # of the NVIDIA/trt-llm-committed-api-review-committee or NVIDIA/trt-llm-noncommitted-api-review-committee team. diff --git a/.github/ISSUE_TEMPLATE/01-installation.yml b/.github/ISSUE_TEMPLATE/01-installation.yml new file mode 100644 index 00000000000..fd24fd93f07 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/01-installation.yml @@ -0,0 +1,66 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/200-installation.yml +name: 🛠️ Installation +description: Report an issue here when you hit errors during installation. +title: "[Installation]: " +labels: ["Installation"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: textarea + attributes: + label: System Info + description: | + Please provide the following system information to help us debug your installation issue: + + ```bash + # System information + cat /etc/os-release + nvidia-smi + nvcc --version + python --version + pip list | grep -E "(tensorrt|torch|cuda)" + + # TensorRT-LLM installation method and version + pip show tensorrt_llm + ``` + value: | + **System Information:** + - OS: + - Python version: + - CUDA version: + - GPU model(s): + - Driver version: + - TensorRT version: + - PyTorch version: + - TensorRT-LLM version: + + **Detailed output:** + ```text + Paste the output of the above commands here + ``` + validations: + required: true +- type: textarea + attributes: + label: How you are installing TensorRT-LLM + description: | + Paste the full command you are trying to execute or describe your installation method. + value: | + ```sh + # Installation command or method + pip install tensorrt_llm + ``` +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [installation documentation](https://nvidia.github.io/TensorRT-LLM/installation/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/02-new-model.yml b/.github/ISSUE_TEMPLATE/02-new-model.yml new file mode 100644 index 00000000000..688c11866fc --- /dev/null +++ b/.github/ISSUE_TEMPLATE/02-new-model.yml @@ -0,0 +1,41 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/600-new-model.yml +name: 🤗 Support request for a new model from huggingface +description: Submit a proposal/request for a new model from huggingface +title: "[New Model]: " +labels: ["new model"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+). + + #### We also highly recommend you read https://nvidia.github.io/TensorRT-LLM/architecture/add-model.html first to understand how to add a new model. +- type: textarea + attributes: + label: The model to consider. + description: > + A huggingface identifier, pointing to the model, e.g. `meta-llama/Llama-3.1-8B-Instruct` . + validations: + required: true +- type: textarea + attributes: + label: The closest model TensorRT-LLM already supports. + description: > + Here is the list of models already supported by TensorRT-LLM: https://github.com/NVIDIA/TensorRT-LLM/tree/main/tensorrt_llm/models (TRT backend) and https://github.com/NVIDIA/TensorRT-LLM/tree/main/tensorrt_llm/_torch/models (Pytorch backend) . Which model is the most similar to the model you want to add support for? +- type: textarea + attributes: + label: What's your difficulty of supporting the model you want? + description: > + For example, any new operators or new architecture? +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/03-documentation.yml b/.github/ISSUE_TEMPLATE/03-documentation.yml new file mode 100644 index 00000000000..df7643337b7 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/03-documentation.yml @@ -0,0 +1,31 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/100-documentation.yml +name: 📚 Documentation +description: Report an issue related to https://nvidia.github.io/TensorRT-LLM/ +title: "[Doc]: " +labels: ["Documentation"] +assignees: ["nv-guomingz"] + +body: +- type: textarea + attributes: + label: 📚 The doc issue + description: > + A clear and concise description of what content in https://nvidia.github.io/TensorRT-LLM/ is an issue. + validations: + required: true +- type: textarea + attributes: + label: Suggest a potential alternative/fix + description: > + Tell us how we could improve the documentation in this regard. +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/04-questions.yml b/.github/ISSUE_TEMPLATE/04-questions.yml new file mode 100644 index 00000000000..75a9416e920 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/04-questions.yml @@ -0,0 +1,62 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/300-usage.yml +name: 💻 Questions +description: Raise an issue here if you don't know how to use TensorRT-LLM. +title: "[Usage]: " +labels: ["question"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: textarea + attributes: + label: System Info + description: | + Please provide the following system information to help us debug your usage issue: + + ```bash + # System information + nvidia-smi + python --version + pip show tensorrt_llm + ``` + value: | + **System Information:** + - OS: + - Python version: + - CUDA version: + - GPU model(s): + - Driver version: + - TensorRT-LLM version: + + **Detailed output:** + ```text + Paste the output of the above commands here + ``` + validations: + required: true +- type: textarea + attributes: + label: How would you like to use TensorRT-LLM + description: | + A detailed description of how you want to use TensorRT-LLM. + value: | + I want to run inference of a [specific model](put Hugging Face link here). I don't know how to integrate it with TensorRT-LLM or optimize it for my use case. + + **Specific questions:** + - Model: + - Use case (e.g., chatbot, batch inference, real-time serving): + - Expected throughput/latency requirements: + - Multi-GPU setup needed: +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/05-feature-request.yml b/.github/ISSUE_TEMPLATE/05-feature-request.yml new file mode 100644 index 00000000000..32c1ee43c7b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/05-feature-request.yml @@ -0,0 +1,40 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/500-feature-request.yml +name: 🚀 Feature request +description: Submit a proposal/request for a new TensorRT-LLM feature +title: "[Feature]: " +labels: ["feature request"] +assignees: ["laikhtewari"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: textarea + attributes: + label: 🚀 The feature, motivation and pitch + description: > + A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. + validations: + required: true +- type: textarea + attributes: + label: Alternatives + description: > + A description of any alternative solutions or features you've considered, if any. +- type: textarea + attributes: + label: Additional context + description: > + Add any other context or screenshots about the feature request. +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/06-bug-report.yml b/.github/ISSUE_TEMPLATE/06-bug-report.yml new file mode 100644 index 00000000000..c41ff62ded3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/06-bug-report.yml @@ -0,0 +1,191 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/400-bug-report.yml +name: "🐛 Bug Report" +description: Submit a bug report to help us improve TensorRT-LLM +title: "[Bug]: " +labels: [ "bug" ] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: markdown + attributes: + value: | + ⚠️ **SECURITY WARNING:** Please review any text you paste to ensure it does not contain sensitive information such as: + - API tokens or keys (e.g., Hugging Face tokens, OpenAI API keys) + - Passwords or authentication credentials + - Private URLs or endpoints + - Personal or confidential data + + Consider redacting or replacing sensitive values with placeholders like `` when sharing configuration or code examples. +- type: textarea + id: system-info + attributes: + label: System Info + description: Please share your system info with us. + placeholder: | + - CPU architecture (e.g., x86_64, aarch64) + - CPU/Host memory size (if known) + - GPU properties + - GPU name (e.g., NVIDIA H100, NVIDIA A100, NVIDIA L40S) + - GPU memory size (if known) + - Clock frequencies used (if applicable) + - Libraries + - TensorRT-LLM branch or tag (e.g., main, v0.7.1) + - TensorRT-LLM commit (if known) + - Versions of TensorRT, Modelopt, CUDA, cuBLAS, etc. used + - Container used (if running TensorRT-LLM in a container) + - NVIDIA driver version + - OS (Ubuntu 24.04, CentOS 8) + - Any other information that may be useful in reproducing the bug + + **Commands to gather system information:** + ```bash + nvidia-smi + nvcc --version + python --version + pip show tensorrt_llm tensorrt torch + ``` + validations: + required: true + +- type: textarea + id: who-can-help + attributes: + label: Who can help? + description: | + To expedite the response to your issue, it would be helpful if you could identify the appropriate person + to tag using the **@** symbol. Here is a general guideline on **whom to tag**. + + Rest assured that all issues are reviewed by the core maintainers. If you are unsure about whom to tag, + you can leave it blank, and a core maintainer will make sure to involve the appropriate person. + + Please tag fewer than 3 people. + + Quantization: @Tracin + + Documentation: @juney-nvidia + + Feature request: @laikhtewari + + Performance: @kaiyux + + placeholder: "@Username ..." + +- type: checkboxes + id: information-scripts-examples + attributes: + label: Information + description: 'The problem arises when using:' + options: + - label: "The official example scripts" + - label: "My own modified scripts" + +- type: checkboxes + id: information-tasks + attributes: + label: Tasks + description: "The tasks I am working on are:" + options: + - label: "An officially supported task in the `examples` folder (such as GLUE/SQuAD, ...)" + - label: "My own task or dataset (give details below)" + +- type: textarea + id: reproduction + validations: + required: true + attributes: + label: Reproduction + description: | + Please provide a clear and concise description of what the bug is and how to reproduce it. + + If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: + + ```python + from tensorrt_llm import LLM + from tensorrt_llm.sampling_params import SamplingParams + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct") + + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + ``` + + If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. + + Remember to use code tags to properly format your code. You can refer to the + link https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting for guidance on code formatting. + + Please refrain from using screenshots, as they can be difficult to read and prevent others from copying and pasting your code. + It would be most helpful if we could reproduce your issue by simply copying and pasting your scripts and codes. + + Please set the environment variable `export TLLM_DEBUG_MODE=1` to turn on more logging to help debugging potential issues. + + placeholder: | + Steps to reproduce the behavior: + + 1. + 2. + 3. + + ```python + # Sample code to reproduce the problem + ``` + + ``` + The error message you got, with the full traceback and the error logs. + ``` + +- type: textarea + id: expected-behavior + validations: + required: true + attributes: + label: Expected behavior + description: "Provide a brief summary of the expected behavior of the software. Provide output files or examples if possible." + +- type: textarea + id: actual-behavior + validations: + required: true + attributes: + label: actual behavior + description: "Describe the actual behavior of the software and how it deviates from the expected behavior. Provide output files or examples if possible." + +- type: textarea + id: additional-notes + validations: + required: true + attributes: + label: additional notes + description: "Provide any additional context here you think might be useful for the TensorRT-LLM team to help debug this issue (such as experiments done, potential things to investigate)." + +- type: markdown + attributes: + value: | + ⚠️ Please separate bugs of `transformers`, `pytorch` implementation or usage from bugs of `TensorRT-LLM`. + + - If the error only appears in TensorRT-LLM, please provide the detailed script of how you run `TensorRT-LLM`, also highlight the difference and what you expect. + + Thanks for reporting 🙏! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/07-performance-discussion.yml b/.github/ISSUE_TEMPLATE/07-performance-discussion.yml new file mode 100644 index 00000000000..feb3b025018 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/07-performance-discussion.yml @@ -0,0 +1,74 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/700-performance-discussion.yml +name: ⚡ Discussion on the performance of TensorRT-LLM +description: Submit a proposal/discussion about the performance of TensorRT-LLM +title: "[Performance]: " +labels: ["Performance"] +assignees: ["byshiue", "kaiyux"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue+sort%3Acreated-desc+). +- type: textarea + attributes: + label: Proposal to improve performance + description: > + How do you plan to improve TensorRT-LLM's performance? + validations: + required: false +- type: textarea + attributes: + label: Report of performance regression + description: > + Please provide detailed description of performance comparison to confirm the regression. You may want to run the benchmark script at https://github.com/NVIDIA/TensorRT-LLM/tree/main/benchmarks . + validations: + required: false +- type: textarea + attributes: + label: Misc discussion on performance + description: > + Anything about the performance. + validations: + required: false +- type: textarea + attributes: + label: Your current environment (if you think it is necessary) + description: | + Please provide the following system information to help with performance analysis: + + ```bash + # System information + nvidia-smi + nvcc --version + python --version + pip show tensorrt_llm tensorrt torch + ``` + value: | + **System Information:** + - OS: + - Python version: + - CUDA version: + - GPU model(s): + - Driver version: + - TensorRT version: + - PyTorch version: + - TensorRT-LLM version: + + **Detailed output:** + ```text + Paste the output of the above commands here + ``` + validations: + required: false +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/08-RFC.yml b/.github/ISSUE_TEMPLATE/08-RFC.yml new file mode 100644 index 00000000000..20d505171b3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/08-RFC.yml @@ -0,0 +1,58 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/.github/ISSUE_TEMPLATE/750-RFC.yml +name: 💬 Request for comments (RFC). +description: Ask for feedback on major architectural changes or design choices. +title: "[RFC]: " +labels: ["RFC"] +assignees: ["laikhtewari"] + +body: +- type: markdown + attributes: + value: > + #### Please take a look at previous [RFCs](https://github.com/NVIDIA/TensorRT-LLM/issues?q=label%3ARFC+sort%3Aupdated-desc) for reference. +- type: textarea + attributes: + label: Motivation. + description: > + The motivation of the RFC. + validations: + required: true +- type: textarea + attributes: + label: Proposed Change. + description: > + The proposed change of the RFC. + validations: + required: true +- type: textarea + attributes: + label: Feedback Period. + description: > + The feedback period of the RFC. Usually at least one week. + validations: + required: false +- type: textarea + attributes: + label: CC List. + description: > + The list of people you want to CC. + validations: + required: false +- type: textarea + attributes: + label: Any Other Things. + description: > + Any other things you would like to mention. + validations: + required: false +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! The TensorRT-LLM team reviews RFCs during regular team meetings. Most RFCs can be discussed online, but you can also reach out to the team through GitHub discussions or issues for additional feedback. +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and checked the [documentation](https://nvidia.github.io/TensorRT-LLM/) and [examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for answers to frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml deleted file mode 100644 index 10591e6b23e..00000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ /dev/null @@ -1,114 +0,0 @@ -name: "Bug Report" -description: Submit a bug report to help us improve TensorRT-LLM -labels: [ "bug" ] -body: - - type: textarea - id: system-info - attributes: - label: System Info - description: Please share your system info with us. - placeholder: | - - CPU architecture (e.g., x86_64, aarch64) - - CPU/Host memory size (if known) - - GPU properties - - GPU name (e.g., NVIDIA H100, NVIDIA A100, NVIDIA L40S) - - GPU memory size (if known) - - Clock frequencies used (if applicable) - - Libraries - - TensorRT-LLM branch or tag (e.g., main, v0.7.1) - - TensorRT-LLM commit (if known) - - Versions of TensorRT, Modelopt, CUDA, cuBLAS, etc. used - - Container used (if running TensorRT-LLM in a container) - - NVIDIA driver version - - OS (Ubuntu 24.04, CentOS 8) - - Any other information that may be useful in reproducing the bug - validations: - required: true - - - type: textarea - id: who-can-help - attributes: - label: Who can help? - description: | - To expedite the response to your issue, it would be helpful if you could identify the appropriate person - to tag using the **@** symbol. Here is a general guideline on **whom to tag**. - - Rest assured that all issues are reviewed by the core maintainers. If you are unsure about whom to tag, - you can leave it blank, and a core maintainer will make sure to involve the appropriate person. - - Please tag fewer than 3 people. - - Quantization: @Tracin - - Documentation: @juney-nvidia - - Feature request: @ncomly-nvidia - - Performance: @kaiyux - - placeholder: "@Username ..." - - - type: checkboxes - id: information-scripts-examples - attributes: - label: Information - description: 'The problem arises when using:' - options: - - label: "The official example scripts" - - label: "My own modified scripts" - - - type: checkboxes - id: information-tasks - attributes: - label: Tasks - description: "The tasks I am working on are:" - options: - - label: "An officially supported task in the `examples` folder (such as GLUE/SQuAD, ...)" - - label: "My own task or dataset (give details below)" - - - type: textarea - id: reproduction - validations: - required: true - attributes: - label: Reproduction - description: | - Kindly share a code example that demonstrates the issue you encountered. It is recommending to provide a code snippet directly. - Additionally, if you have any error messages, or stack traces related to the problem, please include them here. - - Remember to use code tags to properly format your code. You can refer to the - link https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting for guidance on code formatting. - - Please refrain from using screenshots, as they can be difficult to read and prevent others from copying and pasting your code. - It would be most helpful if we could reproduce your issue by simply copying and pasting your scripts and codes. - - placeholder: | - Steps to reproduce the behavior: - - 1. - 2. - 3. - - - type: textarea - id: expected-behavior - validations: - required: true - attributes: - label: Expected behavior - description: "Provide a brief summary of the expected behavior of the software. Provide output files or examples if possible." - - - type: textarea - id: actual-behavior - validations: - required: true - attributes: - label: actual behavior - description: "Describe the actual behavior of the software and how it deviates from the expected behavior. Provide output files or examples if possible." - - - type: textarea - id: additioanl-notes - validations: - required: true - attributes: - label: additional notes - description: "Provide any additional context here you think might be useful for the TensorRT-LLM team to help debug this issue (such as experiments done, potential things to investigate)." diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000000..93ef69beebe --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: 🤔 Questions + url: https://github.com/NVIDIA/TensorRT-LLM/discussions + about: Ask questions and discuss with other TensorRT-LLM community members diff --git a/README.md b/README.md index 5ab7fb51b7f..83cad6eb028 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ TensorRT-LLM [![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/) [![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads) [![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt) -[![version](https://img.shields.io/badge/release-1.0.0rc6-green)](./tensorrt_llm/version.py) +[![version](https://img.shields.io/badge/release-1.1.0rc0-green)](./tensorrt_llm/version.py) [![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE) [Architecture](./docs/source/torch/arch_overview.md)   |   [Performance](./docs/source/performance/perf-overview.md)   |   [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)   |   [Documentation](./docs/source/)   |   [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap) diff --git a/benchmarks/cpp/README.md b/benchmarks/cpp/README.md index 0b89bae6029..ae3287faf06 100644 --- a/benchmarks/cpp/README.md +++ b/benchmarks/cpp/README.md @@ -336,7 +336,7 @@ cd cpp/build `disaggServerBenchmark` only supports `decoder-only` models. Here is the basic usage: ``` -export TRTLLM_USE_MPI_KVCACHE=1 +export TRTLLM_USE_UCX_KVCACHE=1 mpirun -n ${proc} benchmarks/disaggServerBenchmark --context_engine_dirs ${context_engine_0},${context_engine_1}...,${context_engine_{m-1}} \ --generation_engine_dirs ${generation_engine_0},${generation_engine_1}...,${generation_engine_{n-1}} --dataset ${dataset_path} ``` @@ -344,7 +344,7 @@ This command will launch m context engines and n generation engines. You need to for example: ``` -export TRTLLM_USE_MPI_KVCACHE=1 +export TRTLLM_USE_UCX_KVCACHE=1 mpirun -n 7 benchmarks/disaggServerBenchmark --context_engine_dirs ${llama_7b_tp2_pp1_dir},${llama_7b_tp1_pp1_dir} --generation_engine_dirs ${llama_7b_tp1_pp1_dir},${llama_7b_tp2_pp1_dir} --dataset ${dataset_path} # need 6 gpus and 7 processes to launch the benchmark. diff --git a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h index ce42493879e..394f7fb7bfa 100644 --- a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h +++ b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h @@ -75,27 +75,19 @@ class CreateNewDecoderRequests : Algorithm std::vector> operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, - runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, - runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream, - SizeType32 maxSequenceLength, SizeType32 beamWidth, OptionalRef medusaBuffers) const; + nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState, + CudaStream const& runtimeStream, CudaStream const& decoderStream, SizeType32 maxSequenceLength, + SizeType32 beamWidth, OptionalRef medusaBuffers) const; [[nodiscard]] std::tuple, std::vector> createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds, executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState, - runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType, - runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, + nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, runtime::CudaStream const& runtimeStream, runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength, OptionalRef medusaBuffers) const; private: - //! @brief Initialize the decoder at `batchSlot` with a new `request`. Exposed only for static batching via - //! GptDecoderBatched::newBatch() - static void newRequest(SizeType32 batchSlot, runtime::decoder_batch::Request const& request, - SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, - runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream, - SizeType32 maxSequenceLength); - //! @brief Setups decoder internal tensors for new speculative decoding request static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h index a232230c4ff..09a96a56eee 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h @@ -18,6 +18,7 @@ #include "tensorrt_llm/executor/executor.h" +#include #include #include #include @@ -36,7 +37,8 @@ using BlockPtr = std::shared_ptr; class KVCacheEventManager { public: - explicit KVCacheEventManager(size_t maxKVEventEntries); + explicit KVCacheEventManager(size_t maxKVEventEntries, std::optional attentionDpRank = std::nullopt, + std::optional attentionDpSize = std::nullopt, SizeType32 attentionDpEventsGatherPeriodMs = 5); ~KVCacheEventManager(); KVCacheEventManager(KVCacheEventManager& other) = delete; @@ -61,14 +63,19 @@ class KVCacheEventManager // Worker thread which adds events to mEvents. void worker(); + // Thread which exchanges events if attentionDP is enabled + void exchangeAttentionDpThread(); + private: // Add an event to mEventQueue void enqueueEvent(executor::KVCacheEvent&& event); /// @brief Flag to terminate the worker - bool mRun; + std::atomic mRun; /// @brief Worker thread std::thread mWorkerThread; + /// @brief Exchange thread for attention DP events + std::thread mExchangeAttentionDpThread; /// @brief The deque of events std::deque mEvents; @@ -91,6 +98,17 @@ class KVCacheEventManager size_t mMaxSize; /// @brief An auto-incrementing event id counter size_t mEventId; + + /// @brief Attention DP ranks and size + /// If set, we will exchange KV cache events and accumulate on rank 0 + std::optional mAttentionDpRank; + std::optional mAttentionDpSize; + + /// @brief The period in milliseconds to gather attention DP events across rank + SizeType32 mAttentionDpEventsGatherPeriodMs; + + /// @brief MPI communicator for attention DP + std::unique_ptr mMpiComm; }; } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index a0234cbbe49..a49527a6157 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -536,8 +536,7 @@ class WindowBlockManager SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse); + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse); ~WindowBlockManager(); @@ -633,11 +632,6 @@ class WindowBlockManager return mAllBlocksById.at(blockId); } - [[nodiscard]] BlockMapIterRange getBlocksByHash(size_t hash) const - { - return mContextBlocksByHash.equal_range(hash); - } - [[nodiscard]] SizeType32 getTokensPerBlock() const noexcept { return mTokensPerBlock; @@ -723,10 +717,6 @@ class WindowBlockManager //! \param blockIds Id of each block. void storeBlocks(std::vector const& blockKeys, std::vector const& blockIds); - void addBlockToHashMap(BlockPtr const& block); - - void removeBlockFromHashMap(BlockPtr const& block); - [[nodiscard]] bool verifyQueueIntegrity(); // Only needed when sliding window attention + paged context fmha are used together. @@ -808,8 +798,6 @@ class WindowBlockManager SizeType32 mTokensPerBlock; // List of all blocks by idx std::vector mAllBlocksById; - // List of all context blocks by hash - BlockMap mContextBlocksByHash; // Dummy block acting as root for BlockToken searches BlockPtr mCachedBlocksRoot; // KV cache type (self or cross) @@ -841,8 +829,6 @@ class WindowBlockManager double mReusedTokens; // Total number of input tokens double mTotalInputTokens; - // Whether or not to maintain a hashmap of blocks. - bool mEnableHashKey; // Whether blocks that are partially matched should be reused. bool mEnablePartialReuse; // Whether partially matched blocks that are already in use should be copied and reused. @@ -863,8 +849,8 @@ class BlockManager std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, - std::shared_ptr eventManager = nullptr, bool enableHashKey = false, - bool enablePartialReuse = true, bool copyOnPartialReuse = true); + std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, + bool copyOnPartialReuse = true); BlockManager(BlockManager const&) = delete; BlockManager& operator=(BlockManager const&) = delete; @@ -1081,11 +1067,6 @@ class BlockManager return mWindowBlockManagers.at(windowSize).getBlockById(blockId); } - [[nodiscard]] WindowBlockManager::BlockMapIterRange getBlocksByHash(size_t hash, SizeType32 windowSize) const - { - return mWindowBlockManagers.at(windowSize).getBlocksByHash(hash); - } - [[nodiscard]] SizeType32 getNumPrimaryBlocks() const { return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); }); @@ -1096,16 +1077,6 @@ class BlockManager return getPool(poolIdx).containsBlockScales; } - void addBlockToHashMap(BlockPtr const& block, SizeType32 windowSize) - { - mWindowBlockManagers.at(windowSize).addBlockToHashMap(block); - } - - void removeBlockFromHashMap(BlockPtr const& block, SizeType32 windowSize) - { - mWindowBlockManagers.at(windowSize).removeBlockFromHashMap(block); - } - //! \brief Store context blocks void storeContextBlocks(GenerationRequest& sequence, LlmRequest const& llmRequest); @@ -1385,8 +1356,8 @@ class KVCacheManager : public BaseKVCacheManager SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, - std::shared_ptr eventManager = nullptr, bool enableHashKey = false, - bool enablePartialReuse = true, bool copyOnpartialReuse = true); + std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, + bool copyOnpartialReuse = true); KVCacheManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, @@ -1405,8 +1376,8 @@ class KVCacheManager : public BaseKVCacheManager SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, - std::shared_ptr eventManager = nullptr, bool enableHashKey = false, - bool enablePartialReuse = true, bool copyOnpartialReuse = true); + std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, + bool copyOnpartialReuse = true); KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, @@ -1692,8 +1663,6 @@ class KVCacheManager : public BaseKVCacheManager std::unordered_map mSequences; // Whether to cache KV pages for reuse bool mEnableBlockReuse; - // Whether enable finding blocks by their hash, ignored when reuse enabled - bool mEnableHashKey; // Mutex to protect access to mSequences mutable std::mutex mSequencesMtx; // buffers for static tensors, will be created after allocating pools diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 0d087d96c0f..e4d13c9e17b 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -828,8 +828,10 @@ class GenericLlmRequest // for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT : LlmRequestState::kCONTEXT_INIT; - mContextCurrentPosition = 0; - mPrepopulatedPromptLen = 0; + mContextCurrentPositionTarget = 0; + mContextCurrentPositionDraft = 0; + mPrepopulatedPromptLenTarget = 0; + mPrepopulatedPromptLenDraft = 0; mContextChunkSize = mPromptLen; mSeqSlot.reset(); } @@ -1049,7 +1051,7 @@ class GenericLlmRequest [[nodiscard]] SizeType32 getPrepopulatedPromptLen() const { - return mPrepopulatedPromptLen; + return mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget; } void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock) @@ -1066,7 +1068,10 @@ class GenericLlmRequest "Invalid state: prepopulatedPromptLen (%d) >= promptLen (%d) for request %lu", prepopulatedPromptLen, promptLen, mRequestId); TLLM_CHECK(prepopulatedPromptLen < promptLen); - mPrepopulatedPromptLen = prepopulatedPromptLen; + + auto& prePromptLen = mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget; + auto& contextCurrentPosition = mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget; + prePromptLen = prepopulatedPromptLen; if (prepopulatedPromptLen > 0) { @@ -1081,7 +1086,7 @@ class GenericLlmRequest chunkSize = flooredEndPosition - prepopulatedPromptLen; TLLM_CHECK(chunkSize <= getContextChunkSize()); } - setContextCurrentPosition(prepopulatedPromptLen); + contextCurrentPosition = prepopulatedPromptLen; setContextChunkSize(chunkSize); if (!isLastContextChunk()) @@ -1522,14 +1527,15 @@ class GenericLlmRequest void setContextCurrentPosition(SizeType32 contextCurrentPosition) { - mContextCurrentPosition = contextCurrentPosition; + mContextCurrentPositionDraft = contextCurrentPosition; + mContextCurrentPositionTarget = contextCurrentPosition; } /// When chunked, the position of the current chunk is returned. Otherwise, only the beginning /// or end of the context is returned. [[nodiscard]] SizeType32 getContextCurrentPosition() const noexcept { - return mContextCurrentPosition; + return mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget; } /// Return the length of the context that has not yet been processed. @@ -1570,14 +1576,16 @@ class GenericLlmRequest { // The number of cached token is encountered in mContextCurrentPosition, // so the start position of the context is mPrepopulatedPromptLen. - return mContextCurrentPosition == mPrepopulatedPromptLen; + return getContextCurrentPosition() == getPrepopulatedPromptLen(); } /// Move the cursor forward one chunk. When not chunked, move forward to the end of the context. void moveToNextContextChunk() { TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase."); - mContextCurrentPosition += getContextChunkSize(); + + mContextCurrentPositionDraft += getContextChunkSize(); + mContextCurrentPositionTarget += getContextChunkSize(); setContextChunkSize(0); } @@ -1843,6 +1851,16 @@ class GenericLlmRequest return mIsDummyRequest; } + void setUseDraftModel(bool useDraftModel) + { + mUseDraftModel = useDraftModel; + } + + [[nodiscard]] bool useDraftModel() const + { + return mUseDraftModel; + } + RequestIdType mRequestId; SizeType32 mPromptLen; SizeType32 mMaxNewTokens; @@ -1885,7 +1903,8 @@ class GenericLlmRequest // Number of tokens already in KV cache before context phase. // A value > 0 indicates cached KV cache blocks were reused. // Up to inputLen - 1 tokens can be reused. - SizeType32 mPrepopulatedPromptLen{0}; + SizeType32 mPrepopulatedPromptLenTarget{0}; + SizeType32 mPrepopulatedPromptLenDraft{0}; SizeType32 mMaxSentTokenLen; @@ -1916,7 +1935,8 @@ class GenericLlmRequest // The size of the context chunk must be multiple of the KV-Cache block size except the last one. // Value `0` means Chunked-Context is disabled. SizeType32 mContextChunkSize{0}; - SizeType32 mContextCurrentPosition{0}; + SizeType32 mContextCurrentPositionTarget{0}; + SizeType32 mContextCurrentPositionDraft{0}; std::vector mLogProbs; // [beamSize, seqLen] VecLogProbs mCumLogProbs; // [beamSize] @@ -2017,6 +2037,8 @@ class GenericLlmRequest bool mIsDummyRequest{false}; + bool mUseDraftModel{false}; + private: void initialize(VecTokens const& inputTokens, bool outputLogProbs) { @@ -2027,7 +2049,7 @@ class GenericLlmRequest // Scatter the input tokens to other beam mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens); - mLastTokens = VecTokens(mSamplingConfig.beamWidth); + mLastTokens = VecTokens(mSamplingConfig.beamWidth, inputTokens.back()); // Init mUniqueTokens VecUniqueTokens uniqueTokens{inputTokens.size()}; @@ -2347,6 +2369,9 @@ class LlmRequest : public GenericLlmRequest void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager); void moveLoraWeightsToGpu(runtime::BufferManager const& manager); + + // Remove LoRA weights and LoRA config tensors + void removeLoraTensors(); }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/common/quantization.h b/cpp/include/tensorrt_llm/common/quantization.h index 836faa258fe..50aae114e0c 100644 --- a/cpp/include/tensorrt_llm/common/quantization.h +++ b/cpp/include/tensorrt_llm/common/quantization.h @@ -122,6 +122,16 @@ class QuantMode return QuantMode(BaseType(1u) << 14); } + static constexpr QuantMode w4a8Mxfp4Mxfp8() noexcept + { + return QuantMode(BaseType(1u) << 15); + } + + static constexpr QuantMode w4a16Mxfp4() noexcept + { + return QuantMode(BaseType(1u) << 16); + } + constexpr BaseType value() const noexcept { return mValue; @@ -202,6 +212,16 @@ class QuantMode return isSet(w4a8Mxfp4Fp8()); } + constexpr bool hasW4a8Mxfp4Mxfp8() const noexcept + { + return isSet(w4a8Mxfp4Mxfp8()); + } + + constexpr bool hasW4a16Mxfp4() const noexcept + { + return isSet(w4a16Mxfp4()); + } + constexpr bool hasKvCacheQuant() const noexcept { return hasInt8KvCache() || hasFp8KvCache() || hasFp4KvCache(); @@ -209,7 +229,8 @@ class QuantMode static constexpr QuantMode fromDescription(bool quantizeWeights, bool quantizeActivations, bool perToken, bool perChannel, bool perGroup, bool useInt4Weights, bool useInt8KvCache, bool useFp8KvCache, bool useFp8Qdq, - bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8) + bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8, + bool useW4a8Mxfp4Mxfp8, bool useW4a16Mxfp4) { QuantMode quantMode{}; if (quantizeWeights) @@ -278,25 +299,35 @@ class QuantMode quantMode += w4a8Mxfp4Fp8(); } + if (useW4a8Mxfp4Mxfp8) + { + quantMode += w4a8Mxfp4Mxfp8(); + } + + if (useW4a16Mxfp4) + { + quantMode += w4a16Mxfp4(); + } + return quantMode; } static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false) { - return fromDescription( - true, true, perToken, perChannel, false, false, false, false, false, false, false, false, false, false); + return fromDescription(true, true, perToken, perChannel, false, false, false, false, false, false, false, false, + false, false, false, false); } static constexpr QuantMode useQServe(bool perGroup) { - return fromDescription( - true, true, false, false, perGroup, true, false, false, false, false, true, false, false, false); + return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true, false, false, + false, false, false); } static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false) { return fromDescription(true, false, false, false, perGroup, useInt4Weights, false, false, false, false, false, - false, false, false); + false, false, false, false, false); } static QuantMode const fromQuantAlgo( @@ -353,28 +384,38 @@ class QuantMode } else if (quantAlgo == "FP8") { - quantMode = fromDescription( - false, false, false, false, false, false, false, false, true, false, false, false, false, false); + quantMode = fromDescription(false, false, false, false, false, false, false, false, true, false, false, + false, false, false, false, false); } else if (quantAlgo == "FP8_ROWWISE") { - quantMode = fromDescription( - false, false, true, true, false, false, false, false, false, true, false, false, false, false); + quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true, false, false, + false, false, false, false); } else if (quantAlgo == "FP4") { - quantMode = fromDescription( - false, false, false, false, false, false, false, false, false, false, false, true, false, false); + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + true, false, false, false, false); } else if (quantAlgo == "FP8_BLOCK_SCALES") { - quantMode = fromDescription( - false, false, false, false, false, false, false, false, false, false, false, false, true, false); + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + false, true, false, false, false); } else if (quantAlgo == "W4A8_MXFP4_FP8") { - quantMode = fromDescription( - false, false, false, false, false, false, false, false, false, false, false, false, false, true); + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + false, false, true, false, false); + } + else if (quantAlgo == "W4A8_MXFP4_MXFP8") + { + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + false, false, false, true, false); + } + else if (quantAlgo == "W4A16_MXFP4") + { + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, true); } if (kvCacheQuantAlgo == "INT8") diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 6d592654ffd..0a58298c279 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -1001,6 +1001,7 @@ class KvCacheConfig std::optional const& crossKvCacheFraction = std::nullopt, std::optional secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0, bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false, + SizeType32 attentionDpEventsGatherPeriodMs = 5, std::optional const& runtimeDefaults = std::nullopt); [[nodiscard]] bool getEnableBlockReuse() const; @@ -1016,6 +1017,7 @@ class KvCacheConfig [[nodiscard]] std::optional getSecondaryOffloadMinPriority() const; [[nodiscard]] size_t getEventBufferMaxSize() const; [[nodiscard]] bool getUseUvm() const; + [[nodiscard]] SizeType32 getAttentionDpEventsGatherPeriodMs() const; void setEnableBlockReuse(bool enableBlockReuse); void setEnablePartialReuse(bool enablePartialReuse); @@ -1030,6 +1032,7 @@ class KvCacheConfig void setSecondaryOffloadMinPriority(std::optional secondaryOffloadMinPriority); void setEventBufferMaxSize(size_t eventBufferMaxSize); void setUseUvm(bool useUvm); + void setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs); void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults); @@ -1085,6 +1088,9 @@ class KvCacheConfig /// @brief Whether to use UVM for the KV cache. bool mUseUvm; + + /// @brief The period in milliseconds to gather attention DP events across ranks + SizeType32 mAttentionDpEventsGatherPeriodMs; }; /// @brief Configuration class for the runtime perf knobs @@ -1702,6 +1708,12 @@ struct KVCacheUpdatedData explicit KVCacheUpdatedData(IdType blockHash) : blockHash{blockHash} {}; + explicit KVCacheUpdatedData(IdType blockHash, std::optional> cacheLevel, + std::optional> priority) + : blockHash{blockHash} + , cacheLevel{cacheLevel} + , priority{priority} {}; + KVCacheUpdatedData& cacheLevelUpdated(SizeType32 oldValue, SizeType32 newValue) { cacheLevel = KVCacheEventDiff{oldValue, newValue}; @@ -1726,8 +1738,8 @@ using KVCacheEventData = std::variant attentionDpRank = std::nullopt); /// @brief The unique id of this event IdType eventId; @@ -1735,6 +1747,8 @@ struct KVCacheEvent KVCacheEventData data; /// @brief The sliding window size SizeType32 windowSize; + /// @brief The attention DP rank of the event, if applicable + std::optional attentionDpRank; }; /// @brief Exposes a limited set of KV cache manager functionalities diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h index b2ecfc66c84..c370a652350 100644 --- a/cpp/include/tensorrt_llm/executor/serialization.h +++ b/cpp/include/tensorrt_llm/executor/serialization.h @@ -302,6 +302,53 @@ class Serialization [[nodiscard]] static std::vector deserializeRequestStatsPerIterationVec( std::vector& buffer); + // KVCacheEvent deque + [[nodiscard]] static std::vector serialize(std::deque const& kvCacheEvents); + [[nodiscard]] static std::deque deserializeKVCacheEvents(std::vector& buffer); + + // KVCacheEvent + [[nodiscard]] static size_t serializedSize(KVCacheEvent const& event); + static void serialize(KVCacheEvent const& event, std::ostream& os); + [[nodiscard]] static KVCacheEvent deserializeKVCacheEvent(std::istream& is); + + // KVCacheCreatedData + [[nodiscard]] static size_t serializedSize(KVCacheCreatedData const& data); + static void serialize(KVCacheCreatedData const& data, std::ostream& os); + [[nodiscard]] static KVCacheCreatedData deserializeKVCacheCreatedData(std::istream& is); + + // KVCacheStoredData + [[nodiscard]] static size_t serializedSize(KVCacheStoredData const& data); + static void serialize(KVCacheStoredData const& data, std::ostream& os); + [[nodiscard]] static KVCacheStoredData deserializeKVCacheStoredData(std::istream& is); + + // KVCacheStoredBlockData + [[nodiscard]] static size_t serializedSize(KVCacheStoredBlockData const& data); + static void serialize(KVCacheStoredBlockData const& data, std::ostream& os); + [[nodiscard]] static KVCacheStoredBlockData deserializeKVCacheStoredBlockData(std::istream& is); + + // KVCacheRemovedData + [[nodiscard]] static size_t serializedSize(KVCacheRemovedData const& data); + static void serialize(KVCacheRemovedData const& data, std::ostream& os); + [[nodiscard]] static KVCacheRemovedData deserializeKVCacheRemovedData(std::istream& is); + + // KVCacheEventDiff + template + [[nodiscard]] static size_t serializedSize(KVCacheEventDiff const& data); + template + static void serialize(KVCacheEventDiff const& data, std::ostream& os); + template + [[nodiscard]] static KVCacheEventDiff deserializeKVCacheEventDiff(std::istream& is); + + // KVCacheUpdateData + [[nodiscard]] static size_t serializedSize(KVCacheUpdatedData const& data); + static void serialize(KVCacheUpdatedData const& data, std::ostream& os); + [[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is); + + // UniqueToken + [[nodiscard]] static size_t serializedSize(tensorrt_llm::runtime::UniqueToken const& token); + static void serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os); + [[nodiscard]] static tensorrt_llm::runtime::UniqueToken deserializeUniqueToken(std::istream& is); + // String static std::string deserializeString(std::istream& is); diff --git a/cpp/include/tensorrt_llm/runtime/decoderState.h b/cpp/include/tensorrt_llm/runtime/decoderState.h index e4fe9c38010..95d7ff0ffac 100644 --- a/cpp/include/tensorrt_llm/runtime/decoderState.h +++ b/cpp/include/tensorrt_llm/runtime/decoderState.h @@ -51,13 +51,13 @@ class DecoderState DecoderState(); //! @brief Setup buffers for the decoder excluding speculative decoding. - void setup(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, + void setup(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager); //! @brief Setup buffers for the cache indirection. //! @details This is used for beam search on pipeline parallel ranks without a decoder. - void setupCacheIndirection(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, + void setupCacheIndirection(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, BufferManager const& bufferManager); //! @brief Setup buffers for speculative decoding. @@ -134,7 +134,7 @@ class DecoderState //! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu [[nodiscard]] TensorPtr getAcceptedPackedPaths() const; - [[nodiscard]] SizeType32 getMaxBatchSize() const; + [[nodiscard]] SizeType32 getMaxNumSequences() const; [[nodiscard]] SizeType32 getMaxBeamWidth() const; @@ -173,6 +173,11 @@ class DecoderState //! @brief Workspace for beam search in streaming mode. [[nodiscard]] BeamSearchBuffers const& getBeamSearchBuffers() const; + //! @brief Set the beam width for a specific request in the batch. + //! @param batchIdx The index of the request in the batch. + //! @param beamWidth The beam width for the specified request. + void setBeamWidth(SizeType32 batchIdx, SizeType32 beamWidth); + //! @brief Cache indirection input for beam search. [[nodiscard]] TensorPtr getCacheIndirectionInput() const; @@ -187,10 +192,10 @@ class DecoderState //! @param generationSteps The generation steps for all requests in the batch. void setGenerationSteps(std::vector const& generationSteps); - //! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots. + //! @brief Stateful inputs for the decoder. Allocated for maxNumSequences slots. [[nodiscard]] DecodingInput& getJointDecodingInput() const; - //! @brief Stateful outputs for the decoder. Allocated for maxBatchSize slots. + //! @brief Stateful outputs for the decoder. Allocated for maxNumSequences slots. [[nodiscard]] DecodingOutput& getJointDecodingOutput() const; private: @@ -209,13 +214,13 @@ class DecoderState SizeType32 maxTokensPerEngineStep, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager); - SizeType32 mMaxBatchSize{}; + SizeType32 mMaxNumSequences{}; SizeType32 mMaxBeamWidth{}; SizeType32 mMaxSequenceLength{}; - //! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots. + //! @brief Stateful inputs for the decoder. Allocated for maxNumSequences slots. DecodingInputPtr mJointDecodingInput; - //! @brief Stateful outputs for the decoder. Allocated for maxBatchSize slots. + //! @brief Stateful outputs for the decoder. Allocated for maxNumSequences slots. DecodingOutputPtr mJointDecodingOutput; //! @brief Workspace for beam search in streaming mode. diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoder.h b/cpp/include/tensorrt_llm/runtime/gptDecoder.h index 90690c90fc0..7e0cc1bb56d 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoder.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoder.h @@ -71,7 +71,7 @@ class IGptDecoder = 0; static std::unique_ptr create(executor::DecodingMode const& mode, nvinfer1::DataType dtype, - size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, + size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream, std::shared_ptr const& speculativeDecodingModule = nullptr); }; @@ -84,7 +84,7 @@ class GptDecoder : public virtual IGptDecoder using CudaStreamPtr = BufferManager::CudaStreamPtr; using TensorPtr = std::shared_ptr; - GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, + GptDecoder(executor::DecodingMode const& mode, size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, CudaStreamPtr const& stream, std::shared_ptr speculativeDecodingModule = nullptr); @@ -114,7 +114,7 @@ class GptDecoder : public virtual IGptDecoder SamplingConfig mSamplingConfig; - size_t mMaxBatchSize; + size_t mMaxNumSequences; size_t mVocabSize; size_t mVocabSizePadded; @@ -122,7 +122,7 @@ class GptDecoder : public virtual IGptDecoder }; inline std::unique_ptr IGptDecoder::create(executor::DecodingMode const& mode, nvinfer1::DataType dtype, - size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, + size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream, std::shared_ptr const& speculativeDecodingModule) { @@ -130,10 +130,10 @@ inline std::unique_ptr IGptDecoder::create(executor::DecodingMode c { case nvinfer1::DataType::kFLOAT: return std::make_unique>( - mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule); + mode, maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule); case nvinfer1::DataType::kHALF: return std::make_unique>( - mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule); + mode, maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule); default: TLLM_THROW("Unsupported decoder data type: %d. Use either kFLOAT or kHALF.", static_cast(dtype)); return nullptr; diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h index d5dfe9b7b19..d0a9e726d13 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h @@ -47,7 +47,7 @@ class GptDecoderBatched : public IGptDecoderBatched explicit GptDecoderBatched(CudaStreamPtr stream); - void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth, + void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) override; void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override; diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h index 327af71f8a7..606ba3c98a4 100644 --- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h @@ -86,7 +86,7 @@ class IGptDecoderBatched using TensorPtr = std::shared_ptr; //! @brief Setup the decoder before calling `forward()` - virtual void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth, + virtual void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) = 0; diff --git a/cpp/include/tensorrt_llm/runtime/request.h b/cpp/include/tensorrt_llm/runtime/request.h index 1861ea84317..e8f851b7d77 100644 --- a/cpp/include/tensorrt_llm/runtime/request.h +++ b/cpp/include/tensorrt_llm/runtime/request.h @@ -31,26 +31,16 @@ class Request using TensorPtr = ITensor::SharedPtr; using BufferPtr = IBuffer::SharedPtr; - explicit Request(TensorConstPtr ids, SizeType32 inputLen, std::optional maxNewTokens = std::nullopt, - std::optional endId = std::nullopt) - : ids{std::move(ids)} - , inputLen(inputLen) - , maxNewTokens{maxNewTokens} - , endId{endId} + explicit Request(SizeType32 inputLen) + : inputLen(inputLen) { } //! Mandatory parameters - TensorConstPtr ids; // The input sequence of token ids, [inputSeqLen], on gpu SizeType32 inputLen; // Input length without draft tokens, increasing with generation steps // optional parameters - std::optional maxNewTokens; // maximum number of tokens to generate for this request - std::optional endId; // end token id SizeType32 generatedTokensPerEngineStep{1}; // - TensorPtr embeddingBias; // [vocabSizePadded], on gpu - TensorPtr badWordsList; // [2, badWordsLength] on gpu - TensorPtr stopWordsList; // [2, stopWordsLength] on gpu //! Optional parameters for speculative decoding BufferPtr draftTokens; // [generatedTokensPerEngineStep - 1] on gpu diff --git a/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h b/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h index 4443d422ab8..32c086c84ee 100644 --- a/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h +++ b/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h @@ -68,6 +68,10 @@ enum class MpiTag : int // LogitsThread kSpecDecLogitsId = 129, kSpecDecLogitsData = 1025, + + // KvCacheEventManager + kKvCacheEventSize = 1026, + kKvCacheEvent = 1027 }; } // namespace tensorrt_llm::mpi diff --git a/cpp/kernels/fmha_v2/fmha_test.py b/cpp/kernels/fmha_v2/fmha_test.py index f9f28978e66..d02e3cc31c0 100644 --- a/cpp/kernels/fmha_v2/fmha_test.py +++ b/cpp/kernels/fmha_v2/fmha_test.py @@ -1,7 +1,12 @@ import subprocess import pytest -from cuda import cuda, nvrtc + +try: + from cuda.bindings import driver as cuda + from cuda.bindings import nvrtc +except ImportError: + from cuda import cuda, nvrtc def ASSERT_DRV(err): @@ -50,7 +55,7 @@ def getSMVersion(): ids=["fp16", "bf16", "fp16-fp32", "e4m3"]) @pytest.mark.parametrize('flag', [ "-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv", - "-softcapping-scale-bmm1 30", "-contiguous-q-kv" + "-softcapping-scale-bmm1 30", "-contiguous-q-kv", "-use-attention-sinks" ]) @pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"]) def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel): @@ -117,8 +122,8 @@ def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel): f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}", shell=True, check=True) - # alibi and softcapping-scale-bmm1 are mutually exclusive. - if '-softcapping-scale-bmm1' not in flag: + # alibi doesn't work with softcapping-scale-bmm1/use-attention-sinks. + if '-softcapping-scale-bmm1' not in flag and '-use-attention-sinks' not in flag: subprocess.run( f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}", shell=True, diff --git a/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h b/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h index 65e56dbf5de..eed6f852da3 100644 --- a/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h +++ b/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h @@ -326,9 +326,6 @@ struct Compute uint32_t smem_v = __cvta_generic_to_shared(&shared->smem_v[0]); Compute_tile_o ctile_o(0, smem_v); - // BMM2 epilogue - Tile_o_epilogue tile_o_epilogue(params); - // Mutex between two compute groups. OrderedMutexAccessor mutex_accessor(shared->compute_mutex, warpgroup_id, SYNC_BARRIER); // Notify warpgroup 0 to execute HGMMA first (overlap HGMMA and Softmax Math Instructions). @@ -368,6 +365,9 @@ struct Compute sage_scale_row = head_info.bidb * params.h + head_info.bidh; } + // BMM2 epilogue + Tile_o_epilogue tile_o_epilogue(params, head_info); + int q_step_idx = warpgroup_id; // Compute work. @@ -490,7 +490,7 @@ struct Compute if (valid_run) { // Final step's update. - tile_o_epilogue.scale(ctile_o, p_sum); + tile_o_epilogue.scale(ctile_o, p_max, p_sum); // Store o_tile to gmem. gmem_o.store(ctile_o.acc_); } diff --git a/cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h b/cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h index 217e8c08722..99ea1643cd0 100644 --- a/cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h +++ b/cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h @@ -454,7 +454,7 @@ struct Softmax_base #pragma unroll for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++) { - uint32_t const scale = float_to_half2(correction_[mi]); + const uint32_t scale = float_to_half2(correction_[mi]); // Assume only N has multiple MMAs (MMAS_M = 1). // MMAS_N > 1 when N dimension is split. @@ -477,9 +477,15 @@ struct Softmax_base } // BMM1 scale. - uint32_t const scale_bmm1_; + const uint32_t scale_bmm1_; // BMM1 softcapping scale. float const softcapping_scale_bmm1_; + + // The sliding window size. + int const sliding_window_size_; + // The log2 attention chunk size. + int const log2_chunked_attention_size_; + // The thread idx in the warp group. int tidx_; // The col index for the mma thread layout. @@ -487,15 +493,10 @@ struct Softmax_base // The row index for the mma thread layout. int quad_row_; - // The sliding window size. - int const sliding_window_size_; - // The log2 attention chunk size. - int const log2_chunked_attention_size_; - // The packed mask ptr. uint32_t const* packed_mask_ptr_; // The packed mask k-dim stride in bytes; - int64_t const params_packed_mask_stride_in_bytes_; + const int64_t params_packed_mask_stride_in_bytes_; // Unpacked BMM1 output buffer. float elt_[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2]; @@ -1072,20 +1073,53 @@ struct Tile_o_epilogue_base // The MMA tile for the BMM2. using Mma_tile_o = typename Kernel_traits::Mma_tile_o; - template - inline __device__ Tile_o_epilogue_base(Params const& params) + // Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs). + enum + { + EXP2F_OPTIMIZATION = Kernel_traits::EXP2F_OPTIMIZATION + }; + + template + inline __device__ Tile_o_epilogue_base(Params const& params, Block_info& block_info) { - ; // nothing to construct. + has_attention_sink_ = params.attention_sinks != nullptr; + head_idx_ = block_info.bidh; + attention_sink_ = has_attention_sink_ ? params.attention_sinks[block_info.bidh] : 0.f; + // It is only need when the exp2f optimization is enabled, so params.scale_bmm1 is always float. + scale_bmm1_f_ = reinterpret_cast(params.scale_bmm1_d ? *params.scale_bmm1_d : params.scale_bmm1); }; + // The attention sinks. + inline __device__ void add_attention_sink(float& sum, float max) + { + if (has_attention_sink_) + { + // The global max needs to be scaled by the bmm1 scale if exp2f optimization is enabled. + if constexpr (EXP2F_OPTIMIZATION) + { + sum += exp2f(attention_sink_ * M_LOG2E - max * scale_bmm1_f_); + } + else + { + sum += expf(attention_sink_ - max); + } + } + } + // Scale ctile_o output by 1/sum - inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M]) + inline __device__ void scale( + Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M]) { // Final step's update. #pragma unroll for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++) { - global_sum[mi] = global_sum[mi] == 0.f ? 1.f : 1.0f / global_sum[mi]; + // The global sum. + float global_sum_mi = global_sum[mi]; + // Add the attention sink to the global sum. + add_attention_sink(global_sum_mi, global_max[mi]); + // The scale. + float scale = global_sum_mi == 0.f ? 1.f : 1.0f / global_sum_mi; // Assume only N has multiple MMAs (MMAS_M = 1). #pragma unroll @@ -1096,12 +1130,21 @@ struct Tile_o_epilogue_base { float& reg0 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi); float& reg1 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi + 1); - reg0 *= global_sum[mi]; - reg1 *= global_sum[mi]; + reg0 *= scale; + reg1 *= scale; } } } } + + // Whether the attention sink is enabled. + bool has_attention_sink_ = false; + // The attention sink value. + float attention_sink_ = 0.f; + // The float scale of bmm1 outputs. + float scale_bmm1_f_ = 1.f; + // The head idx. + int head_idx_ = 0; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1138,14 +1181,21 @@ struct Tile_o_epilogue using Base::Tile_o_epilogue_base; // Scale ctile_o output by 1/sum - inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M]) + inline __device__ void scale( + Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M]) { // Final step's update. #pragma unroll for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++) { - global_sum[mi] = global_sum[mi] == 0.f ? 1.f : 1.0f / global_sum[mi]; - uint32_t const scale = float_to_half2(global_sum[mi]); + // The global sum. + float global_sum_mi = global_sum[mi]; + // Add the attention sink to the global sum. + this->add_attention_sink(global_sum_mi, global_max[mi]); + // The scale. + float scale = global_sum_mi == 0.f ? 1.f : 1.0f / global_sum_mi; + // The scale. + const uint32_t scale_h = float_to_half2(scale); // Assume only N has multiple MMAs (MMAS_M = 1). #pragma unroll @@ -1155,7 +1205,7 @@ struct Tile_o_epilogue for (int ni = 0; ni < Mma_tile_o::CORES_N; ni++) { uint32_t& reg = ctile_o.acc_[0][mma_ni].reg(ni * Mma_tile_o::CORES_M + mi); - reg = hmul2(reg, scale); + reg = hmul2(reg, scale_h); } } } @@ -1215,27 +1265,58 @@ struct Tile_o_epilogue // The MMA tile for the BMM2. using Mma_tile_o = typename Base::Mma_tile_o; + // Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs). + enum + { + EXP2F_OPTIMIZATION = Base::EXP2F_OPTIMIZATION + }; + // Ctor. - template - inline __device__ Tile_o_epilogue(Params const& params) - : Base(params) + template + inline __device__ Tile_o_epilogue(Params const& params, Block_info& block_info) + : Base(params, block_info) , scale_bmm2_(*params.scale_bmm2_d) { } + // Add the attention sink to the global sum. + inline __device__ void add_attention_sink(float& sum, float max) + { + if (this->has_attention_sink_) + { + // The global max needs to be scaled by the bmm1 scale if exp2f optimization is enabled. + // Take the log2f(Traits_o::SOFTMAX_FP_QUANT_SCALE) into account as the same scale has been applied to sum. + float quant_scale_in_log2 = log2f(Traits_o::SOFTMAX_FP_QUANT_SCALE); + if constexpr (EXP2F_OPTIMIZATION) + { + sum += exp2f(this->attention_sink_ * M_LOG2E - max * this->scale_bmm1_f_ + quant_scale_in_log2); + } + else + { + sum += expf(this->attention_sink_ - max + quant_scale_in_log2); + } + } + } + // Scale ctile_o output by 1/sum - inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M]) + inline __device__ void scale( + Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M]) { // Final step's update. #pragma unroll for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++) { + // The global sum. + float global_sum_mi = global_sum[mi]; + // Add the attention sink to the global sum. + add_attention_sink(global_sum_mi, global_max[mi]); #ifdef UNIFIED_EPILOGUE_SCALE // Descaling factor float const scale_bmm2_f_ = reinterpret_cast(scale_bmm2_); - global_sum[mi] = global_sum[mi] == 0.f ? scale_bmm2_f_ : scale_bmm2_f_ / global_sum[mi]; + // The scale. + float scale = global_sum_mi == 0.f ? scale_bmm2_f_ : scale_bmm2_f_ / global_sum_mi; #else - global_sum[mi] = global_sum[mi] == 0.f ? 1.0f : 1.0f / global_sum[mi]; + float scale = global_sum_mi == 0.f ? 1.0f : 1.0f / global_sum_mi; #endif // Assume only N has multiple MMAs (MMAS_M = 1). #pragma unroll @@ -1246,8 +1327,8 @@ struct Tile_o_epilogue { float& reg0 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi); float& reg1 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi + 1); - reg0 *= global_sum[mi]; - reg1 *= global_sum[mi]; + reg0 *= scale; + reg1 *= scale; } } } diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp b/cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp index 6d9811ac071..6cf52fcf4c9 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp +++ b/cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp @@ -29,30 +29,33 @@ using Kv_block_array = fmha::Kv_block_array; //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); +void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi); +void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, + int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); +void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_bf16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); +void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, int warps_n, - bool has_alibi); +void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, + float softcapping_scale_bmm1, int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,11 +84,11 @@ void run_sage_quant(unsigned int batch_size, unsigned int head_num, unsigned int //////////////////////////////////////////////////////////////////////////////////////////////////// -void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_type const acc_type, +void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type, float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1, - void* qkv_d, void* vt_d, void* mask_d, void* p_d, void* s_d, void* tmp_d, void* o_d, void* softmax_sum_d, - void* cu_q_seqlens_d, size_t const b, size_t const s, size_t const h, size_t const d, size_t const dv, - int const runs, int const warps_m, int const warps_n, bool const has_alibi) + void* qkv_d, void* vt_d, void* mask_d, void* attention_sinks_d, void* p_d, void* s_d, void* tmp_d, void* o_d, + void* softmax_sum_d, void* cu_q_seqlens_d, const size_t b, const size_t s, const size_t h, const size_t d, + const size_t dv, int const runs, int const warps_m, int const warps_n, bool const has_alibi) { cudaStream_t stream = 0; @@ -106,28 +109,28 @@ void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_ty // Softmax. if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16) { - run_softmax_fp16(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1, - warps_n, has_alibi); + run_softmax_fp16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, + softcapping_scale_bmm1, warps_n, has_alibi); } else if (data_type == DATA_TYPE_BF16 && acc_type == DATA_TYPE_FP32) { - run_softmax_bf16(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1, - warps_n, has_alibi); + run_softmax_bf16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, + softcapping_scale_bmm1, warps_n, has_alibi); } else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32) { - run_softmax_fp32(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1, - warps_n, has_alibi); + run_softmax_fp32(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, + softcapping_scale_bmm1, warps_n, has_alibi); } else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32) { - run_softmax_e4m3(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_softmax, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax_e4m3(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, + scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); } else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32) { - run_softmax_int8(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1, scale_softmax, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax_int8(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1, + scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); } else { @@ -179,7 +182,7 @@ static inline void set_params(bert::Fused_multihead_attention_params_v1& params, // types Data_type data_type, Data_type acc_type, // sizes - size_t const b, size_t const s, size_t const h, size_t const d, size_t const packed_mask_stride, + const size_t b, const size_t s, const size_t h, const size_t d, const size_t packed_mask_stride, // device pointers void* qkv_d, void* packed_mask_d, void* o_d, void* p_d, void* s_d, // scale factors @@ -235,17 +238,17 @@ static inline void set_params(bert::Fused_multihead_attention_params_v1& params, //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline void set_params(bert::Fused_multihead_attention_params_v2& params, Launch_params const launch_params, +static inline void set_params(bert::Fused_multihead_attention_params_v2& params, const Launch_params launch_params, // types Data_type data_type, Data_type acc_type, Data_type output_dtype, // attention input layout Attention_input_layout input_layout, // sizes - size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const h_kv, size_t const d, - size_t const dv, size_t const total, const size_t num_grouped_heads, const size_t sliding_window_size, + const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv, const size_t d, + const size_t dv, const size_t total, const size_t num_grouped_heads, const size_t sliding_window_size, const size_t chunked_attention_size, // paged kv cache block size. - size_t const tokens_per_block, + const size_t tokens_per_block, // device pointers void* qkv_packed_d, // contiguous q. @@ -261,8 +264,10 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params, // offsets for different blocks in terms of the start address. int32_t* paged_block_offsets, // mask input. - void* packed_mask_d, void* cu_mask_rows_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d, - void* s_d, void* softmax_stats_d, void* scale_bmm2_d, + void* packed_mask_d, void* cu_mask_rows_d, + // attention sinks. + void* attention_sinks_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d, void* s_d, + void* softmax_stats_d, void* scale_bmm2_d, // scale factors float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1, // flags @@ -329,6 +334,9 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params, // The N dimension has to be aligned. params.packed_mask_stride_in_bytes = (align_to(int64_t(s_kv), int64_t(fmha::FLASH_ATTEN_MASK_N_ALIGNMENT))) / 8; + // Attention sinks. + params.attention_sinks = reinterpret_cast(attention_sinks_d); + #if defined(STORE_P) params.p_ptr = p_d; params.p_stride_in_bytes = get_size_in_bytes(b * h * s_kv, acc_type); @@ -412,13 +420,13 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params, //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, size_t const s, - size_t const d, Attention_mask_type const attention_mask_type, Attention_input_layout const input_layout, +static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, const size_t s, + const size_t d, const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout, bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma, bool const force_non_flash_attention, bool const force_non_warp_specialization, bool const force_non_granular_tiling, bool const force_fp32_acc, // device props - cudaDeviceProp const props) + const cudaDeviceProp props) { // Set launch params to choose kernels @@ -573,6 +581,9 @@ int main(int argc, char** argv) // SageAttention block sizes int sage_block_size_q = 0, sage_block_size_k = 0, sage_block_size_v = 0; + // Use attention sinks (added to the denominator of softmax) + bool use_attention_sinks = false; + // Read the parameters from the command-line. for (int ii = 1; ii < argc; ++ii) { @@ -865,13 +876,16 @@ int main(int argc, char** argv) { sage_block_size_v = strtol(argv[ii], nullptr, 10); } + else if (!strcmp(argv[ii], "-use-attention-sinks")) + { + use_attention_sinks = true; + } else { fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]); return -1; } } - if (save_softmax == true) { if (input_layout != Attention_input_layout::CONTIGUOUS_Q_KV) @@ -1043,11 +1057,11 @@ int main(int argc, char** argv) force_non_granular_tiling, force_fp32_acc, props); // The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D. - size_t const qkv_size = s * b * h * (2 * d + dv); + const size_t qkv_size = s * b * h * (2 * d + dv); // Allocate on the host. float* qkv_h = (float*) malloc(qkv_size * sizeof(float)); // The size in bytes. - size_t const qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type); + const size_t qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type); // Allocate on the device. void *qkv_sbh3d_d = nullptr, *qkv_bsh3d_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&qkv_sbh3d_d, qkv_size_in_bytes)); @@ -1057,7 +1071,7 @@ int main(int argc, char** argv) // The shape is [B, 2, S, H, D]. const size_t kv_size = b * s * h_kv * (d + dv); // The size in bytes. - size_t const kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); + const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); // Allocate on the host. void* contiguous_kv_h = malloc(kv_size_in_bytes); // Memset the buffer. @@ -1071,13 +1085,13 @@ int main(int argc, char** argv) void** kv_cache_ptrs_h = nullptr; void* kv_cache_pool_ptr = nullptr; int32_t *kv_cache_block_offsets_h, *kv_cache_block_offsets_d = nullptr; - size_t const max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block; - size_t const num_total_blocks = b * 2 * max_blocks_per_seq; + const size_t max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block; + const size_t num_total_blocks = b * 2 * max_blocks_per_seq; kv_cache_ptrs_h = (void**) malloc(num_total_blocks * sizeof(void*)); kv_cache_block_offsets_h = (int32_t*) malloc(num_total_blocks * sizeof(int32_t)); - size_t const paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type); + const size_t paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type); FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_block_offsets_d), num_total_blocks * sizeof(int32_t))); - size_t const kv_cache_pool_sz + const size_t kv_cache_pool_sz = get_size_in_bytes(num_total_blocks * tokens_per_block * h_kv * (d + dv) / 2, data_type); FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_pool_ptr), kv_cache_pool_sz)); size_t ptr_index = 0; @@ -1104,7 +1118,7 @@ int main(int argc, char** argv) // Q will always be [B, S, H, Dh] with paged kv cache. void* q_d; - size_t const q_size = s * b * h * d; + const size_t q_size = s * b * h * d; FMHA_CHECK_CUDA(cudaMalloc(&q_d, get_size_in_bytes(q_size, data_type))); // K has [B, S, H_kv, D] with separate kv cache. @@ -1122,11 +1136,11 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMalloc(&scale_bmm2_d, sizeof(uint32_t))); // The mask for dropout or any mask patterns. - size_t const mask_size = s * b * s; + const size_t mask_size = s * b * s; // Allocate on the host. float* mask_h = (float*) malloc(mask_size * sizeof(float)); // The size in bytes. - size_t const mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); + const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); // Allocate on the device. void* mask_d = nullptr; if (!skip_checks) @@ -1158,7 +1172,7 @@ int main(int argc, char** argv) v1 ? 1 : 2); // The number of threads per CTA. - size_t const threads_per_cta = warps_m * warps_n * warps_k * 32; + const size_t threads_per_cta = warps_m * warps_n * warps_k * 32; // The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension. size_t mmas_m = (s + 16 * warps_m - 1) / (16 * warps_m); // The number of mmas in the N dimension. @@ -1182,7 +1196,7 @@ int main(int argc, char** argv) packed_mask_size = b * mmas_m * mmas_n * threads_per_cta; } // The size in bytes. - size_t const packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); + const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); // Allocate on the host. uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes); // Set it to 0 (indicates that all elements are valid). @@ -1190,12 +1204,30 @@ int main(int argc, char** argv) // Allocate on the device. void* packed_mask_d = nullptr; + // The size of the attention sinks. + const size_t attention_sinks_size_in_bytes = h * sizeof(float); + + // The attention sinks. + void* attention_sinks_d = nullptr; + if (use_attention_sinks) + { + // Allocate on the host. + float* attention_sinks_h = (float*) malloc(attention_sinks_size_in_bytes); + // Randomly initialize the attention sinks. + random_init("attention_sinks", attention_sinks_h, 1, h, 1, false, 5.f, 1.f, verbose); + // Allocate on the device. + FMHA_CHECK_CUDA(cudaMalloc(&attention_sinks_d, attention_sinks_size_in_bytes)); + // Copy from the host to the device. + FMHA_CHECK_CUDA( + cudaMemcpy(attention_sinks_d, attention_sinks_h, attention_sinks_size_in_bytes, cudaMemcpyDefault)); + } + // The O matrix is packed as S * B * H * D. - size_t const o_size = s * b * h * dv; + const size_t o_size = s * b * h * dv; // Allocate on the host. float* o_h = (float*) malloc(o_size * sizeof(float)); // The size in bytes. - size_t const o_size_in_bytes = get_size_in_bytes(o_size, data_type); + const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type); // Allocate on the device. void* o_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes)); @@ -1206,7 +1238,7 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMemset(softmax_stats_d, 0x00, 2 * sizeof(float) * b * s * h)); // The size in bytes. - size_t const tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); + const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); // Allocate on the device. void* tmp_d = nullptr; if (data_type != acc_type) @@ -1220,9 +1252,9 @@ int main(int argc, char** argv) float* softmax_sum_h = (float*) malloc(b * s * h * sizeof(float)); // The P matrix is stored as one big matrix of size S x B x H x S. - size_t const p_size = s * b * h * s; + const size_t p_size = s * b * h * s; // The size in bytes. - size_t const p_size_in_bytes = get_size_in_bytes(p_size, acc_type); + const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type); // Allocate on the device. void* p_d = nullptr; if (!skip_checks) @@ -1238,7 +1270,7 @@ int main(int argc, char** argv) #endif // defined(STORE_P) // The size in bytes of the S matrix (the data type may be different from P for int8). - size_t const s_size_in_bytes = get_size_in_bytes(p_size, data_type); + const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type); // Allocate on the device. void* s_d = nullptr; if (!skip_checks) @@ -1327,7 +1359,7 @@ int main(int argc, char** argv) std::vector seqlens(b, 0); // randomly draw a batch of sequence lengths >= min_s std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(), - [=](uint32_t const) + [=](const uint32_t) { if (fix_s) { @@ -1415,7 +1447,7 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_packed_d, mqa_qkv_packed_size_in_bytes)); FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_d, mqa_qkv_size_in_bytes)); - size_t const o_packed_size = cu_seqlens.back() * h * dv; + const size_t o_packed_size = cu_seqlens.back() * h * dv; // Allocate on the host. float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float)); void* o_packed_d = nullptr; @@ -1676,9 +1708,9 @@ int main(int argc, char** argv) total, num_grouped_heads, sliding_window_size, chunked_attention_size, // Paged kv cache. tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, kv_cache_block_offsets_d, - packed_mask_d, cu_mask_rows_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d, softmax_stats_ptr, - scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, use_int8_scale_max, interleaved, - is_s_padded, has_alibi); + packed_mask_d, cu_mask_rows_d, attention_sinks_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d, + softmax_stats_ptr, scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, + use_int8_scale_max, interleaved, is_s_padded, has_alibi); // total number of tokens is needed to set TMA desc on the host. launch_params.total_q_seqlen = q_seqlens[b]; @@ -1894,8 +1926,8 @@ int main(int argc, char** argv) ground_truth(bmm1, bmm2, data_type, acc_type, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, qkv_sbh3d_d, vt_d, // WAR pass in V' - mask_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs, warps_m, warps_n, - has_alibi); + mask_d, attention_sinks_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs, + warps_m, warps_n, has_alibi); timer.stop(); FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaDeviceSynchronize()); @@ -2009,7 +2041,6 @@ int main(int argc, char** argv) // Extract the last s_q tokens from the output. extract_and_transpose_output( o_ref_trans_h.data(), o_ref_h, seqlens, q_seqlens, s, s_q, b, h, dv, is_s_padded); - if (verbose) { printf("\nChecking .....: O = V * S\n"); diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_attention.h b/cpp/kernels/fmha_v2/src/fused_multihead_attention.h index f77e3f14d0c..16e2f9a8db5 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_attention.h +++ b/cpp/kernels/fmha_v2/src/fused_multihead_attention.h @@ -197,6 +197,9 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba // The stride between rows of softmax_stats_ptr int64_t softmax_stats_stride_in_bytes; + // The attention sinks (per head). + float* attention_sinks; + // array of length b+1 holding prefix sum of actual q sequence lengths. int* cu_q_seqlens; // array of length b+1 holding prefix sum of actual kv sequence lengths. diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h b/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h index 76670971e57..bacb4938cf2 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h +++ b/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h @@ -87,6 +87,8 @@ struct Fused_multihead_attention_params_v2 fmha::Kv_block_array paged_kv_cache; // The mask to implement drop-out. void* packed_mask_ptr; + // The attention sinks (per head). + float* attention_sinks; // The O matrix (output). void* o_ptr; // The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_cross_attention.cpp b/cpp/kernels/fmha_v2/src/fused_multihead_cross_attention.cpp index 8a2e7a8fc0c..6e37fc6ab43 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_cross_attention.cpp +++ b/cpp/kernels/fmha_v2/src/fused_multihead_cross_attention.cpp @@ -23,25 +23,27 @@ using Launch_params = bert::Fused_multihead_attention_launch_params; //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); +void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d, - int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi); +void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, + int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); +void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d, - int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, int warps_n, - bool has_alibi); +void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, + float softcapping_scale_bmm1, int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -57,10 +59,10 @@ void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h //////////////////////////////////////////////////////////////////////////////////////////////////// -void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_type const acc_type, +void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type, float const scale_bmm1, float const scale_softmax, float const scale_bmm2, void* q_d, void* kv_d, void* vt_d, void* mask_d, void* p_d, void* s_d, void* tmp_d, void* o_d, void* softmax_sum_d, void* cu_seqlens_q_d, - size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const d, int const runs, + const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t d, int const runs, int const warps_m, int const warps_n, bool has_alibi) { @@ -84,20 +86,22 @@ void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_ty // Softmax. if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16) { - run_softmax_fp16(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi); + run_softmax_fp16( + s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi); } else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32) { - run_softmax_fp32(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi); + run_softmax_fp32( + s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi); } else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32) { - run_softmax_e4m3(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_softmax, 0.f, - warps_n, has_alibi); + run_softmax_e4m3(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_softmax, + 0.f, warps_n, has_alibi); } else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32) { - run_softmax_int8(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_bmm1, + run_softmax_int8(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_bmm1, scale_softmax, 0.f, warps_n, has_alibi); } else @@ -148,8 +152,8 @@ static inline void set_params(bert::Fused_multihead_attention_params_mhca& param // types Data_type data_type, Data_type acc_type, // sizes - size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const d, size_t const d_padded, - size_t const total, + const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t d, const size_t d_padded, + const size_t total, // device pointers void* q_packed_d, void* kv_packed_d, void* cu_seqlens_q_d, void* cu_seqlens_kv_d, void* o_packed_d, void* p_d, void* s_d, @@ -515,17 +519,17 @@ int main(int argc, char** argv) launch_params.use_tma = use_tma; // The Q matrix of size S_Q x B x H x D. - size_t const q_size = s_q * b * h * d; + const size_t q_size = s_q * b * h * d; // The K and V matrices are packed into one big matrix of size S_KV x B x H x 2 x D. - size_t const kv_size = s_kv_padded * b * h * 2 * d; + const size_t kv_size = s_kv_padded * b * h * 2 * d; // Allocate on the host. float* q_h = (float*) malloc(q_size * sizeof(float)); // Allocate on the host. float* kv_h = (float*) malloc(kv_size * sizeof(float)); // The size in bytes. - size_t const q_size_in_bytes = get_size_in_bytes(q_size, data_type); + const size_t q_size_in_bytes = get_size_in_bytes(q_size, data_type); // The size in bytes. - size_t const kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); + const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); // Allocate on the device. void* q_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&q_d, q_size_in_bytes)); @@ -534,11 +538,11 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMalloc(&kv_d, kv_size_in_bytes)); // The mask for dropout. - size_t const mask_size = s_q * b * s_kv_padded; + const size_t mask_size = s_q * b * s_kv_padded; // Allocate on the host. float* mask_h = (float*) malloc(mask_size * sizeof(float)); // The size in bytes. - size_t const mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); + const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); // Allocate on the device. void* mask_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&mask_d, mask_size_in_bytes)); @@ -554,28 +558,28 @@ int main(int argc, char** argv) v1 ? 1 : 2); // The number of threads per CTA. - size_t const threads_per_cta = warps_m * warps_n * warps_k * 32; + const size_t threads_per_cta = warps_m * warps_n * warps_k * 32; // The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension. - size_t const mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m); + const size_t mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m); // The number of mmas in the N dimension. - size_t const mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n); + const size_t mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n); // We do not support more than 4 MMAS in the N dimension (as each MMA needs 8 bits in the mask). assert(!v1 || mmas_n <= 4); // The packed mask for dropout (in the fused kernel). Layout is B * MMAS_M * THREADS_PER_CTA. - size_t const packed_mask_size = b * mmas_m * threads_per_cta; + const size_t packed_mask_size = b * mmas_m * threads_per_cta; // The size in bytes. - size_t const packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); + const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); // Allocate on the host. uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes); // Allocate on the device. void* packed_mask_d = nullptr; // The O matrix is packed as S_Q * B * H * D. - size_t const o_size = s_q * b * h * d; + const size_t o_size = s_q * b * h * d; // Allocate on the host. float* o_h = (float*) malloc(o_size * sizeof(float)); // The size in bytes. - size_t const o_size_in_bytes = get_size_in_bytes(o_size, data_type); + const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type); // Allocate on the device. void* o_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes)); @@ -587,7 +591,7 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMemset(softmax_max_d, 0x00, sizeof(float) * b * s_q * h)); // The size in bytes. - size_t const tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); + const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); // Allocate on the device. void* tmp_d = nullptr; if (data_type != acc_type) @@ -599,9 +603,9 @@ int main(int argc, char** argv) float* o_ref_h = (float*) malloc(o_size * sizeof(float)); // The P matrix is stored as one big matrix of size S_Q x B x H x S_KV. - size_t const p_size = s_q * b * h * s_kv_padded; + const size_t p_size = s_q * b * h * s_kv_padded; // The size in bytes. - size_t const p_size_in_bytes = get_size_in_bytes(p_size, acc_type); + const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type); // Allocate on the device. void* p_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&p_d, p_size_in_bytes)); @@ -614,7 +618,7 @@ int main(int argc, char** argv) #endif // defined(STORE_P) // The size in bytes of the S matrix (the data type may be different from P for int8). - size_t const s_size_in_bytes = get_size_in_bytes(p_size, data_type); + const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type); // Allocate on the device. void* s_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&s_d, s_size_in_bytes)); @@ -634,9 +638,9 @@ int main(int argc, char** argv) // WAR fOR MISSING CUBLAS FP8 NN SUPPORT. // Transpose V, so that we can do a TN BMM2, i.e. O = S x V' instead of O = S x V. - size_t const v_size = s_kv_padded * b * h * d; + const size_t v_size = s_kv_padded * b * h * d; // The size in bytes. - size_t const v_size_in_bytes = get_size_in_bytes(v_size, data_type); + const size_t v_size_in_bytes = get_size_in_bytes(v_size, data_type); float* vt_h = (float*) malloc(v_size * sizeof(float)); void* vt_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&vt_d, v_size_in_bytes)); @@ -676,7 +680,7 @@ int main(int argc, char** argv) = [min_s, fix_s, b](int s, std::vector& seqlens, std::vector& cu_seqlens, void** cu_seqlens_d) { std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(), - [=](uint32_t const) + [=](const uint32_t) { if (fix_s) { @@ -728,7 +732,7 @@ int main(int argc, char** argv) void* kv_packed_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&kv_packed_d, kv_packed_size_in_bytes)); - size_t const o_packed_size = cu_seqlens_q.back() * h * d; + const size_t o_packed_size = cu_seqlens_q.back() * h * d; // Allocate on the host. float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float)); float* o_ref_packed_h = (float*) malloc(o_packed_size * sizeof(float)); diff --git a/cpp/kernels/fmha_v2/src/softmax_bf16.cu b/cpp/kernels/fmha_v2/src/softmax_bf16.cu index 5212d317174..79b681b5023 100644 --- a/cpp/kernels/fmha_v2/src/softmax_bf16.cu +++ b/cpp/kernels/fmha_v2/src/softmax_bf16.cu @@ -12,9 +12,10 @@ #include "softmax_impl.h" -void run_softmax_bf16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi) +void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi) { - run_softmax(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, + b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/fmha_v2/src/softmax_fp16.cu b/cpp/kernels/fmha_v2/src/softmax_fp16.cu index 1fb68b1136d..9df37605a2e 100644 --- a/cpp/kernels/fmha_v2/src/softmax_fp16.cu +++ b/cpp/kernels/fmha_v2/src/softmax_fp16.cu @@ -12,9 +12,10 @@ #include "softmax_impl.h" -void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi) +void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi) { - run_softmax(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, + h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/fmha_v2/src/softmax_fp32.cu b/cpp/kernels/fmha_v2/src/softmax_fp32.cu index 2b3bb6acbb7..12bcd8624d9 100644 --- a/cpp/kernels/fmha_v2/src/softmax_fp32.cu +++ b/cpp/kernels/fmha_v2/src/softmax_fp32.cu @@ -12,9 +12,10 @@ #include "softmax_impl.h" -void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi) +void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi) { - run_softmax(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, + b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/fmha_v2/src/softmax_fp8.cu b/cpp/kernels/fmha_v2/src/softmax_fp8.cu index 0a8e5f50299..26c2f5e88d7 100644 --- a/cpp/kernels/fmha_v2/src/softmax_fp8.cu +++ b/cpp/kernels/fmha_v2/src/softmax_fp8.cu @@ -12,10 +12,10 @@ #include "softmax_impl.h" -void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi) +void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, + int warps_n, bool has_alibi) { - run_softmax(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, - scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, + b, h, 0.f, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/fmha_v2/src/softmax_impl.h b/cpp/kernels/fmha_v2/src/softmax_impl.h index 2bc9f3380be..ca652627442 100644 --- a/cpp/kernels/fmha_v2/src/softmax_impl.h +++ b/cpp/kernels/fmha_v2/src/softmax_impl.h @@ -10,6 +10,7 @@ * its affiliates is strictly prohibited. */ +#include #include #include #include @@ -33,6 +34,8 @@ struct Softmax_params Src_type const* src; // Masks. int8_t const* mask; + // Attention sinks (per head). + float const* attention_sinks; // Softmax sum pointer. float* softmax_sum; // ALiBi @@ -148,7 +151,8 @@ static inline __device__ float apply_exp_(float x, float max) //////////////////////////////////////////////////////////////////////////////////////////////////// template -static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&mask)[N][1], int warps_n, float& sum_fp32) +static inline __device__ void reduce( + float (&data_fp32)[N][1], const int8_t (&mask)[N][1], int warps_n, float& sum_fp32, float const attention_sink) { // Apply the masks. @@ -233,7 +237,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma } // Normalize. - float inv_sum_fp32 = 1.f / sum_fp32; + float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); #pragma unroll for (int ii = 0; ii < N; ++ii) { @@ -244,7 +248,8 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma //////////////////////////////////////////////////////////////////////////////////////////////////// template -static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&mask)[N][2], int warps_n, float& sum_fp32) +static inline __device__ void reduce( + float (&data_fp32)[N][2], const int8_t (&mask)[N][2], int warps_n, float& sum_fp32, float const attention_sink) { // Apply the masks. #pragma unroll @@ -401,7 +406,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma } // Normalize. - float inv_sum_fp32 = 1.f / sum_fp32; + float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); #pragma unroll for (int ii = 0; ii < N; ++ii) { @@ -413,7 +418,8 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma //////////////////////////////////////////////////////////////////////////////////////////////////// template -static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&mask)[N][4], int warps_n, float& sum_fp32) +static inline __device__ void reduce( + float (&data_fp32)[N][4], const int8_t (&mask)[N][4], int warps_n, float& sum_fp32, float const attention_sink) { // Apply the masks. @@ -824,7 +830,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&ma } // Normalize. - float inv_sum_fp32 = 1.f / sum_fp32; + float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); #pragma unroll for (int ii = 0; ii < N; ++ii) { @@ -994,9 +1000,16 @@ static __global__ void softmax_kernel(Softmax_params params) } } + // The attention sink value. + float attention_sink = -FLT_MAX; + if (params.attention_sinks != nullptr) + { + attention_sink = params.attention_sinks[hi]; + } + // Do the reduction. float sum_fp32 = 0.f; - reduce(data_fp32, mask_, params.warps_n, sum_fp32); + reduce(data_fp32, mask_, params.warps_n, sum_fp32, attention_sink); if (threadIdx.x == 0) { int sum_s = params.cu_q_seqlens[bi]; @@ -1025,9 +1038,9 @@ static __global__ void softmax_kernel(Softmax_params params) //////////////////////////////////////////////////////////////////////////////////////////////////// template -void run_softmax(void* dst, void const* src, void const* mask, void* softmax_sum, void* cu_q_seqlens, int s_inner, - int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi) +void run_softmax(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum, + void* cu_q_seqlens, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, + float softcapping_scale_bmm1, int warps_n, bool has_alibi) { Softmax_params params; @@ -1039,6 +1052,7 @@ void run_softmax(void* dst, void const* src, void const* mask, void* softmax_sum params.softmax_sum = reinterpret_cast(softmax_sum); params.cu_q_seqlens = reinterpret_cast(cu_q_seqlens); params.mask = reinterpret_cast(mask); + params.attention_sinks = reinterpret_cast(attention_sinks); params.has_alibi = has_alibi; // The dimensions and precomputed values. diff --git a/cpp/kernels/fmha_v2/src/softmax_int8.cu b/cpp/kernels/fmha_v2/src/softmax_int8.cu index 772fe1520ce..28701de9789 100644 --- a/cpp/kernels/fmha_v2/src/softmax_int8.cu +++ b/cpp/kernels/fmha_v2/src/softmax_int8.cu @@ -12,10 +12,10 @@ #include "softmax_impl.h" -void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1, - int warps_n, bool has_alibi) +void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, + float softcapping_scale_bmm1, int warps_n, bool has_alibi) { - run_softmax(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, scale_bmm1, - scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, + scale_bmm1, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/xqa/mha.cu b/cpp/kernels/xqa/mha.cu index c9690cbc6b0..69d93e901c3 100644 --- a/cpp/kernels/xqa/mha.cu +++ b/cpp/kernels/xqa/mha.cu @@ -1379,6 +1379,19 @@ __device__ inline ThrdRegRowMax mergeRowMax( return mergedRowMax; } +__device__ inline void addAttentionSinks( + ThrdRegRowMax& globalRowSum, ThrdRegRowMax const globalRowMax, float const* attentionSinks) +{ + for (uint32_t i = 0; i < globalRowSum.size; i++) + { + uint32_t srcOffset = warp_size * i + laneId(); + if (srcOffset < headGrpSize) + { + globalRowSum[i] += expf(attentionSinks[srcOffset] - globalRowMax[i]); + } + } +} + #ifdef NDEBUG __device__ __forceinline__ #else @@ -1405,6 +1418,7 @@ CUBIN_EXPORT __global__ #if SPEC_DEC MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32)]. #endif + float const* attentionSinks, // [headGrpSize] #ifdef NDEBUG KVCacheList const& cacheList, #if BEAM_WIDTH > 1 @@ -2371,6 +2385,12 @@ CUBIN_EXPORT __global__ float voScale = (isKVCacheQuantized ? kvCacheScale[0] : 1.F); if (seqIterInit < nbSeqIters) { // otherwise rcpRowSum will be NAN. + // The attention sinks are moved to the multi-block reduction part if the multi-block is enabled. + if (!isMultiBlock && attentionSinks != nullptr) + { + // Attention sinks are per head. + addAttentionSinks(globalRowSum, globalRowMax, attentionSinks + headGrpSize * idxHeadGrp); + } ThrdRegRowMax const rcpRowSum = __frcp_rn(globalRowSum); #if LOW_PREC_OUTPUT voScale *= rcpOutScale[0]; @@ -2559,6 +2579,11 @@ CUBIN_EXPORT __global__ assert(std::isfinite(mergedRowSum[0])); } } + if (attentionSinks != nullptr) + { + // Attention sinks are per head. + addAttentionSinks(mergedRowSum, mergedRowMax, attentionSinks + headGrpSize * idxHeadGrp); + } __syncthreads(); rescaleAcc(warp, sumAcc, fullRescaleMask, __frcp_rn(mergedRowSum)); GemmOutRegTile const mergedOutTile = toFp16(sumAcc); @@ -2615,6 +2640,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32))] uint2 (each bit represents mask for one col // position). #endif + float const* attentionSinks, // [headGrpSize] KVCacheList const cacheList, #if BEAM_WIDTH > 1 BeamSearchParams const beamSearchParams, @@ -2640,7 +2666,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( #if SPEC_DEC mask, #endif - cacheList, + attentionSinks, cacheList, #if BEAM_WIDTH > 1 beamSearchParams, #endif @@ -2667,6 +2693,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else InputHead const* q, #endif + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, @@ -2760,7 +2787,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #if SPEC_DEC mask, #endif - cacheList, + attentionSinks, cacheList, #if BEAM_WIDTH > 1 beamSearchParams, #endif @@ -2788,7 +2815,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #if SPEC_DEC mask, #endif - cacheList, + attentionSinks, cacheList, #if BEAM_WIDTH > 1 beamSearchParams, #endif diff --git a/cpp/kernels/xqa/mha.h b/cpp/kernels/xqa/mha.h index 39c94f985ec..d35ad48104a 100644 --- a/cpp/kernels/xqa/mha.h +++ b/cpp/kernels/xqa/mha.h @@ -101,6 +101,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads, #else InputHead const* q, #endif + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, @@ -140,6 +141,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else InputHead const* q, #endif + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, diff --git a/cpp/kernels/xqa/mha_sm90.cu b/cpp/kernels/xqa/mha_sm90.cu index 88d4c75e30b..9a438df9a2a 100644 --- a/cpp/kernels/xqa/mha_sm90.cu +++ b/cpp/kernels/xqa/mha_sm90.cu @@ -428,6 +428,7 @@ __device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src); __device__ void storeGemm0AccToShm( uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc); __device__ RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec); +__device__ RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound); #else __device__ RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, ShmQWiseVec& smemColMax, Gemm0Acc const& src); __device__ void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd); @@ -453,7 +454,8 @@ __device__ void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec template __device__ void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, - ShmQWiseVec const& accColSum, uint32_t nbKHeads = 0 /* only for final result in spec dec. */); + ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, + uint32_t nbKHeads = 0 /* only for final result in spec dec. */); #else __device__ void transposeVTile( uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, SharedMem::VBuffer const& src); @@ -651,6 +653,7 @@ CUBIN_EXPORT __global__ #else IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads], #endif + float const* attentionSinks, // [headGrpSize] KVCacheList const cacheList, #if USE_BEAM_SEARCH BeamSearchParams const beamSearchParams, @@ -1252,7 +1255,7 @@ CUBIN_EXPORT __global__ IOHead* const dst = (scratchMem.tokens() + idxChunk).template cast(); #if SWAP_AB finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, - smem.gemm1WarpGrpBar, smem.gemm1AccColSum); + smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, nullptr); #else finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, smem.gemm1AccColSum, 1, ctaNbValidTokens); @@ -1262,9 +1265,16 @@ CUBIN_EXPORT __global__ { uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); OutputHead* const dst = &output[outOffset]; + ShmQWiseVec const* attentionSinksVec = nullptr; + if (attentionSinks != nullptr) + { + attentionSinksVec + = reinterpret_cast(attentionSinks + headGrpSize * idxHeadGrp); + } #if SWAP_AB finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, - xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, nbKHeads); + xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, attentionSinksVec, + nbKHeads); #else finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, smem.gemm1AccColSum, nbKHeads, ctaNbValidTokens); @@ -1585,6 +1595,17 @@ CUBIN_EXPORT __global__ } unused(bar.consumed.arrive()); } + // Add the attention sinks. + if (attentionSinks != nullptr) + { + for (uint32_t i = 0; i < headsPerWarp; i++) + { + uint32_t const idxHead = wid + nbMathWarps * i; + float sink = expf( + attentionSinks[mha::min(idxHead, headGrpSize - 1) + idxHeadGrp * headGrpSize] - states[i].max); + states[i].sum += sink; + } + } __syncthreads(); uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); auto const dst = &output[outOffset]; @@ -2029,6 +2050,22 @@ __device__ inline RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smem return ret; } +__device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound) +{ + RegColWiseVec ret; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); + auto const idx = laneId() % nbThrdsPerInstNBase; +#pragma unroll + for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) + { + static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); + ret[i] = reinterpret_cast< + Vec, exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( + gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)]; + } + return ret; +} + __device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd) { uint32_t const idxInQuad = laneId() % 4; @@ -2878,12 +2915,19 @@ __device__ inline void saveTransposedOutput(uint32_t threadRank, uint32_t warpRa template __device__ inline void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, - ShmQWiseVec const& accColSum, uint32_t nbKHeads) + ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, uint32_t nbKHeads) { // @fixme: if ctaNbQHeads is large, use loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of mufu.rcp // static_assert(ctaNbQHeads <= 8, "Warning: consider using loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of // mufu.rcp"); - auto const regColSum = loadShmColWiseVecWithDup(accColSum); + auto regColSum = loadShmColWiseVecWithDup(accColSum); + if (attentionSinksVec != nullptr) + { + auto const regAccColMax = loadShmColWiseVecWithDup(accColMax); + auto const regAttentionSinks = loadGmemColWiseVecWithDup(attentionSinksVec[0], headGrpSize - 1); + auto regColSinks = expf(regAttentionSinks - regAccColMax); + regColSum = regColSum + regColSinks; + } auto const regOutScale = __frcp_rn(regColSum) * xvoScale; rescaleAcc(acc, regOutScale); @@ -3175,6 +3219,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else InputHead const* q, #endif + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, @@ -3286,7 +3331,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else q, #endif - cacheList, + attentionSinks, cacheList, #if USE_BEAM_SEARCH beamSearchParams, #endif @@ -3322,7 +3367,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else q, #endif - cacheList, + attentionSinks, cacheList, #if USE_BEAM_SEARCH beamSearchParams, #endif diff --git a/cpp/kernels/xqa/mla_sm120.cu b/cpp/kernels/xqa/mla_sm120.cu index 74877512a7d..072908fe3e8 100644 --- a/cpp/kernels/xqa/mla_sm120.cu +++ b/cpp/kernels/xqa/mla_sm120.cu @@ -1859,12 +1859,13 @@ CUtensorMap makeTensorMapForQ( #endif // IS_MLA void launchMLA(cudaDeviceProp const& prop, - uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed + uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed float qScale, OutputHead* output, InputHead const* q, + float* attentionSinks, // [headGrpSize], not supported. #if USE_PAGED_KV_CACHE - GMemCacheHead* pool, // global pool of pages + GMemCacheHead* pool, // global pool of pages KVCachePageIndex const* - kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] + kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] #else GMemKVCacheHead* kvCacheData, #endif diff --git a/cpp/kernels/xqa/test/refAttention.cpp b/cpp/kernels/xqa/test/refAttention.cpp index d8f1a688f5d..dd356c101c0 100644 --- a/cpp/kernels/xqa/test/refAttention.cpp +++ b/cpp/kernels/xqa/test/refAttention.cpp @@ -45,7 +45,7 @@ using Vector = Matrix; template Eigen::Matrix refFlashAttention(IOHead const* q, CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize) + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks) { uint32_t const nbTiles = divUp(seqLen, tileSize); auto gemm1Acc = Eigen::Matrix::Zero().eval(); @@ -113,6 +113,16 @@ Eigen::Matrix refFlashAt } rowSum += tileRowSum; } + + // Add the attention sinks. + if (attentionSinks != nullptr) + { + for (uint32_t i = 0; i < headGrpSize; i++) + { + rowSum[i] += expf(attentionSinks[i] - rowMax[i]); + } + } + Eigen::Matrix out = gemm1Acc.array().colwise() * (xScale * kvScale / rowSum.array()); std::for_each(out.data(), out.data() + out.size(), [](float& e) { e = float(OutputElem(e)); }); @@ -123,7 +133,7 @@ Eigen::Matrix refFlashAt template Eigen::Matrix \ refFlashAttention(IOHead const* q, \ CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, \ - float qScale, float kvScale, float xScale, uint32_t slidingWinSize) + float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks) INSTANTIATE_refFlashAttention(CacheElem, 64, false, false); INSTANTIATE_refFlashAttention(CacheElem, 64, false, true); @@ -143,7 +153,7 @@ Eigen::Matrix refAttenti #else Eigen::Matrix refAttention(IOHead const* q, CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize) + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks) { #endif float const rcpXScale = 1.f / xScale; @@ -184,7 +194,7 @@ Eigen::Matrix refAttenti Eigen::Matrix x = (gemm0Acc.colwise() - rowMax).array().exp().eval(); - Eigen::Vector const rowSum = x.rowwise().sum().eval(); + Eigen::Vector rowSum = x.rowwise().sum().eval(); std::for_each(x.data(), x.data() + x.size(), [&](float& e) { e = float(MathElem(e * rcpXScale)); }); @@ -200,6 +210,18 @@ Eigen::Matrix refAttenti } } } + + // Add the attention sinks. +#if !SPEC_DEC + if (attentionSinks != nullptr) + { + for (uint32_t i = 0; i < headGrpSize; i++) + { + rowSum[i] += expf(attentionSinks[i] - rowMax[i]); + } + } +#endif + Eigen::Matrix out = gemm1Acc.array().colwise() * (xScale * kvScale / rowSum.array()); std::for_each(out.data(), out.data() + out.size(), [](float& e) { e = float(OutputElem(e)); }); @@ -217,7 +239,7 @@ Eigen::Matrix refAttenti template Eigen::Matrix \ refAttention(IOHead const* q, CacheSeq const& k, \ CacheSeq const& v, uint32_t seqLen, float qScale, float kvScale, float xScale, \ - uint32_t slidingWinSize) + uint32_t slidingWinSize, float* attentionSinks) #endif INSTANTIATE_refAttention(InputElem, false, false); INSTANTIATE_refAttention(InputElem, false, true); diff --git a/cpp/kernels/xqa/test/refAttention.h b/cpp/kernels/xqa/test/refAttention.h index bfab1418294..a073ed0e801 100644 --- a/cpp/kernels/xqa/test/refAttention.h +++ b/cpp/kernels/xqa/test/refAttention.h @@ -83,7 +83,7 @@ struct CacheSeq template Eigen::Matrix refFlashAttention(IOHead const* q, CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize); + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks); template #if SPEC_DEC @@ -93,7 +93,7 @@ Eigen::Matrix refAttenti #else Eigen::Matrix refAttention(IOHead const* q, CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize); + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks); #endif template diff --git a/cpp/kernels/xqa/test/test.cpp b/cpp/kernels/xqa/test/test.cpp index b9228578623..91b35f3e1a4 100644 --- a/cpp/kernels/xqa/test/test.cpp +++ b/cpp/kernels/xqa/test/test.cpp @@ -130,7 +130,7 @@ template #endif #endif void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false, - bool saveData = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30) + bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30) { #if IS_MLA if (nbKHeads != 1) @@ -613,6 +613,17 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, } } + // Allocate the attention sinks (per head) + auto attentionSinks = ManagedMemBuf(nbQHeads); + // The attention sinks ptr. + float* attentionSinksPtr = hasAttentionSinks ? reinterpret_cast(attentionSinks.get()) : nullptr; + // Initialize the attention sinks (use large values to detect the potential bugs). + for (uint32_t i = 0; i < nbQHeads; i++) + { + // Range: [2, 5] + attentionSinks.get()[i] = 2.f + float(i % 4); + } + if (verbose) { printf("migrating data to gpu\n"); @@ -640,6 +651,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, #if BEAM_WIDTH > 1 cacheIndir.prefetch(dev, stream); #endif + attentionSinks.prefetch(dev, stream); }; prefetchToDevice(device); checkCuda(cudaMemsetAsync(semaphores.get(), 0, 4 * nbSemaphores, stream)); @@ -720,6 +732,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, &qHeads[0][0][0], #endif #endif + attentionSinksPtr, #if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE cacheKHeads.get(), cacheVHeads.get(), #else @@ -1028,10 +1041,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, hostMask, qSeqLen, q_len); #else Eigen::Matrix refOutput; + auto const refAttentionSinks + = hasAttentionSinks ? attentionSinksPtr + headGrpSize * idxKHead : nullptr; if (useQGMMA) { refOutput = refFlashAttention(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, - vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize); + vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, + refAttentionSinks); // refOutput = refAttention(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, // vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize); } @@ -1039,8 +1055,9 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, { // refOutput = refFlashAttention(&qHeads[req][b][headGrpSize * idxKHead], // kCacheSeq, vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale); - refOutput = refAttention(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, - vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize); + refOutput + = refAttention(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, vCacheSeq, + seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks); } #endif if (lowPrecOutput) @@ -1196,11 +1213,23 @@ TEST(RefCheck, llama_V2_70b) runTest<2>(2, 514, false, true); runTest<1>(1, 4096, false, true); #if SLIDING_WINDOW - runTest<2>(2, 4096, false, true, false, false, ~0, 256); - runTest<2>(2, 400, false, true, false, false, ~0U, 256); + runTest<2>(2, 4096, false, true, false, false, false, ~0, 256); + runTest<2>(2, 400, false, true, false, false, false, ~0U, 256); #endif runTest<8>(120, 367, false, true); - // runTest<8>(1792, 2048, false, true); + runTest<8>(1792, 2048, false, true); +} + +TEST(RefCheck, attention_sinks) +{ + auto runAttentionSinksTest = [](uint32_t batchSize, uint32_t seqLen) + { runTest<8>(batchSize, seqLen, false, true, false, false, /*hasAttentionSinks*/ true); }; + + runAttentionSinksTest(2, 2); + runAttentionSinksTest(2, 15); + runAttentionSinksTest(2, 256); + runAttentionSinksTest(2, 514); + runAttentionSinksTest(1, 4096); } TEST(Perf, tracing_long) @@ -1264,7 +1293,7 @@ TEST(Perf, mlperf_gptj) #ifndef NDEBUG GTEST_SKIP() << "Skipping perf tests for debug build"; #endif - runTest<32>(396, 800 + 224, true, false, false, false, 800); + runTest<32>(396, 800 + 224, true, false, false, false, false, 800); } TEST(Perf, mlperf_llama) diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h index 565c170e1df..2559ae54840 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h @@ -53,6 +53,7 @@ using namespace CUTLASS_MOE_GEMM_KERNELS_NAMESPACE; using CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput; using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::CutlassMoeFCRunner; using CUTLASS_MOE_GEMM_NAMESPACE::ActivationType; +using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams; using CUTLASS_MOE_GEMM_NAMESPACE::isGatedActivation; static BufferManager::CudaStreamPtr streamPtr; @@ -980,11 +981,11 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture auto stream = streamPtr->get(); MoeMinLatencyParams min_latency_params; #ifdef USING_OSS_CUTLASS_MOE_GEMM - mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, + mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, true, mSelectedExperts + mSelectedExpertsSize * mBufferIndex, mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr, mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex, - mActType, mExpertWeight2 + mExpertWeight2Size * mBufferIndex, + ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex, mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex, mFinalOutput + mFinalOutputSize * mBufferIndex, @@ -992,11 +993,11 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture /*enable_alltoall=*/false, mUseLora, mLoraParams[mBufferIndex], /*use_fp8_block_scaling=*/false, /*min_latency_mode=*/false, min_latency_params, stream); #else - mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, + mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, true, mSelectedExperts + mSelectedExpertsSize * mBufferIndex, mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr, mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex, - mActType, mExpertWeight2 + mExpertWeight2Size * mBufferIndex, + ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex, mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex, mFinalOutput + mFinalOutputSize * mBufferIndex, diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index d95ca1b412b..503c2e6c5d0 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -75,6 +75,7 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques bool CacheFormatter::needSendCache( CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx) { + // int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism; auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); if (targetInfo.mDupHeadFactor <= 1) { @@ -89,9 +90,8 @@ bool CacheFormatter::needSendCache( = selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize; selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup; } - int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0; - return (destDPRank % targetInfo.mDupHeadFactor) == (selfTpRankInDpGroup % targetInfo.mDupHeadFactor); + return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0; } void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig, @@ -128,12 +128,11 @@ std::vector CacheFormatter::pickRecvConnections( return ret; } TLLM_CHECK(numConnections == targetInfo.mIRanks.size()); - int selfDPRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0; std::vector ret; for (int i = 0; i < targetInfo.mDomainTPSize; i++) { - if ((i % targetInfo.mPeerDupHeadFactor) == (selfDPRank % targetInfo.mPeerDupHeadFactor)) + if (i % targetInfo.mPeerDupHeadFactor == 0) { for (int j = 0; j < targetInfo.mDomainPPSize; j++) { @@ -361,7 +360,7 @@ void CacheFormatter::format(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration(endTime - startTime).count()); - kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, size); }; if (connections.size() > 1) @@ -713,7 +712,7 @@ void CacheFormatter::unformat(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration(endTime - startTime).count()); - kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, size); }; if (pickUpConnections.size() > 1) { @@ -847,16 +846,23 @@ void CacheFormatter::unformat(TransferSession& session) } int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(); int selfPPSize = selfConfig.getParallelConfig().mPipelineParallelism; + int destPPSize = destConfig.getParallelConfig().mPipelineParallelism; + int destNumLayers = destConfig.getModelConfig().mNbKvHeadsPerLayer.size(); + + if (selfPPSize == destPPSize) + { + return true; + } if (selfNumLayers % selfPPSize != 0) { - TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers must be divisible by pipeline parallelism"); + TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d", + selfNumLayers, selfPPSize); return false; } - int destNumLayers = destConfig.getModelConfig().mNbKvHeadsPerLayer.size(); - int destPPSize = destConfig.getParallelConfig().mPipelineParallelism; if (destNumLayers % destPPSize != 0) { - TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers must be divisible by pipeline parallelism"); + TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d ", + destNumLayers, destPPSize); return false; } return true; diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h index ee199c2fb1c..8ae8ee5f2ca 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h @@ -76,15 +76,6 @@ class BaseCacheFormatter /// @brief Destructor. virtual ~BaseCacheFormatter() = default; - - // TODO: better way for context/generation tagging - void markAsSender(bool isSender) - { - kvCacheMeasureHelper.markAsSender(isSender); - } - -protected: - KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()}; }; // Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp index 93df2f96ec0..16771709bb4 100644 --- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp @@ -44,14 +44,16 @@ namespace tensorrt_llm::batch_manager using SizeType32 = CreateNewDecoderRequests::SizeType32; using TensorPtr = CreateNewDecoderRequests::TensorPtr; +using SharedConstPtr = CreateNewDecoderRequests::SharedConstPtr; namespace { void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffers& inputBuffers, - ITensor& sequenceLengths, SizeType32 beamWidth, runtime::BufferManager const& manager, - runtime::CudaStream const& stream) + ITensor& sequenceLengths, SizeType32 beamWidth, runtime::CudaStream const& stream) { + auto const bufferManager = BufferManager{std::make_shared(stream.get())}; + auto const batchSize = contextRequests.size(); auto batchSlotsView = tr::ITensor::slice(inputBuffers.setupBatchSlots, 0, batchSize); auto fillValuesView = tr::ITensor::slice(inputBuffers.fillValues, 0, batchSize); @@ -79,8 +81,8 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe auto batchSlotsDeviceView = tr::ITensor::slice(inputBuffers.setupBatchSlotsDevice, 0, batchSize); auto fillValuesViewDevice = tr::ITensor::slice(inputBuffers.fillValuesDevice, 0, batchSize); - manager.copy(*batchSlotsView, *batchSlotsDeviceView); - manager.copy(*fillValuesView, *fillValuesViewDevice); + bufferManager.copy(*batchSlotsView, *batchSlotsDeviceView); + bufferManager.copy(*fillValuesView, *fillValuesViewDevice); tr::kernels::invokeFillBatch(sequenceLengths, *batchSlotsDeviceView, beamWidth, *fillValuesViewDevice, stream); } } @@ -127,10 +129,10 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe std::tuple, std::vector, std::vector> CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, - executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, - runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, - runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream, - SizeType32 maxSequenceLength, SizeType32 beamWidth, OptionalRef medusaBuffers) const + executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, nvinfer1::DataType logitsType, + DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, + CudaStream const& decoderStream, SizeType32 maxSequenceLength, SizeType32 beamWidth, + OptionalRef medusaBuffers) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(CreateNewDecoderRequests); @@ -141,13 +143,13 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru if (!finishedContextRequests.empty()) { - copySequenceLengths(finishedContextRequests, inputBuffers, *decoderState.getSequenceLengths(), beamWidth, - bufferManager, runtimeStream); + copySequenceLengths( + finishedContextRequests, inputBuffers, *decoderState.getSequenceLengths(), beamWidth, runtimeStream); } - auto [lookaheadPrompt, lookaheadAlgoConfigs] = createDecoderRequests(finishedContextRequests, - inputBuffers.inputsIds, decodingConfig, decoderState, bufferManager, logitsType, modelConfig, worldConfig, - runtimeStream, decoderStream, maxSequenceLength, medusaBuffers); + auto [lookaheadPrompt, lookaheadAlgoConfigs] + = createDecoderRequests(finishedContextRequests, inputBuffers.inputsIds, decodingConfig, decoderState, + logitsType, modelConfig, worldConfig, runtimeStream, decoderStream, maxSequenceLength, medusaBuffers); auto const batchSize = finishedContextRequests.size(); @@ -165,115 +167,122 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru std::move(lookaheadAlgoConfigs)}; } -void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder_batch::Request const& request, - SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, - runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream, - SizeType32 maxSequenceLength) +namespace { - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - - TLLM_CHECK(batchSlot >= 0); - BufferManager manager{std::make_shared(decoderStream.get())}; - - auto const batchSize = decoderState.getMaxBatchSize(); - TLLM_CHECK(0 <= batchSize && batchSlot < batchSize); - auto const maxBeamWidth = decoderState.getMaxBeamWidth(); - auto const beamWidth = samplingConfig.beamWidth; - TLLM_CHECK_WITH_INFO(beamWidth <= maxBeamWidth, - tc::fmtstr("Beam width (%d) must be smaller than maxBeamWidth (%d) passed to decoder setup function.", - beamWidth, maxBeamWidth)); - auto const& requestIds = request.ids; - auto const inputLength = request.inputLen; - auto const numDecodingEngineTokens = request.generatedTokensPerEngineStep; +void initializeInputLengths(DecodingInput& dJointInput, SizeType32 batchSlot, SizeType32 inputLength, + std::optional maxNewTokensOpt, SizeType32 numDecodingEngineTokens, SizeType32 maxSequenceLength, + BufferManager const& manager) +{ auto const numDecodingDraftEngineTokens = numDecodingEngineTokens - 1; - auto const maxNewTokens - = request.maxNewTokens.value_or(maxSequenceLength - inputLength - numDecodingDraftEngineTokens); + auto const maxNewTokens = maxNewTokensOpt.value_or(maxSequenceLength - inputLength - numDecodingDraftEngineTokens); TLLM_CHECK_WITH_INFO(inputLength + maxNewTokens + numDecodingDraftEngineTokens <= maxSequenceLength, tc::fmtstr( "Input length (%d) + max new tokens (%d) + draft tokens (%d) must be less than max sequence length (%d).", inputLength, maxNewTokens, numDecodingDraftEngineTokens, maxSequenceLength)); - TLLM_CHECK(requestIds->getDataType() == TRTDataType::value); - auto const endId = request.endId.value_or(-1); - // input - auto& dJointInput = decoderState.getJointDecodingInput(); + TensorPtr const sequenceLimitLength{ + ITensor::slice(constPointerCast(dJointInput.sequenceLimitLength), batchSlot, 1)}; + runtime::kernels::invokeFill(*sequenceLimitLength, inputLength + maxNewTokens, manager.getStream()); - dJointInput.beamWidths.at(batchSlot) = beamWidth; - decoderState.setNumDecodingEngineTokens(batchSlot, numDecodingEngineTokens); + TensorPtr const inputLengths{ITensor::slice(constPointerCast(dJointInput.lengths), batchSlot, 1)}; + runtime::kernels::invokeFill(*inputLengths, inputLength, manager.getStream()); +} +void initializeRequestIds(DecodingInput& dJointInput, DecodingOutput& dJointOutput, SizeType32 batchSlot, + SharedConstPtr const& requestIds, SizeType32 endId, SizeType32 beamWidth, SizeType32 maxSequenceLength, + BufferManager const& manager) +{ TensorPtr const endIdTensorPtr{ITensor::slice(constPointerCast(dJointInput.endIds), batchSlot, 1)}; - runtime::kernels::invokeFill(*endIdTensorPtr, endId, decoderStream); + runtime::kernels::invokeFill(*endIdTensorPtr, endId, manager.getStream()); + // fill outputIds with endIds + TensorPtr const outputIds = ITensor::slice(dJointOutput.ids, batchSlot, 1); + auto outputIdsTileView = ITensor::view(outputIds, ITensor::makeShape({beamWidth, maxSequenceLength})); + runtime::kernels::invokeFill(*outputIdsTileView, endId, manager.getStream()); + + // copy the request ids into outputIds + auto const requestIdsShape = requestIds->getShape(); + auto outputIdsView = ITensor::view(outputIds, requestIdsShape); + manager.copy(*requestIds, *outputIdsView); +} + +void initializeBeamSearch(DecodingInput& dJointInput, DecodingOutput& dJointOutput, SizeType32 batchSlot, + SizeType32 endId, SizeType32 beamWidth, SizeType32 maxSequenceLength, BufferManager const& manager) +{ + TensorPtr const cumLogProbs = ITensor::slice(dJointOutput.cumLogProbs, batchSlot, 1); + runtime::kernels::invokeFill( + *IBuffer::slice(cumLogProbs, 1, beamWidth - 1), DecodingOutput::kNegativeInfinity, manager.getStream()); + + auto parentIds = ITensor::slice(dJointOutput.parentIds, batchSlot, 1); + auto const outputIdsShape = ITensor::makeShape({1, beamWidth, maxSequenceLength}); + parentIds->reshape(outputIdsShape); + manager.setZero(*parentIds); + + auto cacheIndirectionInput = ITensor::slice(dJointInput.cacheIndirection, batchSlot, 1); + manager.setZero(*cacheIndirectionInput); + + auto cacheIndirectionOutput = ITensor::slice(dJointOutput.cacheIndirection, batchSlot, 1); + manager.setZero(*cacheIndirectionOutput); + + auto beamHypotheses = dJointOutput.beamHypotheses.slice(batchSlot, 1); + beamHypotheses.init(manager, endId); +} + +void initializeEmbeddingBias(DecodingInput& dJointInput, SizeType32 batchSlot, + std::optional const& embeddingBias, nvinfer1::DataType logitsType, + runtime::ModelConfig const& modelConfig, BufferManager const& manager) +{ TensorPtr const embeddingBiasSlice = ITensor::slice(constPointerCast(dJointInput.embeddingBias), batchSlot, 1); - if (request.embeddingBias) + if (embeddingBias.has_value()) { - TLLM_CHECK(request.embeddingBias->getShape().nbDims == 2); - TLLM_CHECK(request.embeddingBias->getShape().d[0] == 1); - TLLM_CHECK_WITH_INFO(request.embeddingBias->getShape().d[1] == modelConfig.getVocabSize(), + auto embeddingBiasTensor = getEmbeddingBias(logitsType, embeddingBias.value()); + + TLLM_CHECK(embeddingBiasTensor->getShape().nbDims == 2); + TLLM_CHECK(embeddingBiasTensor->getShape().d[0] == 1); + TLLM_CHECK_WITH_INFO(embeddingBiasTensor->getShape().d[1] == modelConfig.getVocabSize(), "The embedding bias shape is not as expected. Expected last dimension to be same as vocab size: %d.", modelConfig.getVocabSize()); - manager.copy(*request.embeddingBias, *embeddingBiasSlice); + manager.copy(*embeddingBiasTensor, *embeddingBiasSlice); } else { manager.setZero(*embeddingBiasSlice); } +} - auto setupWords = [](std::vector& jointWordsLists, TensorPtr const& requestWordsList, - SharedConstPtr& jointWordsPtrs, SharedConstPtr& jointWordsLens, SizeType32& jointMaxWordsLen, - SizeType32 batchSlot) +void setupWords(std::vector& jointWordsLists, + std::optional const& requestWordsList, SharedConstPtr& jointWordsPtrs, SharedConstPtr& jointWordsLens, + SizeType32& jointMaxWordsLen, SizeType32 batchSlot, BufferManager const& manager) +{ + if (requestWordsList.has_value()) { - if (requestWordsList) - { - auto const wordsLen = requestWordsList->getShape().d[1]; - BufferRange(*constPointerCast(jointWordsPtrs))[batchSlot] - = runtime::bufferCast(*requestWordsList); - runtime::bufferCast(*constPointerCast(jointWordsLens))[batchSlot] = wordsLen; - // FIXME: this is monotonically growing size - jointMaxWordsLen = std::max(static_cast(wordsLen), jointMaxWordsLen); - - // NOTE: jointWordsList is not used in gptDecoder, but required to keep WordsList's - // memory allocated - jointWordsLists[batchSlot] = requestWordsList; - } - else - { - runtime::bufferCast(*constPointerCast(jointWordsLens))[batchSlot] = 0; - } - }; - - setupWords(dJointInput.stopWordsLists, request.stopWordsList, dJointInput.stopWordsPtrs, dJointInput.stopWordsLens, - dJointInput.maxStopWordsLen, batchSlot); - - setupWords(dJointInput.badWordsLists, request.badWordsList, dJointInput.badWordsPtrs, dJointInput.badWordsLens, - dJointInput.maxBadWordsLen, batchSlot); - - TensorPtr const sequenceLimitLength{ - ITensor::slice(constPointerCast(dJointInput.sequenceLimitLength), batchSlot, 1)}; - runtime::kernels::invokeFill(*sequenceLimitLength, inputLength + maxNewTokens, decoderStream); - - TensorPtr const inputLengths{ITensor::slice(constPointerCast(dJointInput.lengths), batchSlot, 1)}; - runtime::kernels::invokeFill(*inputLengths, inputLength, decoderStream); - - // output - auto& dJointOutput = decoderState.getJointDecodingOutput(); - auto const outputIdsShape = ITensor::makeShape({1, beamWidth, maxSequenceLength}); - - auto finishedSum = ITensor::slice(dJointOutput.finishedSum, batchSlot, 1); - manager.setZero(*finishedSum); - - for (SizeType32 ti = 0; ti < decoderState.getMaxDecodingEngineTokens(); ++ti) + // Move to GPU and remove leading bs1 dimension since this is what decoderRequest expects + TensorPtr wordsList = manager.copyFrom(*requestWordsList.value(), MemoryType::kGPU); + wordsList->squeeze(0); + + auto const wordsLen = wordsList->getShape().d[1]; + BufferRange(*constPointerCast(jointWordsPtrs))[batchSlot] + = runtime::bufferCast(*wordsList); + runtime::bufferCast(*constPointerCast(jointWordsLens))[batchSlot] = wordsLen; + // FIXME: this is monotonically growing size + jointMaxWordsLen = std::max(static_cast(wordsLen), jointMaxWordsLen); + + // NOTE: jointWordsList is not used in gptDecoder, but required to keep WordsList's + // memory allocated + jointWordsLists[batchSlot] = wordsList; + } + else { - TensorPtr const newTokensStepView = ITensor::slice(dJointOutput.newTokensSteps, ti, 1); - newTokensStepView->squeeze(0); - auto newTokensVec = ITensor::slice(newTokensStepView, batchSlot, 1); - manager.setZero(*newTokensVec); + runtime::bufferCast(*constPointerCast(jointWordsLens))[batchSlot] = 0; } +}; - TensorPtr const finishedStepsSlice = ITensor::slice(decoderState.getFinishReasons(), batchSlot, 1); - manager.setZero(*finishedStepsSlice); +void initializeLogProbs(DecodingOutput& dJointOutput, SizeType32 batchSlot, SamplingConfig const& samplingConfig, + BufferManager const& manager) +{ + auto const beamWidth = samplingConfig.beamWidth; // cumLogProb is mandatory for beamWidth > 1 if ((samplingConfig.cumLogProbs.has_value() && samplingConfig.cumLogProbs->at(0)) || beamWidth > 1) @@ -287,49 +296,32 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder auto logProbs = ITensor::slice(dJointOutput.logProbs, batchSlot, 1); manager.setZero(*logProbs); } +} - if (beamWidth > 1) - { - TensorPtr const cumLogProbs = ITensor::slice(dJointOutput.cumLogProbs, batchSlot, 1); - runtime::kernels::invokeFill( - *IBuffer::slice(cumLogProbs, 1, beamWidth - 1), DecodingOutput::kNegativeInfinity, decoderStream); - - auto parentIds = ITensor::slice(dJointOutput.parentIds, batchSlot, 1); - parentIds->reshape(outputIdsShape); - manager.setZero(*parentIds); - - auto cacheIndirectionInput = ITensor::slice(dJointInput.cacheIndirection, batchSlot, 1); - manager.setZero(*cacheIndirectionInput); - - auto cacheIndirectionOutput = ITensor::slice(dJointOutput.cacheIndirection, batchSlot, 1); - manager.setZero(*cacheIndirectionOutput); +void initializeOutputs(DecodingOutput& dJointOutput, SizeType32 batchSlot, SizeType32 maxDecodingEngineTokens, + BufferManager const& manager) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto beamHypotheses = dJointOutput.beamHypotheses.slice(batchSlot, 1); - beamHypotheses.init(manager, endId); - } + auto finishedSum = ITensor::slice(dJointOutput.finishedSum, batchSlot, 1); + manager.setZero(*finishedSum); - // Speculative execution - if (numDecodingEngineTokens > 1 || decoderState.getSpeculativeDecodingMode().isDraftTokensExternal()) + for (SizeType32 ti = 0; ti < maxDecodingEngineTokens; ++ti) { - TLLM_CHECK(beamWidth == 1); - newRequestSpeculativeDecoding(batchSlot, request, samplingConfig, modelConfig, - decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream, decoderStream, - decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens()); + TensorPtr const newTokensStepView = ITensor::slice(dJointOutput.newTokensSteps, ti, 1); + newTokensStepView->squeeze(0); + auto newTokensVec = ITensor::slice(newTokensStepView, batchSlot, 1); + manager.setZero(*newTokensVec); } - // fill outputIds with endIds - TensorPtr const outputIds = ITensor::slice(dJointOutput.ids, batchSlot, 1); - auto outputIdsTileView = ITensor::view(outputIds, ITensor::makeShape({beamWidth, maxSequenceLength})); - runtime::kernels::invokeFill(*outputIdsTileView, endId, decoderStream); - - // copy the request ids into outputIds - auto const requestIdsShape = requestIds->getShape(); - auto outputIdsView = ITensor::view(outputIds, requestIdsShape); - manager.copy(*requestIds, *outputIdsView); + TensorPtr const finishedStepsSlice = ITensor::slice(dJointOutput.finishReasons, batchSlot, 1); + manager.setZero(*finishedStepsSlice); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } +} // namespace + void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, @@ -557,11 +549,12 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec std::tuple, std::vector> CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds, executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState, - BufferManager const& bufferManager, nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig, - runtime::WorldConfig const& worldConfig, runtime::CudaStream const& runtimeStream, - runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength, + nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, + runtime::CudaStream const& runtimeStream, runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength, OptionalRef medusaBuffers) const { + auto const decoderBufferManager = BufferManager{std::make_shared(decoderStream.get())}; + unsigned decoderInputSize{0}; for (auto const& llmReq : finishedContextRequests) { @@ -586,26 +579,38 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon SizeType32 inputOffset{0}; for (auto const& llmReq : finishedContextRequests) { + llmReq->mSamplingConfig.normalizeLogProbs = mIsNormalizeLogProbs; + + TLLM_CHECK(llmReq->mSeqSlot.has_value()); + auto const batchSlot = llmReq->mSeqSlot.value(); + auto const batchSize = decoderState.getMaxNumSequences(); + TLLM_CHECK(0 <= batchSlot && batchSlot < batchSize); + + auto const& samplingConfig = llmReq->mSamplingConfig; + + auto const beamWidth = samplingConfig.beamWidth; + auto const maxBeamWidth = decoderState.getMaxBeamWidth(); + TLLM_CHECK_WITH_INFO(beamWidth <= maxBeamWidth, + tc::fmtstr("Beam width (%d) must be smaller than maxBeamWidth (%d) passed to decoder setup function.", + beamWidth, maxBeamWidth)); + decoderState.setBeamWidth(batchSlot, beamWidth); + auto const promptLen = llmReq->getPromptLen(); - auto const& reqTokens = llmReq->getTokens(0); - TLLM_CHECK(reqTokens.size() == static_cast(promptLen)); - TensorPtr inputView = ITensor::slice(inputIds, inputOffset, promptLen); - bufferManager.copy(reqTokens.data(), *inputView); - auto decoderRequest = decoder_batch::Request{inputView, promptLen, llmReq->mMaxNewTokens, llmReq->mEndId}; + auto decoderRequest = decoder_batch::Request{promptLen}; - llmReq->mSamplingConfig.normalizeLogProbs = mIsNormalizeLogProbs; if (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal()) { if (llmReq->hasDraftTokens()) { auto const& draftTokens = llmReq->getDraftTokens(); - decoderRequest.draftTokens = bufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL); + // Copy to pinned host memory (don't care about stream of bufferManager) + decoderRequest.draftTokens = decoderBufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL); auto const& draftLogits = llmReq->getDraftLogits(); if (draftLogits.has_value()) { decoderRequest.draftLogits - = retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), bufferManager); + = retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), decoderBufferManager); } decoderRequest.generatedTokensPerEngineStep = draftTokens->size() + 1; } @@ -618,48 +623,77 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon { decoderRequest.generatedTokensPerEngineStep = modelConfig.getMaxDecodingTokens(); } - if (modelConfig.getSpeculativeDecodingMode().isMedusa()) - { - TLLM_CHECK(medusaBuffers); - llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs}; - // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest? - // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot. - decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1); - decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1); - } - else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) - { - lookaheadPrompt.emplace_back(ITensor::slice(decoderRequest.ids, 0, decoderRequest.inputLen)); - auto const& lookaheadRuntimeConfig - = llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value()); - lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig); - } - else if (modelConfig.getSpeculativeDecodingMode().isEagle()) - { - decoderRequest.eagleConfig - = llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig(); - } - if (llmReq->getEmbeddingBias().has_value()) - { - decoderRequest.embeddingBias = getEmbeddingBias(logitsType, llmReq->getEmbeddingBias().value()); - } - if (llmReq->getBadWordsList().has_value()) + auto& dJointInput = decoderState.getJointDecodingInput(); + + auto const numDecodingEngineTokens = decoderRequest.generatedTokensPerEngineStep; + initializeInputLengths(dJointInput, batchSlot, promptLen, llmReq->mMaxNewTokens, numDecodingEngineTokens, + maxSequenceLength, decoderBufferManager); + decoderState.setNumDecodingEngineTokens(batchSlot, numDecodingEngineTokens); + + initializeEmbeddingBias( + dJointInput, batchSlot, llmReq->getEmbeddingBias(), logitsType, modelConfig, decoderBufferManager); + + setupWords(dJointInput.badWordsLists, llmReq->getBadWordsList(), dJointInput.badWordsPtrs, + dJointInput.badWordsLens, dJointInput.maxBadWordsLen, batchSlot, decoderBufferManager); + + setupWords(dJointInput.stopWordsLists, llmReq->getStopWordsList(), dJointInput.stopWordsPtrs, + dJointInput.stopWordsLens, dJointInput.maxStopWordsLen, batchSlot, decoderBufferManager); + + auto& dJointOutput = decoderState.getJointDecodingOutput(); + + initializeOutputs(dJointOutput, batchSlot, decoderState.getMaxDecodingEngineTokens(), decoderBufferManager); + + initializeLogProbs(dJointOutput, batchSlot, samplingConfig, decoderBufferManager); + + auto const& reqTokens = llmReq->getTokens(0); + TLLM_CHECK(reqTokens.size() == static_cast(promptLen)); + TensorPtr requestIds = ITensor::slice(inputIds, inputOffset, promptLen); + // Copy to pinned host memory (don't care about stream of bufferManager) + decoderBufferManager.copy(reqTokens.data(), *requestIds); + auto const endId = llmReq->mEndId.value_or(-1); + + initializeRequestIds(dJointInput, dJointOutput, batchSlot, requestIds, endId, beamWidth, maxSequenceLength, + decoderBufferManager); + + if (beamWidth > 1) { - // Move to GPU and remove leading bs1 dimension since this is what decoderRequest expects - decoderRequest.badWordsList = bufferManager.copyFrom(*llmReq->getBadWordsList().value(), MemoryType::kGPU); - decoderRequest.badWordsList->squeeze(0); + initializeBeamSearch( + dJointInput, dJointOutput, batchSlot, endId, beamWidth, maxSequenceLength, decoderBufferManager); } - if (llmReq->getStopWordsList().has_value()) + + // Speculative execution + if (!decoderState.getSpeculativeDecodingMode().isNone()) { - decoderRequest.stopWordsList - = bufferManager.copyFrom(*llmReq->getStopWordsList().value(), MemoryType::kGPU); - decoderRequest.stopWordsList->squeeze(0); - } + TLLM_CHECK(beamWidth == 1); - TLLM_CHECK(llmReq->mSeqSlot.has_value()); - newRequest(llmReq->mSeqSlot.value(), decoderRequest, llmReq->mSamplingConfig, modelConfig, decoderState, - runtimeStream, decoderStream, maxSequenceLength); + if (modelConfig.getSpeculativeDecodingMode().isMedusa()) + { + TLLM_CHECK(medusaBuffers); + llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs}; + // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest? + // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot. + decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1); + decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1); + } + else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) + { + lookaheadPrompt.emplace_back(requestIds); + + auto const& lookaheadRuntimeConfig + = llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value()); + lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig); + } + else if (modelConfig.getSpeculativeDecodingMode().isEagle()) + { + decoderRequest.eagleConfig + = llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig(); + } + + newRequestSpeculativeDecoding(batchSlot, decoderRequest, samplingConfig, modelConfig, + decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream, + decoderStream, decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens()); + } decoderRequests.push_back(decoderRequest); diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index a4617c0d53d..522ec80f84a 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -91,6 +91,43 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) return totalSize; } +void TransferSession::appendMeasure(double delay, double duration, size_t size) +{ + if (!mRecordMeasure) + { + return; + } + auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps + mMeasures.emplace_back(Measure{delay, duration, bandwidth}); +} + +void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const +{ + if (mMeasures.empty()) + { + return; + } + // write header if not exist + if (outFile.tellp() == 0) + { + outFile << "RequestID"; + for (size_t i = 0; i < mMeasures.size(); i++) + { + outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; + } + outFile << '\n'; + } + // write measures + TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value()); + auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId(); + outFile << reqId; + for (auto const& measure : mMeasures) + { + outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; + } + outFile << '\n' << std::flush; +} + class DataResponder::Impl { public: diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index 91215ff66c2..ef66cd1382d 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -97,15 +97,23 @@ class RequestInfo class TransferSession { public: + struct Measure + { + double delay; // from last token (ctx) or arrival time (gen), in ms + double duration; // in ms + double bandwidth; // in Gbps + }; + TransferSession(std::vector connections, DataContext dataContext, executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState, - runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr) + runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false) : mConnections(std::move(connections)) , mDataContext(dataContext) , mSelfState(&selfState) , mOtherState(std::move(otherState)) , mBufferManager(&bufferManager) , mRequest(llmRequest) + , mRecordMeasure(recordMeasure) { TLLM_CHECK(!mConnections.empty()); } @@ -163,6 +171,11 @@ class TransferSession mRequest = &llmRequest; } + void appendMeasure(double delay, double duration, size_t size); + + // TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file + void exportMeasure(std::ofstream& outFile, bool isContext) const; + private: std::vector mConnections; DataContext mDataContext; @@ -170,6 +183,8 @@ class TransferSession executor::DataTransceiverState mOtherState; runtime::BufferManager const* mBufferManager; LlmRequest const* mRequest; + bool mRecordMeasure; + std::vector mMeasures; }; // Operators required for data transmission in specific communication protocols. @@ -266,79 +281,4 @@ class DataRequester std::unique_ptr mImpl; }; -class KvCacheMeasureHelper -{ -public: - struct Measure - { - double delay; // from last token (ctx) or arrival time (gen), in ms - double duration; // in ms - double bandwidth; // in Gbps - }; - - KvCacheMeasureHelper(std::string output_path) - : mOutputPath(std::move(output_path)) - { - } - - void markAsSender(bool isSender) - { - mIsSender = isSender; - } - - void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double delay, double duration, size_t size) - { - auto bandwidth = size * 8 / (duration / 1000) / 1e9; - if (mOutputPath.empty()) - { - return; - } - - std::lock_guard lock(mMutex); - mRequestKVCacheTranfserMeasure[requestId].emplace_back(Measure{delay, duration, bandwidth}); - } - - ~KvCacheMeasureHelper() - { - if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty()) - { - TLLM_CHECK(mIsSender.has_value()); - auto rank = mpi::MpiComm::world().getRank(); - std::string outFilePath - = mOutputPath + "rank_" + std::to_string(rank) + "_" + (mIsSender.value() ? "send" : "recv") + ".csv"; - std::ofstream outFile(outFilePath); - - TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath); - - size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size(); - - outFile << "RequestID"; - for (size_t i = 0; i < numTransferMeasure; i++) - { - outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; - } - outFile << '\n'; - - for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure) - { - outFile << requestID; - - for (auto const& measure : measures) - { - outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; - } - outFile << '\n'; - } - - outFile.close(); - } - } - -private: - std::map> mRequestKVCacheTranfserMeasure; - std::string mOutputPath; - std::mutex mMutex; - std::optional mIsSender; -}; - } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp index 9a72bf2d00f..1a5c7fab4dd 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp @@ -21,6 +21,8 @@ #include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" +#include + namespace tensorrt_llm::batch_manager { @@ -30,6 +32,21 @@ static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId) return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF); } +namespace fs = std::filesystem; + +static fs::path getTransferOutputPath(char const* tag) +{ + auto outputPath = common::getEnvKVCacheTransferOutputPath(); + if (!outputPath.empty()) + { + auto rank = mpi::MpiComm::world().getRank(); + auto path = fs::path(outputPath); + fs::create_directories(path); + return path / ("rank_" + std::to_string(rank) + "_" + tag + ".csv"); + } + return {}; +} + DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter) : mManager{manager} @@ -39,7 +56,6 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, { TLLM_CHECK(mManager); TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); - mFormatter->markAsSender(true); } [[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo() @@ -86,7 +102,8 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, if (it == mRequestToSession.end()) { auto session = TransferSession(std::vector(peerRelativeRanks.size(), nullptr), - DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager); + DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr, + !common::getEnvKVCacheTransferOutputPath().empty()); it = mRequestToSession.emplace(requestId, std::move(session)).first; } it->second.setConnection(peerIdx, connection); @@ -125,6 +142,17 @@ void DataSenderImpl::release(LlmRequest::RequestIdType requestId) auto it = mRequestToSession.find(requestId); TLLM_CHECK(it != mRequestToSession.end()); std::unique_lock lk(mMtxForMap); + if (!common::getEnvKVCacheTransferOutputPath().empty()) + { + if (!mMeasuresFile.is_open()) + { + auto outputPath = getTransferOutputPath("send"); + mMeasuresFile.open(outputPath); + TLLM_CHECK_WITH_INFO( + mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str()); + } + it->second.exportMeasure(mMeasuresFile, true); + } mRequestToSession.erase(it); } @@ -137,7 +165,6 @@ DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manage TLLM_CHECK(mManager); TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); TLLM_CHECK(mFormatter); - mFormatter->markAsSender(false); } TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) @@ -203,12 +230,24 @@ TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) } auto const& resource = getReceiveCacheResource(llmRequest); return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState, - contextState, resource->mBufferManager, &llmRequest); + contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty()); } void DataReceiverImpl::receiveSync(TransferSession& session) { mFormatter->unformat(session); + if (!common::getEnvKVCacheTransferOutputPath().empty()) + { + std::unique_lock lock(mMeasuresFileMutex); + if (!mMeasuresFile.is_open()) + { + auto outputPath = getTransferOutputPath("recv"); + mMeasuresFile.open(outputPath); + TLLM_CHECK_WITH_INFO( + mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str()); + } + session.exportMeasure(mMeasuresFile, false); + } } void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info) diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h index fa8d2728329..2f277f14fff 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h @@ -23,6 +23,8 @@ #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h" +#include + namespace tensorrt_llm::batch_manager { struct TransceiverTag @@ -67,6 +69,7 @@ class DataSenderImpl : public DataSender, public TransceiverTag std::unique_ptr mFormatter; std::mutex mMtxForMap; runtime::BufferManager mBufferManager; + std::ofstream mMeasuresFile; }; class DataReceiverImpl : public DataReceiver, public TransceiverTag @@ -103,6 +106,8 @@ class DataReceiverImpl : public DataReceiver, public TransceiverTag std::unique_ptr mFormatter; std::unordered_map> mProcessToResources; std::mutex mProcessIoResouceMutex; + std::ofstream mMeasuresFile; + std::mutex mMeasuresFileMutex; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp index 040dcd147e9..ea5f0981074 100644 --- a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp +++ b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp @@ -88,8 +88,7 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests) continue; } auto const seqSlot = llmReq->mSeqSlot.value(); - if (llmReq->isContextInitState() - && llmReq->getContextCurrentPosition() == llmReq->getPrepopulatedPromptLen()) + if (llmReq->isContextInitState() && llmReq->isFirstContextChunk()) { // The request is in the first context forward step (considering kv cache reuse). auto const& guideType = guidedDecodingParams->getGuideType(); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp index ff2a2f6b787..ac37278d45f 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp @@ -18,20 +18,51 @@ #include "tensorrt_llm/batch_manager/kvCacheEventManager.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/serialization.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" namespace tle = tensorrt_llm::executor; namespace tensorrt_llm::batch_manager::kv_cache_manager { -KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries) +KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries, std::optional attentionDpRank, + std::optional attentionDpSize, SizeType32 attentionDpEventsGatherPeriodMs) : mRun{true} , mMaxSize{maxKVEventEntries} , mEventId{0} + , mAttentionDpRank{attentionDpRank} + , mAttentionDpSize{attentionDpSize} + , mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs) { TLLM_CHECK(mMaxSize > 0); - // mWorkerThread = std::thread(std::bind(&KVCacheEventManager::worker, this)); + if (mAttentionDpRank) + { + TLLM_CHECK_WITH_INFO( + mAttentionDpSize.has_value(), "If attention DP rank is set, the attention DP size must also be set"); + TLLM_CHECK_WITH_INFO(mAttentionDpRank.value() < mAttentionDpSize.value(), + "Attention DP rank must be less than attention DP size"); + if (mAttentionDpRank.value() == 0) + { + // Rank 0 will gather events from all other ranks + // Need to increase size + mMaxSize *= mAttentionDpSize.value(); + } + // Create a communicator to be used for event exchange + mMpiComm = std::make_unique(COMM_SESSION.split(0, mAttentionDpRank.value())); + } + else + { + TLLM_CHECK_WITH_INFO( + !mAttentionDpSize.has_value(), "If attention DP rank is not set, the attention DP size must not be set"); + } mWorkerThread = std::thread([this]() { this->worker(); }); +#if ENABLE_MULTI_DEVICE + if (mAttentionDpRank) + { + mExchangeAttentionDpThread = std::thread([this]() { this->exchangeAttentionDpThread(); }); + } +#endif }; KVCacheEventManager::~KVCacheEventManager() @@ -40,12 +71,18 @@ KVCacheEventManager::~KVCacheEventManager() mPendingEmptyCV.notify_all(); mEmptyCV.notify_all(); mWorkerThread.join(); +#if ENABLE_MULTI_DEVICE + if (mAttentionDpRank) + { + mExchangeAttentionDpThread.join(); + } +#endif } void KVCacheEventManager::enqueueCreatedEvent( std::vector const& numBlocksPerCacheLevel, SizeType32 windowSize) { - enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize}); + enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize, mAttentionDpRank}); } void KVCacheEventManager::enqueueStoredEvent(std::vector const& blocks, SizeType32 windowSize) @@ -68,7 +105,7 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector const& blocks block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority()); } - enqueueEvent({mEventId++, data, windowSize}); + enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank}); } void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 windowSize) @@ -81,13 +118,13 @@ void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 } else { - enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize}); + enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize, mAttentionDpRank}); } } void KVCacheEventManager::enqueueUpdatedEvent(tle::KVCacheUpdatedData const& data, SizeType32 windowSize) { - enqueueEvent({mEventId++, data, windowSize}); + enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank}); } void KVCacheEventManager::enqueueEvent(tle::KVCacheEvent&& event) @@ -120,8 +157,76 @@ void KVCacheEventManager::flush() mPendingEmptyCV.notify_one(); } +void KVCacheEventManager::exchangeAttentionDpThread() +{ +#if ENABLE_MULTI_DEVICE + while (true) + { + TLLM_CHECK(mAttentionDpRank); + + // Check if any of the ranks have been shutdown + int32_t numFinished = 0; + int32_t finished = mRun ? 0 : 1; + mMpiComm->allreduce(&finished, &numFinished, 1, mpi::MpiType::kINT32, mpi::MpiOp::SUM); + if (numFinished > 0) + { + TLLM_LOG_INFO("One of the rank has been shut down, exiting"); + break; + } + + // If we are not rank 0, send events to rank 0 + if (mAttentionDpRank.value() != 0) + { + std::vector serializedEvents; + uint64_t numEvents = 0; + { + std::lock_guard lck(mEventsMutex); + serializedEvents = executor::Serialization::serialize(mEvents); + numEvents = mEvents.size(); + mEvents.clear(); + } + uint64_t vecSize = numEvents > 0 ? serializedEvents.size() : 0; + mMpiComm->send(&vecSize, 1, mpi::MpiType::kUINT64, 0, mpi::MpiTag::kKvCacheEventSize); + if (vecSize > 0) + { + mMpiComm->send(serializedEvents.data(), serializedEvents.size(), mpi::MpiType::kCHAR, 0, + mpi::MpiTag::kKvCacheEvent); + } + } + else + { + TLLM_CHECK(mAttentionDpSize.has_value()); + // Loop until have received events from all ranks + for (int rank = 1; rank < mAttentionDpSize.value(); ++rank) + { + uint64_t vecSize{0}; + mMpiComm->recv(&vecSize, 1, mpi::MpiType::kUINT64, rank, mpi::MpiTag::kKvCacheEventSize); + if (vecSize > 0) + { + std::vector serializedEvents(vecSize); + mMpiComm->recv( + serializedEvents.data(), vecSize, mpi::MpiType::kCHAR, rank, mpi::MpiTag::kKvCacheEvent); + + // Deserialize the events and add them to the local queue + auto rankEvents = executor::Serialization::deserializeKVCacheEvents(serializedEvents); + { + std::lock_guard lck(mEventsMutex); + mEvents.insert(mEvents.end(), rankEvents.begin(), rankEvents.end()); + mEmptyCV.notify_one(); + } + } + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(mAttentionDpEventsGatherPeriodMs)); + } +#else + TLLM_THROW("Multi device support is disabled."); +#endif +} + void KVCacheEventManager::worker() { + while (true) { std::deque events; @@ -151,6 +256,8 @@ void KVCacheEventManager::worker() // If there's still too many events, take from the front of the events queue. mEvents.insert(mEvents.end(), events.begin() + std::max(0, elementsToRemove), events.end()); + + // Notify the empty condition variable to wake up any waiting threads mEmptyCV.notify_one(); } } diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 4202ba348ac..d5fa982a37a 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -504,8 +504,7 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) : mNumLayers{static_cast(numKvHeadsPerLayer.size())} , mTokensPerBlock{tokensPerBlock} , mEventManager{std::move(eventManager)} @@ -530,7 +529,7 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si TLLM_CHECK(allottedPrimaryBlocks > 0); // You can't have a model with negative primary blocks... mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer, sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream, - onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enableHashKey, enablePartialReuse, + onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse, copyOnPartialReuse); } @@ -573,8 +572,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) : mDataType{dtype} , mWindowSize{windowSize} , mNumPrimaryBlocks{blocksInPrimaryPool} @@ -596,7 +594,6 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind , mLogPrefix{tensorrt_llm::common::fmtstr("BlockManager[windowSize=%u]", mWindowSize)} , mReusedTokens{0.0} , mTotalInputTokens{0.0} - , mEnableHashKey{enableHashKey} , mEnablePartialReuse{enablePartialReuse} , mCopyOnPartialReuse{copyOnPartialReuse} { @@ -920,50 +917,6 @@ void BlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims const mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId); } -void WindowBlockManager::addBlockToHashMap(BlockPtr const& block) -{ - if (!mEnableHashKey) - { - return; - } - auto range = mContextBlocksByHash.equal_range(block->getHash()); - for (auto it = range.first; it != range.second; ++it) - { - if (it->second == block) - { - // TODO: change to assert when reused block is added only once - TLLM_LOG_TRACE( - "Block %d by %zx exists", block->getBlockId(), block->getHash(), mContextBlocksByHash.size()); - return; - } - } - TLLM_LOG_TRACE( - "Add block %d by %zx, block n = %zu", block->getBlockId(), block->getHash(), mContextBlocksByHash.size()); - mContextBlocksByHash.emplace(block->getHash(), std::move(block)); -} - -void WindowBlockManager::removeBlockFromHashMap(BlockPtr const& block) -{ - if (mContextBlocksByHash.empty() || block->getBlockKey().uniqueTokens.empty()) - { - // Hash key not enabled / Empty block - return; - } - auto range = mContextBlocksByHash.equal_range(block->getHash()); - TLLM_LOG_TRACE( - "Remove block %d by %zx, block n = %zu", block->getBlockId(), block->getHash(), mContextBlocksByHash.size()); - for (auto it = range.first; it != range.second; ++it) - { - if (it->second == block) - { - mContextBlocksByHash.erase(it); - return; - } - } - // TODO: should be unreachable - TLLM_LOG_DEBUG("Trying to remove block %d by %zx that is not in hash map", block->getBlockId(), block->getHash()); -} - void BlockManager::onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowSize) { mWindowBlockManagers.at(windowSize).onboardBlock(offloadBlock); @@ -1104,7 +1057,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Reused partially filled block %d", mLogPrefix.c_str(), matchingBlockId); - addBlockToHashMap(matchingBlock); } searchRoot = nullptr; // no matching needed for following blocks } @@ -1114,7 +1066,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& mEvictionPolicy->claimBlock( matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Matched full block %d", mLogPrefix.c_str(), matchingBlockId); - addBlockToHashMap(matchingBlock); searchRoot = matchingBlock; } onboardBlock(matchingBlock); @@ -1145,7 +1096,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& ++blockItr; } freeBlock->setHash(); - addBlockToHashMap(freeBlock); ++mMissedBlocks; } } @@ -1169,7 +1119,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& ++blockItr; } freeBlock->setHash(); - addBlockToHashMap(freeBlock); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Beam %d. Allocated non-shared block %d for bi %d", mLogPrefix.c_str(), beamIdx, freeBlock->getBlockId(), bi); } @@ -1369,9 +1318,7 @@ void WindowBlockManager::storeBlocks( if (oldHash != newHash) { TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash); - removeBlockFromHashMap(block); block->setHash(newHash); - addBlockToHashMap(block); } searchRoot = block; } @@ -1408,7 +1355,6 @@ void WindowBlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeTyp if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block); - removeBlockFromHashMap(block); } } @@ -1473,7 +1419,6 @@ void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence) if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block, true); - removeBlockFromHashMap(block); } // Remove block from allocated blocks allocatedBlocks.pop_back(); @@ -1616,7 +1561,6 @@ void WindowBlockManager::releaseBlocks(GenerationRequest& sequence) if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block); - removeBlockFromHashMap(block); } } // Remove stored block ids in sequence @@ -1654,8 +1598,7 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::make_shared(reinterpret_cast(stream)), maxSequenceLength, - enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, false, enablePartialReuse, - copyOnPartialReuse) + enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse) { } @@ -1682,8 +1625,7 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) : mMaxBeamWidth(maxBeamWidth) , mDataType(dtype) , mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end())) @@ -1693,10 +1635,9 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer , mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager), - enableHashKey, enablePartialReuse, copyOnPartialReuse) + enablePartialReuse, copyOnPartialReuse) // disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case , mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse} - , mEnableHashKey{enableHashKey} { TLLM_CHECK_DEBUG(std::find(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end(), mMaxAttentionWindow) != maxAttentionWindowVec.end()); @@ -1716,12 +1657,11 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, - std::move(eventManager), enableHashKey, enablePartialReuse, copyOnPartialReuse) + std::move(eventManager), enablePartialReuse, copyOnPartialReuse) { } @@ -2085,30 +2025,6 @@ void KVCacheManager::addSequence( llmRequest->mRequestId); } mBlockManager.addSequence(sequence, numContextBlocks, unsharedBlockIdx, windowSize); - if (mEnableHashKey && llmRequest.has_value() && beamWidth == 1) - { - constexpr SizeType32 beamIdx = 0; - auto const& blockIds = sequence.getCacheBlockIds(windowSize).at(beamIdx); - auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx); - auto blockedUniqueTokens = chopVectorIntoBlocks( - uniqueTokens, uniqueTokens.size() - 1, getTokensPerBlock(), true); - auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); - auto tokensPerBlock = static_cast(getTokensPerBlock()); - for (size_t i = 0; i < blockIds.size(); i++) - { - auto const& block = mBlockManager.getBlockById(blockIds[i], windowSize); - if (i < blockKeys.size()) - { - block->setBlockKey(blockKeys[i], blockKeys[i].uniqueTokens.size() == tokensPerBlock); - } - else - { - block->setBlockKey({}, false); - } - block->setHash(); - mBlockManager.addBlockToHashMap(block, windowSize); - } - } } cacheBlockOffsets(sequence, windowSize); } @@ -2127,10 +2043,13 @@ void KVCacheManager::addSequence( void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest) { auto const requestId = llmRequest.mRequestId; - auto& sequence = getSequence(requestId); - if (mEnableBlockReuse && !sequence.isCyclic() && !llmRequest.isDummyRequest()) + if (mSequences.find(requestId) != mSequences.end()) { - mBlockManager.storeContextBlocks(sequence, llmRequest); + auto& sequence = getSequence(requestId); + if (mEnableBlockReuse && !sequence.isCyclic() && !llmRequest.isDummyRequest()) + { + mBlockManager.storeContextBlocks(sequence, llmRequest); + } } } diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index a9a4aec5dfc..dcebc9c3ac6 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -365,4 +365,10 @@ void LlmRequest::moveLoraWeightsToGpu(runtime::BufferManager const& manager) mLoraWeights = gpuLoraWeights; } +void LlmRequest::removeLoraTensors() +{ + mLoraWeights.reset(); + mLoraConfig.reset(); +} + } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 824a31129f8..22756f25527 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -45,12 +45,10 @@ std::vector MLACacheFormatter::pickRecvConnections( auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); TLLM_CHECK(numConnections == targetInfo.mIRanks.size()); std::vector ret; - // targetInfo , mRanks [tpranks, ppranks] - int dpRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0; - + // targetInfo , mRanks [tpranks, dpranks] for (int i = 0; i < targetInfo.mDomainPPSize; i++) { - ret.push_back(i + (dpRank % (targetInfo.mDomainTPSize)) * targetInfo.mDomainPPSize); + ret.push_back(i); } return ret; } @@ -60,24 +58,19 @@ bool MLACacheFormatter::needSendCache( { int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism; - int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP - ? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize - : destConfig.getParallelConfig().mTensorParallelism; - int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0; - if (selfConfig.getParallelConfig().mEnableAttentionDP) { int selfTPNumInDPGroup = selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize; - + int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP + ? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize + : destConfig.getParallelConfig().mTensorParallelism; int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup; if (selfTPNumInDPGroup <= destTPNumInDPGroup) { return true; } - - int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup; - return selfTPrankINDPGroup % dupHeadFactor == destDPRank; + return selfTPrankINDPGroup % (selfTPNumInDPGroup / destTPNumInDPGroup) == 0; } int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP @@ -88,8 +81,7 @@ bool MLACacheFormatter::needSendCache( { return true; } - int dupHeadFactor = selfTPNum / destTPNum; - return selfTpRank % dupHeadFactor == destDPRank; + return selfTpRank % (selfTPNum / destTPNum) == 0; } void MLACacheFormatter::format(TransferSession& session) @@ -244,7 +236,7 @@ void MLACacheFormatter::format(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration(endTime - startTime).count()); - kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, size); }; if (connections.size() > 1) @@ -441,7 +433,7 @@ void MLACacheFormatter::unformat(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration(endTime - startTime).count()); - kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, size); }; if (pickUpConnections.size() > 1) @@ -591,6 +583,28 @@ void MLACacheFormatter::unformat(TransferSession& session) return false; } + int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(); + int selfPPSize = selfConfig.getParallelConfig().mPipelineParallelism; + int destPPSize = destConfig.getParallelConfig().mPipelineParallelism; + int destNumLayers = destConfig.getModelConfig().mNbKvHeadsPerLayer.size(); + + if (selfPPSize == destPPSize) + { + return true; + } + if (selfNumLayers % selfPPSize != 0) + { + TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d", + selfNumLayers, selfPPSize); + return false; + } + if (destNumLayers % destPPSize != 0) + { + TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d ", + destNumLayers, destPPSize); + return false; + } + return true; } } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp index f513f2a3a10..cc62bd3eb04 100644 --- a/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp @@ -591,10 +591,9 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr llmRe TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); if (llmRequest->getLoraTaskId().has_value()) { - auto taskId = llmRequest->getLoraTaskId().value(); try { - return mHostLoraCache->determineNumPages(taskId); + return mHostLoraCache->determineNumPages(llmRequest->getLoraTaskId().value()); } catch (std::runtime_error& e) { @@ -602,16 +601,6 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr llmRe { return mHostLoraCache->determineNumPages(llmRequest->getLoraConfig().value()); } - if (!llmRequest->getLoraWeights().has_value()) - { - auto const reqId = llmRequest->mRequestId; - std::string errMsg - = "Request ID " + std::to_string(reqId) + " has no LoRA adapter weights while configured with LoRA task " - + std::to_string(taskId) + " that's not found in LoRA CPU cache." - " Note that currently a request with LoRA task that was already loaded is sent without its LoRA weights to save its serialization, copy and deserialization," - " so if this LoRA task was evicted from LoRA CPU cache, then its reuse is currently not supported."; - throw PeftTaskNotCachedException(errMsg); - } throw; } } diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 4a5ddb89286..08cb4d407c1 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -693,7 +693,7 @@ std::unique_ptr TrtGptModelInflightBatching::c kvCacheConfig.getEventBufferMaxSize() > 0 ? std::make_unique(kvCacheConfig.getEventBufferMaxSize()) : nullptr, - false, kvCacheConfig.getEnablePartialReuse(), kvCacheConfig.getCopyOnPartialReuse()); + kvCacheConfig.getEnablePartialReuse(), kvCacheConfig.getCopyOnPartialReuse()); reshapeKvTensors(kvCacheManager->getOffsetTableDimensions()); @@ -1866,9 +1866,9 @@ void TrtGptModelInflightBatching::setupDecoderStep( auto const logitsType = mRuntime->getEngine().getTensorDataType("logits"); auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] - = (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests, - mRuntime->getBufferManager(), logitsType, inputBuffers, *mDecoderState, mRuntime->getStream(), - *mDecoder->getDecoderStream(), getMaxSequenceLen(), mOperatingBeamWidth, buffers.mMedusaBuffers); + = (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests, logitsType, + inputBuffers, *mDecoderState, mRuntime->getStream(), *mDecoder->getDecoderStream(), getMaxSequenceLen(), + mOperatingBeamWidth, buffers.mMedusaBuffers); auto const localBatchSize = batchSlots->getSize(); if (localBatchSize > 0) diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index 6e1498ba713..03d03eca3af 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -55,6 +55,7 @@ struct FusedQKVMaskedAttentionDispatchParams T const* qkv_bias; T const* relative_attention_bias; bool const* attention_mask; + float const* attention_sinks; float const* logn_scaling_ptr; int const* cache_indir; void* context_buf; @@ -71,6 +72,7 @@ struct FusedQKVMaskedAttentionDispatchParams RotaryScalingType rotary_embedding_scale_type; float rotary_embedding_scale; float const* rotary_embedding_inv_freq_cache; + float2 const* rotary_embedding_cos_sin_cache; float rotary_embedding_short_m_scale; float rotary_embedding_long_m_scale; int rotary_embedding_max_positions; @@ -225,6 +227,7 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& xqaParams.output = generationsParams.context_buf; xqaParams.qkv = generationsParams.attention_input; xqaParams.cache_indir = generationsParams.cache_indir; + xqaParams.attention_sinks = generationsParams.attention_sinks; xqaParams.kv_scale_orig_quant = generationsParams.kv_scale_orig_quant; xqaParams.kv_scale_quant_orig = generationsParams.kv_scale_quant_orig; xqaParams.host_past_key_value_lengths = generationsParams.host_past_key_value_lengths; @@ -275,7 +278,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& xqaParams.logn_scaling_ptr = generationsParams.logn_scaling_ptr; xqaParams.total_num_input_tokens = mCpSize > 1 ? generationsParams.num_requests : generationsParams.num_tokens; xqaParams.is_fp8_output = mFP8ContextFMHA; - xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr); + xqaParams.fp8_out_scale + = ((mFP8ContextFMHA || mFP8ContextMLA) ? generationsParams.attention_output_orig_quant : nullptr); // Parameters required for FP4 output. xqaParams.output_sf = generationsParams.context_buf_sf; xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale; @@ -596,6 +600,7 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_paramsisSeparateQAndKvInput() + int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_v_per_head = (mMLAParams.v_head_dim); + + // Total dimension per token across all heads for Q, K, and V components respectively + int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head; + int const total_k_dim_all_heads + = mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout + int const total_v_dim_all_heads + = mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout + + int const num_total_qkv_elements + = max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads); + + size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv) : 0; + if (mFP8ContextMLA) + { + fp8_qkv_buffer_size + = mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0; + } + size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens; size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens; // Each token holds (batch_idx, token_idx_in_seq) int2. @@ -1342,10 +1369,26 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea size_t const qk_buf_float_size = mEnableContextFMHA ? 0 : sizeof(float) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length; - size_t const fp8_qkv_buffer_size - = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() + int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_v_per_head = (mMLAParams.v_head_dim); + + // Total dimension per token across all heads for Q, K, and V components respectively + int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head; + int const total_k_dim_all_heads + = mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout + int const total_v_dim_all_heads + = mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout + int const num_total_qkv_elements + = params.num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads); + size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv) : 0; + if (mFP8ContextMLA) + { + fp8_qkv_buffer_size + = mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0; + } size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length; size_t const encoder_padding_offset_size @@ -1353,8 +1396,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea // Each token holds (batch_idx, token_idx_in_seq) int2. size_t const tokens_info_size = sizeof(int2) * params.num_tokens; size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0; - size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0; - size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0; + size_t const fmha_bmm1_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) * 2 : 0; + size_t const fmha_bmm2_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) : 0; // cp workspace size upper bound size_t const cpMaxPadedSequenceLength = params.num_tokens + params.batch_size * (mCpSize - 1); @@ -1601,6 +1644,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea params.mla_param->cache_type = cache_type; params.mla_param->cu_q_seqlens = cu_q_seqlens; params.mla_param->quant_scale_kv = params.kv_scale_orig_quant; + // Set BMM scales for FP8 context computation + params.mla_param->bmm1_scale = fmha_bmm1_scale_ptr; + params.mla_param->bmm2_scale = fmha_bmm2_scale_ptr; + params.mla_param->host_bmm1_scale = decoder_params.fmhaHostBmm1Scale; + params.mla_param->quant_attention_input_buf = mFP8ContextMLA ? fp8_qkv_buffer : nullptr; + // Set additional scales for context phase + params.mla_param->quant_scale_o = params.attention_output_orig_quant; + params.mla_param->dequant_scale_q = params.kv_scale_quant_orig; + params.mla_param->dequant_scale_kv = params.kv_scale_quant_orig; if (mPagedContextFMHA && mPagedKVCache) { TLLM_CHECK_WITH_INFO(params.mla_param->context_paged_kv_ptr != nullptr, @@ -1679,8 +1731,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea // TODO: set it correctly for contiguous kv buffer (cross-attention). fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.num_tokens; // Device buffer pointers. - fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast(fp8_qkv_buffer) - : reinterpret_cast(attention_input); + fmhaParams.qkvPtr = (mFP8ContextFMHA || mFP8ContextMLA) ? reinterpret_cast(fp8_qkv_buffer) + : reinterpret_cast(attention_input); fmhaParams.qPtr = reinterpret_cast(q_buf_2_); // TODO: add contiguous kv buffer (cross-attention). fmhaParams.kvPtr = nullptr; @@ -1691,6 +1743,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea fmhaParams.outputPtr = mCpSize > 1 ? gatherOutBuffer : params.context_buf; // only use [totalLength, h / cpSize, Dh] fmhaParams.outputSfPtr = params.context_buf_sf; + fmhaParams.attentionSinksPtr = params.attention_sinks; fmhaParams.packedMaskPtr = params.attention_packed_mask; if constexpr (std::is_same_v) { @@ -2220,6 +2273,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams const& params, cud dispatch_params.relative_attention_bias_stride = relative_attention_bias_stride; dispatch_params.attention_mask = params.attention_mask; dispatch_params.attention_mask_stride = params.attention_mask_stride; + dispatch_params.attention_sinks = params.attention_sinks; dispatch_params.max_distance = max_distance; dispatch_params.cache_indir = params.cache_indir; dispatch_params.context_buf = mCpSize > 1 ? mhaOutput : params.context_buf; // @@ -2267,6 +2321,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams const& params, cud dispatch_params.rotary_embedding_scale_type = mRotaryEmbeddingScaleType; dispatch_params.rotary_embedding_scale = mRotaryEmbeddingScale; dispatch_params.rotary_embedding_inv_freq_cache = params.rotary_inv_freq; + dispatch_params.rotary_embedding_cos_sin_cache = params.rotary_cos_sin; dispatch_params.rotary_embedding_short_m_scale = mRotaryEmbeddingShortMscale; dispatch_params.rotary_embedding_long_m_scale = mRotaryEmbeddingLongMscale; dispatch_params.rotary_embedding_max_positions = mRotaryEmbeddingMaxPositions; @@ -2477,7 +2532,7 @@ int AttentionOp::initialize() noexcept } // FP8 FMHA should be used with fp8 workflow together. - if (mFP8ContextFMHA) + if (mFP8ContextFMHA || mFP8ContextMLA) { data_type = DATA_TYPE_E4M3; } @@ -2510,6 +2565,11 @@ int AttentionOp::initialize() noexcept fmhaParams.dataTypeOut = DATA_TYPE_BF16; fmhaParams.dataTypeKv = DATA_TYPE_BF16; } + if (mFP8ContextMLA && mKVCacheQuantMode.hasFp8KvCache()) + { + fmhaParams.dataTypeKv = DATA_TYPE_E4M3; + fmhaParams.dataTypeOut = DATA_TYPE_BF16; + } // TODO: remove forceFp32Acc from MHARunnerFixedParams after adding host_runtime_perf_knobs to // bertAttentionPlugin input tensors, so that we can change mLaunchParams.force_fp32_acc value in runtime. fmhaParams.forceFp32Acc = false; @@ -2563,7 +2623,7 @@ int AttentionOp::initialize() noexcept // Deepseek-V2 Generation needs a differ fmha with different argumments if (mIsMLAEnabled) { - mEnableXQA = (mSM == kSM_120); + mEnableXQA = (mSM == kSM_120) && mIsGenerationMLA; if (mUseTllmGen) { Data_type qDataType = DATA_TYPE_FP32; @@ -2826,6 +2886,7 @@ std::string AttentionOp::toString() const ss << "mPosShiftEnabled: " << std::boolalpha << mPosShiftEnabled << std::endl; ss << "mPagedContextFMHA: " << std::boolalpha << mPagedContextFMHA << std::endl; ss << "mFP8ContextFMHA: " << std::boolalpha << mFP8ContextFMHA << std::endl; + ss << "mFP8ContextMLA: " << std::boolalpha << mFP8ContextMLA << std::endl; ss << "mDenseContextFMHA: " << std::boolalpha << mDenseContextFMHA << std::endl; ss << "mEnableContextFMHA: " << std::boolalpha << mEnableContextFMHA << std::endl; ss << "mFMHAForceFP32Acc: " << std::boolalpha << mFMHAForceFP32Acc << std::endl; diff --git a/cpp/tensorrt_llm/common/attentionOp.h b/cpp/tensorrt_llm/common/attentionOp.h index fb71c06d57b..25d95dfea2b 100644 --- a/cpp/tensorrt_llm/common/attentionOp.h +++ b/cpp/tensorrt_llm/common/attentionOp.h @@ -65,6 +65,8 @@ class AttentionOp T const* qkv_bias = nullptr; // Attention mask input, which has shape of [batch_size, attention_mask_stride]. bool const* attention_mask = nullptr; + // Attention sinks with shape of [num_heads_q] float. + float const* attention_sinks = nullptr; // Rotary inv_freq cache buffer to avoid re-computing. float const* rotary_inv_freq = nullptr; // Rotary cos sin cache buffer to avoid re-computing. @@ -386,6 +388,7 @@ class AttentionOp bool mPosShiftEnabled = false; bool mPagedContextFMHA = false; bool mFP8ContextFMHA = false; + bool mFP8ContextMLA = false; bool mFP8GenerationMLA = false; bool mDenseContextFMHA = false; bool mHasFullAttentionMask = false; diff --git a/cpp/tensorrt_llm/common/envUtils.cpp b/cpp/tensorrt_llm/common/envUtils.cpp index f7480229410..59c9d2fffe4 100644 --- a/cpp/tensorrt_llm/common/envUtils.cpp +++ b/cpp/tensorrt_llm/common/envUtils.cpp @@ -366,6 +366,12 @@ bool getEnvForceDeterministicMOE() return forceDeterministic; } +bool getEnvMOEDisableFinalizeFusion() +{ + static bool const moeDisableFinalizeFusion = getBoolEnv("TRTLLM_MOE_DISABLE_FINALIZE_FUSION"); + return moeDisableFinalizeFusion; +} + bool getEnvForceDeterministicAttention() { static bool const forceDeterministic @@ -386,7 +392,7 @@ size_t getEnvAllReduceWorkspaceSize() return workspaceSize; } -std::string getEnvKVCacheTransferOutputPath() +std::string const& getEnvKVCacheTransferOutputPath() { static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or(""); return outputPath; diff --git a/cpp/tensorrt_llm/common/envUtils.h b/cpp/tensorrt_llm/common/envUtils.h index 5e29dfaca71..f5c0d854ba4 100644 --- a/cpp/tensorrt_llm/common/envUtils.h +++ b/cpp/tensorrt_llm/common/envUtils.h @@ -76,7 +76,7 @@ bool getEnvDisableKVCacheTransferOverlap(); bool getEnvEnableReceiveKVCacheParallel(); -std::string getEnvKVCacheTransferOutputPath(); +std::string const& getEnvKVCacheTransferOutputPath(); bool getEnvTryZCopyForKVCacheTransfer(); @@ -86,6 +86,9 @@ bool getEnvForceDeterministic(); // Force deterministic behavior for MoE plugin. bool getEnvForceDeterministicMOE(); +// Disable finalize fusion in MoE plugin +bool getEnvMOEDisableFinalizeFusion(); + // Force deterministic behavior for attention plugin. bool getEnvForceDeterministicAttention(); diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp index c10df82d54c..53dc9e053ad 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp @@ -27,6 +27,78 @@ namespace cutlass::gemm::collective::detail { +using namespace cute; + +typedef uint32_t __nv_fp4x8_storage_t; +typedef uint32_t __nv_bf16x2_storage_t; +typedef cutlass::uint128_t __nv_bf16x8_storage_t; + +constexpr int int4_group_size = 128; +constexpr int mxfp4_group_size = 32; + +inline __device__ unsigned prmt(unsigned hi, unsigned lo, unsigned select_code) +{ + unsigned res = 0; + + asm volatile( + "{\n" + "prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(res) + : "r"(lo), "r"(hi), "r"(select_code)); + + return res; +} + +__device__ __inline__ __nv_fp8x4_storage_t cvt_lut_bf16(unsigned const index) +{ + const __nv_fp8x4_storage_t h4b_lut = 0x03020100U; // 7654 + const __nv_fp8x4_storage_t l4b_lut = 0xFFFEFC00U; // 3210 + + __nv_fp8x4_storage_t lut_res = prmt(h4b_lut, l4b_lut, index); + + return lut_res; +} + +__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8(const __nv_fp4x8_storage_t fp4x8) +{ + __nv_bf16x8_storage_t bf16x8_raw = {0, 0}; + __nv_bf16x2_storage_t* bf16x2_raw = reinterpret_cast<__nv_bf16x2_storage_t*>(&bf16x8_raw); + + unsigned zero_padding = 0x00000000U; + + unsigned h4b_em_fp4x4 = (fp4x8 & 0x77770000U) >> 16U; + unsigned l4b_em_fp4x4 = (fp4x8 & 0x00007777U); + + __nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_bf16(h4b_em_fp4x4); // 7654 + __nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_bf16(l4b_em_fp4x4); // 3210 + + bf16x2_raw[0] = prmt(zero_padding, l4b_2to9_bits, 0x1707U) >> 2U; // 1 0 + bf16x2_raw[1] = prmt(zero_padding, l4b_2to9_bits, 0x3727U) >> 2U; // 3 2 + bf16x2_raw[2] = prmt(h4b_2to9_bits, zero_padding, 0x5040U) >> 2U; // 5 4 + bf16x2_raw[3] = prmt(h4b_2to9_bits, zero_padding, 0x7060U) >> 2U; // 7 6 + + __nv_bf16x2_storage_t bf16x2_0to1_bits; + + __nv_fp8x4_storage_t h_fp8x2_0to1_bits = (fp4x8 & 0x0000C0C0U); // 3 1 + __nv_fp8x4_storage_t l_fp8x2_0to1_bits = (fp4x8 & 0x00000C0CU) << 4U; // 2 0 + + bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x4707U); // 1 0 + bf16x2_raw[0] = bf16x2_raw[0] | bf16x2_0to1_bits; + bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x5717U); // 3 2 + bf16x2_raw[1] = bf16x2_raw[1] | bf16x2_0to1_bits; + + h_fp8x2_0to1_bits = (fp4x8 & 0xC0C00000U); // 7 5 + l_fp8x2_0to1_bits = (fp4x8 & 0x0C0C0000U) << 4U; // 6 4 + + bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x6020U); // 5 4 + bf16x2_raw[2] = bf16x2_raw[2] | bf16x2_0to1_bits; + bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x7030U); // 7 6 + bf16x2_raw[3] = bf16x2_raw[3] | bf16x2_0to1_bits; + + return bf16x8_raw; +} + template struct MixedGroupedGemmInputUtils { @@ -46,6 +118,7 @@ struct MixedGroupedGemmInputUtils static constexpr auto KernelConversionMode = Collective::KernelConversionMode; static constexpr auto ModeHasScales = Collective::ModeHasScales; static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable; + static constexpr auto UseFP4ToBF16LookupTable = Collective::UseFP4ToBF16LookupTable; public: static constexpr auto elements_per_smem_scale() @@ -239,6 +312,27 @@ struct MixedGroupedGemmInputUtils } } + // The core converter uses a lookup table to converts i4 -> 8 bit value. + template + CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert( // Accept mutable temporaries + Tensor const& src, Tensor&& dst) + { + fp4tobf16_lookup_table_convert(src, dst); + } + + template + CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert( + Tensor const& src, Tensor& dst) + { + + // View the input as reg + auto&& src_ = cute::recast<__nv_fp4x8_storage_t>(src)(0); + auto&& dst_ = cute::recast<__nv_bf16x8_storage_t>(dst)(0); + + dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8(src_); + } + /// Utilities to dequantize A. template CUTLASS_DEVICE static void static_check_scale(Layout const& tensor) @@ -253,7 +347,6 @@ struct MixedGroupedGemmInputUtils static_check_scale(flatten(Layout{})); } - // dequantize_A_kblock is here!!! template CUTLASS_DEVICE static void dequantize_A_kblock(Tensor const& tCrA_load, Tensor& tCrA_mma, cute::tuple& partitioned_extra_info, int const k_block) @@ -288,8 +381,6 @@ struct MixedGroupedGemmInputUtils } else if constexpr (UseScaleLookupTable) { - // this path - constexpr int num_elements = decltype(size(src))::value; static_assert(is_same_v, "Lookup table only supports int4 being the quant type now."); @@ -424,7 +515,6 @@ struct MixedGroupedGemmInputUtils static_assert(size_v == cosize_v); static_assert(size_v == cosize_v); using SrcType = typename EngineIn::value_type; - using DstType = typename EngineOut::value_type; Tensor src = tCrA_load(_, _, k_block); Tensor dst = tCrA_mma(_, _, k_block); @@ -441,7 +531,14 @@ struct MixedGroupedGemmInputUtils CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<1>(dst_vm); ++i) { - LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + if constexpr (UseFP4ToBF16LookupTable) + { + fp4tobf16_lookup_table_convert(src_vm(_, i), dst_vm(_, i)); + } + else + { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + } } } diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp deleted file mode 100644 index 09ae3e013ee..00000000000 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp +++ /dev/null @@ -1,568 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/fast_math.h" - -#include "cute/numeric/numeric_types.hpp" -#include "cute/tensor.hpp" -#include "cutlass/trace.h" - -#include "cutlass_extensions/arch/copy_red_global.hpp" -#include "cutlass_extensions/util/gather_tensor.hpp" - -#include "cutlass/epilogue/collective/builders/sm90_builder.inl" -#include "cutlass/epilogue/collective/builders/sm90_common.inl" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace epilogue -{ -namespace collective -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class EpilogueMoeFusedFinalize -{ -public: - using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized; - using DispatchPolicy = PtrArrayNoSmemWarpSpecialized; - - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementIntermediate = typename ThreadEpilogueOp::ElementD; - - using ElementC = typename ThreadEpilogueOp::ElementC; - using StrideC = StrideC_; - using InternalStrideC = cute::remove_pointer_t; - using ElementD = ElementD_; - using StrideD = StrideD_; - using InternalStrideD = cute::remove_pointer_t; - - static_assert(!is_same_v, "Stride C must be a pointer"); - static_assert(is_same_v, "Stride D must not be a pointer"); - - using CopyAtomR2S = Copy_Atom; - using CopyAtomS2R = Copy_Atom; - using CopyAtomR2G = Copy_Atom; - static constexpr int AlignmentD = CopyAtomR2G::NumValSrc; - - using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{})); - - constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); - - struct SharedStorage - { - alignas(SmemAlignmentD) cute::ArrayEngine> smem_D; - }; - - struct TensorMapStorage - { - }; - - struct Arguments - { - typename ThreadEpilogueOp::Params thread{}; - ElementC const** ptr_C{}; - StrideC dC{}; - ElementD* ptr_D{}; - StrideD dD{}; - ElementBias const* ptr_bias; - StrideBias dBias{}; - ElementScale const* ptr_scale; - StrideScale dScale{}; - int64_t const* group_offset{}; - int32_t const* scatter_index{}; - cutlass::FastDivmod num_rows_in_final_output; - }; - - using Params = Arguments; - - // - // Methods - // - - template - static constexpr Params to_underlying_arguments( - ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace) - { - return args; - } - - template - static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0) - { - return 0; - } - - template - static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, - void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) - { - return cutlass::Status::kSuccess; - } - - template - CUTLASS_HOST_DEVICE static bool can_implement( - [[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args) - { - bool implementable = true; - if (problem_shape.is_host_problem_shape_available()) - { - // Check alignment for all problem sizes - for (int i = 0; i < problem_shape.groups(); i++) - { - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); - auto [M, N, K, L] = problem_shape_MNKL; - implementable = implementable - && cutlass::detail::check_alignment(cute::make_shape(M, N, L), InternalStrideD{}); - } - } - - if (!implementable) - { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global " - "reduction instruction.\n"); - } - return implementable; - } - - CUTLASS_HOST_DEVICE - EpilogueMoeFusedFinalize(Params const& params_) - : params(params_) - { - } - - CUTLASS_DEVICE - bool is_source_needed() - { - // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. - return params.ptr_C != nullptr - && (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0); - } - - template - CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, cute::Tensor const& accumulators, TiledMma tiled_mma, - ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf) - { - using namespace cute; - using X = Underscore; - - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - - auto synchronize = [&]() - { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - auto mma_tile_m = tile_size<0>(tiled_mma); - auto mma_tile_n = tile_size<1>(tiled_mma); - auto epi_tile_m = size<0>(EpilogueTile{}); - auto epi_tile_n = size<1>(EpilogueTile{}); - - CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); - CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); - - // Batches are managed by using appropriate pointers to C and D matrices - int32_t const mock_L = 1; - int32_t const mock_l_coord = 0; - - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - - // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. - // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, - // we get the correct alpha/beta values for the current batch/group using group index. - ThreadEpilogueOp epilogue_op(params.thread, l_coord); - - SharedStorage& storage = *reinterpret_cast(smem_buf); - - Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{}); - Tensor sD = as_position_independent_swizzle_tensor(sD_); - - // Function to scatter output rows - auto& num_rows = params.num_rows_in_final_output; - auto read_scatter_map = tensorrt_llm::cutlass_extensions::IndexedGather( - make_gmem_ptr(params.scatter_index + params.group_offset[l_coord])); - auto get_scatter_idx = [&](auto i) - { - auto scatter = read_scatter_map(i); - int quot, rem; - num_rows(quot, rem, scatter); - return rem; - }; - - // Represent the full output tensor - ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr; - auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{}; - Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l) - Tensor mD_mnl = tensorrt_llm::cutlass_extensions::make_gather_tensor( - make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l) - - // Use fake shape for bias, it doesn't matter - bool const is_bias_needed = params.ptr_bias != nullptr; - Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias); - Tensor mScale_mnl = make_tensor( - make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale); - - Tensor gC_mnl - = local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl - = local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) - - Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - Tensor gBias_mnl - = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gScale_mnl - = local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N) - Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N) - - Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Get the smallest tiled copy we can use to retile the accumulators - TiledCopy tiled_copy_C_atom - = make_tiled_copy_C_atom(Copy_Atom{}, tiled_mma); - TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom); - - auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); - Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) - Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N) - Tensor tRS_rD = make_tensor(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N) - - // Make a tiled copy vectorized along major direction of D - auto tiled_s2r = [&]() - { - if constexpr (cutlass::gemm::detail::is_k_major()) - { - constexpr int NumThreadsMajor = epi_tile_n / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; - return make_tiled_copy(CopyAtomS2R{}, - Layout, Int>, Stride, _1>>{}, - Layout>>{}); - } - else if constexpr (cutlass::gemm::detail::is_mn_major()) - { - constexpr int NumThreadsMajor = epi_tile_m / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; - return make_tiled_copy(CopyAtomS2R{}, - Layout, Int>, Stride<_1, Int>>{}, - Layout, _1>>{}); - } - else - { - static_assert(cute::is_void_v, "Unsupported D gmem layout."); - } - }(); - - auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); - Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - - // Allocate intermediate registers for a single subtile - Tensor tSR_rD = make_tensor(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rD_final = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rC = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rBias = make_tensor(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rScale = make_tensor(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) - - // Make an identity coordinate tensor for predicating our output MN tile - Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); - Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - - // epilogue subtile loop - CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) - { - CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) - { - int mma_m = (epi_m * epi_tile_m) / mma_tile_m; - int mma_n = (epi_n * epi_tile_n) / mma_tile_n; - Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n); - - int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); - int r2s_v = epi_n_in_mma * size(tRS_rD); - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v) - { - tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v); - } - - copy(tiled_r2s, tRS_rD, tRS_sD); - synchronize(); - - copy(tiled_s2r, tSR_sD, tSR_rD); - synchronize(); - - Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n); - Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n); - Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n); - Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n); - Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n); - - if (epilogue_op.is_source_needed()) - { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_rD); ++m) - { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_rD); ++n) - { - if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) - { - copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n)); - if (is_bias_needed) - { - copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); - } - copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rD); ++i) - { - auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n)); - if (is_bias_needed) - { - epi_value += static_cast(tSR_rBias(i, m, n)); - } - tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); - } - copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); - } - } - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_rD); ++m) - { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_rD); ++n) - { - if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) - { - if (is_bias_needed) - { - copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); - } - copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rD); ++i) - { - auto epi_value = epilogue_op(tSR_rD(i, m, n)); - if (is_bias_needed) - { - epi_value += static_cast(tSR_rBias(i, m, n)); - } - tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); - } - copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); - } - } - } - } - } - } - } - -private: - Params params; -}; - -namespace detail -{ - -template -constexpr auto get_vectorized_atomic_add_op() -{ - using namespace cute; - - auto constexpr MaxVecSize = size(MaxVec{}); - - if constexpr (is_same_v) - { - if constexpr (MaxVecSize >= 8) - { - return SM90_RED_ADD_NOFTZ_F16x2_V4{}; - } - else if constexpr (MaxVecSize >= 4) - { - return SM90_RED_ADD_NOFTZ_F16x2_V2{}; - } - else if constexpr (MaxVecSize >= 2) - { - return SM70_RED_ADD_NOFTZ_F16x2{}; - } - else - { - return SM70_RED_ADD_NOFTZ_F16{}; - } - } - else if constexpr (is_same_v) - { - if constexpr (MaxVecSize >= 8) - { - return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; - } - else if constexpr (MaxVecSize >= 4) - { - return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; - } - else if constexpr (MaxVecSize >= 2) - { - return SM90_RED_ADD_NOFTZ_BF16x2{}; - } - else - { - return SM90_RED_ADD_NOFTZ_BF16{}; - } - } - else - { - // non-vectorized atomic add for all other types until supported - return TypedAtomicAdd{}; - } -} - -} // namespace detail - -template -struct EpilogueMoeFusedFinalizeBuilder -{ - - // assuming cooperative kernel schedule - using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{})); - using EpilogueTile = Shape<_128, EpiTileN>; - - // Output of linear combination is ElementCompute instead of ElementD - // since we will be doing more computate on it, no need to cast yet. - using ThreadEpilogueOp - = cutlass::epilogue::thread::LinearCombination; - - using SmemLayoutAtomD - = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()); - using CopyAtomR2S - = decltype(detail::sm90_get_smem_store_op_for_accumulator()); - using CopyAtomS2R = DefaultCopy; - using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op()); - - template - struct TmaWarpSpecializedAdapterWithSmemStorageImpl : Base - { - // We need to override this one using declaration because otherwise we double up on the smem - using TensorMapStorage = typename EpilogueOp::TensorMapStorage; - - // using Base = detail::Sm90TmaWarpSpecializedAdapter; - - CUTLASS_HOST_DEVICE - TmaWarpSpecializedAdapterWithSmemStorageImpl( - typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors) - : Base(params) - { - } - - CUTLASS_DEVICE auto load_init([[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] int32_t sm_count, - [[maybe_unused]] int32_t sm_idx) - { - return cute::make_tuple(nullptr); - } - - CUTLASS_DEVICE auto store_init([[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] int32_t sm_count, - [[maybe_unused]] int32_t sm_idx, [[maybe_unused]] int32_t warp_group_idx) - { - return cute::make_tuple(nullptr); - } - - // Dummy methods to perform different parts of TMA/Tensormap modifications - - template - CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] ProblemShapeMNKL problem_shape, - [[maybe_unused]] int32_t next_batch, [[maybe_unused]] int32_t warp_group_idx) - { - } - - template - CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t warp_group_idx) - { - } - - template - CUTLASS_DEVICE void tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) - { - } - }; - - template - using TmaWarpSpecializedAdapterWithSmemStorage = TmaWarpSpecializedAdapterWithSmemStorageImpl< - std::conditional_t= 100, detail::Sm100TmaWarpSpecializedAdapter, - detail::Sm90TmaWarpSpecializedAdapter>, - EpilogueOp>; - - using CollectiveOp = TmaWarpSpecializedAdapterWithSmemStorage< - EpilogueMoeFusedFinalize>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp new file mode 100644 index 00000000000..3571906a64f --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp @@ -0,0 +1,547 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +#include "cutlass_extensions/arch/copy_red_global.hpp" +#include "cutlass_extensions/util/gather_tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// clang-format off + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +template < + class EpilogueTile, + class StrideOutput, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct Sm90ScatterPtrArray { + + using SmemShape = decltype(make_shape(size(make_layout(get<0>(EpilogueTile{}))), size(make_layout(get<1>(EpilogueTile{}))))); + using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{}, SmemShape{})); + + using ElementIndex = int32_t; + // TODO: more generic treatment, or pass StrideIndex via template param? + using StrideIndex = conditional_t(), Stride<_0,_1,_0>, Stride<_1,_0,_0>>; + + struct SharedStorage {}; + + struct Arguments { + ElementOutput* ptr_out = nullptr; + StrideOutput dOut = {}; + ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index + int index_modulo{}; // modulo used to transform the index before store + bool use_reduction = true; + }; + + struct Params { + ElementOutput* ptr_out = nullptr; + StrideOutput dOut = {}; + ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index + cutlass::FastDivmod index_divmod{}; // modulo used to transform the index before store + bool use_reduction = true; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return { + args.ptr_out, + args.dOut, + args.ptr_index, + cutlass::FastDivmod(args.index_modulo), + args.use_reduction + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90ScatterPtrArray() { } + + CUTLASS_HOST_DEVICE + Sm90ScatterPtrArray(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class ArgsTuple + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple) + : args_tuple(std::move(args_tuple)) {} + + ArgsTuple args_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple; + + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Tensor tC_rOut_frg = recast>(coalesce(tC_rOut)); // (EPI_V) + tC_rOut_frg(epi_v) = convert_input(frg_input); + + return tC_rOut_frg(epi_v); + } + + template + CUTLASS_DEVICE void + reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + + auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple; + + Tensor byte_buffer = recast(reduction_buffer); + static_assert(cosize(byte_buffer.layout()) * sizeof_bits_v >= cosize(SmemLayout{}) * sizeof_bits_v, + "Not enough space in scratch smem buffer"); + + Tensor sOut = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(recast_ptr(byte_buffer.data())), SmemLayout{})); + + auto thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_sOut_epi = thread_r2s.partition_D(sOut); + Tensor tRS_rOut_epi = thread_r2s.retile_S(tC_rOut); + + auto thread_r2g = tiled_r2g_red.get_slice(thread_idx); + Tensor tRG_gOut_epi = tRG_gOut(_,_,_,epi_m,epi_n); + Tensor tRG_sOut_epi = thread_r2g.partition_D(sOut); + Tensor tRG_rOut_epi = thread_r2g.retile_S(make_tensor(tC_rOut.data(), shape(tRG_sOut_epi))); // reuse D registers + + // sanity check for register reuse + CUTE_STATIC_ASSERT_V(cosize(tC_rOut.layout()) == cosize(tRG_rOut_epi.layout()), "Invalid register count for R2G"); + + copy(tiled_r2s, tRS_rOut_epi, tRS_sOut_epi); + sync_fn(); + copy(tRG_sOut_epi, tRG_rOut_epi); + + auto residue = residue_cD; // capturing structured bindings is a C++20 feature + Tensor tRG_cD_epi = tRG_cD(0,_,_,epi_m,epi_n); + auto pred = cute::lazy::transform(tRG_cD_epi, [&](auto c){ return elem_less(c, residue); }); + + if (use_reduction) { + copy_if(tiled_r2g_red, pred, tRG_rOut_epi, tRG_gOut_epi); + } + else { + copy_if(tiled_r2g_stg, pred, tRG_rOut_epi, tRG_gOut_epi); + } + } + }; + + template + static constexpr auto get_reduction_op() + { + using namespace cute; + + // For now only support red.add + if constexpr (is_same_v) { + if constexpr (MaxVecSize % 8 == 0) { + return SM90_RED_ADD_NOFTZ_F16x2_V4{}; + } + else if constexpr (MaxVecSize % 4 == 0) { + return SM90_RED_ADD_NOFTZ_F16x2_V2{}; + } + else if constexpr (MaxVecSize % 2 == 0) { + return SM70_RED_ADD_NOFTZ_F16x2{}; + } + else { + return SM70_RED_ADD_NOFTZ_F16{}; + } + } + else if constexpr (is_same_v) { + if constexpr (MaxVecSize % 8 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; + } + else if constexpr (MaxVecSize % 4 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; + } + else if constexpr (MaxVecSize % 2 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2{}; + } + else { + return SM90_RED_ADD_NOFTZ_BF16{}; + } + } + else { + // non-vectorized atomic add for all other types until supported + return TypedAtomicAdd{}; + } + } + + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto index_read = [index = params_ptr->ptr_index[l], divmod = params_ptr->index_divmod](auto i){ return divmod.rem(index[i]); }; + Tensor mOut = cutlass::util::make_gather_tensor(params_ptr->ptr_out, make_shape(M,N,Int<1>{}), params_ptr->dOut, index_read); // (M,N,_1) + Tensor gOut = local_tile(mOut, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N) + Tensor gOut_epi = flat_divide(gOut, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor mIdx = make_tensor(params_ptr->ptr_index[l], make_shape(M,N,Int<1>{}), StrideIndex{}); // (M,N,_1) + Tensor gIdx = local_tile(mIdx, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N) + Tensor gIdx_epi = flat_divide(gIdx, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor cD_epi = flat_divide(args.cD, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor tC_gOut = sm90_partition_for_epilogue(gOut, args.epi_tile, args.tiled_copy, args.thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Tensor tC_rOut = make_tensor(take<0,3>(shape(tC_gOut))); // (CPY,CPY_M,CPY_N) + + auto tiled_r2s = conditional_return( + make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) + ); + + // Vectorization must not exceed alignment and also the number of values per thread in the tile + int constexpr NumThreads = CUTE_STATIC_V(size(args.tiled_copy)); + int constexpr NumValTile = product(take<0,2>(shape(cD_epi))); + int constexpr MaxVecSize = cute::min(AlignmentOutput, NumValTile / NumThreads); + + // Choose the largest available red.global op and an st.global op with matching vectorization + using CopyOpR2GRed = decltype(get_reduction_op()); + using CopyOpR2GStg = UniversalCopy::NumValSrc * sizeof_bits_v>>; + + auto make_tiled_r2g = [&](auto copy_op) + { + using CopyAtomR2G = Copy_Atom; + constexpr int VecSize = CopyAtomR2G::NumValSrc; + if constexpr (cutlass::gemm::detail::is_k_major()) { + constexpr int ThreadsMajor = size<1>(args.epi_tile) / VecSize; + constexpr int ThreadsMinor = NumThreads / ThreadsMajor; + return make_tiled_copy(CopyAtomR2G{}, + Layout, Int>, Stride, _1>>{}, + Layout>>{}); + } + else if constexpr (cutlass::gemm::detail::is_mn_major()) { + constexpr int ThreadsMajor = size<0>(args.epi_tile) / VecSize; + constexpr int ThreadsMinor = NumThreads / ThreadsMajor; + return make_tiled_copy(CopyAtomR2G{}, + Layout, Int>, Stride<_1, Int>>{}, + Layout, _1>>{}); + } + else { + static_assert(cute::is_void_v, "Unsupported D gmem layout."); + } + }; + + auto tiled_r2g_red = make_tiled_r2g(CopyOpR2GRed{}); + auto tiled_r2g_stg = make_tiled_r2g(CopyOpR2GStg{}); + + // Sanity checks - since we will be using one tiled copy with tensors partitioned with the other tiled copy, + // ensure they have matching layouts/tilers + using TiledR2GRed = decltype(tiled_r2g_red); + using TiledR2GStg = decltype(tiled_r2g_stg); + static_assert(typename TiledR2GRed::AtomLayoutSrc{} == typename TiledR2GStg::AtomLayoutSrc{}, "Mismatching AtomLayoutSrc"); + static_assert(typename TiledR2GRed::AtomLayoutDst{} == typename TiledR2GStg::AtomLayoutDst{}, "Mismatching AtomLayoutDst"); + static_assert(typename TiledR2GRed::TiledLayout_TV{} == typename TiledR2GStg::TiledLayout_TV{}, "Mismatching TiledLayout_TV"); + static_assert(typename TiledR2GRed::Tiler_MN{} == typename TiledR2GStg::Tiler_MN{}, "Mismatching Tiler_MN"); + + auto thread_r2g = tiled_r2g_red.get_slice(args.thread_idx); + Tensor tRG_gOut = thread_r2g.partition_D(gOut_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N) + Tensor tRG_cD = thread_r2g.partition_D(cD_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N) + + auto args_tuple = make_tuple( + cute::move(tC_rOut), + tiled_r2s, + tRG_gOut, + tRG_cD, + tiled_r2g_red, + tiled_r2g_stg, + params_ptr->use_reduction, + args.thread_idx, + args.residue_cD); + + return ConsumerStoreCallbacks(std::move(args_tuple)); + } +}; + +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerRowBias + : ScaledAcc +{ + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerRowBiasSupported = true; +}; + +template< + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerRowBiasPerColScaleScatter + : ScaledAccPerRowBias +{ + using ElementAux = ElementOutput; + using GmemLayoutTagAux = GmemLayoutTagOut; + static constexpr int AlignmentAux = AlignmentOutput; + static constexpr bool IsAuxOutSupported = true; +}; + +// D = alpha * acc + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerRowBiasPtrArray = + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcastPtrArray>, // alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + >; + +template< + class CtaTileShapeMNK, + class EpilogueTile, + class StrideOutput, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray = + Sm90EVT, // scatter store + Sm90EVT, // scale * (alpha * acc + bias) + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar *, ElementCompute, Stride<_0,_1,int64_t>, 1>, // scale + Sm90ScaledAccPerRowBiasPtrArray // alpha * acc + bias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementScale, + class ElementScalar, + int AlignmentBias, + int AlignmentOutput, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::ScaledAccPerRowBiasPerColScaleScatter, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + > { + + using StrideOutput = cutlass::gemm::TagToStrideC_t; + + using Impl = Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + StrideOutput, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + >; + using Operation = fusion::ScaledAccPerRowBiasPerColScaleScatter< + GmemLayoutTagOut, + ElementOutput, + ElementCompute, + ElementBias, + ElementScale, + ElementScalar, + AlignmentBias, + AlignmentOutput, + RoundStyle>; + + struct Arguments { + + using StrideAlpha = Stride<_0,_0,int64_t>; + ElementScalar alpha = ElementScalar(1); + ElementScalar const* alpha_ptr{}; + ElementScalar const* const* alpha_ptr_array{}; + StrideAlpha dAlpha{}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* const* bias_ptr{}; + StrideBias dBias{}; + + using StrideScale = Stride<_0,_1,int64_t>; + ElementScalar const* const* scale_ptr_array{}; + StrideScale dScale{}; + + // Nested args not usable due to a compiler bug with constexpr evaluation + // using ScatterArguments = typename Sm90ScatterPtrArray::Arguments; + // ScatterArguments scatter{}; + + ElementOutput* ptr_out = nullptr; + StrideOutput dOut = {}; + int const* const* ptr_index{}; // per-group pointer to the scatter index + int index_modulo{}; // modulo used to transform the index before store + bool use_reduction = true; + + operator typename Impl::Arguments() const { + return + { // unary op: reduce(scale * (beta * C + (alpha * acc))) + { // binary op: scale * (beta * C + (alpha * acc)) + { scale_ptr_array, ElementScalar(1), dScale }, // leaf args : scale broadcast + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end binary op + {} // binary args: multiply + }, // end binary op + //scatter // unary args: reduce + { ptr_out, dOut, ptr_index, index_modulo, use_reduction } + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; + +}; + +} // namespace cutlass::epilogue::fusion + +// clang-format on diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp index 1ee109fd648..2332950629f 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp @@ -30,37 +30,12 @@ #include "cute/atom/mma_atom.hpp" #include "cute/numeric/arithmetic_tuple.hpp" -#define GROUP_SIZE 128 - ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { using namespace cute; -template -CUTE_HOST_DEVICE void warpgroup_wait_() -{ -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_warpgroup_wait(__LINE__, N); - asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory"); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif -} - -CUTLASS_DEVICE void warpgroup_wait_dispatch(int onthefly_count) -{ - switch (onthefly_count) - { - case 0: warpgroup_wait_<0>(); break; - case 4: warpgroup_wait_<4>(); break; - case 8: warpgroup_wait_<8>(); break; - case 12: warpgroup_wait_<12>(); break; - default: assert(false && "Invalid onthefly_count value"); - } -} - ///////////////////////////////////////////////////////////////////////////////////////////////// // WarpSpecialized Mainloop @@ -91,7 +66,7 @@ struct CollectiveMmaArrayMixedInput< private: template friend struct detail::MixedGroupedGemmInputUtils; - using CollectiveType = CollectiveMma; using Utils = detail::MixedGroupedGemmInputUtils; @@ -146,6 +121,11 @@ struct CollectiveMmaArrayMixedInput< static_assert(cutlass::gemm::detail::is_mn_major(), "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); + static constexpr bool IsMXFP4 = cute::is_same_v; + // Group size 128 for int4 weights + // Group size 32 for mxfp4 weights + static constexpr int ScalingGroupSize = IsMXFP4 ? detail::mxfp4_group_size : detail::int4_group_size; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); using TiledMma = TiledMma_; using ElementAccumulator = typename TiledMma::ValTypeC; @@ -268,6 +248,8 @@ struct CollectiveMmaArrayMixedInput< || KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && cutlass::detail::is_Array_v; + static constexpr bool UseFP4ToBF16LookupTable = KernelConversionMode == ConversionMode::ConvertAndScale + && cute::is_same_v && cute::is_same_v; static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); @@ -705,7 +687,7 @@ struct CollectiveMmaArrayMixedInput< { // The real scale_k that actually works // auto scale_k = K / mainloop_params.chunk_size; - auto scale_k = K / GROUP_SIZE; + auto scale_k = K / ScalingGroupSize; Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l) Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l) @@ -872,7 +854,6 @@ struct CollectiveMmaArrayMixedInput< } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - // zero copy auto tZgZ = get<2>(extra_input_partitions); auto tZsZ = get<3>(extra_input_partitions); if (cute::elect_one_sync()) @@ -979,7 +960,8 @@ struct CollectiveMmaArrayMixedInput< return make_tensor_like(tCsA(_, _, _, Int<0>{})); } }(); - Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + // tCrB is just a view of the tensor tCsB Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) // @@ -1013,8 +995,8 @@ struct CollectiveMmaArrayMixedInput< multiply_add fma; - constexpr int NumMMAsPerChunk = GROUP_SIZE / cute::get<0, 1>(tCsB.shape())(); - constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / GROUP_SIZE; + constexpr int NumMMAsPerChunk = ScalingGroupSize / cute::get<0, 1>(tCsB.shape())(); + constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / ScalingGroupSize; cute::array intermediate_array; constexpr int K_BLOCK_MAX = size<2>(tCrA_load); @@ -1045,8 +1027,6 @@ struct CollectiveMmaArrayMixedInput< // src: tCrA_load, dst: tCrA_mma Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id) @@ -1079,10 +1059,11 @@ struct CollectiveMmaArrayMixedInput< } } + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) { - warpgroup_wait_dispatch((NumChunksPerTileK - chunk_id_ - 1) * NumMMAsPerChunk); warpgroup_fence_operand(intermediate_array[chunk_id_]); // Apply the group-wise scaling @@ -1129,7 +1110,6 @@ struct CollectiveMmaArrayMixedInput< Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); - warpgroup_wait(); Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); } } @@ -1169,8 +1149,6 @@ struct CollectiveMmaArrayMixedInput< tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); - warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, - // so we can release prior barrier if (k_block == K_BLOCK_MAX - 1) { pipeline.consumer_release( @@ -1187,10 +1165,11 @@ struct CollectiveMmaArrayMixedInput< { // The last k_block + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) { - warpgroup_wait_dispatch((NumChunksPerTileK - chunk_id_ - 1) * NumMMAsPerChunk); warpgroup_fence_operand(intermediate_array[chunk_id_]); // Apply the group-wise scaling @@ -1257,7 +1236,6 @@ struct CollectiveMmaArrayMixedInput< tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); - warpgroup_wait(); if (k_block == K_BLOCK_MAX - 1) { // release prior barrier @@ -1318,7 +1296,7 @@ struct CollectiveMmaArrayMixedInput< smem_pipe_release.advance(k_tile_count); // Wait on all GMMAs to complete - warpgroup_wait<0>(); + // warpgroup_wait<0>(); for (int count = 0; count < prologue_mma_count; ++count) { @@ -1462,7 +1440,7 @@ struct CollectiveMmaArrayMixedInput< { NonVoidElementScale const* ptr_S = nullptr; // auto scale_k = K / mainloop_params.chunk_size; - auto scale_k = K / GROUP_SIZE; + auto scale_k = K / ScalingGroupSize; Tensor tensor_scale = make_tensor( detail::get_logical_ptr(ptr_S), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]); cute::detail::fill_tma_gmem_shape_stride( @@ -1472,7 +1450,7 @@ struct CollectiveMmaArrayMixedInput< { ElementZero const* ptr_Z = nullptr; // auto scale_k = K / mainloop_params.chunk_size; - auto scale_k = K / GROUP_SIZE; + auto scale_k = K / ScalingGroupSize; Tensor tensor_zero = make_tensor( detail::get_logical_ptr(ptr_Z), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]); cute::detail::fill_tma_gmem_shape_stride( diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h index f9355860bec..fe75687e368 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -133,11 +133,6 @@ enum class CutlassTileConfigSM100 CtaShape128x256x128B, CtaShape128x128x256B, CtaShape128x256x256B, - - // M=256 - CtaShape256x64x128B, - CtaShape256x128x128B, - CtaShape256x256x128B, }; enum class CutlassTileConfigSM120 diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp index e529ffc1faa..a83bf6a0830 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp @@ -19,7 +19,7 @@ #include "cute/tensor.hpp" #include "cute/util/print.hpp" -namespace tensorrt_llm::cutlass_extensions +namespace cutlass::util { /// Function object that applies an index to its argument @@ -81,7 +81,7 @@ struct CustomStride template CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) { - return CustomStride(s.func_, safe_div(s.stride_, div)); + return CustomStride(s.func_, cute::safe_div(s.stride_, div)); } // Circumvent the requirement on make_layout that shape and stride are integral @@ -116,7 +116,7 @@ CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, S Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); } -} // namespace tensorrt_llm::cutlass_extensions +} // namespace cutlass::util namespace cute { diff --git a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt index 088391aef4f..870c1f317c6 100644 --- a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt +++ b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt @@ -1,4 +1,4 @@ -set(DEEP_EP_COMMIT edf3ea2b086a393d3163bf2773eab69d9191cc01) +set(DEEP_EP_COMMIT 515a311f290eb6d9592fcccfcc80c40f5123ca72) set(NVSHMEM_URL_HASH SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a) @@ -19,8 +19,15 @@ foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES) set(CUDA_ARCH_MINOR ${CMAKE_MATCH_2}) set(CUDA_ARCH_POSTFIX ${CMAKE_MATCH_3}) if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 9) - list(APPEND DEEP_EP_CUDA_ARCHITECTURES - "${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}${CUDA_ARCH_POSTFIX}") + # The FP4-related conversion instructions in DeepEP require SM100a, SM110a, + # or SM120a. + if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 10 AND ${CUDA_ARCH_MINOR} EQUAL 0) + list(APPEND DEEP_EP_CUDA_ARCHITECTURES + "${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}a${CUDA_ARCH_POSTFIX}") + else() + list(APPEND DEEP_EP_CUDA_ARCHITECTURES + "${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}${CUDA_ARCH_POSTFIX}") + endif() endif() endforeach() diff --git a/cpp/tensorrt_llm/executor/executor.cpp b/cpp/tensorrt_llm/executor/executor.cpp index 70ca2be41ab..091bb512823 100644 --- a/cpp/tensorrt_llm/executor/executor.cpp +++ b/cpp/tensorrt_llm/executor/executor.cpp @@ -132,10 +132,12 @@ std::optional> Executor::getKVCacheEventMan return mImpl->getKVCacheEventManager(); } -KVCacheEvent::KVCacheEvent(size_t eventId, KVCacheEventData data, SizeType32 windowSize) +KVCacheEvent::KVCacheEvent( + size_t eventId, KVCacheEventData data, SizeType32 windowSize, std::optional attentionDpRank) : eventId{eventId} , data{std::move(data)} , windowSize{windowSize} + , attentionDpRank{attentionDpRank} { } diff --git a/cpp/tensorrt_llm/executor/kvCacheConfig.cpp b/cpp/tensorrt_llm/executor/kvCacheConfig.cpp index 51b047ebd27..21cf314c875 100644 --- a/cpp/tensorrt_llm/executor/kvCacheConfig.cpp +++ b/cpp/tensorrt_llm/executor/kvCacheConfig.cpp @@ -27,6 +27,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional co std::optional const& hostCacheSize, bool onboardBlocks, std::optional const& crossKvCacheFraction, std::optional secondaryOffloadMinPriority, size_t eventBufferMaxSize, bool enablePartialReuse, bool copyOnPartialReuse, bool useUvm, + SizeType32 attentionDpEventsGatherPeriodMs, std::optional const& runtimeDefaults) : mEnableBlockReuse(enableBlockReuse) , mHostCacheSize(hostCacheSize) @@ -36,6 +37,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional co , mEnablePartialReuse{enablePartialReuse} , mCopyOnPartialReuse{copyOnPartialReuse} , mUseUvm{useUvm} + , mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs) { if (maxTokens) { @@ -61,6 +63,8 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional co { fillEmptyFieldsFromRuntimeDefaults(runtimeDefaults.value()); } + TLLM_CHECK_WITH_INFO( + mAttentionDpEventsGatherPeriodMs > 0, "Attention DP events gather period must be greater than 0"); } bool KvCacheConfig::getEnableBlockReuse() const @@ -128,6 +132,11 @@ bool KvCacheConfig::getUseUvm() const return mUseUvm; } +SizeType32 KvCacheConfig::getAttentionDpEventsGatherPeriodMs() const +{ + return mAttentionDpEventsGatherPeriodMs; +} + void KvCacheConfig::setEnableBlockReuse(bool enableBlockReuse) { mEnableBlockReuse = enableBlockReuse; @@ -204,6 +213,12 @@ void KvCacheConfig::setUseUvm(bool useUvm) mUseUvm = useUvm; } +void KvCacheConfig::setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs) +{ + TLLM_CHECK(attentionDpEventsGatherPeriodMs > 0); + mAttentionDpEventsGatherPeriodMs = attentionDpEventsGatherPeriodMs; +} + void KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults) { if (!mMaxAttentionWindowVec && runtimeDefaults.maxAttentionWindowVec) diff --git a/cpp/tensorrt_llm/executor/loraConfig.cpp b/cpp/tensorrt_llm/executor/loraConfig.cpp index 058b1a86710..c8499f36d4d 100644 --- a/cpp/tensorrt_llm/executor/loraConfig.cpp +++ b/cpp/tensorrt_llm/executor/loraConfig.cpp @@ -27,26 +27,29 @@ LoraConfig::LoraConfig(IdType taskId, std::optional weights, std::option , mWeights(std::move(weights)) , mConfig(std::move(config)) { - if (mWeights.has_value() || mConfig.has_value()) + if (mConfig.has_value()) { - TLLM_CHECK_WITH_INFO(mWeights.has_value() && mConfig.has_value(), - "Request for LoRA inference must have both lora weights and lora config"); - - SizeType32 constexpr expectedWeightsDims = 2; SizeType32 constexpr expectedConfigDims = 2; - - TLLM_CHECK_WITH_INFO( - mWeights.value().getShape().size() == expectedWeightsDims, "Expected weights tensor to have 2 dimensions"); TLLM_CHECK_WITH_INFO( mConfig.value().getShape().size() == expectedConfigDims, "Expected config tensor to have 2 dimensions"); - TLLM_CHECK_WITH_INFO(mWeights.value().getMemoryType() != MemoryType::kGPU - && mWeights.value().getMemoryType() != MemoryType::kUNKNOWN, - "Expected lora weights to be in CPU memory"); TLLM_CHECK_WITH_INFO(mConfig.value().getMemoryType() != MemoryType::kGPU && mConfig.value().getMemoryType() != MemoryType::kUNKNOWN, - "Expected lora weights to be in CPU memory"); + "Expected lora config to be in CPU memory"); TLLM_CHECK_WITH_INFO( mConfig.value().getDataType() == DataType::kINT32, "Expected lora config tensor to have type kINT32"); + } + if (mWeights.has_value()) + { + SizeType32 constexpr expectedWeightsDims = 2; + TLLM_CHECK_WITH_INFO( + mConfig.has_value(), "Request for LoRA inference with lora weights must also have lora config"); + + TLLM_CHECK_WITH_INFO( + mWeights.value().getShape().size() == expectedWeightsDims, "Expected weights tensor to have 2 dimensions"); + + TLLM_CHECK_WITH_INFO(mWeights.value().getMemoryType() != MemoryType::kGPU + && mWeights.value().getMemoryType() != MemoryType::kUNKNOWN, + "Expected lora weights to be in CPU memory"); TLLM_CHECK_WITH_INFO(mConfig.value().getShape()[0] == mWeights.value().getShape()[0], "Expected dim 0 of lora weights and lora config to have the same size"); diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 65718f0405d..38256edbc75 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -23,6 +23,7 @@ #include "tensorrt_llm/executor/serializeUtils.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/runtime/cudaStream.h" +#include #include #include #include @@ -1162,10 +1163,11 @@ KvCacheConfig Serialization::deserializeKvCacheConfig(std::istream& is) auto secondaryOffloadMinPriority = su::deserialize>(is); auto eventBufferMaxSize = su::deserialize(is); auto useUvm = su::deserialize(is); + auto attentionDpEventsGatherPeriodMs = su::deserialize(is); return KvCacheConfig{enableBlockReuse, maxTokens, maxAttentionWindowVec, sinkTokenLength, freeGpuMemoryFraction, hostCacheSize, onboardBlocks, crossKvCacheFraction, secondaryOffloadMinPriority, eventBufferMaxSize, - enablePartialReuse, copyOnPartialReuse, useUvm}; + enablePartialReuse, copyOnPartialReuse, useUvm, attentionDpEventsGatherPeriodMs}; } void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream& os) @@ -1183,6 +1185,7 @@ void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream& su::serialize(kvCacheConfig.getSecondaryOffloadMinPriority(), os); su::serialize(kvCacheConfig.getEventBufferMaxSize(), os); su::serialize(kvCacheConfig.getUseUvm(), os); + su::serialize(kvCacheConfig.getAttentionDpEventsGatherPeriodMs(), os); } size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig) @@ -1202,6 +1205,7 @@ size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig) totalSize += su::serializedSize(kvCacheConfig.getSecondaryOffloadMinPriority()); totalSize += su::serializedSize(kvCacheConfig.getEventBufferMaxSize()); totalSize += su::serializedSize(kvCacheConfig.getUseUvm()); + totalSize += su::serializedSize(kvCacheConfig.getAttentionDpEventsGatherPeriodMs()); return totalSize; } @@ -2181,6 +2185,237 @@ std::vector Serialization::deserializeRequestStatsPerI return iterRequestStatsVec; } +// KVCacheEvents deque +std::vector Serialization::serialize(std::deque const& eventQueue) +{ + // Compute the size of serialized buffer + size_t totalSize = 0; + totalSize += sizeof(size_t); + for (auto const& event : eventQueue) + { + totalSize += su::serializedSize(event); + } + + std::vector buffer(totalSize); + std::stringbuf strbuf(std::ios_base::out | std::ios_base::in); + strbuf.pubsetbuf(buffer.data(), buffer.size()); + std::ostream os(&strbuf); + + su::serialize(eventQueue.size(), os); + for (auto const& event : eventQueue) + { + su::serialize(event, os); + } + return buffer; +} + +std::deque Serialization::deserializeKVCacheEvents(std::vector& buffer) +{ + std::deque kvCacheEvents; + su::VectorWrapBuf strbuf(buffer); + std::istream is(&strbuf); + auto numEvents = su::deserialize(is); + for (std::size_t event = 0; event < numEvents; ++event) + { + kvCacheEvents.emplace_back(Serialization::deserializeKVCacheEvent(is)); + } + return kvCacheEvents; +} + +// KVCacheEvent +size_t Serialization::serializedSize(KVCacheEvent const& event) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(event.eventId); + totalSize += su::serializedSize(event.data); + totalSize += su::serializedSize(event.windowSize); + totalSize += su::serializedSize(event.attentionDpRank); + return totalSize; +} + +void Serialization::serialize(KVCacheEvent const& event, std::ostream& os) +{ + su::serialize(event.eventId, os); + su::serialize(event.data, os); + su::serialize(event.windowSize, os); + su::serialize(event.attentionDpRank, os); +} + +KVCacheEvent Serialization::deserializeKVCacheEvent(std::istream& is) +{ + auto eventId = su::deserialize(is); + auto data = su::deserialize(is); + auto windowSize = su::deserialize(is); + auto attentionDpRank = su::deserialize>(is); + + return KVCacheEvent{eventId, data, windowSize, attentionDpRank}; +} + +// KVCacheCreatedData +size_t Serialization::serializedSize(KVCacheCreatedData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.numBlocksPerCacheLevel); + return totalSize; +} + +void Serialization::serialize(KVCacheCreatedData const& data, std::ostream& os) +{ + su::serialize(data.numBlocksPerCacheLevel, os); +} + +KVCacheCreatedData Serialization::deserializeKVCacheCreatedData(std::istream& is) +{ + auto numBlocksPerCacheLevel = su::deserialize>(is); + return KVCacheCreatedData{numBlocksPerCacheLevel}; +} + +// KVCacheStoredData +size_t Serialization::serializedSize(KVCacheStoredData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.parentHash); + totalSize += su::serializedSize(data.blocks); + return totalSize; +} + +void Serialization::serialize(KVCacheStoredData const& data, std::ostream& os) +{ + su::serialize(data.parentHash, os); + su::serialize(data.blocks, os); +} + +KVCacheStoredData Serialization::deserializeKVCacheStoredData(std::istream& is) +{ + auto parentHash = su::deserialize>(is); + auto blocks = su::deserialize>(is); + return KVCacheStoredData{parentHash, blocks}; +} + +// KVCacheStoredBlockData +size_t Serialization::serializedSize(KVCacheStoredBlockData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.blockHash); + totalSize += su::serializedSize(data.tokens); + totalSize += su::serializedSize(data.loraId); + totalSize += su::serializedSize(data.cacheLevel); + totalSize += su::serializedSize(data.priority); + return totalSize; +} + +void Serialization::serialize(KVCacheStoredBlockData const& data, std::ostream& os) +{ + su::serialize(data.blockHash, os); + su::serialize(data.tokens, os); + su::serialize(data.loraId, os); + su::serialize(data.cacheLevel, os); + su::serialize(data.priority, os); +} + +KVCacheStoredBlockData Serialization::deserializeKVCacheStoredBlockData(std::istream& is) +{ + auto blockHash = su::deserialize(is); + auto tokens = su::deserialize(is); + auto loraId = su::deserialize>(is); + auto cacheLevel = su::deserialize(is); + auto priority = su::deserialize(is); + + return KVCacheStoredBlockData{blockHash, tokens, loraId, cacheLevel, priority}; +} + +// KVcacheRemovedData + +size_t Serialization::serializedSize(KVCacheRemovedData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.blockHashes); + return totalSize; +} + +void Serialization::serialize(KVCacheRemovedData const& data, std::ostream& os) +{ + su::serialize(data.blockHashes, os); +} + +KVCacheRemovedData Serialization::deserializeKVCacheRemovedData(std::istream& is) +{ + auto blockHashes = su::deserialize>(is); + return KVCacheRemovedData{blockHashes}; +} + +// KVCacheEventDiff +template +size_t Serialization::serializedSize(KVCacheEventDiff const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.oldValue); + totalSize += su::serializedSize(data.newValue); + return totalSize; +} + +template +void Serialization::serialize(KVCacheEventDiff const& data, std::ostream& os) +{ + su::serialize(data.oldValue, os); + su::serialize(data.newValue, os); +} + +template +KVCacheEventDiff Serialization::deserializeKVCacheEventDiff(std::istream& is) +{ + auto oldValue = su::deserialize(is); + auto newValue = su::deserialize(is); + return KVCacheEventDiff{oldValue, newValue}; +} + +// KVCacheUpdatedData +size_t Serialization::serializedSize(KVCacheUpdatedData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.blockHash); + totalSize += su::serializedSize(data.cacheLevel); + totalSize += su::serializedSize(data.priority); + return totalSize; +} + +void Serialization::serialize(KVCacheUpdatedData const& data, std::ostream& os) +{ + su::serialize(data.blockHash, os); + su::serialize(data.cacheLevel, os); + su::serialize(data.priority, os); +} + +KVCacheUpdatedData Serialization::deserializeKVCacheUpdatedData(std::istream& is) +{ + auto blockHash = su::deserialize(is); + auto cacheLevel = su::deserialize>>(is); + auto priority = su::deserialize>>(is); + return KVCacheUpdatedData{blockHash, cacheLevel, priority}; +} + +// UniqueToken +size_t Serialization::serializedSize(tensorrt_llm::runtime::UniqueToken const& token) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(token.tokenId); + totalSize += su::serializedSize(token.tokenExtraId); + return totalSize; +} + +void Serialization::serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os) +{ + su::serialize(token.tokenId, os); + su::serialize(token.tokenExtraId, os); +} + +tensorrt_llm::runtime::UniqueToken Serialization::deserializeUniqueToken(std::istream& is) +{ + auto tokenId = su::deserialize(is); + auto tokenExtraId = su::deserialize(is); + return tensorrt_llm::runtime::UniqueToken{tokenId, tokenExtraId}; +} + // String std::string Serialization::deserializeString(std::istream& is) { diff --git a/cpp/tensorrt_llm/executor/serializeUtils.h b/cpp/tensorrt_llm/executor/serializeUtils.h index 8f26c58d622..40b50f92309 100644 --- a/cpp/tensorrt_llm/executor/serializeUtils.h +++ b/cpp/tensorrt_llm/executor/serializeUtils.h @@ -122,6 +122,14 @@ static_assert(hasSerializedSize(size_t())); static_assert(!hasSerializedSize(size_t())); static_assert(!hasSerializedSize>(size_t())); static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize>(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); template size_t serializedSize(T const& data) @@ -219,6 +227,14 @@ static_assert(hasSerialize(nullptr)); static_assert(!hasSerialize(nullptr)); static_assert(!hasSerialize>(nullptr)); static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize>(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); template void serialize(T const& data, std::ostream& os) @@ -291,6 +307,22 @@ struct get_variant_alternative_type } }; +template +T deserialize(std::istream& is); + +// Helper function to deserialize variant by index using template recursion +template +T deserializeVariantByIndex(std::istream& is, std::size_t index, std::index_sequence /*indices*/) +{ + T result; + bool found = ((Is == index ? (result = deserialize>(is), true) : false) || ...); + if (!found) + { + TLLM_THROW("Invalid variant index during deserialization: " + std::to_string(index)); + } + return result; +} + // Deserialize template T deserialize(std::istream& is) @@ -511,6 +543,38 @@ T deserialize(std::istream& is) { return Serialization::deserializeCacheTransceiverConfig(is); } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheEvent(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheCreatedData(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheStoredData(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheStoredBlockData(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheRemovedData(is); + } + else if constexpr (std::is_same_v>) + { + return Serialization::deserializeKVCacheEventDiff(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheUpdatedData(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeUniqueToken(is); + } // Optional else if constexpr (std::is_same_v::type>>) { @@ -547,23 +611,7 @@ T deserialize(std::istream& is) std::size_t index = 0; is.read(reinterpret_cast(&index), sizeof(index)); - // TODO: Is there a better way to implement this? - T data; - if (index == 0) - { - using U = std::variant_alternative_t<0, T>; - data = deserialize(is); - } - else if (index == 1) - { - using U = std::variant_alternative_t<1, T>; - data = deserialize(is); - } - else - { - TLLM_THROW("Serialization of variant of size > 2 is not supported."); - } - return data; + return deserializeVariantByIndex(is, index, std::make_index_sequence>{}); } else { diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu index 27d041618e7..84710a96365 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu @@ -256,9 +256,9 @@ public: constexpr int SF_VEC_SIZE = 16; using PackedVec = PackedVec; PackedVec pack_val = *reinterpret_cast(&val); - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt, token_id, - m_access_id_in_token, std::nullopt, m_params.hidden_dim, - reinterpret_cast(m_params.scale_out), m_params.layout); + auto sf_out = cvt_quant_get_sf_out_offset(std::nullopt, token_id, m_access_id_in_token, + std::nullopt, m_params.hidden_dim / SF_VEC_SIZE, reinterpret_cast(m_params.scale_out), + m_params.layout); reinterpret_cast(m_params.quant_out)[m_access_id] = cvt_warp_fp16_to_fp4(pack_val, m_scale_factor, sf_out); } diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h index dbf45ebe1cc..52487b25d4e 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h @@ -132,7 +132,7 @@ struct AllReduceFusionParams float rms_eps; float* scale_factor; bool use_oneshot; - FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED; + QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED; cudaStream_t stream; AllReduceFusionPattern pattern; bool trigger_completion_at_end = true; diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu index 2176ba759f4..c38abd95785 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu @@ -99,15 +99,15 @@ __device__ struct __attribute__((aligned(32))) LamportFlags uint32_t* offset_access_ptr; uint32_t* buffer_flags; - __device__ explicit LamportFlags(uint32_t* buffer_flags) + __device__ explicit LamportFlags(uint32_t* buffer_flags, uint32_t buffer_size) : offset_access_ptr(&buffer_flags[4]) , buffer_flags(buffer_flags) + , buffer_size(buffer_size) { uint4 flag = reinterpret_cast(buffer_flags)[0]; - buffer_size = flag.z; input_offset = flag.x * (buffer_size << 1U); clear_offset = flag.y * (buffer_size << 1U); - num_tokens_prev = flag.w; + num_tokens_prev = flag.z; } __device__ void cta_arrive() @@ -135,7 +135,7 @@ __device__ struct __attribute__((aligned(32))) LamportFlags uint4 flag = reinterpret_cast(buffer_flags)[0]; buffer_flags[0] = (flag.x + 1) % 3; buffer_flags[1] = (flag.y + 1) % 3; - buffer_flags[3] = num_tokens; + buffer_flags[2] = num_tokens; *(offset_access_ptr) = 0; } } @@ -144,7 +144,7 @@ __device__ struct __attribute__((aligned(32))) LamportFlags template __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens, - int buffer_M, int token_dim, int rank, uint32_t* buffer_flags, bool wait_for_results) + int buffer_M, int token_dim, int rank, uint32_t buffer_size, uint32_t* buffer_flags, bool wait_for_results) { int elt = blockIdx.y * blockDim.x + threadIdx.x; if (elt >= token_dim) @@ -155,7 +155,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ cudaGridDependencySynchronize(); #endif - LamportFlags flags(buffer_flags); + LamportFlags flags(buffer_flags, buffer_size); // Capture the number of tokens in previous iteration so that we can properly clear the buffer // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up @@ -217,15 +217,17 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif - - // Similarly clear broadcast buffer here - for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) + if (elt < token_dim) { - uint32_t clr_token_idx = token + clr_tok * gridDim.x; - if (clr_token_idx < buffer_M) + // Similarly clear broadcast buffer here + for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { - input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt] - = fromFloat(-0.f); + uint32_t clr_token_idx = token + clr_tok * gridDim.x; + if (clr_token_idx < buffer_M) + { + input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt] + = fromFloat(-0.f); + } } } @@ -240,20 +242,24 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ // blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32) if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD)) { - uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD; - - void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos]; - // We have 2 assumptions here: - // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B - // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32) - float2 val = loadfloat2(lamport_ptr); - while (isNegZero(*(T*) &val)) - { - val = loadfloat2(lamport_ptr); - } - if (output_ptr) + uint64_t elt_load_offset = blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD; + if (elt_load_offset < token_dim) { - *((float2*) &output_ptr[current_pos]) = val; + uint64_t current_pos = blockIdx.x * token_dim + elt_load_offset; + + void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos]; + // We have 2 assumptions here: + // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B + // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32) + float2 val = loadfloat2(lamport_ptr); + while (isNegZero(*(T*) &val)) + { + val = loadfloat2(lamport_ptr); + } + if (output_ptr) + { + *((float2*) &output_ptr[current_pos]) = val; + } } } @@ -263,10 +269,11 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ } #define LAUNCH_ALL_REDUCE_KERNEL(WORLD_SIZE, T) \ - TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel, \ - reinterpret_cast(params.output), reinterpret_cast(params.input), \ - reinterpret_cast(params.buffer_ptrs_dev), (T*) params.multicast_ptr, params.num_tokens, params.buffer_M, \ - params.token_dim, params.rank, reinterpret_cast(params.buffer_flags), params.wait_for_results)); + TLLM_CUDA_CHECK( \ + cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel, reinterpret_cast(params.output), \ + reinterpret_cast(params.input), reinterpret_cast(params.buffer_ptrs_dev), \ + (T*) params.multicast_ptr, params.num_tokens, params.buffer_M, params.token_dim, params.rank, \ + params.buffer_size, reinterpret_cast(params.buffer_flags), params.wait_for_results)); void twoshot_allreduce_op(AllReduceParams const& params) { @@ -369,20 +376,33 @@ inline __device__ T add(T a, T b) } #define FINAL_MASK 0xffffffff +#define WARP_SIZE 32 template __inline__ __device__ T warpReduceSum(T val) { + // Get the actual number of active threads in this warp + int active_warp_size = min(WARP_SIZE, blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1))); + unsigned int mask = (1U << active_warp_size) - 1; + #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80 + for (int offset = 16; offset > 0; offset >>= 1) + { + if (offset < active_warp_size) + { + val = add(val, __shfl_xor_sync(mask, val, offset, WARP_SIZE)); + } + } return val; } inline __device__ float block_reduce_sum(float val) { - __shared__ float smem[32]; - int lane_id = threadIdx.x % 32, warp_id = threadIdx.x / 32, warp_num = blockDim.x / 32; + __shared__ float smem[WARP_SIZE]; + int lane_id = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + int warp_num = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; // Ceiling division to include partial warps + val = warpReduceSum(val); if (lane_id == 0) { @@ -391,6 +411,7 @@ inline __device__ float block_reduce_sum(float val) __syncthreads(); val = lane_id < warp_num ? smem[lane_id] : 0.f; val = warpReduceSum(val); + return val; } @@ -410,7 +431,7 @@ __device__ float4 loadfloat4(void const* ptr) template __global__ void __launch_bounds__(128, 1) RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, T_IN const* gamma, float epsilon, - T_IN const* residual, int batch_size, uint32_t* buffer_flags) + T_IN const* residual, int batch_size, uint32_t buffer_size, uint32_t* buffer_flags) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) static bool const LAMPORT = true; @@ -433,7 +454,7 @@ __global__ void __launch_bounds__(128, 1) int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)]; - LamportFlags flags(buffer_flags); + LamportFlags flags(buffer_flags, buffer_size); T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size]; cudaTriggerProgrammaticLaunchCompletion(); @@ -598,16 +619,15 @@ __global__ void __launch_bounds__(128, 1) #endif } -template +template void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T const* gamma, double epsilon, - T const* residual, uint32_t* buffer_flags, int batch, cudaStream_t stream) + T const* residual, uint32_t buffer_size, uint32_t* buffer_flags, int batch, cudaStream_t stream) { // input to rmsnorm is the buffer in the twoshot ar // We should use prenorm output to determine the actual used size float _epsilon{static_cast(epsilon)}; - static constexpr int NUM_THREADS = 128; static constexpr int CGA_THREADS = NUM_THREADS; constexpr int iters = H_DIM / CGA_THREADS; @@ -628,28 +648,34 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons &RMSNorm, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); config.dynamicSmemBytes = shmem_size; TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &RMSNorm, prenorm_output, normed_output, - input, gamma, _epsilon, residual, batch, buffer_flags)); + input, gamma, _epsilon, residual, batch, buffer_size, buffer_flags)); } -#define LAUNCH_RMSNORM_KERNEL(T, H_DIM) \ - twoshot_rmsnorm(static_cast(params.residual_output), static_cast(params.output), \ +#define LAUNCH_RMSNORM_KERNEL(T, H_DIM, NUM_THREADS) \ + twoshot_rmsnorm(static_cast(params.residual_output), static_cast(params.output), \ static_cast(params.input), static_cast(params.gamma), params.epsilon, \ - static_cast(params.residual), params.buffer_flags, params.batch, params.stream) + static_cast(params.residual), params.buffer_size, params.buffer_flags, params.batch, params.stream) void twoshot_rmsnorm_op(RMSNormParams const& params) { auto dtype = params.dtype; + +#define CASE_DISPATCH_RMSNORM(T, H_DIM, NUM_THREADS) \ + case H_DIM: LAUNCH_RMSNORM_KERNEL(T, H_DIM, NUM_THREADS); break; + +#define TYPE_DISPATCH_RMSNORM(T) \ + CASE_DISPATCH_RMSNORM(T, 2048, 128) \ + CASE_DISPATCH_RMSNORM(T, 2880, 120) \ + CASE_DISPATCH_RMSNORM(T, 4096, 128) \ + CASE_DISPATCH_RMSNORM(T, 5120, 128) \ + CASE_DISPATCH_RMSNORM(T, 7168, 128) \ + CASE_DISPATCH_RMSNORM(T, 8192, 128) + if (dtype == nvinfer1::DataType::kFLOAT) { switch (params.hidden_dim) { - case 2048: LAUNCH_RMSNORM_KERNEL(float, 2048); break; - case 4096: LAUNCH_RMSNORM_KERNEL(float, 4096); break; - // Llama-4 Hidden Dimension - case 5120: LAUNCH_RMSNORM_KERNEL(float, 5120); break; - // DeepSeek Hidden Dimension - case 7168: LAUNCH_RMSNORM_KERNEL(float, 7168); break; - case 8192: LAUNCH_RMSNORM_KERNEL(float, 8192); break; + TYPE_DISPATCH_RMSNORM(float); default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim."); } } @@ -657,13 +683,7 @@ void twoshot_rmsnorm_op(RMSNormParams const& params) { switch (params.hidden_dim) { - case 2048: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 2048); break; - case 4096: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 4096); break; - // Llama-4 Hidden Dimension - case 5120: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 5120); break; - // DeepSeek Hidden Dimension - case 7168: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 7168); break; - case 8192: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 8192); break; + TYPE_DISPATCH_RMSNORM(__nv_bfloat16); default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim."); } } @@ -671,13 +691,7 @@ void twoshot_rmsnorm_op(RMSNormParams const& params) { switch (params.hidden_dim) { - case 2048: LAUNCH_RMSNORM_KERNEL(__nv_half, 2048); break; - case 4096: LAUNCH_RMSNORM_KERNEL(__nv_half, 4096); break; - // Llama-4 Hidden Dimension - case 5120: LAUNCH_RMSNORM_KERNEL(__nv_half, 5120); break; - // DeepSeek Hidden Dimension - case 7168: LAUNCH_RMSNORM_KERNEL(__nv_half, 7168); break; - case 8192: LAUNCH_RMSNORM_KERNEL(__nv_half, 8192); break; + TYPE_DISPATCH_RMSNORM(__nv_half); default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim."); } } @@ -685,6 +699,8 @@ void twoshot_rmsnorm_op(RMSNormParams const& params) { TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported dtype."); } +#undef TYPE_DISPATCH_RMSNORM +#undef CASE_DISPATCH_RMSNORM } } // namespace tensorrt_llm::kernels::mnnvl diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h index ccca256b5a2..3a0fb753db2 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h @@ -30,6 +30,7 @@ struct AllReduceParams int buffer_M; int num_tokens; int token_dim; + uint32_t buffer_size; void** buffer_ptrs_dev; void* multicast_ptr; void* buffer_flags; @@ -50,6 +51,7 @@ struct RMSNormParams void const* gamma; double epsilon; void* residual; + uint32_t buffer_size; uint32_t* buffer_flags; int batch; int hidden_dim; diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu index 577f4b5ff4f..7bc9e326fb2 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu @@ -150,8 +150,8 @@ __device__ __forceinline__ void fused_op( constexpr int SF_VEC_SIZE = 16; using PackedVec = PackedVec; PackedVec pack_val = *reinterpret_cast(&norm_val); - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt /* batchIdx */, - token_id, access_id_in_token, std::nullopt /* numRows */, params.hidden_dim, + auto sf_out = cvt_quant_get_sf_out_offset(std::nullopt /* batchIdx */, token_id, + access_id_in_token, std::nullopt /* numRows */, params.hidden_dim / SF_VEC_SIZE, reinterpret_cast(params.scale_out), params.layout); reinterpret_cast(params.quant_out)[access_id] = cvt_warp_fp16_to_fp4(pack_val, *params.scale_factor, sf_out); diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h index 9ebc7de6509..4a35d14bf09 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h @@ -55,7 +55,7 @@ struct AllReduceFusionParams void* rms_gamma; float rms_eps; float* scale_factor; - FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED; + QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED; cudaStream_t stream; }; diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_32_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_32_ldgsts_sm90.cubin.cpp deleted file mode 100644 index 81208594d0f..00000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_32_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d5bb139b12206a563daec9fa473dda422319bde5ae5f965d37cf5ca67d325c49 -size 1005546 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_64_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_64_ldgsts_sm90.cubin.cpp deleted file mode 100644 index 7086ad9f485..00000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_64_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c4357a935656d47414a459939720b66311c67213f450168715e1cb0238653768 -size 1066324 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp index 0acae9aa71b..2ae91e52cd7 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0a0671e7cbbed9f51dc0c47e4b970e2f72067d629ff6562c9d65f9cd55c68578 -size 361861 +oid sha256:c709dce149c0f4500539e495c90d1da2d86cec28c4187ee9494b015642e158cf +size 363441 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp index 4cb6bcd1c18..bce0c66bcf1 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5ec9817bebb07483ce29d8d91c45d35c2c05f0101bfa70146fba5a6576a6b825 -size 1091614 +oid sha256:b9170581da010aca67f4bafd9f6f59aaaf5fd1958a1fdd336aa208146599ac06 +size 1094770 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp index 470904148ad..caa735d5724 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0540cdb398818ec54a60c34b462c158e169347db73d244d633669d74211696ba -size 1467312 +oid sha256:2147a246067f7ea74ca382fbc8c02a26332479e5205ecfbe08fb84161a3a87ec +size 1483888 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp index 281985341d5..0b584163a86 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:69bdfba64f1faff30ed8389a28b7b9ef37c0d180b1df643722b280011c8f74e8 -size 692990 +oid sha256:279bd48b8ac53690bb4e37dffbe9060428db80c1417ff29c6f4d4a10ab35a7c9 +size 700094 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp index 8b8738474dd..496df695fcc 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c8173308813999ab64ba8236016b23fbfd3f3f1501f61290bf71ea027ead2920 -size 642456 +oid sha256:db5d186ce70d7a94cae2b6619b3449ca557903944beba1ee738d2ee425792d74 +size 652718 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp index 6ca952af647..c6692932cdb 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f41ae066b01b2a9c3b5165535f743461a9a1d559f6fcd0a00a04c554f8a50962 -size 414757 +oid sha256:089a98cf8ab0bbd7530e69821c42220ea02578b740bff62a3e6e33de45209114 +size 416335 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp index 1a973c5d2e6..555f6268648 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ab0be8e667d459e13135f96469613f1c095e47187b24e5d40c7c57583351a076 -size 1194236 +oid sha256:1f0cc486ec5e9c1720f495a2a5e7c26d42e737694d307d4746a08b6ead5cc225 +size 1197394 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp index 8faf85254d9..b5884bba556 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:03d86280f76994e2e01d43747cb5c811496b8340d031ebb0c3bdd46437422994 -size 1654394 +oid sha256:398965e34c1a4c747b42d8836c04934daaa43903b7931586ed12120e17a61f76 +size 1672548 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp index 53f3032a30e..696620f8791 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:35c5715bcb1a16c343f3a28be105fb6fee1bbca24cf832f71a7d0f20cf9a0b3e -size 365015 +oid sha256:77cbd7d45164d24be73e021bc0a8745b4f021e4369a254e216ee00b36d3c7263 +size 366593 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sm90.cubin.cpp index 89a4eaa580c..22a4ff75bf6 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a3335a8d4b2c0ca63f006c3f957d57aa3f808ef06d4adda322c311a333286d84 +oid sha256:3a3f74fbe72ef54b9c028d957353c1ecbff1d20bcc9619ff17ee37471934a2ab size 1126352 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp index 9cb2eb33c23..e0b9335b45e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fdc0bf099862d352b3b765e117437240a82e4749d3efd104881647dd4ea14562 +oid sha256:b3af082c6742f385d0d2c96489ff1de314458eb992d6d5a251c737f8ec912e79 size 644092 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp index 153555cbe42..ec999849faf 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ccd938df8f78af4eae306c6e9e669599c2baf6f095f956318470063c560fbd3c -size 1091610 +oid sha256:8e26f3b8cc173301b3cf07ba1ca7893b6f140432410b0b298361ecff597604c2 +size 1095556 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp index cab205493aa..284e084f3df 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ce4d35ab4c7b65476f0dcec635db1791fcb718afd6b3531338712f5b2bc9aa84 -size 1460204 +oid sha256:32220d11bc3542e9edcc36d51b4866bf40044213114d7e237e003afc1fc7c464 +size 1478358 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86.cubin.cpp index ab21a448f54..69a3f4789c8 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d088ce37b21d335ba1f92034cf97f78fc968d7fecaa0c4f9ec83a0d5165f1d99 +oid sha256:3ee5ae75df4866d848e90616562345d3740b17b68c90f06329dc074dba5217a9 size 482709 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89.cubin.cpp index 2fa6ba246ed..c19635d6887 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:40653ec672098e2cb1f94c473fa67852efcf6b49a6e8109e4fcf39422281acb4 +oid sha256:817ae5c1eb8a8c6f22a76ab0b88075fd3391d06abb7dd6d9ab51206b809cd69d size 657930 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm90.cubin.cpp index ebdb0563ef9..a625def240f 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:96348957990518db6f51af7c681a71e625dede568cc8f8303dd2de8ad09bfc28 +oid sha256:680734da0abb1c3029dce32e892687f649c4219f66574acb15ab88471f508263 size 677218 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp index 7cd5b267e07..1691a77e1fe 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4687df80ac2fa9454b0564b0a80d78cfaedc2c7796c8f3a1010dd7ebbf722c83 +oid sha256:c27e871dd680022920081c30c5e239613e53b42129680fdb1d17668b5c5ddd9a size 369401 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_kv_128_tma_ws_sm90.cubin.cpp index f4da9b9d86f..6e7098d6c73 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d8b9985065f5f2c62b74c05f8eed02b1909c96656b26fbd7779cc57a2146b037 -size 947140 +oid sha256:3e1ecaa635067924b692b665241d86e1d8c1d60a19290de7adde1ff2ca7dbeb0 +size 956612 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp index 8ffdb6589d9..c38c3b29fd6 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:23599e63b07ad966df921daf3cb97a9ed5cde27eeda0fd96ba5abd835b48f89a -size 590779 +oid sha256:d3018c622303f89c6f22f037ec99eaeaeea9cfe8911e22463b48a22c13116805 +size 592357 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp index 1153714c7e1..5d286a73e53 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:cd1c452565583b20913d835de9b14c2f19c0cc431bc926ea6c92295362a85bca -size 1813864 +oid sha256:a7a381f2855236f418a40124a5254401c95001d5e15c074a704e22cc7ed89aa2 +size 1818600 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp index b6383dcbd5c..5290f97cfb8 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b20de2c6bb3081564ddfbf7ece80fb2c17e66f4e7ff0e0969da4e4655e90d1ec -size 2407418 +oid sha256:9bb49ace4dedc4faa3de2b9c22e09db0f3990129ce7ab4afb6419c38a5d48a16 +size 2427152 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp index 3713748af50..cb3d89f0704 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:33a0e8bb2391128e688e5c6356f09a5ed189ce5c1bcdeef4efc0ce0415dc2849 -size 555245 +oid sha256:9769d7cb9754718798be515c84c45ff48e43322573f3f12e31c2e42e99d8dbd4 +size 557613 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90.cubin.cpp index 795d4d68fc9..de925119b38 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4b014f41b1cfdf6ed2729778841213a36440191eb3c087346a02c21510bd3f0e -size 665794 +oid sha256:134f4a73e0e6b02b717319ec49e3b3ea0a585cad385a1f300e6c5761f12de9d7 +size 671320 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp index 5c8dbe22b24..64bb52e0df9 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bd77afeb7dcd1ff8d6be80788b20e92e4fbc8c3026ba12d1d522c99316754a7c -size 1740442 +oid sha256:7935b0f053a79a7e620c0efe274fa5b4c840fc9c6e439a381c4d380446e1cb68 +size 1744388 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_tma_ws_sm90.cubin.cpp index ee1a46c9bc9..87d96af432a 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b674707d02aac297b66d523de8b11618ca1598c49eeaf7ce9b1c9d516ce95c4b -size 2247958 +oid sha256:74ecbbaa19b2efe97a3b12c488f0e03c2102f16c460239df4bfc19976fc4365e +size 2266902 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm89.cubin.cpp index 349c2efdfe3..15ad1d62a91 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7556f88488e05ee669e763b839afa1b7690060cfa9d8482d419c0ca336df9352 +oid sha256:813265d25709bd2d39982efbaf092c9163b124bd990fccab505b3c22134522aa size 595585 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_64_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_64_sm89.cubin.cpp index 2ccc55f1447..4e62255a629 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_64_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_64_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ac9d879aa0c70967bb3a79cd7034998baf43a544c0dd4444ebddeb76e78df5ae +oid sha256:dd36195c01bf7c2a2013d5f31d2e74c2579c471385d7b45be7e35ea2f0652608 size 908162 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_32_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_32_sm89.cubin.cpp index ec1ef8aae91..10ee7b3d8c4 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_32_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_32_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4e781c0278fc46142f578ae51bfeb38767e89d9c25b92023215948f99dd1d3ed +oid sha256:31d4d6dca68c4632d1f435e9179582cfe2ad7a75ee0f7625ee67b0044c914f10 size 1371512 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp index d904de0acb2..407d34a6552 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d608e9e3ec460d2a38f43067a7d7a2dd408e068db690806bbafb11007e175336 +oid sha256:6570d3ee7b651dec797e82b31eb21fd3261c6e2639fb7c9b157f251bf98bb3bf size 1419662 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_48_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_48_sm89.cubin.cpp index 798e8482b41..d6b829a9a00 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_48_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_48_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9c1e1d300866c6425c2495e550230051debdca0a7eb85874ae33c0c2de8a81cb +oid sha256:88b972677c5436b90fe85870278e3b23d6f709608f99295bddf0be3861d95d1a size 1419662 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_64_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_64_sm89.cubin.cpp index bbcce09e729..7cac9a83250 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_64_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_64_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:132d83639e34af1b431abdcb3f09542d0389030b85752e18a3ae221ead7d24a3 +oid sha256:d975f605d62c3070d6cf72f6114d98642c520e66989ed2d2845c3213e921ebf7 size 1965880 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm89.cubin.cpp index 83287a0376a..9dd7d6bf8e6 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4a96710f6c691580c2363c187a75fd436f5e6be732810a1a45182ce72dc52d1e +oid sha256:ef5a2728cbd3241f45f3d8285c91a818e11b2a9fedf322f343a9461d31a6ad30 size 1380182 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_40_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_40_sm89.cubin.cpp index 00623779346..1b6d6cddf5e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_40_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_40_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a6339f008f451d030aa36a6b3fac7179e7534f7f2474d641fa0ebfbf487074e7 +oid sha256:16b5f3d3f8760dabc0849217cf11edf18d19896dda475a5fc233bbfd444faf33 size 1401494 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_48_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_48_sm89.cubin.cpp index 0d719af97a3..90decb87938 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_48_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_48_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:57ebcae2b70fc28881f2b3969868d64c203ef4a9cbc9588a9e28051c5f5b6849 +oid sha256:cbacb235f39adaeabd68e2fc46c51aac6ca26cdf96293a6a7eb60b5be40640ef size 1401494 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_64_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_64_sm89.cubin.cpp index ceab132d423..5628ced1f3e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_64_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_64_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5e2a4ce1b944feb2b3ed535943089a2d5968bf523b149885df78f7fa4bd7e835 +oid sha256:e6f3e068435339a64d47673f8018b66c202f6259d68e0a97a4a30acb7505a7fd size 1935872 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_128_sm89.cubin.cpp index 2780675d9d0..552a78df4f2 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f5d456b30f89ad05ba5b852fabcffb3f8269913d83ef8c0e4e319f2243dee54d +oid sha256:7c2d7ab0692de5405b26d19a0c57d720285366ac12a8550bbabca1613cce7f0c size 305897 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_72_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_72_sm89.cubin.cpp index 2aa3fd4b0a3..ca2d2a604da 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_72_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_72_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:85593d3c2fecb6842a72952c6dcbde19a70e6b26245829d279ca50bb391eb636 +oid sha256:91a26adfddc0bcaf8b42249f59f1a0b9f74be0f82c7378fe4b56f3a2fa3d4bf1 size 290109 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_104_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_104_sm89.cubin.cpp index b050acbb5aa..da475b4a2d1 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_104_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_104_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:69cd61bd8334d2109067ef0460a91b8dba4c2cb07392eb636d72d025ccb15bf9 +oid sha256:6ef79c9e2e2d8bba55d7803dc8dc147b5d8babc29e906a43407a8722bbd8d939 size 498507 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_128_sm89.cubin.cpp index e741d50f4cd..09b401a0036 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0427b7729ce3cfa652a4595d04f936a947febec8f2c96ce33eed7cbaaa05613e +oid sha256:0eef025f8e8581868b02bcea37ff225afebcbb2966450fb29fb0e32ac54eccd4 size 668214 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_160_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_160_sm89.cubin.cpp index eee064e2804..0c6a45eacc1 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_160_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_160_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:321bcd81b8965c8dfc08682f775508ae18e3ff711490ee8dff5fe56c20f74843 +oid sha256:abb2857ffb85cc36aae90ebb674635dffee2b2c5f7ad1ea81bb8002b65d5a0f8 size 711628 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89.cubin.cpp index 33f4d9cab3b..9ecb64bd23f 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:aa77d3789c0ca314689125ec303a8af76554120a708a4b63395c69b7aad07f04 +oid sha256:49a3661535314b139e2794fe16f6f3e0a8d45742b68ea59ba99a9113068adf2c size 752698 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89.cubin.cpp index 31383430901..d836cccd03a 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:aa35aa70d0fa304c776c076a1a189d32a054d3f696dac5d99018085d1108c73b +oid sha256:d76fb6c4f8bb2de687bc5f9f275389356934119c1f0db9983dcf0ec7b68c6197 size 748726 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89.cubin.cpp index ca7815f7109..79e1e96e9bb 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d1a702d456b5acf279487dd810e3e33efdd1c7bd82530ceb5a32ad30ec30396c +oid sha256:be8ee89f4489c430d0ff6e9c6cf4e07379ac05abf468d47e34e084ad594b2037 size 946060 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89.cubin.cpp index 8bb9403c511..3c8b2528fc3 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:558aa7d42de329c49361c94c4baef16738304b21b6adbe675d77c7819ef37660 +oid sha256:aa4be8ca2dd52e56c9a6af76b90ac353d217fad5fa931b21129ac5a811b5283a size 489823 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_80_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_80_sm89.cubin.cpp index 0754f76695b..22fce024ea0 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_80_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_80_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7b5baa6048e6c33e74c6d343eb7c76252ff2e534fe467b3189af12b5d64af37c +oid sha256:cb0482b768a40bc7f8a86fa23a84bab62fb82c205f3237ff60becda50cbafc90 size 489823 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_96_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_96_sm89.cubin.cpp index 68de134acba..c02b557e7f4 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_96_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_96_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e17cb191ad092e6db255ea503e49ea883ed56322fc58ed8d68710f6687376c1f +oid sha256:95b1796f4e7c905eca82ed3691427025f68e765797440b962b0114a5ab32b1d7 size 500083 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_104_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_104_sm89.cubin.cpp index 3ebcc110ecd..cbc081aae2c 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_104_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_104_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bfca5660a931e08941347f7a0aefa82c214940e8eaa6b6d89cfded621f34a490 +oid sha256:2d9f13977fc865e716f1f35dfdb222a38000b224ff7394134230ed5c88119947 size 496125 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89.cubin.cpp index c0c882331e1..cc613cc08d5 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fffd2cd799953808034d7e7b89a57d4fede24db124bfb0d3938188177acbdfeb +oid sha256:007e32a06fcac853159dc5786940447281c57ba70406d38beb6f089fd037053d size 182023 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_fp16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_fp16_sm89.cubin.cpp index 458aa250b4a..d8ba5241135 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_fp16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_fp16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:19ada3a5d449542f103077db8d193bc2293a8f48ccee201e366473964287314c +oid sha256:26241ea5909395116e1b1a0f19cadc448886f6a6ab2b3ba76c092b67cd0148f0 size 182023 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sm89.cubin.cpp index 65edc3e52ac..0206f719811 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b9c32124cd708aab7da30637d85437da0af9bf2157d163c19c6fe14498698cda +oid sha256:86e4ca60a459117c5e701631fbd3c67ca66e81d177c394c1fc9ad3b66396e69a size 661096 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_160_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_160_sm89.cubin.cpp index 8213475b06f..3444d759b7f 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_160_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_160_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7f248fd42759509c61d20f912ae74dc3a85448a9c8386370ea92492ed9031e80 +oid sha256:770db1f4ec1c2d3c25767593b60cb095e49f7a6eb7abe054bbdec6e72db97f8d size 672936 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89.cubin.cpp index 75bd11ff6e7..b99affa0208 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:190fd946ddc7e1b5e9ca2172ec1de39c6288829773d9ce29fe98374256eff566 +oid sha256:0b6428cae2d0c8c813925be9589c94771098cfe5a6d0ff2036104d3e36384b81 size 721900 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89.cubin.cpp index ed5e241d9e9..e93db30f53a 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b7cd5976c836bcd75c0cadfe968050ac60bf89b93df021ad6c1681e159c497c5 +oid sha256:36c6932301fe3dc29631c28fcb8cb6b08652103bc7a36fd74a03a8189a1c77e4 size 717928 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89.cubin.cpp index 44ce0c307f1..8f42d5a2769 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7c536d725e1d9ebd2cb836dfe3993edcc81101534db6b7f1943c8a9443838bf4 +oid sha256:d858f6dcaf3f49fb3fa18b1c8c20ee1b933e2c8ddd1a429c8d3b5b4d269fb875 size 927892 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_72_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_72_sm89.cubin.cpp index 0216db308c5..0cb2a134102 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_72_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_72_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b5907da5a2f68c010d44bbbd0d780e097f9625be15b2f85e8dd1f00dd4c31ff9 +oid sha256:7dc92ab65ed0fc5f9d821f52a396a6d55ea9ae37e080eac7ff9e9c14eae741e7 size 631890 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89.cubin.cpp index c63b37264a5..648e3acb008 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9cf14c71134a89ed6ffc83c0b7db06ed10e22b55294dc15ddf7f016427f01033 +oid sha256:d66606a37cfe8eb78ccc3f548a231f770df9f46e70f6d3ba22fb8abe6216480e size 159919 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89.cubin.cpp index 7d1ac808673..6028cc1f326 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f2b83c70dbc8ab0b3695dab3f4d2069b7ee7119e9140d7860b8c19f59a498589 +oid sha256:b723b296cff04602f64a5da9928e6f9b6a03c5cc608ba9ef7d8055f23f1f4ea2 size 159919 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89.cubin.cpp index 4041bfc97a4..b1ee67b880c 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fc8369f5701dceea91d429a713ddcbb4ecb0ad08d3c9042688557ead5f00e9da +oid sha256:d40578a5684262cd8136705367e2c98493ea9b9fcfc123c7efa3ead14017b5b8 size 483493 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_96_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_96_sm89.cubin.cpp index f0afe3fcf10..4ce3d2dba50 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_96_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_96_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4e9fffff2d13d49613e5f9334a010ca9bcde43b3bb55a792fd97fe2c867760dc +oid sha256:60cc82b9d11c53392de91a7c4c097263c20a56f9b346278c7c9af12ef2bb5fbf size 496123 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89.cubin.cpp index 03a4b33cefc..d24465ed9c8 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:dd3041ba5a52263f7f02d64f1911c50e346151bf529e865c1abf22583abd3e21 +oid sha256:8f685b6b2a0a573953f31fad89fa37e949361db245de69c0c06ce0bbb14eacef size 443285 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89.cubin.cpp index 6984f3c1700..dc49a306271 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:12482099b086249163085e6e3421a61f6e304f865aaf56dd15382614be5e48e7 +oid sha256:834f0f3601c589893a21b957be2864df594f96b34b2cfd6018ada8319986aa21 size 441683 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89.cubin.cpp index 2bb4cc25821..4763a29923c 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bfea1ea1627eaef7b614db08bad00bda8b611c8e466c858e050c0ce2aee2eafb +oid sha256:3d81a070e7ed49f1e1a322d38a757a3505186cf5cbded99814e950e07229a46a size 298049 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89.cubin.cpp index 7e76c5e13df..c8587a81d35 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f828600699faa3a0474085cbbe88d2e0ac7c8e056c976b81a882c3a72682e527 +oid sha256:b9de5bc49d888699da1880d24ccf6a9cb6c0049d7a244d1ae9ab64b7365ecd5a size 296445 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89.cubin.cpp index 1c1f7bdc42f..7d299b87052 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2d4b297922065ecb79b4a1278d048b253b57601d011fc5833a32f9fc1b78e58e +oid sha256:e30ed0df4b0d0b1da1ace5831dc0a7a526e04001b25860f862345c78acff5a43 size 427485 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89.cubin.cpp index 68394c07c1c..47eeb69632b 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:3fd5305445c9856fbd5d9dfaffdd7f87b9014638f33fb63fb2cb4fce9893b20b +oid sha256:030015dc1811e3dc2ae36ed770f51063a3f46deae42ead5e1523c977b438a133 size 425883 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80.cubin.cpp index 51778ad0e9d..1a5b22eed8a 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2b7fee97097f799830df2bcb1c782c7ea9018243cbd5cd0e0f47ec299b49db79 +oid sha256:6921a204892e1336cef2a308be38855f3c888e56bd6a16752d2806aa9e93c431 size 1524634 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp index 537871847de..834fa7d1c0b 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8ac2f9270988bc02329ce11ef3413395b2b8cdc55fcf4911d170536c6e618317 -size 403697 +oid sha256:200df98fb2fcc734e8fc012c98c5d78c2061e5718eef6ffd50c2358a3d664197 +size 406065 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp index 6bf814ac8a9..e085961e987 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1234cf31a3a6b84ed25fa0ad6c4df9b53f673f6bac2f639a66086ba50f8717ba -size 1120818 +oid sha256:430194fe07e526ad01a1e0fb43273b240c269215b132c9af248ba386dcbda23e +size 1124766 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp index 3bebbebcf15..2d56be2925e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0fff300932a16d30844e317ace515a178f159c483e436f6955983b96c5c424c6 -size 1549402 +oid sha256:53a07904a7bfbf82380c96af99c5e24bc86f77906c5d6fdc85ef9720639d76d2 +size 1569136 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp index ef64a376820..6d074921cde 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ed10767ec913d314936fc5dbd1fd70c5381a622bf3fcf1590f837da6d3285bca -size 723774 +oid sha256:1ce4d27b11fee3e5f6489510b55613177e174660b6c7a6fb4efed862b62c50d7 +size 731668 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp index d0bc52f1318..a6268993164 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7e7a7a9653a9c4e4e9b0514fc1d70abbb4521c7edbede52568d17d0779d62ffb -size 671662 +oid sha256:3992d7bd34e72089c5cffc4fc6de3f70a3995145b989811f83b00b47c96b5159 +size 681924 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp index 3056a533d67..d95d392d536 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1e18db0cd4de65e76e30f219d24ec00095fb16005882c43322182c5fa3f59032 -size 445541 +oid sha256:521417177fc0447809c07ff86b58725fedbf1a6b9412ace4c50268a20bc2680d +size 447119 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80.cubin.cpp index 50d7f1becef..c405f483aed 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9aceb502c1a95f58f1eab515cf2aeac92be6d255ef405008a4fd871fd54e9ba6 +oid sha256:cb063c946558e6928faabb85df9775fecd2b9444b40b3e06cf0f863db80a5ad8 size 1242842 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp index 1a74df12889..e88a310b64b 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ec96248452f638bb9ca50d3630dd67caf71322c01b17aff301c4a98eb7e27974 -size 1215548 +oid sha256:31e6b7442b277f5206cc1d70fa6021f36170265b311106281e88b4611d1a5b6b +size 1220284 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp index e03f7c2575c..0db1249a289 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:dabc44860e81532e9b7ecb35773d0ad409d45361e20c9510d24387039999a7c3 -size 1720698 +oid sha256:c1342769efa91794d5bd35ac623b3014738b075b2671441668e2f0d5c1eef78a +size 1739642 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp index b1d87c1278f..4d68087ca12 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0d9c8d1fe282f46c12898ed4851a2640cb33ba5d75c5fe9da8a988f818a0e733 -size 407639 +oid sha256:a49dd8abcca57a64eb2ab4e00e4e0d26edf68488fb67086a4b466f8e6651522e +size 410007 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_sm90.cubin.cpp index 2a12ddb7118..deb498b1a29 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:849a280994b3fa1f18ca6c3866a16a68a9b02831f134f8dfcf0d34502c1d6772 +oid sha256:a7013b1eea12719ebeaf47facc37ef730bb0d6af03ca2ad890724a25448616a9 size 1102672 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp index a2c78e856df..4bf37280a0e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4e209b01409585433406f8392c77a7398270ee1b58446b728cf74faa6fe1bf9a +oid sha256:a16aeaf5d11a4c25461452b5f3145136b31861ef9c443d7ec82066565275d6f8 size 629884 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp index 61bbc8d762e..0115c2c36f3 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0a22bb0202916831eced0a44acbab769d5647937155e0a2b5e6d0d0cb83c726f -size 1122394 +oid sha256:a7d4526887fe860e0d9c482fc7fe2cfe646c7a20bc8a0813ce33a01fd9cc733c +size 1125550 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp index e0170f8db7f..5d1d2207551 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:582d17d48c7a751a345f74cc8c74f9b8c05278ddfc185da4906310a4973a9bdb -size 1547030 +oid sha256:b880e78ffc354edb541bd612e543dd894843fc4163f7bd65ce53282892381b8a +size 1566764 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_sm90.cubin.cpp index 456d75f72fe..fbab68022c3 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:70f02b7329eef7ceeb73dd43c3bf8f6ea6132c593bba6dbbed720d8b8ff0c287 +oid sha256:de26acaa532f197e339b6d5b2a2dd8032d505c9e169fce38000b02b2a4188eff size 603809 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp index 0c0712acaf1..8315c080842 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f67d4e70c39bf379ed0f3ef73a3690ac64efaee1e7134c793a760924c270f046 +oid sha256:cef5bcfe63650bc924d9e45d2755b50940534999fb4fbad3a8abf0ba73b9245a size 329935 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp index f35d06ef066..c57602da24a 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c2c284c6cb66207bd204bd1b6abe45aa8bf2e0c92631681861df237b8f849a46 -size 363451 +oid sha256:b332d4c6047c98b504cd3be72cc5028d240621c8e0a3260d64c17804982104db +size 365029 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp index 73d9547cf2d..a0fe210d9b0 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d3bede327d80be420e7bf011ee1a4156365afff7020bbf5a8434da18cb19fb23 -size 1093202 +oid sha256:a16c23767a2e5efbd7330728ed87af2ec62a7731debe1da557705c6db6d3268e +size 1096360 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp index 998e46d1f16..3c10c481369 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5ee7695bd5bb0a03eafe29a497060d84caec96ca4d159e99e4f02b99977dd2a6 -size 1469690 +oid sha256:66950bc137b734d509f0574152bcf9cf7efcb17a7483450d5fdbf480e9f83001 +size 1486266 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp index a76bf3814f7..0b4847611fd 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:cecca7ad5c652989a3008c8219177811ab9c7d617adbbc9ed8548141803c66f5 -size 694578 +oid sha256:bba586d9fe487c49cef2abfbfb0a078dde907d28e04b4d2335018cdb7031879c +size 701682 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp index 71a5743dd98..fb1751942e2 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bd6847c0e897eb794a9b1ff67e64358527fe64c3e01fc214545cf76ec60edc6d -size 644046 +oid sha256:d3e45ab30e471f4649807f5b7640512e2c6678cf623cadfcb26c93eb4ad60ec0 +size 654306 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp index ea50fb06310..ca8b31a0105 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:118cc6d4a5e3e12ce0f2727361fd1d52d1a49c67d0bd1837c24e528c064a0dd7 -size 415557 +oid sha256:1932937b7f4ad0370341c77a03db133dd676bdf844b13eb45ec10243d1dfd16b +size 417135 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp index 285c32ec70e..85d85fa4d99 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:36d6c97af5fb15f32cd1ff13f53dd98a7d670cb80ee766765f42cc453f730812 -size 1195826 +oid sha256:c11f5d464b0486023b78babfdfe9d2768e4b0d13caeb436d6f73110ede72498c +size 1198982 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp index bd266daa63a..465fcafeced 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7775bbc1b43487236cf7570d2ed900f1c9830eab70aac1fa9dc59c439cc0c687 -size 1657562 +oid sha256:3bac9b40302bbfc6ee5a49e5c45d3238f46cff45619acd1b098d90e758d3ce30 +size 1675716 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp index 2d3c2887bea..c65fa93d24e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:199b1ff3cc3d0ff04477ff8f1e6390dd62b3a7c9dd264cc73ce6c716af20a0f9 -size 366603 +oid sha256:26f09ab86b52c40b283652e555f677850f00902151d17e375e016b9a99a97794 +size 368183 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_sm90.cubin.cpp index e0073c3730b..36bdbdda6bf 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2e743b470f9607abcbc8b71e7ef67455e6104daf3a80d0bd012a96ecf90a8f18 +oid sha256:960c3f9e4fe46fc6390207ba0ed85ec25435045e2213b60c5d44ea9ab4fa56aa size 1128730 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_sm90.cubin.cpp index 1553e77aee6..58a89a84a2e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:366aa4e9f3263f73c4e76c0ea8008c0449b6d89bcade761500af949912786e32 +oid sha256:ac167d89ea3150f7b65614645ef09f13e2543bdc0523c1eddce5bbd9cfd306ee size 644892 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp index cd0531dde0e..cd64d2fe381 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5b8a8d76e17a24afd7af1dc5e112828f98ace78e3f85a7efaadb0cf1937085cc -size 1093198 +oid sha256:9d0cf59a8114940070448d87d02d9e83d53bb371ca9915c3983e03626d17024e +size 1097144 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp index 54fd20f69c9..f3194ad186e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:aeffa2db467fbae3ace85fae9f31e2b8a7c0923ab349ade42318ae6f55249ac8 -size 1462582 +oid sha256:ff1449b6795f5beda0b6a62e8a1171ce952b07c4e63b607c06f5fedddb2debe9 +size 1480736 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sm90.cubin.cpp index 673041f7af9..87c5afddecc 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ffc92513e64631c33290f1e88e5666f5b85251506d527745c493f2e90da39de4 +oid sha256:cb14ae0271f8a83216f67c111530d3fe1be2231541ded5f992ff45226ae90e69 size 678808 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90.cubin.cpp index c39e7fa450e..dad37ebd422 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:faad8cb1e44f5e16f61720966d2a6c9e782461c209cd8000263b50d42093444d +oid sha256:46a0d8e0a9495e03f72526b4ee04fa3d2a2d87984057b44550cabf4ffa745ef4 size 370201 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_32_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_32_ldgsts_sm90.cubin.cpp deleted file mode 100644 index e2ee736b49d..00000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_32_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:dd930ed415b0303a973a37550ee33fa4975ad6be0cc58d461370b127f9a90f8e -size 1020542 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_64_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_64_ldgsts_sm90.cubin.cpp deleted file mode 100644 index 95d9b2bf647..00000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_64_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4f2b243127e1ce00a850a10cca104ffc42512711f434fbdf8683eeeb49b8ce42 -size 1056062 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_32_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_32_ldgsts_sm90.cubin.cpp deleted file mode 100644 index 0c093db643c..00000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_32_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2ce9cc89b1db7f7e4b76b94cf1c3b04db49a2d86b529b1fc85b19057a99bc9fa -size 1007924 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_64_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_64_ldgsts_sm90.cubin.cpp deleted file mode 100644 index c24e239dd0c..00000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_64_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e176513fa0074d688620299dfca53adc3902491e97ea9b6938a4ceb2fcf17ef5 -size 1068702 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp index a0197d8083a..9454d308bc0 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp @@ -238,6 +238,9 @@ void FusedMHARunnerV2::setupKernelParams(MHARunnerParams runnerParams) mKernelParams.packed_mask_ptr = runnerParams.packedMaskPtr; mKernelParams.cu_mask_rows = reinterpret_cast(runnerParams.cuMaskRowsPtr); } + TLLM_CHECK_WITH_INFO( + runnerParams.attentionSinksPtr == nullptr || mSM == kSM_90, "The attention sinks is only supported on SM90."); + mKernelParams.attention_sinks_ptr = runnerParams.attentionSinksPtr; mKernelParams.cu_q_seqlens = reinterpret_cast(runnerParams.cuQSeqLenPtr); mKernelParams.tile_id_counter_ptr = reinterpret_cast(runnerParams.tileCounterPtr); // TRT doesn't support host scales. Use device scales instead. @@ -294,6 +297,11 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams) = mFixedParams.isSPadded ? runnerParams.b * runnerParams.qSeqLen : runnerParams.totalQSeqLen; mLaunchParams.total_kv_seqlen = mFixedParams.isSPadded ? runnerParams.b * runnerParams.kvSeqLen : runnerParams.totalKvSeqLen; + // Workaround for nvbug 5412456: total_kv_seqlen fallbacks to total_q_seqlen if it's zero. + if (mLaunchParams.total_kv_seqlen == 0) + { + mLaunchParams.total_kv_seqlen = mLaunchParams.total_q_seqlen; + } TLLM_CHECK_WITH_INFO(mFixedParams.headSize > 0, "Head size should be greater than 0."); // Pad head size to next power of 2. diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h index 96435cca528..e9098866161 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h @@ -263,6 +263,8 @@ struct MHARunnerParams void* outputSfPtr; // The softmax_status ptr for RingAttention. void* softmaxStatsPtr; + // The attention sinks ptr. + float const* attentionSinksPtr; // The packed mask ptr. void const* packedMaskPtr; // The cumulative Q sequence lengths. @@ -352,6 +354,8 @@ struct Fused_multihead_attention_params_v2 KVBlockArrayForContextFMHA paged_kv_cache; // The mask to implement drop-out. void const* packed_mask_ptr; + // The attention sinks. + float const* attention_sinks_ptr; // The O matrix (output). void* o_ptr; // The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max diff --git a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h index 934679a944c..6758558e277 100644 --- a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h +++ b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h @@ -56,6 +56,8 @@ enum class AllReduceStrategyType : int8_t ONESHOT = 4, TWOSHOT = 5, LOWPRECISION = 6, + MNNVL = 7, + NCCL_SYMMETRIC = 8, }; enum class AllReduceStrategyConfig : int8_t diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt b/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt index 7a02cdee73f..fd89ae4a194 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt @@ -218,6 +218,11 @@ if(USING_OSS_CUTLASS_MOE_GEMM) set(MOE_GEMM_SRC_CU_LAUNCHER ${MOE_GEMM_SRC_CU}) list(FILTER MOE_GEMM_SRC_CU_LAUNCHER EXCLUDE REGEX ".*moe_gemm_kernels_.*") list(FILTER MOE_GEMM_SRC_CU INCLUDE REGEX ".*moe_gemm_kernels_.*") + set(MOE_GEMM_SRC_CU_HOPPER_FP4 ${MOE_GEMM_SRC_CU}) + list(FILTER MOE_GEMM_SRC_CU_HOPPER_FP4 INCLUDE REGEX + ".*moe_gemm_kernels_(bf16|fp16)_fp4.*") + list(FILTER MOE_GEMM_SRC_CU EXCLUDE REGEX + ".*moe_gemm_kernels_(bf16|fp16)_fp4.*") set(MOE_GEMM_SRC_CU_FP4 ${MOE_GEMM_SRC_CU}) list(FILTER MOE_GEMM_SRC_CU_FP4 INCLUDE REGEX ".*fp4.*") list(FILTER MOE_GEMM_SRC_CU EXCLUDE REGEX ".*fp4.*") @@ -230,6 +235,10 @@ if(USING_OSS_CUTLASS_MOE_GEMM) add_library(_moe_gemm_launcher OBJECT ${MOE_GEMM_SRC_CU_LAUNCHER}) add_cuda_architectures(_moe_gemm_launcher 89) + add_library(_moe_gemm_hopper_fp4 OBJECT ${MOE_GEMM_SRC_CU_HOPPER_FP4}) + set_cuda_architectures(_moe_gemm_hopper_fp4 90) + process_target(_moe_gemm_hopper_fp4 true false) + add_library(_moe_gemm_fp4 OBJECT ${MOE_GEMM_SRC_CU_FP4}) set_cuda_architectures(_moe_gemm_fp4 100f 120f) process_target(_moe_gemm_fp4 false true) @@ -239,8 +248,9 @@ if(USING_OSS_CUTLASS_MOE_GEMM) process_target(_moe_gemm_fp8 true true) add_instantiations(moe_gemm_src ${INSTANTIATION_GENERATION_DIR}/gemm_grouped) - target_link_libraries(moe_gemm_src PRIVATE _moe_gemm_launcher _moe_gemm_fp4 - _moe_gemm_fp8) + target_link_libraries( + moe_gemm_src PRIVATE _moe_gemm_launcher _moe_gemm_hopper_fp4 _moe_gemm_fp4 + _moe_gemm_fp8) target_include_directories( moe_gemm_src PUBLIC ${PROJECT_SOURCE_DIR}/tensorrt_llm/cutlass_extensions/include) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 9e3bbaa32b7..837b916f366 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -377,72 +377,62 @@ std::vector get_candidate_configs_sm100(CutlassGemmConfig::Ca if (config & CutlassGemmConfig::GROUPED_GEMM) { std::vector candidate_configs; - if ((config & CutlassGemmConfig::FP4_ONLY) != 0) + if (config & CutlassGemmConfig::FP4_ONLY) { candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x128x128B, + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1}); candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x256x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); return candidate_configs; } - for (int cluster_m = 1; cluster_m <= 2; cluster_m++) + std::vector> tile_configs{ + {CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x2x1}, + {CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x2x1}, + {CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x2x1}, + {CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x2x1}, + {CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape64x32x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x2x1}, + }; + + if (config & CutlassGemmConfig::FP8_ONLY) { - bool Is2SM = cluster_m == 2; - for (int cluster_n = 1; cluster_n <= 2; cluster_n++) - { - std::vector base = {// M=128 - CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B}; - - if (Is2SM) - { - if (cluster_n == 1) - { - base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B); - base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B); - } - - std::vector twosm = {// M=256 - CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B}; - std::copy(twosm.begin(), twosm.end(), std::back_inserter(base)); - } - else - { - if (cluster_n == 1) - { - base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B); - if ((config & CutlassGemmConfig::FP8_ONLY) != 0) - { - base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B); - } - } - - std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B, - CutlassTileConfigSM100::CtaShape64x128x128B, CutlassTileConfigSM100::CtaShape64x256x128B, - CutlassTileConfigSM100::CtaShape128x64x128B}; - std::copy(onesm.begin(), onesm.end(), std::back_inserter(base)); - } + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x16x128B, ClusterShape::ClusterShape_1x1x1}); + // TODO: re-enable when handled by the MoE GEMM dispatch + // tile_configs.push_back({ CutlassTileConfigSM100::CtaShape128x8x256B, ClusterShape::ClusterShape_1x1x1 }); + } - constexpr std::array cluster_shapes - = {std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1}, - std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}}; - auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1]; - for (auto tile : base) - { - CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster}; - candidate_configs.push_back(config); - } - } + for (auto [tile, cluster] : tile_configs) + { + CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster}; + candidate_configs.push_back(config); } return candidate_configs; } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h index e6c3a6bbfa2..646be2575ca 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h @@ -27,6 +27,7 @@ enum class ActivationType Silu, Swiglu, Geglu, + SwigluBias, Identity, InvalidType }; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index 7ddd756e0d0..3c814851c91 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -37,11 +37,6 @@ namespace tensorrt_llm::kernels::cutlass_kernels { -template -constexpr auto transpose_stride(T const& t) -{ - return cute::prepend(cute::prepend(cute::take<2, cute::rank_v>(t), cute::get<0>(t)), cute::get<1>(t)); -} template struct GroupedGemmInput @@ -72,8 +67,6 @@ struct GroupedGemmInput struct TmaWarpSpecializedGroupedGemmInput { - template - using TransposeStride = decltype(transpose_stride(T{})); template using TransposeLayoutTag = std::conditional_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; @@ -86,6 +79,7 @@ struct TmaWarpSpecializedGroupedGemmInput using LayoutA = TransposeLayoutTag; // Layout type for A matrix operand using LayoutB = TransposeLayoutTag; // Layout type for B matrix operand using LayoutC = TransposeLayoutTag; // Layout type for C matrix operand + using LayoutD = TransposeLayoutTag; // Layout type for D matrix operand constexpr static int NVFP4BlockScaleVectorSize = 16; constexpr static int MXFPXBlockScaleVectorSize = 32; @@ -121,6 +115,7 @@ struct TmaWarpSpecializedGroupedGemmInput using StrideB = std::remove_pointer_t>; // Use A because they will be swapped using StrideC = std::remove_pointer_t>; + using StrideD = std::remove_pointer_t>; #ifdef ENABLE_FP8 template @@ -147,37 +142,26 @@ struct TmaWarpSpecializedGroupedGemmInput StrideC* stride_c = nullptr; void const** ptr_c = nullptr; - struct DefaultEpilogue - { - using LayoutD = TransposeLayoutTag; // Layout type for D matrix operand - using StrideD = std::remove_pointer_t>; - - StrideD* stride_d = nullptr; - void** ptr_d = nullptr; - }; + // D is used in all cases except fused finalize + StrideD* stride_d = nullptr; + void** ptr_d = nullptr; struct FusedFinalizeEpilogue { - using StrideFinalOutput = DefaultEpilogue::StrideD; - using StrideBias = TransposeStride>; - using StrideRouterScales = TransposeStride>; + using StrideFinalOutput = cutlass::detail::TagToStrideC_t; void* ptr_final_output = nullptr; StrideFinalOutput stride_final_output{}; - void const* ptr_bias = nullptr; - StrideBias stride_bias{}; - - float const* ptr_router_scales = nullptr; - StrideRouterScales stride_router_scales{}; + void const** ptr_bias = nullptr; + float const** ptr_router_scales = nullptr; - int64_t const* ptr_expert_first_token_offset = nullptr; - int const* ptr_source_token_index = nullptr; + int const** ptr_source_token_index = nullptr; + int num_rows_in_final_output = 0; - size_t num_rows_in_final_output = 0; + bool use_reduction = true; }; - DefaultEpilogue default_epilogue; FusedFinalizeEpilogue fused_finalize_epilogue; enum class EpilogueFusion @@ -210,8 +194,10 @@ struct TmaWarpSpecializedGroupedGemmInput struct INT4GroupwiseParams { - constexpr static int group_size = 128; // Unused, hard-coded to 128 + constexpr static int int4_group_size = 128; + constexpr static int wfp4a16_group_size = 32; bool enabled = false; + bool use_wfp4a16 = false; using SFA = __nv_bfloat16; using SFB = __nv_bfloat16; // Unused using ProblemShapeInt = cutlass::gemm::GroupProblemShape>; @@ -233,7 +219,7 @@ struct TmaWarpSpecializedGroupedGemmInput uint8_t* gemm_workspace = nullptr; size_t gemm_workspace_size = 0; - static std::array workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); + static std::array workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type); @@ -245,16 +231,15 @@ struct TmaWarpSpecializedGroupedGemmInput return stride_a != nullptr && ptr_a != nullptr; } - void setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens); + void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, bool use_reduction); std::string toString() const; }; constexpr bool isGatedActivation(ActivationType activation_type) { - return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu; + return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu + || activation_type == ActivationType::SwigluBias; } template && (std::is_same_v || std::is_same_v); +#else + static constexpr bool use_wfp4a16 = std::is_same_v && std::is_same_v; +#endif #if defined(ENABLE_FP8) static constexpr bool use_fp8 = (std::is_same_v @@ -282,6 +273,7 @@ class MoeGemmRunner static constexpr bool use_w4afp8 = false; static constexpr bool use_wfp4afp4 = false; #endif + static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; #if defined(ENABLE_FP4) static constexpr bool use_fp4 = std::is_same_v; @@ -306,9 +298,9 @@ class MoeGemmRunner [[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const; [[nodiscard]] bool supportsTmaWarpSpecialized() const; - [[nodiscard]] bool isFusedGatedActivation( - cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const; - [[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const; + [[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config, + ActivationType activation_type, int gemm_n, int gemm_k) const; + [[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const; size_t getMaxWorkspaceSize(int num_experts) const; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index c7c9a55b959..7d592bed0e4 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -87,6 +87,62 @@ struct LoraParams namespace cutlass_kernels { +static inline size_t pad_to_multiple_of_16(size_t const& input) +{ + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +class CubKeyValueSorter +{ +public: + CubKeyValueSorter(); + + CubKeyValueSorter(int const num_experts_per_node); + + void updateNumExperts(int const num_experts_per_node); + + static size_t getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts_per_node); + + void run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out, int const* values_in, + int* values_out, size_t const num_key_value_pairs, cudaStream_t stream); + +private: + static int expertsToBits(int experts); + int num_experts_; + int num_bits_; +}; + +struct ActivationParams +{ + ActivationType activation_type; + float const* swiglu_alpha = nullptr; + float const* swiglu_beta = nullptr; + float const* swiglu_limit = nullptr; + + explicit ActivationParams(ActivationType activation_type) + : activation_type(activation_type) + { + TLLM_CHECK_WITH_INFO(activation_type != ActivationType::SwigluBias, + "SwigluBias is not supported in ActivationParams without swiglu_alpha and swiglu_beta"); + } + + ActivationParams( + ActivationType activation_type, float const* swiglu_alpha, float const* swiglu_beta, float const* swiglu_limit) + : activation_type(activation_type) + , swiglu_alpha(swiglu_alpha) + , swiglu_beta(swiglu_beta) + , swiglu_limit(swiglu_limit) + { + } + + // TODO Port everything properly and get rid of these implicit conversions + operator ActivationType() const + { + return activation_type; + } +}; + /** * \brief Describes what parallelism mode the MoE is using * @@ -392,14 +448,14 @@ class CutlassMoeFCRunnerInterface = 0; virtual std::vector getTactics() = 0; - virtual void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts, - float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, - ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, - QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, - bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, cudaStream_t stream) + virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, + int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, + void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights, + void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, + void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, + bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) = 0; // Aliases for profiling the gemms @@ -410,7 +466,7 @@ class CutlassMoeFCRunnerInterface float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids) @@ -439,7 +495,8 @@ class CutlassMoeFCRunnerInterface void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1, - void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) + void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, cudaStream_t stream) = 0; virtual std::pair @@ -456,13 +513,13 @@ class CutlassMoeFCRunnerInterface virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0; bool is_profiler = false; - bool use_deterministic_hopper_reduce_ = false; + bool use_fused_finalize_ = true; }; // Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc . // Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive. // Avoid making several duplicates of this class. -template ; + +#if defined(ENABLE_BF16) + static constexpr bool use_wfp4a16 + = std::is_same_v && (std::is_same_v || std::is_same_v); +#else + static constexpr bool use_wfp4a16 = std::is_same_v && std::is_same_v; +#endif #if defined(ENABLE_FP8) static constexpr bool use_fp8 = (std::is_same_v || std::is_same_v) &&!std::is_same_v; @@ -485,6 +549,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface static constexpr bool use_fp8 = false; static constexpr bool use_w4afp8 = false; #endif + static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; #if defined(ENABLE_FP4) static constexpr bool act_fp4 = std::is_same_v; static constexpr bool weight_fp4 = std::is_same_v; @@ -539,14 +604,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface return RunnerType::getConfigs(sm); } - void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts, - float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, - ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, - QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, - bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; + void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, + int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, + void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights, + void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, + void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, + bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; // We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work static void gemm1(MoeGemmRunner& gemm_runner, @@ -563,7 +628,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids); @@ -591,7 +656,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids) override @@ -645,7 +710,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1, - void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) override + void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, cudaStream_t stream) override { return Self::computeStridesTmaWarpSpecialized(expert_first_token_offset, layout_info1, layout_info2, num_tokens, expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node, @@ -654,7 +720,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params, reinterpret_cast(bias1), reinterpret_cast(bias2), reinterpret_cast(gemm1_output), - reinterpret_cast(gemm2_output), stream); + reinterpret_cast(gemm2_output), router_scales, permuted_row_to_unpermuted_row, + stream); } std::pair @@ -679,7 +746,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface private: std::pair setupTmaWarpSpecializedInputs( - int64_t num_rows, int64_t expanded_num_rows, ActivationType fc1_activation_type, int64_t hidden_size, + int64_t num_rows, int64_t expanded_num_rows, ActivationParams fc1_activation_type, int64_t hidden_size, int64_t inter_size, int64_t num_experts_per_node, void const* input_activations_void, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void* final_output, WeightType const* fc1_expert_weights, WeightType const* fc2_expert_weights, QuantParams quant_params, @@ -696,7 +763,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, - UnfusedGemmOutputType* gemm2_output, cudaStream_t stream); + UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row, + cudaStream_t stream); static std::pair computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1, TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, @@ -726,8 +794,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface bool mayHaveFinalizeFused() const { - return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90 - && !use_deterministic_hopper_reduce_ && !use_w4afp8; + return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_ + && !use_w4_groupwise; } // TODO: This should eventually take the quant params to give more flexibility @@ -758,7 +826,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases, float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, - ActivationType fc1_activation_type, QuantParams& quant_params, cudaStream_t stream); + ActivationParams fc1_activation_type, QuantParams& quant_params, cudaStream_t stream); static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, OutputType* const final_output, int64_t const* const expert_first_token_offset, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h index b1676993ded..0b86afda684 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h @@ -58,7 +58,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, int const k, int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, cudaStream_t stream); + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + void const* prequant_scales, cudaStream_t stream); template void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index d5f0b198fd8..e92186a3f5c 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,21 +26,14 @@ #include "cute/tensor.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/fusion/operations.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/group_array_problem_shape.hpp" #include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/tensor_ref.h" -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -189,17 +182,19 @@ using SafeBF16 = void; TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, int const multi_processor_count, \ cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size) \ { \ + using ArchTag = cutlass::arch::ArchTag_; \ constexpr static EpilogueFusion FUSION = EpilogueFusion::FUSION_; \ + constexpr static bool IsMXFPX = MXFPX_; \ + constexpr bool IsBlackwell = ArchTag::kMinComputeCapability >= 100; \ + constexpr bool IsSM120 = ArchTag::kMinComputeCapability == 120 || ArchTag::kMinComputeCapability == 121; \ + constexpr bool Is2SM = IsBlackwell && (CGA_M_ % 2 == 0); \ /* constexpr static bool BIAS = BIAS_; */ /* Always false */ \ - using ArchTag = cutlass::arch::ArchTag_; \ using T = DataType_; \ using WeightType = WeightType_; \ using OutputType = OutputType_; \ using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \ - using TileShape = cute::Shape, cute::Int, cute::Int>; \ + using MmaTileShape = cute::Shape, cute::Int, cute::Int>; \ using ClusterShape = cute::Shape, cute::Int, cute::Int>; \ - constexpr static bool IsMXFPX = MXFPX_; \ - \ if constexpr (!COMPILE_HOPPER_TMA_GROUPED_GEMMS_ENABLED && ArchTag::kMinComputeCapability >= 90 \ && ArchTag::kMinComputeCapability < 100) \ { \ @@ -217,18 +212,15 @@ using SafeBF16 = void; TLLM_THROW( \ "Please recompile with support for blackwell by passing 120-real as an arch to build_wheel.py."); \ } \ - else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v) \ + else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v) \ { \ using namespace cute; \ /* Helper class for defining all the cutlass types \ // template \ + // typename MmaTileShape, typename ClusterShape, bool BIAS, EpilogueFusion FUSION> \ // struct TmaWarpSpecializedGroupedGemmInfo \ { */ \ - using Arch = ArchTag; \ - constexpr static bool IsBlackwell = Arch::kMinComputeCapability >= 100; \ - constexpr static bool IsSM120 = Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121; \ constexpr static bool IsWFP4AFP8 = cutlass::platform::is_same::value \ && cutlass::platform::is_same::value; \ constexpr static bool IsFP4 = cutlass::platform::is_same::value; \ @@ -308,8 +300,8 @@ using SafeBF16 = void; // units of elements (up to 16 bytes)*/ \ \ /* D matrix configuration */ \ - using LayoutD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::LayoutD; \ - using StrideD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD; \ + using LayoutD = TmaWarpSpecializedGroupedGemmInput::LayoutD; \ + using StrideD = TmaWarpSpecializedGroupedGemmInput::StrideD; \ constexpr static int AlignmentD \ = 128 / cutlass::sizeof_bits::value; /* Memory access granularity/alignment of D matrix \ // in units of elements (up to 16 bytes) */ \ @@ -327,30 +319,24 @@ using SafeBF16 = void; // cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, \ // cutlass::epilogue::?????????????????? /// <<<<<< what supports activations \ // >;*/ \ - using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; \ + using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \ \ - constexpr static bool Is2SM = IsBlackwell && (cute::size<0>(ClusterShape{}) % 2) == 0; \ using EpilogueScheduleSM100 = std::conditional_t; \ using EpilogueScheduleSM120 = cutlass::epilogue::TmaWarpSpecialized; \ - using EpilogueScheduleBW = std ::conditional_t; \ + using EpilogueScheduleBW = std::conditional_t; \ using EpilogueSchedule = std::conditional_t; \ \ - using EpilogueTileShapeSm90 = TileShape; \ - using AtomClusterDiv = std::conditional_t; \ - using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape{})); \ - using EpilogueTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \ - using EpilogueTileShape = std::conditional_t; \ using EpilogueElementC = std::conditional_t; \ using EpilogueTensorOp = std::conditional_t; \ - using EpilogueSubTile \ - = std::conditional_t, cutlass::epilogue::collective::EpilogueTileAuto>; \ + using EpilogueSubTile = std::conditional_t, cutlass::epilogue::collective::EpilogueTileAuto>; \ /* Epilogue For Default Finalize */ \ using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder::CollectiveOp; \ \ /* Epilogue For Fused Finalize */ \ - using CollectiveEpilogueFinalize = \ - typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< /**/ \ - Arch, EpilogueTileShape, /**/ \ - ElementCSafe, StrideC*, /**/ \ - ElementFinalOutput, \ - TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, /**/ \ - ElementAccumulator, /**/ \ - ElementAccumulator, /**/ \ - ElementBias, TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, /**/ \ - ElementRouterScales, \ - TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales /**/ \ - >::CollectiveOp; \ + using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::CollectiveBuilder /**/ \ + >::CollectiveOp; \ \ using CollectiveEpilogue = std::conditional_t; \ @@ -405,16 +390,12 @@ using SafeBF16 = void; using MainloopElementA = std::conditional_t; \ using MainloopElementB = std::conditional_t; \ \ - using MainloopTileShapeSm90 = TileShape; \ - using MainloopTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \ - using MainloopTileShape = std::conditional_t; \ - \ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder::CollectiveOp; \ \ using GemmKernel = cutlass::gemm::kernel::GemmUniversal; \ /*}; \ - \ \ + // \ // using namespace cute; \ // using GemmInfo = TmaWarpSpecializedGroupedGemmInfo;; \ // \ // using ElementAccumulator = typename GemmInfo::ElementAccumulator; \ @@ -478,7 +459,7 @@ using SafeBF16 = void; TLLM_CHECK(tma_ws_input.ptr_a); \ TLLM_CHECK(tma_ws_input.ptr_b); \ \ - auto make_mainloop_params = [&]() -> MainloopArguments \ + MainloopArguments const mainloop_args = [&] \ { \ if constexpr (IsBlockScaled) \ { \ @@ -498,67 +479,46 @@ using SafeBF16 = void; reinterpret_cast(tma_ws_input.ptr_b), tma_ws_input.stride_b, \ reinterpret_cast(tma_ws_input.ptr_a), tma_ws_input.stride_a); \ } \ - }; \ - \ - auto const mainloop_params = make_mainloop_params(); \ - \ - using EpilogueArguments = typename CollectiveEpilogue::Arguments; \ - using EpilogueScalars = decltype(EpilogueArguments{}.thread); \ - auto make_epilogue_scalars = [&]() \ + }(); \ + using FusionArguments = typename CollectiveEpilogue::FusionCallbacks::Arguments; \ + FusionArguments fusion_args = [&] \ { \ - if constexpr (IsBlackwell) \ - { \ - return construct_if_true(ElementAccumulator(1.f), \ - tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f), nullptr, nullptr, \ - tma_ws_input.alpha_scale_ptr_array, nullptr, \ - cute::Shape<_0, _0, int64_t>{ \ - cute::_0{}, cute::_0{}, (tma_ws_input.alpha_scale_ptr_array != nullptr) ? 1 : 0}, \ - cute::Shape<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 0}); \ - } \ - else if (tma_ws_input.alpha_scale_ptr_array) \ + if constexpr (FUSION == EpilogueFusion::FINALIZE) \ { \ - return construct_if_true(tma_ws_input.alpha_scale_ptr_array); \ + auto epi_params = tma_ws_input.fused_finalize_epilogue; \ + return construct_if_true( \ + ElementAccumulator(1), nullptr, tma_ws_input.alpha_scale_ptr_array, \ + Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, /* alpha */ \ + reinterpret_cast(epi_params.ptr_bias), \ + Stride<_1, _0, int64_t>{}, /* bias */ \ + epi_params.ptr_router_scales, Stride<_0, _1, int64_t>{}, /* scale */ \ + reinterpret_cast(epi_params.ptr_final_output), \ + epi_params.stride_final_output, epi_params.ptr_source_token_index, \ + epi_params.num_rows_in_final_output, epi_params.use_reduction); \ } \ else \ { \ - return construct_if_true(ElementAccumulator(1.f), \ - tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); \ + return construct_if_true( \ + ElementAccumulator(1), ElementAccumulator(0), nullptr, nullptr, \ + tma_ws_input.alpha_scale_ptr_array, nullptr, \ + Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, Stride<_0, _0, int64_t>{}); \ } \ - }; \ - auto epilogue_scalars = make_epilogue_scalars(); \ - /* TODO ptr_c casts to ElementCSafe** because there is a workaround in CUTLASS */ \ - auto make_epi_args = [&]() \ - { \ - static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE, \ - "Unimplemented fusion provided to TMA WS MoE gemm launcher"); \ + }(); \ \ - if constexpr (FUSION == EpilogueFusion::NONE) \ + using EpilogueArguments = typename CollectiveEpilogue::Arguments; \ + EpilogueArguments epilogue_args = [&] \ + { \ + if constexpr (FUSION == EpilogueFusion::FINALIZE) \ { \ - auto epi_params = tma_ws_input.default_epilogue; \ - return construct_if_true(epilogue_scalars, \ - nullptr, tma_ws_input.stride_c, reinterpret_cast(epi_params.ptr_d), \ - epi_params.stride_d); \ + return construct_if_true( \ + fusion_args, nullptr, nullptr, nullptr, nullptr); \ } \ - else if constexpr (FUSION == EpilogueFusion::FINALIZE) \ + else \ { \ - /* Parameters for fused finalize */ \ - auto epi_params = tma_ws_input.fused_finalize_epilogue; \ - return construct_if_true( \ - epilogue_scalars, /* Parameters to underlying epilogue */ \ - nullptr, tma_ws_input.stride_c, /* C params */ \ - reinterpret_cast(epi_params.ptr_final_output), \ - epi_params.stride_final_output, /* D (output) params */ \ - reinterpret_cast(epi_params.ptr_bias), \ - epi_params.stride_bias, /* Bias params */ \ - epi_params.ptr_router_scales, epi_params.stride_router_scales, /* Router scales */ \ - epi_params.ptr_expert_first_token_offset, /* Offset of this expert's token in the \ - router scales */ \ - epi_params.ptr_source_token_index, /* Index of the source token to sum into */ \ - epi_params.num_rows_in_final_output /* Number of tokens in the output buffer */ \ - ); \ + return construct_if_true(fusion_args, \ + nullptr, nullptr, reinterpret_cast(tma_ws_input.ptr_d), tma_ws_input.stride_d); \ } \ - }; \ - EpilogueArguments const epilogue_params = make_epi_args(); \ + }(); \ /* EpilogueArguments const epilogue_params = make_epi_args( \ // tma_ws_input, epilogue_scalars \ @@ -568,7 +528,7 @@ using SafeBF16 = void; 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \ \ const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \ - tma_ws_input.shape_info, mainloop_params, epilogue_params, hw_info, scheduler_args}; \ + tma_ws_input.shape_info, mainloop_args, epilogue_args, hw_info, scheduler_args}; \ \ size_t calculated_ws_size = gemm.get_workspace_size(args); \ TLLM_CHECK_WITH_INFO(calculated_ws_size <= tma_ws_input.gemm_workspace_size, \ diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl index a0ebfbde343..719824c4c6c 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl @@ -85,15 +85,14 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput::type; - using ElementA = cutlass::float_e4m3_t; + using ElementA = typename TllmToCutlassTypeAdapter::type; using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) // B matrix configuration - // using ElementB = typename TllmToCutlassTypeAdapter::type; - using ElementB = typename cutlass::int4b_t; + using ElementB_ = typename TllmToCutlassTypeAdapter::type; + using ElementB = std::conditional_t, cutlass::int4b_t, ElementB_>; using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of @@ -108,9 +107,13 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput>; // Scale configuration - constexpr int PackedScalesNum = get<2>(CTAShape{}) / 128; - using ElementScalePacked - = cutlass::Array; + constexpr bool use_wfp4a16 = std::is_same_v; + constexpr int group_size = use_wfp4a16 ? cutlass::gemm::collective::detail::mxfp4_group_size + : cutlass::gemm::collective::detail::int4_group_size; + constexpr int PackedScalesNum = get<2>(CTAShape{}) / group_size; + using ElementScale = std::conditional_t; + using ElementScalePacked = cutlass::Array; using LayoutScale = cutlass::layout::RowMajor; // C/D matrix configuration @@ -170,20 +173,21 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput(hopper_inputs.ptr_b), hopper_inputs.stride_b, reinterpret_cast(hopper_inputs.ptr_a), hopper_inputs.stride_a, reinterpret_cast(hopper_inputs.int4_groupwise_params.ptr_s_a), - hopper_inputs.int4_groupwise_params.stride_s_a, int(inputs.groupwise_quant_group_size)}, + hopper_inputs.int4_groupwise_params.stride_s_a, group_size}, {fusion_args, reinterpret_cast(hopper_inputs.ptr_c), hopper_inputs.stride_c, - reinterpret_cast(hopper_inputs.default_epilogue.ptr_d), - hopper_inputs.default_epilogue.stride_d}, + reinterpret_cast(hopper_inputs.ptr_d), hopper_inputs.stride_d}, hw_info}; *workspace_size = gemm.get_workspace_size(args); return; @@ -205,10 +208,9 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput(hopper_inputs.ptr_b), hopper_inputs.stride_b, reinterpret_cast(hopper_inputs.ptr_a), hopper_inputs.stride_a, reinterpret_cast(hopper_inputs.int4_groupwise_params.ptr_s_a), - hopper_inputs.int4_groupwise_params.stride_s_a, int(inputs.groupwise_quant_group_size)}, + hopper_inputs.int4_groupwise_params.stride_s_a, group_size}, {fusion_args, reinterpret_cast(hopper_inputs.ptr_c), hopper_inputs.stride_c, - reinterpret_cast(hopper_inputs.default_epilogue.ptr_d), - hopper_inputs.default_epilogue.stride_d}, + reinterpret_cast(hopper_inputs.ptr_d), hopper_inputs.stride_d}, hw_info}; if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu new file mode 100644 index 00000000000..be29019bc6a --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_template_dispatch.h" + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, __nv_fp4_e2m1, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu new file mode 100644 index 00000000000..f1a885ea77d --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_template_dispatch.h" + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +template class MoeGemmRunner; +} diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index ff582ec6e68..56a8299f18f 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -99,6 +99,7 @@ struct genericMoeGemmKernelLauncher static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value || cutlass::platform::is_same::value); static_assert(arch::kMinComputeCapability < 90, "Sm90+ architecture should use specialized kernels"); @@ -503,7 +504,8 @@ MoeGemmRunner::getAmpereConfigs(int sm auto config_type_param = static_cast( weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); - if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation() || (use_w4afp8 && sm != 89)) + if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation() || (use_w4afp8 && sm != 89) + || use_wfp4a16) { return {}; } @@ -580,18 +582,19 @@ int MoeGemmRunner::getSM() const // currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction template bool MoeGemmRunner::supportsFusedGatedActivation( - bool is_gated_activation, int gemm_n, int gemm_k) const + ActivationType activation_type, int gemm_n, int gemm_k) const { constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true; - return is_gated_activation && std::is_same_v && !std::is_same_v && !use_fp8 - && (this->getSM() >= 80) && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION; + return (activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu) + && std::is_same_v && !std::is_same_v && !use_fp8 && (this->getSM() >= 80) + && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION; } template bool MoeGemmRunner::isFusedGatedActivation( - cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const + cutlass_extensions::CutlassGemmConfig gemm_config, ActivationType activation_type, int gemm_n, int gemm_k) const { - return supportsFusedGatedActivation(is_gated_activation, gemm_n, gemm_k) && !gemm_config.is_tma_warp_specialized; + return supportsFusedGatedActivation(activation_type, gemm_n, gemm_k) && !gemm_config.is_tma_warp_specialized; } template @@ -623,26 +626,41 @@ void MoeGemmRunner::dispatchToArch( if (sm_ >= 75 && sm_ < 80) { - dispatchMoeGemmToCutlass( - inputs, multi_processor_count_); + if constexpr (!std::is_same_v) + { + dispatchMoeGemmToCutlass( + inputs, multi_processor_count_); + } + else + { + TLLM_THROW("FP4 data type is not supported on SM < 90"); + } } else if (sm_ >= 80 && sm_ < 90) { - if constexpr (use_fp8 || use_w4afp8) + + if constexpr (!std::is_same_v) { + if constexpr (use_fp8 || use_w4afp8) + { #if defined(ENABLE_FP8) - static_assert(!std::is_same_v && !std::is_same_v, - "FP8 GEMM Output not supported"); + static_assert(!std::is_same_v && !std::is_same_v, + "FP8 GEMM Output not supported"); #endif - TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); - dispatchMoeGemmToCutlass( - inputs, multi_processor_count_); + TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); + dispatchMoeGemmToCutlass( + inputs, multi_processor_count_); + } + else + { + dispatchMoeGemmToCutlass( + inputs, multi_processor_count_); + } } else { - dispatchMoeGemmToCutlass( - inputs, multi_processor_count_); + TLLM_THROW("FP4 data type is not supported on SM < 90"); } } else if (sm_ >= 90) @@ -659,7 +677,7 @@ void MoeGemmRunner::dispatchToArch( } if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() - && !use_w4afp8) + && !use_w4_groupwise) { // We allow both tma warp specialized and SM80 configurations to coexist because for some cases with small // numbers of tokens SM80 is faster. We check here to see which is selected @@ -701,33 +719,39 @@ void MoeGemmRunner::dispatchToArch( // Hopper finegrained INT4 WS grouped GEMM if constexpr (use_w4afp8) { - if (inputs.gemm_config.is_tma_warp_specialized) + TLLM_CHECK_WITH_INFO( + inputs.gemm_config.is_tma_warp_specialized, "w4afp8 is only supported for TMA warp specialization"); + // EpilogueTag is ignored + if (inputs.k % 512 == 0) { - // EpilogueTag is ignored - if (inputs.k % 512 == 0) - { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( - inputs, hopper_inputs, multi_processor_count_, nullptr); - } - else if (inputs.k % 256 == 0) - { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( - inputs, hopper_inputs, multi_processor_count_, nullptr); - } - else if (inputs.k % 128 == 0) - { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( - inputs, hopper_inputs, multi_processor_count_, nullptr); - } - else - { - TLLM_THROW("Invalid GEMM K size %d", (int) inputs.k); - } - return; - }; + sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass(inputs, hopper_inputs, multi_processor_count_, nullptr); + } + else if (inputs.k % 256 == 0) + { + sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass(inputs, hopper_inputs, multi_processor_count_, nullptr); + } + else if (inputs.k % 128 == 0) + { + sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass(inputs, hopper_inputs, multi_processor_count_, nullptr); + } + else + { + TLLM_THROW("Invalid GEMM K size %d", (int) inputs.k); + } + return; + } + + if constexpr (use_wfp4a16) + { + TLLM_CHECK_WITH_INFO( + inputs.gemm_config.is_tma_warp_specialized, "wfp4a16 is only supported for TMA warp specialization"); + // EpilogueTag is ignored + sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass(inputs, hopper_inputs, multi_processor_count_, nullptr); + return; } #endif @@ -779,7 +803,7 @@ size_t MoeGemmRunner::getMaxWorkspaceS template size_t MoeGemmRunner::calcMaxWorkspaceSize(int num_experts) const { - if constexpr (use_w4afp8) + if constexpr (use_w4_groupwise) { return calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput( num_experts, multi_processor_count_); @@ -788,7 +812,8 @@ size_t MoeGemmRunner::calcMaxWorkspace { return 0; } - if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() && !use_w4afp8) + if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() && !use_w4afp8 + && !use_wfp4a16) { auto configs = getTmaWarpSpecializedConfigs(sm_); auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index d9df31513f3..40496a6a0eb 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -138,11 +138,11 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn } } -template +template constexpr bool are_tile_shapes_supported_sm100() { using namespace cute; - using CtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + // This is the epilogue shape. The MMA shape will be twice this for 2SM constexpr auto TileM = size<0>(CtaShape{}); constexpr auto TileN = size<1>(CtaShape{}); @@ -353,6 +353,7 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG { switch (gemm_config.tile_config_sm100) { + SHAPE_CASE(100, 64, 32, 128) SHAPE_CASE(100, 64, 64, 128) SHAPE_CASE(100, 64, 128, 128) SHAPE_CASE(100, 64, 256, 128) @@ -363,13 +364,8 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG SHAPE_CASE(100, 128, 128, 128) SHAPE_CASE(100, 128, 256, 128) - SHAPE_CASE(100, 256, 64, 128) - SHAPE_CASE(100, 256, 128, 128) - SHAPE_CASE(100, 256, 256, 128) - // SHAPE_CASE(100, 128, 128, 64) // SHAPE_CASE(100, 128, 256, 64) - // SHAPE_CASE(100, 256, 256, 64) DEFAULT_CASE(100) } } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h index 9a9f2ebeb38..affa4d8c409 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h @@ -153,10 +153,13 @@ void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best // for mixed type gemms. - constexpr int Ktile = 128 * PackedScalesNum / sizeof(T); - TLLM_CHECK(sizeof(T) == 1); + constexpr int Ntile = (std::is_same_v) ? 64 : 128; + constexpr int Ktile = (std::is_same_v) ? 128 : 128 * PackedScalesNum / sizeof(T); + TLLM_CHECK(sizeof(T) == (std::is_same_v) ? 2 : 1); + using _Ntile = Int; using _Ktile = Int; + switch (inputs.gemm_config.tile_config_sm90) { case tkc::CutlassTileConfigSM90::CtaShape64x16x128B: @@ -172,8 +175,8 @@ void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( inputs, hopper_inputs, sm_count_, workspace_size); break; case tkc::CutlassTileConfigSM90::CtaShape64x128x128B: - sm90_dispatch_moe_mixed_dtype_gemm_config>( - inputs, hopper_inputs, sm_count_, workspace_size); + sm90_dispatch_moe_mixed_dtype_gemm_config>(inputs, hopper_inputs, sm_count_, workspace_size); break; // case tkc::CutlassTileConfigSM90::CtaShape64x256x128B: // sm90_dispatch_moe_mixed_dtype_gemm_config size_t calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput(int num_experts, int sm_count_) { size_t count = 0; + constexpr int Ktile = (std::is_same_v) ? 256 : 512; + using _Ktile = Int; + #ifdef COMPILE_HOPPER_TMA_GROUPED_GEMMS GroupedGemmInput inputs{}; inputs.num_experts = num_experts; sm90_generic_mixed_moe_gemm_kernelLauncher, Shape<_1, _1, _1>, + tensorrt_llm::cutlass_extensions::EpilogueOpDefault, Shape<_128, _64, _Ktile>, Shape<_1, _1, _1>, cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>( inputs, TmaWarpSpecializedGroupedGemmInput{}, sm_count_, &count); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu index 485c19496f3..b49dfec9992 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu @@ -27,14 +27,14 @@ namespace tensorrt_llm::kernels::cutlass_kernels { -std::array TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( +std::array TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( int num_experts, FpXBlockScalingType scaling_type) { size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts; size_t stride_a_size = sizeof(StrideA) * num_experts; size_t stride_b_size = sizeof(StrideB) * num_experts; size_t stride_c_size = sizeof(StrideC) * num_experts; - size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts; + size_t stride_d_size = sizeof(StrideD) * num_experts; size_t ptr_buf_size = sizeof(void*) * num_experts; size_t scale_buf_size = sizeof(float*) * num_experts; @@ -53,9 +53,12 @@ std::array TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( size_t int4_groupwise_sf_a_size = sizeof(INT4GroupwiseParams::SFA*) * num_experts; size_t int4_groupwise_stride_sf_a_size = sizeof(INT4GroupwiseParams::StrideSFA) * num_experts; + size_t ptr_token_map_size = sizeof(int**) * num_experts; + return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size, ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size, sf_a_size, sf_b_size, stride_sf_a_size, - stride_sf_b_size, int4_groupwise_problem_shape_size, int4_groupwise_sf_a_size, int4_groupwise_stride_sf_a_size}; + stride_sf_b_size, int4_groupwise_problem_shape_size, int4_groupwise_sf_a_size, int4_groupwise_stride_sf_a_size, + ptr_buf_size, scale_buf_size, ptr_token_map_size}; } size_t TmaWarpSpecializedGroupedGemmInput::workspaceSize(int num_experts, FpXBlockScalingType scaling_type) @@ -68,7 +71,7 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i size_t gemm_workspace_size, FpXBlockScalingType scaling_type) { auto buffers = workspaceBuffers(num_experts, scaling_type); - std::array pointers{}; + std::array pointers{}; TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers"); for (int i = 0; i < buffers.size(); i++) { @@ -82,12 +85,12 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i stride_a = reinterpret_cast(pointers[1]); stride_b = reinterpret_cast(pointers[2]); stride_c = reinterpret_cast(pointers[3]); - default_epilogue.stride_d = reinterpret_cast(pointers[4]); + stride_d = reinterpret_cast(pointers[4]); ptr_a = reinterpret_cast(pointers[5]); ptr_b = reinterpret_cast(pointers[6]); ptr_c = reinterpret_cast(pointers[7]); - default_epilogue.ptr_d = reinterpret_cast(pointers[8]); + ptr_d = reinterpret_cast(pointers[8]); alpha_scale_ptr_array = reinterpret_cast(pointers[9]); @@ -103,28 +106,24 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i int4_groupwise_params.ptr_s_a = reinterpret_cast(pointers[15]); int4_groupwise_params.stride_s_a = reinterpret_cast(pointers[16]); + fused_finalize_epilogue.ptr_bias = reinterpret_cast(pointers[17]); + fused_finalize_epilogue.ptr_router_scales = reinterpret_cast(pointers[18]); + fused_finalize_epilogue.ptr_source_token_index = reinterpret_cast(pointers[19]); + this->gemm_workspace = reinterpret_cast(gemm_workspace); this->gemm_workspace_size = gemm_workspace_size; } -void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens) +void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams( + void* final_output, int hidden_size, int num_output_tokens, bool use_reduction) { fused_finalize_epilogue.ptr_final_output = final_output; - fused_finalize_epilogue.ptr_router_scales = router_scales; - fused_finalize_epilogue.ptr_bias = bias; - fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset; - fused_finalize_epilogue.ptr_source_token_index = source_token_index; - - fused_finalize_epilogue.stride_final_output - = cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{}, - transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1))); - fused_finalize_epilogue.stride_bias - = transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size)); - fused_finalize_epilogue.stride_router_scales = {}; + + fused_finalize_epilogue.stride_final_output = cutlass::make_cute_packed_stride( + FusedFinalizeEpilogue::StrideFinalOutput{}, cute::make_shape(hidden_size, num_output_tokens, 1)); fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens; + fused_finalize_epilogue.use_reduction = use_reduction; } std::string TmaWarpSpecializedGroupedGemmInput::toString() const @@ -143,16 +142,13 @@ std::string TmaWarpSpecializedGroupedGemmInput::toString() const ss << "Final Output: " << (PrintType) fused_finalize_epilogue.ptr_final_output; ss << " with Stride: " << fused_finalize_epilogue.stride_final_output; ss << ",\nBias: " << (PrintType) fused_finalize_epilogue.ptr_bias; - ss << " with Stride: " << fused_finalize_epilogue.stride_bias; ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales; - ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; - ss << ",\nExpert Offset: " << (PrintType) fused_finalize_epilogue.ptr_expert_first_token_offset; ss << ", Source Map: " << (PrintType) fused_finalize_epilogue.ptr_source_token_index; } else { - ss << "Ptr D: " << (PrintType) default_epilogue.ptr_d; - ss << " with Stride: " << (PrintType) default_epilogue.stride_d; + ss << "Ptr D: " << (PrintType) ptr_d; + ss << " with Stride: " << (PrintType) stride_d; } ss << '\n'; ss << "Alpha scale ptr: " << (PrintType) alpha_scale_ptr_array << "\n"; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 0caf687b569..730840717c2 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -997,12 +997,12 @@ __device__ auto quantizePackedFPXValue(ComputeElem& post_act_val, float global_s TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType scaling_type) { constexpr bool is_fp8 = std::is_same_v; - static constexpr int NumThreadsPerSF = VecSize / CVT_FP4_ELTS_PER_THREAD; + static constexpr int NumThreadsPerSF = VecSize / CVT_ELTS_PER_THREAD; // Quantize the input to FP4 static_assert(std::is_same_v || std::is_same_v); - static_assert(ComputeElem::kElements == CVT_FP4_ELTS_PER_THREAD); + static_assert(ComputeElem::kElements == CVT_ELTS_PER_THREAD); PackedVec packed_vec{}; - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { packed_vec.elts[i].x = static_cast(post_act_val[i * 2 + 0]); packed_vec.elts[i].y = static_cast(post_act_val[i * 2 + 1]); @@ -1013,10 +1013,9 @@ __device__ auto quantizePackedFPXValue(ComputeElem& post_act_val, float global_s = act_sf_flat + getOffsetActivationSF(expert_id, num_tokens_before_expert, num_cols, scaling_type); // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this expert - auto sf_out - = cvt_quant_to_fp4_get_sf_out_offset( - std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, - num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED); + auto sf_out = cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, act_sf_expert, QuantizationSFLayout::SWIZZLED); // Do the conversion and set the output and scaling factor auto func = [&]() @@ -1043,7 +1042,7 @@ __device__ auto quantizePackedFPXValue(ComputeElem& post_act_val, float global_s template __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int64_t source_token_id, int64_t token_id, int64_t elem_idx, int64_t num_cols, TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf = true) { static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; @@ -1055,20 +1054,31 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this expert - auto sf_out - = cvt_quant_to_fp4_get_sf_out_offset( - std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, - num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED); + auto sf_out = cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, act_sf_expert, QuantizationSFLayout::SWIZZLED); if (sf_out) { if (input_sf) { - auto const sf_in - = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols, const_cast(input_sf), - FP4QuantizationSFLayout::SWIZZLED); - *sf_out = *sf_in; + if (swizzled_input_sf) + { + auto const sf_in + = cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, const_cast(input_sf), + QuantizationSFLayout::SWIZZLED); + *sf_out = *sf_in; + } + else + { + auto const sf_in + = cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, const_cast(input_sf), + QuantizationSFLayout::LINEAR); + *sf_out = *sf_in; + } } else { @@ -1155,14 +1165,19 @@ __device__ void computeTmaWarpSpecializedInputStrides( } if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info.default_epilogue.stride_d[out_idx] = cutlass::make_cute_packed_stride( - TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD{}, cute::make_shape(gemm_n, gemm_m, 1)); + layout_info.stride_d[out_idx] = cutlass::make_cute_packed_stride( + TmaWarpSpecializedGroupedGemmInput::StrideD{}, cute::make_shape(gemm_n, gemm_m, 1)); } if (layout_info.int4_groupwise_params.enabled) { layout_info.int4_groupwise_params.stride_s_a[out_idx] = cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::StrideSFA{}, - cute::make_shape(gemm_n, gemm_k / 128, 1)); + cute::make_shape(gemm_n, + gemm_k + / (layout_info.int4_groupwise_params.use_wfp4a16 + ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size + : TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size), + 1)); } } @@ -1170,7 +1185,8 @@ template __device__ void computeTmaWarpSpecializedInputPointers(TmaWarpSpecializedGroupedGemmInput& layout_info, int64_t gemm_m, int64_t gemm_n, int64_t gemm_k, int num_tokens_before_expert, int64_t expert, T const* in, WeightType const* weights, TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const* w4a8_weight_scale, - ScaleBiasType const* bias, OutputType* output, int64_t const out_idx) + ScaleBiasType const* bias, OutputType* output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, int64_t const out_idx) { // The input prior to this contains K elements per token, with `num_tokens_before_expert` tokens layout_info.ptr_a[out_idx] = safe_inc_ptr(in, num_tokens_before_expert * gemm_k); @@ -1181,12 +1197,28 @@ __device__ void computeTmaWarpSpecializedInputPointers(TmaWarpSpecializedGrouped if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { // The output prior to this contains N elements per token, with `num_tokens_before_expert` tokens - layout_info.default_epilogue.ptr_d[out_idx] = safe_inc_ptr(output, num_tokens_before_expert * gemm_n); + layout_info.ptr_d[out_idx] = safe_inc_ptr(output, num_tokens_before_expert * gemm_n); + } + if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE) + { + + layout_info.fused_finalize_epilogue.ptr_source_token_index[expert] + = permuted_row_to_unpermuted_row + num_tokens_before_expert; + layout_info.fused_finalize_epilogue.ptr_router_scales[expert] = router_scales + num_tokens_before_expert; + if (layout_info.fused_finalize_epilogue.ptr_bias != nullptr) + { + layout_info.fused_finalize_epilogue.ptr_bias[expert] = bias + gemm_n * expert; + } } if (layout_info.int4_groupwise_params.enabled) { - layout_info.int4_groupwise_params.ptr_s_a[out_idx] - = safe_inc_ptr(w4a8_weight_scale, expert * (gemm_n * gemm_k / 128)); + // The group size of wfp4a16 is multiplied by 2 because each scale uses 1 byte instead of 2 bytes + layout_info.int4_groupwise_params.ptr_s_a[out_idx] = safe_inc_ptr(w4a8_weight_scale, + expert + * (gemm_n * gemm_k + / (layout_info.int4_groupwise_params.use_wfp4a16 + ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size * 2 + : TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size))); } } @@ -1199,7 +1231,8 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir WeightType const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* gemm1_output, OutputType* gemm2_output) + ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* gemm1_output, OutputType* gemm2_output, + float const* router_scales, int const* permuted_row_to_unpermuted_row) { // First, compute the global tid. We only need 1 thread per expert. int const expert = blockIdx.x * blockDim.x + threadIdx.x; @@ -1277,12 +1310,12 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir gemm1_in, weights1, reinterpret_cast( quant_params.groupwise.fc1.weight_scales), - bias1, gemm1_output, expert); + bias1, gemm1_output, nullptr, nullptr, expert); computeTmaWarpSpecializedInputPointers(layout_info2, gemm_m, gemm2_n, gemm2_k, num_tokens_before_expert, expert, gemm2_in, weights2, reinterpret_cast( quant_params.groupwise.fc2.weight_scales), - bias2, gemm2_output, expert); + bias2, gemm2_output, router_scales, permuted_row_to_unpermuted_row, expert); } template @@ -1400,12 +1433,12 @@ __global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpeciali layout_info2.ptr_b[expert] = safe_inc_ptr(weights2, local_expert * (gemm1_n * gemm2_k)); assert(layout_info1.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE); - layout_info1.default_epilogue.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n); + layout_info1.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n); if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { // The output prior to this contains N elements per token, with `num_tokens` tokens - layout_info2.default_epilogue.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n); + layout_info2.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n); } } else @@ -1415,10 +1448,10 @@ __global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpeciali layout_info1.ptr_b[expert] = nullptr; layout_info2.ptr_b[expert] = nullptr; - layout_info1.default_epilogue.ptr_d[expert] = nullptr; + layout_info1.ptr_d[expert] = nullptr; if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info2.default_epilogue.ptr_d[expert] = nullptr; + layout_info2.ptr_d[expert] = nullptr; } } } @@ -1452,8 +1485,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size, int64_t const k, float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node, - InputActivationsType const* prequant_scales = nullptr) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + int64_t const num_experts_per_node, InputActivationsType const* prequant_scales = nullptr) { static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE || !PRE_QUANT_AWQ, "AWQ and Block Scaling are mutually exclusive"); @@ -1487,7 +1520,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp : TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize; constexpr int64_t ELEM_PER_THREAD - = (is_nvfp4 || is_mxfp8) ? CVT_FP4_ELTS_PER_THREAD : (128 / sizeof_bits::value); + = (is_nvfp4 || is_mxfp8) ? CVT_ELTS_PER_THREAD : (128 / sizeof_bits::value); // This should be VecSize * 4 elements // We assume at least VecSize alignment or the quantization will fail @@ -1555,7 +1588,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp { assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations"); writeSF(num_tokens_before_expert, expert, source_row, permuted_row, - elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf); + elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf, swizzled_input_sf); dest_row_ptr[elem_index] = in_vec; } } @@ -1656,7 +1689,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, int const k, int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, cudaStream_t stream) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + void const* prequant_scales, cudaStream_t stream) { #ifdef ENABLE_FP4 TLLM_CHECK_WITH_INFO( @@ -1732,8 +1766,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, config.attrs = attrs; cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales, permuted_row_to_unpermuted_row, num_rows, hidden_size, k, quant_params.fp4.fc1.act_global_scale, - use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node, - reinterpret_cast(prequant_scales)); + use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, swizzled_input_sf, + num_experts_per_node, reinterpret_cast(prequant_scales)); } #define INSTANTIATE_EXPAND_INPUT_ROWS(InputActivationsType, ExpandedActivationsType) \ @@ -1743,8 +1777,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, int64_t const num_rows, int64_t const hidden_size, int const k, int const num_experts_per_node, \ QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, \ TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, \ - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, \ - cudaStream_t stream) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, \ + void const* prequant_scales, cudaStream_t stream) // Instantiate the data types that are used by the external pytorch op INSTANTIATE_EXPAND_INPUT_ROWS(float, float); @@ -1994,8 +2028,8 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro #define INSTANTIATE_FINALIZE_MOE_ROUTING(OutputT, GemmOutputT, ScaleBiasT) \ template void finalizeMoeRoutingKernelLauncher( \ GemmOutputT const* expanded_permuted_rows, OutputT* reduced_unpermuted_output, ScaleBiasT const* bias, \ - float const* final_scales, int const* expanded_source_row_to_expanded_dest_row, \ - int const* expanded_dest_row_to_expanded_source_row, int const* expert_for_source_row, \ + float const* final_scales, int const* unpermuted_row_to_permuted_row, \ + int const* permuted_row_to_unpermuted_row, int const* expert_for_source_row, \ int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, \ int64_t const experts_per_token, int64_t const num_experts_per_node, MOEParallelismConfig parallelism_config, \ bool const enable_alltoall, cudaStream_t stream); @@ -2007,16 +2041,67 @@ INSTANTIATE_FINALIZE_MOE_ROUTING(float, float, float); INSTANTIATE_FINALIZE_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16); #endif +// ============================== Activation Adaptors ================================= +template