From 8ce55630bb5c2f617d1eeac57e0772826867e3e1 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Thu, 18 May 2023 00:04:18 +0000 Subject: [PATCH] Use pytest format --- .../{activation.py => test_activation.py} | 8 ++--- .../{attention.py => test_attention.py} | 28 ++++++++--------- tests/kernels/{cache.py => test_cache.py} | 30 +++++++++++-------- .../{layernorm.py => test_layernorm.py} | 6 ++-- .../{pos_encoding.py => test_pos_encoding.py} | 8 ++--- 5 files changed, 43 insertions(+), 37 deletions(-) rename tests/kernels/{activation.py => test_activation.py} (82%) rename tests/kernels/{attention.py => test_attention.py} (94%) rename tests/kernels/{cache.py => test_cache.py} (92%) rename tests/kernels/{layernorm.py => test_layernorm.py} (95%) rename tests/kernels/{pos_encoding.py => test_pos_encoding.py} (96%) diff --git a/tests/kernels/activation.py b/tests/kernels/test_activation.py similarity index 82% rename from tests/kernels/activation.py rename to tests/kernels/test_activation.py index b35bea61d04d..4d7d2a4bf702 100644 --- a/tests/kernels/activation.py +++ b/tests/kernels/test_activation.py @@ -10,7 +10,7 @@ def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor: @torch.inference_mode() -def test_silu_and_mul( +def run_silu_and_mul( num_tokens: int, d: int, dtype: torch.dtype, @@ -22,9 +22,9 @@ def test_silu_and_mul( assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) -if __name__ == '__main__': +def test_silu_and_mul() -> None: for dtype in [torch.half, torch.bfloat16, torch.float]: for num_tokens in [7, 83, 2048]: - for d in [512, 4096, 13824]: + for d in [512, 4096, 5120, 13824]: print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') - test_silu_and_mul(num_tokens, d, dtype) + run_silu_and_mul(num_tokens, d, dtype) diff --git a/tests/kernels/attention.py b/tests/kernels/test_attention.py similarity index 94% rename from tests/kernels/attention.py rename to tests/kernels/test_attention.py index ae46fd6bcc06..49c745a53678 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/test_attention.py @@ -8,6 +8,7 @@ from cacheflow import attention_ops MAX_SEQ_LEN = 4096 +TEST_SEED = 0 def ref_masked_attention( @@ -155,7 +156,8 @@ def ref_multi_query_cached_kv_attention( return ref_output -def test_single_query_cached_kv_attention( +@torch.inference_mode() +def run_single_query_cached_kv_attention( num_tokens: int, num_heads: int, head_size: int, @@ -223,7 +225,8 @@ def test_single_query_cached_kv_attention( assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) -def test_multi_query_kv_attention( +@torch.inference_mode() +def run_multi_query_kv_attention( num_seqs: int, num_heads: int, head_size: int, @@ -264,19 +267,16 @@ def test_multi_query_kv_attention( assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) -@torch.inference_mode() -def test_attention(seed: int) -> None: - # NOTE(woosuk): Even when the seed is fixed, there is a chance that - # the test fails due to the precision issue. Re-run the test if it fails. - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) +def test_single_query_cached_kv_attention() -> None: + torch.random.manual_seed(TEST_SEED) + torch.cuda.manual_seed(TEST_SEED) for dtype in [torch.half, torch.bfloat16]: for block_size in [8, 16, 32, 64]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: print(f'Testing single_query_cached_kv_attention with ' f'dtype={dtype}, block_size={block_size}, ' f'head_size={head_size}') - test_single_query_cached_kv_attention( + run_single_query_cached_kv_attention( num_tokens=37, num_heads=3, head_size=head_size, @@ -285,17 +285,17 @@ def test_attention(seed: int) -> None: dtype=dtype, ) + +def test_multi_query_kv_attention() -> None: + torch.random.manual_seed(TEST_SEED) + torch.cuda.manual_seed(TEST_SEED) for dtype in [torch.half, torch.bfloat16]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: print(f'Testing multi_query_kv_attention with dtype={dtype}, ' f'head_size={head_size}') - test_multi_query_kv_attention( + run_multi_query_kv_attention( num_seqs=5, num_heads=3, head_size=head_size, dtype=dtype, ) - - -if __name__ == '__main__': - test_attention(seed=0) diff --git a/tests/kernels/cache.py b/tests/kernels/test_cache.py similarity index 92% rename from tests/kernels/cache.py rename to tests/kernels/test_cache.py index b750ca97e985..25f16b13da2b 100644 --- a/tests/kernels/cache.py +++ b/tests/kernels/test_cache.py @@ -5,7 +5,8 @@ from cacheflow import cache_ops -def test_copy_blocks( +@torch.inference_mode() +def run_copy_blocks( num_mappings: int, num_layers: int, num_heads: int, @@ -60,7 +61,8 @@ def test_copy_blocks( assert torch.allclose(value_cache, cloned_value_cache) -def test_reshape_and_cache( +@torch.inference_mode() +def run_reshape_and_cache( num_tokens: int, num_heads: int, head_size: int, @@ -99,7 +101,8 @@ def test_reshape_and_cache( assert torch.allclose(value_cache, cloned_value_cache) -def test_gather_cached_kv( +@torch.inference_mode() +def run_gather_cached_kv( num_tokens: int, num_heads: int, head_size: int, @@ -140,19 +143,22 @@ def test_gather_cached_kv( assert torch.allclose(value, cloned_value) -@torch.inference_mode() -def test_cache() -> None: +def test_copy_blocks() -> None: for dtype in [torch.half, torch.bfloat16, torch.float]: - test_copy_blocks( + run_copy_blocks( num_mappings=23, num_layers=7, num_heads=17, head_size=16, block_size=8, num_blocks=1024, dtype=dtype) - test_reshape_and_cache( - num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, - dtype=dtype) - test_gather_cached_kv( + + +def test_reshape_and_cache() -> None: + for dtype in [torch.half, torch.bfloat16, torch.float]: + run_reshape_and_cache( num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, dtype=dtype) -if __name__ == '__main__': - test_cache() +def test_gather_cached_kv() -> None: + for dtype in [torch.half, torch.bfloat16, torch.float]: + run_gather_cached_kv( + num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, + dtype=dtype) diff --git a/tests/kernels/layernorm.py b/tests/kernels/test_layernorm.py similarity index 95% rename from tests/kernels/layernorm.py rename to tests/kernels/test_layernorm.py index a61fa9b67aa7..6fe1ca5cfda6 100644 --- a/tests/kernels/layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -22,7 +22,7 @@ def forward(self, hidden_states): @torch.inference_mode() -def test_rms_norm( +def run_rms_norm( num_tokens: int, hidden_size: int, dtype: torch.dtype, @@ -41,13 +41,13 @@ def test_rms_norm( assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5) -if __name__ == '__main__': +def test_rms_norm() -> None: for dtype in [torch.half, torch.bfloat16, torch.float]: for num_tokens in [7, 128, 2048]: for hidden_size in [13, 64, 1024, 5120]: print(f'Testing RMS kernel with dtype={dtype}, num_tokens=' f'{num_tokens}, hidden_size={hidden_size}') - test_rms_norm( + run_rms_norm( num_tokens=num_tokens, hidden_size=hidden_size, dtype=dtype, diff --git a/tests/kernels/pos_encoding.py b/tests/kernels/test_pos_encoding.py similarity index 96% rename from tests/kernels/pos_encoding.py rename to tests/kernels/test_pos_encoding.py index 16b3992a30f1..8299cd0e608a 100644 --- a/tests/kernels/pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -76,7 +76,7 @@ def forward( @torch.inference_mode() -def test_rotary_embedding_neox( +def run_rotary_embedding_neox( num_tokens: int, num_heads: int, head_size: int, @@ -128,15 +128,15 @@ def test_rotary_embedding_neox( assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) -if __name__ == '__main__': +def test_rotary_embedding_neox() -> None: for dtype in [torch.half, torch.bfloat16, torch.float]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: print(f'Running tests for head_size={head_size} and dtype={dtype}') - test_rotary_embedding_neox( + run_rotary_embedding_neox( num_tokens=2145, num_heads=5, head_size=head_size, max_position=8192, - rotary_dim=int(head_size * 0.25), + rotary_dim=head_size, dtype=dtype, )