@@ -35,6 +35,15 @@ def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
3535 assert callable (torch .ops .vllm .rocm_aiter_biased_grouped_topk )
3636
3737
38+ def test_rocm_aiter_grouped_topk_custom_op_registration ():
39+ """Test that the custom op is correctly registered."""
40+ # Check if the op exists in torch.ops.vllm
41+ assert hasattr (torch .ops .vllm , 'rocm_aiter_grouped_topk' )
42+
43+ # Check if the op is callable
44+ assert callable (torch .ops .vllm .rocm_aiter_grouped_topk )
45+
46+
3847def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility ():
3948 """Test that the op can be used with torch.compile."""
4049 # Create test tensors
@@ -120,3 +129,87 @@ def biased_grouped_topk_fn(gating_output, e_score_correction_bias,
120129 rtol = 1e-2 ,
121130 atol = 1e-2 )
122131 assert torch .allclose (topk_ids_original , topk_ids_compiled )
132+
133+
134+ def test_rocm_aiter_grouped_topk_torch_compile_compatibility ():
135+ """Test that the op can be used with torch.compile."""
136+ # Create test tensors
137+ token = 64
138+ expert = 256
139+ num_expert_group = 8
140+ topk = 8
141+ topk_group = 4
142+ renormalize = True
143+ scoring_func = "softmax"
144+ scale_factor = 1.0
145+
146+ gating_output = torch .randn ((token , expert ),
147+ dtype = torch .bfloat16 ,
148+ device = "cuda" )
149+
150+ device = gating_output .device
151+ topk_ids = torch .empty ((token , topk ), dtype = torch .int32 , device = device )
152+ topk_weights = torch .empty ((token , topk ),
153+ dtype = torch .float32 ,
154+ device = device )
155+
156+ # Define a function that uses the op
157+ def grouped_topk_fn (gating_output , topk_weights , topk_ids , scoring_func ):
158+ return torch .ops .vllm .rocm_aiter_grouped_topk (
159+ gating_output , topk_weights , topk_ids , num_expert_group ,
160+ topk_group , renormalize , scoring_func , scale_factor )
161+
162+ # Verify the op's fake implementation
163+ torch .library .opcheck (torch .ops .vllm .rocm_aiter_grouped_topk ,
164+ (gating_output , topk_weights , topk_ids ),
165+ kwargs = {
166+ "num_expert_group" : num_expert_group ,
167+ "topk_group" : topk_group ,
168+ "need_renorm" : renormalize ,
169+ "scoring_func" : scoring_func ,
170+ "routed_scaling_factor" : scale_factor
171+ },
172+ test_utils = ("test_faketensor" ))
173+
174+ # Compile the function with appropriate settings
175+ compiled_fn = torch .compile (grouped_topk_fn ,
176+ fullgraph = True ,
177+ backend = "inductor" ,
178+ mode = "reduce-overhead" ,
179+ dynamic = False )
180+
181+ topk_weights_original = torch .empty ((token , topk ),
182+ dtype = torch .float32 ,
183+ device = device )
184+ topk_ids_original = torch .empty ((token , topk ),
185+ dtype = torch .int32 ,
186+ device = device )
187+
188+ topk_weights_compiled = torch .empty ((token , topk ),
189+ dtype = torch .float32 ,
190+ device = device )
191+ topk_ids_compiled = torch .empty ((token , topk ),
192+ dtype = torch .int32 ,
193+ device = device )
194+
195+ # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
196+ grouped_topk_fn (gating_output , topk_weights_original , topk_ids_original ,
197+ scoring_func )
198+ compiled_fn (gating_output , topk_weights_compiled , topk_ids_compiled ,
199+ scoring_func )
200+
201+ # Sort the results for comparison since the order might not be deterministic
202+ topk_ids_original , indices_original = torch .sort (topk_ids_original )
203+ topk_weights_original = torch .gather (topk_weights_original , 1 ,
204+ indices_original )
205+
206+ topk_ids_compiled , indices_compiled = torch .sort (topk_ids_compiled )
207+ topk_weights_compiled = torch .gather (topk_weights_compiled , 1 ,
208+ indices_compiled )
209+
210+ # Verify results match
211+ assert torch .allclose (topk_weights_original ,
212+ topk_weights_compiled ,
213+ rtol = 1e-2 ,
214+ atol = 1e-2 )
215+ assert torch .allclose (topk_ids_original , topk_ids_compiled )
0 commit comments