Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions tensorrt_llm/serve/scripts/benchmark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import numpy as np
import pandas as pd
from benchmark_utils import download_and_cache_file
from datasets import load_dataset
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -242,14 +243,23 @@ class ShareGPTDataset(BenchmarkDataset):
Implements the ShareGPT dataset. Loads data from a JSON file and generates
sample requests based on conversation turns.
"""
URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"

def __init__(self, **kwargs) -> None:
def __init__(self,
download_timeout: int,
download_path: Optional[str] = None,
**kwargs) -> None:
super().__init__(**kwargs)
self.load_data()
self.load_data(download_timeout, download_path)

def load_data(self) -> None:
def load_data(self,
download_timeout: int,
download_path: Optional[str] = None) -> None:
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")
logger.warning("dataset_path is not provided")
self.dataset_path = download_and_cache_file(
ShareGPTDataset.URL, download_path,
ShareGPTDataset.URL.split("/")[-1], download_timeout)

with open(self.dataset_path, encoding="utf-8") as f:
self.data = json.load(f)
Expand Down
14 changes: 13 additions & 1 deletion tensorrt_llm/serve/scripts/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,9 @@ def main(args: argparse.Namespace):
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"sharegpt":
lambda: ShareGPTDataset(random_seed=args.seed,
lambda: ShareGPTDataset(download_path=args.download_path,
download_timeout=args.download_timeout,
random_seed=args.seed,
dataset_path=args.dataset_path).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
Expand Down Expand Up @@ -778,6 +780,16 @@ def main(args: argparse.Namespace):
default=None,
help="Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset.")
parser.add_argument(
"--download-path",
type=str,
default=None,
help="Path to download dataset if dataset-path is None, "
"only sharegpt is supported for now")
parser.add_argument("--download-timeout",
type=int,
default=180,
help="Timeout for downloading datasets")
parser.add_argument(
"--max-concurrency",
type=int,
Expand Down
62 changes: 61 additions & 1 deletion tensorrt_llm/serve/scripts/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import json
import math
import os
from typing import Any
from typing import Any, Optional

import requests
from tqdm.asyncio import tqdm


def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
Expand Down Expand Up @@ -69,3 +72,60 @@ def iterencode(self, o: Any, *args, **kwargs) -> Any:
def write_to_json(filename: str, records: list) -> None:
with open(filename, "w") as f:
json.dump(records, f, cls=InfEncoder)


def download_and_cache_file(url: str, path: Optional[str], name: str,
timeout: int) -> str:
# Adapted from
# https://github.com/sgl-project/sglang/blob/58f10679e1850fdc86046057c23bac5193156de9/python/sglang/bench_serving.py#L586
"""Read and cache a file from a url."""

# Check if the path is valid and if the file exists
if path is None or not os.path.exists(path):
raise ValueError("download_path is not provided or does not exist")
filename = os.path.join(path, name)

if is_file_valid_json(filename):
return filename

print(f"Downloading from {url} to {filename}")

# Stream the response to show the progress bar
response = requests.get(url, stream=True, timeout=timeout)
response.raise_for_status() # Check for request errors

# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB

# Use tqdm to display the progress bar
with open(filename, "wb") as f, tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))

return filename


def is_file_valid_json(path) -> bool:
# Adapted from
# https://github.com/sgl-project/sglang/blob/58f10679e1850fdc86046057c23bac5193156de9/python/sglang/bench_serving.py#L620
if not os.path.isfile(path):
return False

# TODO can fuse into the real file open later
try:
with open(path) as f:
json.load(f)
return True
except json.JSONDecodeError as e:
print(
f"{path} exists but json loading fails ({e=}), thus treat as invalid file"
)
return False