Skip to content

Commit b9926f7

Browse files
authored
Support block size 32 (#35)
1 parent ee88a7e commit b9926f7

File tree

4 files changed

+49
-5
lines changed

4 files changed

+49
-5
lines changed

cacheflow/master/block_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ def __init__(
1515
block_size: int,
1616
num_blocks: int,
1717
) -> None:
18-
if block_size not in [8, 16]:
18+
if block_size not in [8, 16, 32]:
1919
raise ValueError(f'Unsupported block size: {block_size}'
20-
'The block size must be either 8 or 16.')
20+
'The block size must be one of {8, 16, 32}.')
2121
self.device = device
2222
self.block_size = block_size
2323
self.num_blocks = num_blocks

cacheflow/master/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
174174
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
175175
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
176176
# KV cache arguments
177-
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
177+
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16, 32], help='token block size')
178178
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
179179
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
180180
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).

csrc/attention_kernels.cu

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,16 @@ void single_query_cached_kv_attention(
654654
block_tables,
655655
context_lens,
656656
max_context_len);
657+
} else if (block_size == 32) {
658+
single_query_cached_kv_attention_launcher<uint16_t, 32>(
659+
out,
660+
query,
661+
key_cache,
662+
value_cache,
663+
scale,
664+
block_tables,
665+
context_lens,
666+
max_context_len);
657667
} else {
658668
assert(false);
659669
}
@@ -679,6 +689,16 @@ void single_query_cached_kv_attention(
679689
block_tables,
680690
context_lens,
681691
max_context_len);
692+
} else if (block_size == 32) {
693+
single_query_cached_kv_attention_launcher<float, 32>(
694+
out,
695+
query,
696+
key_cache,
697+
value_cache,
698+
scale,
699+
block_tables,
700+
context_lens,
701+
max_context_len);
682702
} else {
683703
assert(false);
684704
}
@@ -834,6 +854,18 @@ void multi_query_cached_kv_attention(
834854
block_tables,
835855
context_lens,
836856
max_context_len);
857+
} else if (block_size == 32) {
858+
multi_query_cached_kv_attention_launcher<uint16_t, 32>(
859+
cu_query_lens,
860+
seq_prompt_mapping,
861+
out,
862+
query,
863+
key_cache,
864+
value_cache,
865+
scale,
866+
block_tables,
867+
context_lens,
868+
max_context_len);
837869
} else {
838870
assert(false);
839871
}
@@ -863,6 +895,18 @@ void multi_query_cached_kv_attention(
863895
block_tables,
864896
context_lens,
865897
max_context_len);
898+
} else if (block_size == 32) {
899+
multi_query_cached_kv_attention_launcher<float, 32>(
900+
cu_query_lens,
901+
seq_prompt_mapping,
902+
out,
903+
query,
904+
key_cache,
905+
value_cache,
906+
scale,
907+
block_tables,
908+
context_lens,
909+
max_context_len);
866910
} else {
867911
assert(false);
868912
}

tests/kernels/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def test_attention(seed: int) -> None:
350350
torch.random.manual_seed(seed)
351351
torch.cuda.manual_seed(seed)
352352
for dtype in [torch.half, torch.float]:
353-
for block_size in [8, 16]:
353+
for block_size in [8, 16, 32]:
354354
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
355355
print(f'Testing single_query_cached_kv_attention with '
356356
f'dtype={dtype}, block_size={block_size}, '
@@ -368,7 +368,7 @@ def test_attention(seed: int) -> None:
368368
# note that the test is also more likely to fail due to the much
369369
# larger amount of tokens in the input may increase the variance.
370370
for dtype in [torch.half, torch.float]:
371-
for block_size in [8, 16]:
371+
for block_size in [8, 16, 32]:
372372
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
373373
print(f'Testing multi_query_cached_kv_attention with '
374374
f'dtype={dtype}, block_size={block_size}, '

0 commit comments

Comments
 (0)