@@ -48,6 +48,7 @@ class RequestFuncOutput:
48
48
49
49
async def async_request_trt_llm (
50
50
request_func_input : RequestFuncInput ,
51
+ streaming : bool = True ,
51
52
pbar : Optional [tqdm ] = None ,
52
53
) -> RequestFuncOutput :
53
54
api_url = request_func_input .api_url
@@ -61,7 +62,7 @@ async def async_request_trt_llm(
61
62
"temperature" : 0.0 ,
62
63
"top_p" : 1.0 ,
63
64
"max_tokens" : request_func_input .output_len ,
64
- "stream" : True ,
65
+ "stream" : streaming ,
65
66
}
66
67
if request_func_input .ignore_eos :
67
68
payload ["min_length" ] = request_func_input .output_len
@@ -74,30 +75,39 @@ async def async_request_trt_llm(
74
75
try :
75
76
async with session .post (url = api_url , json = payload ) as response :
76
77
if response .status == 200 :
77
- async for chunk_bytes in response .content :
78
- chunk_bytes = chunk_bytes .strip ()
79
- if not chunk_bytes :
80
- continue
81
-
82
- chunk = chunk_bytes .decode ("utf-8" ).removeprefix (
83
- "data:" )
84
-
85
- data = json .loads (chunk )
86
- output .generated_text += data ["text_output" ]
87
- timestamp = time .perf_counter ()
88
- # First token
89
- if ttft == 0.0 :
90
- ttft = timestamp - st
91
- output .ttft = ttft
92
-
93
- # Decoding phase
94
- else :
95
- output .itl .append (timestamp - most_recent_timestamp )
78
+ output .success = True
79
+ if streaming :
80
+ async for chunk_bytes in response .content :
81
+ chunk_bytes = chunk_bytes .strip ()
82
+ if not chunk_bytes :
83
+ continue
96
84
97
- most_recent_timestamp = timestamp
85
+ chunk = chunk_bytes .decode ("utf-8" ).removeprefix (
86
+ "data:" )
98
87
99
- output .latency = most_recent_timestamp - st
100
- output .success = True
88
+ data = json .loads (chunk )
89
+ output .generated_text += data ["text_output" ]
90
+ timestamp = time .perf_counter ()
91
+ # First token
92
+ if ttft == 0.0 :
93
+ ttft = timestamp - st
94
+ output .ttft = ttft
95
+
96
+ # Decoding phase
97
+ else :
98
+ output .itl .append (timestamp -
99
+ most_recent_timestamp )
100
+
101
+ most_recent_timestamp = timestamp
102
+
103
+ output .latency = most_recent_timestamp - st
104
+ else :
105
+ content = await response .content .read ()
106
+ data = json .loads (content .decode ())
107
+ output .ttft = - 1
108
+ output .itl = []
109
+ output .generated_text = data ["text_output" ]
110
+ output .latency = time .perf_counter () - st
101
111
102
112
else :
103
113
output .error = response .reason or ""
@@ -114,6 +124,7 @@ async def async_request_trt_llm(
114
124
115
125
async def async_request_openai_completions (
116
126
request_func_input : RequestFuncInput ,
127
+ streaming : bool = True ,
117
128
pbar : Optional [tqdm ] = None ,
118
129
) -> RequestFuncOutput :
119
130
api_url = request_func_input .api_url
@@ -131,11 +142,10 @@ async def async_request_openai_completions(
131
142
"repetition_penalty" : 1.0 ,
132
143
"max_tokens" : request_func_input .output_len ,
133
144
"logprobs" : request_func_input .logprobs ,
134
- "stream" : True ,
135
- "stream_options" : {
136
- "include_usage" : True ,
137
- },
145
+ "stream" : streaming ,
138
146
}
147
+ if streaming :
148
+ payload ["stream_options" ] = {"include_usage" : True }
139
149
if request_func_input .ignore_eos :
140
150
payload ["ignore_eos" ] = request_func_input .ignore_eos
141
151
if request_func_input .extra_body :
@@ -154,50 +164,62 @@ async def async_request_openai_completions(
154
164
async with session .post (url = api_url , json = payload ,
155
165
headers = headers ) as response :
156
166
if response .status == 200 :
157
- first_chunk_received = False
158
- async for chunk_bytes in response .content :
159
- chunk_bytes = chunk_bytes .strip ()
160
- if not chunk_bytes :
161
- continue
162
-
163
- chunk = chunk_bytes .decode ("utf-8" ).removeprefix (
164
- "data: " )
165
- if chunk != "[DONE]" :
166
- data = json .loads (chunk )
167
-
168
- # NOTE: Some completion API might have a last
169
- # usage summary response without a token so we
170
- # want to check a token was generated
171
- if choices := data .get ("choices" ):
172
- # Note that text could be empty here
173
- # e.g. for special tokens
174
- text = choices [0 ].get ("text" )
175
- timestamp = time .perf_counter ()
176
- # First token
177
- if not first_chunk_received :
178
- first_chunk_received = True
179
- ttft = time .perf_counter () - st
180
- output .ttft = ttft
181
-
182
- # Decoding phase
183
- else :
184
- output .itl .append (timestamp -
185
- most_recent_timestamp )
186
-
187
- most_recent_timestamp = timestamp
188
- generated_text += text or ""
189
- elif usage := data .get ("usage" ):
190
- output .output_tokens = usage .get (
191
- "completion_tokens" )
192
- if first_chunk_received :
193
- output .success = True
167
+ if streaming :
168
+ first_chunk_received = False
169
+ async for chunk_bytes in response .content :
170
+ chunk_bytes = chunk_bytes .strip ()
171
+ if not chunk_bytes :
172
+ continue
173
+
174
+ chunk = chunk_bytes .decode ("utf-8" ).removeprefix (
175
+ "data: " )
176
+ if chunk != "[DONE]" :
177
+ data = json .loads (chunk )
178
+
179
+ # NOTE: Some completion API might have a last
180
+ # usage summary response without a token so we
181
+ # want to check a token was generated
182
+ if choices := data .get ("choices" ):
183
+ # Note that text could be empty here
184
+ # e.g. for special tokens
185
+ text = choices [0 ].get ("text" )
186
+ timestamp = time .perf_counter ()
187
+ # First token
188
+ if not first_chunk_received :
189
+ first_chunk_received = True
190
+ ttft = time .perf_counter () - st
191
+ output .ttft = ttft
192
+
193
+ # Decoding phase
194
+ else :
195
+ output .itl .append (timestamp -
196
+ most_recent_timestamp )
197
+
198
+ most_recent_timestamp = timestamp
199
+ generated_text += text or ""
200
+ elif usage := data .get ("usage" ):
201
+ output .output_tokens = usage .get (
202
+ "completion_tokens" )
203
+ if first_chunk_received :
204
+ output .success = True
205
+ else :
206
+ output .success = False
207
+ output .error = (
208
+ "Never received a valid chunk to calculate TTFT."
209
+ "This response will be marked as failed!" )
210
+ output .generated_text = generated_text
211
+ output .latency = most_recent_timestamp - st
194
212
else :
195
- output .success = False
196
- output .error = (
197
- "Never received a valid chunk to calculate TTFT."
198
- "This response will be marked as failed!" )
199
- output .generated_text = generated_text
200
- output .latency = most_recent_timestamp - st
213
+ content = await response .content .read ()
214
+ data = json .loads (content .decode ())
215
+ generated_text = data ["choices" ][0 ]["text" ]
216
+ output .success = True
217
+ output .generated_text = generated_text
218
+ output .latency = time .perf_counter () - st
219
+ output .ttft = - 1
220
+ output .itl = []
221
+ output .output_tokens = data ["usage" ][
222
+ "completion_tokens" ]
201
223
else :
202
224
output .error = response .reason or ""
203
225
output .success = False
@@ -213,6 +235,7 @@ async def async_request_openai_completions(
213
235
214
236
async def async_request_openai_chat_completions (
215
237
request_func_input : RequestFuncInput ,
238
+ streaming : bool = True ,
216
239
pbar : Optional [tqdm ] = None ,
217
240
) -> RequestFuncOutput :
218
241
api_url = request_func_input .api_url
@@ -222,23 +245,34 @@ async def async_request_openai_chat_completions(
222
245
223
246
async with aiohttp .ClientSession (trust_env = True ,
224
247
timeout = AIOHTTP_TIMEOUT ) as session :
225
- content = [{"type" : "text" , "text" : request_func_input .prompt }]
226
248
payload = {
227
249
"model" : request_func_input .model_name \
228
250
if request_func_input .model_name else request_func_input .model ,
229
251
"messages" : [
230
- {
231
- "role" : "user" ,
232
- "content" : content
233
- },
234
252
],
235
253
"temperature" : 0.0 ,
236
254
"max_completion_tokens" : request_func_input .output_len ,
237
- "stream" : True ,
238
- "stream_options" : {
239
- "include_usage" : True ,
240
- },
255
+ "stream" : streaming ,
241
256
}
257
+
258
+ if isinstance (request_func_input .prompt , list ) and all (
259
+ [isinstance (i , int ) for i in request_func_input .prompt ]):
260
+ payload ["prompt_token_ids" ] = request_func_input .prompt
261
+ else :
262
+ assert isinstance (
263
+ request_func_input .prompt ,
264
+ str ), "Prompt must be a string or a list of integers"
265
+ payload ["messages" ].append ({
266
+ "role" :
267
+ "user" ,
268
+ "content" : [{
269
+ "type" : "text" ,
270
+ "text" : request_func_input .prompt
271
+ }]
272
+ })
273
+
274
+ if streaming :
275
+ payload ["stream_options" ] = {"include_usage" : True }
242
276
if request_func_input .ignore_eos :
243
277
payload ["ignore_eos" ] = request_func_input .ignore_eos
244
278
if request_func_input .extra_body :
@@ -259,39 +293,51 @@ async def async_request_openai_chat_completions(
259
293
async with session .post (url = api_url , json = payload ,
260
294
headers = headers ) as response :
261
295
if response .status == 200 :
262
- async for chunk_bytes in response .content :
263
- chunk_bytes = chunk_bytes .strip ()
264
- if not chunk_bytes :
265
- continue
266
-
267
- chunk = chunk_bytes .decode ("utf-8" ).removeprefix (
268
- "data: " )
269
- if chunk != "[DONE]" :
270
- timestamp = time .perf_counter ()
271
- data = json .loads (chunk )
296
+ output .success = True
297
+ if streaming :
298
+ async for chunk_bytes in response .content :
299
+ chunk_bytes = chunk_bytes .strip ()
300
+ if not chunk_bytes :
301
+ continue
302
+
303
+ chunk = chunk_bytes .decode ("utf-8" ).removeprefix (
304
+ "data: " )
305
+ if chunk != "[DONE]" :
306
+ timestamp = time .perf_counter ()
307
+ data = json .loads (chunk )
272
308
273
- if choices := data .get ("choices" ):
274
- content = choices [0 ]["delta" ].get ("content" )
275
- # First token
276
- if ttft == 0.0 :
277
- ttft = timestamp - st
278
- output .ttft = ttft
309
+ if choices := data .get ("choices" ):
310
+ content = choices [0 ]["delta" ].get ("content" )
311
+ # First token
312
+ if ttft == 0.0 :
313
+ ttft = timestamp - st
314
+ output .ttft = ttft
279
315
280
- # Decoding phase
281
- else :
282
- output .itl .append (timestamp -
283
- most_recent_timestamp )
316
+ # Decoding phase
317
+ else :
318
+ output .itl .append (timestamp -
319
+ most_recent_timestamp )
284
320
285
- generated_text += content or ""
286
- elif usage := data .get ("usage" ):
287
- output .output_tokens = usage .get (
288
- "completion_tokens" )
321
+ generated_text += content or ""
322
+ elif usage := data .get ("usage" ):
323
+ output .output_tokens = usage .get (
324
+ "completion_tokens" )
289
325
290
- most_recent_timestamp = timestamp
326
+ most_recent_timestamp = timestamp
327
+
328
+ output .generated_text = generated_text
329
+ output .latency = most_recent_timestamp - st
330
+ else :
331
+ content = await response .content .read ()
332
+ data = json .loads (content .decode ())
333
+ output .generated_text = data ["choices" ][0 ]["message" ][
334
+ "content" ]
335
+ output .output_tokens = data ["usage" ][
336
+ "completion_tokens" ]
337
+ output .itl = []
338
+ output .latency = time .perf_counter () - st
339
+ output .ttft = - 1
291
340
292
- output .generated_text = generated_text
293
- output .success = True
294
- output .latency = most_recent_timestamp - st
295
341
else :
296
342
output .error = response .reason or ""
297
343
output .success = False
0 commit comments