Skip to content

Commit cb5bbce

Browse files
committed
Fix binding errors.
Signed-off-by: Shiyu Li <[email protected]>
1 parent 60ad4d1 commit cb5bbce

File tree

5 files changed

+25
-21
lines changed

5 files changed

+25
-21
lines changed

cpp/tensorrt_llm/nanobind/runtime/bindings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,9 @@ void initBindings(nb::module_& m)
340340
"Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations");
341341

342342
nb::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
343-
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, at::Device, bool>())
343+
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), nb::arg("buf_size"),
344+
nb::arg("group_size"), nb::arg("group_rank"), nb::arg("split_color"), nb::arg("device_idx"),
345+
nb::arg("mn_nvlink"))
344346
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer)
345347
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer);
346348

cpp/tensorrt_llm/pybind/runtime/bindings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,9 @@ void initBindings(pybind11::module_& m)
434434
"Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations");
435435

436436
py::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
437-
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, at::Device, bool>())
437+
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), py::arg("buf_size"),
438+
py::arg("group_size"), py::arg("group_rank"), py::arg("split_color"), py::arg("device_idx"),
439+
py::arg("mn_nvlink"))
438440
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer)
439441
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer);
440442

cpp/tensorrt_llm/runtime/mcastGPUBuffer.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ class McastGPUBuffer
3838
//! \param device The CUDA device for buffer allocation.
3939
//! \param mnNvlink Flag indicating if multi-node NVLink is used.
4040
McastGPUBuffer(
41-
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, at::Device device, bool mnNvlink)
42-
: mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, device.index(), mnNvlink)
41+
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx, bool mnNvlink)
42+
: mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, deviceIdx, mnNvlink)
4343
, mBufSize(bufSize)
44-
, mLocalDevice(device)
44+
, mLocalDevice(at::Device(at::DeviceType::CUDA, deviceIdx))
4545
{
4646
}
4747

@@ -51,7 +51,7 @@ class McastGPUBuffer
5151
//! \param dtype The data type of the tensor elements.
5252
//! \param storageOffset The offset in elements from the start of the buffer.
5353
//! \return An ATen tensor wrapping the unicast buffer section.
54-
at::Tensor getUCBuffer(uint32_t rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
54+
at::Tensor getUCBuffer(uint32_t rank, std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
5555
{
5656
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
5757
size_t const elementSize = c10::elementSize(dtype);
@@ -61,15 +61,18 @@ class McastGPUBuffer
6161
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getUnicastPtr(rank)) + storageOffset * elementSize;
6262

6363
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
64-
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
64+
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
65+
.options(options)
66+
.target_device(mLocalDevice)
67+
.make_tensor();
6568
}
6669

6770
//! \brief Returns a PyTorch tensor view of the multicast buffer portion.
6871
//! \param sizes The desired shape (dimensions) of the tensor.
6972
//! \param dtype The data type of the tensor elements.
7073
//! \param storageOffset The offset in elements from the start of the buffer.
7174
//! \return An ATen tensor wrapping the multicast buffer section.
72-
at::Tensor getMCBuffer(c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
75+
at::Tensor getMCBuffer(std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
7376
{
7477
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
7578
size_t const elementSize = c10::elementSize(dtype);
@@ -79,7 +82,10 @@ class McastGPUBuffer
7982
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getMulticastPtr()) + storageOffset * elementSize;
8083

8184
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
82-
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
85+
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
86+
.options(options)
87+
.target_device(mLocalDevice)
88+
.make_tensor();
8389
}
8490

8591
private:

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import math
32
import os
43
import platform
@@ -17,7 +16,6 @@
1716
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
1817

1918
_thread_local = threading.local()
20-
logger = logging.getLogger(__name__)
2119

2220

2321
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
@@ -61,8 +59,9 @@ def get_allreduce_mnnvl_workspace(
6159
setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}',
6260
{})
6361
# Support topology split
64-
comm = mpi_comm().Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
65-
mapping.tp_rank)
62+
comm = mpi_comm().Split(
63+
int(mapping.pp_rank * mapping.cp_size + mapping.cp_rank),
64+
mapping.tp_rank)
6665
force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1"
6766

6867
allreduce_mnnvl_workspaces = getattr(
@@ -82,7 +81,7 @@ def get_allreduce_mnnvl_workspace(
8281
mapping.tp_rank,
8382
# Split the communicator according to the topology
8483
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
85-
torch.device("cuda", mapping.local_rank),
84+
mapping.local_rank,
8685
True, # mnNvlink
8786
)
8887

@@ -463,12 +462,7 @@ def __init__(self,
463462
# Initialize MNNVL AllReduce if needed
464463
if self.strategy in (AllReduceStrategy.AUTO,
465464
AllReduceStrategy.MNNVL):
466-
if self.mapping.tp_size != self.mapping.world_size:
467-
logger.debug(
468-
f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} "
469-
f"!= world_size:{self.mapping.world_size}")
470-
self.mnnvl_allreduce = None
471-
elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
465+
if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
472466
try:
473467
self.mnnvl_allreduce = MNNVLAllReduce(
474468
self.mapping, dtype) if dtype else None

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def _compute_mlp_tp_size(self, intermediate_size: int,
749749
mlp_tp_size = math.gcd(
750750
tp,
751751
self.mapping.gpus_per_node,
752-
) # Avoid costly inter-node TP when MNNVL is not supported
752+
) # Avoid costly inter-node TP
753753
else:
754754
mlp_tp_size = tp
755755
return mlp_tp_size

0 commit comments

Comments
 (0)