Skip to content

Commit be82b5d

Browse files
authored
[hotfix] Fix the bug where process groups were not being properly released. (#4940)
* Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit 479900c.
1 parent 4f0234f commit be82b5d

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

colossalai/cluster/process_group_mesh.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gc
12
import itertools
23
from functools import reduce
34
from operator import mul
@@ -44,6 +45,24 @@ def __init__(self, *size: int) -> None:
4445
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
4546
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
4647

48+
def __del__(self):
49+
r"""
50+
Destructor method for the ProcessGroupMesh class.
51+
52+
When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for
53+
cleaning up any process groups that were created during the lifetime of the object.
54+
55+
Note:
56+
All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed
57+
when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release
58+
system resources.
59+
"""
60+
for group in self._ranks_to_group.values():
61+
dist.destroy_process_group(group)
62+
63+
# Manually clear all process groups to save memory
64+
gc.collect()
65+
4766
@property
4867
def shape(self) -> Tuple[int, ...]:
4968
return self._shape

colossalai/tensor/d_tensor/layout_converter.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Dict, List, Tuple
55

66
import torch
7+
import torch.distributed as dist
78

89
from colossalai.context.singleton_meta import SingletonMeta
910
from colossalai.tensor.d_tensor.comm_spec import *
@@ -438,11 +439,58 @@ def layout_converting(
438439
MAX_TRANSFORM_STEPS = 20
439440
total_steps = 0
440441
transform_path = []
441-
comm_action_sequence = []
442+
comm_action_sequence: List[CommSpec] = []
442443
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
443444

444445
if spec_pairs in self.cached_solution:
445-
return self.cached_solution[spec_pairs]
446+
# Solution Cache hit
447+
448+
def _group_alive_check(cached_comm_action_sequence):
449+
r"""
450+
Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method.
451+
If not deleted, return True; otherwise, return False.
452+
453+
Args:
454+
cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions.
455+
456+
Returns:
457+
bool: True if all process groups are still registered, False if at least one has been deleted.
458+
459+
Raises:
460+
RuntimeError: If there is an error while checking the status of a process group.
461+
"""
462+
463+
# Collect all process groups used in communication actions from the cached sequence
464+
used_process_groups = [
465+
pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values()
466+
]
467+
468+
# Check if each process group is still alive
469+
for process_group in used_process_groups:
470+
try:
471+
dist.get_rank(process_group)
472+
except RuntimeError as e:
473+
# If the group is not registered, it means it has been deleted
474+
if str(e) == (
475+
f"Group {process_group} is not registered, please create group with torch.distributed.new_group API"
476+
):
477+
return False
478+
elif str(e) == "The given group does not exist":
479+
return False
480+
else:
481+
# Re-raise the exception if it's not related to group deletion
482+
raise e
483+
# All process groups are alive
484+
return True
485+
486+
cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs]
487+
488+
if _group_alive_check(cached_comm_action_sequence):
489+
# If all process groups have not been deleted, the cache is valid
490+
return cached_transform_path, cached_comm_action_sequence
491+
else:
492+
# If at least one process group has been deleted, the cache is invalid, so delete it
493+
del self.cached_solution[spec_pairs]
446494

447495
# We do nothing if the sharding spec is all the same.
448496
if source_spec.spec_diff(target_spec) == 0:

0 commit comments

Comments
 (0)