diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index ccbc6c022f1f..9c614baf1f0c 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -771,6 +771,60 @@ def sample(self, return sampled_requests +# ----------------------------------------------------------------------------- +# MT-Bench Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MTBenchDataset(HuggingFaceDataset): + """ + MT-Bench Dataset. + https://huggingface.co/datasets/philschmid/mt-bench + + We create a single turn dataset for MT-Bench. + This is similar to Spec decoding benchmark setup in vLLM + https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 + """ # noqa: E501 + + DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM + SUPPORTED_DATASET_PATHS = { + "philschmid/mt-bench", + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item['turns'][0] + + # apply template + prompt = tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + # ----------------------------------------------------------------------------- # AIMO Dataset Implementation # ----------------------------------------------------------------------------- diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index da124e1a81b4..c236d64261d0 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -52,9 +52,9 @@ from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset, ConversationDataset, HuggingFaceDataset, - InstructCoderDataset, RandomDataset, - SampleRequest, ShareGPTDataset, SonnetDataset, - VisionArenaDataset) + InstructCoderDataset, MTBenchDataset, + RandomDataset, SampleRequest, ShareGPTDataset, + SonnetDataset, VisionArenaDataset) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -595,6 +595,9 @@ def main(args: argparse.Namespace): elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: dataset_class = InstructCoderDataset args.hf_split = "train" + elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: + dataset_class = MTBenchDataset + args.hf_split = "train" elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: dataset_class = ConversationDataset elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: