Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# limitations under the License.
# -*- coding: utf-8 -*-
"""
Model yaml config for trtllm-bench perf tests
Model pytorch yaml config for trtllm-bench perf tests
"""


def get_model_yaml_config(model_label: str) -> dict:
def get_model_yaml_config(model_label: str, input_lens: list[int]) -> dict:
"""
Return the yaml config corresponding to the model label.
Args:
Expand All @@ -32,7 +32,6 @@ def get_model_yaml_config(model_label: str) -> dict:
'print_iter_log': True,
'use_cuda_graph': True,
'cuda_graph_padding_enabled': True,
'cuda_graph_max_batch_size': 4096,
}
}
model_configs = {
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/defs/perf/test_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
print_warning)

from ..conftest import get_llm_root, llm_models_root, trt_environment
from .model_yaml_config import get_model_yaml_config
from .pytorch_model_config import get_model_yaml_config
from .utils import (AbstractPerfScriptTestClass, PerfBenchScriptTestCmds,
PerfMetricType, PerfScriptTestCmds, generate_test_nodes)

Expand Down