Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cpp/tensorrt_llm/common/opUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ namespace
{

// Get NCCL unique ID for a group of ranks.
ncclUniqueId getUniqueId(std::set<int> const& group) noexcept
ncclUniqueId getUniqueId(std::set<int> const& group)
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
ncclUniqueId id;
if (rank == *group.begin())
{
NCCLCHECK(ncclGetUniqueId(&id));
NCCLCHECK_THROW(ncclGetUniqueId(&id));
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
COMM_SESSION.sendValue(id, *it, tensorrt_llm::mpi::MpiTag::kDefault);
Expand Down Expand Up @@ -122,7 +122,7 @@ std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
#else
setenv("NCCL_RUNTIME_CONNECT", "0", 0);
#endif // _WIN32
NCCLCHECK(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank));
NCCLCHECK_THROW(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank));
commMap[group] = ncclComm;
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return ncclComm;
Expand Down
20 changes: 20 additions & 0 deletions cpp/tensorrt_llm/common/opUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,16 @@ inline bool isBuilding()
} \
} while (0)

#define NCCLCHECK_THROW(cmd) \
do \
{ \
ncclResult_t r = cmd; \
if (TLLM_UNLIKELY(r != ncclSuccess)) \
{ \
TLLM_THROW("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
} \
} while (0)

std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap();

std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group);
Expand Down Expand Up @@ -308,3 +318,13 @@ std::shared_ptr<cublasLtHandle_t> getCublasLtHandle();
exit(EXIT_FAILURE); \
} \
} while (0)

#define NVML_CHECK_THROW(cmd) \
do \
{ \
nvmlReturn_t r = cmd; \
if (TLLM_UNLIKELY(r != NVML_SUCCESS)) \
{ \
TLLM_THROW("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \
} \
} while (0)
11 changes: 8 additions & 3 deletions cpp/tensorrt_llm/common/tllmException.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ std::string TllmException::getTrace() const
#if defined(_MSC_VER)
return "";
#else
auto const trace = backtrace_symbols(mCallstack.data(), mNbFrames);
auto const trace = std::unique_ptr<char const*, void (*)(char const**)>(
const_cast<char const**>(backtrace_symbols(mCallstack.data(), mNbFrames)),
[](char const** p) { std::free(p); });
if (trace == nullptr)
{
throw std::bad_alloc();
}
std::ostringstream buf;
for (auto i = 1; i < mNbFrames; ++i)
{
Expand All @@ -70,7 +76,7 @@ std::string TllmException::getTrace() const
}
else
{
buf << fmtstr("%-3d %*p %s", i, VOID_PTR_SZ, mCallstack[i], trace[i]);
buf << fmtstr("%-3d %*p %s", i, VOID_PTR_SZ, mCallstack[i], trace.get()[i]);
}
if (i < mNbFrames - 1)
buf << std::endl;
Expand All @@ -79,7 +85,6 @@ std::string TllmException::getTrace() const
if (mNbFrames == MAX_FRAMES)
buf << std::endl << "[truncated]";

std::free(trace);
return buf.str();
#endif
}
Expand Down
13 changes: 6 additions & 7 deletions cpp/tensorrt_llm/thop/allgatherOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ class AllgatherOp

~AllgatherOp() = default;

int initialize() noexcept
int initialize()
{
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
mNcclComm = getComm(mGroup);
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
return 0;
}

torch::Tensor run(torch::Tensor input, torch::optional<torch::List<int64_t>> sizes) noexcept
torch::Tensor run(torch::Tensor input, torch::optional<torch::List<int64_t>> sizes)
{
TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used");
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
Expand All @@ -78,7 +78,7 @@ class AllgatherOp
for (int root = 0; root < static_cast<int>(mGroup.size()); ++root)
{
auto split_size = sizes.value()[root];
NCCLCHECK(ncclBroadcast(input.data_ptr(),
NCCLCHECK_THROW(ncclBroadcast(input.data_ptr(),
output.index({torch::indexing::Slice(split_offset, torch::indexing::None)}).mutable_data_ptr(),
numel_base * split_size, (*getDtypeMap())[type], root, *mNcclComm, stream));
split_offset += split_size;
Expand All @@ -87,14 +87,13 @@ class AllgatherOp
}
else
{
NCCLCHECK(ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(), (*getDtypeMap())[type],
*mNcclComm, stream));
NCCLCHECK_THROW(ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(),
(*getDtypeMap())[type], *mNcclComm, stream));
}
return output;
}

std::vector<torch::Tensor> run_list(
torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes) noexcept
std::vector<torch::Tensor> run_list(torch::TensorList input_list, torch::optional<torch::List<int64_t>> sizes)
{
std::vector<torch::Tensor> output_list;
output_list.reserve(input_list.size());
Expand Down
41 changes: 18 additions & 23 deletions cpp/tensorrt_llm/thop/allreduceOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class NvmlManager
public:
NvmlManager()
{
NVML_CHECK(nvmlInit());
NVML_CHECK_THROW(nvmlInit());
}

~NvmlManager()
Expand Down Expand Up @@ -159,7 +159,7 @@ class AllreduceOp

std::vector<torch::Tensor> run(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
torch::optional<torch::Tensor> const& norm_weight, torch::optional<torch::Tensor> const& scale,
torch::optional<torch::Tensor> const& bias, torch::optional<torch::Tensor> workspace) noexcept
torch::optional<torch::Tensor> const& bias, torch::optional<torch::Tensor> workspace)
{
size_t size = input.numel();
size_t seq_len = input.size(0);
Expand Down Expand Up @@ -187,7 +187,7 @@ class AllreduceOp
}
}

int initialize() noexcept
int initialize()
{
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
mNcclComm = getComm(mGroup);
Expand All @@ -203,7 +203,7 @@ class AllreduceOp
private:
std::vector<torch::Tensor> runUBAllReduce(torch::Tensor const& input,
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias)
{
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int size = input.numel();
Expand Down Expand Up @@ -283,14 +283,14 @@ class AllreduceOp

std::vector<torch::Tensor> runNCCLAllReduce(torch::Tensor const& input,
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias)
{

auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int size = input.numel();

torch::Tensor reduce_output = torch::empty_like(input);
NCCLCHECK(ncclAllReduce(input.data_ptr(), reduce_output.mutable_data_ptr(), size, (*getDtypeMap())[mType],
NCCLCHECK_THROW(ncclAllReduce(input.data_ptr(), reduce_output.mutable_data_ptr(), size, (*getDtypeMap())[mType],
ncclSum, *mNcclComm, stream));

if (mOp == AllReduceFusionOp::NONE)
Expand Down Expand Up @@ -372,7 +372,7 @@ class AllreduceOp
std::vector<torch::Tensor> runFusionAllReduce(torch::Tensor const& input,
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias,
torch::optional<torch::Tensor> workspace, AllReduceStrategyType strategy) noexcept
torch::optional<torch::Tensor> workspace, AllReduceStrategyType strategy)
{
// Should handle only Lamport implementation
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
Expand Down Expand Up @@ -554,7 +554,7 @@ class AllreduceOp
std::vector<torch::Tensor> fallbackRunSubsequentOps(torch::Tensor const& input,
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias,
torch::Tensor& reduce_output) noexcept
torch::Tensor& reduce_output)
{
// If we reach here, it means the extra fallback operations are required.
// All patterns are broken into ALlReduce + residual_rms_norm + following operations (quantization, etc.)
Expand Down Expand Up @@ -619,7 +619,7 @@ class AllreduceOp
return {};
}

AllReduceStrategyType getRuntimeStrategy(size_t seq_len, size_t size) noexcept
AllReduceStrategyType getRuntimeStrategy(size_t seq_len, size_t size)
{
static char* force_nccl_all_reduce_strategy_char = std::getenv("FORCE_NCCL_ALL_REDUCE_STRATEGY");
bool force_nccl_all_reduce_strategy = (force_nccl_all_reduce_strategy_char != nullptr);
Expand Down Expand Up @@ -648,7 +648,7 @@ class AllreduceOp
return runtime_strategy;
}

void logRunTimeStrategy(AllReduceStrategyType strategy, int rank) noexcept
void logRunTimeStrategy(AllReduceStrategyType strategy, int rank)
{
switch (strategy)
{
Expand Down Expand Up @@ -676,12 +676,7 @@ class AllreduceOp
}
}

bool Fusable() noexcept
{
return mOp != AllReduceFusionOp::NONE;
}

void initGroupTopology() noexcept
void initGroupTopology()
{
static std::map<std::set<int>, std::tuple<bool, bool>> cache;
if (cache.find(mGroup) != cache.end())
Expand All @@ -695,7 +690,7 @@ class AllreduceOp
cache[mGroup] = {mIsNVLINKSupported, mIsP2PSupported};
}

void setGroupTopology() noexcept
void setGroupTopology()
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_INFO("Detecting local TP group for rank %d", rank);
Expand Down Expand Up @@ -738,7 +733,7 @@ class AllreduceOp
}

nvmlDevice_t first_device;
NVML_CHECK(nvmlDeviceGetHandleByIndex(first_device_id, &first_device));
NVML_CHECK_THROW(nvmlDeviceGetHandleByIndex(first_device_id, &first_device));

bool is_NVLINK = false;

Expand All @@ -757,7 +752,7 @@ class AllreduceOp
{
// Two GPUs are connected directly through nvlink
unsigned int remote_device_id;
NVML_CHECK(nvmlDeviceGetIndex(remote_device, &remote_device_id));
NVML_CHECK_THROW(nvmlDeviceGetIndex(remote_device, &remote_device_id));

if (remote_device_id == static_cast<unsigned int>(second_device_id))
{
Expand All @@ -771,7 +766,7 @@ class AllreduceOp
// determine whether nvlink is supported by whether two GPUs are connected to the same
// nvswitch.
nvmlDevice_t second_device;
NVML_CHECK(nvmlDeviceGetHandleByIndex(second_device_id, &second_device));
NVML_CHECK_THROW(nvmlDeviceGetHandleByIndex(second_device_id, &second_device));

for (unsigned int second_link = 0; second_link < NVML_NVLINK_MAX_LINKS; second_link++)
{
Expand All @@ -791,7 +786,7 @@ class AllreduceOp
}
else
{
NVML_CHECK(result);
NVML_CHECK_THROW(result);
}

if (is_NVLINK)
Expand All @@ -806,7 +801,7 @@ class AllreduceOp
}
}

bool ifFallbackToNCCL(size_t seq_len, size_t message_size_bytes, size_t max_workspace_size, bool is_auto) noexcept
bool ifFallbackToNCCL(size_t seq_len, size_t message_size_bytes, size_t max_workspace_size, bool is_auto)
{
// If messageSize is less than maxWorkspaceSize, use NCCL, regardless of the fusion type.
if (message_size_bytes > max_workspace_size)
Expand Down Expand Up @@ -842,7 +837,7 @@ class AllreduceOp
}

AllReduceStrategyType selectImplementation(
size_t seq_len, size_t message_size, int world_size, nvinfer1::DataType type) noexcept
size_t seq_len, size_t message_size, int world_size, nvinfer1::DataType type)
{

if (isUsingLowPrecision(message_size))
Expand Down
8 changes: 4 additions & 4 deletions cpp/tensorrt_llm/thop/reducescatterOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ class ReducescatterOp

~ReducescatterOp() = default;

int initialize() noexcept
int initialize()
{
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
mNcclComm = getComm(mGroup);
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
return 0;
}

torch::Tensor run(torch::Tensor const& input, torch::optional<torch::List<int64_t>> sizes) noexcept
torch::Tensor run(torch::Tensor const& input, torch::optional<torch::List<int64_t>> sizes)
{
TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used");
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
Expand Down Expand Up @@ -87,7 +87,7 @@ class ReducescatterOp
for (int root = 0; root < static_cast<int>(mGroup.size()); ++root)
{
auto split_size = sizes.value()[root];
NCCLCHECK(
NCCLCHECK_THROW(
ncclReduce(input.index({torch::indexing::Slice(split_offset, torch::indexing::None)}).data_ptr(),
output.mutable_data_ptr(), numel_base * split_size, (*getDtypeMap())[type], ncclSum, root,
*mNcclComm, stream));
Expand All @@ -97,7 +97,7 @@ class ReducescatterOp
}
else
{
NCCLCHECK(ncclReduceScatter(input.data_ptr(), output.mutable_data_ptr(), output.numel(),
NCCLCHECK_THROW(ncclReduceScatter(input.data_ptr(), output.mutable_data_ptr(), output.numel(),
(*getDtypeMap())[type], ncclSum, *mNcclComm, stream));
}
return output;
Expand Down
14 changes: 9 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,17 @@ def get_token_num_for_estimation(executor_config, model_config):
return None


def estimate_max_kv_cache_tokens(
py_executor: PyExecutor, model_engine: PyTorchModelEngine,
executor_config: ExecutorConfig, mapping: Mapping, origin_seq_len: int,
ctx_chunk_config,
draft_model_engine: PyTorchModelEngine) -> Optional[int]:
def estimate_max_kv_cache_tokens(py_executor: PyExecutor,
model_engine: PyTorchModelEngine,
executor_config: ExecutorConfig,
mapping: Mapping, origin_seq_len: int,
ctx_chunk_config,
draft_model_engine: PyTorchModelEngine) -> int:
# TODO: support CP by generating dummy requests for it.
if 'cp_type' in mapping.cp_config:
# This is called from create_py_executor, which ensures that
# executor_config.max_num_tokens is set.
assert executor_config.max_num_tokens is not None
return executor_config.max_num_tokens

vocab_size = model_engine.model.model_config.pretrained_config.vocab_size
Expand Down
Loading