Skip to content

Commit 829a542

Browse files
committed
adapt to ray server
1 parent a4d1e33 commit 829a542

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

colossalai/inference/dynamic_batching/io_struct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
class Req:
7-
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
7+
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str):
88
self.request_id = request_id
99
self.prompt_ids = prompt_ids
1010
self.input_len = len(prompt_ids)
@@ -14,6 +14,7 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
1414
self.output_metadata_list = []
1515
self.has_generate_finished = False
1616
self.aborted = False
17+
self.prompts = prompts
1718

1819
def to_rpc_obj(self):
1920
return {

colossalai/inference/manager.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
max_total_token_num,
2020
batch_max_tokens,
2121
eos_id,
22+
model,
2223
log_stats=True,
2324
log_stats_interval=10,
2425
running_batch: Batch = None,
@@ -30,6 +31,7 @@ def __init__(
3031
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
3132
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
3233
eos_id : The end token of a seq
34+
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
3335
log_stats : whether to log stats
3436
log_stats_interval : log stats interval
3537
running_batch : running batch
@@ -45,31 +47,32 @@ def __init__(
4547
self.eos_id = eos_id
4648
self.has_wait_tokens = 0
4749
self.max_wait_tokens = 10
50+
self.model = model
4851

4952
self.stats_tool = Stats(log_stats, log_stats_interval)
5053
self.mem_usage_interval = log_stats_interval * 2
54+
self._set_tokenizer(tokenizer_name=self.model)
5155

52-
def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str):
56+
def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str):
5357
"""
5458
Add new request to req queue, during initialization all requests are held in waiting list.
5559
"""
56-
req = Req(request_id, prompt_ids, sampling_params)
60+
req = Req(request_id, prompt_ids, sampling_params, prompts)
5761
self.req_queue.append(req)
5862
return
5963

60-
def add_input(self, request_id, sampling_params, input_ids):
64+
def add_input(self, request_id, sampling_params, prompts):
6165
"""
6266
Encode and Add new input to req queue. support one sequence input for now.
6367
"""
64-
prompt_ids = self.tokenizer.encode(input_ids)
68+
prompt_ids = self.tokenizer.encode(prompts)
6569
prompt_len = len(prompt_ids)
66-
print(prompt_ids)
6770
if prompt_len > self.engine.max_input_len:
6871
raise ValueError(
6972
f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}"
7073
)
7174
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
72-
self.add_req(prompt_ids, sampling_params, request_id)
75+
self.add_req(prompt_ids, sampling_params, request_id, prompts)
7376
return
7477

7578
def abort(self, request_id):
@@ -90,7 +93,7 @@ def loop_for_fwd(self):
9093
"""
9194
counter_count = 0
9295
#self.running_batch is not None or self.req_queue.waiting_req_list
93-
while True:
96+
while self.running_batch is not None or self.req_queue.waiting_req_list:
9497
yield from self._step()
9598
counter_count += 1
9699
if self.running_batch is not None:
@@ -267,17 +270,17 @@ def _output_process(self, finished_reqs: List[Req]):
267270
"""
268271
for req in finished_reqs:
269272
output = self.tokenizer.decode(req.output_ids)
270-
yield output, req.request_id, req.output_metadata_list
273+
yield req.prompts+ output
271274

272275
def clean_up(self):
273276
# this logic should be implemented in the future.
274277
pass
275278

276-
def generate(self,request_id,prompt_id,sampling_params):
279+
def generate(self,prompts,sampling_params,request_id):
277280
"""
278281
Generate the output of a request.
279282
"""
280-
self.add_input(request_id,prompt_id,sampling_params)
283+
self.add_input(request_id,sampling_params,prompts)
281284
return self.loop_for_fwd()
282285

283286
def start_dynamic_batching(args, tp_engine, waiting_req_list):
@@ -287,6 +290,7 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list):
287290
max_total_token_num=args.max_total_token_num,
288291
batch_max_tokens=args.batch_max_tokens,
289292
eos_id=args.eos_id,
293+
model=args.model,
290294
log_stats=not args.disable_log_stats,
291295
log_stats_interval=args.log_stats_interval,
292296
waiting_req_list=waiting_req_list,

0 commit comments

Comments
 (0)