diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 86139d598582..92ce10a9efc0 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional import pytest import torch @@ -34,15 +35,15 @@ class Relu3(ReLUSquaredActivation): [ # Default values based on compile level # - All by default (no Inductor compilation) - ("", 0, False, [True] * 4, True), - ("", 1, True, [True] * 4, True), - ("", 2, False, [True] * 4, True), + (None, 0, False, [True] * 4, True), + (None, 1, True, [True] * 4, True), + (None, 2, False, [True] * 4, True), # - None by default (with Inductor) - ("", 3, True, [False] * 4, False), - ("", 4, True, [False] * 4, False), + (None, 3, True, [False] * 4, False), + (None, 4, True, [False] * 4, False), # - All by default (without Inductor) - ("", 3, False, [True] * 4, True), - ("", 4, False, [True] * 4, True), + (None, 3, False, [True] * 4, True), + (None, 4, False, [True] * 4, True), # Explicitly enabling/disabling # # Default: all @@ -54,7 +55,7 @@ class Relu3(ReLUSquaredActivation): # All but SiluAndMul ("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True), # All but ReLU3 (even if ReLU2 is on) - ("-relu3,relu2", 3, False, [1, 1, 1, 0], True), + ("-relu3,+relu2", 3, False, [1, 1, 1, 0], True), # RMSNorm and SiluAndMul ("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False), # All but RMSNorm @@ -67,12 +68,13 @@ class Relu3(ReLUSquaredActivation): # All but RMSNorm ("all,-rms_norm", 4, True, [0, 1, 1, 1], True), ]) -def test_enabled_ops(env: str, torch_level: int, use_inductor: bool, +def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool, ops_enabled: list[int], default_on: bool): + custom_ops = env.split(',') if env else [] vllm_config = VllmConfig( compilation_config=CompilationConfig(use_inductor=bool(use_inductor), level=torch_level, - custom_ops=env.split(","))) + custom_ops=custom_ops)) with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on