Skip to content

Commit 0cf393f

Browse files
FoolPlayerflybird11111
authored andcommitted
[Pipeline inference] Combine kvcache with pipeline inference (hpcaitech#4938)
* merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test
1 parent 15fee0d commit 0cf393f

File tree

19 files changed

+881
-704
lines changed

19 files changed

+881
-704
lines changed

colossalai/inference/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .pipeline import PPInferEngine
22

3-
__all__ = ["PPInferEngine"]
3+
4+
__all__ = ['PPInferEngine']

colossalai/inference/pipeline/README.md

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py).
1818

1919
1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks:
20-
- Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`.
20+
- Initialize the pipeline inference environment with `PipelineStageManager` and model with `ShardFormer`.
2121
- Run the pipeline inference model.
2222

2323
2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks:
@@ -31,54 +31,53 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManag
3131

3232
### Example
3333
```python
34-
from colossalai.pipeline import PPInferEngine
35-
# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example.
36-
model = LlamaForCausalLM.from_pretrained('/path/to/model')
37-
inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt")
38-
engine = PPInferEngine(
39-
pp_size=2,
40-
dtype='fp16',
41-
micro_batch_size=1,
42-
new_length=10,
43-
model=model,
44-
model_policy=LlamaForCausalLMPipelinePolicy())
45-
46-
output = engine.inference([inputs])
34+
from colossalai.inference import PPInferEngine
35+
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
36+
import colossalai
37+
from transformers import LlamaForCausalLM, LlamaTokenizer
4738

48-
```
39+
colossalai.launch_from_torch(config={})
40+
41+
model = LlamaForCausalLM.from_pretrained("/path/to/model")
42+
tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")
4943

50-
### Quick start
51-
```shell
52-
cd benchmark
53-
sh run.sh
44+
# assume the model is inferred with 2 pipeline stages
45+
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=32)
46+
47+
input = ["Introduce a landmark in London","Introduce a landmark in Singapore"]
48+
data = tokenizer(input, return_tensors='pt')
49+
output = inferengine.inference(data.to('cuda'))
50+
print(tokenizer.batch_decode(output))
5451
```
5552

5653
## Performance
5754

58-
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G.
55+
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G.
5956

60-
### Llama Throughput(tokens/s)
57+
### Llama Throughput (tokens/s) | input length=1024, output length=128
6158

62-
#### 7b, fp16
59+
#### A10 7b, fp16
6360
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
6461
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
65-
| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM |
66-
| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
67-
| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 |
68-
| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM |
62+
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
63+
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
6964

70-
#### 7b, fp32
65+
#### A10 13b, fp16
7166
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
7267
| :---: | :---: | :---: | :---: | :---: |
73-
| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 |
74-
| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM |
75-
| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 |
76-
| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM |
68+
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
69+
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
7770

78-
#### 13b, fp16
79-
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
80-
| :---: | :---: | :---: | :---: | :---: |
81-
| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 |
82-
| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM |
83-
| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 |
84-
| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM |
71+
72+
#### A800 7b, fp16
73+
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
74+
| :---: | :---: | :---: | :---: | :---: | :---: |
75+
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
76+
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
77+
78+
79+
#### A800 13b, fp16
80+
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
81+
| :---: | :---: | :---: | :---: | :---: | :---: |
82+
| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 |
83+
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |

colossalai/inference/pipeline/benchmark/benchmark.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import colossalai
99
from colossalai.inference import PPInferEngine
10-
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
10+
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
1111

1212
GIGABYTE = 1024**3
1313
MEGABYTE = 1024 * 1024
@@ -117,8 +117,11 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
117117
micro_batch_size=args.mb_size,
118118
new_length=args.new_length,
119119
model=model,
120-
model_policy=LlamaForCausalLMPipelinePolicy(),
120+
model_policy=LlamaModelInferPolicy(),
121121
verbose=True,
122+
max_batch_size=args.mb_size,
123+
max_input_len=args.seq_len,
124+
max_output_len=args.seq_len + args.new_length + 256,
122125
)
123126
data = data_gen(args.batch_size, args.seq_len)
124127

colossalai/inference/pipeline/benchmark/run.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
script_dir=$(cd "$(dirname "$0")" && pwd)
22
cd "${script_dir}"
33

4-
# 7b, fp32, 2 gpu, 1024, 128
4+
# 7b, fp16, 2 gpu, 1024, 128
55
for BATCH_SIZE in 2 4 8 16; do
66
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
77
--model="7b" \
@@ -13,7 +13,7 @@ for BATCH_SIZE in 2 4 8 16; do
1313
--pp_size=2
1414
done
1515

16-
# 7b, fp32, 2 gpu, 512, 512
16+
# 7b, fp16, 2 gpu, 512, 512
1717
for BATCH_SIZE in 2 4 8 16 32; do
1818
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
1919
--model="7b" \
@@ -25,7 +25,7 @@ for BATCH_SIZE in 2 4 8 16 32; do
2525
--pp_size=2
2626
done
2727

28-
# 7b, fp32, 2 gpu, 1024, 128
28+
# 7b, fp16, 2 gpu, 1024, 128
2929
for BATCH_SIZE in 2 4 8; do
3030
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
3131
--model="13b" \

colossalai/inference/pipeline/engine.py

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import torch
22
import torch.nn as nn
3+
from transformers.tokenization_utils_base import BatchEncoding
34

45
from colossalai.cluster import ProcessGroupMesh
56
from colossalai.pipeline.schedule.generate import GenerateSchedule
67
from colossalai.pipeline.stage_manager import PipelineStageManager
78
from colossalai.shardformer import ShardConfig, ShardFormer
89
from colossalai.shardformer.policies.base_policy import Policy
910

11+
from ..tensor_parallel.kvcache_manager import MemoryManager
1012
from .microbatch_manager import MicroBatchManager
1113

1214

@@ -23,20 +25,29 @@ class PPInferEngine:
2325
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
2426
new_length (int): the new length of the input sequence.
2527
early_stopping (bool): whether to stop early.
28+
max_batch_size (int): the maximum batch size.
29+
max_input_len (int): the maximum input length.
30+
max_output_len (int): the maximum output length.
2631
2732
Example:
2833
2934
```python
30-
from colossalai.ppinference import PPInferEngine
31-
from transformers import GPT2LMHeadModel, GPT2Tokenizer
35+
from colossalai.inference import PPInferEngine
36+
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
37+
import colossalai
38+
from transformers import LlamaForCausalLM, LlamaTokenizer
3239
33-
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
34-
# assume the model is infered with 4 pipeline stages
35-
inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding})
40+
colossalai.launch_from_torch(config={})
41+
42+
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
43+
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
44+
# assume the model is infered with 2 pipeline stages
45+
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)
46+
47+
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
48+
data = tokenizer(input, return_tensors='pt')
49+
output = inferengine.inference([data.to('cuda').data])
3650
37-
input = ["Hello, my dog is cute, and I like"]
38-
tokenized_input = tokenizer(input, return_tensors='pt')
39-
output = engine.inference([tokenized_input])
4051
```
4152
4253
"""
@@ -51,31 +62,63 @@ def __init__(
5162
new_length: int = 32,
5263
micro_batch_size: int = 1,
5364
micro_batch_buffer_size: int = None,
65+
max_batch_size: int = 4,
66+
max_input_len: int = 32,
67+
max_output_len: int = 32,
5468
verbose: bool = False,
5569
# TODO: implement early_stopping, and various gerneration options
5670
early_stopping: bool = False,
5771
do_sample: bool = False,
5872
num_beams: int = 1,
5973
) -> None:
6074
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
75+
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
76+
77+
max_output_len = max(max_output_len, max_input_len + new_length)
78+
6179
self.pp_size = pp_size
80+
if dtype == "fp16":
81+
self.dtype = torch.float16
82+
model.half()
83+
elif dtype == "bf16":
84+
self.dtype = torch.bfloat16
85+
model.to(torch.bfloat16)
86+
else:
87+
self.dtype = torch.float32
6288
self.pg_mesh = ProcessGroupMesh(pp_size)
6389
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
90+
self.model = pp_model or self._shardformer(model, model_policy)
91+
self.cache_manager_list = [
92+
self._init_manager(max_batch_size, max_input_len, max_output_len)
93+
for _ in range(micro_batch_buffer_size or pp_size)
94+
]
6495
self.mb_manager = MicroBatchManager(
65-
self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size
96+
self.stage_manager.stage,
97+
new_length,
98+
micro_batch_size,
99+
micro_batch_buffer_size or pp_size,
100+
max_input_len,
101+
max_output_len,
102+
self.cache_manager_list,
66103
)
67104
self.verbose = verbose
68105
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
69106

70-
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
71-
if dtype == "fp16":
72-
model.half()
73-
elif dtype == "bf16":
74-
model.to(torch.bfloat16)
75-
self.model = pp_model or self._shardformer(model, model_policy)
76-
77107
def inference(self, input_list):
78-
out, timestamp = self.schedule.generate_step(self.model, iter(input_list))
108+
"""
109+
Args:
110+
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
111+
112+
Returns:
113+
out (list): a list of output data, each element is a list of token.
114+
timestamp (float): the time cost of the inference, only return when verbose is `True`.
115+
"""
116+
assert isinstance(
117+
input_list, (BatchEncoding, dict)
118+
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
119+
if isinstance(input_list, BatchEncoding):
120+
input_list = input_list.data
121+
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
79122
if self.verbose:
80123
return out, timestamp
81124
else:
@@ -95,3 +138,17 @@ def _shardformer(self, model, model_policy):
95138
shardformer = ShardFormer(shard_config=shardconfig)
96139
shard_model, _ = shardformer.optimize(model, model_policy)
97140
return shard_model.cuda()
141+
142+
def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
143+
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
144+
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
145+
head_num = self.model.config.num_attention_heads
146+
num_hidden_layers = (
147+
self.model.config.num_hidden_layers
148+
if hasattr(self.model.config, "num_hidden_layers")
149+
else self.model.config.num_layers
150+
)
151+
layer_num = num_hidden_layers // self.pp_size
152+
153+
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
154+
return cache_manager

0 commit comments

Comments
 (0)