Skip to content

Commit 19dd6c1

Browse files
committed
Add metrics to lightspeed-stack
This patch is adding the /metrics endpoint to the project. Differences with road-core/service: * The metrics are prefixed with "ls_" instead of "ols_" * The "provider_model_configuration" does not set the non-default model/providers to 0 because we currently do not have a way to set a default model/provider in the configuration. A TODO was left in the code. Supported metrics: * rest_api_calls_total: Counter to track REST API calls * response_duration_seconds: Histogram to measure how long it takes to handle requests * provider_model_configuration: Indicates what provider + model customers are using * llm_calls_total: How many LLM calls were made for each provider + model * llm_calls_failures_total: How many LLM calls failed * llm_calls_validation_errors_total: How many LLM calls had validation errors Missing metrics: * llm_token_sent_total: How many tokens were sent * llm_token_received_total: How many tokens were received The above metrics are missing because token counting PR (lightspeed-core#215) is not merged yet. Signed-off-by: Lucas Alvares Gomes <[email protected]>
1 parent 1c873fb commit 19dd6c1

File tree

15 files changed

+304
-37
lines changed

15 files changed

+304
-37
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ dependencies = [
1313
"llama-stack>=0.2.13",
1414
"rich>=14.0.0",
1515
"cachetools>=6.1.0",
16+
"prometheus-client>=0.22.1",
17+
"starlette>=0.47.1",
1618
]
1719

1820
[tool.pyright]

src/app/endpoints/metrics.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Handler for REST API call to provide metrics."""
2+
3+
from fastapi.responses import PlainTextResponse
4+
from fastapi import APIRouter, Request
5+
from prometheus_client import (
6+
generate_latest,
7+
CONTENT_TYPE_LATEST,
8+
)
9+
10+
router = APIRouter(tags=["metrics"])
11+
12+
13+
@router.get("/metrics", response_class=PlainTextResponse)
14+
def metrics_endpoint_handler(_request: Request) -> PlainTextResponse:
15+
"""Handle request to the /metrics endpoint."""
16+
return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)

src/app/endpoints/query.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import os
77
from pathlib import Path
8-
from typing import Any
8+
from typing import Any, Tuple
99

1010
from cachetools import TTLCache # type: ignore
1111

@@ -24,6 +24,7 @@
2424
from client import LlamaStackClientHolder
2525
from configuration import configuration
2626
from app.endpoints.conversations import conversation_id_to_agent_id
27+
import metrics
2728
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
2829
from models.requests import QueryRequest, Attachment
2930
import constants
@@ -120,14 +121,18 @@ def query_endpoint_handler(
120121
try:
121122
# try to get Llama Stack client
122123
client = LlamaStackClientHolder().get_client()
123-
model_id = select_model_id(client.models.list(), query_request)
124+
model_id, provider_id = select_model_and_provider_id(
125+
client.models.list(), query_request
126+
)
124127
response, conversation_id = retrieve_response(
125128
client,
126129
model_id,
127130
query_request,
128131
token,
129132
mcp_headers=mcp_headers,
130133
)
134+
# Update metrics for the LLM call
135+
metrics.llm_calls_total.labels(provider_id, model_id).inc()
131136

132137
if not is_transcripts_enabled():
133138
logger.debug("Transcript collection is disabled in the configuration")
@@ -148,6 +153,8 @@ def query_endpoint_handler(
148153

149154
# connection to Llama Stack server
150155
except APIConnectionError as e:
156+
# Update metrics for the LLM call failure
157+
metrics.llm_calls_failures_total.inc()
151158
logger.error("Unable to connect to Llama Stack: %s", e)
152159
raise HTTPException(
153160
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -158,8 +165,10 @@ def query_endpoint_handler(
158165
) from e
159166

160167

161-
def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> str:
162-
"""Select the model ID based on the request or available models."""
168+
def select_model_and_provider_id(
169+
models: ModelListResponse, query_request: QueryRequest
170+
) -> Tuple[str, str | None]:
171+
"""Select the model ID and provider ID based on the request or available models."""
163172
model_id = query_request.model
164173
provider_id = query_request.provider
165174

@@ -171,9 +180,11 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
171180
m
172181
for m in models
173182
if m.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue]
174-
).identifier
183+
)
184+
model_id = model.identifier
185+
provider_id = model.provider_id
175186
logger.info("Selected model: %s", model)
176-
return model
187+
return model_id, provider_id
177188
except (StopIteration, AttributeError) as e:
178189
message = "No LLM model found in available models"
179190
logger.error(message)
@@ -199,7 +210,7 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
199210
},
200211
)
201212

202-
return model_id
213+
return model_id, provider_id
203214

204215

205216
def retrieve_response(
@@ -263,6 +274,14 @@ def retrieve_response(
263274
toolgroups=toolgroups or None,
264275
)
265276

277+
# Check for validation errors in the response
278+
steps = getattr(response, "steps", [])
279+
for step in steps:
280+
if step.step_type == "shield_call" and step.violation:
281+
# Metric for LLM validation errors
282+
metrics.llm_calls_validation_errors_total.inc()
283+
break
284+
266285
return str(response.output_message.content), conversation_id # type: ignore[union-attr]
267286

268287

src/app/endpoints/streaming_query.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from auth import get_auth_dependency
2020
from client import AsyncLlamaStackClientHolder
2121
from configuration import configuration
22+
import metrics
2223
from models.requests import QueryRequest
2324
from utils.endpoints import check_configuration_loaded, get_system_prompt
2425
from utils.common import retrieve_user_id
@@ -31,7 +32,7 @@
3132
get_rag_toolgroups,
3233
is_transcripts_enabled,
3334
store_transcript,
34-
select_model_id,
35+
select_model_and_provider_id,
3536
validate_attachments_metadata,
3637
)
3738

@@ -182,6 +183,24 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
182183
},
183184
}
184185
)
186+
if (
187+
chunk.event.payload.event_type == "step_complete"
188+
and chunk.event.payload.step_type == "shield_call"
189+
):
190+
violation = chunk.event.payload.step_details.violation
191+
if violation:
192+
# Metric for LLM validation errors
193+
metrics.llm_calls_validation_errors_total.inc()
194+
return format_stream_data(
195+
{
196+
"event": "token",
197+
"data": {
198+
"id": chunk_id,
199+
"role": chunk.event.payload.step_type,
200+
"token": violation.user_message,
201+
},
202+
}
203+
)
185204
return None
186205

187206

@@ -203,14 +222,18 @@ async def streaming_query_endpoint_handler(
203222
try:
204223
# try to get Llama Stack client
205224
client = AsyncLlamaStackClientHolder().get_client()
206-
model_id = select_model_id(await client.models.list(), query_request)
225+
model_id, provider_id = select_model_and_provider_id(
226+
await client.models.list(), query_request
227+
)
207228
response, conversation_id = await retrieve_response(
208229
client,
209230
model_id,
210231
query_request,
211232
token,
212233
mcp_headers=mcp_headers,
213234
)
235+
# Update metrics for the LLM call
236+
metrics.llm_calls_total.labels(provider_id, model_id).inc()
214237
metadata_map: dict[str, dict[str, Any]] = {}
215238

216239
async def response_generator(turn_response: Any) -> AsyncIterator[str]:
@@ -250,6 +273,8 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
250273
return StreamingResponse(response_generator(response))
251274
# connection to Llama Stack server
252275
except APIConnectionError as e:
276+
# Update metrics for the LLM call failure
277+
metrics.llm_calls_failures_total.inc()
253278
logger.error("Unable to connect to Llama Stack: %s", e)
254279
raise HTTPException(
255280
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

src/app/main.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
"""Definition of FastAPI based web service."""
22

3-
from fastapi import FastAPI
3+
from typing import Callable, Awaitable
4+
5+
from fastapi import FastAPI, Request, Response
46
from fastapi.middleware.cors import CORSMiddleware
5-
from app import routers
7+
from starlette.routing import Mount, Route, WebSocketRoute
68

7-
import version
8-
from log import get_logger
9+
from app import routers
910
from configuration import configuration
11+
from log import get_logger
12+
import metrics
13+
from metrics.utils import setup_model_metrics
1014
from utils.common import register_mcp_servers_async
15+
import version
1116

1217
logger = get_logger(__name__)
1318

@@ -34,9 +39,43 @@
3439
allow_headers=["*"],
3540
)
3641

42+
43+
@app.middleware("")
44+
async def rest_api_counter(
45+
request: Request, call_next: Callable[[Request], Awaitable[Response]]
46+
) -> Response:
47+
"""Middleware with REST API counter update logic."""
48+
path = request.url.path
49+
logger.debug("Received request for path: %s", path)
50+
51+
# ignore paths that are not part of the app routes
52+
if path not in app_routes_paths:
53+
return await call_next(request)
54+
55+
logger.debug("Processing API request for path: %s", path)
56+
57+
# measure time to handle duration + update histogram
58+
with metrics.response_duration_seconds.labels(path).time():
59+
response = await call_next(request)
60+
61+
# ignore /metrics endpoint that will be called periodically
62+
if not path.endswith("/metrics"):
63+
# just update metrics
64+
metrics.rest_api_calls_total.labels(path, response.status_code).inc()
65+
return response
66+
67+
3768
logger.info("Including routers")
3869
routers.include_routers(app)
3970

71+
app_routes_paths = [
72+
route.path
73+
for route in app.routes
74+
if isinstance(route, (Mount, Route, WebSocketRoute))
75+
]
76+
77+
setup_model_metrics()
78+
4079

4180
@app.on_event("startup")
4281
async def startup_event() -> None:

src/app/routers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
streaming_query,
1414
authorized,
1515
conversations,
16+
metrics,
1617
)
1718

1819

@@ -34,3 +35,4 @@ def include_routers(app: FastAPI) -> None:
3435
# road-core does not version these endpoints
3536
app.include_router(health.router)
3637
app.include_router(authorized.router)
38+
app.include_router(metrics.router)

src/metrics/__init__.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Metrics module for Lightspeed Stack."""
2+
3+
from prometheus_client import (
4+
Counter,
5+
Gauge,
6+
Histogram,
7+
)
8+
9+
# Counter to track REST API calls
10+
# This will be used to count how many times each API endpoint is called
11+
# and the status code of the response
12+
rest_api_calls_total = Counter(
13+
"ls_rest_api_calls_total", "REST API calls counter", ["path", "status_code"]
14+
)
15+
16+
# Histogram to measure response durations
17+
# This will be used to track how long it takes to handle requests
18+
response_duration_seconds = Histogram(
19+
"ls_response_duration_seconds", "Response durations", ["path"]
20+
)
21+
22+
# Metric that indicates what provider + model customers are using so we can
23+
# understand what is popular/important
24+
provider_model_configuration = Gauge(
25+
"ls_provider_model_configuration",
26+
"LLM provider/models combinations defined in configuration",
27+
["provider", "model"],
28+
)
29+
30+
# Metric that counts how many LLM calls were made for each provider + model
31+
llm_calls_total = Counter(
32+
"ls_llm_calls_total", "LLM calls counter", ["provider", "model"]
33+
)
34+
35+
# Metric that counts how many LLM calls failed
36+
llm_calls_failures_total = Counter("ls_llm_calls_failures_total", "LLM calls failures")
37+
38+
# Metric that counts how many LLM calls had validation errors
39+
llm_calls_validation_errors_total = Counter(
40+
"ls_llm_validation_errors_total", "LLM validation errors"
41+
)
42+
43+
# TODO(lucasagomes): Add metric for token usage
44+
llm_token_sent_total = Counter(
45+
"ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model"]
46+
)
47+
48+
# TODO(lucasagomes): Add metric for token usage
49+
llm_token_received_total = Counter(
50+
"ls_llm_token_received_total", "LLM tokens received", ["provider", "model"]
51+
)

src/metrics/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Utility functions for metrics handling."""
2+
3+
from client import LlamaStackClientHolder
4+
from log import get_logger
5+
import metrics
6+
7+
logger = get_logger(__name__)
8+
9+
10+
# TODO(lucasagomes): Change this metric once we are allowed to set the the
11+
# default model/provider via the configuration.The default provider/model
12+
# will be set to 1, and the rest will be set to 0.
13+
def setup_model_metrics() -> None:
14+
"""Perform setup of all metrics related to LLM model and provider."""
15+
client = LlamaStackClientHolder().get_client()
16+
models = [
17+
model
18+
for model in client.models.list()
19+
if model.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue]
20+
]
21+
22+
for model in models:
23+
provider = model.provider_id
24+
model_name = model.identifier
25+
if provider and model_name:
26+
label_key = (provider, model_name)
27+
metrics.provider_model_configuration.labels(*label_key).set(1)
28+
logger.debug(
29+
"Set provider/model configuration for %s/%s to 1",
30+
provider,
31+
model_name,
32+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Unit tests for the /metrics REST API endpoint."""
2+
3+
from app.endpoints.metrics import metrics_endpoint_handler
4+
5+
6+
def test_metrics_endpoint():
7+
"""Test the metrics endpoint handler."""
8+
response = metrics_endpoint_handler(None)
9+
assert response is not None
10+
assert response.status_code == 200
11+
assert "text/plain" in response.headers["Content-Type"]
12+
13+
response_body = response.body.decode()
14+
15+
# Check if the response contains Prometheus metrics format
16+
assert "# TYPE ls_rest_api_calls_total counter" in response_body
17+
assert "# TYPE ls_response_duration_seconds histogram" in response_body
18+
assert "# TYPE ls_provider_model_configuration gauge" in response_body
19+
assert "# TYPE ls_llm_calls_total counter" in response_body
20+
assert "# TYPE ls_llm_calls_failures_total counter" in response_body
21+
assert "# TYPE ls_llm_calls_failures_created gauge" in response_body
22+
assert "# TYPE ls_llm_validation_errors_total counter" in response_body
23+
assert "# TYPE ls_llm_validation_errors_created gauge" in response_body
24+
assert "# TYPE ls_llm_token_sent_total counter" in response_body
25+
assert "# TYPE ls_llm_token_received_total counter" in response_body

0 commit comments

Comments
 (0)