Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 153 additions & 107 deletions tensorrt_llm/serve/scripts/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class RequestFuncOutput:

async def async_request_trt_llm(
request_func_input: RequestFuncInput,
streaming: bool = True,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
Expand All @@ -61,7 +62,7 @@ async def async_request_trt_llm(
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"stream": True,
"stream": streaming,
}
if request_func_input.ignore_eos:
payload["min_length"] = request_func_input.output_len
Expand All @@ -74,30 +75,39 @@ async def async_request_trt_llm(
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = chunk_bytes.decode("utf-8").removeprefix(
"data:")

data = json.loads(chunk)
output.generated_text += data["text_output"]
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft

# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
output.success = True
if streaming:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

most_recent_timestamp = timestamp
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data:")

output.latency = most_recent_timestamp - st
output.success = True
data = json.loads(chunk)
output.generated_text += data["text_output"]
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft

# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)

most_recent_timestamp = timestamp

output.latency = most_recent_timestamp - st
else:
content = await response.content.read()
data = json.loads(content.decode())
output.ttft = -1
output.itl = []
output.generated_text = data["text_output"]
output.latency = time.perf_counter() - st

else:
output.error = response.reason or ""
Expand All @@ -114,6 +124,7 @@ async def async_request_trt_llm(

async def async_request_openai_completions(
request_func_input: RequestFuncInput,
streaming: bool = True,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
Expand All @@ -131,11 +142,10 @@ async def async_request_openai_completions(
"repetition_penalty": 1.0,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": True,
"stream_options": {
"include_usage": True,
},
"stream": streaming,
}
if streaming:
payload["stream_options"] = {"include_usage": True}
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
Expand All @@ -154,50 +164,62 @@ async def async_request_openai_completions(
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
first_chunk_received = False
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
data = json.loads(chunk)

# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if choices := data.get("choices"):
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
first_chunk_received = True
ttft = time.perf_counter() - st
output.ttft = ttft

# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)

most_recent_timestamp = timestamp
generated_text += text or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
if first_chunk_received:
output.success = True
if streaming:
first_chunk_received = False
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
data = json.loads(chunk)

# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if choices := data.get("choices"):
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
first_chunk_received = True
ttft = time.perf_counter() - st
output.ttft = ttft

# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)

most_recent_timestamp = timestamp
generated_text += text or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
content = await response.content.read()
data = json.loads(content.decode())
generated_text = data["choices"][0]["text"]
output.success = True
output.generated_text = generated_text
output.latency = time.perf_counter() - st
output.ttft = -1
output.itl = []
output.output_tokens = data["usage"][
"completion_tokens"]
else:
output.error = response.reason or ""
output.success = False
Expand All @@ -213,6 +235,7 @@ async def async_request_openai_completions(

async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
streaming: bool = True,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
Expand All @@ -222,23 +245,34 @@ async def async_request_openai_chat_completions(

async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"messages": [
{
"role": "user",
"content": content
},
],
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"stream_options": {
"include_usage": True,
},
"stream": streaming,
}

if isinstance(request_func_input.prompt, list) and all(
[isinstance(i, int) for i in request_func_input.prompt]):
payload["prompt_token_ids"] = request_func_input.prompt
else:
assert isinstance(
request_func_input.prompt,
str), "Prompt must be a string or a list of integers"
payload["messages"].append({
"role":
"user",
"content": [{
"type": "text",
"text": request_func_input.prompt
}]
})

if streaming:
payload["stream_options"] = {"include_usage": True}
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
Expand All @@ -259,39 +293,51 @@ async def async_request_openai_chat_completions(
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
output.success = True
if streaming:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)

if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft

# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)

generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")

most_recent_timestamp = timestamp
most_recent_timestamp = timestamp

output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
content = await response.content.read()
data = json.loads(content.decode())
output.generated_text = data["choices"][0]["message"][
"content"]
output.output_tokens = data["usage"][
"completion_tokens"]
output.itl = []
output.latency = time.perf_counter() - st
output.ttft = -1

output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
Expand Down
Loading