Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,11 @@
}
],
"title": "System Prompt"
},
"default_estimation_tokenizer": {
"type": "string",
"title": "Default Estimation Tokenizer",
"description": "The default tokenizer to use for estimating token usage when the model is not supported by tiktoken."
}
},
"type": "object",
Expand Down Expand Up @@ -1132,14 +1137,36 @@
"response": {
"type": "string",
"title": "Response"
},
"input_tokens": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Input Tokens"
},
"output_tokens": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Output Tokens"
}
},
"type": "object",
"required": [
"response"
],
"title": "QueryResponse",
"description": "Model representing LLM response to a query.\n\nAttributes:\n conversation_id: The optional conversation ID (UUID).\n response: The response.",
"description": "Model representing LLM response to a query.\n\nAttributes:\n conversation_id: The optional conversation ID (UUID).\n response: The response.\n input_tokens: Number of tokens sent to LLM.\n output_tokens: Number of tokens received from LLM.",
"examples": [
{
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"llama-stack>=0.2.13",
"rich>=14.0.0",
"cachetools>=6.1.0",
"tiktoken>=0.9.0,<1.0.0",
]

[tool.pyright]
Expand Down
50 changes: 41 additions & 9 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from utils.endpoints import check_configuration_loaded, get_system_prompt
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.suid import get_suid
from utils.token_counter import get_token_counter
from utils.types import GraniteToolParser

logger = logging.getLogger("app.endpoints.handlers")
Expand Down Expand Up @@ -121,7 +122,7 @@ def query_endpoint_handler(
# try to get Llama Stack client
client = LlamaStackClientHolder().get_client()
model_id = select_model_id(client.models.list(), query_request)
response, conversation_id = retrieve_response(
response, conversation_id, token_usage = retrieve_response(
client,
model_id,
query_request,
Expand All @@ -144,7 +145,12 @@ def query_endpoint_handler(
attachments=query_request.attachments or [],
)

return QueryResponse(conversation_id=conversation_id, response=response)
return QueryResponse(
conversation_id=conversation_id,
response=response,
input_tokens=token_usage["input_tokens"],
output_tokens=token_usage["output_tokens"],
)

# connection to Llama Stack server
except APIConnectionError as e:
Expand Down Expand Up @@ -202,13 +208,21 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
return model_id


def _build_toolgroups(client: LlamaStackClient) -> list[Toolgroup] | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it part of this PR?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was to alleviate a "too many local variables" error that appeared while I was fixing merge conflicts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we can just # pylint: disable=too-many-locals for this

"""Build toolgroups from vector DBs and MCP servers."""
vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
return (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]


def retrieve_response(
client: LlamaStackClient,
model_id: str,
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[str, str]:
) -> tuple[str, str, dict[str, int]]:
"""Retrieve response from LLMs and agents."""
available_shields = [shield.identifier for shield in client.shields.list()]
if not available_shields:
Expand Down Expand Up @@ -251,19 +265,37 @@ def retrieve_response(
),
}

vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]
response = agent.create_turn(
messages=[UserMessage(role="user", content=query_request.query)],
session_id=conversation_id,
documents=query_request.get_documents(),
stream=False,
toolgroups=toolgroups or None,
toolgroups=_build_toolgroups(client) or None,
)

return str(response.output_message.content), conversation_id # type: ignore[union-attr]
response_content = str(response.output_message.content) # type: ignore[union-attr]

# Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it
# try:
# token_usage = {
# "input_tokens": response.usage.get("prompt_tokens", 0),
# "output_tokens": response.usage.get("completion_tokens", 0),
# }
# except AttributeError:
# Estimate token usage
try:
token_counter = get_token_counter(model_id)
token_usage = token_counter.count_conversation_turn_tokens(
conversation_id, system_prompt, query_request, response_content
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to estimate token usage: %s", e)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This probably should be an error instead of a warning

token_usage = {
"input_tokens": 0,
"output_tokens": 0,
}

return response_content, conversation_id, token_usage


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
Expand Down
80 changes: 65 additions & 15 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

from llama_stack_client import APIConnectionError
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
from llama_stack_client.types.agents.turn_create_params import Toolgroup
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
from llama_stack_client.types import UserMessage # type: ignore


from fastapi import APIRouter, HTTPException, Request, Depends, status
from fastapi.responses import StreamingResponse

Expand All @@ -24,6 +26,7 @@
from utils.common import retrieve_user_id
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.suid import get_suid
from utils.token_counter import get_token_counter
from utils.types import GraniteToolParser

from app.endpoints.conversations import conversation_id_to_agent_id
Expand Down Expand Up @@ -97,8 +100,13 @@ def stream_start_event(conversation_id: str) -> str:
)


def stream_end_event(metadata_map: dict) -> str:
"""Yield the end of the data stream."""
def stream_end_event(metadata_map: dict, metrics_map: dict[str, int]) -> str:
"""Yield the end of the data stream.

Args:
metadata_map: Dictionary containing metadata about referenced documents
metrics_map: Dictionary containing metrics like 'input_tokens' and 'output_tokens'
"""
return format_stream_data(
{
"event": "end",
Expand All @@ -114,8 +122,8 @@ def stream_end_event(metadata_map: dict) -> str:
)
],
"truncated": None, # TODO(jboos): implement truncated
"input_tokens": 0, # TODO(jboos): implement input tokens
"output_tokens": 0, # TODO(jboos): implement output tokens
"input_tokens": metrics_map.get("input_tokens", 0),
"output_tokens": metrics_map.get("output_tokens", 0),
},
"available_quotas": {}, # TODO(jboos): implement available quotas
}
Expand Down Expand Up @@ -204,7 +212,7 @@ async def streaming_query_endpoint_handler(
# try to get Llama Stack client
client = AsyncLlamaStackClientHolder().get_client()
model_id = select_model_id(await client.models.list(), query_request)
response, conversation_id = await retrieve_response(
response, conversation_id, token_usage = await retrieve_response(
client,
model_id,
query_request,
Expand All @@ -229,7 +237,24 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
chunk_id += 1
yield event

yield stream_end_event(metadata_map)
# Currently (2025-07-08) the usage is not returned by the API, so we need to estimate
# try:
# output_tokens = response.usage.get("completion_tokens", 0)
# except AttributeError:
# Estimate output tokens from complete response
try:
token_counter = get_token_counter(model_id)
output_tokens = token_counter.count_tokens(complete_response)
logger.debug("Estimated output tokens: %s", output_tokens)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to estimate output tokens: %s", e)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: s/warning/error/g

output_tokens = 0

metrics_map = {
"input_tokens": token_usage["input_tokens"],
"output_tokens": output_tokens,
}
yield stream_end_event(metadata_map, metrics_map)
Comment on lines +240 to +257
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling and remove commented code.

The token counting implementation has broad exception handling and contains commented-out code that should be cleaned up.

-            # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate
-            # try:
-            #     output_tokens = response.usage.get("completion_tokens", 0)
-            # except AttributeError:
-            # Estimate output tokens from complete response
             try:
                 token_counter = get_token_counter(model_id)
                 output_tokens = token_counter.count_tokens(complete_response)
                 logger.debug("Estimated output tokens: %s", output_tokens)
-            except Exception as e:  # pylint: disable=broad-exception-caught
+            except (KeyError, ValueError, AttributeError) as e:
                 logger.warning("Failed to estimate output tokens: %s", e)
                 output_tokens = 0

Rationale: Remove outdated commented code and use more specific exception handling instead of catching all exceptions.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Currently (2025-07-08) the usage is not returned by the API, so we need to estimate
# try:
# output_tokens = response.usage.get("completion_tokens", 0)
# except AttributeError:
# Estimate output tokens from complete response
try:
token_counter = get_token_counter(model_id)
output_tokens = token_counter.count_tokens(complete_response)
logger.debug("Estimated output tokens: %s", output_tokens)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to estimate output tokens: %s", e)
output_tokens = 0
metrics_map = {
"input_tokens": token_usage["input_tokens"],
"output_tokens": output_tokens,
}
yield stream_end_event(metadata_map, metrics_map)
try:
token_counter = get_token_counter(model_id)
output_tokens = token_counter.count_tokens(complete_response)
logger.debug("Estimated output tokens: %s", output_tokens)
except (KeyError, ValueError, AttributeError) as e:
logger.warning("Failed to estimate output tokens: %s", e)
output_tokens = 0
metrics_map = {
"input_tokens": token_usage["input_tokens"],
"output_tokens": output_tokens,
}
yield stream_end_event(metadata_map, metrics_map)
🤖 Prompt for AI Agents
In src/app/endpoints/streaming_query.py around lines 240 to 257, remove the
commented-out code related to usage retrieval since it is outdated. Replace the
broad except Exception clause with more specific exception handling that targets
only the expected errors from get_token_counter or count_tokens, such as
AttributeError or a custom token counting error if applicable. This will improve
error clarity and maintain cleaner code.


if not is_transcripts_enabled():
logger.debug("Transcript collection is disabled in the configuration")
Expand Down Expand Up @@ -260,13 +285,23 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
) from e


async def _build_toolgroups(client: AsyncLlamaStackClient) -> list[Toolgroup] | None:
"""Build toolgroups from vector DBs and MCP servers."""
vector_db_ids = [
vector_db.identifier for vector_db in await client.vector_dbs.list()
]
return (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]


async def retrieve_response(
client: AsyncLlamaStackClient,
model_id: str,
query_request: QueryRequest,
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[Any, str]:
) -> tuple[Any, str, dict[str, int]]:
"""Retrieve response from LLMs and agents."""
available_shields = [shield.identifier for shield in await client.shields.list()]
if not available_shields:
Expand Down Expand Up @@ -312,18 +347,33 @@ async def retrieve_response(
}

logger.debug("Session ID: %s", conversation_id)
vector_db_ids = [
vector_db.identifier for vector_db in await client.vector_dbs.list()
]
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]

response = await agent.create_turn(
messages=[UserMessage(role="user", content=query_request.query)],
session_id=conversation_id,
documents=query_request.get_documents(),
stream=True,
toolgroups=toolgroups or None,
toolgroups=await _build_toolgroups(client) or None,
)

return response, conversation_id
# Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it
# try:
# token_usage = {
# "input_tokens": response.usage.get("prompt_tokens", 0),
# "output_tokens": 0, # Will be calculated during streaming
# }
# except AttributeError:
# # Estimate input tokens (Output will be calculated during streaming)
try:
token_counter = get_token_counter(model_id)
token_usage = token_counter.count_conversation_turn_tokens(
conversation_id, system_prompt, query_request
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to estimate token usage: %s", e)
token_usage = {
"input_tokens": 0,
"output_tokens": 0,
}

return response, conversation_id, token_usage
Comment on lines +359 to +379
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling and remove commented code.

Similar to the earlier issue, this section has broad exception handling and outdated commented code.

-    # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it
-    # try:
-    #     token_usage = {
-    #         "input_tokens": response.usage.get("prompt_tokens", 0),
-    #         "output_tokens": 0,  # Will be calculated during streaming
-    #     }
-    # except AttributeError:
-    #     # Estimate input tokens (Output will be calculated during streaming)
     try:
         token_counter = get_token_counter(model_id)
         token_usage = token_counter.count_conversation_turn_tokens(
             conversation_id, system_prompt, query_request
         )
-    except Exception as e:  # pylint: disable=broad-exception-caught
+    except (KeyError, ValueError, AttributeError) as e:
         logger.warning("Failed to estimate token usage: %s", e)
         token_usage = {
             "input_tokens": 0,
             "output_tokens": 0,
         }

Rationale: Remove outdated commented code and use more specific exception handling.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it
# try:
# token_usage = {
# "input_tokens": response.usage.get("prompt_tokens", 0),
# "output_tokens": 0, # Will be calculated during streaming
# }
# except AttributeError:
# # Estimate input tokens (Output will be calculated during streaming)
try:
token_counter = get_token_counter(model_id)
token_usage = token_counter.count_conversation_turn_tokens(
conversation_id, system_prompt, query_request
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to estimate token usage: %s", e)
token_usage = {
"input_tokens": 0,
"output_tokens": 0,
}
return response, conversation_id, token_usage
try:
token_counter = get_token_counter(model_id)
token_usage = token_counter.count_conversation_turn_tokens(
conversation_id, system_prompt, query_request
)
except (KeyError, ValueError, AttributeError) as e:
logger.warning("Failed to estimate token usage: %s", e)
token_usage = {
"input_tokens": 0,
"output_tokens": 0,
}
return response, conversation_id, token_usage
🤖 Prompt for AI Agents
In src/app/endpoints/streaming_query.py around lines 359 to 379, remove the
outdated commented code related to token usage estimation and replace the broad
exception handling with more specific exceptions that can be raised by
get_token_counter or count_conversation_turn_tokens. This will clean up the code
and improve error handling by catching only expected errors instead of all
exceptions.

2 changes: 2 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
}
)
DEFAULT_AUTHENTICATION_MODULE = AUTH_MOD_NOOP
# Default tokenizer for estimating token usage
DEFAULT_ESTIMATION_TOKENIZER = "cl100k_base"

# Data collector constants
DATA_COLLECTOR_COLLECTION_INTERVAL = 7200 # 2 hours in seconds
Expand Down
1 change: 1 addition & 0 deletions src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class Customization(BaseModel):
disable_query_system_prompt: bool = False
system_prompt_path: Optional[FilePath] = None
system_prompt: Optional[str] = None
default_estimation_tokenizer: str = constants.DEFAULT_ESTIMATION_TOKENIZER

@model_validator(mode="after")
def check_customization_model(self) -> Self:
Expand Down
6 changes: 4 additions & 2 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class ModelsResponse(BaseModel):
# - referenced_documents: The optional URLs and titles for the documents used
# to generate the response.
# - truncated: Set to True if conversation history was truncated to be within context window.
# - input_tokens: Number of tokens sent to LLM
# - output_tokens: Number of tokens received from LLM
# - available_quotas: Quota available as measured by all configured quota limiters
# - tool_calls: List of tool requests.
# - tool_results: List of tool results.
Expand All @@ -28,10 +26,14 @@ class QueryResponse(BaseModel):
Attributes:
conversation_id: The optional conversation ID (UUID).
response: The response.
input_tokens: Number of tokens sent to LLM.
output_tokens: Number of tokens received from LLM.
"""

conversation_id: Optional[str] = None
response: str
input_tokens: int = 0
output_tokens: int = 0

# provides examples for /docs endpoint
model_config = {
Expand Down
Loading
Loading