diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 9d67b46f2e3e..bd7157568e84 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable + import pytest from tests.models.registry import HF_EXAMPLE_MODELS @@ -8,7 +10,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams -from ...utils import check_logprobs_close +from ...utils import check_logprobs_close, check_outputs_equal # Mark all tests as hybrid pytestmark = pytest.mark.hybrid_model @@ -332,3 +334,413 @@ def test_fp32_cache_state( name_0="hf", name_1="vllm", ) + + +# Helper functions for the APC tests +def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1): + return { + 'model_name': model, + 'enable_prefix_caching': False, + 'max_model_len': max_model_len, + 'tensor_parallel_size': tensor_parallel_size, + 'gpu_memory_utilization': 0.4 + } + + +def _get_vLLM_output(vllm_runner, + kwargs, + prompts, + max_tokens, + num_logprobs, + num_repetitions=1, + vllm_model=None): + outs = [] + if vllm_model is None: + vllm_model = vllm_runner(**kwargs) + for _ in range(num_repetitions): + if num_logprobs < 0: + vllm_output = vllm_model.generate_greedy(prompts, max_tokens) + else: + vllm_output = vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs) + outs.append(vllm_output) + + return outs, vllm_model + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_single_prompt( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal # type: ignore + + MULTIPLE = 300 + + # Sample prompts. + generated_prompts = [MULTIPLE * example_prompts[0]] + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size=tensor_parallel_size) + vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32" + vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs) + + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs, n_repetitions) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_single_prompt_block_align_alignment( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal # type: ignore + + MULTIPLE = 300 + + # Sample prompts. This custom prompt is used, as it causes the most issues + generated_prompts = ["The president of the United States is " * MULTIPLE] + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size=tensor_parallel_size) + vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs) + + vllm_runner_kwargs['enable_prefix_caching'] = True + with vllm_runner(**vllm_runner_kwargs) as vllm_model: + # Retrieve the default mamba state block size + mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ + mamba_block_size + + # In case the hybrid model does not have the + # "mamba_block_size" assume a fixed constant + if mamba_block_size is None: + mamba_block_size = 512 + + mamba_block_size_multiplier = 10 + for offsets in [ + -3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3 + ]: + + vllm_runner_kwargs[ + 'max_num_batched_tokens'] = mamba_block_size_multiplier * \ + mamba_block_size - offsets + vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, num_logprobs, + n_repetitions) + + # Check alignment of the output logits when using APC + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_multiple_prompts_all_cached_outputs( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal # type: ignore + + MULTIPLE = 300 + + # Sample prompts. + generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size=tensor_parallel_size) + vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs) + + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs, n_repetitions) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_multiple_prompts_block_align_alignment( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal # type: ignore + + MULTIPLE = 300 + + # Sample prompts. This custom prompt is used, as it causes the most issues + prompt_text = "The president of the United States is " + prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31] + generated_prompts = [ + prompt_text[offset:] * MULTIPLE for offset in prompt_offsets + ] + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params(model, max_model_len, + tensor_parallel_size) + vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs) + + vllm_runner_kwargs['enable_prefix_caching'] = True + with vllm_runner(**vllm_runner_kwargs) as vllm_model: + # Retrieve the default mamba state block size + mamba_block_size = vllm_model.llm.llm_engine.cache_config. \ + mamba_block_size + + # In case the hybrid model does not have the + # "mamba_block_size" assume a fixed constant + if mamba_block_size is None: + mamba_block_size = 512 + + mamba_block_size_multiplier = 10 + for offsets in [ + -3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3 + ]: + + vllm_runner_kwargs[ + 'max_num_batched_tokens'] = mamba_block_size_multiplier * \ + mamba_block_size - offsets + vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, num_logprobs, + n_repetitions) + + # Check alignment of the output logits when using APC + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_multiple_prompts_partial_cached_outputs( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = check_logprobs_close \ + if num_logprobs > 0 else check_outputs_equal # type: ignore + + MULTIPLE = 300 + + # Sample prompts. + generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + + max_model_len = max( + len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size=tensor_parallel_size) + vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, max_tokens, + num_logprobs) + + # Cache only part of all the prompts + vllm_runner_kwargs['enable_prefix_caching'] = True + vllm_outputs_partial_cache, vllm_model = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, + num_logprobs) + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0][:3], + outputs_1_lst=vllm_outputs_partial_cache[0], + name_0="vllm_no_cache", + name_1="vllm_partial_cache", + ) + + vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + vllm_model=vllm_model) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 58770649a8af..bdfa99cd79a3 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -92,7 +92,8 @@ class CacheConfig: mamba_page_size_padded: Optional[int] = None """ Optional override for mamba page size; used by hybrid mamba/attention models to ensure exact alignment with attention page size.""" - + mamba_block_size: Optional[int] = None + """Size of a contiguous cache block in number of tokens for mamba cache.""" mamba_cache_dtype: MambaDType = "auto" """The data type to use for the Mamba cache (both the conv as well as the ssm state). If set to 'auto', the data type will be inferred from the model diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bf293a4d2aa9..89a881675ad6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1563,7 +1563,12 @@ def _set_default_args(self, usage_context: UsageContext, self.enable_prefix_caching = False if self.enable_prefix_caching is None: - self.enable_prefix_caching = True + # Disable prefix caching default for hybrid models + # since the feature is still experimental. + if model_config.is_hybrid: + self.enable_prefix_caching = False + else: + self.enable_prefix_caching = True else: pooling_type = model_config.pooler_config.pooling_type diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index bfb0666d361f..56df9cf511e6 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -489,6 +489,9 @@ def forward_cuda( # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = forward_context.attn_metadata + assert self.cache_config is not None + mamba_block_size = self.cache_config.mamba_block_size + prefix_caching_enabled = self.cache_config.enable_prefix_caching if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] @@ -573,6 +576,25 @@ def forward_cuda( dim=0, ) + if prefix_caching_enabled: + # If prefix caching is enabled, retrieve the relevant variables + # for prefill and decode + last_state_idx_d, last_state_idx_p = torch.split( + attn_metadata.last_state_idx, [num_decodes, num_prefills], + dim=0) + current_last_idx_d, current_last_idx_p = torch.split( + attn_metadata.current_last_idx, [num_decodes, num_prefills], + dim=0) + # Prefill-only variables: + current_first_idx_p = attn_metadata.current_first_idx_p + context_lens_p = attn_metadata.context_lens_p + last_computed_offset_p = attn_metadata.last_computed_offset_p + else: + last_state_idx_d, last_state_idx_p = None, None + current_last_idx_d, current_last_idx_p = None, None + current_first_idx_p = None + context_lens_p = None + # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs preallocated_ssm_out = torch.empty( @@ -592,8 +614,17 @@ def forward_cuda( # Process prefill requests if has_prefill: # 2. Convolution sequence transformation - # - "cache_indices" updates the conv_state cache in positions - # pointed to by "state_indices_tensor" + # - It will read the initial states for every sequence, + # that has "has_initial_states_p" == True, + # from "cache_indices", using "state_indices_tensor_p". + # - It updates the "conv_state" cache in positions pointed + # to by "state_indices_tensor_p". + # In particular, it will always write the state at the + # sequence end. + # In addition, "current_first_idx_p" and "current_last_idx_p" + # are provided (which are pointers into + # "state_indices_tensor_p"), it will write additional cache + # states aligned at "block_size_to_align". x = hidden_states_B_C_p.transpose( 0, 1) # this is the form that causal-conv see hidden_states_B_C_p = causal_conv1d_fn( @@ -604,6 +635,11 @@ def forward_cuda( conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, + current_first_idx=current_first_idx_p, + current_last_idx=current_last_idx_p, + initial_state_idx=last_state_idx_p, + context_lens=context_lens_p, + block_size_to_align=mamba_block_size, metadata=attn_metadata, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] @@ -614,9 +650,13 @@ def forward_cuda( # 3. State Space Model sequence transformation initial_states = None if (has_initial_states_p is not None and prep_initial_states): + kernel_ssm_indices = state_indices_tensor_p + if prefix_caching_enabled: + kernel_ssm_indices = state_indices_tensor_p.gather( + 1, last_state_idx_p.unsqueeze(1)).squeeze(1) initial_states = torch.where( has_initial_states_p[:, None, None, None], - ssm_state[state_indices_tensor_p], 0) + ssm_state[kernel_ssm_indices], 0) # NOTE: final output is an in-place update of out tensor varlen_states = mamba_chunk_scan_combined_varlen( @@ -638,18 +678,71 @@ def forward_cuda( cu_chunk_seqlens=cu_chunk_seqlen_p, last_chunk_indices=last_chunk_indices_p, initial_states=initial_states, + return_intermediate_states=prefix_caching_enabled, dt_softplus=True, dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim), state_dtype=ssm_state.dtype) - # update ssm states - # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor - ssm_state[state_indices_tensor_p] = varlen_states + if prefix_caching_enabled: + # Save states for sequences with more than just the final state: + n_blocks_to_fill = current_last_idx_p - current_first_idx_p + for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1): + cache_blocks_to_fill = state_indices_tensor_p[ + seq_idx, current_first_idx_p[seq_idx]: + current_first_idx_p[seq_idx] + + n_blocks_to_fill[seq_idx]] + # chunks = [0 1 2 3 4 5 6 ...] + # First aligned chunk would typically be: + # mamba_block_size = 1024, chunk_size = 256 + # 1024 // 256 - 1 --> chunks[3] + # But when last chunk wasn't block aligned: + # - last_computed_offset_p[seq_idx] // chunk_size + # e.g. 1000 // 256 -> 3 completed --> store chunk[0] + # e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1) + # e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2) + # e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3) + chunk_stride = mamba_block_size // chunk_size + first_aligned_chunk = \ + torch.concat([torch.zeros(1, \ + dtype=last_chunk_indices_p.dtype, \ + device=last_chunk_indices_p.device), \ + last_chunk_indices_p + 1])[seq_idx] \ + + chunk_stride - 1 \ + - last_computed_offset_p[seq_idx] // chunk_size + from_where = varlen_states[ + first_aligned_chunk:first_aligned_chunk + + n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride] + ssm_state[cache_blocks_to_fill] = from_where + + #For all seqs, store the last state (Note: might be partial): + ssm_state[state_indices_tensor_p.gather(1, + current_last_idx_p.unsqueeze(1)).squeeze(1)] = \ + varlen_states[last_chunk_indices_p] + else: + # update ssm states + # - varlen state is a (num_prefills, nheads, headdim, dstate) + # tensor + ssm_state[state_indices_tensor_p] = varlen_states # Process decode requests if has_decode: + if prefix_caching_enabled: + state_indices_tensor_d_input = \ + state_indices_tensor_d.gather(1, + last_state_idx_d.unsqueeze(1)).squeeze(1) + state_indices_tensor_d_output = \ + state_indices_tensor_d.gather(1, + current_last_idx_d.unsqueeze(1)).squeeze(1) + #Note: + # for decode always: current_first_idx_d == current_last_idx_d + # at block boundaries: current_first_idx_d > last_state_idx_d + else: + # Without caching, read and write in-place to the same blocks: + state_indices_tensor_d_input = state_indices_tensor_d + state_indices_tensor_d_output = state_indices_tensor_d + # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d, @@ -657,7 +750,10 @@ def forward_cuda( conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor_d) + conv_state_indices=state_indices_tensor_d, + current_last_idx=current_last_idx_d, + initial_state_idx=last_state_idx_d, + ) hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn( hidden_states_B_C_d) @@ -689,7 +785,8 @@ def forward_cuda( z=None, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices_tensor_d, + state_batch_indices=state_indices_tensor_d_input, + dst_state_batch_indices=state_indices_tensor_d_output, out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), ) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index c4102c4753c7..a02bba5d4ddd 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -20,19 +20,23 @@ def _causal_conv1d_fwd_kernel( # continuous batching w_ptr, # (dim, width) bias_ptr, initial_states_ptr, # conv_states_ptr - cache_indices_ptr, # conv_state_indices_ptr + cache_indices_ptr, # (batch, n_blocks + padding) The second dimension contains + # the block indices relevant for each sequence + # plus potential 0-padding at the beginning and at the end has_initial_states_ptr, query_start_loc_ptr, batch_ptr, token_chunk_offset_ptr, + current_first_idx, # (batch,) + current_last_idx, # (batch,) + initial_state_idx, # (batch,) + context_lens, # (batch,) o_ptr, # (dim, seqlen) - actually pointing to x_ptr # Matrix dimensions - batch: tl.int32, # actually padded_batch dim: tl.constexpr, seqlen: tl.int32, # cu_seqlen num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines # Strides - stride_x_seq: tl.constexpr, # stride to get to next sequence, stride_x_dim: tl.constexpr, # stride to get to next feature-value, stride_x_token: tl. constexpr, # stride to get to next token (same feature-index, same sequence-index) @@ -42,18 +46,16 @@ def _causal_conv1d_fwd_kernel( # continuous batching stride_istate_dim: tl.constexpr, stride_istate_token: tl.constexpr, stride_cache_indices: tl.constexpr, - stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, + stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M # others pad_slot_id: tl.constexpr, # Meta-parameters HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, - HAS_INITIAL_STATES: tl.constexpr, - HAS_CACHE: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, USE_PAD_SLOT: tl.constexpr, NP2_STATELEN: tl.constexpr, BLOCK_M: tl.constexpr, @@ -84,26 +86,57 @@ def _causal_conv1d_fwd_kernel( # continuous batching # find the actual sequence length seqlen = sequence_end_index - sequence_start_index + B_size: tl.constexpr = (stride_block_m * BLOCK_M) + + if IS_APC_ENABLED: + # Handle the case if prefix caching is enabled. + # In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr" + + # Get the length of the completed sequence so far and compute the offset. + current_first_index = tl.load(current_first_idx + idx_seq) + current_last_index = tl.load(current_last_idx + idx_seq) + sequence_completed_index = tl.load(context_lens + idx_seq) + + # Compute the offset where the first stride_block_m-aligned first full block is + # Value in "token-space" + sequence_completed_offset_token = sequence_completed_index % B_size + seq_completed_offset = B_size - sequence_completed_offset_token + seq_end_offset = (seqlen - seq_completed_offset) % B_size + last_full_block_token_index = sequence_end_index - seq_end_offset + # If the sequence without the sequence_offset_index is stride_cache_chunk-aligned, then the last full chunk is the second-to-last one + if seq_end_offset == 0: + last_full_block_token_index = last_full_block_token_index - B_size + + # Get the number of blocks to be filled for the current sequence + # If n_block_to_fill = 0, then only the state at the sequence end is stored + n_block_to_fill = current_last_index - current_first_index + + # Get the index of the init block + conv_state_init_index = tl.load(initial_state_idx + idx_seq) + else: + n_block_to_fill = 0 + current_last_index = 0 + conv_state_init_index = 0 + current_first_index = 0 + last_full_block_token_index = 0 + token_offset = BLOCK_M * chunk_offset segment_len = min(BLOCK_M, seqlen - token_offset) # base of the sequence x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_cache_indices).to( - tl.int64) - else: - # cache_idx - conv_state_batch_coord = idx_seq + # cache_idx + conv_states_input_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_cache_indices + + conv_state_init_index).to(tl.int64) + if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: + if conv_states_input_coord == pad_slot_id: # not processing as this is not the actual sequence return conv_states_base = (conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] @@ -113,10 +146,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] if chunk_offset == 0: # read from conv_states - load_init_state = False - if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES - load_init_state = tl.load(has_initial_states_ptr + idx_seq).to( - tl.int1) + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) if load_init_state: # load from conv_states prior_tokens = conv_states_base + (state_len - @@ -175,15 +205,23 @@ def _causal_conv1d_fwd_kernel( # continuous batching (idx_feats < dim)[None, :] ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) - new_conv_state = tl.load(x_ptrs, mask_x, 0.0) idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - conv_states_ptrs_target = conv_states_base[None, :] + ( - idx_tokens_conv * stride_conv_state_tok)[:, None] + + # Compute the offset where the last block should be written in the conv_states + conv_states_output_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_cache_indices + + current_last_index).to(tl.int64) + + conv_states_ptrs_target = ( + conv_states_ptr + (conv_states_output_coord * + stride_conv_state_seq) + # Offset from seq + (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] + idx_tokens_conv * stride_conv_state_tok)[:, None] mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] tl.debug_barrier() # NOTE: use this due to bug in Triton compiler - tl.store(conv_states_ptrs_target, new_conv_state, mask) + tl.store(conv_states_ptrs_target, loaded_x, mask) else: if load_init_state: @@ -192,12 +230,12 @@ def _causal_conv1d_fwd_kernel( # continuous batching conv_states_ptrs_source = ( conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)[None, :] + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] ) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) + mask = ((conv_states_input_coord < num_cache_lines) & ((idx_tokens_conv + seqlen) < state_len)[:, None] & (idx_feats < dim)[None, :]) conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) @@ -280,6 +318,45 @@ def _causal_conv1d_fwd_kernel( # continuous batching conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + # Store intermediate states aligned with stride_block_m + # The additional states are cached starting from the last stride_block_m. + # For example: + # If n_block_to_fill = 0, then only the state at the sequence end is cached and the process below is not involved. + # If n_block_to_fill > 0, then the states at the sequence end and at the n_block_to_fill-last + # stride_block_m are cached. + # For example chunk_offset = n_block_to_fill stores the state at last_full_block + if (chunk_offset - 1) < n_block_to_fill: + # Store the states at the chunk boundaries from the start of the sequence + idx_tokens_last = (last_full_block_token_index - + (n_block_to_fill - chunk_offset) * B_size - + state_len) + tl.arange( + 0, NP2_STATELEN) # [BLOCK_M] + x_ptrs = x_ptr + (idx_tokens_last * stride_x_token)[:, None] + ( + idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] + + mask_x = ( + (idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # cache_idx + conv_states_output_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_cache_indices + + current_first_index + + (chunk_offset - 1)).to(tl.int64) + + conv_states_ptrs_target = ( + conv_states_ptr + (conv_states_output_coord * + stride_conv_state_seq) + # Offset from seq + (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] + idx_tokens_conv * stride_conv_state_tok)[:, None] + + mask = (idx_tokens_conv < state_len)[:, None] & \ + (idx_feats < dim)[None, :] + tl.debug_barrier() # NOTE: use this due to bug in Triton compiler + tl.store(conv_states_ptrs_target, loaded_x, mask) + if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim @@ -368,6 +445,11 @@ def causal_conv1d_fn( has_initial_state: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", pad_slot_id: int = PAD_SLOT_ID, + current_first_idx: Optional[torch.Tensor] = None, + current_last_idx: Optional[torch.Tensor] = None, + initial_state_idx: Optional[torch.Tensor] = None, + context_lens: Optional[torch.Tensor] = None, + block_size_to_align=0, metadata=None, validate_data=False, ): @@ -378,7 +460,7 @@ def causal_conv1d_fn( sequences are concatenated from left to right for varlen weight: (dim, width) conv_states: (...,dim,width - 1) itype - updated inplace if provided + updated inplace if cache_indices are not provided [it use `cache_indices` to get the index to the cache of conv_state for that sequence conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True @@ -410,7 +492,16 @@ def causal_conv1d_fn( for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - + current_first_idx: (batch,), dtype int32 + The pointer into cache_indices, where the first cache block to be filled is located. + current_last_idx: (batch,), dtype int32 + The pointer into cache_indices, where the last cache block to be filled is located. + initial_state_idx: (batch,), dtype int32 + The pointer into cache_indices, where the cache block containing the initial state is located. + context_lens: (batch,), dtype int32 + The number of tokens already completed for each sequence + block_size_to_align: int + The block size to align the cached states to out: same shape as `x` """ if isinstance(activation, bool) and activation: @@ -451,7 +542,6 @@ def causal_conv1d_fn( np2_statelen = triton.next_power_of_2(state_len) padded_batch = query_start_loc.size(0) - 1 - stride_x_seq = 0 stride_x_dim = x.stride(0) stride_x_token = x.stride(1) stride_w_dim = weight.stride(0) @@ -460,6 +550,7 @@ def causal_conv1d_fn( stride_istate_dim = 0 stride_istate_token = 0 num_cache_lines = 0 + BLOCK_M = 8 if conv_states is not None: # extensions to support vLLM: # 1. conv_states is used to replaced initial_states @@ -475,11 +566,9 @@ def causal_conv1d_fn( stride_istate_token = conv_states.stride(2) assert stride_istate_dim == 1 if out.dim() == 2: - stride_o_seq = 0 stride_o_dim = out.stride(0) stride_o_token = out.stride(1) else: - stride_o_seq = out.stride(0) stride_o_dim = out.stride(1) stride_o_token = out.stride(2) stride_cache_indices = cache_indices.stride( @@ -502,6 +591,12 @@ def causal_conv1d_fn( assert weight.stride(1) == 1 assert (dim, width) == weight.shape assert is_channel_last, "Need to run in channel-last layout" + if block_size_to_align is not None and block_size_to_align > 0: + assert ( + block_size_to_align % BLOCK_M + ) == 0, "The mamba block size needs to be divisible by the BLOCK_M" + else: + block_size_to_align = BLOCK_M if metadata is None: @@ -584,14 +679,16 @@ def grid(META): query_start_loc, batch_ptr, token_chunk_offset_ptr, + current_first_idx, + current_last_idx, + initial_state_idx, + context_lens, out, # Matrix dimensions - padded_batch, dim, cu_seqlen, num_cache_lines, # stride - stride_x_seq, stride_x_dim, stride_x_token, stride_w_dim, @@ -600,22 +697,20 @@ def grid(META): stride_istate_dim, stride_istate_token, stride_cache_indices, - stride_o_seq, stride_o_dim, stride_o_token, + block_size_to_align // BLOCK_M, # others pad_slot_id, # META HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], - HAS_INITIAL_STATES=has_initial_state is not None, - HAS_CACHE=conv_states is not None, - IS_CONTINUOUS_BATCHING=cache_indices is not None, + IS_APC_ENABLED=current_last_idx is not None, USE_PAD_SLOT=pad_slot_id is not None, NP2_STATELEN=np2_statelen, #launch_cooperative_grid=True - BLOCK_M=8, + BLOCK_M=BLOCK_M, BLOCK_N=256, num_stages=2, ) @@ -629,10 +724,11 @@ def _causal_conv1d_update_kernel( w_ptr, # (dim, width) bias_ptr, conv_state_ptr, - cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, num_accepted_tokens_ptr, query_start_loc_ptr, # (batch + 1) + current_last_idx, # (batch,) + initial_state_idx, #(batch,) o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -660,7 +756,7 @@ def _causal_conv1d_update_kernel( KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_VARLEN: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, @@ -674,15 +770,21 @@ def _causal_conv1d_update_kernel( # [BLOCK_N,] elements along the feature-dimension (channel) idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - if IS_CONTINUOUS_BATCHING: - # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices).to( - tl.int64) + if IS_APC_ENABLED: + # Get the state from the initial_state_idx + conv_state_init = tl.load(initial_state_idx + idx_seq) + current_last_index = tl.load(current_last_idx + idx_seq) else: - conv_state_batch_coord = idx_seq + conv_state_init = 0 + current_last_index = 0 + + # cache_idx + conv_states_input_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices + + conv_state_init).to(tl.int64) + if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: + if conv_states_input_coord == pad_slot_id: # not processing as this is not the actual sequence return @@ -726,7 +828,7 @@ def _causal_conv1d_update_kernel( # STEP 1: READ init_state data conv_states_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + + (conv_states_input_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)) mask_w = idx_feats < dim @@ -754,12 +856,12 @@ def _causal_conv1d_update_kernel( # window manner, at each forward pass, the tokens are shift by 1, so we # load since idx_tokens + 1. conv_state_ptrs_source = ( - conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) + conv_state_token_offset * stride_conv_state_tok + (idx_feats * stride_conv_state_dim)[None, :] + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) + mask = ((conv_states_input_coord < num_cache_lines) & ((idx_tokens + seqlen) < state_len)[:, None] & (idx_feats < dim)[None, :]) conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) @@ -778,11 +880,16 @@ def _causal_conv1d_update_kernel( new_conv_state = tl.where(mask, conv_state, loaded_x) - conv_state_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] - conv_state_ptrs_target = conv_state_base + ( - idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + # Get the state from the initial_state_idx + # cache_idx + conv_states_offset = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices + + current_last_index).to(tl.int64) + conv_state_ptrs_target = ( + conv_state_ptr + + (conv_states_offset * stride_conv_state_seq) + # Offset from seq + (idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,] + idx_tokens * stride_conv_state_tok)[:, None] mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] tl.store(conv_state_ptrs_target, new_conv_state, mask) @@ -923,12 +1030,13 @@ def causal_conv1d_update( weight: torch.Tensor, bias: Optional[torch.Tensor] = None, activation: Union[bool, str, None] = None, - cache_seqlens: Optional[torch.Tensor] = None, conv_state_indices: Optional[torch.Tensor] = None, num_accepted_tokens: Optional[torch.Tensor] = None, query_start_loc: Optional[torch.Tensor] = None, max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, + current_last_idx: Optional[torch.Tensor] = None, + initial_state_idx: Optional[torch.Tensor] = None, validate_data=False, ): """ @@ -942,15 +1050,14 @@ def causal_conv1d_update( conv_state: (..., dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state - starting at the index - @cache_seqlens % state_len. conv_state_indices: (batch,), dtype int32 If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. + current_last_idx: (batch,), dtype int32 + The pointer into conv_state_indices, where the last cache block to be filled is located. + initial_state_idx: (batch,), dtype int32 + The pointer into conv_state_indices, where the cache block containing the initial state is located. num_accepted_tokens: (batch,), dtype int32 If not None, it indicates the number of accepted tokens for each sequence in the batch. @@ -963,15 +1070,14 @@ def causal_conv1d_update( If query_start_loc is not None, this indicates the maximum query length in the batch. pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded + if conv_state_indices is passed, lets the kernel identify padded entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` """ if validate_data: - assert cache_seqlens is None # not implemented yet - ok for vLLM assert pad_slot_id is not None assert x.stride(1) == 1 if isinstance(activation, bool): @@ -1011,7 +1117,6 @@ def causal_conv1d_update( assert num_cache_lines >= batch assert weight.stride(1) == 1 # Need this - assert cache_seqlens is None # not needed for vLLM - circular buffer # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' out = x @@ -1050,10 +1155,11 @@ def grid(META): weight, bias, conv_state, - cache_seqlens, conv_state_indices, num_accepted_tokens, query_start_loc, + current_last_idx, + initial_state_idx, out, # Matrix dimensions batch, @@ -1081,7 +1187,7 @@ def grid(META): KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_VARLEN=query_start_loc is not None, - IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_APC_ENABLED=current_last_idx is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 838290a9f5fb..21bc32ddecd4 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -52,6 +52,7 @@ def _selective_scan_update_kernel( z_ptr, out_ptr, state_batch_indices_ptr, + dst_state_batch_indices_ptr, pad_slot_id, # Matrix dimensions batch, @@ -107,11 +108,17 @@ def _selective_scan_update_kernel( # is taken from the state_batch_indices_ptr Otherwise, the state coordinate # is the same as the batch id. if HAS_STATE_BATCH_INDICES: + dst_state_batch_indices_ptr += pid_b + dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64) + dst_state_ptr = state_ptr + (dst_state_batch_idx * stride_state_batch + + pid_h * stride_state_head) state_batch_indices_ptr += pid_b state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) state_ptr += (state_batch_idx * stride_state_batch + pid_h * stride_state_head) else: + dst_state_ptr = state_ptr + pid_b * stride_state_batch + \ + pid_h * stride_state_head state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head @@ -131,6 +138,8 @@ def _selective_scan_update_kernel( offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) + dst_state_ptrs = dst_state_ptr + (offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate) x_ptrs = x_ptr + offs_m * stride_x_dim dt_ptrs = dt_ptr + offs_m * stride_dt_dim if HAS_DT_BIAS: @@ -185,7 +194,7 @@ def _selective_scan_update_kernel( mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) if HAS_STATE_BATCH_INDICES: mask &= (state_batch_idx != pad_slot_id) - tl.store(state_ptrs, state, mask=mask) + tl.store(dst_state_ptrs, state, mask=mask) out = tl.sum(state * C[None, :], axis=1) if HAS_D: out += x * D @@ -205,6 +214,7 @@ def selective_state_update(state, dt_bias=None, dt_softplus=False, state_batch_indices=None, + dst_state_batch_indices=None, pad_slot_id=PAD_SLOT_ID, out=None): """ @@ -266,6 +276,11 @@ def selective_state_update(state, assert dt_bias.shape == (nheads, dim) if state_batch_indices is not None: assert state_batch_indices.shape == (batch, ) + if dst_state_batch_indices is not None: + assert dst_state_batch_indices.shape == (batch, ) + else: + # revert to the default behavior of in-place state updates + dst_state_batch_indices = state_batch_indices assert out.shape == x.shape grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) @@ -292,6 +307,7 @@ def selective_state_update(state, z, out, state_batch_indices, + dst_state_batch_indices, pad_slot_id, batch, nheads, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index f3eb61d5840e..e9e589115b8a 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -35,6 +35,7 @@ def _mamba_chunk_scan_combined_fwd(x, z=None, dt_bias=None, initial_states=None, + return_intermediate_states=False, seq_idx=None, cu_seqlens=None, cu_chunk_seqlens=None, @@ -151,28 +152,32 @@ def _mamba_chunk_scan_combined_fwd(x, initial_states=initial_states, ) - return states[last_chunk_indices] + if return_intermediate_states: + return states + else: + return states[last_chunk_indices] def mamba_chunk_scan_combined_varlen( - x, - dt, - A, - B, - C, - chunk_size, - cu_seqlens, - cu_chunk_seqlens, - last_chunk_indices, - seq_idx, - out, - D=None, - z=None, - dt_bias=None, - initial_states=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - state_dtype=None, + x, + dt, + A, + B, + C, + chunk_size, + cu_seqlens, + cu_chunk_seqlens, + last_chunk_indices, + seq_idx, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_intermediate_states=False, + state_dtype=None, ): """ Argument: @@ -213,6 +218,7 @@ def mamba_chunk_scan_combined_varlen( z=z, dt_bias=dt_bias, initial_states=initial_states, + return_intermediate_states=return_intermediate_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, cu_chunk_seqlens=cu_chunk_seqlens, diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 4a6154dc548a..c58d6eaa19cb 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -453,12 +453,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Bamba currently does not support prefix caching" - self.quant_config = vllm_config.quant_config super().__init__() diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 5711b5ebe85e..283cd2bb8b41 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -292,10 +292,33 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config compilation_config = vllm_config.compilation_config - # TODO(tdoublep): remove once prefix caching is enabled - cache_config.enable_prefix_caching = False - logger.info("Hybrid or mamba-based model detected: disabling prefix " - "caching since it is not yet supported.") + # Set mamba block size to max_model_len (this may get + # override by prefix caching logic later) + cache_config.mamba_block_size = model_config.max_model_len + + # TODO(@tdoublep) find a better way to do this than whitelist + MAMBA2_MODELS = [ + "BambaForCausalLM", + "FalconH1ForCausalLM", + "GraniteMoeHybridForCausalLM", + "Mamba2ForCausalLM", + "NemotronHForCausalLM", + "Zamba2ForCausalLM", + ] + if cache_config.enable_prefix_caching: + if model_config.architecture in MAMBA2_MODELS: + logger.info("Warning: Prefix caching is currently enabled. " + "Its support for Mamba2 layers is experimental. " + "Please report any issues you may observe.") + else: + logger.info("Hybrid or mamba-based model detected without " + "support for prefix caching: disabling.") + cache_config.enable_prefix_caching = False + + # TODO(tdoublep): remove once cascade attention is supported + logger.info("Disabling cascade attention since it is not supported " + "for hybrid models.") + model_config.disable_cascade_attn = True # TODO(tdoublep): remove as full cuda graph support is added FCG_NOT_SUPPORTED_MODELS = [ @@ -360,12 +383,38 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: block_size=model_config.max_model_len, ).page_size_bytes - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - attn_block_size = 16 * cdiv(mamba_page_size, - 16 * attn_page_size_1_token) + if cache_config.enable_prefix_caching: + # With prefix caching, select attention block size to + # optimize for mamba kernel performance + + # mamba SSD kernel uses a chunk_size, e.g. 256 + # Align the block to the kernel: use lowest multiple of chunk_size + # of attention tokens that would fit mamba_page_size: + # e.g. for mamba page size = 788kB + # attn_1_token = 2kB -> fits ~394 tokens + # then round up to a mulitple of 256 -> 512 tokens + # End result: + # attn_block_size = 512 + # mamba_block_size = 512 (aligned to a multiple of chunk_size) + # TODO(tdoublep): this constraint can be relaxed fairly + # easily by changing the way we layout chunks in the + # mamba2 kernels. + chunk_size = model_config.get_mamba_chunk_size() + attn_tokens_per_mamba_state = \ + cdiv(mamba_page_size, attn_page_size_1_token) + attn_block_size = chunk_size * \ + cdiv(attn_tokens_per_mamba_state, chunk_size) + cache_config.mamba_block_size = attn_block_size + else: + # Without prefix caching, select minimum valid attention block size + # to minimize mamba state padding + + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + attn_block_size = 16 * cdiv(mamba_page_size, + 16 * attn_page_size_1_token) # override attention block size if either (a) the # user has not set it or (b) the user has set it diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index f382018e2222..ccea9add093f 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -540,11 +540,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert (not cache_config.enable_prefix_caching - ), "FalconH1 currently does not support prefix caching" self.quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index f5751fe47bb8..dc213e029cd5 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -549,13 +549,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - if cache_config.enable_prefix_caching: - raise RuntimeError( - "GraniteMoeHybrid currently does not support prefix caching") - self.quant_config = vllm_config.quant_config self.config = config self.scheduler_config = scheduler_config diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index f8a5a8f6081b..250698a61387 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -222,11 +222,8 @@ def get_mamba_state_shape_from_config( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Mamba does not support prefix caching" super().__init__() self.config = config diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 987920ecc331..c89550923938 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -505,11 +505,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "NemotronH currently does not support prefix caching" self.quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 1d68320bd9b2..1803fa259cf4 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -868,11 +868,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: (not supported by Mamba) """ config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Mamba does not support prefix caching" super().__init__() self.config = config diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 68b6ff73ba3f..49fe1584e79c 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -122,6 +122,11 @@ class Mamba2AttentionMetadata: last_chunk_indices_p: Optional[torch.Tensor] state_indices_tensor: torch.Tensor # shape: [batch,] + current_last_idx: torch.Tensor + current_first_idx_p: torch.Tensor + last_state_idx: torch.Tensor + context_lens_p: torch.Tensor + last_computed_offset_p: torch.Tensor # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None @@ -138,6 +143,24 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") + if self.vllm_config.cache_config.enable_prefix_caching: + self.state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, + cdiv(vllm_config.model_config.max_model_len, + kv_cache_spec.block_size)), + dtype=torch.int32, + device=device, + ) + self.current_last_idx = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + self.last_state_idx = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) def build(self, common_prefix_len: int, @@ -158,7 +181,45 @@ def build(self, # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + context_lens, context_lens_p = None, None + current_first_idx, current_first_idx_p = None, None + last_computed_offset, last_computed_offset_p = None, None + + if self.vllm_config.cache_config.enable_prefix_caching: + # Return a tensor of shape (#requests, #max blocks) + state_indices_tensor = common_attn_metadata.block_table_tensor + + # Additional cache-related varaiables: + mamba_block_size = self.kv_cache_spec.block_size + seq_lens_pending = ( + torch.roll(common_attn_metadata.query_start_loc, -1, -1) - + common_attn_metadata.query_start_loc)[:-1] + context_lens = common_attn_metadata.seq_lens - \ + seq_lens_pending + last_computed_offset = \ + context_lens % mamba_block_size + # Indices: last_computed <= current_first <= current_last + # Cases: + # last_computed == current_first if last state was partially + # computed and needs to be updated + # current_first == current_last if no block crossing occurs, and + # only one state will be stored + # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: + current_last_idx = cdiv(context_lens + seq_lens_pending, + mamba_block_size) - 1 + current_first_idx = cdiv(context_lens + 1, mamba_block_size) - 1 + last_state_idx = cdiv(context_lens, mamba_block_size) - 1 + # -1 in case it's non-computed and causes later issues with indexing + last_state_idx = \ + last_state_idx.clamp(min=0) + + else: + # Always return just a single block per each request: + state_indices_tensor = common_attn_metadata.block_table_tensor[:, + 0] + # Additional cache-related varaiables: + current_last_idx = None + last_state_idx = None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( @@ -178,6 +239,16 @@ def build(self, query_start_loc_p = common_attn_metadata.query_start_loc[ -num_prefills - 1:] - num_decode_tokens + if self.vllm_config.cache_config.enable_prefix_caching: + assert context_lens is not None + context_lens_p = context_lens[num_reqs - num_prefills:num_reqs] + assert last_computed_offset is not None + last_computed_offset_p = last_computed_offset[ + num_reqs - num_prefills:num_reqs] + assert current_first_idx is not None + current_first_idx_p = current_first_idx[num_reqs - + num_prefills:num_reqs] + num_computed_tokens_p = \ common_attn_metadata.num_computed_tokens_cpu[ num_reqs - num_prefills:num_reqs] @@ -252,6 +323,19 @@ def build(self, state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID + if self.vllm_config.cache_config.enable_prefix_caching: + self.current_last_idx[:num_decodes].copy_(current_last_idx, + non_blocking=True) + current_last_idx = \ + self.current_last_idx[:num_input_tokens] + current_last_idx[num_decodes:] = 0 + + self.last_state_idx[:num_decodes].copy_(last_state_idx, + non_blocking=True) + last_state_idx = \ + self.last_state_idx[:num_input_tokens] + last_state_idx[num_decodes:] = 0 + attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, @@ -269,5 +353,10 @@ def build(self, nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr, + current_last_idx=current_last_idx, + current_first_idx_p=current_first_idx_p, + last_state_idx=last_state_idx, + context_lens_p=context_lens_p, + last_computed_offset_p=last_computed_offset_p, ) return attn_metadata diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 27ea1c4db2a5..07777efc3281 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -546,20 +546,38 @@ def find_longest_cache_hit( kv_cache_spec, MambaSpec), ("MambaManager can only be used for mamba groups") assert dcp_world_size == 1, "DCP not support mamba now." - # Prefix caching is not supported for mamba now. Always return empty - # list. computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids))) + + max_num_blocks = max_length // kv_cache_spec.block_size + # Search from right to left and early stop when a match is found. + for i in range(max_num_blocks - 1, -1, -1): + if cached_block := block_pool.get_cached_block( + block_hashes[i], kv_cache_group_ids): + for computed, cached in zip(computed_blocks, cached_block): + # the hit length logic later assumes: + # hit_length = len(hit_blocks_other_attn[0]) + # * self.other_block_size + # so we insert dummy blocks at the beginning: + if i > 0: + computed.extend([block_pool.null_block] * i) + computed.append(cached) + break # we just need the last match - early stopping + return computed_blocks def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: - # Each request will always have 1 block at this moment, so no need to - # remove blocks. + # Here unused blocks may be freed up for running requests. + # TODO(@s3woz) Free up all blocks that aren't needed by Mamba2 + # (for which find_longest_cache_hit returns block_pool.null_block) pass def get_num_common_prefix_blocks(self, request_id: str, num_running_requests: int) -> int: + """ + cascade attention is not supported by mamba + """ return 0 def get_num_blocks_to_allocate( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 281816653540..054ab591b817 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -233,10 +233,8 @@ def page_size_bytes(self) -> int: return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - # We allocate 1 block for each request now, so max_memory_usage_bytes is - # the same as page_size_bytes. - # Need to update this when supporting prefix caching. - return self.page_size_bytes + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes @dataclass(frozen=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ff95acf0c016..11e24e4d13dc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4240,21 +4240,15 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: not in ["qwen3_next"]): raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") - if self.vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError( - "Prefix caching is not supported for Mamba yet.") - max_model_len = self.vllm_config.model_config.max_model_len - + mamba_block_size = self.vllm_config.cache_config.mamba_block_size page_size_padded = ( self.vllm_config.cache_config.mamba_page_size_padded) - # Set block_size to max_model_len, so that mamba model will always - # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): kv_cache_spec[layer_name] = MambaSpec( shapes=mamba_module.get_state_shape(), dtypes=mamba_module.get_state_dtype(), - block_size=max_model_len, + block_size=mamba_block_size, page_size_padded=page_size_padded, mamba_type=mamba_module.mamba_type, num_speculative_blocks=(