Skip to content

Commit db5b65b

Browse files
committed
Add metrics to lightspeed-stack
This patch is adding the /metrics endpoint to the project. Differences with road-core/service: * The metrics are prefixed with "ls_" instead of "ols_" * The "provider_model_configuration" does not set the non-default model/providers to 0 because we currently do not have a way to set a default model/provider in the configuration. A TODO was left in the code. Supported metrics: * rest_api_calls_total: Counter to track REST API calls * response_duration_seconds: Histogram to measure how long it takes to handle requests * provider_model_configuration: Indicates what provider + model customers are using * llm_calls_total: How many LLM calls were made for each provider + model * llm_calls_failures_total: How many LLM calls failed * llm_calls_validation_errors_total: How many LLM calls had validation errors Missing metrics: * llm_token_sent_total: How many tokens were sent * llm_token_received_total: How many tokens were received The above metrics are missing because token counting PR (lightspeed-core#215) is not merged yet. Signed-off-by: Lucas Alvares Gomes <[email protected]>
1 parent 28abfaf commit db5b65b

File tree

15 files changed

+363
-46
lines changed

15 files changed

+363
-46
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ dependencies = [
1313
"llama-stack>=0.2.13",
1414
"rich>=14.0.0",
1515
"cachetools>=6.1.0",
16+
"prometheus-client>=0.22.1",
17+
"starlette>=0.47.1",
1618
]
1719

1820
[tool.pyright]

src/app/endpoints/metrics.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Handler for REST API call to provide metrics."""
2+
3+
from fastapi.responses import PlainTextResponse
4+
from fastapi import APIRouter, Request
5+
from prometheus_client import (
6+
generate_latest,
7+
CONTENT_TYPE_LATEST,
8+
)
9+
10+
router = APIRouter(tags=["metrics"])
11+
12+
13+
@router.get("/metrics", response_class=PlainTextResponse)
14+
def metrics_endpoint_handler(_request: Request) -> PlainTextResponse:
15+
"""Handle request to the /metrics endpoint."""
16+
return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)

src/app/endpoints/query.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from client import LlamaStackClientHolder
2525
from configuration import configuration
2626
from app.endpoints.conversations import conversation_id_to_agent_id
27+
import metrics
2728
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
2829
from models.requests import QueryRequest, Attachment
2930
import constants
@@ -122,14 +123,18 @@ def query_endpoint_handler(
122123
try:
123124
# try to get Llama Stack client
124125
client = LlamaStackClientHolder().get_client()
125-
model_id = select_model_id(client.models.list(), query_request)
126+
model_id, provider_id = select_model_and_provider_id(
127+
client.models.list(), query_request
128+
)
126129
response, conversation_id = retrieve_response(
127130
client,
128131
model_id,
129132
query_request,
130133
token,
131134
mcp_headers=mcp_headers,
132135
)
136+
# Update metrics for the LLM call
137+
metrics.llm_calls_total.labels(provider_id, model_id).inc()
133138

134139
if not is_transcripts_enabled():
135140
logger.debug("Transcript collection is disabled in the configuration")
@@ -150,6 +155,8 @@ def query_endpoint_handler(
150155

151156
# connection to Llama Stack server
152157
except APIConnectionError as e:
158+
# Update metrics for the LLM call failure
159+
metrics.llm_calls_failures_total.inc()
153160
logger.error("Unable to connect to Llama Stack: %s", e)
154161
raise HTTPException(
155162
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -160,8 +167,10 @@ def query_endpoint_handler(
160167
) from e
161168

162169

163-
def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> str:
164-
"""Select the model ID based on the request or available models."""
170+
def select_model_and_provider_id(
171+
models: ModelListResponse, query_request: QueryRequest
172+
) -> tuple[str, str | None]:
173+
"""Select the model ID and provider ID based on the request or available models."""
165174
model_id = query_request.model
166175
provider_id = query_request.provider
167176

@@ -173,9 +182,11 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
173182
m
174183
for m in models
175184
if m.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue]
176-
).identifier
185+
)
186+
model_id = model.identifier
187+
provider_id = model.provider_id
177188
logger.info("Selected model: %s", model)
178-
return model
189+
return model_id, provider_id
179190
except (StopIteration, AttributeError) as e:
180191
message = "No LLM model found in available models"
181192
logger.error(message)
@@ -201,7 +212,7 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
201212
},
202213
)
203214

204-
return model_id
215+
return model_id, provider_id
205216

206217

207218
def _is_inout_shield(shield: Shield) -> bool:
@@ -218,7 +229,7 @@ def is_input_shield(shield: Shield) -> bool:
218229
return _is_inout_shield(shield) or not is_output_shield(shield)
219230

220231

221-
def retrieve_response(
232+
def retrieve_response( # pylint: disable=too-many-locals
222233
client: LlamaStackClient,
223234
model_id: str,
224235
query_request: QueryRequest,
@@ -288,6 +299,14 @@ def retrieve_response(
288299
toolgroups=toolgroups or None,
289300
)
290301

302+
# Check for validation errors in the response
303+
steps = getattr(response, "steps", [])
304+
for step in steps:
305+
if step.step_type == "shield_call" and step.violation:
306+
# Metric for LLM validation errors
307+
metrics.llm_calls_validation_errors_total.inc()
308+
break
309+
291310
return str(response.output_message.content), conversation_id # type: ignore[union-attr]
292311

293312

src/app/endpoints/streaming_query.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from auth import get_auth_dependency
2424
from client import AsyncLlamaStackClientHolder
2525
from configuration import configuration
26+
import metrics
2627
from models.requests import QueryRequest
2728
from utils.endpoints import check_configuration_loaded, get_system_prompt
2829
from utils.common import retrieve_user_id
@@ -37,7 +38,7 @@
3738
is_output_shield,
3839
is_transcripts_enabled,
3940
store_transcript,
40-
select_model_id,
41+
select_model_and_provider_id,
4142
validate_attachments_metadata,
4243
)
4344

@@ -229,6 +230,8 @@ def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]:
229230
}
230231
)
231232
else:
233+
# Metric for LLM validation errors
234+
metrics.llm_calls_validation_errors_total.inc()
232235
violation = (
233236
f"Violation: {violation.user_message} (Metadata: {violation.metadata})"
234237
)
@@ -421,7 +424,9 @@ async def streaming_query_endpoint_handler(
421424
try:
422425
# try to get Llama Stack client
423426
client = AsyncLlamaStackClientHolder().get_client()
424-
model_id = select_model_id(await client.models.list(), query_request)
427+
model_id, provider_id = select_model_and_provider_id(
428+
await client.models.list(), query_request
429+
)
425430
response, conversation_id = await retrieve_response(
426431
client,
427432
model_id,
@@ -465,9 +470,14 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
465470
attachments=query_request.attachments or [],
466471
)
467472

473+
# Update metrics for the LLM call
474+
metrics.llm_calls_total.labels(provider_id, model_id).inc()
475+
468476
return StreamingResponse(response_generator(response))
469477
# connection to Llama Stack server
470478
except APIConnectionError as e:
479+
# Update metrics for the LLM call failure
480+
metrics.llm_calls_failures_total.inc()
471481
logger.error("Unable to connect to Llama Stack: %s", e)
472482
raise HTTPException(
473483
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

src/app/main.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
"""Definition of FastAPI based web service."""
22

3-
from fastapi import FastAPI
3+
from typing import Callable, Awaitable
4+
5+
from fastapi import FastAPI, Request, Response
46
from fastapi.middleware.cors import CORSMiddleware
5-
from app import routers
7+
from starlette.routing import Mount, Route, WebSocketRoute
68

7-
import version
8-
from log import get_logger
9+
from app import routers
910
from configuration import configuration
11+
from log import get_logger
12+
import metrics
13+
from metrics.utils import setup_model_metrics
1014
from utils.common import register_mcp_servers_async
15+
import version
1116

1217
logger = get_logger(__name__)
1318

@@ -34,9 +39,43 @@
3439
allow_headers=["*"],
3540
)
3641

42+
43+
@app.middleware("")
44+
async def rest_api_metrics(
45+
request: Request, call_next: Callable[[Request], Awaitable[Response]]
46+
) -> Response:
47+
"""Middleware with REST API counter update logic."""
48+
path = request.url.path
49+
logger.debug("Received request for path: %s", path)
50+
51+
# ignore paths that are not part of the app routes
52+
if path not in app_routes_paths:
53+
return await call_next(request)
54+
55+
logger.debug("Processing API request for path: %s", path)
56+
57+
# measure time to handle duration + update histogram
58+
with metrics.response_duration_seconds.labels(path).time():
59+
response = await call_next(request)
60+
61+
# ignore /metrics endpoint that will be called periodically
62+
if not path.endswith("/metrics"):
63+
# just update metrics
64+
metrics.rest_api_calls_total.labels(path, response.status_code).inc()
65+
return response
66+
67+
3768
logger.info("Including routers")
3869
routers.include_routers(app)
3970

71+
app_routes_paths = [
72+
route.path
73+
for route in app.routes
74+
if isinstance(route, (Mount, Route, WebSocketRoute))
75+
]
76+
77+
setup_model_metrics()
78+
4079

4180
@app.on_event("startup")
4281
async def startup_event() -> None:

src/app/routers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
streaming_query,
1414
authorized,
1515
conversations,
16+
metrics,
1617
)
1718

1819

@@ -34,3 +35,4 @@ def include_routers(app: FastAPI) -> None:
3435
# road-core does not version these endpoints
3536
app.include_router(health.router)
3637
app.include_router(authorized.router)
38+
app.include_router(metrics.router)

src/metrics/__init__.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Metrics module for Lightspeed Stack."""
2+
3+
from prometheus_client import (
4+
Counter,
5+
Gauge,
6+
Histogram,
7+
)
8+
9+
# Counter to track REST API calls
10+
# This will be used to count how many times each API endpoint is called
11+
# and the status code of the response
12+
rest_api_calls_total = Counter(
13+
"ls_rest_api_calls_total", "REST API calls counter", ["path", "status_code"]
14+
)
15+
16+
# Histogram to measure response durations
17+
# This will be used to track how long it takes to handle requests
18+
response_duration_seconds = Histogram(
19+
"ls_response_duration_seconds", "Response durations", ["path"]
20+
)
21+
22+
# Metric that indicates what provider + model customers are using so we can
23+
# understand what is popular/important
24+
provider_model_configuration = Gauge(
25+
"ls_provider_model_configuration",
26+
"LLM provider/models combinations defined in configuration",
27+
["provider", "model"],
28+
)
29+
30+
# Metric that counts how many LLM calls were made for each provider + model
31+
llm_calls_total = Counter(
32+
"ls_llm_calls_total", "LLM calls counter", ["provider", "model"]
33+
)
34+
35+
# Metric that counts how many LLM calls failed
36+
llm_calls_failures_total = Counter("ls_llm_calls_failures_total", "LLM calls failures")
37+
38+
# Metric that counts how many LLM calls had validation errors
39+
llm_calls_validation_errors_total = Counter(
40+
"ls_llm_validation_errors_total", "LLM validation errors"
41+
)
42+
43+
# TODO(lucasagomes): Add metric for token usage
44+
llm_token_sent_total = Counter(
45+
"ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model"]
46+
)
47+
48+
# TODO(lucasagomes): Add metric for token usage
49+
llm_token_received_total = Counter(
50+
"ls_llm_token_received_total", "LLM tokens received", ["provider", "model"]
51+
)

src/metrics/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Utility functions for metrics handling."""
2+
3+
from client import LlamaStackClientHolder
4+
from log import get_logger
5+
import metrics
6+
7+
logger = get_logger(__name__)
8+
9+
10+
# TODO(lucasagomes): Change this metric once we are allowed to set the the
11+
# default model/provider via the configuration.The default provider/model
12+
# will be set to 1, and the rest will be set to 0.
13+
def setup_model_metrics() -> None:
14+
"""Perform setup of all metrics related to LLM model and provider."""
15+
client = LlamaStackClientHolder().get_client()
16+
models = [
17+
model
18+
for model in client.models.list()
19+
if model.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue]
20+
]
21+
22+
for model in models:
23+
provider = model.provider_id
24+
model_name = model.identifier
25+
if provider and model_name:
26+
label_key = (provider, model_name)
27+
metrics.provider_model_configuration.labels(*label_key).set(1)
28+
logger.debug(
29+
"Set provider/model configuration for %s/%s to 1",
30+
provider,
31+
model_name,
32+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Unit tests for the /metrics REST API endpoint."""
2+
3+
from app.endpoints.metrics import metrics_endpoint_handler
4+
5+
6+
def test_metrics_endpoint():
7+
"""Test the metrics endpoint handler."""
8+
response = metrics_endpoint_handler(None)
9+
assert response is not None
10+
assert response.status_code == 200
11+
assert "text/plain" in response.headers["Content-Type"]
12+
13+
response_body = response.body.decode()
14+
15+
# Check if the response contains Prometheus metrics format
16+
assert "# TYPE ls_rest_api_calls_total counter" in response_body
17+
assert "# TYPE ls_response_duration_seconds histogram" in response_body
18+
assert "# TYPE ls_provider_model_configuration gauge" in response_body
19+
assert "# TYPE ls_llm_calls_total counter" in response_body
20+
assert "# TYPE ls_llm_calls_failures_total counter" in response_body
21+
assert "# TYPE ls_llm_calls_failures_created gauge" in response_body
22+
assert "# TYPE ls_llm_validation_errors_total counter" in response_body
23+
assert "# TYPE ls_llm_validation_errors_created gauge" in response_body
24+
assert "# TYPE ls_llm_token_sent_total counter" in response_body
25+
assert "# TYPE ls_llm_token_received_total counter" in response_body

0 commit comments

Comments
 (0)