Skip to content

Commit a2db755

Browse files
author
Baizhou Zhang
authored
[doc] polish shardformer doc (#4779)
* fix example format in docstring * polish shardformer doc
1 parent 26cd6d8 commit a2db755

File tree

8 files changed

+220
-73
lines changed

8 files changed

+220
-73
lines changed

colossalai/booster/plugin/gemini_plugin.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -229,16 +229,17 @@ class GeminiPlugin(DPPluginBase):
229229
"""
230230
Plugin for Gemini.
231231
232-
Example:
233-
>>> from colossalai.booster import Booster
234-
>>> from colossalai.booster.plugin import GeminiPlugin
235-
>>>
236-
>>> model, train_dataset, optimizer, criterion = ...
237-
>>> plugin = GeminiPlugin()
238-
239-
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
240-
>>> booster = Booster(plugin=plugin)
241-
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
232+
```python
233+
from colossalai.booster import Booster
234+
from colossalai.booster.plugin import GeminiPlugin
235+
236+
model, train_dataset, optimizer, criterion = ...
237+
plugin = GeminiPlugin()
238+
239+
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
240+
booster = Booster(plugin=plugin)
241+
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
242+
```
242243
243244
Args:
244245
chunk_config_dict (dict, optional): chunk configuration dictionary.

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,16 +266,17 @@ class HybridParallelPlugin(PipelinePluginBase):
266266
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
267267
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
268268
269-
Example:
270-
>>> from colossalai.booster import Booster
271-
>>> from colossalai.booster.plugin import HybridParallelPlugin
269+
```python
270+
from colossalai.booster import Booster
271+
from colossalai.booster.plugin import HybridParallelPlugin
272272
273-
>>> model, train_dataset, optimizer, criterion = ...
274-
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
273+
model, train_dataset, optimizer, criterion = ...
274+
plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
275275
276-
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
277-
>>> booster = Booster(plugin=plugin)
278-
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
276+
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
277+
booster = Booster(plugin=plugin)
278+
model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
279+
```
279280
280281
Args:
281282
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -213,16 +213,17 @@ class LowLevelZeroPlugin(DPPluginBase):
213213
"""
214214
Plugin for low level zero.
215215
216-
Example:
217-
>>> from colossalai.booster import Booster
218-
>>> from colossalai.booster.plugin import LowLevelZeroPlugin
219-
>>>
220-
>>> model, train_dataset, optimizer, criterion = ...
221-
>>> plugin = LowLevelZeroPlugin()
222-
223-
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
224-
>>> booster = Booster(plugin=plugin)
225-
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
216+
```python
217+
from colossalai.booster import Booster
218+
from colossalai.booster.plugin import LowLevelZeroPlugin
219+
220+
model, train_dataset, optimizer, criterion = ...
221+
plugin = LowLevelZeroPlugin()
222+
223+
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
224+
booster = Booster(plugin=plugin)
225+
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
226+
```
226227
227228
Args:
228229
strage (int, optional): ZeRO stage. Defaults to 1.

colossalai/booster/plugin/torch_ddp_plugin.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,17 @@ class TorchDDPPlugin(DPPluginBase):
130130
"""
131131
Plugin for PyTorch DDP.
132132
133-
Example:
134-
>>> from colossalai.booster import Booster
135-
>>> from colossalai.booster.plugin import TorchDDPPlugin
136-
>>>
137-
>>> model, train_dataset, optimizer, criterion = ...
138-
>>> plugin = TorchDDPPlugin()
139-
140-
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
141-
>>> booster = Booster(plugin=plugin)
142-
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
133+
```python
134+
from colossalai.booster import Booster
135+
from colossalai.booster.plugin import TorchDDPPlugin
136+
137+
model, train_dataset, optimizer, criterion = ...
138+
plugin = TorchDDPPlugin()
139+
140+
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
141+
booster = Booster(plugin=plugin)
142+
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
143+
```
143144
144145
Args:
145146
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.

colossalai/booster/plugin/torch_fsdp_plugin.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,17 @@ class TorchFSDPPlugin(DPPluginBase):
143143
"""
144144
Plugin for PyTorch FSDP.
145145
146-
Example:
147-
>>> from colossalai.booster import Booster
148-
>>> from colossalai.booster.plugin import TorchFSDPPlugin
149-
>>>
150-
>>> model, train_dataset, optimizer, criterion = ...
151-
>>> plugin = TorchFSDPPlugin()
152-
153-
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
154-
>>> booster = Booster(plugin=plugin)
155-
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
146+
```python
147+
from colossalai.booster import Booster
148+
from colossalai.booster.plugin import TorchFSDPPlugin
149+
150+
model, train_dataset, optimizer, criterion = ...
151+
plugin = TorchFSDPPlugin()
152+
153+
train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
154+
booster = Booster(plugin=plugin)
155+
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
156+
```
156157
157158
Args:
158159
See https://pytorch.org/docs/stable/fsdp.html for details.

colossalai/cluster/dist_coordinator.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@ class in the whole program.
2020
- master: the process with rank 0
2121
- node master: the process with local rank 0 on the current node
2222
23-
Example:
24-
>>> from colossalai.cluster.dist_coordinator import DistCoordinator
25-
>>> coordinator = DistCoordinator()
26-
>>>
27-
>>> if coordinator.is_master():
28-
>>> do_something()
29-
>>>
30-
>>> coordinator.print_on_master('hello world')
23+
24+
```python
25+
from colossalai.cluster.dist_coordinator import DistCoordinator
26+
coordinator = DistCoordinator()
27+
28+
if coordinator.is_master():
29+
do_something()
30+
31+
coordinator.print_on_master('hello world')
32+
```
3133
3234
Attributes:
3335
rank (int): the rank of the current process
@@ -131,11 +133,13 @@ def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup
131133
other processes in the same process group. This is often useful when downloading is required
132134
as we only want to download in one process to prevent file corruption.
133135
134-
Example:
135-
>>> from colossalai.cluster import DistCoordinator
136-
>>> dist_coordinator = DistCoordinator()
137-
>>> with dist_coordinator.priority_execution():
138-
>>> dataset = CIFAR10(root='./data', download=True)
136+
137+
```python
138+
from colossalai.cluster import DistCoordinator
139+
dist_coordinator = DistCoordinator()
140+
with dist_coordinator.priority_execution():
141+
dataset = CIFAR10(root='./data', download=True)
142+
```
139143
140144
Args:
141145
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
@@ -174,13 +178,14 @@ def on_master_only(self, process_group: ProcessGroup = None):
174178
"""
175179
A function wrapper that only executes the wrapped function on the master process (rank 0).
176180
177-
Example:
178-
>>> from colossalai.cluster import DistCoordinator
179-
>>> dist_coordinator = DistCoordinator()
180-
>>>
181-
>>> @dist_coordinator.on_master_only()
182-
>>> def print_on_master(msg):
183-
>>> print(msg)
181+
```python
182+
from colossalai.cluster import DistCoordinator
183+
dist_coordinator = DistCoordinator()
184+
185+
@dist_coordinator.on_master_only()
186+
def print_on_master(msg):
187+
print(msg)
188+
```
184189
"""
185190
is_master = self.is_master(process_group)
186191

docs/source/en/features/shardformer.md

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,17 +214,83 @@ In addition, xFormers's `cutlass_op` can serve as a backup for flash attention.
214214
Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer.
215215
The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero.
216216

217-
More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md).
217+
[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Move to the root directory of this example, and execute
218+
```bash
219+
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert"
220+
```
221+
Then you can start finetuning a bert model wrapped by `Shardformer`. The process of wrapping is operated by `HybridParallelPlugin`.
222+
223+
Let's delve into the code of `finetune.py`:
224+
225+
In the `main` function, the plugin is created through the following codes:
226+
```python
227+
...
228+
elif args.plugin == "hybrid_parallel":
229+
# modify the param accordingly for finetuning test cases
230+
plugin = HybridParallelPlugin(
231+
tp_size=1,
232+
pp_size=2,
233+
num_microbatches=None,
234+
microbatch_size=1,
235+
enable_all_optimization=True,
236+
zero_stage=1,
237+
precision="fp16",
238+
initial_scale=1,
239+
)
240+
```
241+
Here you can change the configuration of plugin by setting `tp_size`, `pp_size` or `zero_stage` to other values. More details about plugin configuration can be found in [Booster Plugins Doc](../basics/booster_plugins.md).
242+
243+
If pipeline parallel is not enabled, just do the training in the same way of other booster plugins(first boost with Booster, then do forward and backward through normal way).
244+
However, if pipeline parallel is enabled, there are several usages different from other normal cases:
245+
246+
1. Before doing forward or backward, the criterion function (loss function) is processed to meet the argument demand of running pipeline:
247+
```python
248+
def _criterion(outputs, inputs):
249+
outputs = output_transform_fn(outputs)
250+
loss = criterion(outputs)
251+
return loss
252+
```
253+
254+
2. In `train_epoch` function, dataloader is converted into `Iterator` class before running pipeline:
255+
```python
256+
train_dataloader_iter = iter(train_dataloader)
257+
```
218258

219-
[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline.
259+
3. Do forward and backward passing through calling `Booster.execute_pipeline` method:
260+
```python
261+
outputs = booster.execute_pipeline(
262+
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
263+
)
264+
```
265+
Backward passing has been completed by this method, so there is no need to call `loss.backward()` after executing this method.
266+
More details about `Booster.execute_pipeline` can be found in [Booster API Doc](../basics/booster_api.md).
220267

221268

222269
#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended)
223270

224271
You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`.
225272

226273
[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
227-
is an example on how to trigger `Shardformer` through calling Shardformer APIs.
274+
is an example on how to trigger `Shardformer` through calling Shardformer APIs. In the `train` function of example code, the model is wrapped by `Shardformer` through the following few codes:
275+
```python
276+
...
277+
if dist.get_world_size() > 1:
278+
tp_group = dist.new_group(backend="nccl")
279+
280+
# First create configuration for Shardformer
281+
shard_config = ShardConfig(
282+
tensor_parallel_process_group=tp_group,
283+
enable_tensor_parallelism=True,
284+
enable_all_optimization=True
285+
)
286+
287+
# Then create ShardFormer object with created config
288+
shard_former = ShardFormer(shard_config=shard_config)
289+
290+
# Finally shard the model using ShardFormer.optimize method
291+
model, _ = shard_former.optimize(model)
292+
...
293+
```
228294

229295
### Precautions
230296

@@ -241,6 +307,8 @@ is an example on how to trigger `Shardformer` through calling Shardformer APIs.
241307

242308
## How Shardformer Works
243309

310+
### Main Idea
311+
244312
Generally, Shardformer works through the following four kinds of *replacements*:
245313

246314
1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module.

docs/source/zh-Hans/features/shardformer.md

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,83 @@ Shardformer的配置由类`ShardConfig`的参数控制:
207207

208208
通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster``execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能(例如混合精度训练或Zero)相结合的能力。
209209

210-
更多关于这一用法的细节可以参考 [Booster API 文档](../basics/booster_api.md)以及[Booster 插件文档](../basics/booster_plugins.md)[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。
210+
[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。
211+
移动到示例的根目录下,执行命令:
212+
```bash
213+
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert"
214+
```
215+
你便可以微调一个被`Shardformer`封装过的Bert模型,而封装的操作是由`HybridParallelPlugin`完成的。
216+
217+
接下来一起深入挖掘一下`finetune.py`里的代码:
218+
219+
`main`函数中,混合并行的插件通过以下的代码创建
220+
```python
221+
...
222+
elif args.plugin == "hybrid_parallel":
223+
# modify the param accordingly for finetuning test cases
224+
plugin = HybridParallelPlugin(
225+
tp_size=1,
226+
pp_size=2,
227+
num_microbatches=None,
228+
microbatch_size=1,
229+
enable_all_optimization=True,
230+
zero_stage=1,
231+
precision="fp16",
232+
initial_scale=1,
233+
)
234+
```
235+
在这里你可以通过设置不同的`tp_size`, `pp_size``zero_stage`来改变插件的配置。更多关于插件配置的信息可以在[Booster 插件文档](../basics/booster_plugins.md)中被找到。
236+
237+
当流水并行不被启用的时候,训练的流程和其他的插件是一样的 (先用Booster封装模型和优化器,再用正常的方式做前向和后向传递)。然而,当流水线并行被启用的时候,有几处不同于寻常情况的用法:
238+
239+
1. 在进行前向和后向之前,criterion函数(loss函数)需要被处理以满足流水线并行的传参要求:
240+
```python
241+
def _criterion(outputs, inputs):
242+
outputs = output_transform_fn(outputs)
243+
loss = criterion(outputs)
244+
return loss
245+
```
211246

247+
2. 在 `train_epoch` 函数中, dataloader 在进行流水线的前向后向操作之前需要被转换为 `Iterator` 类:
248+
```python
249+
train_dataloader_iter = iter(train_dataloader)
250+
```
251+
252+
3. 通过调用`Booster.execute_pipeline` 方法来执行前向和后向传递:
253+
```python
254+
outputs = booster.execute_pipeline(
255+
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
256+
)
257+
```
258+
该方法会自动执行后向传递,所以在执行该方法后不需要再调用 `loss.backward()`方法。
259+
更多关于 `Booster.execute_pipeline` 的信息可以参考 [Booster API 文档](../basics/booster_api.md)。
212260

213261
#### 2. 通过Shardformer API启动Shardformer (不推荐)
214262

215263
您还可以通过手动调用Shardformer API的方式启动Shardformer。然而我们并不推荐这种用法,因为流水线并行在没有`Booster`的情况下无法正常运行。
216264

217265
[这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
218266
是一个通过调用Shardformer的API启动`Shardformer`的示例。
219-
267+
在示例代码的`train`函数中,模型被以下的几行代码进行封装:
268+
```python
269+
...
270+
if dist.get_world_size() > 1:
271+
tp_group = dist.new_group(backend="nccl")
272+
273+
# First create configuration for Shardformer
274+
shard_config = ShardConfig(
275+
tensor_parallel_process_group=tp_group,
276+
enable_tensor_parallelism=True,
277+
enable_all_optimization=True
278+
)
279+
280+
# Then create ShardFormer object with created config
281+
shard_former = ShardFormer(shard_config=shard_config)
282+
283+
# Finally shard the model using ShardFormer.optimize method
284+
model, _ = shard_former.optimize(model)
285+
...
286+
```
220287

221288
### 注意事项
222289

@@ -234,6 +301,8 @@ Shardformer的配置由类`ShardConfig`的参数控制:
234301

235302
## Shardformer的工作原理
236303

304+
### 设计思想
305+
237306
通常来说,Shardformer通过以下四种“替换”进行工作:
238307

239308
1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear``nn.Embedding`)。

0 commit comments

Comments
 (0)