Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,10 @@ def echo_dc(
return_list: bool = False,
) -> Union[MyDataclass, list[MyDataclass]]:
print(f"echo dc util function called: {msg}")
val = None if msg is None else MyDataclass(msg)
# Return dataclass to verify support for returning custom types
# (for which there is special handling to make it work with msgspec).
return [MyDataclass(msg) for _ in range(3)] if return_list \
else MyDataclass(msg)
return [val for _ in range(3)] if return_list else val


@pytest.mark.asyncio(loop_scope="function")
Expand Down Expand Up @@ -351,6 +351,15 @@ async def test_engine_core_client_util_method_custom_return(
assert isinstance(result, list) and all(
isinstance(r, MyDataclass) and r.message == "testarg2"
for r in result)

# Test returning None and list of Nones
result = await core_client.call_utility_async(
"echo_dc", None, False)
assert result is None
result = await core_client.call_utility_async(
"echo_dc", None, True)
assert isinstance(result, list) and all(r is None for r in result)

finally:
client.shutdown()

Expand Down
16 changes: 10 additions & 6 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def _log_insecure_serialization_warning():
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")


def _typestr(t: type):
def _typestr(val: Any) -> Optional[tuple[str, str]]:
if val is None:
return None
t = type(val)
return t.__module__, t.__qualname__


Expand Down Expand Up @@ -131,14 +134,13 @@ def enc_hook(self, obj: Any) -> Any:

if isinstance(obj, UtilityResult):
result = obj.result
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION or result is None:
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
return None, result
# Since utility results are not strongly typed, we also encode
# the type (or a list of types in the case it's a list) to
# help with correct msgspec deserialization.
cls = result.__class__
return _typestr(cls) if cls is not list else [
_typestr(type(v)) for v in result
return _typestr(result) if type(result) is not list else [
_typestr(v) for v in result
], result

if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
Expand Down Expand Up @@ -277,7 +279,9 @@ def _decode_utility_result(self, obj: Any) -> UtilityResult:
]
return UtilityResult(result)

def _convert_result(self, result_type: Sequence[str], result: Any):
def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
if result_type is None:
return result
mod_name, name = result_type
mod = importlib.import_module(mod_name)
result_type = getattr(mod, name)
Expand Down