Skip to content

Commit e0757c3

Browse files
authored
[inference] Dynamic Batching for Single and Multiple GPUs (#4831)
* finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
1 parent 8aed02b commit e0757c3

File tree

16 files changed

+1221
-48
lines changed

16 files changed

+1221
-48
lines changed

colossalai/inference/dynamic_batching/__init__.py

Whitespace-only changes.
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
import collections
2+
from dataclasses import dataclass
3+
from typing import Dict, List , Tuple
4+
5+
import numpy as np
6+
import torch
7+
8+
from colossalai.inference.tensor_parallel import MemoryManager
9+
10+
# make batch infer state an attr of InferBatch
11+
12+
13+
class InferSamplingParams:
14+
def __init__(
15+
self,
16+
do_sample: bool = False,
17+
presence_penalty: float = 0.0,
18+
frequency_penalty: float = 0.0,
19+
temperature: float = 1.0,
20+
top_p: float = 1.0,
21+
top_k: int = -1,
22+
vocab_size: int = -1,
23+
) -> None:
24+
self.do_sample = do_sample
25+
self.presence_penalty = presence_penalty
26+
self.frequency_penalty = frequency_penalty
27+
self.temperature = temperature
28+
self.top_p = top_p
29+
self.top_k = top_k
30+
if self.top_k == -1:
31+
self.top_k = vocab_size
32+
return
33+
34+
35+
@dataclass
36+
class InferBatch:
37+
batch_id: int
38+
requests: List
39+
requests_idx_mapping: Dict[int, int]
40+
41+
input_ids: torch.Tensor
42+
43+
all_input_ids: List[List[int]]
44+
input_lengths: List[int]
45+
46+
out_token_id_counts: List
47+
sampling_param_list: List[InferSamplingParams]
48+
49+
nopad_total_token_num: int
50+
nopad_max_len_in_batch: int
51+
nopad_b_loc: torch.Tensor
52+
nopad_b_start_loc: torch.Tensor
53+
nopad_b_seq_len: torch.Tensor
54+
cache_manager: MemoryManager
55+
max_total_len: int
56+
57+
@classmethod
58+
@torch.no_grad()
59+
def init_batch(
60+
cls,
61+
batch_id,
62+
requests,
63+
dtype: torch.dtype,
64+
device: torch.device,
65+
cache_manager: MemoryManager,
66+
vocab_size: int,
67+
max_total_len: int,
68+
) -> 'InferBatch':
69+
input_lengths = []
70+
all_input_ids = []
71+
requests_idx_mapping = {}
72+
73+
out_token_id_counts = []
74+
sampling_param_list = []
75+
76+
nopad_total_token_num = 0
77+
nopad_max_len_in_batch = 0
78+
nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda")
79+
# to avoid memory leak , we pre-allocate 12 more space for each batch.
80+
nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda")
81+
for i, r in enumerate(requests):
82+
# request id -> idx in list mapping
83+
requests_idx_mapping[r["request_id"]] = i
84+
85+
tokenized_input = r["input_id"]
86+
87+
input_length = len(tokenized_input)
88+
input_lengths.append(input_length)
89+
all_input_ids.append(tokenized_input)
90+
out_token_id_counts.append(collections.defaultdict(int))
91+
92+
# postprocessor
93+
sampling_param = r["sampling_param"]
94+
sampling_param["vocab_size"] = vocab_size
95+
sampling_param_list.append(InferSamplingParams(**sampling_param))
96+
97+
nopad_total_token_num += input_length
98+
nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length)
99+
100+
nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda")
101+
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
102+
103+
if len(requests) > 1:
104+
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
105+
else:
106+
input_ids = all_input_ids[0]
107+
108+
# Create tensors on device
109+
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
110+
111+
return cls(
112+
batch_id=batch_id,
113+
requests=requests,
114+
requests_idx_mapping=requests_idx_mapping,
115+
input_ids=input_ids,
116+
input_lengths=input_lengths,
117+
all_input_ids=all_input_ids,
118+
nopad_total_token_num=nopad_total_token_num,
119+
nopad_max_len_in_batch=nopad_max_len_in_batch,
120+
nopad_b_loc=nopad_b_loc,
121+
nopad_b_start_loc=nopad_b_start_loc,
122+
nopad_b_seq_len=nopad_b_seq_len,
123+
out_token_id_counts=out_token_id_counts,
124+
sampling_param_list=sampling_param_list,
125+
cache_manager=cache_manager,
126+
max_total_len=max_total_len,
127+
)
128+
129+
@torch.no_grad()
130+
def free_self(self) -> None:
131+
"""
132+
Free the memory of the InferBatch itself
133+
"""
134+
remove_index = []
135+
for idx in range(len(self)):
136+
remove_index.append(
137+
self.nopad_b_loc[
138+
idx,
139+
(self.nopad_max_len_in_batch - 1)
140+
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
141+
]
142+
)
143+
remove_index = torch.cat(remove_index, dim=-1)
144+
self.cache_manager.free(remove_index)
145+
146+
147+
@torch.no_grad()
148+
def filter(self, request_ids: List[int]) -> 'InferBatch':
149+
"""
150+
Filter finished batch and return a new InferBatch with left ones.
151+
"""
152+
if len(request_ids) == 0:
153+
raise ValueError("Batch must have at least one request")
154+
if len(request_ids) == len(self):
155+
return self
156+
requests_idx_mapping = {}
157+
indices = []
158+
requests = []
159+
all_input_ids = []
160+
input_lengths = []
161+
nopad_total_token_num = 0
162+
nopad_max_len_in_batch = 0
163+
nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda")
164+
nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
165+
nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
166+
167+
left_idx = []
168+
for i, request_id in enumerate(request_ids):
169+
idx = self.requests_idx_mapping[request_id]
170+
left_idx.append(idx)
171+
172+
left_idx_set = set(left_idx)
173+
remove_index = []
174+
for idx in range(len(self)):
175+
if idx not in left_idx_set:
176+
remove_index.append(
177+
self.nopad_b_loc[
178+
idx,
179+
(self.nopad_max_len_in_batch - 1)
180+
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
181+
]
182+
)
183+
remove_index = torch.cat(remove_index, dim=-1)
184+
self.cache_manager.free(remove_index)
185+
186+
nopad_max_len_in_batch = 0
187+
for i, request_id in enumerate(request_ids):
188+
idx = self.requests_idx_mapping[request_id]
189+
indices.append(idx)
190+
191+
nopad_b_seq_len[:] = self.nopad_b_seq_len[indices]
192+
nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item()
193+
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
194+
nopad_total_token_num = torch.sum(nopad_b_seq_len).item()
195+
196+
nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[
197+
indices,
198+
(self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1),
199+
]
200+
for i, request_id in enumerate(request_ids):
201+
idx = self.requests_idx_mapping[request_id]
202+
requests_idx_mapping[request_id] = i
203+
requests.append(self.requests[idx])
204+
all_input_ids.append(self.all_input_ids[idx])
205+
input_lengths.append(self.input_lengths[idx])
206+
207+
input_ids = self.input_ids[indices]
208+
209+
return InferBatch(
210+
batch_id=self.batch_id,
211+
requests=requests,
212+
requests_idx_mapping=requests_idx_mapping,
213+
input_ids=input_ids,
214+
input_lengths=input_lengths,
215+
all_input_ids=all_input_ids,
216+
nopad_total_token_num=nopad_total_token_num,
217+
nopad_max_len_in_batch=nopad_max_len_in_batch,
218+
nopad_b_loc=nopad_b_loc,
219+
nopad_b_start_loc=nopad_b_start_loc,
220+
nopad_b_seq_len=nopad_b_seq_len,
221+
out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices],
222+
sampling_param_list=[self.sampling_param_list[_i] for _i in indices],
223+
cache_manager=self.cache_manager,
224+
max_total_len=self.max_total_len,
225+
)
226+
227+
@classmethod
228+
@torch.no_grad()
229+
def merge(cls, batch1, batch2) -> 'InferBatch':
230+
"""
231+
Return megerd new InferBatch
232+
"""
233+
requests = batch1.requests + batch2.requests
234+
requests_idx_mapping = {}
235+
new_batch_size = len(batch1) + len(batch2)
236+
237+
input_ids = batch1.input_ids.new_empty(new_batch_size)
238+
all_input_ids = []
239+
input_lengths = []
240+
out_token_id_counts = []
241+
sampling_param_list = []
242+
243+
cumulative_batch_size = 0
244+
nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num
245+
nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch)
246+
max_total_len = max(batch1.max_total_len, batch2.max_total_len)
247+
nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda")
248+
nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
249+
nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
250+
nopad_start_loc_len_temp = 0
251+
batches = [batch1, batch2]
252+
for i, batch in enumerate(batches):
253+
if i == 0:
254+
requests_idx_mapping = batch.requests_idx_mapping
255+
else:
256+
for k, v in batch.requests_idx_mapping.items():
257+
requests_idx_mapping[k] = v + cumulative_batch_size
258+
start_index = cumulative_batch_size
259+
end_index = cumulative_batch_size + len(batch)
260+
input_ids[start_index:end_index] = batch.input_ids
261+
nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len
262+
nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp
263+
nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1]
264+
nopad_b_loc[
265+
start_index:end_index,
266+
nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1,
267+
] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1]
268+
269+
all_input_ids.extend(batch.all_input_ids)
270+
271+
input_lengths.extend(batch.input_lengths)
272+
out_token_id_counts.extend(batch.out_token_id_counts)
273+
sampling_param_list.extend(batch.sampling_param_list)
274+
# Update
275+
cumulative_batch_size += len(batch)
276+
277+
nopad_b_loc[:, nopad_max_len_in_batch - 1] = (
278+
nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda")
279+
)
280+
return InferBatch(
281+
batch_id=batches[0].batch_id,
282+
requests=requests,
283+
requests_idx_mapping=requests_idx_mapping,
284+
input_ids=input_ids,
285+
input_lengths=input_lengths,
286+
all_input_ids=all_input_ids,
287+
nopad_total_token_num=nopad_total_token_num,
288+
nopad_max_len_in_batch=nopad_max_len_in_batch,
289+
nopad_b_loc=nopad_b_loc,
290+
nopad_b_start_loc=nopad_b_start_loc,
291+
nopad_b_seq_len=nopad_b_seq_len,
292+
out_token_id_counts=out_token_id_counts,
293+
sampling_param_list=sampling_param_list,
294+
cache_manager=batches[0].cache_manager,
295+
max_total_len=max_total_len,
296+
)
297+
298+
def __len__(self):
299+
return len(self.requests)
300+
301+
def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
302+
presence_penalties: List[float] = []
303+
frequency_penalties: List[float] = []
304+
temperatures: List[float] = []
305+
top_ps: List[float] = []
306+
top_ks: List[int] = []
307+
p_token_ids: List[int] = []
308+
p_token_counts: List[int] = []
309+
p_seq_len: List[int] = [
310+
0,
311+
]
312+
p_max_len_in_batch: int = 0
313+
for i, id_to_count in enumerate(self.out_token_id_counts):
314+
sample_param = self.sampling_param_list[i]
315+
presence_penalties.append(sample_param.presence_penalty)
316+
frequency_penalties.append(sample_param.frequency_penalty)
317+
temperatures.append(sample_param.temperature)
318+
top_ps.append(sample_param.top_p)
319+
top_ks.append(sample_param.top_k)
320+
321+
for token_id, count in id_to_count.items():
322+
p_token_ids.append(token_id)
323+
p_token_counts.append(count)
324+
p_seq_len.append(len(id_to_count))
325+
p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count))
326+
327+
presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda")
328+
frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda")
329+
temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda")
330+
top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda")
331+
top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda")
332+
p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda")
333+
p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda")
334+
p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda")
335+
p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32)
336+
return (
337+
presence_penalties,
338+
frequency_penalties,
339+
temperatures,
340+
top_ps,
341+
top_ks,
342+
p_token_ids,
343+
p_token_counts,
344+
p_cumsum_seq_len,
345+
p_max_len_in_batch,
346+
)

0 commit comments

Comments
 (0)