-
Notifications
You must be signed in to change notification settings - Fork 51
LCORE-178 Token counting #215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
53a5441
ac90f08
4c1bd95
67269cb
ebd1424
4ef7798
cd5bb03
3c5a3e1
5818297
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,7 @@ | |
| from utils.endpoints import check_configuration_loaded, get_system_prompt | ||
| from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups | ||
| from utils.suid import get_suid | ||
| from utils.token_counter import get_token_counter | ||
| from utils.types import GraniteToolParser | ||
|
|
||
| logger = logging.getLogger("app.endpoints.handlers") | ||
|
|
@@ -121,7 +122,7 @@ def query_endpoint_handler( | |
| # try to get Llama Stack client | ||
| client = LlamaStackClientHolder().get_client() | ||
| model_id = select_model_id(client.models.list(), query_request) | ||
| response, conversation_id = retrieve_response( | ||
| response, conversation_id, token_usage = retrieve_response( | ||
| client, | ||
| model_id, | ||
| query_request, | ||
|
|
@@ -144,7 +145,12 @@ def query_endpoint_handler( | |
| attachments=query_request.attachments or [], | ||
| ) | ||
|
|
||
| return QueryResponse(conversation_id=conversation_id, response=response) | ||
| return QueryResponse( | ||
| conversation_id=conversation_id, | ||
| response=response, | ||
| input_tokens=token_usage["input_tokens"], | ||
| output_tokens=token_usage["output_tokens"], | ||
| ) | ||
|
|
||
| # connection to Llama Stack server | ||
| except APIConnectionError as e: | ||
|
|
@@ -202,13 +208,21 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s | |
| return model_id | ||
|
|
||
|
|
||
| def _build_toolgroups(client: LlamaStackClient) -> list[Toolgroup] | None: | ||
| """Build toolgroups from vector DBs and MCP servers.""" | ||
| vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] | ||
| return (get_rag_toolgroups(vector_db_ids) or []) + [ | ||
| mcp_server.name for mcp_server in configuration.mcp_servers | ||
| ] | ||
|
|
||
|
|
||
| def retrieve_response( | ||
| client: LlamaStackClient, | ||
| model_id: str, | ||
| query_request: QueryRequest, | ||
| token: str, | ||
| mcp_headers: dict[str, dict[str, str]] | None = None, | ||
| ) -> tuple[str, str]: | ||
| ) -> tuple[str, str, dict[str, int]]: | ||
| """Retrieve response from LLMs and agents.""" | ||
| available_shields = [shield.identifier for shield in client.shields.list()] | ||
| if not available_shields: | ||
|
|
@@ -251,19 +265,37 @@ def retrieve_response( | |
| ), | ||
| } | ||
|
|
||
| vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] | ||
| toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ | ||
| mcp_server.name for mcp_server in configuration.mcp_servers | ||
| ] | ||
| response = agent.create_turn( | ||
| messages=[UserMessage(role="user", content=query_request.query)], | ||
| session_id=conversation_id, | ||
| documents=query_request.get_documents(), | ||
| stream=False, | ||
| toolgroups=toolgroups or None, | ||
| toolgroups=_build_toolgroups(client) or None, | ||
| ) | ||
|
|
||
| return str(response.output_message.content), conversation_id # type: ignore[union-attr] | ||
| response_content = str(response.output_message.content) # type: ignore[union-attr] | ||
|
|
||
| # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it | ||
| # try: | ||
| # token_usage = { | ||
| # "input_tokens": response.usage.get("prompt_tokens", 0), | ||
| # "output_tokens": response.usage.get("completion_tokens", 0), | ||
| # } | ||
| # except AttributeError: | ||
| # Estimate token usage | ||
| try: | ||
| token_counter = get_token_counter(model_id) | ||
| token_usage = token_counter.count_conversation_turn_tokens( | ||
| conversation_id, system_prompt, query_request, response_content | ||
| ) | ||
| except Exception as e: # pylint: disable=broad-exception-caught | ||
| logger.warning("Failed to estimate token usage: %s", e) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: This probably should be an error instead of a warning |
||
| token_usage = { | ||
| "input_tokens": 0, | ||
| "output_tokens": 0, | ||
| } | ||
|
|
||
| return response_content, conversation_id, token_usage | ||
|
|
||
|
|
||
| def validate_attachments_metadata(attachments: list[Attachment]) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,10 +9,12 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from llama_stack_client import APIConnectionError | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from llama_stack_client.types.agents.turn_create_params import Toolgroup | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from llama_stack_client import AsyncLlamaStackClient # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from llama_stack_client.types.shared.interleaved_content_item import TextContentItem | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from llama_stack_client.types import UserMessage # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from fastapi import APIRouter, HTTPException, Request, Depends, status | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from fastapi.responses import StreamingResponse | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -24,6 +26,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from utils.common import retrieve_user_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from utils.suid import get_suid | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from utils.token_counter import get_token_counter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from utils.types import GraniteToolParser | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from app.endpoints.conversations import conversation_id_to_agent_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -97,8 +100,13 @@ def stream_start_event(conversation_id: str) -> str: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def stream_end_event(metadata_map: dict) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Yield the end of the data stream.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def stream_end_event(metadata_map: dict, metrics_map: dict[str, int]) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Yield the end of the data stream. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metadata_map: Dictionary containing metadata about referenced documents | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metrics_map: Dictionary containing metrics like 'input_tokens' and 'output_tokens' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return format_stream_data( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "event": "end", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -114,8 +122,8 @@ def stream_end_event(metadata_map: dict) -> str: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "truncated": None, # TODO(jboos): implement truncated | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "input_tokens": 0, # TODO(jboos): implement input tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "output_tokens": 0, # TODO(jboos): implement output tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "input_tokens": metrics_map.get("input_tokens", 0), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "output_tokens": metrics_map.get("output_tokens", 0), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "available_quotas": {}, # TODO(jboos): implement available quotas | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -204,7 +212,7 @@ async def streaming_query_endpoint_handler( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # try to get Llama Stack client | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| client = AsyncLlamaStackClientHolder().get_client() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_id = select_model_id(await client.models.list(), query_request) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response, conversation_id = await retrieve_response( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response, conversation_id, token_usage = await retrieve_response( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| client, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| query_request, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -229,7 +237,24 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| chunk_id += 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield event | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield stream_end_event(metadata_map) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # output_tokens = response.usage.get("completion_tokens", 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # except AttributeError: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Estimate output tokens from complete response | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_counter = get_token_counter(model_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_tokens = token_counter.count_tokens(complete_response) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.debug("Estimated output tokens: %s", output_tokens) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: # pylint: disable=broad-exception-caught | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning("Failed to estimate output tokens: %s", e) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: s/warning/error/g |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_tokens = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metrics_map = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "input_tokens": token_usage["input_tokens"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "output_tokens": output_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield stream_end_event(metadata_map, metrics_map) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+240
to
+257
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Improve error handling and remove commented code. The token counting implementation has broad exception handling and contains commented-out code that should be cleaned up. - # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate
- # try:
- # output_tokens = response.usage.get("completion_tokens", 0)
- # except AttributeError:
- # Estimate output tokens from complete response
try:
token_counter = get_token_counter(model_id)
output_tokens = token_counter.count_tokens(complete_response)
logger.debug("Estimated output tokens: %s", output_tokens)
- except Exception as e: # pylint: disable=broad-exception-caught
+ except (KeyError, ValueError, AttributeError) as e:
logger.warning("Failed to estimate output tokens: %s", e)
output_tokens = 0Rationale: Remove outdated commented code and use more specific exception handling instead of catching all exceptions. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not is_transcripts_enabled(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.debug("Transcript collection is disabled in the configuration") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -260,13 +285,23 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) from e | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def _build_toolgroups(client: AsyncLlamaStackClient) -> list[Toolgroup] | None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Build toolgroups from vector DBs and MCP servers.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vector_db_ids = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vector_db.identifier for vector_db in await client.vector_dbs.list() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return (get_rag_toolgroups(vector_db_ids) or []) + [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mcp_server.name for mcp_server in configuration.mcp_servers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def retrieve_response( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| client: AsyncLlamaStackClient, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_id: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| query_request: QueryRequest, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mcp_headers: dict[str, dict[str, str]] | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> tuple[Any, str]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> tuple[Any, str, dict[str, int]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Retrieve response from LLMs and agents.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| available_shields = [shield.identifier for shield in await client.shields.list()] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not available_shields: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -312,18 +347,33 @@ async def retrieve_response( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.debug("Session ID: %s", conversation_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vector_db_ids = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vector_db.identifier for vector_db in await client.vector_dbs.list() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mcp_server.name for mcp_server in configuration.mcp_servers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response = await agent.create_turn( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| messages=[UserMessage(role="user", content=query_request.query)], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| session_id=conversation_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| documents=query_request.get_documents(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stream=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| toolgroups=toolgroups or None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| toolgroups=await _build_toolgroups(client) or None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return response, conversation_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # token_usage = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # "input_tokens": response.usage.get("prompt_tokens", 0), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # "output_tokens": 0, # Will be calculated during streaming | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # except AttributeError: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # # Estimate input tokens (Output will be calculated during streaming) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_counter = get_token_counter(model_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_usage = token_counter.count_conversation_turn_tokens( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| conversation_id, system_prompt, query_request | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: # pylint: disable=broad-exception-caught | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning("Failed to estimate token usage: %s", e) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_usage = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "input_tokens": 0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "output_tokens": 0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return response, conversation_id, token_usage | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+359
to
+379
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Improve error handling and remove commented code. Similar to the earlier issue, this section has broad exception handling and outdated commented code. - # Currently (2025-07-08) the usage is not returned by the API, so we need to estimate it
- # try:
- # token_usage = {
- # "input_tokens": response.usage.get("prompt_tokens", 0),
- # "output_tokens": 0, # Will be calculated during streaming
- # }
- # except AttributeError:
- # # Estimate input tokens (Output will be calculated during streaming)
try:
token_counter = get_token_counter(model_id)
token_usage = token_counter.count_conversation_turn_tokens(
conversation_id, system_prompt, query_request
)
- except Exception as e: # pylint: disable=broad-exception-caught
+ except (KeyError, ValueError, AttributeError) as e:
logger.warning("Failed to estimate token usage: %s", e)
token_usage = {
"input_tokens": 0,
"output_tokens": 0,
}Rationale: Remove outdated commented code and use more specific exception handling. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it part of this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was to alleviate a "too many local variables" error that appeared while I was fixing merge conflicts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably we can just # pylint: disable=too-many-locals for this