Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion vllm/benchmarks/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
default="random",
choices=[
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
"custom", "prefix_repetition"
"custom", "prefix_repetition", "spec_bench"
],
help="Name of the dataset to benchmark on.",
)
Expand Down Expand Up @@ -1053,6 +1053,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"Skip applying chat template to prompt, used only for custom dataset.",
)

spec_bench_group = parser.add_argument_group("spec bench dataset options")
spec_bench_group.add_argument(
"--spec-bench-output-len",
type=int,
default=256,
help=
"Num of output tokens per request, used only for spec bench dataset.",
)
spec_bench_group.add_argument(
"--spec-bench-category",
type=str,
default=None,
help=
"Category for spec bench dataset. If None, use all categories.",
)

sonnet_group = parser.add_argument_group("sonnet dataset options")
sonnet_group.add_argument(
"--sonnet-input-len",
Expand Down Expand Up @@ -1404,6 +1420,14 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
else:
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"spec_bench":
lambda: SpecBench(dataset_path=args.dataset_path,
category=args.spec_bench_category).sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
output_len=args.spec_bench_output_len,
request_id_prefix=args.request_id_prefix,
),
"sharegpt": lambda: ShareGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
Expand Down Expand Up @@ -1541,6 +1565,14 @@ def sample(
request_id_prefix: str = "",
**kwargs,
) -> list:
# load all data if needed
self.num_available_samples = len(self.data)
if num_requests <= 0:
num_requests = self.num_available_samples
logger.info("num_requests is set to 0 or negative, "
"so using all available samples: %d",
num_requests)

sampled_requests = []
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
Expand Down Expand Up @@ -1572,6 +1604,52 @@ def sample(
return sampled_requests


# -----------------------------------------------------------------------------
# Spec Bench Dataset Implementation
# -----------------------------------------------------------------------------


class SpecBench(CustomDataset):
"""
Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
Download the dataset using:
wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
""" # noqa: E501

def __init__(self, **kwargs) -> None:
self.category = kwargs.pop("category", None)
super().__init__(**kwargs)
self.load_data()

def load_data(self) -> None:
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")

self.data = []

# Load the JSONL file
jsonl_data = pd.read_json(path_or_buf=self.dataset_path,
lines=True)

# check if the JSONL file has a 'turns' column
if "turns" not in jsonl_data.columns:
raise ValueError("JSONL file must contain a 'turns' column.")

for _, row in jsonl_data.iterrows():
# sample only from a specific category if specified
if (not self.category) or (self.category == row['category']):
prompt = row["turns"][0]
self.data.append({"prompt": prompt})

random.seed(self.random_seed)
random.shuffle(self.data)

def sample(self, **kwargs) -> list:
# leverage CustomDataset sample
kwargs["skip_chat_template"] = False
return super().sample(**kwargs)


# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------
Expand Down