@@ -54,8 +54,7 @@ def forward(self, x):
5454 return y
5555
5656 def example_inputs (self , num_tokens = 32 , hidden_size = 128 ):
57- dtype = torch .float16 if TEST_FP8 else torch .float32
58- return (torch .rand (num_tokens , hidden_size * 2 , dtype = dtype ),)
57+ return (torch .rand (num_tokens , hidden_size * 2 ),)
5958
6059 def ops_in_model (self , do_fusion ):
6160 if TEST_FP8 and do_fusion :
@@ -73,15 +72,11 @@ def __init__(self, hidden_size=16, intermediate_size=32):
7372 self .hidden_size = hidden_size
7473 self .intermediate_size = intermediate_size
7574
76- dtype = torch .float16 if TEST_FP8 else torch .float32
77-
7875 self .gate_proj = torch .nn .Parameter (
79- torch .empty ((intermediate_size , hidden_size ), dtype = dtype )
76+ torch .empty ((intermediate_size , hidden_size ))
8077 )
8178 self .norm = RMSNorm (intermediate_size , 1e-05 )
82- self .norm .weight = torch .nn .Parameter (
83- torch .ones (intermediate_size , dtype = dtype )
84- )
79+ self .norm .weight = torch .nn .Parameter (torch .ones (intermediate_size ))
8580
8681 torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
8782
@@ -118,9 +113,8 @@ def forward(self, hidden_states, residual):
118113 return norm_output , residual_output
119114
120115 def example_inputs (self , batch_size = 8 , hidden_size = 16 , seq_len = 16 ):
121- dtype = torch .float16 if TEST_FP8 else torch .float32
122- hidden_states = torch .randn ((batch_size * seq_len , hidden_size ), dtype = dtype )
123- residual = torch .randn ((batch_size * seq_len , hidden_size ), dtype = dtype )
116+ hidden_states = torch .randn ((batch_size * seq_len , hidden_size ))
117+ residual = torch .randn ((batch_size * seq_len , hidden_size ))
124118 return (hidden_states , residual )
125119
126120 def ops_in_model (self , do_fusion ):
@@ -151,10 +145,9 @@ def forward(self, positions, q, k):
151145 return q_rotated , k_rotated
152146
153147 def example_inputs (self , num_tokens = 32 , head_dim = 64 ):
154- dtype = torch .float16
155148 positions = torch .arange (num_tokens , dtype = torch .long )
156- q = torch .randn (num_tokens , head_dim , dtype = dtype )
157- k = torch .randn (num_tokens , head_dim , dtype = dtype )
149+ q = torch .randn (num_tokens , head_dim )
150+ k = torch .randn (num_tokens , head_dim )
158151 return (positions , q , k )
159152
160153 def ops_in_model (self , do_fusion ):
@@ -172,7 +165,7 @@ def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
172165 self .hidden_size = head_dim * num_heads
173166
174167 self .qkv_proj = torch .nn .Linear (
175- self .hidden_size , self .hidden_size * 3 , bias = False , dtype = torch . float16
168+ self .hidden_size , self .hidden_size * 3 , bias = False
176169 )
177170
178171 self .rotary_emb = get_rope (
@@ -196,10 +189,9 @@ def forward(self, positions, hidden_states):
196189 return qkv_updated
197190
198191 def example_inputs (self , num_tokens = 32 , head_dim = 64 , num_heads = 4 ):
199- dtype = torch .float16
200192 hidden_size = head_dim * num_heads
201193 positions = torch .arange (num_tokens , dtype = torch .long )
202- hidden_states = torch .randn (num_tokens , hidden_size , dtype = dtype )
194+ hidden_states = torch .randn (num_tokens , hidden_size )
203195 return (positions , hidden_states )
204196
205197 def ops_in_model (self , do_fusion ):
@@ -217,14 +209,18 @@ def ops_not_in_model(self):
217209]
218210
219211
212+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
220213@pytest .mark .parametrize ("model_class" , MODELS )
221214@pytest .mark .parametrize ("do_fusion" , [True , False ])
222215@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE != "cuda" , reason = "Only test on CUDA" )
223- def test_fix_functionalization (model_class : torch .nn .Module , do_fusion : bool ):
216+ def test_fix_functionalization (
217+ model_class : torch .nn .Module , do_fusion : bool , dtype : torch .dtype
218+ ):
224219 torch .set_default_device ("cuda" )
220+ torch .set_default_dtype (dtype )
225221
226222 vllm_config = VllmConfig (
227- model_config = ModelConfig (dtype = torch . bfloat16 ),
223+ model_config = ModelConfig (dtype = dtype ),
228224 compilation_config = CompilationConfig (
229225 custom_ops = ["all" ],
230226 pass_config = PassConfig (enable_fusion = do_fusion , enable_noop = True ),
0 commit comments