|  | 
| 4 | 4 | from typing import Dict, List, Tuple | 
| 5 | 5 | 
 | 
| 6 | 6 | import torch | 
|  | 7 | +import torch.distributed as dist | 
| 7 | 8 | 
 | 
| 8 | 9 | from colossalai.context.singleton_meta import SingletonMeta | 
| 9 | 10 | from colossalai.tensor.d_tensor.comm_spec import * | 
| @@ -438,11 +439,58 @@ def layout_converting( | 
| 438 | 439 |         MAX_TRANSFORM_STEPS = 20 | 
| 439 | 440 |         total_steps = 0 | 
| 440 | 441 |         transform_path = [] | 
| 441 |  | -        comm_action_sequence = [] | 
|  | 442 | +        comm_action_sequence: List[CommSpec] = [] | 
| 442 | 443 |         spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) | 
| 443 | 444 | 
 | 
| 444 | 445 |         if spec_pairs in self.cached_solution: | 
| 445 |  | -            return self.cached_solution[spec_pairs] | 
|  | 446 | +            # Solution Cache hit | 
|  | 447 | + | 
|  | 448 | +            def _group_alive_check(cached_comm_action_sequence): | 
|  | 449 | +                r""" | 
|  | 450 | +                Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method. | 
|  | 451 | +                If not deleted, return True; otherwise, return False. | 
|  | 452 | +
 | 
|  | 453 | +                Args: | 
|  | 454 | +                    cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions. | 
|  | 455 | +
 | 
|  | 456 | +                Returns: | 
|  | 457 | +                    bool: True if all process groups are still registered, False if at least one has been deleted. | 
|  | 458 | +
 | 
|  | 459 | +                Raises: | 
|  | 460 | +                    RuntimeError: If there is an error while checking the status of a process group. | 
|  | 461 | +                """ | 
|  | 462 | + | 
|  | 463 | +                # Collect all process groups used in communication actions from the cached sequence | 
|  | 464 | +                used_process_groups = [ | 
|  | 465 | +                    pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values() | 
|  | 466 | +                ] | 
|  | 467 | + | 
|  | 468 | +                # Check if each process group is still alive | 
|  | 469 | +                for process_group in used_process_groups: | 
|  | 470 | +                    try: | 
|  | 471 | +                        dist.get_rank(process_group) | 
|  | 472 | +                    except RuntimeError as e: | 
|  | 473 | +                        # If the group is not registered, it means it has been deleted | 
|  | 474 | +                        if str(e) == ( | 
|  | 475 | +                            f"Group {process_group} is not registered, please create group with torch.distributed.new_group API" | 
|  | 476 | +                        ): | 
|  | 477 | +                            return False | 
|  | 478 | +                        elif str(e) == "The given group does not exist": | 
|  | 479 | +                            return False | 
|  | 480 | +                        else: | 
|  | 481 | +                            # Re-raise the exception if it's not related to group deletion | 
|  | 482 | +                            raise e | 
|  | 483 | +                # All process groups are alive | 
|  | 484 | +                return True | 
|  | 485 | + | 
|  | 486 | +            cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs] | 
|  | 487 | + | 
|  | 488 | +            if _group_alive_check(cached_comm_action_sequence): | 
|  | 489 | +                # If all process groups have not been deleted, the cache is valid | 
|  | 490 | +                return cached_transform_path, cached_comm_action_sequence | 
|  | 491 | +            else: | 
|  | 492 | +                # If at least one process group has been deleted, the cache is invalid, so delete it | 
|  | 493 | +                del self.cached_solution[spec_pairs] | 
| 446 | 494 | 
 | 
| 447 | 495 |         # We do nothing if the sharding spec is all the same. | 
| 448 | 496 |         if source_spec.spec_diff(target_spec) == 0: | 
|  | 
0 commit comments