|
30 | 30 |
|
31 | 31 | ### Quick Start |
32 | 32 |
|
33 | | -The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.): |
| 33 | +The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization): |
34 | 34 |
|
35 | 35 | ```python |
36 | | -from colossalai.shardformer import ShardConfig, Shard |
| 36 | +from colossalai.shardformer import ShardConfig, ShardFormer |
37 | 37 | from transformers import BertForMaskedLM |
| 38 | +import colossalai |
38 | 39 |
|
39 | 40 | # launch colossalai |
40 | | -colossalai.launch_from_torch() |
| 41 | +colossalai.launch_from_torch(config={}) |
41 | 42 |
|
42 | 43 | # create model |
43 | 44 | config = BertConfig.from_pretrained('bert-base-uncased') |
44 | 45 | model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config) |
45 | 46 |
|
46 | 47 | # create huggingface model as normal |
47 | | -shard_config = ShardConfig() |
| 48 | +shard_config = ShardConfig(tensor_parallel_process_group=tp_group, |
| 49 | + pipeline_stage_manager=stage_manager, |
| 50 | + enable_tensor_parallelism=True, |
| 51 | + enable_fused_normalization=True, |
| 52 | + enable_flash_attention=True, |
| 53 | + enable_jit_fused=True, |
| 54 | + enable_sequence_parallelism=True, |
| 55 | + enable_sequence_overlap=True) |
| 56 | + |
48 | 57 | shard_former = ShardFormer(shard_config=shard_config) |
49 | | -sharded_model = shard_former.optimize(model).to('cuda') |
| 58 | +sharded_model, shared_params = shard_former.optimize(model).to('cuda') |
50 | 59 |
|
51 | 60 | # do everything like normal |
52 | 61 | ... |
53 | 62 | ``` |
| 63 | +shardformer configuration |
| 64 | + |
| 65 | +`tensor_parallel_process_group`: the process group of tensor parallelism, it's necessary when using tensor parallel. |
| 66 | +`pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. |
| 67 | +{{ autodoc:colossalai.pipeline.stage_manager.PipelineStageManager }} |
| 68 | +`enable_tensor_parallelism`: using tensor parallel, partition the model along the columns or along the rows |
| 69 | +`enable_fused_normalization`: using apex fused layernorm |
| 70 | +`enable_flash_attention`: using flash attention |
| 71 | +`enable_jit_fused`: using jit fused operators |
| 72 | +`enable_sequence_parallelism`: using sequence parallelism, partition these non-tensor parallel regions along the sequence dimension. |
| 73 | +`enable_sequence_overlap`: overlap the computation and communication in the sequence parallelism, it's used with `enable_sequence_parallelism`. |
| 74 | + |
54 | 75 |
|
55 | 76 | ### Write your own policy |
56 | 77 |
|
@@ -82,44 +103,30 @@ We will follow this roadmap to develop Shardformer: |
82 | 103 | - [x] API Implementation |
83 | 104 | - [x] Unit Testing |
84 | 105 | - [ ] Policy Implementation |
85 | | - - [ ] Hugging Face |
86 | | - - [ ] NLP |
87 | | - - [x] BERT |
88 | | - - [x] T5 |
89 | | - - [x] LlaMa |
90 | | - - [x] GPT2 |
91 | | - - [x] OPT |
92 | | - - [x] BLOOM |
93 | | - - [ ] GLM |
94 | | - - [ ] RoBERTa |
95 | | - - [ ] ALBERT |
96 | | - - [ ] ERNIE |
97 | | - - [ ] GPT Neo |
98 | | - - [ ] GPT-J |
99 | | - - [ ] CV |
100 | | - - [x] ViT |
101 | | - - [ ] BEiT |
102 | | - - [ ] SwinTransformer |
103 | | - - [ ] SwinTransformer V2 |
104 | | - - [ ] Audio |
105 | | - - [x] Whisper |
106 | | - - [ ] Multi-modal |
107 | | - - [x] SAM |
108 | | - - [x] BLIP-2 |
109 | | -- [ ] Flash Attention Support |
110 | | - - [ ] NLP |
111 | | - - [x] BERT |
112 | | - - [x] T5 |
113 | | - - [x] LlaMa |
114 | | - - [x] GPT2 |
115 | | - - [x] OPT |
116 | | - - [x] BLOOM |
117 | | - - [ ] GLM |
118 | | - - [ ] RoBERTa |
119 | | - - [ ] ALBERT |
120 | | - - [ ] ERNIE |
121 | | - - [ ] GPT Neo |
122 | | - - [ ] GPT-J |
| 106 | + |
| 107 | +| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap | |
| 108 | +| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: | |
| 109 | +| bert | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | |
| 110 | +| t5 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | |
| 111 | +| llama V1/V2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | |
| 112 | +| gpt2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | |
| 113 | +| opt | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | |
| 114 | +| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | |
| 115 | +| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | |
| 116 | +| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | |
| 117 | +| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | |
| 118 | +| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | |
| 119 | +| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | |
| 120 | +| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | |
| 121 | +| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | |
| 122 | +| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | |
| 123 | +| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | |
| 124 | +| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | |
| 125 | +| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | |
| 126 | +| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | |
| 127 | +| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | |
| 128 | +| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | |
| 129 | + |
123 | 130 |
|
124 | 131 | ## 💡 API Design |
125 | 132 |
|
@@ -286,41 +293,36 @@ class ShardFormer: |
286 | 293 |
|
287 | 294 | Example: |
288 | 295 |
|
| 296 | + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') |
| 297 | + shard_config = ShardConfig() |
289 | 298 | shard_former = ShardFormer(shard_config=shard_config) |
290 | | - shard_former.init_distributed() |
291 | | - model = shard_former.optimize(model, policy=policy) |
292 | | - dataloader = shard_former.shard_dataset(dataset) |
| 299 | + model, shared_params = shard_former.optimize(org_model) |
293 | 300 |
|
294 | 301 | """ |
295 | 302 |
|
296 | 303 | def __init__(self, shard_config: ShardConfig): |
297 | 304 | """ |
298 | 305 | Do two things: |
299 | | - 1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp |
| 306 | + 1. Create a distribute coordinator |
300 | 307 | 2. serve as a store for shard config |
301 | 308 | """ |
302 | 309 | self.shard_config = shard_config |
303 | | - self.pg_manager = None |
| 310 | + self.coordinator = DistCoordinator() |
304 | 311 |
|
305 | | - def init_distributed(self) -> colossalai.cluster.ProcessGroupManager: |
306 | | - """ |
307 | | - Initialize the distributed process group according to the |
308 | | - """ |
309 | | - pg_manager = ... |
310 | | - self.pg_manager = pg_manager |
311 | | - return pg_manager |
| 312 | + def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: |
| 313 | + r""" |
| 314 | + This method will optimize the model based on the given policy. |
312 | 315 |
|
313 | | - def shard_model(self, model: torch.nn.Module,policy: Policy) -> torch.nn.Module: |
314 | | - """ |
315 | | - Shard model for TP and PP |
316 | | - """ |
317 | | - ... |
| 316 | + Args: |
| 317 | + model (`torch.nn.Model`): the origin huggingface model |
| 318 | + shard_config (`ShardConfig`): the config for distribute information |
| 319 | + policy (`Policy`): the custom policy for sharding |
318 | 320 |
|
319 | | - def shard_dataset(self, dataset: Dataset) -> Dataloader: |
| 321 | + Returns: the sharded model and the shared parameters |
320 | 322 | """ |
321 | | - Shard dataset for DP |
322 | | - """ |
323 | | - ... |
| 323 | + sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) |
| 324 | + shared_params = sharder.shard() |
| 325 | + return model, shared_params |
324 | 326 | ``` |
325 | 327 |
|
326 | 328 | ## ⌨️ Development Notes |
@@ -429,13 +431,24 @@ As shown in the figures above, when the sequence length is around 1000 or greate |
429 | 431 | ### Convergence |
430 | 432 |
|
431 | 433 |
|
432 | | -To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results. |
| 434 | +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results. |
| 435 | + |
| 436 | +the configurations are as follows: |
| 437 | +```python |
| 438 | +batch_size = 2 |
| 439 | +epoch = 3 |
| 440 | +lr = 2.4e-5 |
| 441 | +accumulation_steps = 8 |
| 442 | +warmup_fraction = 0.03 |
| 443 | +``` |
| 444 | + |
433 | 445 |
|
434 | 446 |
|
435 | 447 | | accuracy | f1 | loss | GPU number | model sharded | |
436 | 448 | | :------: | :-----: | :-----: | :--------: | :---------: | |
437 | | -| 0.84589 | 0.88613 | 0.43414 | 4 | True | |
438 | | -| 0.83594 | 0.88064 | 0.43298 | 1 | False | |
| 449 | +| 0.82971 | 0.87713 | 0.23194 | 4 | True | |
| 450 | +| 0.83797 | 0.88006 | 0.22683 | 2 | True | |
| 451 | +| 0.84521 | 0.88700 | 0.21822 | 1 | False | |
439 | 452 |
|
440 | 453 |
|
441 | 454 | Overall, the results demonstrate that using shardformers during model training does not affect the convergence. |
0 commit comments