Skip to content

Commit e44f768

Browse files
authored
feat: Add no_kv_cache_reuse option and streaming support for trtllm serve bench (#4971)
Signed-off-by: Yi Zhang <[email protected]>
1 parent 8f67e36 commit e44f768

File tree

3 files changed

+300
-143
lines changed

3 files changed

+300
-143
lines changed

tensorrt_llm/serve/scripts/backend_request_func.py

Lines changed: 153 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class RequestFuncOutput:
4848

4949
async def async_request_trt_llm(
5050
request_func_input: RequestFuncInput,
51+
streaming: bool = True,
5152
pbar: Optional[tqdm] = None,
5253
) -> RequestFuncOutput:
5354
api_url = request_func_input.api_url
@@ -61,7 +62,7 @@ async def async_request_trt_llm(
6162
"temperature": 0.0,
6263
"top_p": 1.0,
6364
"max_tokens": request_func_input.output_len,
64-
"stream": True,
65+
"stream": streaming,
6566
}
6667
if request_func_input.ignore_eos:
6768
payload["min_length"] = request_func_input.output_len
@@ -74,30 +75,39 @@ async def async_request_trt_llm(
7475
try:
7576
async with session.post(url=api_url, json=payload) as response:
7677
if response.status == 200:
77-
async for chunk_bytes in response.content:
78-
chunk_bytes = chunk_bytes.strip()
79-
if not chunk_bytes:
80-
continue
81-
82-
chunk = chunk_bytes.decode("utf-8").removeprefix(
83-
"data:")
84-
85-
data = json.loads(chunk)
86-
output.generated_text += data["text_output"]
87-
timestamp = time.perf_counter()
88-
# First token
89-
if ttft == 0.0:
90-
ttft = timestamp - st
91-
output.ttft = ttft
92-
93-
# Decoding phase
94-
else:
95-
output.itl.append(timestamp - most_recent_timestamp)
78+
output.success = True
79+
if streaming:
80+
async for chunk_bytes in response.content:
81+
chunk_bytes = chunk_bytes.strip()
82+
if not chunk_bytes:
83+
continue
9684

97-
most_recent_timestamp = timestamp
85+
chunk = chunk_bytes.decode("utf-8").removeprefix(
86+
"data:")
9887

99-
output.latency = most_recent_timestamp - st
100-
output.success = True
88+
data = json.loads(chunk)
89+
output.generated_text += data["text_output"]
90+
timestamp = time.perf_counter()
91+
# First token
92+
if ttft == 0.0:
93+
ttft = timestamp - st
94+
output.ttft = ttft
95+
96+
# Decoding phase
97+
else:
98+
output.itl.append(timestamp -
99+
most_recent_timestamp)
100+
101+
most_recent_timestamp = timestamp
102+
103+
output.latency = most_recent_timestamp - st
104+
else:
105+
content = await response.content.read()
106+
data = json.loads(content.decode())
107+
output.ttft = -1
108+
output.itl = []
109+
output.generated_text = data["text_output"]
110+
output.latency = time.perf_counter() - st
101111

102112
else:
103113
output.error = response.reason or ""
@@ -114,6 +124,7 @@ async def async_request_trt_llm(
114124

115125
async def async_request_openai_completions(
116126
request_func_input: RequestFuncInput,
127+
streaming: bool = True,
117128
pbar: Optional[tqdm] = None,
118129
) -> RequestFuncOutput:
119130
api_url = request_func_input.api_url
@@ -131,11 +142,10 @@ async def async_request_openai_completions(
131142
"repetition_penalty": 1.0,
132143
"max_tokens": request_func_input.output_len,
133144
"logprobs": request_func_input.logprobs,
134-
"stream": True,
135-
"stream_options": {
136-
"include_usage": True,
137-
},
145+
"stream": streaming,
138146
}
147+
if streaming:
148+
payload["stream_options"] = {"include_usage": True}
139149
if request_func_input.ignore_eos:
140150
payload["ignore_eos"] = request_func_input.ignore_eos
141151
if request_func_input.extra_body:
@@ -154,50 +164,62 @@ async def async_request_openai_completions(
154164
async with session.post(url=api_url, json=payload,
155165
headers=headers) as response:
156166
if response.status == 200:
157-
first_chunk_received = False
158-
async for chunk_bytes in response.content:
159-
chunk_bytes = chunk_bytes.strip()
160-
if not chunk_bytes:
161-
continue
162-
163-
chunk = chunk_bytes.decode("utf-8").removeprefix(
164-
"data: ")
165-
if chunk != "[DONE]":
166-
data = json.loads(chunk)
167-
168-
# NOTE: Some completion API might have a last
169-
# usage summary response without a token so we
170-
# want to check a token was generated
171-
if choices := data.get("choices"):
172-
# Note that text could be empty here
173-
# e.g. for special tokens
174-
text = choices[0].get("text")
175-
timestamp = time.perf_counter()
176-
# First token
177-
if not first_chunk_received:
178-
first_chunk_received = True
179-
ttft = time.perf_counter() - st
180-
output.ttft = ttft
181-
182-
# Decoding phase
183-
else:
184-
output.itl.append(timestamp -
185-
most_recent_timestamp)
186-
187-
most_recent_timestamp = timestamp
188-
generated_text += text or ""
189-
elif usage := data.get("usage"):
190-
output.output_tokens = usage.get(
191-
"completion_tokens")
192-
if first_chunk_received:
193-
output.success = True
167+
if streaming:
168+
first_chunk_received = False
169+
async for chunk_bytes in response.content:
170+
chunk_bytes = chunk_bytes.strip()
171+
if not chunk_bytes:
172+
continue
173+
174+
chunk = chunk_bytes.decode("utf-8").removeprefix(
175+
"data: ")
176+
if chunk != "[DONE]":
177+
data = json.loads(chunk)
178+
179+
# NOTE: Some completion API might have a last
180+
# usage summary response without a token so we
181+
# want to check a token was generated
182+
if choices := data.get("choices"):
183+
# Note that text could be empty here
184+
# e.g. for special tokens
185+
text = choices[0].get("text")
186+
timestamp = time.perf_counter()
187+
# First token
188+
if not first_chunk_received:
189+
first_chunk_received = True
190+
ttft = time.perf_counter() - st
191+
output.ttft = ttft
192+
193+
# Decoding phase
194+
else:
195+
output.itl.append(timestamp -
196+
most_recent_timestamp)
197+
198+
most_recent_timestamp = timestamp
199+
generated_text += text or ""
200+
elif usage := data.get("usage"):
201+
output.output_tokens = usage.get(
202+
"completion_tokens")
203+
if first_chunk_received:
204+
output.success = True
205+
else:
206+
output.success = False
207+
output.error = (
208+
"Never received a valid chunk to calculate TTFT."
209+
"This response will be marked as failed!")
210+
output.generated_text = generated_text
211+
output.latency = most_recent_timestamp - st
194212
else:
195-
output.success = False
196-
output.error = (
197-
"Never received a valid chunk to calculate TTFT."
198-
"This response will be marked as failed!")
199-
output.generated_text = generated_text
200-
output.latency = most_recent_timestamp - st
213+
content = await response.content.read()
214+
data = json.loads(content.decode())
215+
generated_text = data["choices"][0]["text"]
216+
output.success = True
217+
output.generated_text = generated_text
218+
output.latency = time.perf_counter() - st
219+
output.ttft = -1
220+
output.itl = []
221+
output.output_tokens = data["usage"][
222+
"completion_tokens"]
201223
else:
202224
output.error = response.reason or ""
203225
output.success = False
@@ -213,6 +235,7 @@ async def async_request_openai_completions(
213235

214236
async def async_request_openai_chat_completions(
215237
request_func_input: RequestFuncInput,
238+
streaming: bool = True,
216239
pbar: Optional[tqdm] = None,
217240
) -> RequestFuncOutput:
218241
api_url = request_func_input.api_url
@@ -222,23 +245,34 @@ async def async_request_openai_chat_completions(
222245

223246
async with aiohttp.ClientSession(trust_env=True,
224247
timeout=AIOHTTP_TIMEOUT) as session:
225-
content = [{"type": "text", "text": request_func_input.prompt}]
226248
payload = {
227249
"model": request_func_input.model_name \
228250
if request_func_input.model_name else request_func_input.model,
229251
"messages": [
230-
{
231-
"role": "user",
232-
"content": content
233-
},
234252
],
235253
"temperature": 0.0,
236254
"max_completion_tokens": request_func_input.output_len,
237-
"stream": True,
238-
"stream_options": {
239-
"include_usage": True,
240-
},
255+
"stream": streaming,
241256
}
257+
258+
if isinstance(request_func_input.prompt, list) and all(
259+
[isinstance(i, int) for i in request_func_input.prompt]):
260+
payload["prompt_token_ids"] = request_func_input.prompt
261+
else:
262+
assert isinstance(
263+
request_func_input.prompt,
264+
str), "Prompt must be a string or a list of integers"
265+
payload["messages"].append({
266+
"role":
267+
"user",
268+
"content": [{
269+
"type": "text",
270+
"text": request_func_input.prompt
271+
}]
272+
})
273+
274+
if streaming:
275+
payload["stream_options"] = {"include_usage": True}
242276
if request_func_input.ignore_eos:
243277
payload["ignore_eos"] = request_func_input.ignore_eos
244278
if request_func_input.extra_body:
@@ -259,39 +293,51 @@ async def async_request_openai_chat_completions(
259293
async with session.post(url=api_url, json=payload,
260294
headers=headers) as response:
261295
if response.status == 200:
262-
async for chunk_bytes in response.content:
263-
chunk_bytes = chunk_bytes.strip()
264-
if not chunk_bytes:
265-
continue
266-
267-
chunk = chunk_bytes.decode("utf-8").removeprefix(
268-
"data: ")
269-
if chunk != "[DONE]":
270-
timestamp = time.perf_counter()
271-
data = json.loads(chunk)
296+
output.success = True
297+
if streaming:
298+
async for chunk_bytes in response.content:
299+
chunk_bytes = chunk_bytes.strip()
300+
if not chunk_bytes:
301+
continue
302+
303+
chunk = chunk_bytes.decode("utf-8").removeprefix(
304+
"data: ")
305+
if chunk != "[DONE]":
306+
timestamp = time.perf_counter()
307+
data = json.loads(chunk)
272308

273-
if choices := data.get("choices"):
274-
content = choices[0]["delta"].get("content")
275-
# First token
276-
if ttft == 0.0:
277-
ttft = timestamp - st
278-
output.ttft = ttft
309+
if choices := data.get("choices"):
310+
content = choices[0]["delta"].get("content")
311+
# First token
312+
if ttft == 0.0:
313+
ttft = timestamp - st
314+
output.ttft = ttft
279315

280-
# Decoding phase
281-
else:
282-
output.itl.append(timestamp -
283-
most_recent_timestamp)
316+
# Decoding phase
317+
else:
318+
output.itl.append(timestamp -
319+
most_recent_timestamp)
284320

285-
generated_text += content or ""
286-
elif usage := data.get("usage"):
287-
output.output_tokens = usage.get(
288-
"completion_tokens")
321+
generated_text += content or ""
322+
elif usage := data.get("usage"):
323+
output.output_tokens = usage.get(
324+
"completion_tokens")
289325

290-
most_recent_timestamp = timestamp
326+
most_recent_timestamp = timestamp
327+
328+
output.generated_text = generated_text
329+
output.latency = most_recent_timestamp - st
330+
else:
331+
content = await response.content.read()
332+
data = json.loads(content.decode())
333+
output.generated_text = data["choices"][0]["message"][
334+
"content"]
335+
output.output_tokens = data["usage"][
336+
"completion_tokens"]
337+
output.itl = []
338+
output.latency = time.perf_counter() - st
339+
output.ttft = -1
291340

292-
output.generated_text = generated_text
293-
output.success = True
294-
output.latency = most_recent_timestamp - st
295341
else:
296342
output.error = response.reason or ""
297343
output.success = False

0 commit comments

Comments
 (0)