Skip to content

Conversation

brb-nv
Copy link
Collaborator

@brb-nv brb-nv commented Sep 8, 2025

Description

This MR adds functionality for KV cache transmission with context parallelism on gen side in disaggregated serving.

Current scope:

  • MLA only.
  • CP on gen side only. Hence, no changes to MLACacheFormatter::unformat() and concatKvCacheV2Dispatch() in this MR. Those would be needed if prefill also has CP.

Known limitations to be addressed in a future MR:

  • Over-allocation of KV cache on gen ranks with CP. If genCP=4, KV cache must be allocated accordingly for gen-side ranks. This is not the case today.
  • There's also a non-strict mode in getBlockNumAccountingForCP() which allows for additional block allocation while sending. This will also be addressed when previous limitation is taken care of.
  • Current implementation is based on an assumption that context ranks "know" gen ranks expect blocks in a round-robin fashion. Ideally, gen ranks should request specific block hashes so that there's no such assumption. Onus will then be on the gen ranks to request the right block hashes. An ongoing effort by @Tabrizian which enables requesting by block hashes should enable this in near future.

Couple of design decisions:

  • We split blocks instead of tokens (keeps KV cache transmission logic simple).
  • We split blocks in a round-robin fashion instead of contiguous blocks - felt this will enable better KV cache reuse.

For example, let's say seq1 has 4 blocks while seq2 has 6 blocks (first 4 being same as seq1).

Round-robin:
seq1 cache:
cpRank0 - block0, block2
cpRank1 - block1, block3

seq2 cache:
cpRank0 - block0, block2, block4 (first 2 reused from seq1)
cpRank1 - block1, block3, block5 (first 2 reused from seq1)

Contiguous:
seq1 cache:
cpRank0 - block0, block1
cpRank1 - block2, block3

seq2 cache:
cpRank0 - block0, block1, block2 (first 2 reused from seq1)
cpRank1 - block3, block4, block5 (no reuse)

Please let me know if something sounds amiss.

Test Coverage

$ TRTLLM_USE_UCX_KVCACHE=1 mpirun -n 8 ./tests/unit_tests/multi_gpu/cacheTransceiverTest --gtest_filter="AsymmetricCaseTest0WithCPForMLA/AsymmetricalCacheTest.TestCase/*"
$ TRTLLM_USE_UCX_KVCACHE=1 mpirun -n 8 ./tests/unit_tests/multi_gpu/cacheTransceiverTest --gtest_filter="AsymmetricCaseTest1WithCPForMLA/AsymmetricalCacheTest.TestCase/*"

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Summary by CodeRabbit

  • New Features

    • Added context parallelism (CP) support to MLA cache distribution and transmission, enabling PP×CP configurations for larger multi-GPU setups.
    • Introduced environment-controlled debug logging (TLLM_DEBUG_RANK) for targeted cache formatting and split/concat diagnostics.
    • Enabled 4D tensor value printing to simplify debugging.
  • Tests

    • Expanded multi-GPU test coverage with CP-aware MLA scenarios, including safeguards for over-allocation and rank validation to prevent false failures.

@brb-nv brb-nv changed the title TRTLLM-7731 KV cache transmission in disagg with CP on gen side [TRTLLM-7731][feat] KV cache transmission in disagg with CP on gen side Sep 8, 2025
Copy link
Contributor

coderabbitai bot commented Sep 8, 2025

📝 Walkthrough

Walkthrough

Extends MLA KV-cache formatting/split/concat to support CP (context parallel) domain alongside PP/TP, updates buffer sizing and indexing, adds environment-gated debug logging, relaxes a CP constraint, enhances tensor print for 4D, and augments multi-GPU tests for CP-aware scenarios with additional rank decomposition and validations.

Changes

Cohort / File(s) Summary
MLA cache formatter (CP support, logging, sizing)
cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
Adds CP-domain distribution to MLA cache formatting (targets, buffers, concurrency). Introduces getEnvMpiDebugRank() for MPI debug gating. Implements CP-aware block accounting and sizing. Removes restriction requiring CP=1. Adds guarded logging of input/output blocks. Internal refactors for per-peer layer/block math.
MLA cache formatter API (CP block accounting)
cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h
Declares new function: int getBlockNumAccountingForCP(int cpRank, int cpSize, int numTotalBlocks, bool strict); in tensorrt_llm::batch_manager::kv_cache_manager. No other signature changes.
Split/concat kernels (CP dimension added)
cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu
Adds CP domain to MLA split/concat. Kernel splitKVCacheForMLAKernel signature extended with domainCPSize; indexing and offsets updated to be CP-aware. Host calls propagate CP, counts/validations use PP×TP×CP. Adjusts MLA cache counts accordingly. Minor const cleanups.
TargetRanksInfo const-correctness
cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h
Makes TargetRanksInfo::getPeerPPDomainLayerNum(int) const; uses .at() for bounds-checked access.
Tensor debug print enhancement
cpp/tensorrt_llm/runtime/iTensor.cpp
Enables printing for 4D tensors; updates guard to skip only when dims > 4; adds 4D iteration/printing branch.
Multi-GPU tests (CP-aware MLA/DP)
cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp
Adds CP rank/size handling, environment-gated debug (TLLM_DEBUG_RANK), CP-aware token/block indexing, over-allocation checks, MPI proc checks, and new CP-enabled MLA test instantiations. Adds helpers getEnvMpiDebugRank(), isBlockOverallocated(...).

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Host as Host (Transceiver)
  participant TR as TargetRanksInfo
  participant Kern as splitKVCacheForMLAKernel
  participant Buff as Output Buffers (PP×CP)
  note over Host,TR: CP-aware MLA split path

  Host->>TR: Query DomainPPSize, DomainTPSize, DomainCPSize<br/>and peer layer/block metadata
  TR-->>Host: Sizes and indexing info
  Host->>Host: Compute output cache count = PP×CP<br/>Validate IRanks size = PP×TP×CP
  Host->>Buff: Allocate/prepare buffers per peer (PP×CP)
  Host->>Kern: Launch splitKVCacheForMLAKernel(..., DomainPPSize, DomainTPSize, domainCPSize, ...)
  rect rgba(200, 235, 255, 0.25)
    note right of Kern: CP-aware indexing<br/>outputCacheIdx = (blockId % CP) * PP + rankInPP<br/>offset uses blockIdInDomainCP
  end
  Kern-->>Buff: Write K/V slices to PP×CP outputs
  Buff-->>Host: Buffers ready for transmission
Loading
sequenceDiagram
  autonumber
  participant FM as MLACacheFormatter
  participant Peers as Peers (PP×CP)
  participant Buf as Send/Recv Buffers
  note over FM,Peers: CP-extended formatting

  FM->>FM: Read topology (PP, CP), peer layer counts
  FM->>FM: Compute per-peer block counts (getBlockNumAccountingForCP)
  FM->>Buf: Allocate send/recv buffers for PP×CP targets
  FM->>Peers: Distribute/collect cache blocks with CP indexing
  opt Debug (TLLM_DEBUG_RANK)
    FM->>FM: Print inputKvCacheBlocks / outputBuffers snapshots
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  - Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.
  - Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp (1)

1034-1034: Critical: Verify CP-aware token ID calculation

The token ID calculation changes significantly with CP support. The fill function uses blockId * tokensPerBlock while verify uses (blockId * mCpSize + mCpRank) * tokensPerBlock. This asymmetry needs verification.

The token ID calculation differs between fill and verify functions:

  • Fill: int startTokenId = blockId * tokensPerBlock;
  • Verify: int startTokenId = (blockId * mCpSize + mCpRank) * tokensPerBlock;

This appears intentional for CP-aware distribution, but should be validated:

#!/bin/bash
# Search for other startTokenId calculations to ensure consistency
rg "startTokenId\s*=" --type cpp -B 2 -A 2

# Check if there's documentation about this token distribution strategy
rg "CP.*token.*distribution|token.*CP.*distribution" --type cpp

Also applies to: 1141-1141

🧹 Nitpick comments (15)
cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h (2)

48-52: Good: const-correctness and bounds-checked access

Making getPeerPPDomainLayerNum() const and switching to at() improves safety without API friction.

If practical, add a lightweight precondition to guard mDomainPPSize > 0 to avoid UB on modulo when misconfigured.

 int getPeerPPDomainLayerNum(int targetRankIdx) const
 {
-    int ppDomainRankIdx = targetRankIdx % mDomainPPSize;
+    // Defensive: ensure domain size is valid before modulo
+    // (consider TLLM_CHECK if available here)
+    int ppDomainRankIdx = targetRankIdx % mDomainPPSize;
     return mPeerAttentionLayerNumInDomainPP.at(ppDomainRankIdx);
 }

24-33: Header hygiene: include the standard headers you use

This header uses std::vector but doesn’t include . Prefer self-contained headers. Also, coding-guidelines ask for include guards; consider adding them (or confirm project-wide exception).

Proposed additions (outside the changed hunk):

 #pragma once
+// Prefer include guards per guidelines if applicable:
+// #ifndef TRTLLM_CACHE_SPLIT_CONCAT_H
+// #define TRTLLM_CACHE_SPLIT_CONCAT_H
+
+#include <vector>
 ...
-} // namespace tensorrt_llm::executor::kv_cache
+} // namespace tensorrt_llm::executor::kv_cache
+
+// #endif // TRTLLM_CACHE_SPLIT_CONCAT_H

Also applies to: 35-47

cpp/tensorrt_llm/runtime/iTensor.cpp (1)

197-213: 4D print fallback can misformat when shape.d[3] == 1

For 4D tensors with the last dim == 1, the code falls through to the 2D branch and ignores d[2], printing a contiguous slice of length d[1], which is misleading. Add an explicit 4D path for d[3] == 1.

Apply before the current “nbDims >= 2” branch:

-    else if (shape.nbDims == 4 && shape.d[3] > 1)
+    else if (shape.nbDims == 4 && shape.d[3] > 1)
     {
         ...
     }
+    else if (shape.nbDims == 4 && shape.d[3] == 1)
+    {
+        for (int i = 0; i < shape.d[0]; ++i)
+        {
+            for (int j = 0; j < shape.d[1]; ++j)
+            {
+                out << "i=" << i << " j=" << j << ": ";
+                // Print d[2] scalars at k = 0..d[2]-1, d[3] fixed at 0
+                for (int k = 0; k < shape.d[2]; ++k)
+                {
+                    auto const idx = tc::flat_index(shape.d, i, j, k, 0);
+                    // Print as a single value followed by space
+                    out << static_cast<TOutput>(hostData[idx]) << (k + 1 < shape.d[2] ? " " : "");
+                }
+                out << "\n";
+            }
+        }
+    }

Please confirm this matches your intended visualization for 4D tensors where the trailing dimension is degenerate.

cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu (5)

573-577: MLA kernel signature extended with domainCPSize — good, but assert invariants

Adding domainCPSize is correct for CP-aware layout. Add a device-side assert that domainCPSize > 0 to protect modulo/division.

-__global__ void splitKVCacheForMLAKernel(..., int DomainTPSize, int domainCPSize, int kvFactor, uint64_t* prefixLayerNumDevPtr)
+__global__ void splitKVCacheForMLAKernel(..., int DomainTPSize, int domainCPSize, int kvFactor, uint64_t* prefixLayerNumDevPtr)
 {
+    assert(domainCPSize > 0);

606-614: CP-aware output index assumes round-robin block distribution

Using (blockId % domainCPSize) to select the CP group and blockId / domainCPSize for the per-CP offset encodes a strict round-robin mapping. This can misaddress outputs if getBlockNumAccountingForCP(strict=false) yields uneven assignment (noted in PR). At minimum, guard with bounds in host sizing and add a TODO here referencing the distributor to avoid silent OOB writes.

  • Validate at host: each output cache has capacity >= ceil(inputBlockNumSum / domainCPSize) blocks.
  • Alternatively, pass a precomputed blockIdToCpRank/table from host to device when non-strict mode is enabled.
    Example TODO comment:
-                // We do blockId % domainCPSize because blocks are distributed among cpRanks in a round-robin fashion.
+                // We do blockId % domainCPSize assuming round-robin CP distribution.
+                // TODO(TRTLLM-7731): If non-strict distribution is active, replace with an index map provided by host.
                 int outputCacheIdx = (blockId % domainCPSize) * DomainPPSize + rankInDomainPP;
                 ...
-                int const blockIdInDomainCP = blockId / domainCPSize;
+                int const blockIdInDomainCP = blockId / domainCPSize; // relies on round-robin

Also applies to: 618-621


1151-1158: Deriving selfPPRank with CP accounted — good; include CP in debug log

selfPPRank now divides by TP×CP, which is correct. Consider adding domainCPSize to the debug log for easier triage.

-    TLLM_LOG_DEBUG(
-        "splitKVCache - numLayers: %d, headNum: %d, domainPPSize: %d, domainTPSize: %d, "
-        "headsPerDomainTP: %d",
-        numLayers, headNum, DomainPPSize, DomainTPSize, headNumDomainTP);
+    TLLM_LOG_DEBUG(
+        "splitKVCache - numLayers: %d, headNum: %d, domainPPSize: %d, domainTPSize: %d, domainCPSize: %d, "
+        "headsPerDomainTP: %d",
+        numLayers, headNum, DomainPPSize, DomainTPSize, domainCPSize, headNumDomainTP);

Also applies to: 1163-1167


90-97: Pointer table buffers typed as INT64 — minor type-safety nit

You allocate the pointer tables with DataType::kINT64 and assert sizeInBytes vs sizeof(T*). That’s fine on 64-bit, but subtly couples to pointer size. Consider using a dedicated byte buffer (kUINT8) and computing sizes via sizeof(void*) for clarity.

Also applies to: 107-110, 1430-1432


18-36: Missing include for std::accumulate

This TU uses std::accumulate; add explicitly to keep the file self-contained.

 #include <sstream>
+#include <numeric>
 #include <string>
 #include <vector>

Also applies to: 107-110

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp (4)

43-53: Enhance debug rank environment variable documentation

The getEnvMpiDebugRank() function could benefit from better documentation about the expected values and their meanings.

+// Returns debug rank from TLLM_DEBUG_RANK environment variable
+// -1: debug all ranks
+// -2: no debug output (default)
+// >=0: debug specific rank
 int getEnvMpiDebugRank()
 {
-    // Look-up env variable TLLM_DEBUG_RANK.
     char const* const env = std::getenv("TLLM_DEBUG_RANK");
     if (env == nullptr)
     {
-        return -2;  // -1 means all ranks, -2 means no debug rank.
+        return -2;  // Default: no debug output
     }
     return std::stoi(env);
 }

155-159: Refactor repeated debug logging pattern

The debug logging pattern for MPI rank is repeated multiple times. Consider extracting it into a helper function or macro.

-static const int TARGET_RANK = getEnvMpiDebugRank(); // -1 means all ranks.
-if (TARGET_RANK == -1 || mpi::MpiComm::world().getRank() == TARGET_RANK)
+auto const shouldLogDebug = [](int targetRank = getEnvMpiDebugRank()) {
+    return targetRank == -1 || mpi::MpiComm::world().getRank() == targetRank;
+};
+if (shouldLogDebug())
 {
     std::cerr << "[mpiRank:" << mpi::MpiComm::world().getRank() << "]" << "[MLACacheFormatter::format] inputKvCacheBlocks[" << blockNum << "]: \n" << *it << std::endl;
 }

552-562: Remove or gate verbose debug output appropriately

The debug output for all output buffers could be very verbose in production. Consider using a more specific debug flag or removing it.

 static const int TARGET_RANK = getEnvMpiDebugRank(); // -1 means all ranks.
 if (TARGET_RANK == -1 || mpi::MpiComm::world().getRank() == TARGET_RANK)
 {
     bufferManager.getStream().synchronize();
-    int blockNum = 0;
-    for (auto const& block : outputBuffers)
-    {
-        std::cerr << "[mpiRank:" << mpi::MpiComm::world().getRank() << "]" << "[MLACacheFormatter::format] outputBuffers[" << blockNum << "]: \n" << *block << std::endl;
-        blockNum++;
-    }
+    TLLM_LOG_DEBUG("Completed concat of %zu output buffers for rank %d", 
+                   outputBuffers.size(), mpi::MpiComm::world().getRank());
 }

409-409: Fix comment typo

Minor typo in the comment.

-// @B: Maybe no updates are needed because contextCP is always 1?
+// Note: Maybe no updates are needed because contextCP is always 1?
cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp (3)

540-543: Improve test skip condition with clear message

The test skip condition could provide more informative feedback about why the test is being skipped.

 if (tensorrt_llm::mpi::MpiComm::world().getSize() < nprocs)
 {
-    GTEST_SKIP() << "mpirun with procs=" << nprocs << " is required to run this test.";
+    GTEST_SKIP() << "Test requires " << nprocs << " MPI processes (contextTp=" << contextTp 
+                  << " * contextPp=" << contextPp << " * contextCp=" << contextCp 
+                  << " + genTp=" << genTp << " * genPp=" << genPp << " * genCp=" << genCp 
+                  << "), but only " << tensorrt_llm::mpi::MpiComm::world().getSize() << " available";
 }

1056-1068: Consider extracting verbose logging to a separate debug utility

The verbose per-value logging in fillBlockData could impact performance even when disabled. Consider moving to a separate debug utility class.

Consider creating a debug utility class to handle verbose logging:

class KVCacheDebugLogger {
public:
    static bool shouldLog(int targetRank = getEnvMpiDebugRank()) {
        return targetRank == -1 || mpi::MpiComm::world().getRank() == targetRank;
    }
    
    static void logBlockValue(const char* op, int blockId, int layerId, int headId, 
                              int tokenId, int hiddenId, size_t index, double value, 
                              nvinfer1::DataType dataType, bool isKey) {
        if (!shouldLog()) return;
        // logging implementation
    }
};

Also applies to: 1080-1092


1320-1320: Consider adding more iterations for cache reuse testing

The loop is set to only 1 iteration, which doesn't test cache reuse as mentioned in the comment.

 // the second loop is for cache reuse
-for (int i = 0; i < 1; i++)
+for (int i = 0; i < 2; i++)  // Test both initial use and cache reuse
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 77657a1 and 96398fb.

📒 Files selected for processing (6)
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp (7 hunks)
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h (1 hunks)
  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu (10 hunks)
  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h (1 hunks)
  • cpp/tensorrt_llm/runtime/iTensor.cpp (1 hunks)
  • cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp (15 hunks)
🧰 Additional context used
📓 Path-based instructions (7)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...

Files:

  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h
  • cpp/tensorrt_llm/runtime/iTensor.cpp
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
  • cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp
  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.

Files:

  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h
  • cpp/tensorrt_llm/runtime/iTensor.cpp
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
  • cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp
  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h
  • cpp/tensorrt_llm/runtime/iTensor.cpp
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
  • cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp
  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu
**/*.{h,hpp,hh,hxx}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Document new class interfaces and function prototypes with Doxygen; use //! for single-line and //!< for members.

Files:

  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc}: Prefer anonymous namespaces over 'static' for internal linkage of functions.
All templates (class/function/member/static) must be instantiated at least once; non-POD classes should have private data members.

Files:

  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h
  • cpp/tensorrt_llm/runtime/iTensor.cpp
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
  • cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp
**/*.{h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use include guards named 'TRTLLM_<FILE_NAME_IN_CAPS_WITH_UNDERSCORES>_H' (no leading or trailing underscore; directory names excluded).

Files:

  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h
  • cpp/tensorrt_llm/runtime/iTensor.cpp
  • cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
  • cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp
  • cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (8)
cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu (4)

1032-1038: LGTM: CP included in IRank sizing and MLA output cache count

The checks now enforce |IRanks| = PP×TP×CP and, for MLA, outputCacheNum = PP×CP, which matches the kernel’s expectations.


1175-1182: Host wiring passes domainCPSize for all vec widths — consistent

All splitKVCacheForMLAKernel instantiations propagate domainCPSize. Good coverage across 16/8/4/2/1-byte paths.

Also applies to: 1199-1206, 1224-1231, 1258-1265, 1287-1294


1372-1380: Concat path keeps CP collapsed for MLA — verify symmetrical with split

concatKVCache() keeps inputCacheNum = PP for MLA (no CP), which matches your “gen side only” scope. Please confirm tests cover mixed PP×CP on sender and PP-only on receiver to avoid off-by-one in cache indexing.


1477-1487: Repeat: selfPPRank with CP accounted — good

Same correctness note as in splitKVCache.

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp (1)

195-209: CP domain calculation is consistent across codebase
PP rank divisor and buffer size logic correctly incorporate mContextParallelism and getBlockNumAccountingForCP in all instances; no changes needed.

cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp (3)

953-961: Good defensive programming for over-allocation protection

The isBlockOverallocated function properly handles the known limitation of over-allocation in CP-enabled generation ranks. The implementation correctly identifies blocks that exceed the strict CP accounting.


979-983: LGTM: Proper handling of over-allocated blocks

The generation verification correctly skips over-allocated blocks with appropriate logging. This aligns with the PR's known limitations.


1553-1570: Good test coverage for CP scenarios

The new test instantiation AsymmetricCaseTestWithCPForMLA provides comprehensive coverage for CP-enabled MLA scenarios with various configurations. This aligns well with the PR objectives.

@brb-nv brb-nv force-pushed the user/brb/cache-transmission-with-distinct-cp-redo branch from a95c0a6 to 019f8c1 Compare September 8, 2025 19:22
@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 8, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18090 [ run ] triggered by Bot

@brb-nv brb-nv force-pushed the user/brb/cache-transmission-with-distinct-cp-redo branch from 019f8c1 to 8ef531f Compare September 8, 2025 21:35
@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 8, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18098 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18090 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18098 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13562 completed with status: 'FAILURE'

@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 9, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18115 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18115 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13576 completed with status: 'FAILURE'

@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 9, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18243 [ run ] triggered by Bot

@brb-nv brb-nv requested a review from a team as a code owner September 10, 2025 06:26
@brb-nv brb-nv force-pushed the user/brb/cache-transmission-with-distinct-cp-redo branch 2 times, most recently from b69a4ef to 65ea160 Compare September 10, 2025 06:39
@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 10, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18318 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18318 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #13738 completed with status: 'FAILURE'

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18833 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18833 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14117 completed with status: 'FAILURE'

@brb-nv brb-nv force-pushed the user/brb/cache-transmission-with-distinct-cp-redo branch from 735a31d to 1389582 Compare September 17, 2025 00:53
@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 17, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18850 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18850 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14130 completed with status: 'FAILURE'

@brb-nv brb-nv force-pushed the user/brb/cache-transmission-with-distinct-cp-redo branch from 1389582 to 9229e04 Compare September 17, 2025 16:07
@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 17, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19027 [ run ] triggered by Bot

@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 17, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19037 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19027 [ run ] completed with state ABORTED
LLM/main/L0_MergeRequest_PR #14268 (Blue Ocean) completed with status: ABORTED

@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 17, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19052 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19037 [ run ] completed with state ABORTED
LLM/main/L0_MergeRequest_PR #14278 (Blue Ocean) completed with status: ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19052 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14290 completed with status: 'FAILURE'

@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 18, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19209 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19209 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14420 completed with status: 'FAILURE'

@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 19, 2025

/bot run

1 similar comment
@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 19, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19354 [ run ] triggered by Bot

@brb-nv
Copy link
Collaborator Author

brb-nv commented Sep 19, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19368 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19368 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14544 completed with status: 'SUCCESS'

@brb-nv brb-nv merged commit e10a027 into NVIDIA:main Sep 20, 2025
5 checks passed
MrGeva pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Sep 21, 2025
nv-lschneider pushed a commit to nv-lschneider/TensorRT-LLM that referenced this pull request Sep 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants