Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 79 additions & 79 deletions tests/integration/defs/disaggregated/test_disaggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from defs.conftest import skip_no_hopper
from defs.trt_test_alternative import check_call, popen

from tensorrt_llm.logger import logger


def cleanup_output_files():
"""Clean up output files from previous runs."""
Expand Down Expand Up @@ -130,88 +132,86 @@ def run_disaggregated_test(example_dir,
str(server_start_timeout), '-c', config_file
]

with ( # Start workers
open('output_workers.log', 'w') as output_workers,
popen(workers_cmd,
stdout=output_workers,
stderr=subprocess.STDOUT,
env=env,
cwd=cwd),
# Start server
open('output_disagg.log', 'w') as output_disagg,
popen(server_cmd,
stdout=output_disagg,
stderr=subprocess.STDOUT,
env=env,
cwd=cwd)):
client_dir = f"{example_dir}/clients"
for _ in range(num_iters):
client_cmd = [
'python3', f'{client_dir}/disagg_client.py', '-c',
f'{example_dir}/disagg_config.yaml', '-p',
f'{client_dir}/prompts.json', '--ignore-eos',
'--server-start-timeout',
str(server_start_timeout)
]
check_call(client_cmd, env=env)

# Streaming client run
streaming_client_cmd = client_cmd + [
'--streaming', '-o', 'output_streaming.json'
]
check_call(streaming_client_cmd, env=env)

# Run the chat completion endpoint test only for TinyLlama
if test_desc == "overlap":
chat_client_cmd = client_cmd + [
'-e', 'chat', '-o', 'output_chat.json'
try:
with ( # Start workers
open('output_workers.log', 'w') as output_workers,
popen(workers_cmd,
stdout=output_workers,
stderr=subprocess.STDOUT,
env=env,
cwd=cwd),
# Start server
open('output_disagg.log', 'w') as output_disagg,
popen(server_cmd,
stdout=output_disagg,
stderr=subprocess.STDOUT,
env=env,
cwd=cwd)):
client_dir = f"{example_dir}/clients"
for _ in range(num_iters):
client_cmd = [
'python3', f'{client_dir}/disagg_client.py', '-c',
f'{example_dir}/disagg_config.yaml', '-p',
f'{client_dir}/prompts.json', '--ignore-eos',
'--server-start-timeout',
str(server_start_timeout)
]
check_call(chat_client_cmd, env=env)
check_call(client_cmd, env=env)

streaming_chat_client_cmd = chat_client_cmd + [
'--streaming', '-o', 'output_streaming_chat.json'
# Streaming client run
streaming_client_cmd = client_cmd + [
'--streaming', '-o', 'output_streaming.json'
]
check_call(streaming_chat_client_cmd, env=env)

# Verify outputs
not_expected_strings = ["Berlin Berlin"]

output_files = ['output.json', 'output_streaming.json']
if test_desc == "overlap":
# Disable streaming chat completion for overlap test
# due to bug
output_files.extend(['output_chat.json'])

if test_desc.startswith("gen_only"):
continue

for output_file in output_files:
with open(output_file, 'r') as f:
content = f.read()
if "deepseek_v3_lite" in test_desc or output_file == "output_chat.json":
expected_strings = ["Berlin", "Asyncio is a"]
else:
expected_strings = [
"The capital of Germany is Berlin",
"Asyncio is a Python library"
]
for expected_string in expected_strings:
assert expected_string in content, f"Expected string '{expected_string}' not found in {output_file}"
for not_expected_string in not_expected_strings:
assert not_expected_string not in content, f"Unexpected string '{not_expected_string}' found in {output_file}"

# Print outputs
print("------------------")
print("Workers output:")
print("------------------")
with open('output_workers.log', 'r') as f:
print(f.read())

print("\n\n------------------")
print("Disagg server output")
print("------------------")
with open('output_disagg.log', 'r') as f:
print(f.read())
check_call(streaming_client_cmd, env=env)

# Run the chat completion endpoint test only for TinyLlama
if test_desc == "overlap":
chat_client_cmd = client_cmd + [
'-e', 'chat', '-o', 'output_chat.json'
]
check_call(chat_client_cmd, env=env)

streaming_chat_client_cmd = chat_client_cmd + [
'--streaming', '-o', 'output_streaming_chat.json'
]
check_call(streaming_chat_client_cmd, env=env)

# Verify outputs
not_expected_strings = ["Berlin Berlin"]

output_files = ['output.json', 'output_streaming.json']
if test_desc == "overlap":
# Disable streaming chat completion for overlap test
# due to bug
output_files.extend(['output_chat.json'])

if test_desc.startswith("gen_only"):
continue

for output_file in output_files:
with open(output_file, 'r') as f:
content = f.read()
if "deepseek_v3_lite" in test_desc or output_file == "output_chat.json":
expected_strings = ["Berlin", "Asyncio is a"]
else:
expected_strings = [
"The capital of Germany is Berlin",
"Asyncio is a Python library"
]
for expected_string in expected_strings:
assert expected_string in content, f"Expected string '{expected_string}' not found in {output_file}"
for not_expected_string in not_expected_strings:
assert not_expected_string not in content, f"Unexpected string '{not_expected_string}' found in {output_file}"
except Exception:
# Print outputs on error
logger.error("-------- Workers output --------")
with open('output_workers.log', 'r') as f:
logger.error(f.read())

logger.error("-------- Disagg server output --------")
with open('output_disagg.log', 'r') as f:
logger.error(f.read())
raise


@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
Expand Down
21 changes: 12 additions & 9 deletions tests/integration/defs/disaggregated/test_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import os
import subprocess
from typing import List, Optional, Tuple
from typing import Generator, List, Optional, Tuple

import aiohttp
import pytest
Expand Down Expand Up @@ -40,7 +40,7 @@ def run_disaggregated_workers(
env: Optional[dict] = None,
cwd: Optional[str] = None,
num_ranks: Optional[int] = None
) -> Tuple[subprocess.Popen, List[str], List[str]]:
) -> Tuple[Generator[subprocess.Popen, None, None], List[str], List[str]]:

ctx_servers, gen_servers = get_ctx_gen_server_urls_from_cfg(config_file)

Expand Down Expand Up @@ -424,7 +424,7 @@ async def test_eviction(self):
# send a dummy request for initialization
dummy_request = {
"model": MODEL_NAME,
"prompt": [3] * 100,
"prompt": [3] * 200,
"max_tokens": 1,
"ignore_eos": True,
"temperature": 0.0,
Expand Down Expand Up @@ -509,10 +509,14 @@ def background_workers(llm_venv, config_file: str, num_ranks: int = None):
env=llm_venv._new_env,
cwd=cwd,
num_ranks=num_ranks)
with workers_proc as proc:
yield ctx_servers, gen_servers
proc.terminate()
proc.wait()
try:
with workers_proc as proc:
yield ctx_servers, gen_servers
except Exception:
log_file.seek(0)
logger.error("-------- Worker output --------")
logger.error(log_file.read())
raise


@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
Expand Down Expand Up @@ -552,7 +556,6 @@ def test_workers_kv_cache_events(disaggregated_test_root,
def test_workers_kv_cache_aware_router(disaggregated_test_root,
disaggregated_example_root, llm_venv,
llama_model_root):
pytest.skip("https://nvbugspro.nvidia.com/bug/5301492")
config_file = os.path.join(
disaggregated_test_root,
'test_configs/disagg_config_cache_aware_balance.yaml')
Expand All @@ -562,7 +565,7 @@ def test_workers_kv_cache_aware_router(disaggregated_test_root,
4) as (ctx_servers, gen_servers):
tester = KvCacheAwareRouterTester(ctx_servers, gen_servers)
prompts = load_default_prompts(disaggregated_example_root)
asyncio.run(tester.test_multi_round_request(prompts, 6, 4))
asyncio.run(tester.test_multi_round_request(prompts, 16, 4))


@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
Expand Down