Skip to content

Commit bcd95b5

Browse files
committed
Fix func test
Signed-off-by: Luka Govedič <[email protected]>
1 parent bb0254a commit bcd95b5

File tree

2 files changed

+17
-19
lines changed

2 files changed

+17
-19
lines changed

csrc/layernorm_quant_kernels.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ void fused_add_rms_norm_static_fp8_quant(
216216
double epsilon) {
217217
TORCH_CHECK(out.is_contiguous());
218218
TORCH_CHECK(residual.is_contiguous());
219+
TORCH_CHECK(residual.scalar_type() == input.scalar_type());
220+
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
219221
int hidden_size = input.size(-1);
220222
int input_stride = input.stride(-2);
221223
int num_tokens = input.numel() / hidden_size;

tests/compile/test_functionalization.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ def forward(self, x):
5454
return y
5555

5656
def example_inputs(self, num_tokens=32, hidden_size=128):
57-
dtype = torch.float16 if TEST_FP8 else torch.float32
58-
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
57+
return (torch.rand(num_tokens, hidden_size * 2),)
5958

6059
def ops_in_model(self, do_fusion):
6160
if TEST_FP8 and do_fusion:
@@ -73,15 +72,11 @@ def __init__(self, hidden_size=16, intermediate_size=32):
7372
self.hidden_size = hidden_size
7473
self.intermediate_size = intermediate_size
7574

76-
dtype = torch.float16 if TEST_FP8 else torch.float32
77-
7875
self.gate_proj = torch.nn.Parameter(
79-
torch.empty((intermediate_size, hidden_size), dtype=dtype)
76+
torch.empty((intermediate_size, hidden_size))
8077
)
8178
self.norm = RMSNorm(intermediate_size, 1e-05)
82-
self.norm.weight = torch.nn.Parameter(
83-
torch.ones(intermediate_size, dtype=dtype)
84-
)
79+
self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
8580

8681
torch.nn.init.normal_(self.gate_proj, std=0.02)
8782

@@ -118,9 +113,8 @@ def forward(self, hidden_states, residual):
118113
return norm_output, residual_output
119114

120115
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
121-
dtype = torch.float16 if TEST_FP8 else torch.float32
122-
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
123-
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
116+
hidden_states = torch.randn((batch_size * seq_len, hidden_size))
117+
residual = torch.randn((batch_size * seq_len, hidden_size))
124118
return (hidden_states, residual)
125119

126120
def ops_in_model(self, do_fusion):
@@ -151,10 +145,9 @@ def forward(self, positions, q, k):
151145
return q_rotated, k_rotated
152146

153147
def example_inputs(self, num_tokens=32, head_dim=64):
154-
dtype = torch.float16
155148
positions = torch.arange(num_tokens, dtype=torch.long)
156-
q = torch.randn(num_tokens, head_dim, dtype=dtype)
157-
k = torch.randn(num_tokens, head_dim, dtype=dtype)
149+
q = torch.randn(num_tokens, head_dim)
150+
k = torch.randn(num_tokens, head_dim)
158151
return (positions, q, k)
159152

160153
def ops_in_model(self, do_fusion):
@@ -172,7 +165,7 @@ def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
172165
self.hidden_size = head_dim * num_heads
173166

174167
self.qkv_proj = torch.nn.Linear(
175-
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
168+
self.hidden_size, self.hidden_size * 3, bias=False
176169
)
177170

178171
self.rotary_emb = get_rope(
@@ -196,10 +189,9 @@ def forward(self, positions, hidden_states):
196189
return qkv_updated
197190

198191
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
199-
dtype = torch.float16
200192
hidden_size = head_dim * num_heads
201193
positions = torch.arange(num_tokens, dtype=torch.long)
202-
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
194+
hidden_states = torch.randn(num_tokens, hidden_size)
203195
return (positions, hidden_states)
204196

205197
def ops_in_model(self, do_fusion):
@@ -217,14 +209,18 @@ def ops_not_in_model(self):
217209
]
218210

219211

212+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
220213
@pytest.mark.parametrize("model_class", MODELS)
221214
@pytest.mark.parametrize("do_fusion", [True, False])
222215
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
223-
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
216+
def test_fix_functionalization(
217+
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
218+
):
224219
torch.set_default_device("cuda")
220+
torch.set_default_dtype(dtype)
225221

226222
vllm_config = VllmConfig(
227-
model_config=ModelConfig(dtype=torch.bfloat16),
223+
model_config=ModelConfig(dtype=dtype),
228224
compilation_config=CompilationConfig(
229225
custom_ops=["all"],
230226
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),

0 commit comments

Comments
 (0)