diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index c3ea024f58cb..81601c87ad8b 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -105,8 +105,9 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(json=sample_json_schema)) outputs = llm.generate(prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}" + (f"Give an example JSON for an employee profile that fits this " + f"schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}") ] * 2, sampling_params=sampling_params, use_tqdm=True) @@ -136,7 +137,8 @@ def test_structured_output( outputs = llm.generate( prompts=("Generate a JSON object with curly braces for a person with " - "name and age fields for John Smith who is 31 years old."), + "name and age fields for John Smith who is 31 years old. " + "Make the response as short as possible."), sampling_params=sampling_params, use_tqdm=True) @@ -165,19 +167,20 @@ def test_structured_output( with pytest.raises(ValueError, match="The provided JSON schema contains features " "not supported by xgrammar."): - llm.generate(prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {unsupported_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + llm.generate( + prompts=[(f"Give an example JSON for an employee profile that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible.")] * 2, + sampling_params=sampling_params, + use_tqdm=True) else: - outputs = llm.generate( - prompts=("Give an example JSON object for a grade " - "that fits this schema: " - f"{unsupported_json_schema}"), - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate(prompts=( + "Give an example JSON object for a grade " + "that fits this schema: " + f"{unsupported_json_schema}. Make the response as short as " + "possible."), + sampling_params=sampling_params, + use_tqdm=True) assert outputs is not None for output in outputs: assert output is not None @@ -199,8 +202,10 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) outputs = llm.generate( - prompts=("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1"), + prompts=( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -231,8 +236,10 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) outputs = llm.generate( - prompts=("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1"), + prompts=( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -269,8 +276,10 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(grammar="not a grammar")) with pytest.raises(ValueError, match="Failed to convert the grammar "): llm.generate( - prompts=("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1"), + prompts=( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short " + "as possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -284,7 +293,8 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(regex=sample_regex)) outputs = llm.generate( prompts=[ - f"Give an example IPv4 address with this regex: {sample_regex}" + (f"Give an example IPv4 address with this regex: {sample_regex}. " + f"Make the response as short as possible.") ] * 2, sampling_params=sampling_params, use_tqdm=True, @@ -309,7 +319,8 @@ def test_structured_output( top_p=0.95, guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) outputs = llm.generate( - prompts="The best language for type-safe systems programming is ", + prompts=("The best language for type-safe systems programming is " + "(Make the response as short as possible.) "), sampling_params=sampling_params, use_tqdm=True) assert outputs is not None @@ -331,11 +342,12 @@ def test_structured_output( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams(json=json_schema)) - outputs = llm.generate( - prompts="Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's", - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate(prompts=( + "Generate a JSON with the brand, model and car_type of the most " + "iconic car from the 90's. Make the response as short as " + "possible."), + sampling_params=sampling_params, + use_tqdm=True) assert outputs is not None @@ -373,7 +385,8 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(json=json_schema)) outputs = llm.generate( - prompts="Generate a description of a frog using 50 characters.", + prompts=("Generate a description of a frog using 50 characters. " + "Make the response as short as possible."), sampling_params=sampling_params, use_tqdm=True) @@ -452,7 +465,8 @@ def test_structured_output( You are a helpful assistant. -Given the previous instructions, what is the weather in New York City? +Given the previous instructions, what is the weather in New York City? \ +Make the response as short as possible. """ # Change this once other backends support structural_tag @@ -509,9 +523,10 @@ def test_structured_output_auto_mode( max_tokens=1000, guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) - prompts = ("Give an example JSON object for a grade " - "that fits this schema: " - f"{unsupported_json_schema}") + prompts = ( + "Give an example JSON object for a grade " + "that fits this schema: " + f"{unsupported_json_schema}. Make the response as short as possible.") # This would fail with the default of "xgrammar", but in "auto" # we will handle fallback automatically. outputs = llm.generate(prompts=prompts, @@ -566,7 +581,8 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): prompt = ( "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a " "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " - "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" + "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20. " + "Make the response as short as possible." "<|im_end|>\n<|im_start|>assistant\n") def generate_with_backend(backend): diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index d1271b210ad8..ee490071f6a2 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -9,8 +9,8 @@ import torch from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, + MultiModalFieldElem, MultiModalFlatField, + MultiModalKwargs, MultiModalKwargsItem, MultiModalSharedField, NestedTensors) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -36,59 +36,62 @@ class MyType: empty_tensor: torch.Tensor -def test_encode_decode(): +def test_encode_decode(monkeypatch: pytest.MonkeyPatch): """Test encode/decode loop with zero-copy tensors.""" - obj = MyType( - tensor1=torch.randint(low=0, - high=100, - size=(1024, ), - dtype=torch.int32), - a_string="hello", - list_of_tensors=[ - torch.rand((1, 10), dtype=torch.float32), - torch.rand((3, 5, 4000), dtype=torch.float64), - torch.tensor(1984), # test scalar too - # Make sure to test bf16 which numpy doesn't support. - torch.rand((3, 5, 1000), dtype=torch.bfloat16), - torch.tensor([float("-inf"), float("inf")] * 1024, - dtype=torch.bfloat16), - ], - numpy_array=np.arange(512), - unrecognized=UnrecognizedType(33), - small_f_contig_tensor=torch.rand(5, 4).t(), - large_f_contig_tensor=torch.rand(1024, 4).t(), - small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], - large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], - empty_tensor=torch.empty(0), - ) + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - encoder = MsgpackEncoder(size_threshold=256) - decoder = MsgpackDecoder(MyType) + obj = MyType( + tensor1=torch.randint(low=0, + high=100, + size=(1024, ), + dtype=torch.int32), + a_string="hello", + list_of_tensors=[ + torch.rand((1, 10), dtype=torch.float32), + torch.rand((3, 5, 4000), dtype=torch.float64), + torch.tensor(1984), # test scalar too + # Make sure to test bf16 which numpy doesn't support. + torch.rand((3, 5, 1000), dtype=torch.bfloat16), + torch.tensor([float("-inf"), float("inf")] * 1024, + dtype=torch.bfloat16), + ], + numpy_array=np.arange(512), + unrecognized=UnrecognizedType(33), + small_f_contig_tensor=torch.rand(5, 4).t(), + large_f_contig_tensor=torch.rand(1024, 4).t(), + small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], + large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], + empty_tensor=torch.empty(0), + ) - encoded = encoder.encode(obj) + encoder = MsgpackEncoder(size_threshold=256) + decoder = MsgpackDecoder(MyType) - # There should be the main buffer + 4 large tensor buffers - # + 1 large numpy array. "large" is <= 512 bytes. - # The two small tensors are encoded inline. - assert len(encoded) == 8 + encoded = encoder.encode(obj) + + # There should be the main buffer + 4 large tensor buffers + # + 1 large numpy array. "large" is <= 512 bytes. + # The two small tensors are encoded inline. + assert len(encoded) == 8 - decoded: MyType = decoder.decode(encoded) + decoded: MyType = decoder.decode(encoded) - assert_equal(decoded, obj) + assert_equal(decoded, obj) - # Test encode_into case + # Test encode_into case - preallocated = bytearray() + preallocated = bytearray() - encoded2 = encoder.encode_into(obj, preallocated) + encoded2 = encoder.encode_into(obj, preallocated) - assert len(encoded2) == 8 - assert encoded2[0] is preallocated + assert len(encoded2) == 8 + assert encoded2[0] is preallocated - decoded2: MyType = decoder.decode(encoded2) + decoded2: MyType = decoder.decode(encoded2) - assert_equal(decoded2, obj) + assert_equal(decoded2, obj) class MyRequest(msgspec.Struct): @@ -122,7 +125,7 @@ def test_multimodal_kwargs(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) # expected total encoding length, should be 44559, +-20 for minor changes - assert total_len >= 44539 and total_len <= 44579 + assert 44539 <= total_len <= 44579 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] assert all(nested_equal(d[k], decoded[k]) for k in d) @@ -135,14 +138,15 @@ def test_multimodal_items_by_modality(): "video", "v0", [torch.zeros(1000, dtype=torch.int8) for _ in range(4)], - MultiModalBatchedField(), + MultiModalFlatField( + [[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0), ) e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4)) - e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000, - dtype=torch.int32), - MultiModalBatchedField()) + e4 = MultiModalFieldElem( + "image", "i1", torch.zeros(1000, dtype=torch.int32), + MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2)) audio = MultiModalKwargsItem.from_elems([e1]) video = MultiModalKwargsItem.from_elems([e2]) image = MultiModalKwargsItem.from_elems([e3, e4]) @@ -161,7 +165,7 @@ def test_multimodal_items_by_modality(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) # expected total encoding length, should be 14255, +-20 for minor changes - assert total_len >= 14235 and total_len <= 14275 + assert 14250 <= total_len <= 14300 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] # check all modalities were recovered and do some basic sanity checks @@ -178,8 +182,7 @@ def test_multimodal_items_by_modality(): def nested_equal(a: NestedTensors, b: NestedTensors): if isinstance(a, torch.Tensor): return torch.equal(a, b) - else: - return all(nested_equal(x, y) for x, y in zip(a, b)) + return all(nested_equal(x, y) for x, y in zip(a, b)) def assert_equal(obj1: MyType, obj2: MyType): @@ -199,11 +202,10 @@ def assert_equal(obj1: MyType, obj2: MyType): assert torch.equal(obj1.empty_tensor, obj2.empty_tensor) -@pytest.mark.parametrize("allow_pickle", [True, False]) -def test_dict_serialization(allow_pickle: bool): +def test_dict_serialization(): """Test encoding and decoding of a generic Python object using pickle.""" - encoder = MsgpackEncoder(allow_pickle=allow_pickle) - decoder = MsgpackDecoder(allow_pickle=allow_pickle) + encoder = MsgpackEncoder() + decoder = MsgpackDecoder() # Create a sample Python object obj = {"key": "value", "number": 42} @@ -218,11 +220,10 @@ def test_dict_serialization(allow_pickle: bool): assert obj == decoded, "Decoded object does not match the original object." -@pytest.mark.parametrize("allow_pickle", [True, False]) -def test_tensor_serialization(allow_pickle: bool): +def test_tensor_serialization(): """Test encoding and decoding of a torch.Tensor.""" - encoder = MsgpackEncoder(allow_pickle=allow_pickle) - decoder = MsgpackDecoder(torch.Tensor, allow_pickle=allow_pickle) + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(torch.Tensor) # Create a sample tensor tensor = torch.rand(10, 10) @@ -238,11 +239,10 @@ def test_tensor_serialization(allow_pickle: bool): tensor, decoded), "Decoded tensor does not match the original tensor." -@pytest.mark.parametrize("allow_pickle", [True, False]) -def test_numpy_array_serialization(allow_pickle: bool): +def test_numpy_array_serialization(): """Test encoding and decoding of a numpy array.""" - encoder = MsgpackEncoder(allow_pickle=allow_pickle) - decoder = MsgpackDecoder(np.ndarray, allow_pickle=allow_pickle) + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(np.ndarray) # Create a sample numpy array array = np.random.rand(10, 10) @@ -268,26 +268,31 @@ def __eq__(self, other): return isinstance(other, CustomClass) and self.value == other.value -def test_custom_class_serialization_allowed_with_pickle(): +def test_custom_class_serialization_allowed_with_pickle( + monkeypatch: pytest.MonkeyPatch): """Test that serializing a custom class succeeds when allow_pickle=True.""" - encoder = MsgpackEncoder(allow_pickle=True) - decoder = MsgpackDecoder(CustomClass, allow_pickle=True) - obj = CustomClass("test_value") + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(CustomClass) - # Encode the custom class - encoded = encoder.encode(obj) + obj = CustomClass("test_value") - # Decode the custom class - decoded = decoder.decode(encoded) + # Encode the custom class + encoded = encoder.encode(obj) - # Verify the decoded object matches the original - assert obj == decoded, "Decoded object does not match the original object." + # Decode the custom class + decoded = decoder.decode(encoded) + + # Verify the decoded object matches the original + assert obj == decoded, ( + "Decoded object does not match the original object.") def test_custom_class_serialization_disallowed_without_pickle(): """Test that serializing a custom class fails when allow_pickle=False.""" - encoder = MsgpackEncoder(allow_pickle=False) + encoder = MsgpackEncoder() obj = CustomClass("test_value") diff --git a/vllm/envs.py b/vllm/envs.py index ea40bfff11b5..0e4b3a9893d0 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -110,6 +110,7 @@ VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 + VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False def get_default_cache_root(): @@ -727,6 +728,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # limit will actually be zero-copy decoded. "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), + + # If set, allow insecure serialization using pickle. + # This is useful for environments where it is deemed safe to use the + # insecure method and it is needed for some reason. + "VLLM_ALLOW_INSECURE_SERIALIZATION": + lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))), } # end-env-vars-definition diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index e00ecde66af0..6989ad9a40a4 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -14,6 +14,7 @@ from msgspec import msgpack from vllm import envs +from vllm.logger import init_logger from vllm.multimodal.inputs import (BaseMultiModalField, MultiModalBatchedField, MultiModalFieldConfig, MultiModalFieldElem, @@ -21,6 +22,8 @@ MultiModalKwargsItem, MultiModalSharedField, NestedTensors) +logger = init_logger(__name__) + CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_RAW_VIEW = 3 @@ -47,9 +50,7 @@ class MsgpackEncoder: via dedicated messages. Note that this is a per-tensor limit. """ - def __init__(self, - size_threshold: Optional[int] = None, - allow_pickle: bool = True): + def __init__(self, size_threshold: Optional[int] = None): if size_threshold is None: size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) @@ -58,7 +59,10 @@ def __init__(self, # pass custom data to the hook otherwise. self.aux_buffers: Optional[list[bytestr]] = None self.size_threshold = size_threshold - self.allow_pickle = allow_pickle + if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: + logger.warning( + "Allowing insecure serialization using pickle due to " + "VLLM_ALLOW_INSECURE_SERIALIZATION=1") def encode(self, obj: Any) -> Sequence[bytestr]: try: @@ -89,6 +93,12 @@ def enc_hook(self, obj: Any) -> Any: if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): return self._encode_ndarray(obj) + if isinstance(obj, slice): + # We are assuming only int-based values will be used here. + return tuple( + int(v) if v is not None else None + for v in (obj.start, obj.stop, obj.step)) + if isinstance(obj, MultiModalKwargs): mm: MultiModalKwargs = obj if not mm.modalities: @@ -108,7 +118,7 @@ def enc_hook(self, obj: Any) -> Any: for itemlist in mm._items_by_modality.values() for item in itemlist] - if not self.allow_pickle: + if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: raise TypeError(f"Object of type {type(obj)} is not serializable") if isinstance(obj, FunctionType): @@ -185,13 +195,16 @@ class MsgpackDecoder: not thread-safe when encoding tensors / numpy arrays. """ - def __init__(self, t: Optional[Any] = None, allow_pickle: bool = True): + def __init__(self, t: Optional[Any] = None): args = () if t is None else (t, ) self.decoder = msgpack.Decoder(*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook) self.aux_buffers: Sequence[bytestr] = () - self.allow_pickle = allow_pickle + if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: + logger.warning( + "Allowing insecure deserialization using pickle due to " + "VLLM_ALLOW_INSECURE_SERIALIZATION=1") def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any: if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)): @@ -212,6 +225,8 @@ def dec_hook(self, t: type, obj: Any) -> Any: return self._decode_ndarray(obj) if issubclass(t, torch.Tensor): return self._decode_tensor(obj) + if t is slice: + return slice(*obj) if issubclass(t, MultiModalKwargs): if isinstance(obj, list): return MultiModalKwargs.from_items( @@ -253,6 +268,12 @@ def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: factory_meth_name, *field_args = v["field"] factory_meth = getattr(MultiModalFieldConfig, factory_meth_name) + + # Special case: decode the union "slices" field of + # MultiModalFlatField + if factory_meth_name == "flat": + field_args[0] = self._decode_nested_slices(field_args[0]) + v["field"] = factory_meth(None, *field_args).field elems.append(MultiModalFieldElem(**v)) decoded_items.append(MultiModalKwargsItem.from_elems(elems)) @@ -269,11 +290,17 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: return self._decode_tensor(obj) return [self._decode_nested_tensors(x) for x in obj] + def _decode_nested_slices(self, obj: Any) -> Any: + assert isinstance(obj, (list, tuple)) + if obj and not isinstance(obj[0], (list, tuple)): + return slice(*obj) + return [self._decode_nested_slices(x) for x in obj] + def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_RAW_VIEW: return data - if self.allow_pickle: + if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: if code == CUSTOM_TYPE_PICKLE: return pickle.loads(data) if code == CUSTOM_TYPE_CLOUDPICKLE: