From f12b92a5969b393d06b0c020a3782e256bd51382 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 29 Apr 2025 21:09:15 -0400 Subject: [PATCH 1/6] [V1] Disable pickle fallback by default for better security This changes `vllm.v1.serial_utils` to disable the pickle fallback by default. Add an environment variable that can be used to turn the pickle fallback back on. Follow-up to #17427. Signed-off-by: Russell Bryant --- vllm/envs.py | 7 +++++++ vllm/v1/serial_utils.py | 23 +++++++++++++++-------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index ea40bfff11b5..0e4b3a9893d0 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -110,6 +110,7 @@ VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 + VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False def get_default_cache_root(): @@ -727,6 +728,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # limit will actually be zero-copy decoded. "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), + + # If set, allow insecure serialization using pickle. + # This is useful for environments where it is deemed safe to use the + # insecure method and it is needed for some reason. + "VLLM_ALLOW_INSECURE_SERIALIZATION": + lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))), } # end-env-vars-definition diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index e00ecde66af0..0b0c3c51d110 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -14,6 +14,7 @@ from msgspec import msgpack from vllm import envs +from vllm.logger import init_logger from vllm.multimodal.inputs import (BaseMultiModalField, MultiModalBatchedField, MultiModalFieldConfig, MultiModalFieldElem, @@ -21,6 +22,8 @@ MultiModalKwargsItem, MultiModalSharedField, NestedTensors) +logger = init_logger(__name__) + CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_RAW_VIEW = 3 @@ -47,9 +50,7 @@ class MsgpackEncoder: via dedicated messages. Note that this is a per-tensor limit. """ - def __init__(self, - size_threshold: Optional[int] = None, - allow_pickle: bool = True): + def __init__(self, size_threshold: Optional[int] = None): if size_threshold is None: size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) @@ -58,7 +59,10 @@ def __init__(self, # pass custom data to the hook otherwise. self.aux_buffers: Optional[list[bytestr]] = None self.size_threshold = size_threshold - self.allow_pickle = allow_pickle + if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: + logger.warning( + "Allowing insecure serialization using pickle due to " + "VLLM_ALLOW_INSECURE_SERIALIZATION=1") def encode(self, obj: Any) -> Sequence[bytestr]: try: @@ -108,7 +112,7 @@ def enc_hook(self, obj: Any) -> Any: for itemlist in mm._items_by_modality.values() for item in itemlist] - if not self.allow_pickle: + if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: raise TypeError(f"Object of type {type(obj)} is not serializable") if isinstance(obj, FunctionType): @@ -185,13 +189,16 @@ class MsgpackDecoder: not thread-safe when encoding tensors / numpy arrays. """ - def __init__(self, t: Optional[Any] = None, allow_pickle: bool = True): + def __init__(self, t: Optional[Any] = None): args = () if t is None else (t, ) self.decoder = msgpack.Decoder(*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook) self.aux_buffers: Sequence[bytestr] = () - self.allow_pickle = allow_pickle + if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: + logger.warning( + "Allowing insecure deserialization using pickle due to " + "VLLM_ALLOW_INSECURE_SERIALIZATION=1") def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any: if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)): @@ -273,7 +280,7 @@ def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_RAW_VIEW: return data - if self.allow_pickle: + if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: if code == CUSTOM_TYPE_PICKLE: return pickle.loads(data) if code == CUSTOM_TYPE_CLOUDPICKLE: From f8a20a32ded5c0b7fa024a0be5d7408cb41e742a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 6 May 2025 08:16:51 -0700 Subject: [PATCH 2/6] Handle python slices Signed-off-by: Nick Hill --- vllm/v1/serial_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 0b0c3c51d110..2592ea710679 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -93,6 +93,9 @@ def enc_hook(self, obj: Any) -> Any: if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): return self._encode_ndarray(obj) + if isinstance(obj, slice): + return obj.start, obj.stop, obj.step + if isinstance(obj, MultiModalKwargs): mm: MultiModalKwargs = obj if not mm.modalities: @@ -219,6 +222,8 @@ def dec_hook(self, t: type, obj: Any) -> Any: return self._decode_ndarray(obj) if issubclass(t, torch.Tensor): return self._decode_tensor(obj) + if t is slice: + return slice(*obj) if issubclass(t, MultiModalKwargs): if isinstance(obj, list): return MultiModalKwargs.from_items( @@ -260,6 +265,12 @@ def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: factory_meth_name, *field_args = v["field"] factory_meth = getattr(MultiModalFieldConfig, factory_meth_name) + + # Special case: decode the union "slices" field of + # MultiModalFlatField + if factory_meth_name == "flat": + field_args[0] = self._decode_nested_slices(field_args[0]) + v["field"] = factory_meth(None, *field_args).field elems.append(MultiModalFieldElem(**v)) decoded_items.append(MultiModalKwargsItem.from_elems(elems)) @@ -276,6 +287,11 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: return self._decode_tensor(obj) return [self._decode_nested_tensors(x) for x in obj] + def _decode_nested_slices(self, obj: Any) -> Any: + if obj and isinstance(obj[0], int): + return slice(*obj) + return [self._decode_nested_slices(x) for x in obj] + def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_RAW_VIEW: return data From 13670fe0e9e58101a1a0ede71f158bd7ec5631b2 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 6 May 2025 18:59:41 -0700 Subject: [PATCH 3/6] Fix slice decode, update test Signed-off-by: Nick Hill --- tests/v1/test_serial_utils.py | 150 +++++++++++++++++----------------- vllm/v1/serial_utils.py | 3 +- 2 files changed, 79 insertions(+), 74 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index d1271b210ad8..3a5a793b133e 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -9,8 +9,8 @@ import torch from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, + MultiModalFieldElem, MultiModalFlatField, + MultiModalKwargs, MultiModalKwargsItem, MultiModalSharedField, NestedTensors) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -36,59 +36,62 @@ class MyType: empty_tensor: torch.Tensor -def test_encode_decode(): +def test_encode_decode(monkeypatch: pytest.MonkeyPatch): """Test encode/decode loop with zero-copy tensors.""" - obj = MyType( - tensor1=torch.randint(low=0, - high=100, - size=(1024, ), - dtype=torch.int32), - a_string="hello", - list_of_tensors=[ - torch.rand((1, 10), dtype=torch.float32), - torch.rand((3, 5, 4000), dtype=torch.float64), - torch.tensor(1984), # test scalar too - # Make sure to test bf16 which numpy doesn't support. - torch.rand((3, 5, 1000), dtype=torch.bfloat16), - torch.tensor([float("-inf"), float("inf")] * 1024, - dtype=torch.bfloat16), - ], - numpy_array=np.arange(512), - unrecognized=UnrecognizedType(33), - small_f_contig_tensor=torch.rand(5, 4).t(), - large_f_contig_tensor=torch.rand(1024, 4).t(), - small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], - large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], - empty_tensor=torch.empty(0), - ) + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - encoder = MsgpackEncoder(size_threshold=256) - decoder = MsgpackDecoder(MyType) + obj = MyType( + tensor1=torch.randint(low=0, + high=100, + size=(1024, ), + dtype=torch.int32), + a_string="hello", + list_of_tensors=[ + torch.rand((1, 10), dtype=torch.float32), + torch.rand((3, 5, 4000), dtype=torch.float64), + torch.tensor(1984), # test scalar too + # Make sure to test bf16 which numpy doesn't support. + torch.rand((3, 5, 1000), dtype=torch.bfloat16), + torch.tensor([float("-inf"), float("inf")] * 1024, + dtype=torch.bfloat16), + ], + numpy_array=np.arange(512), + unrecognized=UnrecognizedType(33), + small_f_contig_tensor=torch.rand(5, 4).t(), + large_f_contig_tensor=torch.rand(1024, 4).t(), + small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], + large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], + empty_tensor=torch.empty(0), + ) - encoded = encoder.encode(obj) + encoder = MsgpackEncoder(size_threshold=256) + decoder = MsgpackDecoder(MyType) - # There should be the main buffer + 4 large tensor buffers - # + 1 large numpy array. "large" is <= 512 bytes. - # The two small tensors are encoded inline. - assert len(encoded) == 8 + encoded = encoder.encode(obj) + + # There should be the main buffer + 4 large tensor buffers + # + 1 large numpy array. "large" is <= 512 bytes. + # The two small tensors are encoded inline. + assert len(encoded) == 8 - decoded: MyType = decoder.decode(encoded) + decoded: MyType = decoder.decode(encoded) - assert_equal(decoded, obj) + assert_equal(decoded, obj) - # Test encode_into case + # Test encode_into case - preallocated = bytearray() + preallocated = bytearray() - encoded2 = encoder.encode_into(obj, preallocated) + encoded2 = encoder.encode_into(obj, preallocated) - assert len(encoded2) == 8 - assert encoded2[0] is preallocated + assert len(encoded2) == 8 + assert encoded2[0] is preallocated - decoded2: MyType = decoder.decode(encoded2) + decoded2: MyType = decoder.decode(encoded2) - assert_equal(decoded2, obj) + assert_equal(decoded2, obj) class MyRequest(msgspec.Struct): @@ -122,7 +125,7 @@ def test_multimodal_kwargs(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) # expected total encoding length, should be 44559, +-20 for minor changes - assert total_len >= 44539 and total_len <= 44579 + assert 44539 <= total_len <= 44579 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] assert all(nested_equal(d[k], decoded[k]) for k in d) @@ -135,14 +138,15 @@ def test_multimodal_items_by_modality(): "video", "v0", [torch.zeros(1000, dtype=torch.int8) for _ in range(4)], - MultiModalBatchedField(), + MultiModalFlatField( + [[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0), ) e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4)) - e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000, - dtype=torch.int32), - MultiModalBatchedField()) + e4 = MultiModalFieldElem( + "image", "i1", torch.zeros(1000, dtype=torch.int32), + MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2)) audio = MultiModalKwargsItem.from_elems([e1]) video = MultiModalKwargsItem.from_elems([e2]) image = MultiModalKwargsItem.from_elems([e3, e4]) @@ -161,7 +165,7 @@ def test_multimodal_items_by_modality(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) # expected total encoding length, should be 14255, +-20 for minor changes - assert total_len >= 14235 and total_len <= 14275 + assert 14250 <= total_len <= 14300 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] # check all modalities were recovered and do some basic sanity checks @@ -178,8 +182,7 @@ def test_multimodal_items_by_modality(): def nested_equal(a: NestedTensors, b: NestedTensors): if isinstance(a, torch.Tensor): return torch.equal(a, b) - else: - return all(nested_equal(x, y) for x, y in zip(a, b)) + return all(nested_equal(x, y) for x, y in zip(a, b)) def assert_equal(obj1: MyType, obj2: MyType): @@ -199,11 +202,10 @@ def assert_equal(obj1: MyType, obj2: MyType): assert torch.equal(obj1.empty_tensor, obj2.empty_tensor) -@pytest.mark.parametrize("allow_pickle", [True, False]) -def test_dict_serialization(allow_pickle: bool): +def test_dict_serialization(): """Test encoding and decoding of a generic Python object using pickle.""" - encoder = MsgpackEncoder(allow_pickle=allow_pickle) - decoder = MsgpackDecoder(allow_pickle=allow_pickle) + encoder = MsgpackEncoder() + decoder = MsgpackDecoder() # Create a sample Python object obj = {"key": "value", "number": 42} @@ -218,11 +220,10 @@ def test_dict_serialization(allow_pickle: bool): assert obj == decoded, "Decoded object does not match the original object." -@pytest.mark.parametrize("allow_pickle", [True, False]) -def test_tensor_serialization(allow_pickle: bool): +def test_tensor_serialization(): """Test encoding and decoding of a torch.Tensor.""" - encoder = MsgpackEncoder(allow_pickle=allow_pickle) - decoder = MsgpackDecoder(torch.Tensor, allow_pickle=allow_pickle) + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(torch.Tensor) # Create a sample tensor tensor = torch.rand(10, 10) @@ -238,11 +239,10 @@ def test_tensor_serialization(allow_pickle: bool): tensor, decoded), "Decoded tensor does not match the original tensor." -@pytest.mark.parametrize("allow_pickle", [True, False]) -def test_numpy_array_serialization(allow_pickle: bool): +def test_numpy_array_serialization(): """Test encoding and decoding of a numpy array.""" - encoder = MsgpackEncoder(allow_pickle=allow_pickle) - decoder = MsgpackDecoder(np.ndarray, allow_pickle=allow_pickle) + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(np.ndarray) # Create a sample numpy array array = np.random.rand(10, 10) @@ -268,26 +268,30 @@ def __eq__(self, other): return isinstance(other, CustomClass) and self.value == other.value -def test_custom_class_serialization_allowed_with_pickle(): +def test_custom_class_serialization_allowed_with_pickle( + monkeypatch: pytest.MonkeyPatch): """Test that serializing a custom class succeeds when allow_pickle=True.""" - encoder = MsgpackEncoder(allow_pickle=True) - decoder = MsgpackDecoder(CustomClass, allow_pickle=True) - obj = CustomClass("test_value") + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(CustomClass) - # Encode the custom class - encoded = encoder.encode(obj) + obj = CustomClass("test_value") - # Decode the custom class - decoded = decoder.decode(encoded) + # Encode the custom class + encoded = encoder.encode(obj) - # Verify the decoded object matches the original - assert obj == decoded, "Decoded object does not match the original object." + # Decode the custom class + decoded = decoder.decode(encoded) + + # Verify the decoded object matches the original + assert obj == decoded, "Decoded object does not match the original object." def test_custom_class_serialization_disallowed_without_pickle(): """Test that serializing a custom class fails when allow_pickle=False.""" - encoder = MsgpackEncoder(allow_pickle=False) + encoder = MsgpackEncoder() obj = CustomClass("test_value") diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 2592ea710679..7b3212c109c9 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -288,7 +288,8 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: return [self._decode_nested_tensors(x) for x in obj] def _decode_nested_slices(self, obj: Any) -> Any: - if obj and isinstance(obj[0], int): + assert isinstance(obj, (list, tuple)) + if obj and not isinstance(obj[0], (list, tuple)): return slice(*obj) return [self._decode_nested_slices(x) for x in obj] From c558931983ee247c00fa5b860e5470b4a9acb6dc Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 6 May 2025 19:11:18 -0700 Subject: [PATCH 4/6] fix linting Signed-off-by: Nick Hill --- tests/v1/test_serial_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 3a5a793b133e..ee490071f6a2 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -286,7 +286,8 @@ def test_custom_class_serialization_allowed_with_pickle( decoded = decoder.decode(encoded) # Verify the decoded object matches the original - assert obj == decoded, "Decoded object does not match the original object." + assert obj == decoded, ( + "Decoded object does not match the original object.") def test_custom_class_serialization_disallowed_without_pickle(): From 31590f57e4d59f29edbaabb4099b9186e23fa7fa Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 7 May 2025 11:21:04 -0400 Subject: [PATCH 5/6] Make structured output tests less likely to fail Update all prompts with an instruction to keep responses short. We see ocassional failures when output has not reached the end before hitting the output length limit, particularly in the JSON cases. Hopefully this helps. Signed-off-by: Russell Bryant --- .../llm/test_struct_output_generate.py | 84 +++++++++++-------- 1 file changed, 50 insertions(+), 34 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index c3ea024f58cb..81601c87ad8b 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -105,8 +105,9 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(json=sample_json_schema)) outputs = llm.generate(prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}" + (f"Give an example JSON for an employee profile that fits this " + f"schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}") ] * 2, sampling_params=sampling_params, use_tqdm=True) @@ -136,7 +137,8 @@ def test_structured_output( outputs = llm.generate( prompts=("Generate a JSON object with curly braces for a person with " - "name and age fields for John Smith who is 31 years old."), + "name and age fields for John Smith who is 31 years old. " + "Make the response as short as possible."), sampling_params=sampling_params, use_tqdm=True) @@ -165,19 +167,20 @@ def test_structured_output( with pytest.raises(ValueError, match="The provided JSON schema contains features " "not supported by xgrammar."): - llm.generate(prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {unsupported_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + llm.generate( + prompts=[(f"Give an example JSON for an employee profile that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible.")] * 2, + sampling_params=sampling_params, + use_tqdm=True) else: - outputs = llm.generate( - prompts=("Give an example JSON object for a grade " - "that fits this schema: " - f"{unsupported_json_schema}"), - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate(prompts=( + "Give an example JSON object for a grade " + "that fits this schema: " + f"{unsupported_json_schema}. Make the response as short as " + "possible."), + sampling_params=sampling_params, + use_tqdm=True) assert outputs is not None for output in outputs: assert output is not None @@ -199,8 +202,10 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) outputs = llm.generate( - prompts=("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1"), + prompts=( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -231,8 +236,10 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) outputs = llm.generate( - prompts=("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1"), + prompts=( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -269,8 +276,10 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(grammar="not a grammar")) with pytest.raises(ValueError, match="Failed to convert the grammar "): llm.generate( - prompts=("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1"), + prompts=( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short " + "as possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -284,7 +293,8 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(regex=sample_regex)) outputs = llm.generate( prompts=[ - f"Give an example IPv4 address with this regex: {sample_regex}" + (f"Give an example IPv4 address with this regex: {sample_regex}. " + f"Make the response as short as possible.") ] * 2, sampling_params=sampling_params, use_tqdm=True, @@ -309,7 +319,8 @@ def test_structured_output( top_p=0.95, guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) outputs = llm.generate( - prompts="The best language for type-safe systems programming is ", + prompts=("The best language for type-safe systems programming is " + "(Make the response as short as possible.) "), sampling_params=sampling_params, use_tqdm=True) assert outputs is not None @@ -331,11 +342,12 @@ def test_structured_output( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams(json=json_schema)) - outputs = llm.generate( - prompts="Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's", - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate(prompts=( + "Generate a JSON with the brand, model and car_type of the most " + "iconic car from the 90's. Make the response as short as " + "possible."), + sampling_params=sampling_params, + use_tqdm=True) assert outputs is not None @@ -373,7 +385,8 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(json=json_schema)) outputs = llm.generate( - prompts="Generate a description of a frog using 50 characters.", + prompts=("Generate a description of a frog using 50 characters. " + "Make the response as short as possible."), sampling_params=sampling_params, use_tqdm=True) @@ -452,7 +465,8 @@ def test_structured_output( You are a helpful assistant. -Given the previous instructions, what is the weather in New York City? +Given the previous instructions, what is the weather in New York City? \ +Make the response as short as possible. """ # Change this once other backends support structural_tag @@ -509,9 +523,10 @@ def test_structured_output_auto_mode( max_tokens=1000, guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) - prompts = ("Give an example JSON object for a grade " - "that fits this schema: " - f"{unsupported_json_schema}") + prompts = ( + "Give an example JSON object for a grade " + "that fits this schema: " + f"{unsupported_json_schema}. Make the response as short as possible.") # This would fail with the default of "xgrammar", but in "auto" # we will handle fallback automatically. outputs = llm.generate(prompts=prompts, @@ -566,7 +581,8 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): prompt = ( "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a " "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " - "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" + "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20. " + "Make the response as short as possible." "<|im_end|>\n<|im_start|>assistant\n") def generate_with_backend(backend): From c95fcec68c08ba782ebbc51e56f813d23f3810c1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 7 May 2025 10:46:50 -0700 Subject: [PATCH 6/6] Fix serialization of non-int slice elements Signed-off-by: Nick Hill --- vllm/v1/serial_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 7b3212c109c9..6989ad9a40a4 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -94,7 +94,10 @@ def enc_hook(self, obj: Any) -> Any: return self._encode_ndarray(obj) if isinstance(obj, slice): - return obj.start, obj.stop, obj.step + # We are assuming only int-based values will be used here. + return tuple( + int(v) if v is not None else None + for v in (obj.start, obj.stop, obj.step)) if isinstance(obj, MultiModalKwargs): mm: MultiModalKwargs = obj