@@ -28,42 +28,49 @@ class Relu3(ReLUSquaredActivation):
28
28
29
29
30
30
@pytest .mark .parametrize (
31
- "env, torch_level, ops_enabled, default_on" ,
31
+ "env, torch_level, use_inductor, ops_enabled, default_on" ,
32
32
[
33
33
# Default values based on compile level
34
- ("" , 0 , [True ] * 4 , True ),
35
- ("" , 1 , [True ] * 4 , True ),
36
- ("" , 2 , [True ] * 4 , True ), # All by default
37
- ("" , 3 , [False ] * 4 , False ),
38
- ("" , 4 , [False ] * 4 , False ), # None by default
34
+ # - All by default (no Inductor compilation)
35
+ ("" , 0 , False , [True ] * 4 , True ),
36
+ ("" , 1 , True , [True ] * 4 , True ),
37
+ ("" , 2 , False , [True ] * 4 , True ),
38
+ # - None by default (with Inductor)
39
+ ("" , 3 , True , [False ] * 4 , False ),
40
+ ("" , 4 , True , [False ] * 4 , False ),
41
+ # - All by default (without Inductor)
42
+ ("" , 3 , False , [True ] * 4 , True ),
43
+ ("" , 4 , False , [True ] * 4 , True ),
39
44
# Explicitly enabling/disabling
40
45
#
41
46
# Default: all
42
47
#
43
48
# All but SiluAndMul
44
- ("+rms_norm,-silu_and_mul" , 0 , [1 , 0 , 1 , 1 ], True ),
49
+ ("+rms_norm,-silu_and_mul" , 0 , True , [1 , 0 , 1 , 1 ], True ),
45
50
# Only ReLU3
46
- ("none,-rms_norm,+relu3" , 0 , [0 , 0 , 0 , 1 ], False ),
51
+ ("none,-rms_norm,+relu3" , 1 , False , [0 , 0 , 0 , 1 ], False ),
47
52
# All but SiluAndMul
48
- ("all,-silu_and_mul" , 1 , [1 , 0 , 1 , 1 ], True ),
53
+ ("all,-silu_and_mul" , 2 , True , [1 , 0 , 1 , 1 ], True ),
49
54
# All but ReLU3 (even if ReLU2 is on)
50
- ("-relu3,relu2" , 1 , [1 , 1 , 1 , 0 ], True ),
51
- # GeluAndMul and SiluAndMul
52
- ("none,-relu3,+gelu_and_mul ,+silu_and_mul" , 2 , [ 0 , 1 , 1 , 0 ], False ),
55
+ ("-relu3,relu2" , 3 , False , [1 , 1 , 1 , 0 ], True ),
56
+ # RMSNorm and SiluAndMul
57
+ ("none,-relu3,+rms_norm ,+silu_and_mul" , 4 , False , [ 1 , 1 , 0 , 0 ], False ),
53
58
# All but RMSNorm
54
- ("-rms_norm" , 2 , [0 , 1 , 1 , 1 ], True ),
59
+ ("-rms_norm" , 3 , False , [0 , 1 , 1 , 1 ], True ),
55
60
#
56
61
# Default: none
57
62
#
58
63
# Only ReLU3
59
- ("-silu_and_mul,+relu3" , 3 , [0 , 0 , 0 , 1 ], False ),
64
+ ("-silu_and_mul,+relu3" , 3 , True , [0 , 0 , 0 , 1 ], False ),
60
65
# All but RMSNorm
61
- ("all,-rms_norm" , 4 , [0 , 1 , 1 , 1 ], True ),
66
+ ("all,-rms_norm" , 4 , True , [0 , 1 , 1 , 1 ], True ),
62
67
])
63
- def test_enabled_ops (env : str , torch_level : int , ops_enabled : list [int ],
64
- default_on : bool ):
65
- vllm_config = VllmConfig (compilation_config = CompilationConfig (
66
- level = torch_level , custom_ops = env .split ("," )))
68
+ def test_enabled_ops (env : str , torch_level : int , use_inductor : bool ,
69
+ ops_enabled : list [int ], default_on : bool ):
70
+ vllm_config = VllmConfig (
71
+ compilation_config = CompilationConfig (use_inductor = bool (use_inductor ),
72
+ level = torch_level ,
73
+ custom_ops = env .split ("," )))
67
74
with set_current_vllm_config (vllm_config ):
68
75
assert CustomOp .default_on () == default_on
69
76
0 commit comments