Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import itertools
from functools import reduce
from operator import mul
Expand Down Expand Up @@ -44,6 +45,24 @@ def __init__(self, *size: int) -> None:
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}

def __del__(self):
r"""
Destructor method for the ProcessGroupMesh class.

When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for
cleaning up any process groups that were created during the lifetime of the object.

Note:
All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed
when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release
system resources.
"""
for group in self._ranks_to_group.values():
dist.destroy_process_group(group)

# Manually clear all process groups to save memory
gc.collect()

@property
def shape(self) -> Tuple[int, ...]:
return self._shape
Expand Down
52 changes: 50 additions & 2 deletions colossalai/tensor/d_tensor/layout_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, List, Tuple

import torch
import torch.distributed as dist

from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import *
Expand Down Expand Up @@ -438,11 +439,58 @@ def layout_converting(
MAX_TRANSFORM_STEPS = 20
total_steps = 0
transform_path = []
comm_action_sequence = []
comm_action_sequence: List[CommSpec] = []
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))

if spec_pairs in self.cached_solution:
return self.cached_solution[spec_pairs]
# Solution Cache hit

def _group_alive_check(cached_comm_action_sequence):
r"""
Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method.
If not deleted, return True; otherwise, return False.

Args:
cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions.

Returns:
bool: True if all process groups are still registered, False if at least one has been deleted.

Raises:
RuntimeError: If there is an error while checking the status of a process group.
"""

# Collect all process groups used in communication actions from the cached sequence
used_process_groups = [
pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values()
]

# Check if each process group is still alive
for process_group in used_process_groups:
try:
dist.get_rank(process_group)
except RuntimeError as e:
# If the group is not registered, it means it has been deleted
if str(e) == (
f"Group {process_group} is not registered, please create group with torch.distributed.new_group API"
):
return False
elif str(e) == "The given group does not exist":
return False
else:
# Re-raise the exception if it's not related to group deletion
raise e
# All process groups are alive
return True

cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs]

if _group_alive_check(cached_comm_action_sequence):
# If all process groups have not been deleted, the cache is valid
return cached_transform_path, cached_comm_action_sequence
else:
# If at least one process group has been deleted, the cache is invalid, so delete it
del self.cached_solution[spec_pairs]

# We do nothing if the sharding spec is all the same.
if source_spec.spec_diff(target_spec) == 0:
Expand Down