diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 916ec2b83df4..9a51e6b3514f 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -22,6 +22,8 @@ multi_gpu_test) from .backend import TestBackend +FP8_DTYPE = current_platform.fp8_dtype() + prompts = [ "Hello, my name is", "The president of the United States is", @@ -32,9 +34,10 @@ class TestMMRSModel(torch.nn.Module): - def __init__(self, hidden_size=16): + def __init__(self, hidden_size=16, dtype=torch.float16): super().__init__() self.hidden_size = hidden_size + self.dtype = dtype self.gate_proj = torch.nn.Parameter(torch.empty( (self.hidden_size * 2, hidden_size)), requires_grad=False) @@ -64,9 +67,10 @@ def ops_in_model_after(self): class TestAGMMModel(torch.nn.Module): - def __init__(self, hidden_size=16): + def __init__(self, hidden_size=16, dtype=torch.float16): super().__init__() self.hidden_size = hidden_size + self.dtype = dtype self.weight = torch.nn.Parameter(torch.empty( (hidden_size, hidden_size)), requires_grad=False) @@ -91,8 +95,125 @@ def ops_in_model_after(self): return [torch.ops.symm_mem.fused_all_gather_matmul.default] +class _BaseScaledMMModel(torch.nn.Module): + + def __init__(self, hidden_size=16, dtype=torch.float16): + super().__init__() + self.hidden_size = hidden_size + self.dtype = dtype + self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\ + .contiguous().transpose(0, 1) + + # Initialize scale_b for _scaled_mm. + self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32) + + +class TestScaledMMRSModel(_BaseScaledMMModel): + + def forward(self, input: torch.Tensor): + """ + Forward pass implementing the scaled_mm + reduce scatter in the FX graph + + """ + fp8_input = input.to(FP8_DTYPE) + scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32) + scaled_mm = torch._scaled_mm(fp8_input, + self.weight, + scale_a=scale_a, + scale_b=self.scale_b, + out_dtype=self.dtype) + reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0) + return reduce_scatter + + def ops_in_model_before(self): + return [torch.ops.vllm.reduce_scatter.default] + + def ops_in_model_after(self): + return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default] + + +class TestAGScaledMMModel(_BaseScaledMMModel): + + def forward(self, input: torch.Tensor): + """ + Forward pass implementing the all gather + scaled_mm in the FX graph + """ + # Reshape input + fp8_input = input.to(FP8_DTYPE) + all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0) + + scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32) + scaled_mm = torch._scaled_mm(all_gather, + self.weight, + scale_a=scale_a, + scale_b=self.scale_b, + out_dtype=self.dtype) + return scaled_mm + + def ops_in_model_before(self): + return [torch.ops.vllm.all_gather.default] + + def ops_in_model_after(self): + return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default] + + +class TestCutlassScaledMMRSModel(_BaseScaledMMModel): + + def forward(self, input: torch.Tensor): + """ + Forward pass implementing the cutlass_scaled_mm + reduce scatter + in the FX graph + + """ + fp8_input = input.to(FP8_DTYPE) + scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32) + mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]), + dtype=self.dtype, + device=input.device) + torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a, + self.scale_b, None) + reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0) + return reduce_scatter + + def ops_in_model_before(self): + return [torch.ops.vllm.reduce_scatter.default] + + def ops_in_model_after(self): + return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default] + + +class TestAGCutlassScaledMMModel(_BaseScaledMMModel): + + def forward(self, input: torch.Tensor): + """ + Forward pass implementing the all gather + cutlass_scaled_mm + in the FX graph + """ + # Reshape input + fp8_input = input.to(FP8_DTYPE) + all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0) + + scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32) + + mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]), + dtype=self.dtype, + device=all_gather.device) + torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight, + scale_a, self.scale_b, None) + return mm_out + + def ops_in_model_before(self): + return [torch.ops.vllm.all_gather.default] + + def ops_in_model_after(self): + return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default] + + @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel]) +@pytest.mark.parametrize("test_model", [ + TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel, + TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel +]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @@ -101,6 +222,14 @@ def ops_in_model_after(self): reason="Only test on CUDA") def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype): + if test_model in (TestScaledMMRSModel, TestAGScaledMMModel, + TestCutlassScaledMMRSModel, + TestAGCutlassScaledMMModel) and dtype == torch.float16: + pytest.skip( + "Only bf16 high precision output types are supported for " \ + "per-token (row-wise) scaling" + ) + num_processes = 2 def run_torch_spawn(fn, nprocs): @@ -155,7 +284,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, async_tp_pass = AsyncTPPass(vllm_config) backend = TestBackend(async_tp_pass) - model = test_model_cls(hidden_size) + model = test_model_cls(hidden_size, + dtype) # Pass dtype to model constructor hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype, @@ -174,7 +304,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, @create_new_process_for_each_test() -@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"]) +@pytest.mark.parametrize("model_id", [ + "meta-llama/Llama-3.2-1B-Instruct", + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" +]) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("async_tp_enabled", [True]) @pytest.mark.parametrize("distributed_backend", ["mp"]) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 0e7961841bd3..cb99fe8310e7 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -15,10 +15,13 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from .vllm_inductor_pass import VllmInductorPass +FP8_DTYPE = current_platform.fp8_dtype() + if find_spec("flashinfer"): try: import flashinfer.comm as flashinfer_comm @@ -28,7 +31,6 @@ flashinfer_comm = None else: flashinfer_comm = None -from vllm.platforms import current_platform logger = init_logger(__name__) @@ -118,6 +120,230 @@ def replacement( pm.fwd_only, pm_pass) +class ScaledMMReduceScatterPattern(BasePattern): + + def get_inputs(self): + input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + mm_weight = torch.empty([16, 16], device=self.device, + dtype=FP8_DTYPE).contiguous().transpose(0, 1) + scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32) + scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) + return [input, mm_weight, scale_a, scale_b] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(input: torch.Tensor, mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor) -> torch.Tensor: + scaled_mm = torch.ops.aten._scaled_mm.default(input, + mat2=mat2, + scale_a=scale_a, + scale_b=scale_b, + bias=None, + scale_result=None, + out_dtype=self.dtype) + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + scaled_mm, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + return reduce_scatter + + def replacement(input: torch.Tensor, mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor) -> torch.Tensor: + gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + input, + mat2, + scale_a, + scale_b, + "avg", + scatter_dim=0, + out_dtype=self.dtype, + group_name=self.tp.device_group.group_name, + ) + + return gemm_rs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllGatherScaledMMPattern(BasePattern): + + def get_inputs(self): + x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) + weight = torch.empty([16, 16], device=self.device, + dtype=FP8_DTYPE).contiguous().transpose(0, 1) + + s1 = x.shape[0] * self.tp_size + + scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32) + scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) + + return [x, weight, scale_a, scale_b] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + x: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + ) -> torch.Tensor: + all_gather = torch.ops.vllm.all_gather.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + + return torch.ops.aten._scaled_mm.default(all_gather, + mat2=weight, + scale_a=scale_a, + scale_b=scale_b, + bias=None, + scale_result=None, + out_dtype=self.dtype) + + def replacement(x: torch.Tensor, weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor) -> torch.Tensor: + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa + x, + [weight], + scale_a, + [scale_b], + gather_dim=0, + biases=[None], + result_scales=[None], + out_dtypes=[self.dtype], + use_fast_accum=[False], + group_name=self.tp.device_group.group_name, + ) + return mm_outputs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class CutlassScaledMMReduceScatterPattern(BasePattern): + + def get_inputs(self): + input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + mm_weight = torch.empty([16, 16], device=self.device, + dtype=FP8_DTYPE).contiguous().transpose(0, 1) + scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32) + scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) + + cutlass_mm_output = torch.empty([16, 16], + device=self.device, + dtype=self.dtype) + return [input, mm_weight, scale_a, scale_b, cutlass_mm_output] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(input: torch.Tensor, weight: torch.Tensor, + scale_a: torch.Tensor, scale_b: torch.Tensor, + cutlass_mm_output: torch.Tensor) -> torch.Tensor: + cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.cutlass_scaled_mm.default, + out=cutlass_mm_output, + a=input, + b=weight, + a_scales=scale_a, + b_scales=scale_b, + bias=None) + + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + cutlass_scaled_mm[1], + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + return reduce_scatter + + def replacement(input: torch.Tensor, mat2: torch.Tensor, + scale_a: torch.Tensor, scale_b: torch.Tensor, + cutlass_mm_output: torch.Tensor) -> torch.Tensor: + gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + input, + mat2, + scale_a, + scale_b, + "avg", + scatter_dim=0, + out_dtype=self.dtype, + group_name=self.tp.device_group.group_name, + ) + + return gemm_rs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllGatherCutlassScaledMMPattern(BasePattern): + + def get_inputs(self): + x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) + weight = torch.empty([16, 16], device=self.device, + dtype=FP8_DTYPE).contiguous().transpose(0, 1) + + s1 = x.shape[0] * self.tp_size + + scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32) + scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) + + s2 = weight.shape[1] + output = torch.empty([s1, s2], device=self.device, dtype=self.dtype) + + return [x, weight, scale_a, scale_b, output] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + x: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + output: torch.Tensor, + ) -> torch.Tensor: + all_gather = torch.ops.vllm.all_gather.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + + cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.cutlass_scaled_mm.default, + out=output, + a=all_gather, + b=weight, + a_scales=scale_a, + b_scales=scale_b, + bias=None) + return cutlass_scaled_mm[1] + + def replacement(x: torch.Tensor, weight: torch.Tensor, + scale_a: torch.Tensor, scale_b: torch.Tensor, + output: torch.Tensor) -> torch.Tensor: + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa + x, + [weight], + scale_a, + [scale_b], + gather_dim=0, + biases=[None], + result_scales=[None], + out_dtypes=[self.dtype], + use_fast_accum=[False], + group_name=self.tp.device_group.group_name, + ) + return mm_outputs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + class AsyncTPPass(VllmInductorPass): def __init__(self, config: VllmConfig): @@ -133,6 +359,20 @@ def __init__(self, config: VllmConfig): AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns) + # These fusions are enabled only for bfloat16 models because + # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling + # only supports bfloat16 as the output dtype. + if self.model_dtype == torch.bfloat16: + ScaledMMReduceScatterPattern(self.model_dtype, + self.device).register(self.patterns) + AllGatherScaledMMPattern(self.model_dtype, + self.device).register(self.patterns) + + CutlassScaledMMReduceScatterPattern( + self.model_dtype, self.device).register(self.patterns) + AllGatherCutlassScaledMMPattern( + self.model_dtype, self.device).register(self.patterns) + def is_applicable_for_shape(self, shape: Optional[int]) -> bool: # only do replace for specific shapes tp_size = get_tensor_model_parallel_world_size() @@ -142,7 +382,7 @@ def __call__(self, graph: fx.Graph): self.begin() self.dump_graph(graph, "before_async_tp_pass") count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", count) + logger.debug("Replaced %s patterns with async TP pass.", count) self.dump_graph(graph, "after_async_tp_pass") self.end_and_log() diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 6107046e40dc..ebc025cba71e 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -477,6 +477,6 @@ def __call__(self, graph: fx.Graph): self.begin() self.dump_graph(graph, "before_sequence_parallelism_pass") count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", count) + logger.debug("Replaced %s patterns with sequence parallelism", count) self.dump_graph(graph, "after_sequence_parallelism_pass") self.end_and_log()