Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,21 @@ def __init__(self,
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
self._llm_id = None

log_level = logger.level
logger.set_level("info") # force display the backend

try:
backend = kwargs.get('backend', None)
if backend == 'pytorch':
if backend == "pytorch":
logger.info("Using LLM with PyTorch backend")
llm_args_cls = TorchLlmArgs
elif backend == '_autodeploy':
logger.info("Using LLM with AutoDeploy backend")
from .._torch.auto_deploy.llm_args import \
LlmArgs as AutoDeployLlmArgs
llm_args_cls = AutoDeployLlmArgs
else:
logger.info("Using LLM with TensorRT backend")
llm_args_cls = TrtLlmArgs

# check the kwargs and raise ValueError directly
Expand Down Expand Up @@ -160,6 +166,9 @@ def __init__(self,
f"Failed to parse the arguments for the LLM constructor: {e}")
raise e

finally:
logger.set_level(log_level) # restore the log level

print_colored_debug(f"LLM.args.mpi_session: {self.args.mpi_session}\n",
"yellow")
self.mpi_session = self.args.mpi_session
Expand Down
19 changes: 13 additions & 6 deletions tests/integration/defs/llmapi/_run_llmapi_llm.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
#!/usr/bin/env python3
import os
from typing import Optional

import click

from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm.llmapi import BuildConfig, SamplingParams
from tensorrt_llm._tensorrt_engine import LLM as TrtLLM
from tensorrt_llm.llmapi import LLM, BuildConfig, SamplingParams


@click.command()
@click.option("--model_dir", type=str, required=True)
@click.option("--tp_size", type=int, default=1)
@click.option("--engine_dir", type=str, default=None)
def main(model_dir: str, tp_size: int, engine_dir: str):
@click.option("--backend", type=str, default=None)
def main(model_dir: str, tp_size: int, engine_dir: str, backend: Optional[str]):
build_config = BuildConfig()
build_config.max_batch_size = 8
build_config.max_input_len = 256
build_config.max_seq_len = 512

llm = LLM(model_dir,
tensor_parallel_size=tp_size,
build_config=build_config)
backend = backend or "tensorrt"
assert backend in ["pytorch", "tensorrt"]

llm_cls = TrtLLM if backend == "tensorrt" else LLM

kwargs = {} if backend == "pytorch" else {"build_config": build_config}

llm = llm_cls(model_dir, tensor_parallel_size=tp_size, **kwargs)

if engine_dir is not None and os.path.abspath(
engine_dir) != os.path.abspath(model_dir):
Expand Down
70 changes: 70 additions & 0 deletions tests/integration/defs/llmapi/test_llm_api_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Confirm that the default backend is changed
import os

from defs.common import venv_check_output

from ..conftest import llm_models_root

model_path = llm_models_root() + "/llama-models-v3/llama-v3-8b-instruct-hf"


class TestLlmDefaultBackend:
"""
Check that the default backend is PyTorch for v1.0 breaking change
"""

def test_llm_args_type_default(self, llm_root, llm_venv):
# Keep the complete example code here
from tensorrt_llm.llmapi import LLM, KvCacheConfig, TorchLlmArgs

kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
llm = LLM(model=model_path, kv_cache_config=kv_cache_config)

# The default backend should be PyTorch
assert llm.args.backend == "pytorch"
assert isinstance(llm.args, TorchLlmArgs)

for output in llm.generate(["Hello, world!"]):
print(output)

def test_llm_args_type_tensorrt(self, llm_root, llm_venv):
# Keep the complete example code here
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm.llmapi import KvCacheConfig, TrtLlmArgs

kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)

llm = LLM(model=model_path, kv_cache_config=kv_cache_config)

# If the backend is TensorRT, the args should be TrtLlmArgs
assert llm.args.backend in ("tensorrt", None)
assert isinstance(llm.args, TrtLlmArgs)

for output in llm.generate(["Hello, world!"]):
print(output)

def test_llm_args_logging(self, llm_root, llm_venv):
# It should print the backend in the log
script_path = os.path.join(os.path.dirname(__file__),
"_run_llmapi_llm.py")
print(f"script_path: {script_path}")

# Test with pytorch backend
pytorch_cmd = [
script_path, "--model_dir", model_path, "--backend", "pytorch"
]

pytorch_output = venv_check_output(llm_venv, pytorch_cmd)

# Check that pytorch backend keyword appears in logs
assert "Using LLM with PyTorch backend" in pytorch_output, f"Expected 'pytorch' in logs, got: {pytorch_output}"

# Test with tensorrt backend
tensorrt_cmd = [
script_path, "--model_dir", model_path, "--backend", "tensorrt"
]

tensorrt_output = venv_check_output(llm_venv, tensorrt_cmd)

# Check that tensorrt backend keyword appears in logs
assert "Using LLM with TensorRT backend" in tensorrt_output, f"Expected 'tensorrt' in logs, got: {tensorrt_output}"
6 changes: 6 additions & 0 deletions tests/integration/test_lists/qa/llm_function_full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -677,3 +677,9 @@ disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyL
# These tests will impact triton. They should be at the end of all tests (https://nvbugs/4904271)
# examples/test_openai.py::test_llm_openai_triton_1gpu
# examples/test_openai.py::test_llm_openai_triton_plugingen_1gpu

# llm-api promote pytorch to default
llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_logging
llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_type_tensorrt
llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_type_default
llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_logging