|  | 
| 26 | 26 | from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper | 
| 27 | 27 | from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule | 
| 28 | 28 | from colossalai.pipeline.stage_manager import PipelineStageManager | 
| 29 |  | -from colossalai.shardformer import GradCkptCollection, ShardConfig, ShardFormer | 
|  | 29 | +from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer | 
| 30 | 30 | from colossalai.shardformer.layer.utils import SeqParallelUtils | 
| 31 | 31 | from colossalai.shardformer.policies.base_policy import Policy | 
| 32 | 32 | from colossalai.tensor.d_tensor.api import is_distributed_tensor | 
| @@ -930,7 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase): | 
| 930 | 930 |         custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. | 
| 931 | 931 |         pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. | 
| 932 | 932 |         num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. | 
| 933 |  | -        gradient_ckpt_collection (GradCkptCollection, optional): The configuration for gradient checkpointing. Defaults to None. | 
|  | 933 | +        gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. | 
| 934 | 934 |         enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. | 
| 935 | 935 |     """ | 
| 936 | 936 | 
 | 
| @@ -970,7 +970,7 @@ def __init__( | 
| 970 | 970 |         custom_policy: Policy = None, | 
| 971 | 971 |         pp_style: str = "1f1b", | 
| 972 | 972 |         num_model_chunks: int = 1, | 
| 973 |  | -        gradient_ckpt_collection: Optional[GradCkptCollection] = None, | 
|  | 973 | +        gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, | 
| 974 | 974 |         enable_metadata_cache: bool = True, | 
| 975 | 975 |     ) -> None: | 
| 976 | 976 |         super().__init__() | 
| @@ -1045,7 +1045,7 @@ def __init__( | 
| 1045 | 1045 |             enable_sequence_parallelism=enable_sequence_parallelism, | 
| 1046 | 1046 |             enable_sequence_overlap=enable_sequence_overlap, | 
| 1047 | 1047 |             parallel_output=parallel_output, | 
| 1048 |  | -            gradient_ckpt_collection=gradient_ckpt_collection, | 
|  | 1048 | +            gradient_checkpoint_config=gradient_checkpoint_config, | 
| 1049 | 1049 |         ) | 
| 1050 | 1050 |         self.amp_config = dict( | 
| 1051 | 1051 |             initial_scale=initial_scale, | 
|  | 
0 commit comments