Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ opentelemetry = [
]
pydantic = ["pydantic>=2.0.0,<3"]
openai-agents = [
"openai-agents >= 0.0.19,<0.1",
"openai-agents >= 0.1,<0.2",
"eval-type-backport>=0.2.2; python_version < '3.10'"
]

Expand Down
55 changes: 27 additions & 28 deletions temporalio/contrib/openai_agents/_temporal_model_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,35 @@

logger = logging.getLogger(__name__)

with workflow.unsafe.imports_passed_through():
from typing import Any, AsyncIterator, Optional, Sequence, Union, cast
from typing import Any, AsyncIterator, Optional, Sequence, Union, cast

from agents import (
AgentOutputSchema,
AgentOutputSchemaBase,
ComputerTool,
FileSearchTool,
FunctionTool,
Handoff,
Model,
ModelResponse,
ModelSettings,
ModelTracing,
Tool,
TResponseInputItem,
WebSearchTool,
)
from agents.items import TResponseStreamEvent
from openai.types.responses.response_prompt_param import ResponsePromptParam
from agents import (
AgentOutputSchema,
AgentOutputSchemaBase,
ComputerTool,
FileSearchTool,
FunctionTool,
Handoff,
Model,
ModelResponse,
ModelSettings,
ModelTracing,
Tool,
TResponseInputItem,
WebSearchTool,
)
from agents.items import TResponseStreamEvent
from openai.types.responses.response_prompt_param import ResponsePromptParam

from temporalio.contrib.openai_agents.invoke_model_activity import (
ActivityModelInput,
AgentOutputSchemaInput,
FunctionToolInput,
HandoffInput,
ModelActivity,
ModelTracingInput,
ToolInput,
)
from temporalio.contrib.openai_agents.invoke_model_activity import (
ActivityModelInput,
AgentOutputSchemaInput,
FunctionToolInput,
HandoffInput,
ModelActivity,
ModelTracingInput,
ToolInput,
)


class _TemporalModelStub(Model):
Expand Down
146 changes: 3 additions & 143 deletions temporalio/contrib/openai_agents/open_ai_data_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,147 +5,7 @@

from __future__ import annotations

import importlib
import inspect
from typing import Any, Optional, Type, TypeVar
import temporalio.contrib.pydantic

from agents import Usage
from agents.items import TResponseOutputItem
from openai import NOT_GIVEN, BaseModel
from pydantic import RootModel, TypeAdapter

import temporalio.api.common.v1
from temporalio import workflow
from temporalio.converter import (
CompositePayloadConverter,
DataConverter,
DefaultPayloadConverter,
EncodingPayloadConverter,
JSONPlainPayloadConverter,
)

T = TypeVar("T", bound=BaseModel)


class _WrapperModel(RootModel[T]):
model_config = {
"arbitrary_types_allowed": True,
}


class _OpenAIJSONPlainPayloadConverter(EncodingPayloadConverter):
"""Payload converter for OpenAI agent types that supports Pydantic models and standard Python types.

This converter extends the standard JSON payload converter to handle OpenAI agent-specific
types, particularly Pydantic models. It supports:

1. All Pydantic models and their nested structures
2. Standard JSON-serializable types
3. Python standard library types like:
- dataclasses
- datetime objects
- sets
- UUIDs
4. Custom types composed of any of the above

The converter uses Pydantic's serialization capabilities to ensure proper handling
of complex types while maintaining compatibility with Temporal's payload system.

See https://docs.pydantic.dev/latest/api/standard_library_types/ for details
on supported types.
"""

@property
def encoding(self) -> str:
"""Get the encoding identifier for this converter.

Returns:
The string "json/plain" indicating this is a plain JSON converter.
"""
return "json/plain"

def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]:
"""Convert a value to a Temporal payload.

This method wraps the value in a Pydantic RootModel to handle arbitrary types
and serializes it to JSON.

Args:
value: The value to convert to a payload.

Returns:
A Temporal payload containing the serialized value, or None if the value
cannot be converted.
"""
wrapper = _WrapperModel[Any](root=value)
data = wrapper.model_dump_json(exclude_unset=True).encode()

return temporalio.api.common.v1.Payload(
metadata={"encoding": self.encoding.encode()}, data=data
)

def from_payload(
self,
payload: temporalio.api.common.v1.Payload,
type_hint: Optional[Type] = None,
) -> Any:
"""Convert a Temporal payload back to a Python value.

This method deserializes the JSON payload and validates it against the
provided type hint using Pydantic's validation system.

Args:
payload: The Temporal payload to convert.
type_hint: Optional type hint for validation.

Returns:
The deserialized and validated value.

Note:
The type hint is used for validation but the actual type returned
may be a Pydantic model instance.
"""
_type_hint = type_hint or Any
wrapper = _WrapperModel[_type_hint] # type: ignore[valid-type]

with workflow.unsafe.imports_passed_through():
with workflow.unsafe.sandbox_unrestricted():
wrapper.model_rebuild(
_types_namespace=_get_openai_modules()
| {"TResponseOutputItem": TResponseOutputItem, "Usage": Usage}
)
return TypeAdapter(wrapper).validate_json(payload.data.decode()).root


def _get_openai_modules() -> dict[Any, Any]:
def get_modules(module):
result_dict: dict[Any, Any] = {}
for _, mod in inspect.getmembers(module, inspect.ismodule):
result_dict |= mod.__dict__ | get_modules(mod)
return result_dict

return get_modules(importlib.import_module("openai.types"))


class OpenAIPayloadConverter(CompositePayloadConverter):
"""Payload converter for payloads containing pydantic model instances.

JSON conversion is replaced with a converter that uses
:py:class:`PydanticJSONPlainPayloadConverter`.
"""

def __init__(self) -> None:
"""Initialize object"""
json_payload_converter = _OpenAIJSONPlainPayloadConverter()
super().__init__(
*(
c
if not isinstance(c, JSONPlainPayloadConverter)
else json_payload_converter
for c in DefaultPayloadConverter.default_encoding_payload_converters
)
)


open_ai_data_converter = DataConverter(payload_converter_class=OpenAIPayloadConverter)
"""Open AI Agent library types data converter"""
open_ai_data_converter = temporalio.contrib.pydantic.pydantic_data_converter
"""DEPRECATED, use temporalio.contrib.pydantic.pydantic_data_converter"""
7 changes: 3 additions & 4 deletions temporalio/contrib/openai_agents/temporal_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from datetime import timedelta
from typing import Any, Callable, Optional

from agents import FunctionTool, RunContextWrapper, Tool
from agents.function_schema import function_schema

from temporalio import activity, workflow
from temporalio.common import Priority, RetryPolicy
from temporalio.exceptions import ApplicationError, TemporalError
from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe

with unsafe.imports_passed_through():
from agents import FunctionTool, RunContextWrapper, Tool
from agents.function_schema import function_schema


class ToolSerializationError(TemporalError):
"""Error that occurs when a tool output could not be serialized."""
Expand Down
93 changes: 74 additions & 19 deletions temporalio/worker/workflow_sandbox/_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,30 @@ def applied(self) -> Iterator[None]:
while it is running and therefore should be locked against other
code running at the same time.
"""
with _thread_local_sys_modules.applied(sys, "modules", self.new_modules):
with _thread_local_import.applied(builtins, "__import__", self.import_func):
with self._builtins_restricted():
yield None
orig_importer = Importer.current_importer()
Importer._thread_local_current.importer = self
try:
with _thread_local_sys_modules.applied(sys, "modules", self.new_modules):
with _thread_local_import.applied(
builtins, "__import__", self.import_func
):
with self._builtins_restricted():
yield None
finally:
Importer._thread_local_current.importer = orig_importer

@contextmanager
def _unapplied(self) -> Iterator[None]:
orig_importer = Importer.current_importer()
Importer._thread_local_current.importer = None
# Set orig modules, then unset on complete
with _thread_local_sys_modules.unapplied():
with _thread_local_import.unapplied():
with self._builtins_unrestricted():
yield None
try:
with _thread_local_sys_modules.unapplied():
with _thread_local_import.unapplied():
with self._builtins_unrestricted():
yield None
finally:
Importer._thread_local_current.importer = orig_importer

def _traced_import(
self,
Expand Down Expand Up @@ -211,6 +223,8 @@ def _import(
# Put it on the parent
if parent:
setattr(sys.modules[parent], child, sys.modules[full_name])
# All children of this module that are on the original sys
# modules but not here and are passthrough

# If the module is __temporal_main__ and not already in sys.modules,
# we load it from whatever file __main__ was originally in
Expand Down Expand Up @@ -251,21 +265,31 @@ def _assert_valid_module(self, name: str) -> None:
):
raise RestrictedWorkflowAccessError(name)

def module_configured_passthrough(self, name: str) -> bool:
"""Whether the given module name is configured as passthrough."""
if (
self.restrictions.passthrough_all_modules
or name in self.restrictions.passthrough_modules
):
return True
# Iterate backwards looking if configured passthrough
end_dot = -1
while True:
end_dot = name.find(".", end_dot + 1)
if end_dot == -1:
return False
elif name[:end_dot] in self.restrictions.passthrough_modules:
break
return True

def _maybe_passthrough_module(self, name: str) -> Optional[types.ModuleType]:
# If imports not passed through and all modules are not passed through
# and name not in passthrough modules, check parents
if (
not temporalio.workflow.unsafe.is_imports_passed_through()
and not self.restrictions.passthrough_all_modules
and name not in self.restrictions.passthrough_modules
and not self.module_configured_passthrough(name)
):
end_dot = -1
while True:
end_dot = name.find(".", end_dot + 1)
if end_dot == -1:
return None
elif name[:end_dot] in self.restrictions.passthrough_modules:
break
return None
# Do the pass through
with self._unapplied():
_trace("Passing module %s through from host", name)
Expand Down Expand Up @@ -311,6 +335,13 @@ def _builtins_unrestricted(self) -> Iterator[None]:
stack.enter_context(thread_local.unapplied())
yield None

_thread_local_current = threading.local()

@staticmethod
def current_importer() -> Optional[Importer]:
"""Get the current importer if any."""
return Importer._thread_local_current.__dict__.get("importer")


_T = TypeVar("_T")

Expand Down Expand Up @@ -385,13 +416,23 @@ class _ThreadLocalSysModules(
MutableMapping[str, types.ModuleType],
):
def __contains__(self, key: object) -> bool:
return key in self.current
if key in self.current:
return True
return (
isinstance(key, str)
and self._lazily_passthrough_if_available(key) is not None
)

def __delitem__(self, key: str) -> None:
del self.current[key]

def __getitem__(self, key: str) -> types.ModuleType:
return self.current[key]
try:
return self.current[key]
except KeyError:
if module := self._lazily_passthrough_if_available(key):
return module
raise

def __len__(self) -> int:
return len(self.current)
Expand Down Expand Up @@ -431,6 +472,20 @@ def copy(self) -> Dict[str, types.ModuleType]:
def fromkeys(cls, *args, **kwargs) -> Any:
return dict.fromkeys(*args, **kwargs)

def _lazily_passthrough_if_available(self, key: str) -> Optional[types.ModuleType]:
# We only lazily pass through if it's in orig, lazy not disabled, and
# module configured as pass through
if (
key in self.orig
and (importer := Importer.current_importer())
and not importer.restrictions.disable_lazy_sys_module_passthrough
and importer.module_configured_passthrough(key)
):
orig = self.orig[key]
self.current[key] = orig
return orig
return None


_thread_local_sys_modules = _ThreadLocalSysModules(sys.modules)

Expand Down
Loading