From 48541109039873290ae70a03cc8bbce44b4edf16 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 28 Sep 2024 21:28:32 -0700 Subject: [PATCH 01/44] set_current_metadata --- vllm/attention/backends/abstract.py | 6 ++++++ vllm/attention/backends/flash_attn.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 2bc36ff18a96..79a026abf263 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -88,6 +88,12 @@ def advance_step(self, model_input: "ModelRunnerInputBase", block_size: int, num_seqs: int, num_queries: int) -> None: raise NotImplementedError + @contextmanager + @staticmethod + def set_current_metadata(metadata: "AttentionMetadata"): + """Context manager to set the current metadata.""" + raise NotImplementedError + @dataclass class AttentionMetadata: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 43ca6c9ff160..d1200171b8f4 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,4 +1,5 @@ """Attention layer with FlashAttention.""" +from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type @@ -221,6 +222,20 @@ def copy_blocks( value_caches = [kv_cache[1] for kv_cache in kv_caches] ops.copy_blocks(key_caches, value_caches, src_to_dists) + @contextmanager + @staticmethod + def set_current_metadata( + metadata: "FlashAttentionMetadata"): # type: ignore + global current_metadata + try: + current_metadata = metadata + yield + finally: + current_metadata = None + + +current_metadata: Optional["FlashAttentionMetadata"] = None + @dataclass class FlashAttentionMetadata(AttentionMetadata): From 02929efe29aac9c4e7acf5ccff4144300cfcbced Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 28 Sep 2024 21:33:00 -0700 Subject: [PATCH 02/44] add unified_flash_attention --- vllm/attention/backends/flash_attn.py | 269 ++++++++++++++++---------- 1 file changed, 166 insertions(+), 103 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index d1200171b8f4..426f567c77d5 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -716,108 +716,171 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - torch.ops.vllm.reshape_and_cache_flash( - key, - value, - kv_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - k_scale, - v_scale, - ) + if not torch.compiler.is_compiling(): + global current_metadata + current_metadata = attn_metadata + # if torch.compiler.is_compiling(), the metadata is set + # in the context manager from the caller of the whole model. + + return torch.ops.vllm.unified_flash_attention( + query, + key, + value, + self.num_heads, + self.head_size, + self.num_kv_heads, + kv_cache, + self.kv_cache_dtype, + k_scale, + v_scale, + self.scale, + self.sliding_window, + self.alibi_slopes, + self.logits_soft_cap, + ) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - prefill_output = torch.ops.vllm.flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ) - else: - # prefix-enabled attention - assert prefill_meta.seq_lens is not None - max_seq_len = max(prefill_meta.seq_lens) - prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=self.logits_soft_cap, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - decode_output = torch.ops.vllm.flash_attn_with_kvcache( - decode_query.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, + +@torch.library.custom_op("vllm::unified_flash_attention", mutates_args=[]) +def unified_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + global current_metadata + assert current_metadata is not None + attn_metadata = current_metadata + + if kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + torch.ops.vllm.reshape_and_cache_flash( + key, + value, + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache.numel() or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + prefill_output = torch.ops.vllm.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, causal=True, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ).squeeze(1) - - if prefill_output is None: - assert decode_output is not None - return decode_output.view(num_decode_tokens, hidden_size) - if decode_output is None: - assert prefill_output is not None - return prefill_output.view(num_prefill_tokens, hidden_size) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ) + else: + # prefix-enabled attention + assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) + prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, + causal=True, + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + decode_output = torch.ops.vllm.flash_attn_with_kvcache( + decode_query.unsqueeze(1), + key_cache, + value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=softmax_scale, + causal=True, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ).squeeze(1) + + if prefill_output is None: + assert decode_output is not None + return decode_output.view(num_decode_tokens, hidden_size) + if decode_output is None: + assert prefill_output is not None + return prefill_output.view(num_prefill_tokens, hidden_size) + output = torch.cat([prefill_output, decode_output], dim=0) + return output.view(num_tokens, hidden_size) + + +@unified_flash_attention.register_fake +def _( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + return torch.empty_like(query) From 383b51c3676baacd4788922bd248c775256c82d2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 28 Sep 2024 22:55:55 -0700 Subject: [PATCH 03/44] expose attention_backend from attention metadata --- vllm/attention/backends/abstract.py | 4 ++++ vllm/attention/backends/flash_attn.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 79a026abf263..e91065972a98 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -111,6 +111,10 @@ class AttentionMetadata: # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor + @property + def attention_backend(self): + return AttentionBackend + @property @abstractmethod def prefill_metadata(self) -> Optional["AttentionMetadata"]: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 426f567c77d5..2710a2db56a9 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -296,6 +296,10 @@ class FlashAttentionMetadata(AttentionMetadata): _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + @property + def attention_backend(self): + return FlashAttentionBackend + @property def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: if self.num_prefills == 0: From 861a65e6ea6b8a8100202016101ac1b3ca7c18eb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 28 Sep 2024 23:17:40 -0700 Subject: [PATCH 04/44] init draft --- vllm/compilation/backends.py | 86 ++++++++++++++++++++++++++++- vllm/compilation/wrapper.py | 47 +++++++++++++++- vllm/model_executor/models/llama.py | 38 ++++++++++++- vllm/plugins/__init__.py | 14 ++++- 4 files changed, 180 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index de0b1d8a7575..0ceff54a6d57 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,8 +1,16 @@ import operator +from typing import Callable, Dict, Optional, Tuple +from weakref import ReferenceType import torch import torch.fx as fx +from vllm.logger import init_logger + +from .wrapper import TorchCompileWrapperWithCustomDispatcher + +logger = init_logger(__name__) + def fix_functionalization(graph: fx.Graph): """ @@ -148,9 +156,85 @@ def fix_functionalization(graph: fx.Graph): # print(graph.python_code(root_module="self", verbose=True).src, file=f) -def vllm_backend(graph, example_inputs): +def wrap_inductor(graph, example_inputs, additional_inductor_config): from torch._inductor import config current_config = config.shallow_copy_dict() from torch._inductor.compile_fx import compile_fx + current_config['post_grad_custom_post_pass'] = fix_functionalization + if additional_inductor_config is not None: + current_config.update(additional_inductor_config) return compile_fx(graph, example_inputs, config_patches=current_config) + + +def vllm_backend( + graph, + example_inputs, + model_ref: Optional[ + ReferenceType[TorchCompileWrapperWithCustomDispatcher]] = None, + additional_inductor_config: Optional[Dict] = None) -> Callable: + + # flags for all the seen shapes, whether we need to specialize + runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {} + + # if we need to specialize, the compiled graph for that shape + runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {} + + # this is the first compilation, we will compile a graph with + # dynamic shape, as the caller will mark first dimension as dynamic + logger.info("Compiling a graph for general shapes") + graph_for_symbolic_shape = wrap_inductor(graph, example_inputs, + additional_inductor_config) + + first_run = True + + # this is the function we return to Dynamo to run finally + def compiled_graph_wrapper(*args): + + # Dynamo calling convention: the first integer arguments are the + # runtime shapes of the dynamic dimensions + runtime_shapes = [] + for x in args: + if isinstance(x, int): + runtime_shapes.append(x) + else: + # important to break and exit early + # the list of args can be very long + break + + nonlocal first_run + nonlocal runtime_shapes_to_compile_flags + nonlocal runtime_shapes_to_compiled_graph + + if first_run: + # the first compilation is for profiling, we directly run it + first_run = False + return graph_for_symbolic_shape(*args) + + if model_ref is None: + # no information about the model, we cannot specialize + return graph_for_symbolic_shape(*args) + + model: TorchCompileWrapperWithCustomDispatcher = model_ref() + assert model is not None, "model is garbage collected" + + if runtime_shapes not in runtime_shapes_to_compile_flags: + # we haven't seen this shape before + # query the model if we need to specialize for this shape + runtime_shapes_to_compile_flags[ + runtime_shapes] = model.need_to_specialize(runtime_shapes) + + if not runtime_shapes_to_compile_flags[runtime_shapes]: + # we don't need to specialize for this shape + return graph_for_symbolic_shape(*args) + + if runtime_shapes not in runtime_shapes_to_compiled_graph: + # we need to specialize for this shape, and we haven't compiled + # compile the graph for this shape + logger.info("Compiling a graph for shapes %s", runtime_shapes) + runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor( + graph, args, additional_inductor_config) + + return runtime_shapes_to_compiled_graph[runtime_shapes](*args) + + return compiled_graph_wrapper diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index e923bd36ccc0..b585237d2ff9 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -1,9 +1,10 @@ import os import sys +import weakref from abc import abstractmethod from contextlib import contextmanager from types import CodeType -from typing import Callable, List +from typing import Callable, List, Optional, Tuple import torch @@ -23,7 +24,37 @@ class TorchCompileWrapperWithCustomDispatcher: `torch.compile` over the forward method. """ - def __init__(self, compiled_callable: Callable): + def __init__(self, compiled_callable: Optional[Callable] = None): + + if compiled_callable is None: + # default compilation settings + # compiling the forward method + + # choose the compile backend + + # if the user has set the backend, use it + from vllm.plugins import get_torch_compile_backend + backend = get_torch_compile_backend() + if backend is None: + # otherwise, use the default backend, + # which compiles one general graph and + # several specialized graphs + from vllm.compilation.backends import vllm_backend + # in this case, users can only customize the inductor config + from vllm.plugins import get_inductor_additional_configs + additional_configs = get_inductor_additional_configs() + + from functools import partial + backend = partial( + vllm_backend, + model=weakref.ref(self), + additional_inductor_config=additional_configs) + + compiled_callable = torch.compile( + self.forward, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend) + self.compiled_callable = compiled_callable self.original_code_object = self.__class__.forward.__code__ self.compiled_codes: List[CodeType] = [] @@ -35,6 +66,8 @@ def __init__(self, compiled_callable: Callable): self.use_custom_dispatcher: bool = \ envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER + self.sizes_to_specialize = [] + def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. NOTE: this function can have additional arguments beyond the forward @@ -79,3 +112,13 @@ def dispatch_to_code(self, index: int): self.__class__.forward.__code__ = self.compiled_codes[index] yield self.__class__.forward.__code__ = self.original_code_object + + def need_to_specialize(self, runtime_shapes: Tuple[int, ...]) -> bool: + """Check if the current runtime shapes need to be specialized. + If not, we can use the graph for general shapes. + If yes, we will compile the graph for the current shapes. + The argument `runtime_shapes` is a tuple of integers, representing + the runtime shapes of the dimensions marked as dynamic during graph + capture. + """ + return False diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5ff31e3833ec..61d27d3f3463 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -27,7 +27,9 @@ from torch import nn from transformers import LlamaConfig +import vllm.envs as envs from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -344,7 +346,8 @@ def forward( return hidden_states -class LlamaForCausalLM(nn.Module, SupportsLoRA): +class LlamaForCausalLM(nn.Module, SupportsLoRA, + TorchCompileWrapperWithCustomDispatcher): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -436,6 +439,39 @@ def __init__( self.sampler = Sampler() else: self.lm_head = PPMissingLayer() + self._use_torch_compile = envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE + if self._use_torch_compile: + TorchCompileWrapperWithCustomDispatcher.__init__(self) + + def need_to_specialize(self, runtime_shapes: Tuple[int, ...]) -> bool: + if len(self.sizes_to_specialize) == 0: + return False + return runtime_shapes[0] in self.sizes_to_specialize + + def __call__( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if not self._use_torch_compile: + return self.forward(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors) + with attn_metadata.attention_backend.set_current_metadata( + attn_metadata): + if len(self.compiled_codes) < 1: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(positions, 0) + return self.compiled_callable(input_ids, positions, kv_caches, + attn_metadata, + intermediate_tensors) + with self.dispatch_to_code(0): + model_output = self.forward(input_ids, positions, kv_caches, + attn_metadata, + intermediate_tensors) + return model_output def forward( self, diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 7939688ef0da..1e89ba8569a9 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, Union +from typing import Callable, Dict, Optional, Union import vllm.envs as envs @@ -42,3 +42,15 @@ def set_torch_compile_backend(backend: Union[Callable, str]): def get_torch_compile_backend() -> Optional[Union[Callable, str]]: return _torch_compile_backend + + +_inductor_additional_configs: Optional[Dict] = None + + +def set_inductor_additional_configs(configs: Optional[Dict]): + global _inductor_additional_configs + _inductor_additional_configs = configs + + +def get_inductor_additional_configs() -> Optional[Dict]: + return _inductor_additional_configs From d751293916953766ee862fc9c23750019ce88fbc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 28 Sep 2024 23:25:00 -0700 Subject: [PATCH 05/44] finish --- vllm/compilation/wrapper.py | 6 +++++- vllm/worker/model_runner.py | 17 +++++------------ 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index b585237d2ff9..2e7b0c84da2d 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -4,7 +4,7 @@ from abc import abstractmethod from contextlib import contextmanager from types import CodeType -from typing import Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import torch @@ -113,6 +113,10 @@ def dispatch_to_code(self, index: int): yield self.__class__.forward.__code__ = self.original_code_object + def set_sizes_to_specialize(self, sizes: List[Any]): + """Set the sizes to specialize for the compiled code.""" + self.sizes_to_specialize = sizes + def need_to_specialize(self, runtime_shapes: Tuple[int, ...]) -> bool: """Check if the current runtime shapes need to be specialized. If not, we can use the graph for general shapes. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4ac67a5fade8..9db7c230b40c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,10 +14,10 @@ import torch.distributed import torch.nn as nn -import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -46,8 +46,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, - flatten_2d_lists, is_hip, is_pin_memory_available, - supports_dynamo) + flatten_2d_lists, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -1088,15 +1087,6 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): - from vllm.compilation.backends import vllm_backend - from vllm.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() or vllm_backend - self.model = torch.compile( - self.model, - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=backend) - def save_sharded_state( self, path: str, @@ -1386,6 +1376,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] + if isinstance(self.model, TorchCompileWrapperWithCustomDispatcher): + self.model.set_sizes_to_specialize(batch_size_capture_list) + with self.attn_state.graph_capture( max_batch_size), graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the From dc5e931762b2b7e1aa46e970a7a3ec823ae95e6d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 28 Sep 2024 23:26:54 -0700 Subject: [PATCH 06/44] warning for overwritten config --- vllm/compilation/backends.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 0ceff54a6d57..0e3ad2dc5840 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -161,9 +161,13 @@ def wrap_inductor(graph, example_inputs, additional_inductor_config): current_config = config.shallow_copy_dict() from torch._inductor.compile_fx import compile_fx - current_config['post_grad_custom_post_pass'] = fix_functionalization if additional_inductor_config is not None: current_config.update(additional_inductor_config) + if 'post_grad_custom_post_pass' in current_config: + logger.warning( + "post_grad_custom_post_pass is already set in the config. " + "Overwriting it with the fix_functionalization") + current_config['post_grad_custom_post_pass'] = fix_functionalization return compile_fx(graph, example_inputs, config_patches=current_config) From 558ea39ea8eb2d23e676c2f6659b191c6f49e592 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 00:22:42 -0700 Subject: [PATCH 07/44] unify flags --- tests/compile/utils.py | 4 ++-- vllm/compilation/backends.py | 26 +++++++++++++++++++++++++- vllm/compilation/wrapper.py | 18 +++++------------- vllm/envs.py | 17 +++++++++-------- vllm/model_executor/custom_op.py | 2 +- vllm/model_executor/models/llama.py | 2 +- vllm/plugins/__init__.py | 6 +++--- 7 files changed, 46 insertions(+), 29 deletions(-) diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 2d06a0946d91..f3261131e7dc 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -70,8 +70,8 @@ def check_full_graph_support(model, model_kwargs, backend, tp_size=1): # make sure these models can be captured in full graph mode - if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: - os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" + if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = "1" os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" # Inductor doesn't support fp8/gptq_marlin_24 yet. diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 0e3ad2dc5840..313ab6112069 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,5 +1,5 @@ import operator -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, Union from weakref import ReferenceType import torch @@ -242,3 +242,27 @@ def compiled_graph_wrapper(*args): return runtime_shapes_to_compiled_graph[runtime_shapes](*args) return compiled_graph_wrapper + + +def select_default_backend(level: int) -> Union[str, Callable]: + if level == 1: + backend = "eager" + return backend + assert level in [2, 3], f"Invalid level {level}" + + from vllm.compilation.backends import vllm_backend + from vllm.plugins import get_inductor_additional_configs + additional_configs = get_inductor_additional_configs() + + if level == 3: + if "max_autotune" in additional_configs and not additional_configs[ + "max_autotune"]: + logger.warning( + "max_autotune is disabled, but is overridden by level 3") + additional_configs['max_autotune'] = True + + from functools import partial + backend = partial(vllm_backend, + additional_inductor_config=additional_configs) + + return backend diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 2e7b0c84da2d..e860d4ebf714 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -36,19 +36,11 @@ def __init__(self, compiled_callable: Optional[Callable] = None): from vllm.plugins import get_torch_compile_backend backend = get_torch_compile_backend() if backend is None: - # otherwise, use the default backend, - # which compiles one general graph and - # several specialized graphs - from vllm.compilation.backends import vllm_backend - # in this case, users can only customize the inductor config - from vllm.plugins import get_inductor_additional_configs - additional_configs = get_inductor_additional_configs() - - from functools import partial - backend = partial( - vllm_backend, - model=weakref.ref(self), - additional_inductor_config=additional_configs) + from vllm.compilation.backends import get_default_backend + backend = get_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) + if not isinstance(backend, str): + from functools import partial + backend = partial(backend, model=weakref.ref(self)) compiled_callable = torch.compile( self.forward, diff --git a/vllm/envs.py b/vllm/envs.py index 7cbffc83a625..22fe7dbe22a8 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -201,19 +201,20 @@ def get_default_config_root(): "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH": lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1", - # Internal flag to enable Dynamo graph capture - "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": - lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), + # torch.compile optimization level + # 0: no optimization (don't use torch.compile) + # 1: capture the graph, run in eager mode (seconds of compilation time) + # 2: capture the graph, compile with inductor (minutes of compilation time) + # 3: capture the graph, compile with inductor max-autotune (dozens of minutes of compilation time) # noqa + "VLLM_TORCH_COMPILE_LEVEL": + lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), + + # Internal flag for Dynamo testing "VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": lambda: (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in ("true", "1")), - # Internal flag to control whether we use custom op, - # or use the native pytorch implementation - "VLLM_TEST_COMPILE_NO_CUSTOM_OPS": - lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")), - # Internal flag to enable Dynamo fullgraph capture "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool( diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 9102b5e19ebe..09fe35ad6179 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -55,7 +55,7 @@ def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS: + if envs.VLLM_TORCH_COMPILE_LEVEL >= 2: return self.forward_native if is_hip(): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 61d27d3f3463..35aff20c0ae4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -439,7 +439,7 @@ def __init__( self.sampler = Sampler() else: self.lm_head = PPMissingLayer() - self._use_torch_compile = envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE + self._use_torch_compile = envs.VLLM_TORCH_COMPILE_LEVEL > 0 if self._use_torch_compile: TorchCompileWrapperWithCustomDispatcher.__init__(self) diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 1e89ba8569a9..211fedbc6e2e 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -44,13 +44,13 @@ def get_torch_compile_backend() -> Optional[Union[Callable, str]]: return _torch_compile_backend -_inductor_additional_configs: Optional[Dict] = None +_inductor_additional_configs: Dict = {} -def set_inductor_additional_configs(configs: Optional[Dict]): +def set_inductor_additional_configs(configs: Dict): global _inductor_additional_configs _inductor_additional_configs = configs -def get_inductor_additional_configs() -> Optional[Dict]: +def get_inductor_additional_configs() -> Dict: return _inductor_additional_configs From 1074d7af1daed9606c17693606be552707420503 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 00:27:41 -0700 Subject: [PATCH 08/44] fix code --- vllm/compilation/wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index e860d4ebf714..0b12e635af53 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -36,8 +36,8 @@ def __init__(self, compiled_callable: Optional[Callable] = None): from vllm.plugins import get_torch_compile_backend backend = get_torch_compile_backend() if backend is None: - from vllm.compilation.backends import get_default_backend - backend = get_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) + from vllm.compilation.backends import select_default_backend + backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) if not isinstance(backend, str): from functools import partial backend = partial(backend, model=weakref.ref(self)) From 6f65ec51e473b98a4a0908b72424f3c35725b6b5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 00:37:01 -0700 Subject: [PATCH 09/44] store forward context --- vllm/attention/backends/abstract.py | 10 ---------- vllm/attention/backends/flash_attn.py | 28 +++++---------------------- vllm/compilation/__init__.py | 28 +++++++++++++++++++++++++++ vllm/model_executor/models/llama.py | 4 ++-- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index e91065972a98..2bc36ff18a96 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -88,12 +88,6 @@ def advance_step(self, model_input: "ModelRunnerInputBase", block_size: int, num_seqs: int, num_queries: int) -> None: raise NotImplementedError - @contextmanager - @staticmethod - def set_current_metadata(metadata: "AttentionMetadata"): - """Context manager to set the current metadata.""" - raise NotImplementedError - @dataclass class AttentionMetadata: @@ -111,10 +105,6 @@ class AttentionMetadata: # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - @property - def attention_backend(self): - return AttentionBackend - @property @abstractmethod def prefill_metadata(self) -> Optional["AttentionMetadata"]: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 2710a2db56a9..5a14f5c194cb 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,5 +1,4 @@ """Attention layer with FlashAttention.""" -from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type @@ -14,6 +13,7 @@ compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) +from vllm.compilation import get_forward_context, set_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: @@ -222,20 +222,6 @@ def copy_blocks( value_caches = [kv_cache[1] for kv_cache in kv_caches] ops.copy_blocks(key_caches, value_caches, src_to_dists) - @contextmanager - @staticmethod - def set_current_metadata( - metadata: "FlashAttentionMetadata"): # type: ignore - global current_metadata - try: - current_metadata = metadata - yield - finally: - current_metadata = None - - -current_metadata: Optional["FlashAttentionMetadata"] = None - @dataclass class FlashAttentionMetadata(AttentionMetadata): @@ -296,10 +282,6 @@ class FlashAttentionMetadata(AttentionMetadata): _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None - @property - def attention_backend(self): - return FlashAttentionBackend - @property def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: if self.num_prefills == 0: @@ -721,8 +703,7 @@ def forward( "key/v_scale is not supported in FlashAttention.") if not torch.compiler.is_compiling(): - global current_metadata - current_metadata = attn_metadata + set_forward_context(attn_metadata) # if torch.compiler.is_compiling(), the metadata is set # in the context manager from the caller of the whole model. @@ -767,9 +748,10 @@ def unified_flash_attention( key = key.view(-1, num_kv_heads, head_size) value = value.view(-1, num_kv_heads, head_size) - global current_metadata + current_metadata = get_forward_context() assert current_metadata is not None - attn_metadata = current_metadata + assert isinstance(current_metadata, FlashAttentionMetadata) + attn_metadata: FlashAttentionMetadata = current_metadata if kv_cache.numel() > 0: key_cache = kv_cache[0] diff --git a/vllm/compilation/__init__.py b/vllm/compilation/__init__.py index e69de29bb2d1..229503b167d7 100644 --- a/vllm/compilation/__init__.py +++ b/vllm/compilation/__init__.py @@ -0,0 +1,28 @@ +from contextlib import contextmanager +from typing import Any + +_forward_context: Any = None + + +def get_forward_context() -> Any: + """Get the current forward context.""" + return _forward_context + + +def set_forward_context(context: Any): + """Set the current forward context.""" + global _forward_context + _forward_context = context + + +@contextmanager +def forward_context(context: Any): + """A context manager that stores the current forward context, + can be attention metadata, etc.""" + global _forward_context + prev_context = _forward_context + _forward_context = context + try: + yield + finally: + _forward_context = prev_context diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 35aff20c0ae4..615e161703bf 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,6 +29,7 @@ import vllm.envs as envs from vllm.attention import Attention, AttentionMetadata +from vllm.compilation import forward_context from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -459,8 +460,7 @@ def __call__( if not self._use_torch_compile: return self.forward(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) - with attn_metadata.attention_backend.set_current_metadata( - attn_metadata): + with forward_context(attn_metadata): if len(self.compiled_codes) < 1: torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(positions, 0) From e6c21c7f0629ce130ad1b2a7938df7606e725db6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 00:41:44 -0700 Subject: [PATCH 10/44] fix --- vllm/compilation/wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 0b12e635af53..c147a13b411a 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -40,7 +40,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) if not isinstance(backend, str): from functools import partial - backend = partial(backend, model=weakref.ref(self)) + backend = partial(backend, model_ref=weakref.ref(self)) compiled_callable = torch.compile( self.forward, From ae97d2ca4e32a093f468dc3ef93f8677934ba4f3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 00:44:15 -0700 Subject: [PATCH 11/44] fix --- vllm/compilation/backends.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 313ab6112069..4c58d2597b49 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -163,7 +163,7 @@ def wrap_inductor(graph, example_inputs, additional_inductor_config): if additional_inductor_config is not None: current_config.update(additional_inductor_config) - if 'post_grad_custom_post_pass' in current_config: + if current_config['post_grad_custom_post_pass'] is not None: logger.warning( "post_grad_custom_post_pass is already set in the config. " "Overwriting it with the fix_functionalization") @@ -206,6 +206,8 @@ def compiled_graph_wrapper(*args): # the list of args can be very long break + runtime_shapes = tuple(runtime_shapes) + nonlocal first_run nonlocal runtime_shapes_to_compile_flags nonlocal runtime_shapes_to_compiled_graph From 2b4fe53bb41cc7fdecdcf9f3903a4b7375018884 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 10:37:46 -0700 Subject: [PATCH 12/44] get symint --- vllm/compilation/backends.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4c58d2597b49..9fd52894a864 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -190,23 +190,17 @@ def vllm_backend( graph_for_symbolic_shape = wrap_inductor(graph, example_inputs, additional_inductor_config) + sym_shape_indices = [ + i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt) + ] + first_run = True # this is the function we return to Dynamo to run finally def compiled_graph_wrapper(*args): - # Dynamo calling convention: the first integer arguments are the - # runtime shapes of the dynamic dimensions - runtime_shapes = [] - for x in args: - if isinstance(x, int): - runtime_shapes.append(x) - else: - # important to break and exit early - # the list of args can be very long - break - - runtime_shapes = tuple(runtime_shapes) + runtime_shapes: Tuple[int, + ...] = tuple(args[i] for i in sym_shape_indices) nonlocal first_run nonlocal runtime_shapes_to_compile_flags From a6f0e3be00fbffd7580ef30e3353350293660969 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 10:40:49 -0700 Subject: [PATCH 13/44] fix bugs --- vllm/worker/model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9db7c230b40c..50cc9a716980 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1225,9 +1225,13 @@ def profile_run(self) -> None: # it by reference, rather by specializing on the value ``None``. # the `dtype` argument does not matter, and we use `float32` as # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) - ] * num_layers + for _ in range(num_layers) + ] finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) From 99a281e6a113b458cc92367058edcb0031233f58 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 10:42:39 -0700 Subject: [PATCH 14/44] fix the rest --- vllm/worker/embedding_model_runner.py | 3 ++- vllm/worker/enc_dec_model_runner.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 5c5d20a51e7d..1ccf10f1a60d 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -103,7 +103,8 @@ def execute_model( # a placeholder (it has wide hardware support). kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) - ] * num_layers + for _ in range(num_layers) + ] execute_model_kwargs = { "input_ids": diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 3bb4e28c6e1b..3e37b3519c8e 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -346,7 +346,8 @@ def profile_run(self) -> None: # a placeholder (it has wide hardware support). kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) - ] * num_layers + for _ in range(num_layers) + ] finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) From 44328eb5e0f74f5abfdcda3963ff3a3aa8fcb076 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 10:55:39 -0700 Subject: [PATCH 15/44] fix tpu --- vllm/worker/tpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 2472ac25aee4..038f0c31f95c 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -765,7 +765,8 @@ def forward( slot_mapping = slot_mapping.flatten() attn_metadata.slot_mapping = slot_mapping - hidden_states = self.model( + # directly call `forward` to avoid interference from compilation + hidden_states = self.model.forward( token_ids, position_ids, kv_caches, From 500430b650a0b398d90b6b1d0c08d111424001cf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 11:00:20 -0700 Subject: [PATCH 16/44] leave todo --- vllm/compilation/backends.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 9fd52894a864..98ab6667a689 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -190,6 +190,9 @@ def vllm_backend( graph_for_symbolic_shape = wrap_inductor(graph, example_inputs, additional_inductor_config) + # TODO: Dynamo does not pass all dynamic shapes. + # Need to investigate why. It works now because all the dynamic + # shapes have the same value, and either of them can be used. sym_shape_indices = [ i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt) ] From 5b50c684623ec86dfe9efbb12e456a30635940e4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 11:07:27 -0700 Subject: [PATCH 17/44] add tests --- tests/compile/test_compile.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 tests/compile/test_compile.py diff --git a/tests/compile/test_compile.py b/tests/compile/test_compile.py new file mode 100644 index 000000000000..5e03ca19f88a --- /dev/null +++ b/tests/compile/test_compile.py @@ -0,0 +1,31 @@ +import os + +from vllm import LLM, SamplingParams + + +def test_compile_correctness(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0) + + all_outputs = [] + all_levels = [0, 1, 2] + for level in all_levels: + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(level) + llm = LLM(model="meta-llama/Meta-Llama-3-8B", + enforce_eager=True, + tensor_parallel_size=1, + disable_custom_all_reduce=True) + outputs = llm.generate(prompts, sampling_params) + all_outputs.append(outputs) + reference_outputs = all_outputs[0] + for level, outputs in zip(all_levels[1:], all_outputs[1:]): + for ref_output, output in zip(reference_outputs, outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + ref_generated_text = ref_output.outputs[0].text + assert generated_text == ref_generated_text, f"level: {level}, prompt: {prompt}, generated_text: {generated_text}, ref_generated_text: {ref_generated_text}" # noqa From 55d54fef768df7200ff2882c26cc377661718cd5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 11:09:38 -0700 Subject: [PATCH 18/44] run 3 tests --- tests/compile/test_compile.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_compile.py b/tests/compile/test_compile.py index 5e03ca19f88a..29ae02abd353 100644 --- a/tests/compile/test_compile.py +++ b/tests/compile/test_compile.py @@ -19,7 +19,8 @@ def test_compile_correctness(): llm = LLM(model="meta-llama/Meta-Llama-3-8B", enforce_eager=True, tensor_parallel_size=1, - disable_custom_all_reduce=True) + disable_custom_all_reduce=True, + gpu_memory_utilization=0.3) outputs = llm.generate(prompts, sampling_params) all_outputs.append(outputs) reference_outputs = all_outputs[0] From 954caf806c158d43285ad91e2318a871deefab7a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 11:10:42 -0700 Subject: [PATCH 19/44] rename --- tests/compile/{test_compile.py => test_correctness.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/compile/{test_compile.py => test_correctness.py} (100%) diff --git a/tests/compile/test_compile.py b/tests/compile/test_correctness.py similarity index 100% rename from tests/compile/test_compile.py rename to tests/compile/test_correctness.py From ee2100e593eb7bcac6d8ff2d1c5e6c6f4532a7dc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 11:16:44 -0700 Subject: [PATCH 20/44] support pp --- vllm/model_executor/models/llama.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 615e161703bf..57a8f94aad64 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -464,6 +464,9 @@ def __call__( if len(self.compiled_codes) < 1: torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(positions, 0) + if intermediate_tensors is not None: + for tensors in intermediate_tensors.tensors.values(): + torch._dynamo.mark_dynamic(tensors, 0) return self.compiled_callable(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) From b5fc0f1ccd980c4ef7d832e740296fcecb5cbbad Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 11:31:01 -0700 Subject: [PATCH 21/44] move to decorators --- vllm/compilation/decorators.py | 60 +++++++++++++++++++++++++++++ vllm/model_executor/models/llama.py | 43 ++------------------- 2 files changed, 63 insertions(+), 40 deletions(-) create mode 100644 vllm/compilation/decorators.py diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py new file mode 100644 index 000000000000..602334e8fa44 --- /dev/null +++ b/vllm/compilation/decorators.py @@ -0,0 +1,60 @@ +from typing import List, Optional, Tuple, Union + +import torch + +import vllm.envs as envs +from vllm.attention import AttentionMetadata +from vllm.compilation import forward_context +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.sequence import IntermediateTensors + + +def support_compile_llama_style(cls: type): + cls.__bases__ = (TorchCompileWrapperWithCustomDispatcher, ) + cls.__bases__ + + old_init = cls.__init__ + + def __init__(self, *args, **kwargs): + old_init(self, *args, **kwargs) + self._use_torch_compile = envs.VLLM_TORCH_COMPILE_LEVEL > 0 + if self._use_torch_compile: + TorchCompileWrapperWithCustomDispatcher.__init__(self) + + cls.__init__ = __init__ + + def need_to_specialize(self, runtime_shapes: Tuple[int, ...]) -> bool: + if len(self.sizes_to_specialize) == 0: + return False + return runtime_shapes[0] in self.sizes_to_specialize + + cls.need_to_specialize = need_to_specialize + + def __call__( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if not self._use_torch_compile: + return self.forward(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors) + with forward_context(attn_metadata): + if len(self.compiled_codes) < 1: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(positions, 0) + if intermediate_tensors is not None: + for tensors in intermediate_tensors.tensors.values(): + torch._dynamo.mark_dynamic(tensors, 0) + return self.compiled_callable(input_ids, positions, kv_caches, + attn_metadata, + intermediate_tensors) + with self.dispatch_to_code(0): + model_output = self.forward(input_ids, positions, kv_caches, + attn_metadata, + intermediate_tensors) + return model_output + + cls.__call__ = __call__ + return cls diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 57a8f94aad64..d44e29b78bdb 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -27,10 +27,8 @@ from torch import nn from transformers import LlamaConfig -import vllm.envs as envs from vllm.attention import Attention, AttentionMetadata -from vllm.compilation import forward_context -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.compilation.decorators import support_compile_llama_style from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -347,8 +345,8 @@ def forward( return hidden_states -class LlamaForCausalLM(nn.Module, SupportsLoRA, - TorchCompileWrapperWithCustomDispatcher): +@support_compile_llama_style +class LlamaForCausalLM(nn.Module, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -440,41 +438,6 @@ def __init__( self.sampler = Sampler() else: self.lm_head = PPMissingLayer() - self._use_torch_compile = envs.VLLM_TORCH_COMPILE_LEVEL > 0 - if self._use_torch_compile: - TorchCompileWrapperWithCustomDispatcher.__init__(self) - - def need_to_specialize(self, runtime_shapes: Tuple[int, ...]) -> bool: - if len(self.sizes_to_specialize) == 0: - return False - return runtime_shapes[0] in self.sizes_to_specialize - - def __call__( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if not self._use_torch_compile: - return self.forward(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors) - with forward_context(attn_metadata): - if len(self.compiled_codes) < 1: - torch._dynamo.mark_dynamic(input_ids, 0) - torch._dynamo.mark_dynamic(positions, 0) - if intermediate_tensors is not None: - for tensors in intermediate_tensors.tensors.values(): - torch._dynamo.mark_dynamic(tensors, 0) - return self.compiled_callable(input_ids, positions, kv_caches, - attn_metadata, - intermediate_tensors) - with self.dispatch_to_code(0): - model_output = self.forward(input_ids, positions, kv_caches, - attn_metadata, - intermediate_tensors) - return model_output def forward( self, From 246e6e5fb8e4ca9521afa3cf2731cada346130e0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 11:38:25 -0700 Subject: [PATCH 22/44] fix mro --- vllm/compilation/decorators.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 602334e8fa44..709049f51cd2 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -10,7 +10,10 @@ def support_compile_llama_style(cls: type): - cls.__bases__ = (TorchCompileWrapperWithCustomDispatcher, ) + cls.__bases__ + # take care of method resolution order + # make sure super().__init__ is called on the base class + # other than TorchCompileWrapperWithCustomDispatcher + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) old_init = cls.__init__ From 49aa7ccff73e21c8e5d26d9556c00bf2a8899cdb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 11:40:01 -0700 Subject: [PATCH 23/44] add comments --- vllm/compilation/decorators.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 709049f51cd2..75ae578ac8d0 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -10,6 +10,12 @@ def support_compile_llama_style(cls: type): + """ + A decorator to add support for compiling the forward method of a class. + If a module's **forward signature** is compatible with llama, this + decorator can be used to enable the compilation of the forward method. + """ + # take care of method resolution order # make sure super().__init__ is called on the base class # other than TorchCompileWrapperWithCustomDispatcher From 99144b3a63613c3eb3f4b98b15de46dd6e674c69 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 12:10:12 -0700 Subject: [PATCH 24/44] fix mutates_args --- vllm/attention/backends/flash_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 5a14f5c194cb..346f9393d176 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -725,7 +725,8 @@ def forward( ) -@torch.library.custom_op("vllm::unified_flash_attention", mutates_args=[]) +@torch.library.custom_op("vllm::unified_flash_attention", + mutates_args=["kv_cache"]) def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, From 6ae09bd08c7f59a5414c8b7050542d28b51ce20b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 14:02:58 -0700 Subject: [PATCH 25/44] fix forward context --- vllm/attention/backends/flash_attn.py | 8 +++++++- vllm/worker/model_runner.py | 24 +++++++++++++----------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 346f9393d176..fa803cc6d76c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -703,11 +703,12 @@ def forward( "key/v_scale is not supported in FlashAttention.") if not torch.compiler.is_compiling(): + old_context = get_forward_context() set_forward_context(attn_metadata) # if torch.compiler.is_compiling(), the metadata is set # in the context manager from the caller of the whole model. - return torch.ops.vllm.unified_flash_attention( + output = torch.ops.vllm.unified_flash_attention( query, key, value, @@ -724,6 +725,11 @@ def forward( self.logits_soft_cap, ) + if not torch.compiler.is_compiling(): + set_forward_context(old_context) + + return output + @torch.library.custom_op("vllm::unified_flash_attention", mutates_args=["kv_cache"]) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 50cc9a716980..33744a12e2a2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -17,6 +17,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState +from vllm.compilation import forward_context from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, @@ -1455,8 +1456,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # encoder-decoder models. self._update_inputs_to_capture_for_enc_dec_model( capture_inputs) - - graph_runner.capture(**capture_inputs) + with forward_context(attn_metadata): + graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( graph_runner) @@ -1598,15 +1599,16 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + with forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): From ec2191fa09baf9464616e29a853f9d6fb13f00e4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 14:06:42 -0700 Subject: [PATCH 26/44] surface errors --- vllm/attention/backends/flash_attn.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index fa803cc6d76c..f0a21262a076 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -13,7 +13,7 @@ compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.compilation import get_forward_context, set_forward_context +from vllm.compilation import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: @@ -702,12 +702,6 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") - if not torch.compiler.is_compiling(): - old_context = get_forward_context() - set_forward_context(attn_metadata) - # if torch.compiler.is_compiling(), the metadata is set - # in the context manager from the caller of the whole model. - output = torch.ops.vllm.unified_flash_attention( query, key, @@ -725,9 +719,6 @@ def forward( self.logits_soft_cap, ) - if not torch.compiler.is_compiling(): - set_forward_context(old_context) - return output From 889794e005e0b683812afc7db63ea288b6e60be9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 14:11:20 -0700 Subject: [PATCH 27/44] fix more --- vllm/worker/embedding_model_runner.py | 4 +++- vllm/worker/enc_dec_model_runner.py | 24 +++++++++++++----------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 1ccf10f1a60d..6e7d452ba670 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -3,6 +3,7 @@ import torch +from vllm.compilation import forward_context from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -119,7 +120,8 @@ def execute_model( device=self.device), } - hidden_states = model_executable(**execute_model_kwargs) + with forward_context(model_input.attn_metadata): + hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. if not self.is_driver_worker: diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 3e37b3519c8e..b1af1287eed5 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -11,6 +11,7 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, get_global_forced_attn_backend, global_force_attn_backend) +from vllm.compilation import forward_context from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -198,17 +199,18 @@ def execute_model( } if self.has_seqlen_agnostic else {} multi_modal_kwargs = model_input.multi_modal_kwargs or {} - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - encoder_input_ids=model_input.encoder_input_tokens, - encoder_positions=model_input.encoder_input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + with forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + encoder_input_ids=model_input.encoder_input_tokens, + encoder_positions=model_input.encoder_input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) From ed80d6728d3936e2da64bdf89e67a4b388ec1aec Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 16:07:59 -0700 Subject: [PATCH 28/44] fix spec decode --- vllm/spec_decode/draft_model_runner.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index cf64af72a14a..659e0c9f7525 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -2,6 +2,7 @@ import torch +from vllm.compilation import forward_context from vllm.model_executor.layers.sampler import SamplerOutput try: @@ -293,16 +294,17 @@ def execute_model( if previous_hidden_states is not None else {} # Run model - hidden_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **kwargs, - ) + with forward_context(model_input.attn_metadata): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **kwargs, + ) # Compute the logits. logits = self.model.compute_logits(hidden_states, From 2b0c543303f2b533fd45f82a3477e3397ae27b76 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 16:47:52 -0700 Subject: [PATCH 29/44] complicated bug, thank you chatgpt --- vllm/attention/backends/flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f0a21262a076..eb5f0c0c1618 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -788,7 +788,7 @@ def unified_flash_attention( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if (kv_cache.numel() or prefill_meta.block_tables is None + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the From ca79dd5e585e63bd6dba7fb1935204a68b286589 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 16:55:15 -0700 Subject: [PATCH 30/44] simplification, model runner set context, model does not --- vllm/compilation/decorators.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 75ae578ac8d0..0c2840b35949 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -4,7 +4,6 @@ import vllm.envs as envs from vllm.attention import AttentionMetadata -from vllm.compilation import forward_context from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.sequence import IntermediateTensors @@ -49,21 +48,18 @@ def __call__( if not self._use_torch_compile: return self.forward(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) - with forward_context(attn_metadata): - if len(self.compiled_codes) < 1: - torch._dynamo.mark_dynamic(input_ids, 0) - torch._dynamo.mark_dynamic(positions, 0) - if intermediate_tensors is not None: - for tensors in intermediate_tensors.tensors.values(): - torch._dynamo.mark_dynamic(tensors, 0) - return self.compiled_callable(input_ids, positions, kv_caches, - attn_metadata, - intermediate_tensors) - with self.dispatch_to_code(0): - model_output = self.forward(input_ids, positions, kv_caches, - attn_metadata, - intermediate_tensors) - return model_output + if len(self.compiled_codes) < 1: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(positions, 0) + if intermediate_tensors is not None: + for tensors in intermediate_tensors.tensors.values(): + torch._dynamo.mark_dynamic(tensors, 0) + return self.compiled_callable(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + with self.dispatch_to_code(0): + model_output = self.forward(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output cls.__call__ = __call__ return cls From fad55cb07d2c02af01aed2f0c2cbeae3330142fd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 17:12:42 -0700 Subject: [PATCH 31/44] fix tests --- tests/compile/test_full_graph.py | 13 ++++++++----- tests/compile/test_full_graph_multi_gpu.py | 10 ++++++---- tests/compile/test_full_graph_smoke.py | 13 ++++++++----- tests/compile/utils.py | 15 +++++++-------- 4 files changed, 29 insertions(+), 22 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 5dd65ad7236f..03e377023b99 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,13 +1,16 @@ import pytest -from vllm.compilation.backends import vllm_backend - +from ..utils import fork_new_process_for_each_test from .utils import TEST_MODELS, check_full_graph_support @pytest.mark.parametrize("model_info", TEST_MODELS) -@pytest.mark.parametrize("backend", ["eager", vllm_backend]) -def test_full_graph(model_info, backend): +@pytest.mark.parametrize("optimization_level", [1, 2]) +@fork_new_process_for_each_test +def test_full_graph(model_info, optimization_level): model = model_info[0] model_kwargs = model_info[1] - check_full_graph_support(model, model_kwargs, backend, tp_size=1) + check_full_graph_support(model, + model_kwargs, + optimization_level, + tp_size=1) diff --git a/tests/compile/test_full_graph_multi_gpu.py b/tests/compile/test_full_graph_multi_gpu.py index e9883d5254e7..caf66c0d269f 100644 --- a/tests/compile/test_full_graph_multi_gpu.py +++ b/tests/compile/test_full_graph_multi_gpu.py @@ -1,6 +1,5 @@ import pytest -from vllm.compilation.backends import vllm_backend from vllm.utils import cuda_device_count_stateless from ..utils import fork_new_process_for_each_test @@ -9,9 +8,9 @@ @pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) @pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("backend", ["eager", vllm_backend]) +@pytest.mark.parametrize("optimization_level", [1, 2]) @fork_new_process_for_each_test -def test_full_graph_multi_gpu(model_info, tp_size, backend): +def test_full_graph_multi_gpu(model_info, tp_size, optimization_level): model = model_info[0] model_kwargs = model_info[1] @@ -19,4 +18,7 @@ def test_full_graph_multi_gpu(model_info, tp_size, backend): if cuda_device_count_stateless() < tp_size: pytest.skip("Not enough CUDA devices for the test.") - check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size) + check_full_graph_support(model, + model_kwargs, + optimization_level, + tp_size=tp_size) diff --git a/tests/compile/test_full_graph_smoke.py b/tests/compile/test_full_graph_smoke.py index 0c5a95b4ead4..0e12e43e2483 100644 --- a/tests/compile/test_full_graph_smoke.py +++ b/tests/compile/test_full_graph_smoke.py @@ -1,13 +1,16 @@ import pytest -from vllm.compilation.backends import vllm_backend - +from ..utils import fork_new_process_for_each_test from .utils import TEST_MODELS_SMOKE, check_full_graph_support @pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) -@pytest.mark.parametrize("backend", ["eager", vllm_backend]) -def test_full_graph(model_info, backend): +@pytest.mark.parametrize("optimization_level", [1, 2]) +@fork_new_process_for_each_test +def test_full_graph(model_info, optimization_level): model = model_info[0] model_kwargs = model_info[1] - check_full_graph_support(model, model_kwargs, backend, tp_size=1) + check_full_graph_support(model, + model_kwargs, + optimization_level, + tp_size=1) diff --git a/tests/compile/utils.py b/tests/compile/utils.py index f3261131e7dc..6509d530011b 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,7 +4,6 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.plugins import set_torch_compile_backend from vllm.utils import is_hip TEST_MODELS_SMOKE = [ @@ -68,20 +67,20 @@ })) -def check_full_graph_support(model, model_kwargs, backend, tp_size=1): +def check_full_graph_support(model, + model_kwargs, + optimization_level, + tp_size=1): # make sure these models can be captured in full graph mode - if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = "1" - os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level) + os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" # Inductor doesn't support fp8/gptq_marlin_24 yet. quantization = model_kwargs.get("quantization") if (quantization == "fp8" or quantization == "gptq_marlin" - or quantization == "gptq_marlin_24") and backend != "eager": + or quantization == "gptq_marlin_24") and optimization_level > 1: return - set_torch_compile_backend(backend) - prompts = [ "Hello, my name is", "The president of the United States is", From e1958414ab3f48a6c2bb0491d16fc5e8ceb63bd5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 17:26:08 -0700 Subject: [PATCH 32/44] add compare_all_settings --- tests/utils.py | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 3eff77f396e1..5c1ebd9ce50d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -180,9 +180,24 @@ def compare_two_settings(model: str, env1: The first set of environment variables to pass to the API server. env2: The second set of environment variables to pass to the API server. """ + compare_all_settings(model, [arg1, arg2], [env1, env2], max_wait_seconds) + +def compare_all_settings(model: str, + all_args: List[List[str]], + all_envs: List[Optional[Dict[str, str]]], + max_wait_seconds: Optional[float] = None) -> None: + """ + Launch API server with several different sets of arguments/environments + and compare the results of the API calls with the first set of arguments. + + Args: + model: The model to test. + all_args: A list of argument lists to pass to the API server. + all_envs: A list of environment dictionaries to pass to the API server. + """ trust_remote_code = "--trust-remote-code" - if trust_remote_code in arg1 or trust_remote_code in arg2: + if any(trust_remote_code in args for args in all_args): tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) else: @@ -191,7 +206,7 @@ def compare_two_settings(model: str, prompt = "Hello, my name is" token_ids = tokenizer(prompt)["input_ids"] results = [] - for args, env in ((arg1, env1), (arg2, env2)): + for args, env in zip(all_args, all_envs): with RemoteOpenAIServer(model, args, env_dict=env, @@ -299,13 +314,17 @@ def compare_two_settings(model: str, "texts": texts, }) - n = len(results) // 2 - arg1_results = results[:n] - arg2_results = results[n:] - for arg1_result, arg2_result in zip(arg1_results, arg2_results): - assert arg1_result == arg2_result, ( - f"Results for {model=} are not the same with {arg1=} and {arg2=}. " - f"{arg1_result=} != {arg2_result=}") + n = len(results) // len(all_args) + ref_results = results[:n] + ref_args = all_args[0] + for i in range(1, len(all_args)): + compare_results = results[i * n:(i + 1) * n] + compare_args = all_args[i] + for ref_result, compare_result in zip(ref_results, compare_results): + assert ref_result == compare_result, ( + f"Results for {model=} are not the same with " + f"{ref_args=} and {compare_args=}. " + f"{ref_result=} != {compare_result=}") def init_test_distributed_environment( From fbd3231c3796a87ebbe1335781d0728c996ca32c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 17:31:04 -0700 Subject: [PATCH 33/44] change tests --- tests/compile/test_correctness.py | 38 +++++++++---------------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/tests/compile/test_correctness.py b/tests/compile/test_correctness.py index 29ae02abd353..b52b40a7aab6 100644 --- a/tests/compile/test_correctness.py +++ b/tests/compile/test_correctness.py @@ -1,32 +1,16 @@ -import os +from typing import Dict, List, Optional -from vllm import LLM, SamplingParams +from ..utils import compare_all_settings def test_compile_correctness(): - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", + all_args = [ + ["--enforce-eager"], + ["--enforce-eager"], + ["--enforce-eager"], ] - sampling_params = SamplingParams(temperature=0) - - all_outputs = [] - all_levels = [0, 1, 2] - for level in all_levels: - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(level) - llm = LLM(model="meta-llama/Meta-Llama-3-8B", - enforce_eager=True, - tensor_parallel_size=1, - disable_custom_all_reduce=True, - gpu_memory_utilization=0.3) - outputs = llm.generate(prompts, sampling_params) - all_outputs.append(outputs) - reference_outputs = all_outputs[0] - for level, outputs in zip(all_levels[1:], all_outputs[1:]): - for ref_output, output in zip(reference_outputs, outputs): - prompt = output.prompt - generated_text = output.outputs[0].text - ref_generated_text = ref_output.outputs[0].text - assert generated_text == ref_generated_text, f"level: {level}, prompt: {prompt}, generated_text: {generated_text}, ref_generated_text: {ref_generated_text}" # noqa + all_envs: List[Optional[Dict[str, str]]] = [{ + "VLLM_TORCH_COMPILE_LEVEL": + str(i) + } for i in range(3)] + compare_all_settings("meta-llama/Meta-Llama-3-8B", all_args, all_envs) From 4781c14a4a75128fbd311a7337d10c1c5261b1e8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 17:35:17 -0700 Subject: [PATCH 34/44] repurpose smoke tests --- tests/compile/test_full_graph_smoke.py | 26 ++++++++++++++++---------- tests/compile/utils.py | 7 +++---- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/tests/compile/test_full_graph_smoke.py b/tests/compile/test_full_graph_smoke.py index 0e12e43e2483..ddfcba7b75f7 100644 --- a/tests/compile/test_full_graph_smoke.py +++ b/tests/compile/test_full_graph_smoke.py @@ -1,16 +1,22 @@ +from typing import Dict, List, Optional + import pytest -from ..utils import fork_new_process_for_each_test -from .utils import TEST_MODELS_SMOKE, check_full_graph_support +from ..utils import compare_all_settings +from .utils import TEST_MODELS_SMOKE @pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) -@pytest.mark.parametrize("optimization_level", [1, 2]) -@fork_new_process_for_each_test -def test_full_graph(model_info, optimization_level): +def test_compile_correctness(model_info): model = model_info[0] - model_kwargs = model_info[1] - check_full_graph_support(model, - model_kwargs, - optimization_level, - tp_size=1) + model_args = model_info[1] + all_args = [ + ["--enforce-eager"] + model_args, + ["--enforce-eager"] + model_args, + ["--enforce-eager"] + model_args, + ] + all_envs: List[Optional[Dict[str, str]]] = [{ + "VLLM_TORCH_COMPILE_LEVEL": + str(i) + } for i in range(3)] + compare_all_settings(model, all_args, all_envs) diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 6509d530011b..eb5b2e741f96 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -7,10 +7,9 @@ from vllm.utils import is_hip TEST_MODELS_SMOKE = [ - ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { - "quantization": "compressed-tensors" - }), - ("meta-llama/Meta-Llama-3-8B", {}), + ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", + ["--quantization", "compressed-tensors"]), + ("meta-llama/Meta-Llama-3-8B", []), ] TEST_MODELS = [ From cbc922923cdbaecbbb42d31c9c41782b437341e4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 17:35:36 -0700 Subject: [PATCH 35/44] remove --- tests/compile/test_correctness.py | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100644 tests/compile/test_correctness.py diff --git a/tests/compile/test_correctness.py b/tests/compile/test_correctness.py deleted file mode 100644 index b52b40a7aab6..000000000000 --- a/tests/compile/test_correctness.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Dict, List, Optional - -from ..utils import compare_all_settings - - -def test_compile_correctness(): - all_args = [ - ["--enforce-eager"], - ["--enforce-eager"], - ["--enforce-eager"], - ] - all_envs: List[Optional[Dict[str, str]]] = [{ - "VLLM_TORCH_COMPILE_LEVEL": - str(i) - } for i in range(3)] - compare_all_settings("meta-llama/Meta-Llama-3-8B", all_args, all_envs) From 5970a6fd38a1dcd84f388173d109c92c0ad600a1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 17:36:30 -0700 Subject: [PATCH 36/44] restore --- tests/compile/test_full_graph_smoke.py | 22 ---------------------- 1 file changed, 22 deletions(-) delete mode 100644 tests/compile/test_full_graph_smoke.py diff --git a/tests/compile/test_full_graph_smoke.py b/tests/compile/test_full_graph_smoke.py deleted file mode 100644 index ddfcba7b75f7..000000000000 --- a/tests/compile/test_full_graph_smoke.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Dict, List, Optional - -import pytest - -from ..utils import compare_all_settings -from .utils import TEST_MODELS_SMOKE - - -@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) -def test_compile_correctness(model_info): - model = model_info[0] - model_args = model_info[1] - all_args = [ - ["--enforce-eager"] + model_args, - ["--enforce-eager"] + model_args, - ["--enforce-eager"] + model_args, - ] - all_envs: List[Optional[Dict[str, str]]] = [{ - "VLLM_TORCH_COMPILE_LEVEL": - str(i) - } for i in range(3)] - compare_all_settings(model, all_args, all_envs) From ca587a8197fa66957b249dddfee39e8ada17ab18 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 17:36:36 -0700 Subject: [PATCH 37/44] restore --- tests/compile/test_correctness.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 tests/compile/test_correctness.py diff --git a/tests/compile/test_correctness.py b/tests/compile/test_correctness.py new file mode 100644 index 000000000000..ddfcba7b75f7 --- /dev/null +++ b/tests/compile/test_correctness.py @@ -0,0 +1,22 @@ +from typing import Dict, List, Optional + +import pytest + +from ..utils import compare_all_settings +from .utils import TEST_MODELS_SMOKE + + +@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) +def test_compile_correctness(model_info): + model = model_info[0] + model_args = model_info[1] + all_args = [ + ["--enforce-eager"] + model_args, + ["--enforce-eager"] + model_args, + ["--enforce-eager"] + model_args, + ] + all_envs: List[Optional[Dict[str, str]]] = [{ + "VLLM_TORCH_COMPILE_LEVEL": + str(i) + } for i in range(3)] + compare_all_settings(model, all_args, all_envs) From 7ea321cd6ba0b96e30913ddba20f76bd963a41aa Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 17:37:33 -0700 Subject: [PATCH 38/44] restore --- tests/compile/{test_correctness.py => test_full_graph_smoke.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/compile/{test_correctness.py => test_full_graph_smoke.py} (100%) diff --git a/tests/compile/test_correctness.py b/tests/compile/test_full_graph_smoke.py similarity index 100% rename from tests/compile/test_correctness.py rename to tests/compile/test_full_graph_smoke.py From f3a5a5e98d99bc34252da18d77fa893d98b54b5a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 19:04:36 -0700 Subject: [PATCH 39/44] fix for pp --- vllm/sequence.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 781bcedde2b5..894473deb11a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1149,10 +1149,9 @@ def __eq__(self, other: object) -> bool: return self.embeddings == other.embeddings -class IntermediateTensors( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] +# cannot use msgspec.Struct here because Dynamo does not support it +@dataclass +class IntermediateTensors: """For all pipeline stages except the last, we need to return the hidden states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. From a8644754ac97e6ef1204f30795badb4a326316cb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 19:07:53 -0700 Subject: [PATCH 40/44] add tests --- tests/compile/test_full_graph_multi_gpu.py | 24 ---------------------- tests/compile/test_full_graph_smoke.py | 14 ++++++++----- 2 files changed, 9 insertions(+), 29 deletions(-) delete mode 100644 tests/compile/test_full_graph_multi_gpu.py diff --git a/tests/compile/test_full_graph_multi_gpu.py b/tests/compile/test_full_graph_multi_gpu.py deleted file mode 100644 index caf66c0d269f..000000000000 --- a/tests/compile/test_full_graph_multi_gpu.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest - -from vllm.utils import cuda_device_count_stateless - -from ..utils import fork_new_process_for_each_test -from .utils import TEST_MODELS_SMOKE, check_full_graph_support - - -@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) -@pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("optimization_level", [1, 2]) -@fork_new_process_for_each_test -def test_full_graph_multi_gpu(model_info, tp_size, optimization_level): - model = model_info[0] - model_kwargs = model_info[1] - - # Skip the test if there are not enough CUDA devices. - if cuda_device_count_stateless() < tp_size: - pytest.skip("Not enough CUDA devices for the test.") - - check_full_graph_support(model, - model_kwargs, - optimization_level, - tp_size=tp_size) diff --git a/tests/compile/test_full_graph_smoke.py b/tests/compile/test_full_graph_smoke.py index ddfcba7b75f7..a73fdea0d14e 100644 --- a/tests/compile/test_full_graph_smoke.py +++ b/tests/compile/test_full_graph_smoke.py @@ -2,19 +2,23 @@ import pytest +from vllm.utils import cuda_device_count_stateless + from ..utils import compare_all_settings from .utils import TEST_MODELS_SMOKE @pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) -def test_compile_correctness(model_info): +@pytest.mark.parametrize("pp_size", [1, 2]) +def test_compile_correctness(model_info, pp_size): + if cuda_device_count_stateless() < pp_size: + pytest.skip("Not enough CUDA devices for the test.") model = model_info[0] model_args = model_info[1] all_args = [ - ["--enforce-eager"] + model_args, - ["--enforce-eager"] + model_args, - ["--enforce-eager"] + model_args, - ] + ["--enforce-eager"] + model_args + ["--max_model_len", "1024"] + + ["-pp", str(pp_size)], + ] * 3 all_envs: List[Optional[Dict[str, str]]] = [{ "VLLM_TORCH_COMPILE_LEVEL": str(i) From 1d9aacdd10a50de41c799f8afef4404e9c0cfebb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 19:08:23 -0700 Subject: [PATCH 41/44] rename --- .../{test_full_graph_smoke.py => test_basic_correctness.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/compile/{test_full_graph_smoke.py => test_basic_correctness.py} (100%) diff --git a/tests/compile/test_full_graph_smoke.py b/tests/compile/test_basic_correctness.py similarity index 100% rename from tests/compile/test_full_graph_smoke.py rename to tests/compile/test_basic_correctness.py From f2330876f71306fdfb00c767c5720df8c66c2759 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 19:12:14 -0700 Subject: [PATCH 42/44] update tests --- .buildkite/test-pipeline.yaml | 6 ++++-- tests/compile/test_basic_correctness.py | 7 +++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index bb42b5f29a72..d5ebff80b91d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -110,7 +110,9 @@ steps: - vllm/core/ - tests/distributed - tests/spec_decode/e2e/test_integration_dist_tp4 + - tests/compile commands: + - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py @@ -218,7 +220,7 @@ steps: - vllm/ - tests/compile commands: - - pytest -v -s compile/test_full_graph_smoke.py + - pytest -v -s compile/test_basic_correctness.py - label: "PyTorch Fullgraph Test" # 18min source_file_dependencies: @@ -382,7 +384,7 @@ steps: - tests/distributed/ - vllm/compilation commands: - - pytest -v -s ./compile/test_full_graph_multi_gpu.py + - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index a73fdea0d14e..5b009e7d1a7b 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -11,8 +11,11 @@ @pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) @pytest.mark.parametrize("pp_size", [1, 2]) def test_compile_correctness(model_info, pp_size): - if cuda_device_count_stateless() < pp_size: - pytest.skip("Not enough CUDA devices for the test.") + # this test is run under multiple suits, with different GPUs. + # make sure we only run the test with correct CUDA devices. + # don't use "<", as it will duplicate the tests. + if cuda_device_count_stateless() != pp_size: + pytest.skip("Not correct CUDA devices for the test.") model = model_info[0] model_args = model_info[1] all_args = [ From d2f1b97d4dad32e6839f73f3b6b638102158986f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 19:13:57 -0700 Subject: [PATCH 43/44] prepare for tp test --- tests/compile/test_basic_correctness.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 5b009e7d1a7b..c31b87421856 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -10,18 +10,17 @@ @pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) @pytest.mark.parametrize("pp_size", [1, 2]) -def test_compile_correctness(model_info, pp_size): +@pytest.mark.parametrize("tp_size", [1]) +def test_compile_correctness(model_info, pp_size, tp_size): # this test is run under multiple suits, with different GPUs. # make sure we only run the test with correct CUDA devices. # don't use "<", as it will duplicate the tests. - if cuda_device_count_stateless() != pp_size: + if cuda_device_count_stateless() != pp_size * tp_size: pytest.skip("Not correct CUDA devices for the test.") model = model_info[0] model_args = model_info[1] - all_args = [ - ["--enforce-eager"] + model_args + ["--max_model_len", "1024"] + - ["-pp", str(pp_size)], - ] * 3 + all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"] + + ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3 all_envs: List[Optional[Dict[str, str]]] = [{ "VLLM_TORCH_COMPILE_LEVEL": str(i) From 1b8ee5a3c0cd8a8b112456b1d42fe6c670ef5137 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Sep 2024 19:31:15 -0700 Subject: [PATCH 44/44] early error --- tests/utils.py | 98 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 38 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 5c1ebd9ce50d..618b8b7a70f4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -205,8 +205,9 @@ def compare_all_settings(model: str, prompt = "Hello, my name is" token_ids = tokenizer(prompt)["input_ids"] - results = [] - for args, env in zip(all_args, all_envs): + ref_results: List = [] + for i, (args, env) in enumerate(zip(all_args, all_envs)): + compare_results: List = [] with RemoteOpenAIServer(model, args, env_dict=env, @@ -217,10 +218,13 @@ def compare_all_settings(model: str, models = client.models.list() models = models.data served_model = models[0] - results.append({ - "test": "models_list", - "id": served_model.id, - "root": served_model.root, + (ref_results if i == 0 else compare_results).append({ + "test": + "models_list", + "id": + served_model.id, + "root": + served_model.root, }) # test with text prompt @@ -229,11 +233,15 @@ def compare_all_settings(model: str, max_tokens=5, temperature=0.0) - results.append({ - "test": "single_completion", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, + (ref_results if i == 0 else compare_results).append({ + "test": + "single_completion", + "text": + completion.choices[0].text, + "finish_reason": + completion.choices[0].finish_reason, + "usage": + completion.usage, }) # test using token IDs @@ -244,11 +252,15 @@ def compare_all_settings(model: str, temperature=0.0, ) - results.append({ - "test": "token_ids", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, + (ref_results if i == 0 else compare_results).append({ + "test": + "token_ids", + "text": + completion.choices[0].text, + "finish_reason": + completion.choices[0].finish_reason, + "usage": + completion.usage, }) # test seeded random sampling @@ -258,11 +270,15 @@ def compare_all_settings(model: str, seed=33, temperature=1.0) - results.append({ - "test": "seeded_sampling", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, + (ref_results if i == 0 else compare_results).append({ + "test": + "seeded_sampling", + "text": + completion.choices[0].text, + "finish_reason": + completion.choices[0].finish_reason, + "usage": + completion.usage, }) # test seeded random sampling with multiple prompts @@ -272,7 +288,7 @@ def compare_all_settings(model: str, seed=33, temperature=1.0) - results.append({ + (ref_results if i == 0 else compare_results).append({ "test": "seeded_sampling", "text": [choice.text for choice in completion.choices], @@ -290,10 +306,13 @@ def compare_all_settings(model: str, temperature=0.0, ) - results.append({ - "test": "simple_list", - "text0": batch.choices[0].text, - "text1": batch.choices[1].text, + (ref_results if i == 0 else compare_results).append({ + "test": + "simple_list", + "text0": + batch.choices[0].text, + "text1": + batch.choices[1].text, }) # test streaming @@ -309,22 +328,25 @@ def compare_all_settings(model: str, assert len(chunk.choices) == 1 choice = chunk.choices[0] texts[choice.index] += choice.text - results.append({ + (ref_results if i == 0 else compare_results).append({ "test": "streaming", "texts": texts, }) - n = len(results) // len(all_args) - ref_results = results[:n] - ref_args = all_args[0] - for i in range(1, len(all_args)): - compare_results = results[i * n:(i + 1) * n] - compare_args = all_args[i] - for ref_result, compare_result in zip(ref_results, compare_results): - assert ref_result == compare_result, ( - f"Results for {model=} are not the same with " - f"{ref_args=} and {compare_args=}. " - f"{ref_result=} != {compare_result=}") + if i > 0: + # if any setting fails, raise an error early + ref_args = all_args[0] + ref_envs = all_envs[0] + compare_args = all_args[i] + compare_envs = all_envs[i] + for ref_result, compare_result in zip(ref_results, + compare_results): + assert ref_result == compare_result, ( + f"Results for {model=} are not the same.\n" + f"{ref_args=} {ref_envs=}\n" + f"{compare_args=} {compare_envs=}\n" + f"{ref_result=}\n" + f"{compare_result=}\n") def init_test_distributed_environment(