From 589e2552a969e2c169c2fbeef2aaf310b1946e78 Mon Sep 17 00:00:00 2001 From: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> Date: Tue, 22 Jul 2025 11:31:21 -0700 Subject: [PATCH] fix: Fixing kv_cache_events unit tests Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> --- tests/unittest/llmapi/test_llm_kv_cache_events.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/unittest/llmapi/test_llm_kv_cache_events.py b/tests/unittest/llmapi/test_llm_kv_cache_events.py index 718cd531dda..f5efbe2bcf8 100644 --- a/tests/unittest/llmapi/test_llm_kv_cache_events.py +++ b/tests/unittest/llmapi/test_llm_kv_cache_events.py @@ -1,10 +1,8 @@ import asyncio import time -import pytest - import tensorrt_llm -from tensorrt_llm._tensorrt_engine import LLM +from tensorrt_llm import LLM from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm._utils import KVCacheEventSerializer @@ -16,7 +14,6 @@ default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" llama_model_path = get_model_path(default_model_name) - global_kvcache_config = KvCacheConfig(free_gpu_memory_fraction=0.4, event_buffer_max_size=1024, enable_block_reuse=True, @@ -50,8 +47,7 @@ def create_llm(tensor_parallel_size=1): return LLM(model=llama_model_path, tensor_parallel_size=tensor_parallel_size, kv_cache_config=global_kvcache_config, - enable_autotuner=False, - backend="pytorch") + enable_autotuner=False) def create_llm_request(id, input_tokens, new_tokens=1): @@ -103,7 +99,6 @@ def test_kv_cache_event_data_serialization(): serialized_event = KVCacheEventSerializer.serialize(events) -@pytest.mark.skip(reason="https://nvbugs/5362412") def test_expected_kv_cache_events(): llm = create_llm() sampling_params = SamplingParams(max_tokens=6, temperature=0.01) @@ -122,7 +117,6 @@ def test_expected_kv_cache_events(): assert event["data"]["type"] == "stored" -@pytest.mark.skip(reason="https://nvbugs/5362412") def test_kv_cache_event_async_api(): llm = create_llm() sampling_params = SamplingParams(max_tokens=6, temperature=0.01) @@ -150,7 +144,6 @@ async def main(): asyncio.run(main()) -@pytest.mark.skip(reason="https://nvbugs/5362412") def test_llm_kv_events_api(): llm = create_llm() sampling_params = SamplingParams(max_tokens=6, temperature=0.01)