|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
3 | 3 | import json |
| 4 | +import re |
4 | 5 |
|
5 | 6 | import jsonschema |
6 | 7 | import pytest |
@@ -219,25 +220,24 @@ def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str): |
219 | 220 | guided_decoding=GuidedDecodingParams( |
220 | 221 | regex=sample_regex, |
221 | 222 | backend=guided_decoding_backend)) |
222 | | - with pytest.raises(ValueError, |
223 | | - match="Regex guided decoding is not supported."): |
224 | | - llm.generate(prompts=[ |
| 223 | + outputs = llm.generate( |
| 224 | + prompts=[ |
225 | 225 | f"Give an example IPv4 address with this regex: {sample_regex}" |
226 | 226 | ] * 2, |
227 | | - sampling_params=sampling_params, |
228 | | - use_tqdm=True) |
| 227 | + sampling_params=sampling_params, |
| 228 | + use_tqdm=True, |
| 229 | + ) |
229 | 230 |
|
230 | | - # Once regex is supported -- |
231 | | - #assert outputs is not None |
232 | | - #for output in outputs: |
233 | | - # assert output is not None |
234 | | - # assert isinstance(output, RequestOutput) |
235 | | - # prompt = output.prompt |
236 | | - # generated_text = output.outputs[0].text |
237 | | - # print(generated_text) |
238 | | - # assert generated_text is not None |
239 | | - # assert re.fullmatch(sample_regex, generated_text) is not None |
240 | | - # print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| 231 | + assert outputs is not None |
| 232 | + for output in outputs: |
| 233 | + assert output is not None |
| 234 | + assert isinstance(output, RequestOutput) |
| 235 | + prompt = output.prompt |
| 236 | + generated_text = output.outputs[0].text |
| 237 | + print(generated_text) |
| 238 | + assert generated_text is not None |
| 239 | + assert re.fullmatch(sample_regex, generated_text) is not None |
| 240 | + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
241 | 241 |
|
242 | 242 |
|
243 | 243 | @pytest.mark.skip_global_cleanup |
|
0 commit comments