|
4 | 4 | from typing import Dict, List, Tuple |
5 | 5 |
|
6 | 6 | import torch |
| 7 | +import torch.distributed as dist |
7 | 8 |
|
8 | 9 | from colossalai.context.singleton_meta import SingletonMeta |
9 | 10 | from colossalai.tensor.d_tensor.comm_spec import * |
@@ -438,11 +439,58 @@ def layout_converting( |
438 | 439 | MAX_TRANSFORM_STEPS = 20 |
439 | 440 | total_steps = 0 |
440 | 441 | transform_path = [] |
441 | | - comm_action_sequence = [] |
| 442 | + comm_action_sequence: List[CommSpec] = [] |
442 | 443 | spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) |
443 | 444 |
|
444 | 445 | if spec_pairs in self.cached_solution: |
445 | | - return self.cached_solution[spec_pairs] |
| 446 | + # Solution Cache hit |
| 447 | + |
| 448 | + def _group_alive_check(cached_comm_action_sequence): |
| 449 | + r""" |
| 450 | + Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method. |
| 451 | + If not deleted, return True; otherwise, return False. |
| 452 | +
|
| 453 | + Args: |
| 454 | + cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions. |
| 455 | +
|
| 456 | + Returns: |
| 457 | + bool: True if all process groups are still registered, False if at least one has been deleted. |
| 458 | +
|
| 459 | + Raises: |
| 460 | + RuntimeError: If there is an error while checking the status of a process group. |
| 461 | + """ |
| 462 | + |
| 463 | + # Collect all process groups used in communication actions from the cached sequence |
| 464 | + used_process_groups = [ |
| 465 | + pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values() |
| 466 | + ] |
| 467 | + |
| 468 | + # Check if each process group is still alive |
| 469 | + for process_group in used_process_groups: |
| 470 | + try: |
| 471 | + dist.get_rank(process_group) |
| 472 | + except RuntimeError as e: |
| 473 | + # If the group is not registered, it means it has been deleted |
| 474 | + if str(e) == ( |
| 475 | + f"Group {process_group} is not registered, please create group with torch.distributed.new_group API" |
| 476 | + ): |
| 477 | + return False |
| 478 | + elif str(e) == "The given group does not exist": |
| 479 | + return False |
| 480 | + else: |
| 481 | + # Re-raise the exception if it's not related to group deletion |
| 482 | + raise e |
| 483 | + # All process groups are alive |
| 484 | + return True |
| 485 | + |
| 486 | + cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs] |
| 487 | + |
| 488 | + if _group_alive_check(cached_comm_action_sequence): |
| 489 | + # If all process groups have not been deleted, the cache is valid |
| 490 | + return cached_transform_path, cached_comm_action_sequence |
| 491 | + else: |
| 492 | + # If at least one process group has been deleted, the cache is invalid, so delete it |
| 493 | + del self.cached_solution[spec_pairs] |
446 | 494 |
|
447 | 495 | # We do nothing if the sharding spec is all the same. |
448 | 496 | if source_spec.spec_diff(target_spec) == 0: |
|
0 commit comments