Skip to content
15 changes: 12 additions & 3 deletions src/pdl/pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class InterpreterConfig(TypedDict, total=False):
"""
cwd: Path
"""Path considered as the current working directory for file reading."""
replay: dict[str, Any]
"""Execute the program reusing some already computed values.
"""


def exec_program(
Expand All @@ -66,9 +69,10 @@ def exec_program(
output: Configure the output of the returned value of this function. Defaults to `"result"`

Returns:
Return the final result if `output` is set to `"result"`. If set of `all`, it returns a dictionary containing, `result`, `scope`, and `trace`.
Return the final result if `output` is set to `"result"`. If set of `all`, it returns a dictionary containing, `result`, `scope`, `trace`, and `replay`.
"""
config = config or {}
config = config or InterpreterConfig()
config["replay"] = dict(config.get("replay", {}))
state = InterpreterState(**config)
if not isinstance(scope, PdlDict):
scope = PdlDict(scope or {})
Expand All @@ -83,7 +87,12 @@ def exec_program(
return result
case "all":
scope = future_scope.result()
return {"result": result, "scope": scope, "trace": trace}
return {
"result": result,
"scope": scope,
"trace": trace,
"replay": state.replay,
}
case _:
assert False, 'The `output` variable should be "result" or "all"'

Expand Down
43 changes: 41 additions & 2 deletions src/pdl/pdl_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class InterpreterState(BaseModel):
"""Event loop to schedule LLM calls."""
current_pdl_context: Ref[LazyMessages] = Ref(DependentContext([]))
"""Current value of the context set at the beginning of the execution of the block."""
replay: dict[str, Any] = {}

def with_yield_result(self: "InterpreterState", b: bool) -> "InterpreterState":
return self.model_copy(update={"yield_result": b})
Expand Down Expand Up @@ -305,7 +306,7 @@ def process_prog(
stdlib_file = Path(__file__).parent / "pdl_stdlib.pdl"
stdlib, _ = parse_file(stdlib_file)
_, _, stdlib_dict, _ = process_block(
state.with_yield_background(False).with_yield_result(False),
state.with_yield_background(False).with_yield_result(False).with_id("stdlib"),
empty_scope,
stdlib.root,
loc,
Expand Down Expand Up @@ -505,7 +506,7 @@ def process_advance_block_retry( # noqa: C901
trial_total = max_retry + 1
for trial_idx in range(trial_total): # pylint: disable=too-many-nested-blocks
try:
result, background, new_scope, trace = process_block_body(
result, background, new_scope, trace = process_block_body_with_replay(
state, scope, block, loc
)

Expand Down Expand Up @@ -640,6 +641,42 @@ def result_with_type_checking(
return result


def process_block_body_with_replay(
state: InterpreterState,
scope: ScopeType,
block: AdvancedBlockType,
loc: PdlLocationType,
) -> tuple[PdlLazy[Any], LazyMessages, ScopeType, AdvancedBlockType]:
if isinstance(block, LeafBlock):
block_id = block.pdl__id
assert isinstance(block_id, str)
try:
result = state.replay[block_id]
background: LazyMessages = SingletonContext(
PdlDict({"role": state.role, "content": result})
)
if state.yield_result:
yield_result(result.result(), block.kind)
if state.yield_background:
yield_background(background)
trace = block
# Special case
match block:
case ModelBlock():
if block.modelResponse is not None:
assert block.pdl__id is not None
raw_result = state.replay[block.pdl__id + ".modelResponse"]
scope = scope | {block.modelResponse: raw_result}
except KeyError:
result, background, scope, trace = process_block_body(
state, scope, block, loc
)
state.replay[block_id] = result
else:
result, background, scope, trace = process_block_body(state, scope, block, loc)
return result, background, scope, trace


def process_block_body(
state: InterpreterState,
scope: ScopeType,
Expand Down Expand Up @@ -1815,6 +1852,8 @@ def get_transformed_inputs(kwargs):
)
if block.modelResponse is not None:
scope = scope | {block.modelResponse: raw_result}
assert block.pdl__id is not None
state.replay[block.pdl__id + ".modelResponse"] = raw_result
trace: BlockTypeTVarProcessCallModel = concrete_block.model_copy(
update={"pdl__result": result}
) # pyright: ignore
Expand Down
3 changes: 0 additions & 3 deletions tests/data/function.pdl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,3 @@ defs:
${ notes }

### Answer:



55 changes: 45 additions & 10 deletions tests/test_examples_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

import yaml
from pytest import CaptureFixture, MonkeyPatch
Expand Down Expand Up @@ -104,6 +104,7 @@ class FailedResults:
wrong_results: Dict[str, str] = field(default_factory=lambda: {})
unexpected_parse_error: Dict[str, str] = field(default_factory=lambda: {})
unexpected_runtime_error: Dict[str, str] = field(default_factory=lambda: {})
wrong_replay_results: Dict[str, str] = field(default_factory=lambda: {})


# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -161,7 +162,9 @@ def __init__(self, monkeypatch: MonkeyPatch) -> None:
self.__collect_expected_results()

# Inits execution results for each PDL file
self.execution_results: Dict[str, ExecutionResult] = {}
self.execution_results: Dict[
str, Tuple[ExecutionResult, ExecutionResult | None]
] = {}

# Init failed results
self.failed_results = FailedResults()
Expand Down Expand Up @@ -199,13 +202,11 @@ def __collect_expected_results(self) -> None:

self.expected_results[file] = expected_result

def __execute_file(self, pdl_file_name: str) -> None:
def __execute_and_replay_file(self, pdl_file_name: str) -> None:
"""
Tests the result of a single file and returns the result output and the error code
"""

exec_result = ExecutionResult()

pdl_file_path = pathlib.Path(pdl_file_name)
scope: ScopeType = PdlDict({})

Expand All @@ -217,13 +218,27 @@ def __execute_file(self, pdl_file_name: str) -> None:
if inputs.scope is not None:
scope = inputs.scope

exec_result, output = self.__execute_file(pdl_file_path, scope, replay={})

if output is not None:
replay_result, _ = self.__execute_file(
pdl_file_path, scope, replay=output["replay"]
)
else:
replay_result = None

self.execution_results[pdl_file_name] = exec_result, replay_result

def __execute_file(self, pdl_file_path, scope, replay):
exec_result = ExecutionResult()
output = None
try:
# Execute file
output = pdl.exec_file(
pdl_file_path,
scope=scope,
output="all",
config=pdl.InterpreterConfig(batch=1),
config=pdl.InterpreterConfig(batch=1, replay=replay),
)

exec_result.result = str(output["result"])
Expand All @@ -235,8 +250,7 @@ def __execute_file(self, pdl_file_name: str) -> None:
except Exception as exc:
exec_result.result = str(exc)
exec_result.error_code = ExecutionErrorCode.RUNTIME_ERROR

self.execution_results[pdl_file_name] = exec_result
return exec_result, output

def populate_exec_result_for_checks(self) -> None:
"""
Expand All @@ -245,7 +259,7 @@ def populate_exec_result_for_checks(self) -> None:

for file in self.check:
if file not in self.skip:
self.__execute_file(file)
self.__execute_and_replay_file(file)

def validate_expected_and_actual(self) -> None:
"""
Expand All @@ -256,11 +270,12 @@ def validate_expected_and_actual(self) -> None:
wrong_result: Dict[str, str] = {}
unexpected_parse_error: Dict[str, str] = {}
unexpected_runtime_error: Dict[str, str] = {}
wrong_replay_result: Dict[str, str] = {}

for file in self.check:
if file not in self.skip:
expected_result = self.expected_results[file]
actual_result = self.execution_results[file]
actual_result, replay_result = self.execution_results[file]
match = expected_result.compare_to_execution(actual_result)

if not match:
Expand All @@ -274,7 +289,14 @@ def validate_expected_and_actual(self) -> None:
if actual_result.result is not None:
wrong_result[file] = actual_result.result

if replay_result is not None:
match_replay = expected_result.compare_to_execution(replay_result)
if not match_replay:
if replay_result.result is not None:
wrong_replay_result[file] = replay_result.result

self.failed_results.wrong_results = wrong_result
self.failed_results.wrong_replay_results = wrong_replay_result
self.failed_results.unexpected_parse_error = unexpected_parse_error
self.failed_results.unexpected_runtime_error = unexpected_runtime_error

Expand Down Expand Up @@ -347,6 +369,16 @@ def test_example_runs(capsys: CaptureFixture[str], monkeypatch: MonkeyPatch) ->
f"Actual result (copy everything below this line):\n✂️ ------------------------------------------------------------\n{actual}\n-------------------------------------------------------------"
)

# Print the actual results for wrong replay results
for file, actual in background.failed_results.wrong_replay_results.items():
print(
"\n============================================================================"
)
print(f"File that produced wrong REPLAY result: {file}")
print(
f"Replay result:\n ------------------------------------------------------------\n{actual}\n-------------------------------------------------------------"
)

assert (
len(background.failed_results.unexpected_parse_error) == 0
), f"Unexpected parse error: {background.failed_results.unexpected_parse_error}"
Expand All @@ -356,3 +388,6 @@ def test_example_runs(capsys: CaptureFixture[str], monkeypatch: MonkeyPatch) ->
assert (
len(background.failed_results.wrong_results) == 0
), f"Wrong results: {background.failed_results.wrong_results}"
assert (
len(background.failed_results.wrong_replay_results) == 0
), f"Wrong replay results: {background.failed_results.wrong_results}"