Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 138 additions & 5 deletions tests/compile/test_async_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
multi_gpu_test)
from .backend import TestBackend

FP8_DTYPE = current_platform.fp8_dtype()

prompts = [
"Hello, my name is",
"The president of the United States is",
Expand All @@ -32,9 +34,10 @@

class TestMMRSModel(torch.nn.Module):

def __init__(self, hidden_size=16):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.gate_proj = torch.nn.Parameter(torch.empty(
(self.hidden_size * 2, hidden_size)),
requires_grad=False)
Expand Down Expand Up @@ -64,9 +67,10 @@ def ops_in_model_after(self):

class TestAGMMModel(torch.nn.Module):

def __init__(self, hidden_size=16):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.weight = torch.nn.Parameter(torch.empty(
(hidden_size, hidden_size)),
requires_grad=False)
Expand All @@ -91,8 +95,125 @@ def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_all_gather_matmul.default]


class _BaseScaledMMModel(torch.nn.Module):

def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\
.contiguous().transpose(0, 1)

# Initialize scale_b for _scaled_mm.
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)


class TestScaledMMRSModel(_BaseScaledMMModel):

def forward(self, input: torch.Tensor):
"""
Forward pass implementing the scaled_mm + reduce scatter in the FX graph

"""
fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(fp8_input,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype)
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
return reduce_scatter

def ops_in_model_before(self):
return [torch.ops.vllm.reduce_scatter.default]

def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]


class TestAGScaledMMModel(_BaseScaledMMModel):

def forward(self, input: torch.Tensor):
"""
Forward pass implementing the all gather + scaled_mm in the FX graph
"""
# Reshape input
fp8_input = input.to(FP8_DTYPE)
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)

scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(all_gather,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype)
return scaled_mm

def ops_in_model_before(self):
return [torch.ops.vllm.all_gather.default]

def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]


class TestCutlassScaledMMRSModel(_BaseScaledMMModel):

def forward(self, input: torch.Tensor):
"""
Forward pass implementing the cutlass_scaled_mm + reduce scatter
in the FX graph

"""
fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=input.device)
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a,
self.scale_b, None)
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
return reduce_scatter

def ops_in_model_before(self):
return [torch.ops.vllm.reduce_scatter.default]

def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]


class TestAGCutlassScaledMMModel(_BaseScaledMMModel):

def forward(self, input: torch.Tensor):
"""
Forward pass implementing the all gather + cutlass_scaled_mm
in the FX graph
"""
# Reshape input
fp8_input = input.to(FP8_DTYPE)
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)

scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)

mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=all_gather.device)
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight,
scale_a, self.scale_b, None)
return mm_out

def ops_in_model_before(self):
return [torch.ops.vllm.all_gather.default]

def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel])
@pytest.mark.parametrize("test_model", [
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel,
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
Expand All @@ -101,6 +222,14 @@ def ops_in_model_after(self):
reason="Only test on CUDA")
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel) and dtype == torch.float16:
pytest.skip(
"Only bf16 high precision output types are supported for " \
"per-token (row-wise) scaling"
)

num_processes = 2

def run_torch_spawn(fn, nprocs):
Expand Down Expand Up @@ -155,7 +284,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass)

model = test_model_cls(hidden_size)
model = test_model_cls(hidden_size,
dtype) # Pass dtype to model constructor

hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype,
Expand All @@ -174,7 +304,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,


@create_new_process_for_each_test()
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"])
@pytest.mark.parametrize("model_id", [
"meta-llama/Llama-3.2-1B-Instruct",
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
])
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("async_tp_enabled", [True])
@pytest.mark.parametrize("distributed_backend", ["mp"])
Expand Down
Loading