diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 5a8c68643ee..1a590aaebad 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -529,8 +529,13 @@ def _check_arguments(self, prompt_len: int, query_len: int, raise ValueError( f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead." ) - return - elif self.args.backend == "_autodeploy": + # Check prompt length and query length against max_num_tokens to filter illegal requests. + if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill: + max_num_tokens = self.args.max_num_tokens + if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens: + raise ValueError( + f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) and max_tokens ({sampling_params.max_tokens}) should not exceed " + f"max_num_tokens ({max_num_tokens})") return build_config = self.args.build_config @@ -547,7 +552,7 @@ def _check_arguments(self, prompt_len: int, query_len: int, (sampling_params.max_tokens or 0) > max_seq_len): raise ValueError( f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}) and query length ({query_len}) max_tokens ({sampling_params.max_tokens}) should not exceed " - f"max_seq_len ({build_config.max_seq_len})") + f"max_seq_len ({max_seq_len})") if sampling_params.use_beam_search and sampling_params.best_of > build_config.max_beam_width: if sampling_params.n == sampling_params.best_of: diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index 043c13c22f3..0631496ab17 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -2061,24 +2061,37 @@ def success_path(): success_path() -def _test_llm_capture_request_error(tp_size: int = 1): - build_config = BuildConfig() - build_config.max_num_tokens = 64 +def _test_llm_capture_request_error(pytorch_backend: bool, tp_size: int = 1): + llm_args_extra = {} + if pytorch_backend: + from tensorrt_llm._torch import LLM as LLM_torch + LLM_CLASS = LLM_torch + llm_args_extra["max_num_tokens"] = 64 + else: + LLM_CLASS = LLM + build_config = BuildConfig() + build_config.max_num_tokens = 64 + llm_args_extra["fast_build"] = True + llm_args_extra["build_config"] = build_config - llm = LLM( + llm = LLM_CLASS( model=llama_model_path, - build_config=build_config, - fast_build=True, + tensor_parallel_size=tp_size, + **llm_args_extra, ) prompt = 'A ' * 65 # the minimum max_num_tokens is 64 - - with pytest.raises(RequestError): - llm.generate(prompt) + if pytorch_backend: + # pytorch backend will raise ValueError for max_num_tokens + with pytest.raises(ValueError): + llm.generate(prompt) + else: + with pytest.raises(RequestError): + llm.generate(prompt) def test_llm_capture_request_error(): - _test_llm_capture_request_error(tp_size=1) + _test_llm_capture_request_error(pytorch_backend=False, tp_size=1) def test_llm_api_jupyter_scenario(): diff --git a/tests/unittest/llmapi/test_llm_multi_gpu.py b/tests/unittest/llmapi/test_llm_multi_gpu.py index f14b358f63a..55b9d6f4c25 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu.py @@ -452,7 +452,7 @@ def test_llm_get_stats_async_tp2(pytorch_backend): def test_llm_capture_request_error(): - _test_llm_capture_request_error(tp_size=2) + _test_llm_capture_request_error(pytorch_backend=False, tp_size=2) def test_llm_with_postprocess_parallel_tp2(): diff --git a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py index 8dc1450f339..55ba1927eea 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py @@ -5,12 +5,17 @@ from tensorrt_llm.llmapi import KvCacheConfig from .test_llm_pytorch import (llama_v2_13b_lora_test_harness, llama_7b_multi_lora_test_harness) - +from .test_llm import _test_llm_capture_request_error # isort: on global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) +@pytest.mark.gpu2 +def test_llm_capture_request_error(): + _test_llm_capture_request_error(pytorch_backend=True, tp_size=2) + + @pytest.mark.gpu4 def test_tinyllama_logits_processor_tp2pp2(): tinyllama_logits_processor_test_harness(backend="pytorch", diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 65f3d16ac69..411ccfb8158 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -4,12 +4,11 @@ from tensorrt_llm.sampling_params import SamplingParams # isort: off -from .test_llm import (get_model_path, global_kvcache_config, llama_model_path, - llm_get_stats_async_test_harness, - llm_get_stats_test_harness, prompts, - run_llm_abort_request, - run_llm_with_postprocess_parallel_and_result_handler, - tinyllama_logits_processor_test_harness) +from .test_llm import ( + get_model_path, global_kvcache_config, llama_model_path, + llm_get_stats_async_test_harness, llm_get_stats_test_harness, prompts, + run_llm_abort_request, run_llm_with_postprocess_parallel_and_result_handler, + tinyllama_logits_processor_test_harness, _test_llm_capture_request_error) from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_gpu_memory_less_than_80gb, skip_gpu_memory_less_than_138gb from utils.llm_data import llm_models_root from tensorrt_llm.lora_manager import LoraConfig @@ -64,6 +63,10 @@ def test_llm_get_stats_async(return_context_logits, use_overlap, enable_iter_req_stats=enable_iter_req_stats) +def test_llm_capture_request_error(): + _test_llm_capture_request_error(pytorch_backend=True, tp_size=1) + + @force_ampere @pytest.mark.parametrize( "sampling_params",