|
24 | 24 | """ |
25 | 25 | import argparse |
26 | 26 | import asyncio |
| 27 | +import copy |
27 | 28 | import dataclasses |
28 | 29 | import json |
29 | 30 | import os |
30 | 31 | import random |
31 | 32 | import time |
| 33 | +import uuid |
32 | 34 | import warnings |
33 | 35 | from collections.abc import AsyncGenerator |
34 | 36 | from dataclasses import dataclass |
@@ -109,24 +111,43 @@ class SampleRequest: |
109 | 111 |
|
110 | 112 | def sample_requests(tokenizer: PreTrainedTokenizerBase, |
111 | 113 | args: argparse.Namespace) -> list[SampleRequest]: |
112 | | - if args.dataset == 'json': |
| 114 | + if args.dataset == 'json' or args.dataset == 'json-unique': |
113 | 115 | if args.json_schema_path is None: |
114 | 116 | dir_path = os.path.dirname(os.path.realpath(__file__)) |
115 | 117 | args.json_schema_path = os.path.join(dir_path, |
116 | 118 | "structured_schemas", |
117 | 119 | "structured_schema_1.json") |
| 120 | + json_schemas = [] |
118 | 121 | with open(args.json_schema_path) as f: |
119 | 122 | schema = json.load(f) |
120 | | - prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501 |
121 | | - input_len = len(tokenizer(prompt).input_ids) |
122 | | - print(f"Input length of the prompt: {input_len} tokens") |
| 123 | + |
| 124 | + if args.dataset == 'json-unique': |
| 125 | + json_schemas = [ |
| 126 | + copy.deepcopy(schema) for _ in range(args.num_prompts) |
| 127 | + ] |
| 128 | + for i in range(len(json_schemas)): |
| 129 | + json_schemas[i]["properties"][ |
| 130 | + f"__optional_field_{uuid.uuid4()}"] = { |
| 131 | + "type": |
| 132 | + "string", |
| 133 | + "description": |
| 134 | + "An unique optional field to avoid cached schemas" |
| 135 | + } |
| 136 | + |
| 137 | + def gen_prompt(index: int): |
| 138 | + schema = json_schemas[index % len(json_schemas)] |
| 139 | + return f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501 |
| 140 | + |
| 141 | + def get_schema(index: int): |
| 142 | + return json_schemas[index % len(json_schemas)] |
| 143 | + |
123 | 144 | requests = [ |
124 | | - SampleRequest(prompt=prompt, |
125 | | - prompt_len=input_len, |
| 145 | + SampleRequest(prompt=gen_prompt(i), |
| 146 | + prompt_len=len(tokenizer(gen_prompt(i)).input_ids), |
126 | 147 | expected_output_len=args.output_len, |
127 | | - schema=schema, |
| 148 | + schema=get_schema(i), |
128 | 149 | structure_type=args.structure_type) |
129 | | - for _ in range(args.num_prompts) |
| 150 | + for i in range(args.num_prompts) |
130 | 151 | ] |
131 | 152 |
|
132 | 153 | elif args.dataset == "grammar": |
@@ -821,10 +842,12 @@ def main(args: argparse.Namespace): |
821 | 842 | default="/v1/completions", |
822 | 843 | help="API endpoint.", |
823 | 844 | ) |
824 | | - parser.add_argument( |
825 | | - "--dataset", |
826 | | - default='json', |
827 | | - choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench']) |
| 845 | + parser.add_argument("--dataset", |
| 846 | + default='json', |
| 847 | + choices=[ |
| 848 | + 'json', 'json-unique', 'grammar', 'regex', |
| 849 | + 'choice', 'xgrammar_bench' |
| 850 | + ]) |
828 | 851 | parser.add_argument("--json_schema_path", |
829 | 852 | type=str, |
830 | 853 | default=None, |
@@ -966,11 +989,12 @@ def main(args: argparse.Namespace): |
966 | 989 | type=float, |
967 | 990 | default=1.0, |
968 | 991 | help="Ratio of Structured Outputs requests") |
969 | | - parser.add_argument("--structured-output-backend", |
970 | | - type=str, |
971 | | - choices=["outlines", "lm-format-enforcer", "xgrammar"], |
972 | | - default="xgrammar", |
973 | | - help="Backend to use for structured outputs") |
| 992 | + parser.add_argument( |
| 993 | + "--structured-output-backend", |
| 994 | + type=str, |
| 995 | + choices=["outlines", "lm-format-enforcer", "xgrammar", "json-unique"], |
| 996 | + default="xgrammar", |
| 997 | + help="Backend to use for structured outputs") |
974 | 998 |
|
975 | 999 | args = parser.parse_args() |
976 | 1000 | main(args) |
0 commit comments