Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 12 additions & 45 deletions tensorrt_llm/_torch/distributed/communicator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
from abc import ABC, abstractmethod
from typing import Optional, ValuesView
from typing import Optional

import numpy as np
import torch
import torch.distributed as dist

from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_broadcast,
mpi_comm, mpi_isend, mpi_recv, mpi_send,
torch_dtype_to_np)
mpi_comm, mpi_isend, mpi_isend_object,
mpi_recv, mpi_recv_object, mpi_send,
mpi_send_object)
from tensorrt_llm.mapping import Mapping


Expand Down Expand Up @@ -121,48 +122,14 @@ def recv(self, buf: np.ndarray, src, tag=0):
# in-place recv numpy buffer
return mpi_recv(buf, src, tag)

def isend_tensor(self, tensor: torch.Tensor, dest, tag=0):
return self.isend(tensor.numpy(), dest, tag)

def recv_tensor(self, tensor: torch.Tensor, src, tag=0):
return self.recv(tensor.numpy(), src, tag)

def isend_tensor_list(self,
tensor_list: ValuesView[torch.Tensor],
dest,
tag=0):
if len(tensor_list) == 0:
return None
elif len(tensor_list) == 1:
return self.isend_tensor(next(iter(tensor_list)), dest, tag)

return self.isend(
np.concatenate([t.numpy().ravel() for t in tensor_list]), dest, tag)

def recv_tensor_list(self,
tensor_list: ValuesView[torch.Tensor],
src,
tag=0):
if len(tensor_list) == 0:
return None
elif len(tensor_list) == 1:
return self.recv_tensor(next(iter(tensor_list)), src, tag)

# Use the first tensor's dtype to infer the buffer dtype
numpy_dtype = torch_dtype_to_np(next(iter(tensor_list)).dtype)
# Prepare buffer to receive tensor_list
recv_buffer = np.empty(sum([t.numel() for t in tensor_list]),
dtype=numpy_dtype)
# Receive tensors
self.recv(recv_buffer, src, tag)
# Assign to tensor_list
offset = 0
for t in tensor_list:
t.copy_(
torch.from_numpy(recv_buffer[offset:offset +
t.numel()]).reshape(t.shape))
offset += t.numel()
return None
def send_object(self, obj, dest, tag=0):
mpi_send_object(obj, dest, tag)

def isend_object(self, obj, dest, tag=0):
return mpi_isend_object(obj, dest, tag)

def recv_object(self, src, tag=0):
return mpi_recv_object(src, tag)

def create_tp_comm(self):
new_group = mpi_comm().group.Incl(self.mapping.tp_group)
Expand Down
115 changes: 64 additions & 51 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from itertools import chain
from typing import Dict, List, Optional, Tuple, Union

import dill # nosec B403
import numpy as np
import torch

from tensorrt_llm._utils import (global_mpi_rank, is_trace_enabled, nvtx_range,
Expand Down Expand Up @@ -661,6 +659,24 @@ def _executor_loop_cleanup(self):
self.response_cv.notify_all()
self.shutdown_event.set()

def _need_return_logits(self, scheduled_requests: ScheduledRequests):
for req in scheduled_requests.context_requests:
if req.py_return_context_logits:
return True
for req in scheduled_requests.generation_requests:
if req.py_return_generation_logits:
return True
return False

def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
for req in scheduled_requests.context_requests:
if req.py_return_log_probs:
return True
for req in scheduled_requests.generation_requests:
if req.py_return_log_probs:
return True
return False

def _executor_loop_pp(self):
torch.cuda.set_device(self.device_id)
got_finish_signal = False
Expand Down Expand Up @@ -720,8 +736,13 @@ def _executor_loop_pp(self):
else:
with torch.cuda.nvtx.range("_forward_step_last_pp"):
batch_outputs = self._forward_step(scheduled_batch)
logits_host = None
if self._need_return_logits(scheduled_batch):
logits_host = batch_outputs["logits"].to(
"cpu", non_blocking=True)
sample_state = self._sample_async(
scheduled_batch, batch_outputs)
sample_state.logits_host = logits_host
self._update_request_states(scheduled_batch)

if self.enable_iter_perf_stats:
Expand All @@ -741,33 +762,51 @@ def _executor_loop_pp(self):

# Stage 2: Communicate new tokens for previous batch between ranks
# send/recv chain: (pp_size - 1) -> 0 -> 1 -> ... -> (pp_size - 2)
# last rank: sync decoder for previous microbatch to start new tokens comm chain.
# last rank: sync sampler for previous microbatch to start new tokens comm chain.
# other ranks: send/recv tokens for next microbatch to allow overlap
offset = -1 if self.dist.is_last_pp_rank else 1
prev_microbatch_id = (microbatch_id +
offset) % self.num_micro_batches
previous_batch = self.micro_batches[prev_microbatch_id]
if previous_batch is not None:
sample_state = previous_batch.sample_state
if not self.dist.is_last_pp_rank:
torch.cuda.nvtx.range_push(
"_handle_new_tokens_inter_pp")
# Receive tokens from previous pp rank (w.r.t model forward direction)
self.dist.recv_tensor_list(
previous_batch.sample_state.host.values(),
(
logits,
sample_state.log_probs,
sample_state.host,
) = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=prev_microbatch_id)
tag=prev_microbatch_id,
)
if logits is not None:
logits_host = torch.from_numpy(logits)
sample_state.logits_host = logits_host
sample_state.logits = logits_host.to(self.device_id)
else:
torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp")
previous_batch.sample_state.sampler_event.synchronize()
sample_state.sampler_event.synchronize()

# Send tokens to next pp rank (w.r.t model forward direction)
# Second last rank does not need to since last rank has original decoded tokens
if not self.dist.is_second_last_pp_rank:
if self.send_handles[prev_microbatch_id] is not None:
self.send_handles[prev_microbatch_id].Wait()
self.send_handles[
prev_microbatch_id] = self.dist.isend_tensor_list(
previous_batch.sample_state.host.values(),
prev_microbatch_id] = self.dist.isend_object(
(
sample_state.logits_host.numpy() if
self._need_return_logits(scheduled_batch) or
(self._need_return_log_probs(
scheduled_batch)
and sample_state.log_probs is not None)
else None,
sample_state.log_probs,
sample_state.host,
),
dest=self.dist.next_pp_rank,
tag=prev_microbatch_id)
torch.cuda.nvtx.range_pop()
Expand Down Expand Up @@ -899,7 +938,7 @@ def _executor_loop(self):
self._finish_dummy_request(scheduled_batch)

if self.kv_cache_transceiver:
# For context only req in transmission, we reset the state since decoder might have changed it
# For context only req in transmission, we reset the state since sampler might have changed it
for req in ctx_transmission_reqs:
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS

Expand Down Expand Up @@ -1100,11 +1139,15 @@ def _process_previous_batch(self):

@nvtx_range("_forward_step_inter_pp")
def _forward_step_inter_pp(self, scheduled_batch) -> SampleState:
batch_outputs = self._forward_step(scheduled_batch)
sample_state = self._sample_async(scheduled_batch, batch_outputs)
self._forward_step(scheduled_batch)
sampler_event = torch.cuda.Event()
sampler_event.record()
self._update_request_states(scheduled_batch)
sample_state.sampler_event.synchronize()
return sample_state
sampler_event.synchronize()
return self.sampler.SampleState(
scheduled_requests=scheduled_batch,
sampler_event=sampler_event,
)

def _update_new_active_requests_queue_latency(self, new_requests):
if self.enable_iter_perf_stats and self.dist.rank == 0:
Expand All @@ -1125,12 +1168,10 @@ def _broadcast_new_requests(
"""Broadcasts new_requests and optional Python-only metadata (`py_request_objects`) across pipeline stages.
`py_request_objects` is a tuple of (attribute_name, {request_id: object}).
"""
payloads = (new_requests, py_request_objects
) if py_request_objects is not None else new_requests
payloads = (new_requests, py_request_objects)

if not self.dist.has_pp:
result = self.dist.broadcast(payloads, root=0)
return result if isinstance(result, tuple) else (result, None)
return self.dist.broadcast(payloads, root=0)

# broadcast within first tp group before send/recv chain to other tp groups
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
Expand All @@ -1139,42 +1180,14 @@ def _broadcast_new_requests(
# tag = [0, num_micro_batches - 1] used for new_tokens send/recv
tag = self.num_micro_batches

# 1. send metadata: len(num_requests) and serialized buffer size
new_requests = payloads[0] if isinstance(payloads, tuple) else payloads
if self.dist.is_first_pp_rank and len(new_requests) > 0:
buf = np.array(bytearray(dill.dumps(payloads)))
buf_size = len(buf)
else:
buf, buf_size = None, 0
metadata_arr = np.array([len(new_requests), buf_size])

# send payloads
if not self.dist.is_first_pp_rank:
self.dist.recv(metadata_arr, self.dist.prev_pp_rank, tag)
payloads = self.dist.recv_object(self.dist.prev_pp_rank, tag)

if not self.dist.is_last_pp_rank:
self.dist.send(metadata_arr, self.dist.next_pp_rank, tag)

# 2. send serialized buffer when new requests is not empty
num_new_requests = metadata_arr[0]
if num_new_requests > 0:
buf_size = metadata_arr[1]
if not self.dist.is_first_pp_rank:
buf = np.array(bytearray(buf_size))
self.dist.recv(buf, self.dist.prev_pp_rank, tag)

if not self.dist.is_last_pp_rank:
self.dist.send(buf, self.dist.next_pp_rank, tag)

if not self.dist.is_first_pp_rank:
buf_data = dill.loads(buf.tobytes()) # nosec B301
if isinstance(buf_data, tuple):
new_requests, py_request_objects = buf_data
else:
new_requests = buf_data

assert len(new_requests) == num_new_requests
self.dist.send_object(payloads, self.dist.next_pp_rank, tag)

return new_requests, py_request_objects
return payloads

@nvtx_range("_fetch_new_requests")
def _fetch_new_requests(self):
Expand Down Expand Up @@ -1975,7 +1988,7 @@ def _handle_cancelled_requests(self):
self.active_requests = left_requests

# enqueue the cancelled requests' responses as they are not
# active_requests and be discarded in the decoder loop.
# active_requests and be discarded in the sampler loop.
self._enqueue_responses(cancelled_responses)

@nvtx_range("_enqueue_responses")
Expand Down
26 changes: 16 additions & 10 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ def values(self):
return vars(self).values()


@dataclass(frozen=True, kw_only=True)
@dataclass(kw_only=True)
class SampleState:
scheduled_requests: ScheduledRequests

logits: torch.Tensor = None
logits_host: torch.Tensor = None

# Set when decode_async() has evaluated these to avoid computing again in update_requests()
# log_probs[request_idx][token_idx]
Expand All @@ -50,6 +51,8 @@ class SampleState:

class Sampler(ABC):

SampleState = SampleState

def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
pass

Expand Down Expand Up @@ -81,13 +84,14 @@ def update_requests(self, state: SampleState) -> None:
request.state = LlmRequestState.GENERATION_COMPLETE
# NOTE: This is a hack: set finish reason manually and set the beam 0
request.set_finished_reason(FinishReason.LENGTH, 0)
logits = state.logits[idx]
if logits.ndim == 1:
# For BERT: Add axis to be compatible with LogitsStorage
# (LogitsStorage will interpret this dim as the prompt_len which
# is not relevant for outputting logits of encoder only model).
logits = logits.unsqueeze(0)
request.py_result.append_context_logits(logits)
if request.py_return_context_logits:
logits = state.logits[idx]
if logits.ndim == 1:
# For BERT: Add axis to be compatible with LogitsStorage
# (LogitsStorage will interpret this dim as the prompt_len which
# is not relevant for outputting logits of encoder only model).
logits = logits.unsqueeze(0)
request.py_result.append_context_logits(logits)


def top_k_sampling_batch(logits, top_k=50):
Expand Down Expand Up @@ -409,7 +413,8 @@ def update_one_request(self, request: LlmRequest,
num_tokens = request.add_new_token(new_token, beam_idx)

current_logits = logits[output_token_idx].unsqueeze(0)
request.py_result.append_generation_logits(current_logits)
if request.py_return_generation_logits:
request.py_result.append_generation_logits(current_logits)
if request.py_return_log_probs:
_, log_probs = greedy_search_sampling_batch(current_logits)
request.py_result.append_log_probs([[{
Expand Down Expand Up @@ -452,14 +457,15 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
sequence_lengths: torch.Tensor


@dataclass(frozen=True, kw_only=True)
@dataclass(kw_only=True)
class SampleStateTRTLLM(SampleState):
host: SampleStateTensorsHostTRTLLM
device: SampleStateTensors


class TRTLLMSampler(Sampler):
MAX_DECODING_TOKENS = 1 # It must be 1 when not in speculative decoding
SampleState = SampleStateTRTLLM

def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class SampleStateTensorsMTP(SampleStateTensors):
next_draft_tokens: torch.Tensor


@dataclass(frozen=True, kw_only=True)
@dataclass(kw_only=True)
class SampleStateMTP(SampleState):
device: SampleStateTensorsMTP
host: SampleStateTensorsMTP
Expand Down Expand Up @@ -221,6 +221,8 @@ class MTPSampler(TorchSampler):
MTP sampler.
"""

SampleState = SampleStateMTP

def __init__(self, max_seq_len: int, config: MTPConfig):
super().__init__(max_seq_len, False)
self.mapping = None
Expand Down
Loading