diff --git a/benchmarks/cpp/prepare_dataset.py b/benchmarks/cpp/prepare_dataset.py index 93a225a2504..2f7b5516b62 100644 --- a/benchmarks/cpp/prepare_dataset.py +++ b/benchmarks/cpp/prepare_dataset.py @@ -16,10 +16,8 @@ from typing import Optional, Tuple import click -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, model_validator from transformers import AutoTokenizer -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from utils.prepare_real_data import dataset from utils.prepare_synthetic_data import token_norm_dist, token_unif_dist @@ -30,20 +28,25 @@ class RootArgs(BaseModel): random_seed: int task_id: int std_out: bool + trust_remote_code: bool = False rand_task_id: Optional[Tuple[int, int]] lora_dir: Optional[str] = None - @field_validator('tokenizer') - def get_tokenizer(cls, - v: str) -> PreTrainedTokenizer | PreTrainedTokenizerFast: + @model_validator(mode='after') + def validate_tokenizer(self): try: - tokenizer = AutoTokenizer.from_pretrained(v, padding_side='left') + tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer, + padding_side='left', + trust_remote_code=self.trust_remote_code) except EnvironmentError as e: raise ValueError( f"Cannot find a tokenizer from the given string because of {e}\nPlease set tokenizer to the directory that contains the tokenizer, or set to a model name in HuggingFace." ) tokenizer.pad_token = tokenizer.eos_token - return tokenizer + self.tokenizer = tokenizer + + return self @click.group() @@ -82,6 +85,11 @@ def get_tokenizer(cls, default="info", type=click.Choice(['info', 'debug']), help="Logging level.") +@click.option("--trust-remote-code", + is_flag=True, + default=False, + envvar="TRUST_REMOTE_CODE", + help="Trust remote code.") @click.pass_context def cli(ctx, **kwargs): """This script generates dataset input for gptManagerBenchmark.""" @@ -98,7 +106,8 @@ def cli(ctx, **kwargs): random_seed=kwargs['random_seed'], task_id=kwargs['task_id'], rand_task_id=kwargs['rand_task_id'], - lora_dir=kwargs['lora_dir']) + lora_dir=kwargs['lora_dir'], + trust_remote_code=kwargs['trust_remote_code']) cli.add_command(dataset)