Skip to content

Commit 4992c1d

Browse files
njhillpaulpak58
authored andcommitted
[BugFix] Fix case where collective_rpc returns None (vllm-project#22006)
Signed-off-by: Nick Hill <[email protected]> Signed-off-by: Paul Pak <[email protected]>
1 parent 3e194d6 commit 4992c1d

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

tests/v1/engine/test_engine_core_client.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,10 @@ def echo_dc(
305305
return_list: bool = False,
306306
) -> Union[MyDataclass, list[MyDataclass]]:
307307
print(f"echo dc util function called: {msg}")
308+
val = None if msg is None else MyDataclass(msg)
308309
# Return dataclass to verify support for returning custom types
309310
# (for which there is special handling to make it work with msgspec).
310-
return [MyDataclass(msg) for _ in range(3)] if return_list \
311-
else MyDataclass(msg)
311+
return [val for _ in range(3)] if return_list else val
312312

313313

314314
@pytest.mark.asyncio(loop_scope="function")
@@ -351,6 +351,15 @@ async def test_engine_core_client_util_method_custom_return(
351351
assert isinstance(result, list) and all(
352352
isinstance(r, MyDataclass) and r.message == "testarg2"
353353
for r in result)
354+
355+
# Test returning None and list of Nones
356+
result = await core_client.call_utility_async(
357+
"echo_dc", None, False)
358+
assert result is None
359+
result = await core_client.call_utility_async(
360+
"echo_dc", None, True)
361+
assert isinstance(result, list) and all(r is None for r in result)
362+
354363
finally:
355364
client.shutdown()
356365

vllm/v1/serial_utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def _log_insecure_serialization_warning():
4949
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
5050

5151

52-
def _typestr(t: type):
52+
def _typestr(val: Any) -> Optional[tuple[str, str]]:
53+
if val is None:
54+
return None
55+
t = type(val)
5356
return t.__module__, t.__qualname__
5457

5558

@@ -131,14 +134,13 @@ def enc_hook(self, obj: Any) -> Any:
131134

132135
if isinstance(obj, UtilityResult):
133136
result = obj.result
134-
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION or result is None:
137+
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
135138
return None, result
136139
# Since utility results are not strongly typed, we also encode
137140
# the type (or a list of types in the case it's a list) to
138141
# help with correct msgspec deserialization.
139-
cls = result.__class__
140-
return _typestr(cls) if cls is not list else [
141-
_typestr(type(v)) for v in result
142+
return _typestr(result) if type(result) is not list else [
143+
_typestr(v) for v in result
142144
], result
143145

144146
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
@@ -277,7 +279,9 @@ def _decode_utility_result(self, obj: Any) -> UtilityResult:
277279
]
278280
return UtilityResult(result)
279281

280-
def _convert_result(self, result_type: Sequence[str], result: Any):
282+
def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
283+
if result_type is None:
284+
return result
281285
mod_name, name = result_type
282286
mod = importlib.import_module(mod_name)
283287
result_type = getattr(mod, name)

0 commit comments

Comments
 (0)