Skip to content

Commit 8844691

Browse files
authored
[shardformer] update shardformer readme (#4689)
* [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme
1 parent 1d45473 commit 8844691

File tree

4 files changed

+90
-72
lines changed

4 files changed

+90
-72
lines changed

colossalai/shardformer/README.md

Lines changed: 80 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -30,27 +30,48 @@
3030

3131
### Quick Start
3232

33-
The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.):
33+
The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization):
3434

3535
```python
36-
from colossalai.shardformer import ShardConfig, Shard
36+
from colossalai.shardformer import ShardConfig, ShardFormer
3737
from transformers import BertForMaskedLM
38+
import colossalai
3839

3940
# launch colossalai
40-
colossalai.launch_from_torch()
41+
colossalai.launch_from_torch(config={})
4142

4243
# create model
4344
config = BertConfig.from_pretrained('bert-base-uncased')
4445
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
4546

4647
# create huggingface model as normal
47-
shard_config = ShardConfig()
48+
shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
49+
pipeline_stage_manager=stage_manager,
50+
enable_tensor_parallelism=True,
51+
enable_fused_normalization=True,
52+
enable_flash_attention=True,
53+
enable_jit_fused=True,
54+
enable_sequence_parallelism=True,
55+
enable_sequence_overlap=True)
56+
4857
shard_former = ShardFormer(shard_config=shard_config)
49-
sharded_model = shard_former.optimize(model).to('cuda')
58+
sharded_model, shared_params = shard_former.optimize(model).to('cuda')
5059

5160
# do everything like normal
5261
...
5362
```
63+
shardformer configuration
64+
65+
`tensor_parallel_process_group`: the process group of tensor parallelism, it's necessary when using tensor parallel.
66+
`pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism.
67+
{{ autodoc:colossalai.pipeline.stage_manager.PipelineStageManager }}
68+
`enable_tensor_parallelism`: using tensor parallel, partition the model along the columns or along the rows
69+
`enable_fused_normalization`: using apex fused layernorm
70+
`enable_flash_attention`: using flash attention
71+
`enable_jit_fused`: using jit fused operators
72+
`enable_sequence_parallelism`: using sequence parallelism, partition these non-tensor parallel regions along the sequence dimension.
73+
`enable_sequence_overlap`: overlap the computation and communication in the sequence parallelism, it's used with `enable_sequence_parallelism`.
74+
5475

5576
### Write your own policy
5677

@@ -82,44 +103,30 @@ We will follow this roadmap to develop Shardformer:
82103
- [x] API Implementation
83104
- [x] Unit Testing
84105
- [ ] Policy Implementation
85-
- [ ] Hugging Face
86-
- [ ] NLP
87-
- [x] BERT
88-
- [x] T5
89-
- [x] LlaMa
90-
- [x] GPT2
91-
- [x] OPT
92-
- [x] BLOOM
93-
- [ ] GLM
94-
- [ ] RoBERTa
95-
- [ ] ALBERT
96-
- [ ] ERNIE
97-
- [ ] GPT Neo
98-
- [ ] GPT-J
99-
- [ ] CV
100-
- [x] ViT
101-
- [ ] BEiT
102-
- [ ] SwinTransformer
103-
- [ ] SwinTransformer V2
104-
- [ ] Audio
105-
- [x] Whisper
106-
- [ ] Multi-modal
107-
- [x] SAM
108-
- [x] BLIP-2
109-
- [ ] Flash Attention Support
110-
- [ ] NLP
111-
- [x] BERT
112-
- [x] T5
113-
- [x] LlaMa
114-
- [x] GPT2
115-
- [x] OPT
116-
- [x] BLOOM
117-
- [ ] GLM
118-
- [ ] RoBERTa
119-
- [ ] ALBERT
120-
- [ ] ERNIE
121-
- [ ] GPT Neo
122-
- [ ] GPT-J
106+
107+
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
108+
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
109+
| bert | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
110+
| t5 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
111+
| llama V1/V2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
112+
| gpt2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
113+
| opt | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
114+
| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
115+
| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
116+
| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
117+
| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
118+
| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
119+
| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
120+
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
121+
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
122+
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
123+
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
124+
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
125+
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
126+
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
127+
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
128+
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
129+
123130

124131
## 💡 API Design
125132

@@ -286,41 +293,36 @@ class ShardFormer:
286293
287294
Example:
288295
296+
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
297+
shard_config = ShardConfig()
289298
shard_former = ShardFormer(shard_config=shard_config)
290-
shard_former.init_distributed()
291-
model = shard_former.optimize(model, policy=policy)
292-
dataloader = shard_former.shard_dataset(dataset)
299+
model, shared_params = shard_former.optimize(org_model)
293300
294301
"""
295302

296303
def __init__(self, shard_config: ShardConfig):
297304
"""
298305
Do two things:
299-
1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
306+
1. Create a distribute coordinator
300307
2. serve as a store for shard config
301308
"""
302309
self.shard_config = shard_config
303-
self.pg_manager = None
310+
self.coordinator = DistCoordinator()
304311

305-
def init_distributed(self) -> colossalai.cluster.ProcessGroupManager:
306-
"""
307-
Initialize the distributed process group according to the
308-
"""
309-
pg_manager = ...
310-
self.pg_manager = pg_manager
311-
return pg_manager
312+
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
313+
r"""
314+
This method will optimize the model based on the given policy.
312315
313-
def shard_model(self, model: torch.nn.Module,policy: Policy) -> torch.nn.Module:
314-
"""
315-
Shard model for TP and PP
316-
"""
317-
...
316+
Args:
317+
model (`torch.nn.Model`): the origin huggingface model
318+
shard_config (`ShardConfig`): the config for distribute information
319+
policy (`Policy`): the custom policy for sharding
318320
319-
def shard_dataset(self, dataset: Dataset) -> Dataloader:
321+
Returns: the sharded model and the shared parameters
320322
"""
321-
Shard dataset for DP
322-
"""
323-
...
323+
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
324+
shared_params = sharder.shard()
325+
return model, shared_params
324326
```
325327

326328
## ⌨️ Development Notes
@@ -429,13 +431,24 @@ As shown in the figures above, when the sequence length is around 1000 or greate
429431
### Convergence
430432

431433

432-
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.
434+
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.
435+
436+
the configurations are as follows:
437+
```python
438+
batch_size = 2
439+
epoch = 3
440+
lr = 2.4e-5
441+
accumulation_steps = 8
442+
warmup_fraction = 0.03
443+
```
444+
433445

434446

435447
| accuracy | f1 | loss | GPU number | model sharded |
436448
| :------: | :-----: | :-----: | :--------: | :---------: |
437-
| 0.84589 | 0.88613 | 0.43414 | 4 | True |
438-
| 0.83594 | 0.88064 | 0.43298 | 1 | False |
449+
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
450+
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
451+
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
439452

440453

441454
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.

colossalai/shardformer/examples/convergence_benchmark.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,12 @@ def train(args):
4949

5050
# if multiple GPUs, shard the model
5151
if dist.get_world_size() > 1:
52-
shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
52+
tp_group = dist.new_group(backend='nccl')
53+
shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
54+
enable_tensor_parallelism=True,
55+
enable_all_optimization=True)
5356
shard_former = ShardFormer(shard_config=shard_config)
54-
model = shard_former.optimize(model)
57+
model, _ = shard_former.optimize(model)
5558

5659
optim = Adam(model.parameters(), lr=args.lr)
5760
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps

colossalai/shardformer/examples/convergence_benchmark.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
22
--model "bert" \
33
--pretrain "bert-base-uncased" \
4-
--max_epochs 1 \
4+
--max_epochs 3 \
55
--batch_size 2 \
66
--lr 2.4e-5 \
77
--fused_layernorm False \

colossalai/shardformer/examples/performance_benchmark.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def data_gen_for_sequence_classification(batch_size, seq_length):
2929
intermediate_size=256,
3030
num_attention_heads=4,
3131
max_position_embeddings=128,
32-
num_labels=16)
32+
num_labels=16,
33+
pad_token_id=2)
3334
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
3435
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
3536

@@ -73,7 +74,8 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d
7374
if provider == "shard_model":
7475
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
7576
shard_former = ShardFormer(shard_config=shard_config)
76-
sharded_model = shard_former.optimize(model).cuda()
77+
sharded_model, _ = shard_former.optimize(model)
78+
sharded_model = sharded_model.cuda()
7779
fn = lambda: train(sharded_model, data)
7880
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
7981
return ms

0 commit comments

Comments
 (0)