Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions tests/unittest/llmapi/test_llm_kv_cache_events.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import asyncio
import time

import pytest

import tensorrt_llm
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm import LLM
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._utils import KVCacheEventSerializer
Expand All @@ -16,7 +14,6 @@

default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
llama_model_path = get_model_path(default_model_name)

global_kvcache_config = KvCacheConfig(free_gpu_memory_fraction=0.4,
event_buffer_max_size=1024,
enable_block_reuse=True,
Expand Down Expand Up @@ -50,8 +47,7 @@ def create_llm(tensor_parallel_size=1):
return LLM(model=llama_model_path,
tensor_parallel_size=tensor_parallel_size,
kv_cache_config=global_kvcache_config,
enable_autotuner=False,
backend="pytorch")
enable_autotuner=False)


def create_llm_request(id, input_tokens, new_tokens=1):
Expand Down Expand Up @@ -103,7 +99,6 @@ def test_kv_cache_event_data_serialization():
serialized_event = KVCacheEventSerializer.serialize(events)


@pytest.mark.skip(reason="https://nvbugs/5362412")
def test_expected_kv_cache_events():
llm = create_llm()
sampling_params = SamplingParams(max_tokens=6, temperature=0.01)
Expand All @@ -122,7 +117,6 @@ def test_expected_kv_cache_events():
assert event["data"]["type"] == "stored"


@pytest.mark.skip(reason="https://nvbugs/5362412")
def test_kv_cache_event_async_api():
llm = create_llm()
sampling_params = SamplingParams(max_tokens=6, temperature=0.01)
Expand Down Expand Up @@ -150,7 +144,6 @@ async def main():
asyncio.run(main())


@pytest.mark.skip(reason="https://nvbugs/5362412")
def test_llm_kv_events_api():
llm = create_llm()
sampling_params = SamplingParams(max_tokens=6, temperature=0.01)
Expand Down