diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index bb42b5f29a72..d5ebff80b91d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -110,7 +110,9 @@ steps: - vllm/core/ - tests/distributed - tests/spec_decode/e2e/test_integration_dist_tp4 + - tests/compile commands: + - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py @@ -218,7 +220,7 @@ steps: - vllm/ - tests/compile commands: - - pytest -v -s compile/test_full_graph_smoke.py + - pytest -v -s compile/test_basic_correctness.py - label: "PyTorch Fullgraph Test" # 18min source_file_dependencies: @@ -382,7 +384,7 @@ steps: - tests/distributed/ - vllm/compilation commands: - - pytest -v -s ./compile/test_full_graph_multi_gpu.py + - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py new file mode 100644 index 000000000000..c31b87421856 --- /dev/null +++ b/tests/compile/test_basic_correctness.py @@ -0,0 +1,28 @@ +from typing import Dict, List, Optional + +import pytest + +from vllm.utils import cuda_device_count_stateless + +from ..utils import compare_all_settings +from .utils import TEST_MODELS_SMOKE + + +@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) +@pytest.mark.parametrize("pp_size", [1, 2]) +@pytest.mark.parametrize("tp_size", [1]) +def test_compile_correctness(model_info, pp_size, tp_size): + # this test is run under multiple suits, with different GPUs. + # make sure we only run the test with correct CUDA devices. + # don't use "<", as it will duplicate the tests. + if cuda_device_count_stateless() != pp_size * tp_size: + pytest.skip("Not correct CUDA devices for the test.") + model = model_info[0] + model_args = model_info[1] + all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"] + + ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3 + all_envs: List[Optional[Dict[str, str]]] = [{ + "VLLM_TORCH_COMPILE_LEVEL": + str(i) + } for i in range(3)] + compare_all_settings(model, all_args, all_envs) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 5dd65ad7236f..03e377023b99 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,13 +1,16 @@ import pytest -from vllm.compilation.backends import vllm_backend - +from ..utils import fork_new_process_for_each_test from .utils import TEST_MODELS, check_full_graph_support @pytest.mark.parametrize("model_info", TEST_MODELS) -@pytest.mark.parametrize("backend", ["eager", vllm_backend]) -def test_full_graph(model_info, backend): +@pytest.mark.parametrize("optimization_level", [1, 2]) +@fork_new_process_for_each_test +def test_full_graph(model_info, optimization_level): model = model_info[0] model_kwargs = model_info[1] - check_full_graph_support(model, model_kwargs, backend, tp_size=1) + check_full_graph_support(model, + model_kwargs, + optimization_level, + tp_size=1) diff --git a/tests/compile/test_full_graph_multi_gpu.py b/tests/compile/test_full_graph_multi_gpu.py deleted file mode 100644 index e9883d5254e7..000000000000 --- a/tests/compile/test_full_graph_multi_gpu.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -from vllm.compilation.backends import vllm_backend -from vllm.utils import cuda_device_count_stateless - -from ..utils import fork_new_process_for_each_test -from .utils import TEST_MODELS_SMOKE, check_full_graph_support - - -@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) -@pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("backend", ["eager", vllm_backend]) -@fork_new_process_for_each_test -def test_full_graph_multi_gpu(model_info, tp_size, backend): - model = model_info[0] - model_kwargs = model_info[1] - - # Skip the test if there are not enough CUDA devices. - if cuda_device_count_stateless() < tp_size: - pytest.skip("Not enough CUDA devices for the test.") - - check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size) diff --git a/tests/compile/test_full_graph_smoke.py b/tests/compile/test_full_graph_smoke.py deleted file mode 100644 index 0c5a95b4ead4..000000000000 --- a/tests/compile/test_full_graph_smoke.py +++ /dev/null @@ -1,13 +0,0 @@ -import pytest - -from vllm.compilation.backends import vllm_backend - -from .utils import TEST_MODELS_SMOKE, check_full_graph_support - - -@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) -@pytest.mark.parametrize("backend", ["eager", vllm_backend]) -def test_full_graph(model_info, backend): - model = model_info[0] - model_kwargs = model_info[1] - check_full_graph_support(model, model_kwargs, backend, tp_size=1) diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 2d06a0946d91..eb5b2e741f96 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,14 +4,12 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.plugins import set_torch_compile_backend from vllm.utils import is_hip TEST_MODELS_SMOKE = [ - ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { - "quantization": "compressed-tensors" - }), - ("meta-llama/Meta-Llama-3-8B", {}), + ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", + ["--quantization", "compressed-tensors"]), + ("meta-llama/Meta-Llama-3-8B", []), ] TEST_MODELS = [ @@ -68,20 +66,20 @@ })) -def check_full_graph_support(model, model_kwargs, backend, tp_size=1): +def check_full_graph_support(model, + model_kwargs, + optimization_level, + tp_size=1): # make sure these models can be captured in full graph mode - if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: - os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" - os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level) + os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" # Inductor doesn't support fp8/gptq_marlin_24 yet. quantization = model_kwargs.get("quantization") if (quantization == "fp8" or quantization == "gptq_marlin" - or quantization == "gptq_marlin_24") and backend != "eager": + or quantization == "gptq_marlin_24") and optimization_level > 1: return - set_torch_compile_backend(backend) - prompts = [ "Hello, my name is", "The president of the United States is", diff --git a/tests/utils.py b/tests/utils.py index 3eff77f396e1..618b8b7a70f4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -180,9 +180,24 @@ def compare_two_settings(model: str, env1: The first set of environment variables to pass to the API server. env2: The second set of environment variables to pass to the API server. """ + compare_all_settings(model, [arg1, arg2], [env1, env2], max_wait_seconds) + +def compare_all_settings(model: str, + all_args: List[List[str]], + all_envs: List[Optional[Dict[str, str]]], + max_wait_seconds: Optional[float] = None) -> None: + """ + Launch API server with several different sets of arguments/environments + and compare the results of the API calls with the first set of arguments. + + Args: + model: The model to test. + all_args: A list of argument lists to pass to the API server. + all_envs: A list of environment dictionaries to pass to the API server. + """ trust_remote_code = "--trust-remote-code" - if trust_remote_code in arg1 or trust_remote_code in arg2: + if any(trust_remote_code in args for args in all_args): tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) else: @@ -190,8 +205,9 @@ def compare_two_settings(model: str, prompt = "Hello, my name is" token_ids = tokenizer(prompt)["input_ids"] - results = [] - for args, env in ((arg1, env1), (arg2, env2)): + ref_results: List = [] + for i, (args, env) in enumerate(zip(all_args, all_envs)): + compare_results: List = [] with RemoteOpenAIServer(model, args, env_dict=env, @@ -202,10 +218,13 @@ def compare_two_settings(model: str, models = client.models.list() models = models.data served_model = models[0] - results.append({ - "test": "models_list", - "id": served_model.id, - "root": served_model.root, + (ref_results if i == 0 else compare_results).append({ + "test": + "models_list", + "id": + served_model.id, + "root": + served_model.root, }) # test with text prompt @@ -214,11 +233,15 @@ def compare_two_settings(model: str, max_tokens=5, temperature=0.0) - results.append({ - "test": "single_completion", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, + (ref_results if i == 0 else compare_results).append({ + "test": + "single_completion", + "text": + completion.choices[0].text, + "finish_reason": + completion.choices[0].finish_reason, + "usage": + completion.usage, }) # test using token IDs @@ -229,11 +252,15 @@ def compare_two_settings(model: str, temperature=0.0, ) - results.append({ - "test": "token_ids", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, + (ref_results if i == 0 else compare_results).append({ + "test": + "token_ids", + "text": + completion.choices[0].text, + "finish_reason": + completion.choices[0].finish_reason, + "usage": + completion.usage, }) # test seeded random sampling @@ -243,11 +270,15 @@ def compare_two_settings(model: str, seed=33, temperature=1.0) - results.append({ - "test": "seeded_sampling", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, + (ref_results if i == 0 else compare_results).append({ + "test": + "seeded_sampling", + "text": + completion.choices[0].text, + "finish_reason": + completion.choices[0].finish_reason, + "usage": + completion.usage, }) # test seeded random sampling with multiple prompts @@ -257,7 +288,7 @@ def compare_two_settings(model: str, seed=33, temperature=1.0) - results.append({ + (ref_results if i == 0 else compare_results).append({ "test": "seeded_sampling", "text": [choice.text for choice in completion.choices], @@ -275,10 +306,13 @@ def compare_two_settings(model: str, temperature=0.0, ) - results.append({ - "test": "simple_list", - "text0": batch.choices[0].text, - "text1": batch.choices[1].text, + (ref_results if i == 0 else compare_results).append({ + "test": + "simple_list", + "text0": + batch.choices[0].text, + "text1": + batch.choices[1].text, }) # test streaming @@ -294,18 +328,25 @@ def compare_two_settings(model: str, assert len(chunk.choices) == 1 choice = chunk.choices[0] texts[choice.index] += choice.text - results.append({ + (ref_results if i == 0 else compare_results).append({ "test": "streaming", "texts": texts, }) - n = len(results) // 2 - arg1_results = results[:n] - arg2_results = results[n:] - for arg1_result, arg2_result in zip(arg1_results, arg2_results): - assert arg1_result == arg2_result, ( - f"Results for {model=} are not the same with {arg1=} and {arg2=}. " - f"{arg1_result=} != {arg2_result=}") + if i > 0: + # if any setting fails, raise an error early + ref_args = all_args[0] + ref_envs = all_envs[0] + compare_args = all_args[i] + compare_envs = all_envs[i] + for ref_result, compare_result in zip(ref_results, + compare_results): + assert ref_result == compare_result, ( + f"Results for {model=} are not the same.\n" + f"{ref_args=} {ref_envs=}\n" + f"{compare_args=} {compare_envs=}\n" + f"{ref_result=}\n" + f"{compare_result=}\n") def init_test_distributed_environment( diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 43ca6c9ff160..eb5f0c0c1618 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -13,6 +13,7 @@ compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) +from vllm.compilation import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: @@ -701,108 +702,169 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - torch.ops.vllm.reshape_and_cache_flash( - key, - value, - kv_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - k_scale, - v_scale, - ) + output = torch.ops.vllm.unified_flash_attention( + query, + key, + value, + self.num_heads, + self.head_size, + self.num_kv_heads, + kv_cache, + self.kv_cache_dtype, + k_scale, + v_scale, + self.scale, + self.sliding_window, + self.alibi_slopes, + self.logits_soft_cap, + ) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - prefill_output = torch.ops.vllm.flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ) - else: - # prefix-enabled attention - assert prefill_meta.seq_lens is not None - max_seq_len = max(prefill_meta.seq_lens) - prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=self.logits_soft_cap, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - decode_output = torch.ops.vllm.flash_attn_with_kvcache( - decode_query.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, + return output + + +@torch.library.custom_op("vllm::unified_flash_attention", + mutates_args=["kv_cache"]) +def unified_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + current_metadata = get_forward_context() + assert current_metadata is not None + assert isinstance(current_metadata, FlashAttentionMetadata) + attn_metadata: FlashAttentionMetadata = current_metadata + + if kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + torch.ops.vllm.reshape_and_cache_flash( + key, + value, + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + prefill_output = torch.ops.vllm.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ) + else: + # prefix-enabled attention + assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) + prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, causal=True, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ).squeeze(1) - - if prefill_output is None: - assert decode_output is not None - return decode_output.view(num_decode_tokens, hidden_size) - if decode_output is None: - assert prefill_output is not None - return prefill_output.view(num_prefill_tokens, hidden_size) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + decode_output = torch.ops.vllm.flash_attn_with_kvcache( + decode_query.unsqueeze(1), + key_cache, + value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=softmax_scale, + causal=True, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ).squeeze(1) + + if prefill_output is None: + assert decode_output is not None + return decode_output.view(num_decode_tokens, hidden_size) + if decode_output is None: + assert prefill_output is not None + return prefill_output.view(num_prefill_tokens, hidden_size) + output = torch.cat([prefill_output, decode_output], dim=0) + return output.view(num_tokens, hidden_size) + + +@unified_flash_attention.register_fake +def _( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + return torch.empty_like(query) diff --git a/vllm/compilation/__init__.py b/vllm/compilation/__init__.py index e69de29bb2d1..229503b167d7 100644 --- a/vllm/compilation/__init__.py +++ b/vllm/compilation/__init__.py @@ -0,0 +1,28 @@ +from contextlib import contextmanager +from typing import Any + +_forward_context: Any = None + + +def get_forward_context() -> Any: + """Get the current forward context.""" + return _forward_context + + +def set_forward_context(context: Any): + """Set the current forward context.""" + global _forward_context + _forward_context = context + + +@contextmanager +def forward_context(context: Any): + """A context manager that stores the current forward context, + can be attention metadata, etc.""" + global _forward_context + prev_context = _forward_context + _forward_context = context + try: + yield + finally: + _forward_context = prev_context diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index de0b1d8a7575..98ab6667a689 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,8 +1,16 @@ import operator +from typing import Callable, Dict, Optional, Tuple, Union +from weakref import ReferenceType import torch import torch.fx as fx +from vllm.logger import init_logger + +from .wrapper import TorchCompileWrapperWithCustomDispatcher + +logger = init_logger(__name__) + def fix_functionalization(graph: fx.Graph): """ @@ -148,9 +156,112 @@ def fix_functionalization(graph: fx.Graph): # print(graph.python_code(root_module="self", verbose=True).src, file=f) -def vllm_backend(graph, example_inputs): +def wrap_inductor(graph, example_inputs, additional_inductor_config): from torch._inductor import config current_config = config.shallow_copy_dict() from torch._inductor.compile_fx import compile_fx + + if additional_inductor_config is not None: + current_config.update(additional_inductor_config) + if current_config['post_grad_custom_post_pass'] is not None: + logger.warning( + "post_grad_custom_post_pass is already set in the config. " + "Overwriting it with the fix_functionalization") current_config['post_grad_custom_post_pass'] = fix_functionalization return compile_fx(graph, example_inputs, config_patches=current_config) + + +def vllm_backend( + graph, + example_inputs, + model_ref: Optional[ + ReferenceType[TorchCompileWrapperWithCustomDispatcher]] = None, + additional_inductor_config: Optional[Dict] = None) -> Callable: + + # flags for all the seen shapes, whether we need to specialize + runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {} + + # if we need to specialize, the compiled graph for that shape + runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {} + + # this is the first compilation, we will compile a graph with + # dynamic shape, as the caller will mark first dimension as dynamic + logger.info("Compiling a graph for general shapes") + graph_for_symbolic_shape = wrap_inductor(graph, example_inputs, + additional_inductor_config) + + # TODO: Dynamo does not pass all dynamic shapes. + # Need to investigate why. It works now because all the dynamic + # shapes have the same value, and either of them can be used. + sym_shape_indices = [ + i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt) + ] + + first_run = True + + # this is the function we return to Dynamo to run finally + def compiled_graph_wrapper(*args): + + runtime_shapes: Tuple[int, + ...] = tuple(args[i] for i in sym_shape_indices) + + nonlocal first_run + nonlocal runtime_shapes_to_compile_flags + nonlocal runtime_shapes_to_compiled_graph + + if first_run: + # the first compilation is for profiling, we directly run it + first_run = False + return graph_for_symbolic_shape(*args) + + if model_ref is None: + # no information about the model, we cannot specialize + return graph_for_symbolic_shape(*args) + + model: TorchCompileWrapperWithCustomDispatcher = model_ref() + assert model is not None, "model is garbage collected" + + if runtime_shapes not in runtime_shapes_to_compile_flags: + # we haven't seen this shape before + # query the model if we need to specialize for this shape + runtime_shapes_to_compile_flags[ + runtime_shapes] = model.need_to_specialize(runtime_shapes) + + if not runtime_shapes_to_compile_flags[runtime_shapes]: + # we don't need to specialize for this shape + return graph_for_symbolic_shape(*args) + + if runtime_shapes not in runtime_shapes_to_compiled_graph: + # we need to specialize for this shape, and we haven't compiled + # compile the graph for this shape + logger.info("Compiling a graph for shapes %s", runtime_shapes) + runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor( + graph, args, additional_inductor_config) + + return runtime_shapes_to_compiled_graph[runtime_shapes](*args) + + return compiled_graph_wrapper + + +def select_default_backend(level: int) -> Union[str, Callable]: + if level == 1: + backend = "eager" + return backend + assert level in [2, 3], f"Invalid level {level}" + + from vllm.compilation.backends import vllm_backend + from vllm.plugins import get_inductor_additional_configs + additional_configs = get_inductor_additional_configs() + + if level == 3: + if "max_autotune" in additional_configs and not additional_configs[ + "max_autotune"]: + logger.warning( + "max_autotune is disabled, but is overridden by level 3") + additional_configs['max_autotune'] = True + + from functools import partial + backend = partial(vllm_backend, + additional_inductor_config=additional_configs) + + return backend diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py new file mode 100644 index 000000000000..0c2840b35949 --- /dev/null +++ b/vllm/compilation/decorators.py @@ -0,0 +1,65 @@ +from typing import List, Optional, Tuple, Union + +import torch + +import vllm.envs as envs +from vllm.attention import AttentionMetadata +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.sequence import IntermediateTensors + + +def support_compile_llama_style(cls: type): + """ + A decorator to add support for compiling the forward method of a class. + If a module's **forward signature** is compatible with llama, this + decorator can be used to enable the compilation of the forward method. + """ + + # take care of method resolution order + # make sure super().__init__ is called on the base class + # other than TorchCompileWrapperWithCustomDispatcher + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) + + old_init = cls.__init__ + + def __init__(self, *args, **kwargs): + old_init(self, *args, **kwargs) + self._use_torch_compile = envs.VLLM_TORCH_COMPILE_LEVEL > 0 + if self._use_torch_compile: + TorchCompileWrapperWithCustomDispatcher.__init__(self) + + cls.__init__ = __init__ + + def need_to_specialize(self, runtime_shapes: Tuple[int, ...]) -> bool: + if len(self.sizes_to_specialize) == 0: + return False + return runtime_shapes[0] in self.sizes_to_specialize + + cls.need_to_specialize = need_to_specialize + + def __call__( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if not self._use_torch_compile: + return self.forward(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors) + if len(self.compiled_codes) < 1: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(positions, 0) + if intermediate_tensors is not None: + for tensors in intermediate_tensors.tensors.values(): + torch._dynamo.mark_dynamic(tensors, 0) + return self.compiled_callable(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + with self.dispatch_to_code(0): + model_output = self.forward(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + cls.__call__ = __call__ + return cls diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index e923bd36ccc0..c147a13b411a 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -1,9 +1,10 @@ import os import sys +import weakref from abc import abstractmethod from contextlib import contextmanager from types import CodeType -from typing import Callable, List +from typing import Any, Callable, List, Optional, Tuple import torch @@ -23,7 +24,29 @@ class TorchCompileWrapperWithCustomDispatcher: `torch.compile` over the forward method. """ - def __init__(self, compiled_callable: Callable): + def __init__(self, compiled_callable: Optional[Callable] = None): + + if compiled_callable is None: + # default compilation settings + # compiling the forward method + + # choose the compile backend + + # if the user has set the backend, use it + from vllm.plugins import get_torch_compile_backend + backend = get_torch_compile_backend() + if backend is None: + from vllm.compilation.backends import select_default_backend + backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) + if not isinstance(backend, str): + from functools import partial + backend = partial(backend, model_ref=weakref.ref(self)) + + compiled_callable = torch.compile( + self.forward, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend) + self.compiled_callable = compiled_callable self.original_code_object = self.__class__.forward.__code__ self.compiled_codes: List[CodeType] = [] @@ -35,6 +58,8 @@ def __init__(self, compiled_callable: Callable): self.use_custom_dispatcher: bool = \ envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER + self.sizes_to_specialize = [] + def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. NOTE: this function can have additional arguments beyond the forward @@ -79,3 +104,17 @@ def dispatch_to_code(self, index: int): self.__class__.forward.__code__ = self.compiled_codes[index] yield self.__class__.forward.__code__ = self.original_code_object + + def set_sizes_to_specialize(self, sizes: List[Any]): + """Set the sizes to specialize for the compiled code.""" + self.sizes_to_specialize = sizes + + def need_to_specialize(self, runtime_shapes: Tuple[int, ...]) -> bool: + """Check if the current runtime shapes need to be specialized. + If not, we can use the graph for general shapes. + If yes, we will compile the graph for the current shapes. + The argument `runtime_shapes` is a tuple of integers, representing + the runtime shapes of the dimensions marked as dynamic during graph + capture. + """ + return False diff --git a/vllm/envs.py b/vllm/envs.py index 7cbffc83a625..22fe7dbe22a8 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -201,19 +201,20 @@ def get_default_config_root(): "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH": lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1", - # Internal flag to enable Dynamo graph capture - "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": - lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), + # torch.compile optimization level + # 0: no optimization (don't use torch.compile) + # 1: capture the graph, run in eager mode (seconds of compilation time) + # 2: capture the graph, compile with inductor (minutes of compilation time) + # 3: capture the graph, compile with inductor max-autotune (dozens of minutes of compilation time) # noqa + "VLLM_TORCH_COMPILE_LEVEL": + lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), + + # Internal flag for Dynamo testing "VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": lambda: (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in ("true", "1")), - # Internal flag to control whether we use custom op, - # or use the native pytorch implementation - "VLLM_TEST_COMPILE_NO_CUSTOM_OPS": - lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")), - # Internal flag to enable Dynamo fullgraph capture "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool( diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 9102b5e19ebe..09fe35ad6179 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -55,7 +55,7 @@ def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS: + if envs.VLLM_TORCH_COMPILE_LEVEL >= 2: return self.forward_native if is_hip(): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5ff31e3833ec..d44e29b78bdb 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -28,6 +28,7 @@ from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_compile_llama_style from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -344,6 +345,7 @@ def forward( return hidden_states +@support_compile_llama_style class LlamaForCausalLM(nn.Module, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 7939688ef0da..211fedbc6e2e 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, Union +from typing import Callable, Dict, Optional, Union import vllm.envs as envs @@ -42,3 +42,15 @@ def set_torch_compile_backend(backend: Union[Callable, str]): def get_torch_compile_backend() -> Optional[Union[Callable, str]]: return _torch_compile_backend + + +_inductor_additional_configs: Dict = {} + + +def set_inductor_additional_configs(configs: Dict): + global _inductor_additional_configs + _inductor_additional_configs = configs + + +def get_inductor_additional_configs() -> Dict: + return _inductor_additional_configs diff --git a/vllm/sequence.py b/vllm/sequence.py index 781bcedde2b5..894473deb11a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1149,10 +1149,9 @@ def __eq__(self, other: object) -> bool: return self.embeddings == other.embeddings -class IntermediateTensors( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] +# cannot use msgspec.Struct here because Dynamo does not support it +@dataclass +class IntermediateTensors: """For all pipeline stages except the last, we need to return the hidden states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index cf64af72a14a..659e0c9f7525 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -2,6 +2,7 @@ import torch +from vllm.compilation import forward_context from vllm.model_executor.layers.sampler import SamplerOutput try: @@ -293,16 +294,17 @@ def execute_model( if previous_hidden_states is not None else {} # Run model - hidden_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **kwargs, - ) + with forward_context(model_input.attn_metadata): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **kwargs, + ) # Compute the logits. logits = self.model.compute_logits(hidden_states, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 5c5d20a51e7d..6e7d452ba670 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -3,6 +3,7 @@ import torch +from vllm.compilation import forward_context from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -103,7 +104,8 @@ def execute_model( # a placeholder (it has wide hardware support). kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) - ] * num_layers + for _ in range(num_layers) + ] execute_model_kwargs = { "input_ids": @@ -118,7 +120,8 @@ def execute_model( device=self.device), } - hidden_states = model_executable(**execute_model_kwargs) + with forward_context(model_input.attn_metadata): + hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. if not self.is_driver_worker: diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 3bb4e28c6e1b..b1af1287eed5 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -11,6 +11,7 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, get_global_forced_attn_backend, global_force_attn_backend) +from vllm.compilation import forward_context from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -198,17 +199,18 @@ def execute_model( } if self.has_seqlen_agnostic else {} multi_modal_kwargs = model_input.multi_modal_kwargs or {} - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - encoder_input_ids=model_input.encoder_input_tokens, - encoder_positions=model_input.encoder_input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + with forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + encoder_input_ids=model_input.encoder_input_tokens, + encoder_positions=model_input.encoder_input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) @@ -346,7 +348,8 @@ def profile_run(self) -> None: # a placeholder (it has wide hardware support). kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) - ] * num_layers + for _ in range(num_layers) + ] finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4ac67a5fade8..33744a12e2a2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,10 +14,11 @@ import torch.distributed import torch.nn as nn -import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState +from vllm.compilation import forward_context +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -46,8 +47,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, - flatten_2d_lists, is_hip, is_pin_memory_available, - supports_dynamo) + flatten_2d_lists, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -1088,15 +1088,6 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): - from vllm.compilation.backends import vllm_backend - from vllm.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() or vllm_backend - self.model = torch.compile( - self.model, - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=backend) - def save_sharded_state( self, path: str, @@ -1235,9 +1226,13 @@ def profile_run(self) -> None: # it by reference, rather by specializing on the value ``None``. # the `dtype` argument does not matter, and we use `float32` as # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) - ] * num_layers + for _ in range(num_layers) + ] finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) @@ -1386,6 +1381,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] + if isinstance(self.model, TorchCompileWrapperWithCustomDispatcher): + self.model.set_sizes_to_specialize(batch_size_capture_list) + with self.attn_state.graph_capture( max_batch_size), graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the @@ -1458,8 +1456,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # encoder-decoder models. self._update_inputs_to_capture_for_enc_dec_model( capture_inputs) - - graph_runner.capture(**capture_inputs) + with forward_context(attn_metadata): + graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( graph_runner) @@ -1601,15 +1599,16 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + with forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 2472ac25aee4..038f0c31f95c 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -765,7 +765,8 @@ def forward( slot_mapping = slot_mapping.flatten() attn_metadata.slot_mapping = slot_mapping - hidden_states = self.model( + # directly call `forward` to avoid interference from compilation + hidden_states = self.model.forward( token_ids, position_ids, kv_caches,