Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CODING_GUIDELINES.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ for (int i = 0; i < static_cast<int>(mTensors.size()); ++i)
1. C headers should not be used directly.
- Example: Use `<cstdint>` instead of `<stdint.h>`
2. Do not use C library functions, whenever possible.
* Use brace initialization or `std::fill_n()` instead of `memset()`. This is especially important when dealing with non-[POD types](http://en.cppreference.com/w/cpp/concept/PODType). In the example below, using `memset()` will corrupt the vtable of `Foo:`
* Use brace initialization or `std::fill_n()` instead of `memset()`. This is especially important when dealing with non-[POD types](https://en.cppreference.com/w/cpp/named_req/PODType). In the example below, using `memset()` will corrupt the vtable of `Foo:`
```cpp
struct Foo {
virtual int getX() { return x; }
Expand Down
Git LFS file not shown
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
4a23fc1883a3d35ae8b47c6bac3d8e7e85adc0e5615a9afd46acf8b84c8d62a8 libtensorrt_llm_internal_cutlass_kernels_static.a
commit 0971a156bd13fe781c1c16659c94bc69e0ca77e1
b34f70bfe1f97af0b38c1866858d519fbf1c383f97b3f375c2a59985336e1fa2 libtensorrt_llm_internal_cutlass_kernels_static.a
commit 59ea816b923df8ece603c6eb06cbbf80d82557c9
Git LFS file not shown
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ea9382af97a51c33ee4edb438e9add4e52fbdfcea68c6e420204992213013727 libtensorrt_llm_internal_cutlass_kernels_static.a
commit 0971a156bd13fe781c1c16659c94bc69e0ca77e1
eb5d38d57e05d9d6702d29b99e961ace2927bf95356cd44c79b61aa3d450c0ee libtensorrt_llm_internal_cutlass_kernels_static.a
commit 59ea816b923df8ece603c6eb06cbbf80d82557c9
151 changes: 15 additions & 136 deletions cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,17 @@ struct Vec2Type<__nv_bfloat16>
}; // namespace

template <typename T_in, typename T_out, int kProcessRows, typename AccessType>
__global__ void apply_per_channel_scale(
T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols)
__global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows,
int cols, int64_t const* num_valid_tokens_ptr)
{
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
T_in scale[kElems], act_vec[kElems];
int col_offset = blockIdx.y * blockDim.x + threadIdx.x;
int row_offset = blockIdx.x;
if (col_offset * kElems >= cols || row_offset * kProcessRows >= rows)
return;
if (num_valid_tokens_ptr && (row_offset * kProcessRows >= *num_valid_tokens_ptr))
return;
act += row_offset * kProcessRows * cols;
smoothed_act += row_offset * kProcessRows * cols;
*reinterpret_cast<AccessType*>(scale) = reinterpret_cast<AccessType const*>(per_channel_scale)[col_offset];
Expand Down Expand Up @@ -95,46 +97,46 @@ __global__ void apply_per_channel_scale(
}

template <typename T_in, typename T_out, int kProcessRows, typename AccessType = float4>
void apply_per_channel_scale_kernel_launcher_(
T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols, cudaStream_t stream = 0)
void apply_per_channel_scale_kernel_launcher_(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0)
{
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
dim3 block(128);
dim3 grid((rows + kProcessRows - 1) / kProcessRows, (cols / kElems + block.x - 1) / block.x);
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType>
<<<grid, block, 0, stream>>>(smoothed_act, act, per_channel_scale, rows, cols);
<<<grid, block, 0, stream>>>(smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr);
}

template <typename T_in, typename T_out>
void apply_per_channel_scale_kernel_launcher(
T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols, cudaStream_t stream)
void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
int rows, int cols, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
{
uint64_t elems = static_cast<uint64_t>(rows) * static_cast<uint64_t>(cols);
if (elems < 2048 * 2048)
{
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 1, float4>(
smoothed_act, act, per_channel_scale, rows, cols, stream);
smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr, stream);
}
else if (elems < 4096 * 4096)
{
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 4, float4>(
smoothed_act, act, per_channel_scale, rows, cols, stream);
smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr, stream);
}
else if (elems < 8192 * 8192)
{
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 8, float4>(
smoothed_act, act, per_channel_scale, rows, cols, stream);
smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr, stream);
}
else
{
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 16, float4>(
smoothed_act, act, per_channel_scale, rows, cols, stream);
smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr, stream);
}
}

#define INSTANTIATE_PREQUANT_SCALE(T_in, T_out) \
template void apply_per_channel_scale_kernel_launcher<T_in, T_out>( \
T_out * smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream)
template void apply_per_channel_scale_kernel_launcher<T_in, T_out>(T_out * smoothed_act, const T_in* act, \
const T_in* per_channel_scale, int rows, int cols, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)

INSTANTIATE_PREQUANT_SCALE(half, half);
#if defined(ENABLE_FP8)
Expand All @@ -148,128 +150,5 @@ INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif

template <typename T_in, typename T_out, int kProcessRows, typename AccessType>
__global__ void apply_per_expert_scale(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale,
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols)
{
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
T_in act_vec[kElems];
int col_offset = blockIdx.x * blockDim.x + threadIdx.x;
int row_offset = blockIdx.y;
int expert_idx = permuted_token_selected_experts[row_offset];
T_in scale = per_expert_scale[expert_idx];
if (col_offset * kElems >= cols || row_offset * kProcessRows >= rows)
return;
if (num_valid_tokens_ptr && (row_offset * kProcessRows >= *num_valid_tokens_ptr))
return;
act += row_offset * kProcessRows * cols;
smoothed_act += row_offset * kProcessRows * cols;
#pragma unroll
for (int i = 0; i < kProcessRows; ++i)
{
*reinterpret_cast<AccessType*>(act_vec) = reinterpret_cast<AccessType const*>(act + i * cols)[col_offset];
if constexpr ((std::is_same_v<T_in, half>
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|| std::is_same_v<T_in, __nv_bfloat16>
#endif
) &&(kElems % 2 == 0))
{
using Vec2 = typename Vec2Type<T_in>::type;
#pragma unroll
for (int j = 0; j < kElems; j += 2)
{
if constexpr (std::is_same_v<T_in, half>)
{
*reinterpret_cast<Vec2*>(act_vec + j)
= __hmul2(*reinterpret_cast<Vec2*>(act_vec + j), __half2half2(scale));
}
else
{
*reinterpret_cast<Vec2*>(act_vec + j)
= __hmul2(*reinterpret_cast<Vec2*>(act_vec + j), __bfloat162bfloat162(scale));
}
}
}
else
{
#pragma unroll
for (int j = 0; j < kElems; ++j)
{
act_vec[j] = static_cast<T_in>(static_cast<float>(act_vec[j]) * static_cast<float>(scale));
}
}
if constexpr (std::is_same_v<T_in, T_out>)
{
reinterpret_cast<AccessType*>(smoothed_act + i * cols)[col_offset]
= *reinterpret_cast<AccessType*>(act_vec);
}
else
{
#pragma unroll
for (int j = 0; j < kElems; ++j)
{
(smoothed_act + i * cols)[col_offset * kElems + j] = static_cast<T_out>(act_vec[j]);
}
}
}
}

template <typename T_in, typename T_out, int kProcessRows, typename AccessType = float4>
void apply_per_expert_scale_kernel_launcher_(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale,
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols,
cudaStream_t stream = 0)
{
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
dim3 block(128);
dim3 grid((cols / kElems + block.x - 1) / block.x, (rows + kProcessRows - 1) / kProcessRows);
apply_per_expert_scale<T_in, T_out, kProcessRows, AccessType><<<grid, block, 0, stream>>>(
smoothed_act, act, per_expert_scale, permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols);
}

template <typename T_in, typename T_out>
void apply_per_expert_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale,
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols,
cudaStream_t stream)
{
int elems = rows * cols;
if (elems < 2048 * 2048)
{
apply_per_expert_scale_kernel_launcher_<T_in, T_out, 1, float4>(smoothed_act, act, per_expert_scale,
permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream);
}
else if (elems < 4096 * 4096)
{
apply_per_expert_scale_kernel_launcher_<T_in, T_out, 4, float4>(smoothed_act, act, per_expert_scale,
permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream);
}
else if (elems < 8192 * 8192)
{
apply_per_expert_scale_kernel_launcher_<T_in, T_out, 8, float4>(smoothed_act, act, per_expert_scale,
permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream);
}
else
{
apply_per_expert_scale_kernel_launcher_<T_in, T_out, 16, float4>(smoothed_act, act, per_expert_scale,
permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream);
}
}

#define INSTANTIATE_PEREXPERT_SCALE(T_in, T_out) \
template void apply_per_expert_scale_kernel_launcher<T_in, T_out>(T_out * smoothed_act, T_in const* act, \
T_in const* per_expert_scale, int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, \
int rows, int cols, cudaStream_t stream)

INSTANTIATE_PEREXPERT_SCALE(half, half);
#if defined(ENABLE_FP8)
INSTANTIATE_PEREXPERT_SCALE(half, __nv_fp8_e4m3);
#endif

#if defined(ENABLE_BF16)
INSTANTIATE_PEREXPERT_SCALE(__nv_bfloat16, __nv_bfloat16);
#if defined(ENABLE_FP8)
INSTANTIATE_PEREXPERT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif

} // namespace kernels
} // namespace tensorrt_llm
9 changes: 2 additions & 7 deletions cpp/tensorrt_llm/kernels/preQuantScaleKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,8 @@ namespace kernels
{

template <typename T_in, typename T_out = T_in>
void apply_per_channel_scale_kernel_launcher(
T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols, cudaStream_t stream = 0);

template <typename T_in, typename T_out = T_in>
void apply_per_expert_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale,
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols,
cudaStream_t stream = 0);
void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0);

} // namespace kernels
} // namespace tensorrt_llm
Original file line number Diff line number Diff line change
Expand Up @@ -394,13 +394,13 @@ void pre_quant_scale_for_act(int const m, int const k, int const mQuantAlgo, int
{
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<ActType, __nv_fp8_e4m3>(
reinterpret_cast<__nv_fp8_e4m3*>(workspace), reinterpret_cast<ActType const*>(inputs[0]),
reinterpret_cast<ActType const*>(inputs[mPreQuantScaleInputIdx]), m, k, stream);
reinterpret_cast<ActType const*>(inputs[mPreQuantScaleInputIdx]), m, k, nullptr, stream);
}
else
{
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<ActType, ActType>(
reinterpret_cast<ActType*>(workspace), reinterpret_cast<ActType const*>(inputs[0]),
reinterpret_cast<ActType const*>(inputs[mPreQuantScaleInputIdx]), m, k, stream);
reinterpret_cast<ActType const*>(inputs[mPreQuantScaleInputIdx]), m, k, nullptr, stream);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ void exec_cutlass_kernel(
{
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<AType, AType>(
reinterpret_cast<AType*>(scaled_act), reinterpret_cast<AType const*>(params.act),
reinterpret_cast<AType const*>(params.act_scale), params.m, params.k, stream);
reinterpret_cast<AType const*>(params.act_scale), params.m, params.k, nullptr, stream);
act = scaled_act;
}
if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY)
Expand Down
22 changes: 11 additions & 11 deletions docs/source/architecture/core-concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,24 @@

TensorRT-LLM has a Model Definition API that can be used to define
Large Language Models. This API is built on top of the powerful
[TensorRT Python API](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/index.html#)
[TensorRT Python API](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/index.html)
to create graph representations of deep neural networks in TensorRT. To become
familiar with the core concepts of the TensorRT API, refer to the
[Core Concepts](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/coreConcepts.html)
[Core Concepts](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/coreConcepts.html)
section of the TensorRT documentation before proceeding further.

In TensorRT-LLM, the [`tensorrt_llm.Builder`](source:tensorrt_llm/builder.py) class
contains a
[`tensorrt.Builder`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/Builder.html#tensorrt.Builder)
[`tensorrt.Builder`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Core/Builder.html#id1)
object. That instance is used in the `tensorrt_llm.Builder.create_network`
method to create an instance of the
[`tensorrt.INetworkDefinition`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Network.html#tensorrt.INetworkDefinition)
[`tensorrt.INetworkDefinition`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Graph/Network.html#tensorrt.INetworkDefinition)
class. The `INetworkDefinition` object can then be populated using the free
functions defined in the
[`tensorrt_llm.functional`](source:tensorrt_llm/functional.py).

A simple example of such a free function is `tensorrt_llm.activation` that inserts a
[`tensorrt.IActivationLayer`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html#tensorrt.IActivationLayer)
[`tensorrt.IActivationLayer`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Graph/Layers.html#tensorrt.IActivationLayer)
node in the graph of the model:

```python
Expand Down Expand Up @@ -56,23 +56,23 @@ def silu(input: Tensor) -> Tensor:
When the TensorRT-LLM's Model Definition API is utilized, a graph of the network is
assembled. The graph can later be traversed or transformed using the graph
traversal API exposed by the
[`tensorrt.ILayer`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/LayerBase.html#tensorrt.ILayer)
[`tensorrt.ILayer`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Graph/LayerBase.html#tensorrt.ILayer)
class. That graph will also be optimized by TensorRT during the compilation of
the engine, as explained in the next section.

# Compilation

Once populated, the instance of the
[`tensorrt.INetworkDefinition`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Network.html#tensorrt.INetworkDefinition),
[`tensorrt.INetworkDefinition`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Graph/Network.html#tensorrt.INetworkDefinition),
can be compiled into an efficient engine by the
[`tensorrt.Builder`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/Builder.html#tensorrt.Builder)
[`tensorrt.Builder`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Core/Builder.html#id1)
In TensorRT-LLM, it is done through the `build_engine` member function of the
`tensorrt_llm.Builder` class that calls the
[`build_serialized_network`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/Builder.html#tensorrt.Builder.build_serialized_network)
[`build_serialized_network`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Core/Builder.html#tensorrt.Builder.build_serialized_network
method of the
[`tensorrt.Builder`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/Builder.html#tensorrt.Builder)
[`tensorrt.Builder`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Core/Builder.html#id1)
object. That call, if everything works as expected, produces an instance of the
[`tensorrt.IHostMemory`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/FoundationalTypes/HostMemory.html#tensorrt.IHostMemory)
[`tensorrt.IHostMemory`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/FoundationalTypes/HostMemory.html#tensorrt.IHostMemory)
class. That object is an optimized TensorRT engine that can be stored as a
binary file.

Expand Down
Loading