Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,11 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
Expand Down
32 changes: 21 additions & 11 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,28 @@ sharded_model, shared_params = shard_former.optimize(model).to('cuda')
# do everything like normal
...
```
shardformer configuration

`tensor_parallel_process_group`: the process group of tensor parallelism, it's necessary when using tensor parallel.
`pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism.
{{ autodoc:colossalai.pipeline.stage_manager.PipelineStageManager }}
`enable_tensor_parallelism`: using tensor parallel, partition the model along the columns or along the rows
`enable_fused_normalization`: using apex fused layernorm
`enable_flash_attention`: using flash attention
`enable_jit_fused`: using jit fused operators
`enable_sequence_parallelism`: using sequence parallelism, partition these non-tensor parallel regions along the sequence dimension.
`enable_sequence_overlap`: overlap the computation and communication in the sequence parallelism, it's used with `enable_sequence_parallelism`.

Following are the description `ShardConfig`'s arguments:

- `tensor_parallel_process_group`: The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group.

- `pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.

- `enable_tensor_parallelism`: Whether to use tensor parallelism. Defaults to True.

- `enable_fused_normalization`: Whether to use fused layernorm. Defaults to False.

- `enable_flash_attention`: Whether to switch on flash attention. Defaults to False.

- `enable_jit_fused`: Whether to switch on JIT fused operators. Defaults to False.

- `enable_sequence_parallelism`: Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.

- `enable_sequence_overlap`: Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False.

- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.

- `inference_only`: Whether only doing forward passing. Defaults to False.

### Write your own policy

Expand Down
26 changes: 11 additions & 15 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,28 @@ class ShardConfig:
The config for sharding the huggingface model

Args:
tensor_parallel_process_group (Optional[ProcessGroup]): The process group for tensor parallelism, defaults to None, which is the global process group.
pipeline_stage_manager (Optional[PipelineStageManager]): The pipeline stage manager, defaults to None, which means no pipeline.
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False.
tensor_parallel_process_group (Optional[ProcessGroup]): The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group.
pipeline_stage_manager (Optional[PipelineStageManager]): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
enable_tensor_parallelism (bool): Whether to use tensor parallelism. Defaults to True.
enable_fused_normalization (bool): Whether to use fused layernorm. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
inference_only (bool): Whether only doing forward passing. Defaults to False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
enable_tensor_parallelism: bool = True
enable_fused_normalization: bool = False
enable_all_optimization: bool = False
enable_flash_attention: bool = False
enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
enable_all_optimization: bool = False
inference_only: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False

# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
# inference_only: bool = True
# gather_output: bool = True

@property
def tensor_parallel_size(self):
Expand Down
3 changes: 2 additions & 1 deletion docs/source/en/basics/booster_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Jianghai Chen](https:

**Example Code**

- [Train with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet)
- [Train ResNet on CIFAR-10 with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet)
- [Train LLaMA-1/2 on RedPajama with Booster](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2)

## Introduction

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/basics/booster_plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.h

This plugin implements the combination of various parallel training strategies and optimization tools. The features of HybridParallelPlugin can be generally divided into four parts:

1. Shardformer: This plugin provides an entrance to Shardformer, which controls model sharding under tensor parallel and pipeline parallel setting. Shardformer also overloads the logic of model's forward/backward process to ensure the smooth working of tp/pp. Also, optimization tools including fused normalization, flash attention (xformers), JIT and sequence parallel are injected into the overloaded forward/backward method by Shardformer.
1. Shardformer: This plugin provides an entrance to Shardformer, which controls model sharding under tensor parallel and pipeline parallel setting. Shardformer also overloads the logic of model's forward/backward process to ensure the smooth working of tp/pp. Also, optimization tools including fused normalization, flash attention (xformers), JIT and sequence parallel are injected into the overloaded forward/backward method by Shardformer. More details can be found in chapter [Shardformer Doc](../features/shardformer.md).

2. Mixed Precision Training: Support for fp16/bf16 mixed precision training. More details about its arguments configuration can be found in [Mixed Precision Training Doc](../features/mixed_precision_training_with_booster.md).

Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/features/1D_tensor_parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Author: Zhengda Bian, Yongbin Li

> ⚠️ The information on this page is outdated and will be deprecated. Please check [Shardformer](./shardformer.md) for more information.

**Prerequisite**
- [Define Your Configuration](../basics/define_your_config.md)
- [Configure Parallelization](../basics/configure_parallelization.md)
Expand Down Expand Up @@ -116,3 +118,5 @@ Output of the first linear layer: torch.Size([16, 512])
Output of the second linear layer: torch.Size([16, 256])
```
The output of the first linear layer is split into 2 partitions (each has the shape `[16, 512]`), while the second layer has identical outputs across the GPUs.

<!-- doc-test-command: echo -->
143 changes: 143 additions & 0 deletions docs/source/en/features/shardformer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Shardformer

Author: [Baizhou Zhang](https://github.com/Fridge003)

**Prerequisite**
- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md)
- [Booster API](../basics/booster_api.md)
- [Booster Plugins](../basics/booster_plugins.md)

**Example Code**
- [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples)
- [Enabling Shardformer using HybridPrallelPlugin](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)

**Related Paper**
- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473)
- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965)
- [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691)
- [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120)


## Introduction

When training large transformer models such as LLaMa-2 70B or OPT 175B, model parallelism methods that divide a huge model into smaller shards, including tensor parallelism or pipeline parallism, are essential so as to meet the limitation of GPU memory.
However, manually cutting model and rewriting its forward/backword logic could be difficult for users who are not familiar with distributed training.
Meanwhile, the Huggingface transformers library has gradually become users' first choice of model source, and most mainstream large models have been open-sourced in Huggingface transformers model library.

Out of this motivation, the ColossalAI team develops **Shardformer**, a feature that automatically does preparation of model parallelism (tensor parallelism/pipeline parallelism) for popular transformer models in HuggingFace.
This module aims to make parallelization hassle-free for users who are not from the system background.
Within a few lines of codes, users can turn a model into a state ready for distributed training.
Also, Shardformer contains various optimization tools for acceleration and memory saving during forward/backward pass.


## How Shardformer Works

Generally, Shardformer works through the following four kinds of *replacements*:

1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module.
The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters.
Also, new `forward` methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism.
Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.

2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training.
For example, when training LlaMa-2 with tensor parallel size as 2, the attribute `num_heads` of `LlamaDecoderLayer` (the number of attention heads in each layer) should be replaced with `model.config.num_attention_heads // 2`.

3. Replacing the `forward` methods implemented by original Huggingface
Transformers libraries with our customized `forward` methods.
This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages.
Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method.

4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer).
By executing `ModelSharder.shard` method, current device will only keep the part of model parameters it's supposed to take care of.
To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them.
All other parameters are released so as to liberate memory usage.
As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved.

All of these replacements are implemented with manually written policies and forward functions.
If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details.

## Usage

### Shardformer Configuration

The configuration of Shardformer is controlled by class `ShardConfig`:

{{ autodoc:colossalai.shardformer.ShardConfig }}

If you want to enable Apex Fused Layernorm, please install `apex`.
If you want to enable the usage of flash attention, please install `flash_attn`.
In addition, xFormers's `cutlass_op` can serve as a backup for flash attention.

### Enabling Shardformer

#### 1. Enabling Shardformer Through Booster (Recommended)

Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer.
The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero.

More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md).

[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline.


#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended)

You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`.

[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
is an example on how to trigger `Shardformer` through calling Shardformer APIs.


### Precautions

1. When enabling pipeline parallel, please don't do the forward/backward pass in the conventional way (`model(input)`, `loss.backward()`), which will cause unexpected errors. Rather, please do forward/backward pass through calling `booster.execute_pipeline` method.

2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer.

3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through
```python
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
```
when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.


## Supporting Information

List of Huggingface transformers model families currently supported by Shardformer:
- LlaMa-1/LlaMa-2
- GPT2
- BERT
- OPT
- BLOOM
- T5
- ViT
- ChatGLM-2 6B
- Whisper

List of optimization tools currently supported by Shardformer:
- Flash Attention 2
- JIT Fused Operator
- xFormers
- Fused Layer Normalization
- Sequence Parallel
- Sequence Overlap

List of model families we plan to support in the near future:
- SAM
- Blip2
- RoBERTa
- ALBERT
- ERNIE
- GPT Neo
- GPT-J
- BEiT
- SwinTransformer V1/V2
- qwen

These lists will grow longer as more models and optimization tools emerge in the future. If you have any suggestions on the models/optimization we should support, please feel free to mention it in [Issues](https://github.com/hpcaitech/ColossalAI/issues) section of our project.

For more details about compatibility between each optimization tool and each supported model, please refer to chapter Roadmap in our [develop document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md).


<!-- doc-test-command: echo -->
5 changes: 3 additions & 2 deletions docs/source/zh-Hans/basics/booster_api.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# booster 使用
# Booster API

作者: [Mingyan Jiang](https://github.com/jiangmingyan), [Jianghai Chen](https://github.com/CjhHa1), [Baizhou Zhang](https://github.com/Fridge003)

Expand All @@ -11,7 +11,8 @@

<!-- update this url-->

- [使用 booster 训练](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet)
- [使用Booster在CIFAR-10数据集上训练ResNet](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet)
- [使用Booster在RedPajama数据集上训练Llama-1/2](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2)

## 简介

Expand Down
Loading