From af0e16c70572c222e747a64f637b7c795f884334 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 13 Jun 2024 12:10:11 -0700 Subject: [PATCH 1/2] fix benign error --- vllm/distributed/parallel_state.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f6a2fc9b05a8..d72dd7c2a4e9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -23,11 +23,12 @@ from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass -from multiprocessing import resource_tracker, shared_memory +from multiprocessing import shared_memory from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch.distributed import Backend, ProcessGroup +from unittest.mock import patch import vllm.envs as envs from vllm.logger import init_logger @@ -744,7 +745,12 @@ def is_in_the_same_node(pg: ProcessGroup): src=ranks[0], group=pg) name = recv[0] - shm = shared_memory.SharedMemory(name=name) + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + shm = shared_memory.SharedMemory(name=name) if shm.buf[:len(magic_message)] == magic_message: is_in_the_same_node[rank] = 1 except Exception as e: @@ -757,14 +763,8 @@ def is_in_the_same_node(pg: ProcessGroup): # clean up the shared memory segment with contextlib.suppress(OSError): - if rank == 0: - if shm: + if rank == 0 and shm: shm.unlink() - else: - if shm: - # fix to https://stackoverflow.com/q/62748654/9191338 - resource_tracker.unregister( - shm._name, "shared_memory") # type: ignore[attr-defined] torch.distributed.all_reduce(is_in_the_same_node, group=pg) return is_in_the_same_node.sum().item() == world_size From 74bf481115e3b0fe087aabc2811f87a73b88745b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 13 Jun 2024 12:13:35 -0700 Subject: [PATCH 2/2] fix format --- vllm/distributed/parallel_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d72dd7c2a4e9..16c5297af1b5 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -25,10 +25,10 @@ from dataclasses import dataclass from multiprocessing import shared_memory from typing import Any, Dict, List, Optional, Tuple, Union +from unittest.mock import patch import torch from torch.distributed import Backend, ProcessGroup -from unittest.mock import patch import vllm.envs as envs from vllm.logger import init_logger @@ -746,7 +746,7 @@ def is_in_the_same_node(pg: ProcessGroup): group=pg) name = recv[0] # fix to https://stackoverflow.com/q/62748654/9191338 - # Python incorrectly tracks shared memory even if it is not + # Python incorrectly tracks shared memory even if it is not # created by the process. The following patch is a workaround. with patch("multiprocessing.resource_tracker.register", lambda *args, **kwargs: None): @@ -764,7 +764,7 @@ def is_in_the_same_node(pg: ProcessGroup): # clean up the shared memory segment with contextlib.suppress(OSError): if rank == 0 and shm: - shm.unlink() + shm.unlink() torch.distributed.all_reduce(is_in_the_same_node, group=pg) return is_in_the_same_node.sum().item() == world_size