-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[V1] [Hybrid] Mamba2 Automatic Prefix Caching #25752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e6a41ba
87dd0a0
de7d2d6
224c9e1
f9e0de1
dddb650
1dc7a04
2a7b216
a30bb5e
664a21a
6c475d6
0d00c69
b7ae698
9b24bce
e850661
0d5c3ae
31e05fa
a8aff97
67db9b4
d841e82
908aecb
af7a246
7ce2b59
f950f2e
039267d
fe095c4
8a6336e
75e01c8
2698f2e
d3f05b7
df63503
c5edccd
712ced1
dc85f7e
5e827a6
42e4b27
d859182
56b37c2
e21b4e6
88c157a
45bdc9c
90dd7f5
d8e00e3
5785a85
d5cab4c
e711fc0
120fbb7
bc3c122
22c47f9
f56fe5c
15bb921
5890677
df23ee4
4ef4023
ff49343
c4255bb
7649489
1bb59d7
25f8a27
8515ee2
ebba273
48faed8
0ce539e
c06246b
77ab2fe
4bb28c0
1c7e947
7cdae60
f2beb4d
1f40794
7119d48
e5d5519
148ea61
ce5144b
e3b8cfb
0bb5197
c911c88
cf5c4c7
5345d66
618fe53
9dd6b81
63e9217
c0eed4a
1425b73
6e8faf9
0141f15
8d0077a
37dacff
17986f8
f71ad6d
a22c8ab
b49e33f
3333e77
0e574a7
ae21a8c
60cc1cd
0d1b054
e9f2257
e733552
6ebc97f
19b33a0
cd57200
9ab2dfc
f6293da
0806c06
ac31d48
e57cf9c
a49f94d
785ef4a
d68e491
5e8a63a
36cc1ce
b006aba
b256275
745af73
dcfc5ad
a0a4c40
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -489,6 +489,9 @@ def forward_cuda( | |
| # stay the same and reused for all mamba layers in the same iteration | ||
| attn_metadata: AttentionMetadata = forward_context.attn_metadata | ||
|
|
||
| assert self.cache_config is not None | ||
| mamba_block_size = self.cache_config.mamba_block_size | ||
| prefix_caching_enabled = self.cache_config.enable_prefix_caching | ||
| if attn_metadata is not None: | ||
| assert isinstance(attn_metadata, dict) | ||
| attn_metadata = attn_metadata[self.prefix] | ||
|
|
@@ -573,6 +576,25 @@ def forward_cuda( | |
| dim=0, | ||
| ) | ||
|
|
||
| if prefix_caching_enabled: | ||
| # If prefix caching is enabled, retrieve the relevant variables | ||
| # for prefill and decode | ||
| last_state_idx_d, last_state_idx_p = torch.split( | ||
| attn_metadata.last_state_idx, [num_decodes, num_prefills], | ||
| dim=0) | ||
| current_last_idx_d, current_last_idx_p = torch.split( | ||
| attn_metadata.current_last_idx, [num_decodes, num_prefills], | ||
| dim=0) | ||
| # Prefill-only variables: | ||
| current_first_idx_p = attn_metadata.current_first_idx_p | ||
| context_lens_p = attn_metadata.context_lens_p | ||
| last_computed_offset_p = attn_metadata.last_computed_offset_p | ||
| else: | ||
| last_state_idx_d, last_state_idx_p = None, None | ||
| current_last_idx_d, current_last_idx_p = None, None | ||
| current_first_idx_p = None | ||
| context_lens_p = None | ||
|
|
||
| # Preallocate output tensor to avoid memcpy cost for merging prefill | ||
| # and decode outputs | ||
| preallocated_ssm_out = torch.empty( | ||
|
|
@@ -592,8 +614,17 @@ def forward_cuda( | |
| # Process prefill requests | ||
| if has_prefill: | ||
| # 2. Convolution sequence transformation | ||
| # - "cache_indices" updates the conv_state cache in positions | ||
| # pointed to by "state_indices_tensor" | ||
| # - It will read the initial states for every sequence, | ||
| # that has "has_initial_states_p" == True, | ||
| # from "cache_indices", using "state_indices_tensor_p". | ||
| # - It updates the "conv_state" cache in positions pointed | ||
| # to by "state_indices_tensor_p". | ||
| # In particular, it will always write the state at the | ||
| # sequence end. | ||
| # In addition, "current_first_idx_p" and "current_last_idx_p" | ||
| # are provided (which are pointers into | ||
| # "state_indices_tensor_p"), it will write additional cache | ||
| # states aligned at "block_size_to_align". | ||
| x = hidden_states_B_C_p.transpose( | ||
| 0, 1) # this is the form that causal-conv see | ||
| hidden_states_B_C_p = causal_conv1d_fn( | ||
|
|
@@ -604,6 +635,11 @@ def forward_cuda( | |
| conv_states=conv_state, | ||
| has_initial_state=has_initial_states_p, | ||
| cache_indices=state_indices_tensor_p, | ||
| current_first_idx=current_first_idx_p, | ||
| current_last_idx=current_last_idx_p, | ||
| initial_state_idx=last_state_idx_p, | ||
| context_lens=context_lens_p, | ||
| block_size_to_align=mamba_block_size, | ||
| metadata=attn_metadata, | ||
| query_start_loc=query_start_loc_p).transpose( | ||
| 0, 1)[:num_prefill_tokens] | ||
|
|
@@ -614,9 +650,13 @@ def forward_cuda( | |
| # 3. State Space Model sequence transformation | ||
| initial_states = None | ||
| if (has_initial_states_p is not None and prep_initial_states): | ||
| kernel_ssm_indices = state_indices_tensor_p | ||
| if prefix_caching_enabled: | ||
| kernel_ssm_indices = state_indices_tensor_p.gather( | ||
| 1, last_state_idx_p.unsqueeze(1)).squeeze(1) | ||
| initial_states = torch.where( | ||
| has_initial_states_p[:, None, None, None], | ||
| ssm_state[state_indices_tensor_p], 0) | ||
| ssm_state[kernel_ssm_indices], 0) | ||
|
|
||
| # NOTE: final output is an in-place update of out tensor | ||
| varlen_states = mamba_chunk_scan_combined_varlen( | ||
|
|
@@ -638,26 +678,82 @@ def forward_cuda( | |
| cu_chunk_seqlens=cu_chunk_seqlen_p, | ||
| last_chunk_indices=last_chunk_indices_p, | ||
| initial_states=initial_states, | ||
| return_intermediate_states=prefix_caching_enabled, | ||
| dt_softplus=True, | ||
| dt_limit=(0.0, float("inf")), | ||
| out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, | ||
| self.head_dim), | ||
| state_dtype=ssm_state.dtype) | ||
|
|
||
| # update ssm states | ||
| # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor | ||
| ssm_state[state_indices_tensor_p] = varlen_states | ||
| if prefix_caching_enabled: | ||
| # Save states for sequences with more than just the final state: | ||
| n_blocks_to_fill = current_last_idx_p - current_first_idx_p | ||
| for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1): | ||
| cache_blocks_to_fill = state_indices_tensor_p[ | ||
| seq_idx, current_first_idx_p[seq_idx]: | ||
| current_first_idx_p[seq_idx] + | ||
| n_blocks_to_fill[seq_idx]] | ||
| # chunks = [0 1 2 3 4 5 6 ...] | ||
| # First aligned chunk would typically be: | ||
| # mamba_block_size = 1024, chunk_size = 256 | ||
| # 1024 // 256 - 1 --> chunks[3] | ||
| # But when last chunk wasn't block aligned: | ||
| # - last_computed_offset_p[seq_idx] // chunk_size | ||
| # e.g. 1000 // 256 -> 3 completed --> store chunk[0] | ||
| # e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1) | ||
| # e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2) | ||
| # e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3) | ||
| chunk_stride = mamba_block_size // chunk_size | ||
| first_aligned_chunk = \ | ||
| torch.concat([torch.zeros(1, \ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you need an additional torch.zeros?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Re-written this |
||
| dtype=last_chunk_indices_p.dtype, \ | ||
| device=last_chunk_indices_p.device), \ | ||
| last_chunk_indices_p + 1])[seq_idx] \ | ||
| + chunk_stride - 1 \ | ||
| - last_computed_offset_p[seq_idx] // chunk_size | ||
|
Comment on lines
+697
to
+713
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I can't understand this part (And also the comments in 702-705...
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've re-written the code in a way that I believe is easier to understand what is happening. We plan to fuse this code into the kernels at a later stage but I would like to make sure the code on main right now is easy to follow. |
||
| from_where = varlen_states[ | ||
| first_aligned_chunk:first_aligned_chunk + | ||
| n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride] | ||
| ssm_state[cache_blocks_to_fill] = from_where | ||
|
|
||
| #For all seqs, store the last state (Note: might be partial): | ||
| ssm_state[state_indices_tensor_p.gather(1, | ||
| current_last_idx_p.unsqueeze(1)).squeeze(1)] = \ | ||
| varlen_states[last_chunk_indices_p] | ||
| else: | ||
| # update ssm states | ||
| # - varlen state is a (num_prefills, nheads, headdim, dstate) | ||
| # tensor | ||
| ssm_state[state_indices_tensor_p] = varlen_states | ||
|
|
||
| # Process decode requests | ||
| if has_decode: | ||
| if prefix_caching_enabled: | ||
| state_indices_tensor_d_input = \ | ||
| state_indices_tensor_d.gather(1, | ||
| last_state_idx_d.unsqueeze(1)).squeeze(1) | ||
| state_indices_tensor_d_output = \ | ||
| state_indices_tensor_d.gather(1, | ||
| current_last_idx_d.unsqueeze(1)).squeeze(1) | ||
| #Note: | ||
| # for decode always: current_first_idx_d == current_last_idx_d | ||
| # at block boundaries: current_first_idx_d > last_state_idx_d | ||
| else: | ||
| # Without caching, read and write in-place to the same blocks: | ||
| state_indices_tensor_d_input = state_indices_tensor_d | ||
| state_indices_tensor_d_output = state_indices_tensor_d | ||
|
|
||
| # 2. Convolution sequence transformation | ||
| hidden_states_B_C_d = causal_conv1d_update( | ||
| hidden_states_B_C_d, | ||
| conv_state, | ||
| conv_weights, | ||
| self.conv1d.bias, | ||
| self.activation, | ||
| conv_state_indices=state_indices_tensor_d) | ||
| conv_state_indices=state_indices_tensor_d, | ||
| current_last_idx=current_last_idx_d, | ||
| initial_state_idx=last_state_idx_d, | ||
| ) | ||
|
|
||
| hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn( | ||
| hidden_states_B_C_d) | ||
|
|
@@ -689,7 +785,8 @@ def forward_cuda( | |
| z=None, | ||
| dt_bias=dt_bias, | ||
| dt_softplus=True, | ||
| state_batch_indices=state_indices_tensor_d, | ||
| state_batch_indices=state_indices_tensor_d_input, | ||
| dst_state_batch_indices=state_indices_tensor_d_output, | ||
| out=preallocated_ssm_out_d.view(num_decodes, -1, | ||
| self.head_dim), | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.