Skip to content

Commit e181a45

Browse files
authored
feat: collect global statistics on LLM usage (#1328)
Signed-off-by: Louis Mandel <[email protected]>
1 parent bf92869 commit e181a45

File tree

9 files changed

+134
-95
lines changed

9 files changed

+134
-95
lines changed

pdl-live-react/src/pdl_ast.d.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,9 @@ export type As4 = "reduce"
537537
export type Reduce = LocalizedExpression | string
538538
export type PdlTrace = BlockType[] | null
539539
export type PdlTrace1 = BlockType[] | null
540+
export type ModelCalls = number
541+
export type CompletionTokens = number
542+
export type PromptTokens = number
540543
/**
541544
* Optional field to ensure that the block is using granite-io.
542545
*
@@ -3610,8 +3613,9 @@ export interface JoinReduce {
36103613
* Internal data structure to record token consumption usage information.
36113614
*/
36123615
export interface PdlUsage {
3613-
completion_tokens?: number | null
3614-
prompt_tokens?: number | null
3616+
model_calls?: ModelCalls
3617+
completion_tokens?: CompletionTokens
3618+
prompt_tokens?: PromptTokens
36153619
[k: string]: unknown
36163620
}
36173621
/**

src/pdl/pdl-schema.json

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4174,13 +4174,20 @@
41744174
"PdlUsage": {
41754175
"description": "Internal data structure to record token consumption usage information.",
41764176
"properties": {
4177+
"model_calls": {
4178+
"default": 0,
4179+
"title": "Model Calls",
4180+
"type": "integer"
4181+
},
41774182
"completion_tokens": {
4178-
"$ref": "#/$defs/OptionalInt",
4179-
"default": 0
4183+
"default": 0,
4184+
"title": "Completion Tokens",
4185+
"type": "integer"
41804186
},
41814187
"prompt_tokens": {
4182-
"$ref": "#/$defs/OptionalInt",
4183-
"default": 0
4188+
"default": 0,
4189+
"title": "Prompt Tokens",
4190+
"type": "integer"
41844191
}
41854192
},
41864193
"title": "PdlUsage",

src/pdl/pdl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
BlockType,
1616
PdlBlock,
1717
PdlLocationType,
18+
PdlUsage,
1819
Program,
1920
RoleType,
2021
ScopeType,
@@ -64,6 +65,7 @@ class Result(TypedDict):
6465
scope: dict[str, Any]
6566
trace: BlockType
6667
replay: dict[str, Any]
68+
usage: PdlUsage
6769

6870

6971
def exec_program(
@@ -105,6 +107,7 @@ def exec_program(
105107
"scope": future_scope.result(),
106108
"trace": trace,
107109
"replay": state.replay,
110+
"usage": state.llm_usage,
108111
}
109112
return result_all
110113
case _:

src/pdl/pdl_ast.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,11 +376,14 @@ class PdlTiming(BaseModel):
376376
class PdlUsage(BaseModel):
377377
"""Internal data structure to record token consumption usage information."""
378378

379-
completion_tokens: OptionalInt = 0
380-
"""Completion tokens consumed
379+
model_calls: int = 0
380+
"""Number of calls to LLMs.
381381
"""
382-
prompt_tokens: OptionalInt = 0
383-
"""Prompt tokens consumed
382+
completion_tokens: int = 0
383+
"""Completion tokens consumed.
384+
"""
385+
prompt_tokens: int = 0
386+
"""Prompt tokens consumed.
384387
"""
385388

386389

src/pdl/pdl_dumper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def timing_to_dict(timing: PdlTiming) -> dict:
407407

408408
def usage_to_dict(usage: PdlUsage) -> dict:
409409
d: dict = {}
410+
d["model_calls"] = usage.model_calls
410411
d["completion_tokens"] = usage.completion_tokens
411412
d["prompt_tokens"] = usage.prompt_tokens
412413
return d

src/pdl/pdl_interpreter.py

Lines changed: 27 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -11,44 +11,39 @@
1111
# TODO: temporarily disabling warnings to mute a pydantic warning from liteLLM
1212
import warnings
1313
from abc import ABC, abstractmethod
14-
from asyncio import AbstractEventLoop
1514
from concurrent.futures import ThreadPoolExecutor
1615
from functools import partial, reduce
1716
from itertools import count
1817
from os import getenv
19-
20-
warnings.filterwarnings("ignore", "Valid config keys have changed in V2")
21-
22-
from pathlib import Path # noqa: E402
23-
from typing import ( # noqa: E402
18+
from pathlib import Path
19+
from typing import (
2420
IO,
2521
Any,
2622
Generator,
27-
Generic,
2823
Iterable,
2924
Optional,
3025
Sequence,
3126
Tuple,
3227
TypeVar,
3328
)
3429

35-
import httpx # noqa: E402
36-
import json_repair # noqa: E402
37-
import yaml # noqa: E402
38-
from jinja2 import ( # noqa: E402
30+
import httpx
31+
import json_repair
32+
import yaml
33+
from jinja2 import (
3934
Environment,
4035
StrictUndefined,
4136
Template,
4237
TemplateSyntaxError,
4338
UndefinedError,
4439
meta,
4540
)
46-
from jinja2.nodes import TemplateData # noqa: E402
47-
from jinja2.runtime import Undefined # noqa: E402
48-
from pydantic import BaseModel, ConfigDict, Field # noqa: E402
49-
from pydantic.json_schema import SkipJsonSchema # noqa: E402
41+
from jinja2.nodes import TemplateData
42+
from jinja2.runtime import Undefined
43+
from pydantic import Field
44+
from pydantic.json_schema import SkipJsonSchema
5045

51-
from .pdl_ast import ( # noqa: E402
46+
from .pdl_ast import (
5247
AdvancedBlockType,
5348
AggregatorBlock,
5449
AnyPattern,
@@ -121,7 +116,7 @@
121116
TextBlock,
122117
empty_block_location,
123118
)
124-
from .pdl_context import ( # noqa: E402
119+
from .pdl_context import (
125120
DependentContext,
126121
IndependentContext,
127122
PDLContext,
@@ -131,19 +126,19 @@
131126
deserialize,
132127
ensure_context,
133128
)
134-
from .pdl_lazy import PdlConst, PdlDict, PdlLazy, PdlList, lazy_apply # noqa: E402
135-
from .pdl_llms import LitellmModel # noqa: E402
136-
from .pdl_location_utils import append, get_loc_string # noqa: E402
137-
from .pdl_parser import PDLParseError, parse_file, parse_str # noqa: E402
138-
from .pdl_python_repl import PythonREPL # noqa: E402
139-
from .pdl_scheduler import ( # noqa: E402
140-
create_event_loop_thread,
129+
from .pdl_interpreter_state import InterpreterState
130+
from .pdl_lazy import PdlConst, PdlDict, PdlLazy, PdlList, lazy_apply
131+
from .pdl_llms import LitellmModel
132+
from .pdl_location_utils import append, get_loc_string
133+
from .pdl_parser import PDLParseError, parse_file, parse_str
134+
from .pdl_python_repl import PythonREPL
135+
from .pdl_scheduler import (
141136
yield_background,
142137
yield_result,
143138
)
144-
from .pdl_schema_utils import get_json_schema # noqa: E402
145-
from .pdl_schema_validator import type_check_args, type_check_spec # noqa: E402
146-
from .pdl_utils import ( # noqa: E402
139+
from .pdl_schema_utils import get_json_schema
140+
from .pdl_schema_validator import type_check_args, type_check_spec
141+
from .pdl_utils import (
147142
GeneratorWrapper,
148143
apply_defaults,
149144
get_contribute_context_value,
@@ -153,6 +148,8 @@
153148
write_trace,
154149
)
155150

151+
warnings.filterwarnings("ignore", "Valid config keys have changed in V2")
152+
156153
empty_scope: ScopeType = PdlDict(
157154
{
158155
"pdl_context": DependentContext([]),
@@ -162,60 +159,6 @@
162159
)
163160

164161

165-
RefT = TypeVar("RefT")
166-
167-
168-
class Ref(Generic[RefT]):
169-
def __init__(self, ref: RefT):
170-
self.ref = ref
171-
172-
173-
class InterpreterState(BaseModel):
174-
model_config = ConfigDict(arbitrary_types_allowed=True)
175-
176-
yield_result: bool = False
177-
"""Stream the result on the standard output as soon as possible."""
178-
yield_background: bool = False
179-
"""Stream the toplevel pdl_context on the standard output as soon as possible."""
180-
batch: int = 1
181-
"""
182-
Stream the output of the LLM
183-
- batch=0: streaming
184-
- batch=1: call to generate with `input`
185-
"""
186-
role: RoleType = "user"
187-
"""Current role to add messages in the context."""
188-
cwd: Path = Path.cwd()
189-
"""Current working directory."""
190-
id_stack: list[str] = []
191-
"""Id generator for the UI."""
192-
193-
# The following are shared variable that should be modified by side effects
194-
imported: dict[str, tuple[ScopeType, BlockType]] = {}
195-
"""Cache containing the imported files."""
196-
event_loop: AbstractEventLoop = Field(default_factory=create_event_loop_thread)
197-
"""Event loop to schedule LLM calls."""
198-
current_pdl_context: Ref[LazyMessages] = Ref(DependentContext([]))
199-
"""Current value of the context set at the beginning of the execution of the block."""
200-
replay: dict[str, Any] = {}
201-
202-
def with_yield_result(self: "InterpreterState", b: bool) -> "InterpreterState":
203-
return self.model_copy(update={"yield_result": b})
204-
205-
def with_yield_background(self: "InterpreterState", b: bool) -> "InterpreterState":
206-
return self.model_copy(update={"yield_background": b})
207-
208-
def with_role(self: "InterpreterState", role: RoleType) -> "InterpreterState":
209-
return self.model_copy(update={"role": role})
210-
211-
def with_id(self: "InterpreterState", n: str) -> "InterpreterState":
212-
stack = self.id_stack if self.id_stack is not None else []
213-
return self.model_copy(update={"id_stack": stack + [n]})
214-
215-
def with_iter(self: "InterpreterState", i: int) -> "InterpreterState":
216-
return self.with_id(str(i))
217-
218-
219162
class ClosureBlock(FunctionBlock):
220163
pdl__scope: SkipJsonSchema[Optional[ScopeType]] = Field(repr=False)
221164
pdl__state: SkipJsonSchema[InterpreterState] = Field(repr=False)
@@ -2017,9 +1960,11 @@ def generate_client_response_streaming(
20171960
and usage["prompt_tokens"] is not None
20181961
):
20191962
block.pdl__usage = PdlUsage(
1963+
model_calls=1,
20201964
completion_tokens=usage["completion_tokens"],
20211965
prompt_tokens=usage["prompt_tokens"],
20221966
)
1967+
state.add_usage(block.pdl__usage)
20231968
return PdlConst(complete_msg), PdlConst(raw_result)
20241969

20251970

@@ -2057,11 +2002,11 @@ def generate_client_response_single(
20572002
match block:
20582003
case LitellmModelBlock():
20592004
message, response = LitellmModel.generate_text(
2005+
state=state,
20602006
block=block,
20612007
model_id=value_of_expr(block.model),
20622008
messages=model_input,
20632009
parameters=litellm_parameters_to_dict(parameters),
2064-
event_loop=state.event_loop,
20652010
)
20662011
case GraniteioModelBlock():
20672012
from .pdl_granite_io import GraniteioModel

src/pdl/pdl_interpreter_state.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from asyncio import AbstractEventLoop
2+
from pathlib import Path
3+
from typing import Any
4+
5+
from pydantic import BaseModel, ConfigDict, Field
6+
7+
from pdl.pdl_context import DependentContext
8+
from pdl.pdl_utils import Ref
9+
10+
from .pdl_ast import BlockType, LazyMessages, PdlUsage, RoleType, ScopeType
11+
from .pdl_scheduler import create_event_loop_thread
12+
13+
14+
class InterpreterState(BaseModel):
15+
model_config = ConfigDict(arbitrary_types_allowed=True)
16+
17+
yield_result: bool = False
18+
"""Stream the result on the standard output as soon as possible."""
19+
yield_background: bool = False
20+
"""Stream the toplevel pdl_context on the standard output as soon as possible."""
21+
batch: int = 1
22+
"""
23+
Stream the output of the LLM
24+
- batch=0: streaming
25+
- batch=1: call to generate with `input`
26+
"""
27+
role: RoleType = "user"
28+
"""Current role to add messages in the context."""
29+
cwd: Path = Path.cwd()
30+
"""Current working directory."""
31+
id_stack: list[str] = []
32+
"""Id generator for the UI."""
33+
34+
# The following are shared variable that should be modified by side effects
35+
imported: dict[str, tuple[ScopeType, BlockType]] = {}
36+
"""Cache containing the imported files."""
37+
event_loop: AbstractEventLoop = Field(default_factory=create_event_loop_thread)
38+
"""Event loop to schedule LLM calls."""
39+
current_pdl_context: Ref[LazyMessages] = Ref(DependentContext([]))
40+
"""Current value of the context set at the beginning of the execution of the block."""
41+
replay: dict[str, Any] = {}
42+
"""Dictionary that associate runtime block ids with their values to be able to replay an execution."""
43+
llm_usage: PdlUsage = PdlUsage()
44+
"""Record statistics about LLM usage."""
45+
46+
def with_yield_result(self: "InterpreterState", b: bool) -> "InterpreterState":
47+
return self.model_copy(update={"yield_result": b})
48+
49+
def with_yield_background(self: "InterpreterState", b: bool) -> "InterpreterState":
50+
return self.model_copy(update={"yield_background": b})
51+
52+
def with_role(self: "InterpreterState", role: RoleType) -> "InterpreterState":
53+
return self.model_copy(update={"role": role})
54+
55+
def with_id(self: "InterpreterState", n: str) -> "InterpreterState":
56+
stack = self.id_stack if self.id_stack is not None else []
57+
return self.model_copy(update={"id_stack": stack + [n]})
58+
59+
def with_iter(self: "InterpreterState", i: int) -> "InterpreterState":
60+
return self.with_id(str(i))
61+
62+
def add_usage(self, usage: PdlUsage):
63+
self.llm_usage.model_calls += usage.model_calls
64+
self.llm_usage.completion_tokens += usage.completion_tokens
65+
self.llm_usage.prompt_tokens += usage.prompt_tokens

0 commit comments

Comments
 (0)