|
11 | 11 | # TODO: temporarily disabling warnings to mute a pydantic warning from liteLLM |
12 | 12 | import warnings |
13 | 13 | from abc import ABC, abstractmethod |
14 | | -from asyncio import AbstractEventLoop |
15 | 14 | from concurrent.futures import ThreadPoolExecutor |
16 | 15 | from functools import partial, reduce |
17 | 16 | from itertools import count |
18 | 17 | from os import getenv |
19 | | - |
20 | | -warnings.filterwarnings("ignore", "Valid config keys have changed in V2") |
21 | | - |
22 | | -from pathlib import Path # noqa: E402 |
23 | | -from typing import ( # noqa: E402 |
| 18 | +from pathlib import Path |
| 19 | +from typing import ( |
24 | 20 | IO, |
25 | 21 | Any, |
26 | 22 | Generator, |
27 | | - Generic, |
28 | 23 | Iterable, |
29 | 24 | Optional, |
30 | 25 | Sequence, |
31 | 26 | Tuple, |
32 | 27 | TypeVar, |
33 | 28 | ) |
34 | 29 |
|
35 | | -import httpx # noqa: E402 |
36 | | -import json_repair # noqa: E402 |
37 | | -import yaml # noqa: E402 |
38 | | -from jinja2 import ( # noqa: E402 |
| 30 | +import httpx |
| 31 | +import json_repair |
| 32 | +import yaml |
| 33 | +from jinja2 import ( |
39 | 34 | Environment, |
40 | 35 | StrictUndefined, |
41 | 36 | Template, |
42 | 37 | TemplateSyntaxError, |
43 | 38 | UndefinedError, |
44 | 39 | meta, |
45 | 40 | ) |
46 | | -from jinja2.nodes import TemplateData # noqa: E402 |
47 | | -from jinja2.runtime import Undefined # noqa: E402 |
48 | | -from pydantic import BaseModel, ConfigDict, Field # noqa: E402 |
49 | | -from pydantic.json_schema import SkipJsonSchema # noqa: E402 |
| 41 | +from jinja2.nodes import TemplateData |
| 42 | +from jinja2.runtime import Undefined |
| 43 | +from pydantic import Field |
| 44 | +from pydantic.json_schema import SkipJsonSchema |
50 | 45 |
|
51 | | -from .pdl_ast import ( # noqa: E402 |
| 46 | +from .pdl_ast import ( |
52 | 47 | AdvancedBlockType, |
53 | 48 | AggregatorBlock, |
54 | 49 | AnyPattern, |
|
121 | 116 | TextBlock, |
122 | 117 | empty_block_location, |
123 | 118 | ) |
124 | | -from .pdl_context import ( # noqa: E402 |
| 119 | +from .pdl_context import ( |
125 | 120 | DependentContext, |
126 | 121 | IndependentContext, |
127 | 122 | PDLContext, |
|
131 | 126 | deserialize, |
132 | 127 | ensure_context, |
133 | 128 | ) |
134 | | -from .pdl_lazy import PdlConst, PdlDict, PdlLazy, PdlList, lazy_apply # noqa: E402 |
135 | | -from .pdl_llms import LitellmModel # noqa: E402 |
136 | | -from .pdl_location_utils import append, get_loc_string # noqa: E402 |
137 | | -from .pdl_parser import PDLParseError, parse_file, parse_str # noqa: E402 |
138 | | -from .pdl_python_repl import PythonREPL # noqa: E402 |
139 | | -from .pdl_scheduler import ( # noqa: E402 |
140 | | - create_event_loop_thread, |
| 129 | +from .pdl_interpreter_state import InterpreterState |
| 130 | +from .pdl_lazy import PdlConst, PdlDict, PdlLazy, PdlList, lazy_apply |
| 131 | +from .pdl_llms import LitellmModel |
| 132 | +from .pdl_location_utils import append, get_loc_string |
| 133 | +from .pdl_parser import PDLParseError, parse_file, parse_str |
| 134 | +from .pdl_python_repl import PythonREPL |
| 135 | +from .pdl_scheduler import ( |
141 | 136 | yield_background, |
142 | 137 | yield_result, |
143 | 138 | ) |
144 | | -from .pdl_schema_utils import get_json_schema # noqa: E402 |
145 | | -from .pdl_schema_validator import type_check_args, type_check_spec # noqa: E402 |
146 | | -from .pdl_utils import ( # noqa: E402 |
| 139 | +from .pdl_schema_utils import get_json_schema |
| 140 | +from .pdl_schema_validator import type_check_args, type_check_spec |
| 141 | +from .pdl_utils import ( |
147 | 142 | GeneratorWrapper, |
148 | 143 | apply_defaults, |
149 | 144 | get_contribute_context_value, |
|
153 | 148 | write_trace, |
154 | 149 | ) |
155 | 150 |
|
| 151 | +warnings.filterwarnings("ignore", "Valid config keys have changed in V2") |
| 152 | + |
156 | 153 | empty_scope: ScopeType = PdlDict( |
157 | 154 | { |
158 | 155 | "pdl_context": DependentContext([]), |
|
162 | 159 | ) |
163 | 160 |
|
164 | 161 |
|
165 | | -RefT = TypeVar("RefT") |
166 | | - |
167 | | - |
168 | | -class Ref(Generic[RefT]): |
169 | | - def __init__(self, ref: RefT): |
170 | | - self.ref = ref |
171 | | - |
172 | | - |
173 | | -class InterpreterState(BaseModel): |
174 | | - model_config = ConfigDict(arbitrary_types_allowed=True) |
175 | | - |
176 | | - yield_result: bool = False |
177 | | - """Stream the result on the standard output as soon as possible.""" |
178 | | - yield_background: bool = False |
179 | | - """Stream the toplevel pdl_context on the standard output as soon as possible.""" |
180 | | - batch: int = 1 |
181 | | - """ |
182 | | - Stream the output of the LLM |
183 | | - - batch=0: streaming |
184 | | - - batch=1: call to generate with `input` |
185 | | - """ |
186 | | - role: RoleType = "user" |
187 | | - """Current role to add messages in the context.""" |
188 | | - cwd: Path = Path.cwd() |
189 | | - """Current working directory.""" |
190 | | - id_stack: list[str] = [] |
191 | | - """Id generator for the UI.""" |
192 | | - |
193 | | - # The following are shared variable that should be modified by side effects |
194 | | - imported: dict[str, tuple[ScopeType, BlockType]] = {} |
195 | | - """Cache containing the imported files.""" |
196 | | - event_loop: AbstractEventLoop = Field(default_factory=create_event_loop_thread) |
197 | | - """Event loop to schedule LLM calls.""" |
198 | | - current_pdl_context: Ref[LazyMessages] = Ref(DependentContext([])) |
199 | | - """Current value of the context set at the beginning of the execution of the block.""" |
200 | | - replay: dict[str, Any] = {} |
201 | | - |
202 | | - def with_yield_result(self: "InterpreterState", b: bool) -> "InterpreterState": |
203 | | - return self.model_copy(update={"yield_result": b}) |
204 | | - |
205 | | - def with_yield_background(self: "InterpreterState", b: bool) -> "InterpreterState": |
206 | | - return self.model_copy(update={"yield_background": b}) |
207 | | - |
208 | | - def with_role(self: "InterpreterState", role: RoleType) -> "InterpreterState": |
209 | | - return self.model_copy(update={"role": role}) |
210 | | - |
211 | | - def with_id(self: "InterpreterState", n: str) -> "InterpreterState": |
212 | | - stack = self.id_stack if self.id_stack is not None else [] |
213 | | - return self.model_copy(update={"id_stack": stack + [n]}) |
214 | | - |
215 | | - def with_iter(self: "InterpreterState", i: int) -> "InterpreterState": |
216 | | - return self.with_id(str(i)) |
217 | | - |
218 | | - |
219 | 162 | class ClosureBlock(FunctionBlock): |
220 | 163 | pdl__scope: SkipJsonSchema[Optional[ScopeType]] = Field(repr=False) |
221 | 164 | pdl__state: SkipJsonSchema[InterpreterState] = Field(repr=False) |
@@ -2017,9 +1960,11 @@ def generate_client_response_streaming( |
2017 | 1960 | and usage["prompt_tokens"] is not None |
2018 | 1961 | ): |
2019 | 1962 | block.pdl__usage = PdlUsage( |
| 1963 | + model_calls=1, |
2020 | 1964 | completion_tokens=usage["completion_tokens"], |
2021 | 1965 | prompt_tokens=usage["prompt_tokens"], |
2022 | 1966 | ) |
| 1967 | + state.add_usage(block.pdl__usage) |
2023 | 1968 | return PdlConst(complete_msg), PdlConst(raw_result) |
2024 | 1969 |
|
2025 | 1970 |
|
@@ -2057,11 +2002,11 @@ def generate_client_response_single( |
2057 | 2002 | match block: |
2058 | 2003 | case LitellmModelBlock(): |
2059 | 2004 | message, response = LitellmModel.generate_text( |
| 2005 | + state=state, |
2060 | 2006 | block=block, |
2061 | 2007 | model_id=value_of_expr(block.model), |
2062 | 2008 | messages=model_input, |
2063 | 2009 | parameters=litellm_parameters_to_dict(parameters), |
2064 | | - event_loop=state.event_loop, |
2065 | 2010 | ) |
2066 | 2011 | case GraniteioModelBlock(): |
2067 | 2012 | from .pdl_granite_io import GraniteioModel |
|
0 commit comments