|
18 | 18 | # |
19 | 19 | import json |
20 | 20 | import os |
| 21 | +from typing import Any, Dict |
21 | 22 |
|
22 | 23 | import jsonschema |
23 | 24 | import pytest |
24 | 25 | import regex as re |
| 26 | + |
| 27 | +from vllm_ascend.utils import vllm_version_is |
| 28 | + |
| 29 | +if vllm_version_is("0.10.2"): |
| 30 | + from vllm.sampling_params import GuidedDecodingParams, SamplingParams |
| 31 | +else: |
| 32 | + from vllm.sampling_params import SamplingParams, StructuredOutputsParams |
| 33 | + |
25 | 34 | from vllm.outputs import RequestOutput |
26 | | -from vllm.sampling_params import GuidedDecodingParams, SamplingParams |
27 | 35 |
|
28 | 36 | from tests.e2e.conftest import VllmRunner |
29 | 37 |
|
@@ -84,16 +92,29 @@ def sample_json_schema(): |
84 | 92 | @pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend) |
85 | 93 | def test_guided_json_completion(guided_decoding_backend: str, |
86 | 94 | sample_json_schema): |
87 | | - sampling_params = SamplingParams( |
88 | | - temperature=1.0, |
89 | | - max_tokens=500, |
90 | | - guided_decoding=GuidedDecodingParams(json=sample_json_schema)) |
91 | | - |
92 | | - with VllmRunner( |
93 | | - MODEL_NAME, |
94 | | - seed=0, |
95 | | - guided_decoding_backend=guided_decoding_backend, |
96 | | - ) as vllm_model: |
| 95 | + runner_kwargs: Dict[str, Any] = {} |
| 96 | + if vllm_version_is("0.10.2"): |
| 97 | + sampling_params = SamplingParams( |
| 98 | + temperature=1.0, |
| 99 | + max_tokens=500, |
| 100 | + guided_decoding=GuidedDecodingParams(json=sample_json_schema)) |
| 101 | + runner_kwargs = { |
| 102 | + "seed": 0, |
| 103 | + "guided_decoding_backend": guided_decoding_backend, |
| 104 | + } |
| 105 | + else: |
| 106 | + sampling_params = SamplingParams( |
| 107 | + temperature=1.0, |
| 108 | + max_tokens=500, |
| 109 | + structured_outputs=StructuredOutputsParams( |
| 110 | + json=sample_json_schema)) |
| 111 | + runner_kwargs = { |
| 112 | + "seed": 0, |
| 113 | + "structured_outputs_config": { |
| 114 | + "backend": guided_decoding_backend |
| 115 | + }, |
| 116 | + } |
| 117 | + with VllmRunner(MODEL_NAME, **runner_kwargs) as vllm_model: |
97 | 118 | prompts = [ |
98 | 119 | f"Give an example JSON for an employee profile " |
99 | 120 | f"that fits this schema: {sample_json_schema}" |
@@ -121,17 +142,29 @@ def test_guided_json_completion(guided_decoding_backend: str, |
121 | 142 | def test_guided_regex(guided_decoding_backend: str, sample_regex): |
122 | 143 | if guided_decoding_backend == "outlines": |
123 | 144 | pytest.skip("Outlines doesn't support regex-based guided decoding.") |
| 145 | + runner_kwargs: Dict[str, Any] = {} |
| 146 | + if vllm_version_is("0.10.2"): |
| 147 | + sampling_params = SamplingParams( |
| 148 | + temperature=0.8, |
| 149 | + top_p=0.95, |
| 150 | + guided_decoding=GuidedDecodingParams(regex=sample_regex)) |
| 151 | + runner_kwargs = { |
| 152 | + "seed": 0, |
| 153 | + "guided_decoding_backend": guided_decoding_backend, |
| 154 | + } |
| 155 | + else: |
| 156 | + sampling_params = SamplingParams( |
| 157 | + temperature=0.8, |
| 158 | + top_p=0.95, |
| 159 | + structured_outputs=StructuredOutputsParams(regex=sample_regex)) |
| 160 | + runner_kwargs = { |
| 161 | + "seed": 0, |
| 162 | + "structured_outputs_config": { |
| 163 | + "backend": guided_decoding_backend |
| 164 | + }, |
| 165 | + } |
124 | 166 |
|
125 | | - sampling_params = SamplingParams( |
126 | | - temperature=0.8, |
127 | | - top_p=0.95, |
128 | | - guided_decoding=GuidedDecodingParams(regex=sample_regex)) |
129 | | - |
130 | | - with VllmRunner( |
131 | | - MODEL_NAME, |
132 | | - seed=0, |
133 | | - guided_decoding_backend=guided_decoding_backend, |
134 | | - ) as vllm_model: |
| 167 | + with VllmRunner(MODEL_NAME, **runner_kwargs) as vllm_model: |
135 | 168 | prompts = [ |
136 | 169 | f"Give an example IPv4 address with this regex: {sample_regex}" |
137 | 170 | ] * 2 |
|
0 commit comments