@@ -403,7 +403,7 @@ def __init__(self,
403403 strategy : AllReduceStrategy = AllReduceStrategy .AUTO ,
404404 dtype : Optional [torch .dtype ] = None ):
405405 super ().__init__ ()
406- '''
406+ """
407407 AllReduce is a module that performs an all-reduce operation on a tensor.
408408
409409 Args:
@@ -440,7 +440,7 @@ def __init__(self,
440440 https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/unittest/_torch/multi_gpu/test_allreduce.py
441441
442442 The LOWPRECISION strategy can be selected either by directly specifying it in the constructor.
443- '''
443+ """
444444
445445 self .mapping = mapping
446446 self .workspace = None
@@ -486,7 +486,7 @@ def forward(
486486 * ,
487487 all_reduce_params : Optional [AllReduceParams ] = None ,
488488 ) -> Union [torch .Tensor , Tuple [torch .Tensor , ...]]:
489- """
489+ '''
490490 The input tensors in the different ranks must have the same shape.
491491 The output tensor will have that same shape with the input tensor.
492492 The output tensor will be replicated among the TP group.
@@ -508,7 +508,7 @@ def forward(
508508 RESIDUAL_RMS_NORM_OUT_QUANT_FP8: [norm, norm_quant, residual]
509509 RESIDUAL_RMS_NORM_QUANT_NVFP4: [norm_quant_fp4, scale_factor, residual]
510510 RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: [norm, norm_quant_fp4, scale_factor, residual]
511- """
511+ '''
512512 if self .mapping .tp_size == 1 or (all_reduce_params is not None
513513 and all_reduce_params .enable_allreduce
514514 == False ):
0 commit comments