Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import argparse

from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
DraftTargetDecodingConfig, EagleDecodingConfig,
KvCacheConfig, MoeConfig, MTPDecodingConfig,
NGramDecodingConfig, TorchCompileConfig)
from tensorrt_llm.llmapi import (AttentionDpConfig, AutoDecodingConfig,
CudaGraphConfig, DraftTargetDecodingConfig,
EagleDecodingConfig, KvCacheConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
TorchCompileConfig)

example_prompts = [
"Hello, my name is",
Expand Down Expand Up @@ -57,6 +58,13 @@ def add_llm_args(parser):
parser.add_argument('--enable_attention_dp',
default=False,
action='store_true')
parser.add_argument('--attention_dp_enable_balance',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One suggestion, if this feature is important, please add a dedicated example under examples/llm-api directory, you can reference llm_inference.py. You can ship it in a separate PR.
The quickstart_advanced.py is only for quick-test, only the separate tests in the llm-api directory will appear in the examples doc

default=False,
action='store_true')
parser.add_argument('--attention_dp_time_out_iters', type=int, default=0)
parser.add_argument('--attention_dp_batching_wait_iters',
type=int,
default=0)
parser.add_argument('--enable_trtllm_sampler',
default=False,
action='store_true')
Expand Down Expand Up @@ -196,6 +204,12 @@ def setup_llm(args, **kwargs):
enable_padding=args.cuda_graph_padding_enabled,
) if args.use_cuda_graph else None

attention_dp_config = AttentionDpConfig(
enable_balance=args.attention_dp_enable_balance,
timeout_iters=args.attention_dp_time_out_iters,
batching_wait_iters=args.attention_dp_batching_wait_iters,
)

llm = LLM(
model=args.model_dir,
backend='pytorch',
Expand All @@ -218,6 +232,7 @@ def setup_llm(args, **kwargs):
max_batch_size=args.max_batch_size,
max_num_tokens=args.max_num_tokens,
enable_attention_dp=args.enable_attention_dp,
attention_dp_config=attention_dp_config,
tensor_parallel_size=args.tp_size,
pipeline_parallel_size=args.pp_size,
moe_expert_parallel_size=args.moe_ep_size,
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def __init__(
self.pytorch_backend_config.enable_iter_perf_stats = False
self.pytorch_backend_config.enable_iter_req_stats = False
self.pytorch_backend_config.stream_interval = 1
self.pytorch_backend_config.attention_dp_enable_balance = False
self.pytorch_backend_config.attention_dp_time_out_iters = 50
self.pytorch_backend_config.attention_dp_batching_wait_iters = 10
self.iter_counter = 0

# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class PyTorchConfig:
moe_max_num_tokens: Optional[int] = None
moe_load_balancer: Optional[Union[MoeLoadBalancerConfig, dict, str]] = None

attention_dp_enable_balance: bool = False
attention_dp_time_out_iters: int = 50
attention_dp_batching_wait_iters: int = 10

attn_backend: str = 'TRTLLM'
moe_backend: str = 'CUTLASS'

Expand Down
65 changes: 63 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ def __init__(self,
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
self.stream_interval = model_engine.pytorch_backend_config.stream_interval
self.attention_dp_enable_balance = model_engine.pytorch_backend_config.attention_dp_enable_balance
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
self.num_fetch_requests_cur_rank = 0
self.num_fetch_requests = 0
self.shutdown_event = threading.Event()
Expand Down Expand Up @@ -215,6 +218,9 @@ def __init__(self,
self.draft_model_engine.warmup(self.resource_manager)

self.is_shutdown = False
self.max_batch_size = max_batch_size
self.adp_ctx_waiting_iters_count = 0
self.adp_ctx_batching_wait_iters_count = 0

# request fetcher initialization
self.executor_request_queue = ExecutorRequestQueue(
Expand Down Expand Up @@ -1131,13 +1137,68 @@ def _add_kv_cache_events(self):
# to be transferred to main thread when user needs them.
kv_cache_manager.flush_iteration_events()

def _balance_adp_requests(self, context_requests: list[LlmRequest],
generation_requests: list[LlmRequest]):
balanced_context_requests = context_requests
num_scheduled_context_requests = len(context_requests)
num_scheduled_generation_requests = len(generation_requests)
num_scheduled_tokens = sum(
[len(req.get_tokens(0))
for req in context_requests]) + num_scheduled_generation_requests
responses_list = self.dist.tp_allgather([
num_scheduled_context_requests, num_scheduled_generation_requests,
num_scheduled_tokens
])
all_ranks_num_scheduled_context_requests = [
response[0] for response in responses_list
]
all_ranks_num_scheduled_generation_requests = [
response[1] for response in responses_list
]
all_ranks_have_free_ctx_slots = all([
num_gen < self.max_batch_size
for num_gen in all_ranks_num_scheduled_generation_requests
])
all_ranks_have_ctx_requests = all([
num_ctx > 0 for num_ctx in all_ranks_num_scheduled_context_requests
])
all_ranks_have_gen_requests = all([
num_gen > 0
for num_gen in all_ranks_num_scheduled_generation_requests
])

if self.attention_dp_enable_balance:
# wait for all ranks have context requests
if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests:
self.adp_ctx_waiting_iters_count = 0
# balance number of context requests across ranks
if all_ranks_have_gen_requests:
if self.adp_ctx_batching_wait_iters_count < self.attention_dp_batching_wait_iters:
self.adp_ctx_batching_wait_iters_count += 1
balanced_context_requests = []
else:
self.adp_ctx_batching_wait_iters_count = 0
else:
self.adp_ctx_waiting_iters_count += 1
balanced_context_requests = []
timeout_reached = self.adp_ctx_waiting_iters_count >= self.attention_dp_time_out_iters
if timeout_reached or not all_ranks_have_gen_requests:
self.adp_ctx_waiting_iters_count = 0
balanced_context_requests = context_requests
return balanced_context_requests

@nvtx_range("_schedule")
def _schedule(self):
scheduler_output = self.scheduler.schedule_request(
self.active_requests, self.inflight_req_ids)
scheduled_requests = ScheduledRequests()
scheduled_context_requests = scheduler_output.context_requests
if self.enable_attention_dp and self.attention_dp_enable_balance:
scheduled_context_requests = self._balance_adp_requests(
scheduler_output.context_requests,
scheduler_output.generation_requests)

scheduled_requests.context_requests = scheduler_output.context_requests
scheduled_requests = ScheduledRequests()
scheduled_requests.context_requests = scheduled_context_requests
scheduled_requests.generation_requests = scheduler_output.generation_requests
scheduled_requests.paused_requests = scheduler_output.paused_requests
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests
Expand Down
20 changes: 11 additions & 9 deletions tensorrt_llm/llmapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
from .build_cache import BuildCacheConfig
from .llm import LLM, RequestOutput
# yapf: disable
from .llm_args import (AutoDecodingConfig, BatchingType, CacheTransceiverConfig,
CalibConfig, CapacitySchedulerPolicy,
ContextChunkingPolicy, CudaGraphConfig,
DraftTargetDecodingConfig, DynamicBatchConfig,
EagleDecodingConfig, ExtendedRuntimePerfKnobConfig,
KvCacheConfig, LlmArgs, LookaheadDecodingConfig,
MedusaDecodingConfig, MoeConfig, MTPDecodingConfig,
NGramDecodingConfig, SchedulerConfig, TorchCompileConfig,
TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig)
from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
CacheTransceiverConfig, CalibConfig,
CapacitySchedulerPolicy, ContextChunkingPolicy,
CudaGraphConfig, DraftTargetDecodingConfig,
DynamicBatchConfig, EagleDecodingConfig,
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
UserProvidedDecodingConfig)
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
QuantConfig)
from .mpi_session import MpiCommSession
Expand Down Expand Up @@ -54,4 +55,5 @@
'TorchLlmArgs',
'TrtLlmArgs',
'AutoDecodingConfig',
'AttentionDpConfig',
]
57 changes: 56 additions & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,23 @@ def from_dict(cls, data: dict):
return cls(**data)


class AttentionDpConfig(StrictBaseModel):
"""
Configuration for attention DP.
"""
enable_balance: bool = Field(default=False,
description="Whether to enable balance.")
timeout_iters: int = Field(
default=50, description="The number of iterations to timeout.")
batching_wait_iters: int = Field(
default=10,
description="The number of iterations to wait for batching.")

@classmethod
def from_dict(cls, data: dict):
return cls(**data)


@dataclass
class _ParallelConfig:
''' The model distribution configs for LLM. '''
Expand Down Expand Up @@ -1988,6 +2005,11 @@ class TorchLlmArgs(BaseLlmArgs):
Note that each CUDA graph can use up to 200 MB of extra memory.",
status="beta")

attention_dp_config: Optional[AttentionDpConfig] = Field(
default=None,
description="Optimized load-balancing for the DP Attention scheduler.",
status="beta")

disable_overlap_scheduler: bool = Field(
default=False,
description="Disable the overlap scheduler.",
Expand Down Expand Up @@ -2253,6 +2275,29 @@ def warn_on_unstable_feature_usage(self) -> 'TorchLlmArgs':

return self

@model_validator(mode='after')
def validate_attention_dp_config(self) -> 'TorchLlmArgs':
"""Validate attention DP configuration.

Ensures that:
1. If attention_dp_config.enable_balance is true, attention_dp_config.batching_wait_iters must be greater or equal to 0
2. If attention_dp_config.enable_balance is true, attention_dp_config.timeout_iters must be greater or equal to 0
"""
if self.attention_dp_config is None:
return self

config = self.attention_dp_config
if config.enable_balance:
if config.batching_wait_iters < 0:
raise ValueError(
"attention_dp_config.batching_wait_iters must be greater or equal to 0 when enable_balance is true"
)
if config.timeout_iters < 0:
raise ValueError(
"attention_dp_config.timeout_iters must be greater or equal to 0 when enable_balance is true"
)
return self

# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
def get_pytorch_backend_config(self) -> "PyTorchConfig":
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
Expand Down Expand Up @@ -2303,7 +2348,16 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
enable_min_latency=self.enable_min_latency,
stream_interval=self.stream_interval,
force_dynamic_quantization=self.force_dynamic_quantization,
allreduce_strategy=self.allreduce_strategy)
allreduce_strategy=self.allreduce_strategy,
attention_dp_enable_balance=bool(
self.attention_dp_config is not None
and self.attention_dp_config.enable_balance),
attention_dp_time_out_iters=self.attention_dp_config.timeout_iters
if self.attention_dp_config is not None else
AttentionDpConfig.model_fields['timeout_iters'].default,
attention_dp_batching_wait_iters=self.attention_dp_config.
batching_wait_iters if self.attention_dp_config is not None else
AttentionDpConfig.model_fields['batching_wait_iters'].default)


def update_llm_args_with_extra_dict(
Expand All @@ -2320,6 +2374,7 @@ def update_llm_args_with_extra_dict(
"speculative_config": DecodingBaseConfig,
"lora_config": LoraConfig,
"moe_config": MoeConfig,
"attention_dp_config": AttentionDpConfig,
}
for field_name, field_type in field_mapping.items():
if field_name in llm_args_dict:
Expand Down
39 changes: 39 additions & 0 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,45 @@ def test_ptp_quickstart_advanced_auto(llm_root, llm_venv, model_name,
_check_mem_usage(running_log, [27.0, 0, 0, 0])


@skip_post_blackwell
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("model_name,model_path", [
pytest.param(
'DeepSeek-V3-Lite-FP8', 'DeepSeek-V3-Lite/fp8', marks=skip_pre_hopper),
])
def test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance(
llm_root, llm_venv, model_name, model_path):
print(f"Testing {model_name}.")
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
with tempfile.NamedTemporaryFile(mode='w+t',
suffix=f".{model_name}.log",
dir="./",
delete=True,
delete_on_close=True) as running_log:
llm_venv.run_cmd([
str(example_root / "quickstart_advanced.py"),
"--model_dir",
f"{llm_models_root()}/{model_path}",
"--moe_tp_size=1",
"--moe_ep_size=4",
"--tp_size=4",
"--use_cuda_graph",
"--enable_attention_dp",
f"--kv_cache_fraction={_MEM_FRACTION_95}",
"--max_batch_size=1",
"--max_seq_len=3000",
"--disable_kv_cache_reuse",
"--attention_dp_enable_balance",
"--attention_dp_time_out_iters",
"10",
"--attention_dp_batching_wait_iters",
"10",
],
stdout=running_log)
_check_mem_usage(running_log, [106.3, 0, 0, 0], 8)


@skip_post_blackwell
@pytest.mark.skip_less_device_memory(110000)
@pytest.mark.skip_less_device(8)
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
- test_e2e.py::test_ptp_quickstart_advanced_bs1
- test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance[DeepSeek-V3-Lite-FP8-DeepSeek-V3-Lite/fp8]
- condition:
ranges:
system_gpu_count:
Expand Down
4 changes: 4 additions & 0 deletions tests/unittest/api_stability/references/llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ methods:
annotation: Optional[tensorrt_llm.llmapi.llm_args.CudaGraphConfig]
default: null
status: beta
attention_dp_config:
annotation: Optional[tensorrt_llm.llmapi.llm_args.AttentionDpConfig]
default: null
status: beta
checkpoint_loader:
annotation: Optional[tensorrt_llm._torch.models.checkpoints.BaseCheckpointLoader]
default: null
Expand Down