From 951d69c5d8f5091c6b9e14b012cd5916b72e7751 Mon Sep 17 00:00:00 2001 From: Vladimir Date: Fri, 26 Jan 2024 03:52:26 -0800 Subject: [PATCH 1/2] add swap_blocks unit tests #2583 --- tests/kernels/test_cache.py | 68 +++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 7b1cc058f2cb..0455ab99cb1d 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -3,8 +3,11 @@ import pytest import torch +from typing import Tuple + from vllm._C import cache_ops +COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing @@ -149,3 +152,68 @@ def test_reshape_and_cache( assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(value_cache, cloned_value_cache) + + +@pytest.mark.parametrize("direction", COPYING_DIRECTION) +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_swap_blocks( + kv_cache_factory, + direction: Tuple[str, str], + num_mappings: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + src_device = f"{direction[0]}:{device}" if direction[ + 0] == "cuda" else direction[0] + dst_device = f"{direction[1]}:{device}" if direction[ + 1] == "cuda" else direction[1] + + src_blocks = random.sample(range(num_blocks), num_mappings) + # For the same device, mapping must not overlap + if src_device == dst_device: + remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) + dst_blocks = random.sample(remaining_blocks, num_mappings) + else: + dst_blocks = random.sample(range(num_blocks), num_mappings) + + block_mapping = dict(zip(src_blocks, dst_blocks)) + + # Create the KV caches on the first device. + src_key_caches, src_value_caches = kv_cache_factory( + num_blocks, block_size, 1, num_heads, head_size, dtype, seed, + src_device) + + # Create the KV caches on the second device. + dist_key_caches, dist_value_caches = kv_cache_factory( + num_blocks, block_size, 1, num_heads, head_size, dtype, seed, + dst_device) + + src_key_caches_clone = src_key_caches[0].clone() + src_value_caches_clone = src_value_caches[0].clone() + + # # Call the swap_blocks kernel. + cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) + cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], + block_mapping) + + for src, dst in block_mapping.items(): + assert torch.allclose(src_key_caches_clone[src].cpu(), + dist_key_caches[0][dst].cpu()) + assert torch.allclose(src_value_caches_clone[src].cpu(), + dist_value_caches[0][dst].cpu()) From 05f8222e7df7da6cb6ab4e2b06d236e12db2af89 Mon Sep 17 00:00:00 2001 From: Vladimir Date: Tue, 30 Jan 2024 15:03:18 +0100 Subject: [PATCH 2/2] Update tests/kernels/test_cache.py Co-authored-by: Lily Liu --- tests/kernels/test_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 0455ab99cb1d..7bc64699ec1c 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -207,7 +207,7 @@ def test_swap_blocks( src_key_caches_clone = src_key_caches[0].clone() src_value_caches_clone = src_value_caches[0].clone() - # # Call the swap_blocks kernel. + # Call the swap_blocks kernel. cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)