Skip to content

Commit 779c97e

Browse files
committed
fix: remove GradCkptCollection
1 parent e14c496 commit 779c97e

File tree

6 files changed

+23
-52
lines changed

6 files changed

+23
-52
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
2727
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
2828
from colossalai.pipeline.stage_manager import PipelineStageManager
29-
from colossalai.shardformer import GradCkptCollection, ShardConfig, ShardFormer
29+
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
3030
from colossalai.shardformer.layer.utils import SeqParallelUtils
3131
from colossalai.shardformer.policies.base_policy import Policy
3232
from colossalai.tensor.d_tensor.api import is_distributed_tensor
@@ -930,7 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase):
930930
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
931931
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
932932
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
933-
gradient_ckpt_collection (GradCkptCollection, optional): The configuration for gradient checkpointing. Defaults to None.
933+
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
934934
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
935935
"""
936936

@@ -970,7 +970,7 @@ def __init__(
970970
custom_policy: Policy = None,
971971
pp_style: str = "1f1b",
972972
num_model_chunks: int = 1,
973-
gradient_ckpt_collection: Optional[GradCkptCollection] = None,
973+
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
974974
enable_metadata_cache: bool = True,
975975
) -> None:
976976
super().__init__()
@@ -1045,7 +1045,7 @@ def __init__(
10451045
enable_sequence_parallelism=enable_sequence_parallelism,
10461046
enable_sequence_overlap=enable_sequence_overlap,
10471047
parallel_output=parallel_output,
1048-
gradient_ckpt_collection=gradient_ckpt_collection,
1048+
gradient_checkpoint_config=gradient_checkpoint_config,
10491049
)
10501050
self.amp_config = dict(
10511051
initial_scale=initial_scale,

colossalai/shardformer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .shard import GradCkptCollection, ModelSharder, PipelineGradCkptConfig, ShardConfig, ShardFormer
1+
from .shard import GradientCheckpointConfig, ModelSharder, PipelineGradientCheckpointConfig, ShardConfig, ShardFormer
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .grad_ckpt_config import GradCkptCollection, PipelineGradCkptConfig
1+
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
22
from .shard_config import ShardConfig
33
from .sharder import ModelSharder
44
from .shardformer import ShardFormer
55

6-
__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradCkptConfig", "GradCkptCollection"]
6+
__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradientCheckpointConfig", "GradientCheckpointConfig"]

colossalai/shardformer/shard/grad_ckpt_config.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,22 @@
1-
from dataclasses import dataclass, field
1+
from dataclasses import dataclass
22
from typing import List, Optional
33

44

55
@dataclass
6-
class GradCkptConfig:
6+
class GradientCheckpointConfig:
77
# TODO: for future use
88
_dummy_value: Optional[float] = None
99

10-
def __post_init__(self):
11-
raise NotImplementedError()
12-
1310
@property
1411
def control_gradient_checkpointing(self) -> bool:
15-
raise NotImplementedError()
12+
return False
1613

1714
def get_num_ckpt_layers(self, *args, **kwargs) -> int:
1815
raise NotImplementedError()
1916

2017

2118
@dataclass
22-
class GradCkptCollection:
23-
gradient_ckpt_configs: List[GradCkptConfig] = field(default_factory=list)
24-
25-
def __post_init__(self):
26-
assert all([isinstance(config, GradCkptConfig) for config in self.gradient_ckpt_configs])
27-
28-
@property
29-
def control_gradient_checkpointing(self) -> bool:
30-
return any([config.control_gradient_checkpointing for config in self.gradient_ckpt_configs])
31-
32-
def get_num_ckpt_layers(self, *args, **kwargs) -> int:
33-
for config in self.gradient_ckpt_configs:
34-
if config.control_gradient_checkpointing:
35-
return config.get_num_ckpt_layers(*args, **kwargs)
36-
raise RuntimeError("No checkpointed layers information is provided")
37-
38-
39-
@dataclass
40-
class PipelineGradCkptConfig(GradCkptConfig):
19+
class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
4120
r"""
4221
The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism.
4322
Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism.

colossalai/shardformer/shard/shard_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from colossalai.pipeline.stage_manager import PipelineStageManager
88

9-
from .grad_ckpt_config import GradCkptCollection
9+
from .grad_ckpt_config import GradientCheckpointConfig
1010

1111
__all__ = ["ShardConfig"]
1212

@@ -25,7 +25,7 @@ class ShardConfig:
2525
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
2626
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
2727
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
28-
gradient_ckpt_collection (Optional[GradCkptCollection]): The gradient checkpointing configs. Defaults to None.
28+
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
2929
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
3030
"""
3131
tensor_parallel_process_group: Optional[ProcessGroup] = None
@@ -38,7 +38,7 @@ class ShardConfig:
3838
enable_sequence_parallelism: bool = False
3939
enable_sequence_overlap: bool = False
4040
parallel_output: bool = True
41-
gradient_ckpt_collection: Optional[GradCkptCollection] = None
41+
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
4242
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
4343
# TODO padding vocab
4444
# make_vocab_size_divisible_by: int = 128

tests/test_shardformer/test_model/test_shard_llama.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import colossalai
77
from colossalai.logging import disable_existing_loggers
8-
from colossalai.shardformer import GradCkptCollection, PipelineGradCkptConfig
8+
from colossalai.shardformer import PipelineGradientCheckpointConfig
99
from colossalai.shardformer.layer.utils import Randomizer
1010
from colossalai.tensor.d_tensor.api import clear_layout_converter
1111
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
107107
"precision": "fp16",
108108
"initial_scale": 1,
109109
"enable_gradient_checkpointing": True,
110-
"gradient_ckpt_collection": GradCkptCollection([PipelineGradCkptConfig(gradient_checkpointing_ratio=0.5)]),
110+
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
111111
},
112112
{
113113
"tp_size": 1,
@@ -116,12 +116,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
116116
"use_lazy_init": False,
117117
"precision": "fp32",
118118
"enable_gradient_checkpointing": True,
119-
"gradient_ckpt_collection": GradCkptCollection(
120-
[
121-
PipelineGradCkptConfig(
122-
num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
123-
)
124-
]
119+
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
120+
num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
125121
),
126122
},
127123
{
@@ -205,15 +201,11 @@ def run_llama_test(test_config):
205201
"zero_stage": 1,
206202
"initial_scale": 1,
207203
"enable_gradient_checkpointing": True,
208-
"gradient_ckpt_collection": GradCkptCollection(
209-
[
210-
PipelineGradCkptConfig(
211-
num_stages=2,
212-
num_model_chunks=2,
213-
num_model_layers=8,
214-
num_ckpt_layers_per_stage=[0, 1, 2, 2],
215-
)
216-
]
204+
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
205+
num_stages=2,
206+
num_model_chunks=2,
207+
num_model_layers=8,
208+
num_ckpt_layers_per_stage=[0, 1, 2, 2],
217209
),
218210
},
219211
],

0 commit comments

Comments
 (0)