diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu index 6f85317ae77..2176ba759f4 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu @@ -27,6 +27,10 @@ namespace tensorrt_llm::kernels::mnnvl { + +// Guard for internal helper functions +namespace +{ __device__ bool isNegZero(float v) { return v == 0.f && signbit(v); @@ -49,6 +53,12 @@ inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val) return __bfloat162float(val); } +template <> +inline __device__ float toFloat<__nv_half>(__nv_half val) +{ + return __half2float(val); +} + template inline __device__ T fromFloat(float val) { @@ -61,30 +71,76 @@ inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val) return __float2bfloat16(val); } -__device__ float4 loadfloat4(void const* ptr) +template <> +inline __device__ __nv_half fromFloat<__nv_half>(float val) { + return __float2half(val); +} - float return_value[4]; - - asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" - : "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), "=f"(return_value[3]) - : "l"(ptr)); - - return *(float4*) return_value; +inline __device__ float2 loadfloat2(void const* ptr) +{ + float2 return_value; + asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" : "=f"(return_value.x), "=f"(return_value.y) : "l"(ptr)); + return return_value; } -__device__ __inline__ float2 loadfloat2(void const* ptr) +template +inline __device__ T divUp(T val, T divisor) { + return (val + divisor - 1) / divisor; +} - float return_value[2]; +__device__ struct __attribute__((aligned(32))) LamportFlags +{ + uint32_t buffer_size; + uint32_t input_offset; + uint32_t clear_offset; + uint32_t num_tokens_prev; + uint32_t* offset_access_ptr; + uint32_t* buffer_flags; + + __device__ explicit LamportFlags(uint32_t* buffer_flags) + : offset_access_ptr(&buffer_flags[4]) + , buffer_flags(buffer_flags) + { + uint4 flag = reinterpret_cast(buffer_flags)[0]; + buffer_size = flag.z; + input_offset = flag.x * (buffer_size << 1U); + clear_offset = flag.y * (buffer_size << 1U); + num_tokens_prev = flag.w; + } - asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" - : "=f"(return_value[0]), "=f"(return_value[1]) - : "l"(ptr) - : "memory"); + __device__ void cta_arrive() + { + __syncthreads(); + if (threadIdx.x == 0) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) + asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory"); +#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory"); +#else + atomicAdd(offset_access_ptr, 1); +#endif + } + } - return *(float2*) return_value; -} + __device__ void wait_and_update(uint32_t num_tokens) + { + if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0) + { + while (*reinterpret_cast(offset_access_ptr) < gridDim.x * gridDim.y) + { + } + uint4 flag = reinterpret_cast(buffer_flags)[0]; + buffer_flags[0] = (flag.x + 1) % 3; + buffer_flags[1] = (flag.y + 1) % 3; + buffer_flags[3] = num_tokens; + *(offset_access_ptr) = 0; + } + } +}; +} // namespace template __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens, @@ -99,13 +155,14 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ cudaGridDependencySynchronize(); #endif - // [input_ptr, clear_ptr, buffer_size, access_counter] - uint4 flag = reinterpret_cast(buffer_flags)[0]; - // Each buffer is M * N and we have 2 buffers in each group, one for reduce-scatter and one for allgather - uint32_t buffer_group_size = flag.z << 1; - uint32_t input_offset = flag.x * buffer_group_size; - uint32_t clear_offset = flag.y * buffer_group_size; - uint32_t* offset_access_ptr = &buffer_flags[3]; + LamportFlags flags(buffer_flags); + + // Capture the number of tokens in previous iteration so that we can properly clear the buffer + // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up + uint32_t clr_toks_cta + = divUp(flags.num_tokens_prev > num_tokens ? flags.num_tokens_prev : num_tokens, WORLD_SIZE) + * WORLD_SIZE; + clr_toks_cta = divUp(clr_toks_cta, gridDim.x); if (elt < token_dim) { @@ -115,29 +172,33 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ T val = shard_ptr[token * token_dim + elt]; if (isNegZero(val)) val = fromFloat(0.f); - input_ptrs[dest_rank][input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt] = val; + input_ptrs[dest_rank][flags.input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt] + = val; - // Reduce and broadcast + // Clear the buffer used by the previous call. Note the number of tokens to clear could be larger than the + // number of tokens in the current call. + for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) + { + uint32_t clr_token_idx = token + clr_tok * gridDim.x; + if (clr_token_idx < buffer_M) + { + input_ptrs[rank][flags.clear_offset + clr_token_idx * token_dim + elt] = fromFloat(-0.f); + } + } + // Reduce and broadcast if ((token % WORLD_SIZE) == rank) { int local_token = token / WORLD_SIZE; float accum = 0.f; T values[WORLD_SIZE]; - - for (int r = 0; r < WORLD_SIZE; r++) - { - input_ptrs[rank][clear_offset + local_token * token_dim * WORLD_SIZE + r * token_dim + elt] - = fromFloat(-0.f); - } - while (1) { bool valid = true; for (int r = 0; r < WORLD_SIZE; r++) { - T volatile* lamport_ptr = (T volatile*) &input_ptrs[rank][input_offset + T volatile* lamport_ptr = (T volatile*) &input_ptrs[rank][flags.input_offset + local_token * token_dim * WORLD_SIZE + r * token_dim + elt]; values[r] = *lamport_ptr; valid &= !isNegZero(values[r]); @@ -149,7 +210,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ { accum += toFloat(values[r]); } - mcast_ptr[input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat(accum); + mcast_ptr[flags.input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat(accum); } } @@ -157,24 +218,23 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ cudaTriggerProgrammaticLaunchCompletion(); #endif - input_ptrs[rank][clear_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat(-0.f); + // Similarly clear broadcast buffer here + for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) + { + uint32_t clr_token_idx = token + clr_tok * gridDim.x; + if (clr_token_idx < buffer_M) + { + input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt] + = fromFloat(-0.f); + } + } // Optionally wait for results if the next layer isn't doing the Lamport check if (wait_for_results) { // Update the atomic counter to indicate the block has read the offsets - __syncthreads(); + flags.cta_arrive(); - if (threadIdx.x == 0) - { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) - asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory"); -#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory"); -#else - atomicAdd(offset_access_ptr, 1); -#endif - } // Only use a set of CTAs for lamport sync, reargange the grid constexpr int ELTS_PER_LOAD = sizeof(float2) / sizeof(T); // blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32) @@ -182,7 +242,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ { uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD; - void* lamport_ptr = (void*) &input_ptrs[rank][input_offset + buffer_M * token_dim + current_pos]; + void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos]; // We have 2 assumptions here: // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32) @@ -198,16 +258,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ } // Update the buffer flags - if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0) - { - // Make sure all blocks have finished reading the offsets, 2-D grid - while (*reinterpret_cast(offset_access_ptr) < gridDim.x * gridDim.y) - { - } - buffer_flags[0] = (flag.x + 1) % 3; - buffer_flags[1] = (flag.y + 1) % 3; - *(offset_access_ptr) = 0; - } + flags.wait_and_update(num_tokens); } } @@ -273,12 +324,28 @@ void twoshot_allreduce_op(AllReduceParams const& params) default: TLLM_CHECK_WITH_INFO(false, "TwoShot AllReduce]: unsupported world_size."); } } + else if (dtype == nvinfer1::DataType::kHALF) + { + switch (world_size) + { + case 2: LAUNCH_ALL_REDUCE_KERNEL(2, __nv_half); break; + case 4: LAUNCH_ALL_REDUCE_KERNEL(4, __nv_half); break; + case 8: LAUNCH_ALL_REDUCE_KERNEL(8, __nv_half); break; + case 16: LAUNCH_ALL_REDUCE_KERNEL(16, __nv_half); break; + case 32: LAUNCH_ALL_REDUCE_KERNEL(32, __nv_half); break; + case 64: LAUNCH_ALL_REDUCE_KERNEL(64, __nv_half); break; + default: TLLM_CHECK_WITH_INFO(false, "TwoShot AllReduce]: unsupported world_size."); + } + } else { TLLM_CHECK_WITH_INFO(false, "TwoShot AllReduce]: unsupported dtype."); } } +// Guard for internal helper functions +namespace +{ template __device__ void copy_f4(T_IN* dst, T_IN const* src) { @@ -327,6 +394,19 @@ inline __device__ float block_reduce_sum(float val) return val; } +__device__ float4 loadfloat4(void const* ptr) +{ + + float4 return_value; + + asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(return_value.x), "=f"(return_value.y), "=f"(return_value.z), "=f"(return_value.w) + : "l"(ptr)); + + return return_value; +} +} // namespace + template __global__ void __launch_bounds__(128, 1) RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, T_IN const* gamma, float epsilon, @@ -353,12 +433,8 @@ __global__ void __launch_bounds__(128, 1) int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)]; - uint32_t* offset_access_ptr = &buffer_flags[3]; - uint4 flag = reinterpret_cast(buffer_flags)[0]; - // Buffer size is M * N, and we need two buffers for reduce-scatter and allgather - uint32_t buffer_size = flag.z; - uint32_t buffer_offset = flag.x * (buffer_size << 1); - T_IN const* input = &buffer_input[buffer_offset + buffer_size]; + LamportFlags flags(buffer_flags); + T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size]; cudaTriggerProgrammaticLaunchCompletion(); @@ -388,17 +464,7 @@ __global__ void __launch_bounds__(128, 1) } __pipeline_commit(); - __syncthreads(); - if (threadIdx.x == 0) - { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) - asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory"); -#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory"); -#else - atomicAdd(offset_access_ptr, 1); -#endif - } + flags.cta_arrive(); // Load all inputs bool valid = false; @@ -528,16 +594,7 @@ __global__ void __launch_bounds__(128, 1) = out4; } // Update the buffer pointers - if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0) - { - // Make sure all blocks have finished accessing the buffer - while (*reinterpret_cast(offset_access_ptr) < gridDim.x * gridDim.y) - { - } - buffer_flags[0] = (flag.x + 1) % 3; - buffer_flags[1] = (flag.y + 1) % 3; - *(offset_access_ptr) = 0; - } + flags.wait_and_update(batch_size); #endif } @@ -548,8 +605,6 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons // input to rmsnorm is the buffer in the twoshot ar // We should use prenorm output to determine the actual used size - // int batch = normed_output.sizes()[0]; - // int dim = normed_output.sizes()[1]; float _epsilon{static_cast(epsilon)}; static constexpr int NUM_THREADS = 128; @@ -612,6 +667,20 @@ void twoshot_rmsnorm_op(RMSNormParams const& params) default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim."); } } + else if (dtype == nvinfer1::DataType::kHALF) + { + switch (params.hidden_dim) + { + case 2048: LAUNCH_RMSNORM_KERNEL(__nv_half, 2048); break; + case 4096: LAUNCH_RMSNORM_KERNEL(__nv_half, 4096); break; + // Llama-4 Hidden Dimension + case 5120: LAUNCH_RMSNORM_KERNEL(__nv_half, 5120); break; + // DeepSeek Hidden Dimension + case 7168: LAUNCH_RMSNORM_KERNEL(__nv_half, 7168); break; + case 8192: LAUNCH_RMSNORM_KERNEL(__nv_half, 8192); break; + default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim."); + } + } else { TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported dtype."); diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 83fbf5f91ef..ba713a7d566 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -88,8 +88,8 @@ def get_allreduce_mnnvl_workspace( # This is a buffer to maintain the state of this allreduce Op # Should have the same lifetime with self._buffer - # [Buffer_ptr, Clear_ptr, Buffer_size, atomic access counter] - buffer_flags = torch.tensor([0, 2, max_num_elements, 0], + # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_to_clear,atomic access counter] + buffer_flags = torch.tensor([0, 2, max_num_elements, 0, 0], dtype=torch.uint32, device=torch.device("cuda", mapping.local_rank)) @@ -305,7 +305,7 @@ def __init__(self, mapping: Mapping, dtype: torch.dtype): @staticmethod def get_supported_dtypes(): - return (torch.bfloat16, torch.float32) + return (torch.float16, torch.bfloat16, torch.float32) def forward( self, @@ -458,6 +458,7 @@ def forward( == False): return input + allreduce_strategy = self.strategy if all_reduce_params is None: all_reduce_params = AllReduceParams() @@ -469,6 +470,9 @@ def forward( return mnnvl_output # Fall back to regular AllReduce if MNNVL is not available or not applicable + # Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL + if allreduce_strategy == AllReduceStrategy.MNNVL: + allreduce_strategy = AllReduceStrategy.AUTO output = torch.ops.trtllm.allreduce( input=input, residual=all_reduce_params.residual, @@ -477,7 +481,7 @@ def forward( bias=all_reduce_params.bias, workspace=self.workspace, group=self.mapping.tp_group, - strategy=self.strategy, + strategy=allreduce_strategy, op=all_reduce_params.fusion_op, eps=all_reduce_params.eps, trigger_completion_at_end=all_reduce_params. diff --git a/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py b/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py index 595ff09d12e..e3d00f4683c 100644 --- a/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py @@ -47,21 +47,21 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): def run_single_rank( tensor_parallel_size, single_rank_forward_func, - input, - residual, + input_list, + residual_list, norm_weight, eps, hidden_size, dtype, fused_add_norm, - reference_output, + reference_output_list, ): rank = tensorrt_llm.mpi_rank() torch.cuda.set_device(rank) try: single_rank_forward_func( - input, - residual, + input_list, + residual_list, norm_weight, eps, hidden_size, @@ -69,7 +69,7 @@ def run_single_rank( tensor_parallel_size, rank, fused_add_norm, - reference_output, + reference_output_list, ) except Exception: traceback.print_exc() @@ -79,8 +79,8 @@ def run_single_rank( @torch.inference_mode() def row_linear_residual_norm_fusion_forward( - x: torch.Tensor, - residual: torch.Tensor, + x_list: list[torch.Tensor], + residual_list: list[torch.Tensor], norm_weight: torch.Tensor, eps: float, hidden_size: int, @@ -88,16 +88,21 @@ def row_linear_residual_norm_fusion_forward( tensor_parallel_size: int, tensor_parallel_rank: int, fusion: bool, - reference_output: tuple[torch.Tensor, ...], + reference_output_list: list[tuple[torch.Tensor, ...]], ): - x = x.cuda() - residual = residual.cuda() + # Move all tensors to GPU + x_list = [x.cuda() for x in x_list] + residual_list = [residual.cuda() for residual in residual_list] norm_weight = norm_weight.cuda() - reference_output = tuple(t.cuda() for t in reference_output) + reference_output_list = [ + tuple(t.cuda() for t in ref_output) + for ref_output in reference_output_list + ] MPI.COMM_WORLD.barrier() + # Create a single AllReduce instance to be reused for all sequence lengths allreduce = AllReduce( mapping=Mapping( world_size=tensor_parallel_size, @@ -119,72 +124,106 @@ def func(input, residual, norm_weight, eps, enable_fusion): residual=residual, norm_weight=norm_weight, eps=eps, - )) + ), + ) return (output, residual) else: output = allreduce(input) return (output, ) - output = func(x.clone(), residual.clone(), norm_weight, eps, fusion) + # Process each sequence length using the same AllReduce instance + for i, (x, residual, reference_output) in enumerate( + zip(x_list, residual_list, reference_output_list)): + output = func(x.clone(), residual.clone(), norm_weight, eps, fusion) - torch.testing.assert_close( - output[0], - reference_output[0], - rtol=0.05, - atol=0.15, - ) - - if fusion: torch.testing.assert_close( - output[1], - reference_output[1], + output[0], + reference_output[0], rtol=0.05, atol=0.15, ) + if fusion: + torch.testing.assert_close( + output[1], + reference_output[1], + rtol=0.05, + atol=0.15, + ) + @skip_pre_blackwell @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test") -@pytest.mark.parametrize("seq_len", [1, 4, 32, 128], - ids=lambda x: f"seqlen:{x}") +@pytest.mark.parametrize( + "seq_len", + [ + [1], + [4], + [15], + [32], + [128], + [31, 11, 27, 4], + ], + ids=lambda x: f"seqlen:{x}", +) @pytest.mark.parametrize("hidden_size", [7168], ids=lambda x: f"hidden:{x}") +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32], + ids=lambda x: f"dtype:{torch.finfo(x).dtype}") @pytest.mark.parametrize( "fusion", [True, False], ids=["fusion", "no_fusion"], ) -def test_row_linear_residual_norm_fusion(seq_len, hidden_size, fusion): +def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, fusion): torch.manual_seed(42) - dtype = torch.bfloat16 tensor_parallel_size = 2 - x = torch.randn((tensor_parallel_size, seq_len, hidden_size), dtype=dtype) - residual = torch.randn((seq_len, hidden_size), dtype=dtype) + # Create norm_weight once (same for all sequence lengths) norm_weight = torch.randn((hidden_size, ), dtype=dtype) eps = 1e-5 - reference_output = (torch.sum(x, dim=0), ) - if fusion: - residual_out = reference_output[0] + residual - reference_output = (rms_norm(residual_out.to(torch.float32), - norm_weight, eps).to(dtype), residual_out) + + # Create lists of tensors for each sequence length + x_list = [] + residual_list = [] + reference_output_list = [] + + for seq_len_val in seq_len: + x = torch.randn((tensor_parallel_size, seq_len_val, hidden_size), + dtype=dtype) + residual = torch.randn((seq_len_val, hidden_size), dtype=dtype) + reference_output = (torch.sum(x, dim=0), ) + if fusion: + residual_out = reference_output[0] + residual + reference_output = (rms_norm(residual_out.to(torch.float32), + norm_weight, + eps).to(dtype), residual_out) + + x_list.append(x) + residual_list.append(residual) + reference_output_list.append(reference_output) with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor: results = executor.map( run_single_rank, - *zip(*[( - tensor_parallel_size, - row_linear_residual_norm_fusion_forward, - x[i, :, :], - residual, - norm_weight, - eps, - hidden_size, - dtype, - fusion, - reference_output, - ) for i in range(tensor_parallel_size)]), + *zip(*[ + ( + tensor_parallel_size, + row_linear_residual_norm_fusion_forward, + [ + x[i, :, :] for x in x_list + ], # Extract the i-th rank's data from each sequence length + residual_list, + norm_weight, + eps, + hidden_size, + dtype, + fusion, + reference_output_list, + ) for i in range(tensor_parallel_size) + ]), ) for r in results: assert r is True