diff --git a/tensorrt_llm/serve/scripts/benchmark_dataset.py b/tensorrt_llm/serve/scripts/benchmark_dataset.py index 494c68f526f..1948d32892a 100644 --- a/tensorrt_llm/serve/scripts/benchmark_dataset.py +++ b/tensorrt_llm/serve/scripts/benchmark_dataset.py @@ -25,6 +25,7 @@ import numpy as np import pandas as pd +from benchmark_utils import download_and_cache_file from datasets import load_dataset from transformers import PreTrainedTokenizerBase @@ -242,14 +243,23 @@ class ShareGPTDataset(BenchmarkDataset): Implements the ShareGPT dataset. Loads data from a JSON file and generates sample requests based on conversation turns. """ + URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" - def __init__(self, **kwargs) -> None: + def __init__(self, + download_timeout: int, + download_path: Optional[str] = None, + **kwargs) -> None: super().__init__(**kwargs) - self.load_data() + self.load_data(download_timeout, download_path) - def load_data(self) -> None: + def load_data(self, + download_timeout: int, + download_path: Optional[str] = None) -> None: if self.dataset_path is None: - raise ValueError("dataset_path must be provided for loading data.") + logger.warning("dataset_path is not provided") + self.dataset_path = download_and_cache_file( + ShareGPTDataset.URL, download_path, + ShareGPTDataset.URL.split("/")[-1], download_timeout) with open(self.dataset_path, encoding="utf-8") as f: self.data = json.load(f) diff --git a/tensorrt_llm/serve/scripts/benchmark_serving.py b/tensorrt_llm/serve/scripts/benchmark_serving.py index 70f4576fcfa..89e452671fe 100644 --- a/tensorrt_llm/serve/scripts/benchmark_serving.py +++ b/tensorrt_llm/serve/scripts/benchmark_serving.py @@ -612,7 +612,9 @@ def main(args: argparse.Namespace): # For datasets that follow a similar structure, use a mapping. dataset_mapping = { "sharegpt": - lambda: ShareGPTDataset(random_seed=args.seed, + lambda: ShareGPTDataset(download_path=args.download_path, + download_timeout=args.download_timeout, + random_seed=args.seed, dataset_path=args.dataset_path).sample( tokenizer=tokenizer, num_requests=args.num_prompts, @@ -778,6 +780,16 @@ def main(args: argparse.Namespace): default=None, help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.") + parser.add_argument( + "--download-path", + type=str, + default=None, + help="Path to download dataset if dataset-path is None, " + "only sharegpt is supported for now") + parser.add_argument("--download-timeout", + type=int, + default=180, + help="Timeout for downloading datasets") parser.add_argument( "--max-concurrency", type=int, diff --git a/tensorrt_llm/serve/scripts/benchmark_utils.py b/tensorrt_llm/serve/scripts/benchmark_utils.py index 638ac22b679..946a6ceeb07 100644 --- a/tensorrt_llm/serve/scripts/benchmark_utils.py +++ b/tensorrt_llm/serve/scripts/benchmark_utils.py @@ -6,7 +6,10 @@ import json import math import os -from typing import Any +from typing import Any, Optional + +import requests +from tqdm.asyncio import tqdm def convert_to_pytorch_benchmark_format(args: argparse.Namespace, @@ -69,3 +72,60 @@ def iterencode(self, o: Any, *args, **kwargs) -> Any: def write_to_json(filename: str, records: list) -> None: with open(filename, "w") as f: json.dump(records, f, cls=InfEncoder) + + +def download_and_cache_file(url: str, path: Optional[str], name: str, + timeout: int) -> str: + # Adapted from + # https://github.com/sgl-project/sglang/blob/58f10679e1850fdc86046057c23bac5193156de9/python/sglang/bench_serving.py#L586 + """Read and cache a file from a url.""" + + # Check if the path is valid and if the file exists + if path is None or not os.path.exists(path): + raise ValueError("download_path is not provided or does not exist") + filename = os.path.join(path, name) + + if is_file_valid_json(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True, timeout=timeout) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as f, tqdm( + desc=filename, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + + return filename + + +def is_file_valid_json(path) -> bool: + # Adapted from + # https://github.com/sgl-project/sglang/blob/58f10679e1850fdc86046057c23bac5193156de9/python/sglang/bench_serving.py#L620 + if not os.path.isfile(path): + return False + + # TODO can fuse into the real file open later + try: + with open(path) as f: + json.load(f) + return True + except json.JSONDecodeError as e: + print( + f"{path} exists but json loading fails ({e=}), thus treat as invalid file" + ) + return False