Skip to content
27 changes: 18 additions & 9 deletions benchmarks/cpp/prepare_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
from typing import Optional, Tuple

import click
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, model_validator
from transformers import AutoTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from utils.prepare_real_data import dataset
from utils.prepare_synthetic_data import token_norm_dist, token_unif_dist

Expand All @@ -30,20 +28,25 @@ class RootArgs(BaseModel):
random_seed: int
task_id: int
std_out: bool
trust_remote_code: bool = False
rand_task_id: Optional[Tuple[int, int]]
lora_dir: Optional[str] = None

@field_validator('tokenizer')
def get_tokenizer(cls,
v: str) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
@model_validator(mode='after')
def validate_tokenizer(self):
try:
tokenizer = AutoTokenizer.from_pretrained(v, padding_side='left')
tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer,
padding_side='left',
trust_remote_code=self.trust_remote_code)
except EnvironmentError as e:
raise ValueError(
f"Cannot find a tokenizer from the given string because of {e}\nPlease set tokenizer to the directory that contains the tokenizer, or set to a model name in HuggingFace."
)
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
self.tokenizer = tokenizer

return self


@click.group()
Expand Down Expand Up @@ -82,6 +85,11 @@ def get_tokenizer(cls,
default="info",
type=click.Choice(['info', 'debug']),
help="Logging level.")
@click.option("--trust-remote-code",
is_flag=True,
default=False,
envvar="TRUST_REMOTE_CODE",
help="Trust remote code.")
@click.pass_context
def cli(ctx, **kwargs):
"""This script generates dataset input for gptManagerBenchmark."""
Expand All @@ -98,7 +106,8 @@ def cli(ctx, **kwargs):
random_seed=kwargs['random_seed'],
task_id=kwargs['task_id'],
rand_task_id=kwargs['rand_task_id'],
lora_dir=kwargs['lora_dir'])
lora_dir=kwargs['lora_dir'],
trust_remote_code=kwargs['trust_remote_code'])


cli.add_command(dataset)
Expand Down