diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 79961e2d..64bbfc5f 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -44,6 +44,7 @@ logger_name, make_logger, ) +from model_engine_server.core.utils.timer import timer from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, EndpointDeleteFailedException, @@ -313,9 +314,10 @@ async def create_completion_sync_task( llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, tokenizer_repository=external_interfaces.tokenizer_repository, ) - response = await use_case.execute( - user=auth, model_endpoint_name=model_endpoint_name, request=request - ) + with timer() as use_case_timer: + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) background_tasks.add_task( external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, TokenUsage( @@ -323,6 +325,7 @@ async def create_completion_sync_task( num_completion_tokens=response.output.num_completion_tokens if response.output else None, + total_duration=use_case_timer.duration, ), metric_metadata, ) @@ -374,8 +377,9 @@ async def create_completion_stream_task( async def event_generator(): try: - async for message in response: - yield {"data": message.json()} + with timer() as use_case_timer: + async for message in response: + yield {"data": message.json()} background_tasks.add_task( external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, TokenUsage( @@ -383,6 +387,7 @@ async def event_generator(): num_completion_tokens=message.output.num_completion_tokens if message.output else None, + total_duration=use_case_timer.duration, ), metric_metadata, ) diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 346c9ae2..fc531c1f 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -280,13 +280,27 @@ class CompletionStreamV1Response(BaseModel): class TokenUsage(BaseModel): + """ + Token usage for a prompt completion task. + """ + num_prompt_tokens: Optional[int] = 0 num_completion_tokens: Optional[int] = 0 + total_duration: Optional[float] = None + """Includes time spent waiting for the model to be ready.""" @property def num_total_tokens(self) -> int: return (self.num_prompt_tokens or 0) + (self.num_completion_tokens or 0) + @property + def total_tokens_per_second(self) -> float: + return ( + self.num_total_tokens / self.total_duration + if self.total_duration and self.total_duration > 0 + else 0.0 + ) + class CreateFineTuneRequest(BaseModel): model: str diff --git a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py index 9b63a135..dc419a07 100644 --- a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py @@ -21,6 +21,7 @@ def __init__(self): self.database_cache_miss = 0 self.route_call = defaultdict(int) self.token_count = 0 + self.total_tokens_per_second = 0 def reset(self): self.attempted_build = 0 @@ -35,6 +36,7 @@ def reset(self): self.database_cache_miss = 0 self.route_call = defaultdict(int) self.token_count = 0 + self.total_tokens_per_second = 0 def emit_attempted_build_metric(self): self.attempted_build += 1 @@ -71,3 +73,4 @@ def emit_route_call_metric(self, route: str, _metadata: MetricMetadata): def emit_token_count_metrics(self, token_usage: TokenUsage, _metadata: MetricMetadata): self.token_count += token_usage.num_total_tokens + self.total_tokens_per_second = token_usage.total_tokens_per_second