Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
Use with caution in production environments.
"""

from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._temporal_openai_agents import (
OpenAIAgentsPlugin,
TestModel,
TestModelProvider,
set_open_ai_agent_temporal_overrides,
)
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
Expand All @@ -21,9 +23,12 @@
from . import workflow

__all__ = [
"OpenAIAgentsPlugin",
"ModelActivity",
"ModelActivityParameters",
"workflow",
"OpenAIAgentsPlugin",
"OpenAIAgentsTracingInterceptor",
"set_open_ai_agent_temporal_overrides",
"TestModel",
"TestModelProvider",
"workflow",
]
63 changes: 58 additions & 5 deletions temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
FileSearchTool,
FunctionTool,
Handoff,
Model,
ModelProvider,
ModelResponse,
ModelSettings,
Expand Down Expand Up @@ -126,21 +127,73 @@ class ActivityModelInput(TypedDict, total=False):


class ModelActivity:
"""Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled.
"""Class wrapper for model invocation activities to allow model customization.

By default, we use an OpenAIProvider with retries disabled. The activity automatically
detects models prefixed with 'litellm/' and routes them to LiteLLM when available.

Disabling retries in your model of choice is recommended to allow activity retries to define the retry model.
"""

def __init__(self, model_provider: Optional[ModelProvider] = None):
"""Initialize the activity with a model provider."""
self._model_provider = model_provider or OpenAIProvider(
openai_client=AsyncOpenAI(max_retries=0)
)
self._custom_model_provider = model_provider
self._default_model_provider = None # Lazy initialization

def _get_openai_provider(self) -> ModelProvider:
"""Get the OpenAI provider, initializing it lazily to avoid requiring API key upfront."""
if self._custom_model_provider:
return self._custom_model_provider

if self._default_model_provider is None:
self._default_model_provider = OpenAIProvider(
openai_client=AsyncOpenAI(max_retries=0)
)

return self._default_model_provider

def _get_litellm_model(self, model_name: str) -> Model:
"""Get a LiteLLM model for the given model name.

Args:
model_name: Model name prefixed with 'litellm/' (e.g., 'litellm/anthropic/claude-3-5-sonnet')

Returns:
A LiteLLM model instance

Raises:
ImportError: If LiteLLM is not installed
ValueError: If model name is invalid
"""
try:
from temporalio.contrib.openai_agents._litellm_model import LiteLLMModel
except ImportError:
raise ImportError(
f"LiteLLM model '{model_name}' requested but LiteLLM is not installed. "
"Install with: pip install litellm"
)

# Remove the 'litellm/' prefix to get the actual model name
actual_model_name = model_name[8:] # len("litellm/") = 8

if not actual_model_name:
raise ValueError("Model name cannot be just 'litellm/' - provide the actual model after the prefix")

return LiteLLMModel(model=actual_model_name)

@activity.defn
@_auto_heartbeater
async def invoke_model_activity(self, input: ActivityModelInput) -> ModelResponse:
"""Activity that invokes a model with the given input."""
model = self._model_provider.get_model(input.get("model_name"))
model_name = input.get("model_name")

# Check if this is a LiteLLM model (prefixed with 'litellm/')
if model_name and model_name.startswith("litellm/"):
model = self._get_litellm_model(model_name)
else:
# Use regular model provider (AsyncOpenAI by default, lazy initialization)
openai_provider = self._get_openai_provider()
model = openai_provider.get_model(model_name)

async def empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
return ""
Expand Down
271 changes: 271 additions & 0 deletions temporalio/contrib/openai_agents/_litellm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
"""LiteLLM model class used when models are prefixed with 'litellm/'.
The routing logic is handled in the ModelActivity itself.
"""

from typing import Optional, Union

from agents import Model, ModelResponse, Usage
from agents.models.chatcmpl_converter import Converter
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from openai.types.chat import ChatCompletionMessage

try:
import litellm
from litellm import acompletion
LITELLM_AVAILABLE = True
except ImportError:
LITELLM_AVAILABLE = False
litellm = None
acompletion = None


class LiteLLMConverter:
"""Helper class to convert LiteLLM messages to OpenAI format."""

@classmethod
def convert_message_to_openai(cls, message) -> ChatCompletionMessage:
"""Convert LiteLLM message to OpenAI ChatCompletionMessage format."""
if hasattr(message, 'role') and message.role != "assistant":
raise ValueError(f"Unsupported role: {message.role}")

# Convert tool calls if present
tool_calls = None
if hasattr(message, 'tool_calls') and message.tool_calls:
tool_calls = []
for tc in message.tool_calls:
tool_calls.append({
'id': tc.id,
'type': tc.type,
'function': {
'name': tc.function.name,
'arguments': tc.function.arguments
}
})

# Handle provider-specific fields like refusal
provider_specific_fields = getattr(message, 'provider_specific_fields', None) or {}
refusal = provider_specific_fields.get('refusal', None)

return ChatCompletionMessage(
content=getattr(message, 'content', None),
refusal=refusal,
role="assistant",
tool_calls=tool_calls,
)


class LiteLLMModel(Model):
"""LiteLLM model implementation compatible with OpenAI Agents SDK interface.

This model provides the same interface as other models in the OpenAI Agents SDK
while using LiteLLM's unified API to communicate with various LLM providers.
"""

def __init__(
self,
model: str,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs
):
"""Initialize the LiteLLM model.

Args:
model: Model name/identifier (e.g., "gpt-4", "anthropic/claude-3-5-sonnet")
api_key: API key for the model provider
base_url: Optional base URL for custom endpoints
**kwargs: Additional configuration for LiteLLM
"""
if not LITELLM_AVAILABLE:
raise ImportError(
"LiteLLM is not installed. Install it with: pip install litellm"
)

self._model = model
self._api_key = api_key
self._base_url = base_url
self._config = kwargs

async def get_response(
self,
system_instructions: Union[str, None],
input: Union[str, list],
model_settings,
tools: list,
output_schema=None,
handoffs: list = None,
tracing=None,
*,
previous_response_id: Union[str, None] = None,
prompt=None,
**kwargs
):
"""Get a response from the LiteLLM model.

This method translates OpenAI Agents SDK parameters to LiteLLM's
completion API format and returns a ModelResponse.
"""

# Build messages from system instructions and input
messages = []

if system_instructions:
messages.append({"role": "system", "content": system_instructions})

# Handle different input types
if isinstance(input, str):
messages.append({"role": "user", "content": input})
elif isinstance(input, list):
# Handle list of messages/content items
for item in input:
if isinstance(item, dict):
messages.append(item)
else:
# Convert other types to user message
messages.append({"role": "user", "content": str(item)})

# Prepare LiteLLM parameters
litellm_params = {
"model": self._model,
"messages": messages,
}

# Add API key if available
if self._api_key:
litellm_params["api_key"] = self._api_key

# Add base URL if available
if self._base_url:
litellm_params["base_url"] = self._base_url

# Map model settings to LiteLLM parameters
if hasattr(model_settings, 'temperature') and model_settings.temperature is not None:
litellm_params["temperature"] = model_settings.temperature
if hasattr(model_settings, 'max_tokens') and model_settings.max_tokens is not None:
litellm_params["max_tokens"] = model_settings.max_tokens
if hasattr(model_settings, 'top_p') and model_settings.top_p is not None:
litellm_params["top_p"] = model_settings.top_p

# Handle tools if provided
if tools:
# Convert tools to OpenAI format that LiteLLM expects
litellm_tools = []
for tool in tools:
if hasattr(tool, 'name') and hasattr(tool, 'description'):
tool_def = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
}
}
if hasattr(tool, 'params_json_schema'):
tool_def["function"]["parameters"] = tool.params_json_schema
litellm_tools.append(tool_def)

if litellm_tools:
litellm_params["tools"] = litellm_tools

# Handle output schema/response format
if output_schema and not (hasattr(output_schema, 'is_plain_text') and output_schema.is_plain_text()):
if hasattr(output_schema, 'json_schema'):
try:
schema = output_schema.json_schema()
litellm_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": getattr(output_schema, 'name', lambda: "response")(),
"schema": schema,
"strict": getattr(output_schema, 'is_strict_json_schema', lambda: False)()
}
}
except Exception:
# Fallback to basic JSON mode if schema extraction fails
litellm_params["response_format"] = {"type": "json_object"}

# Add any additional config
litellm_params.update(self._config)

try:
# Make the LiteLLM API call
response = await acompletion(**litellm_params)

# Convert response to ModelResponse format using the same approach as official openai-agents
choice = response.choices[0]
message = choice.message

# Convert LiteLLM message to OpenAI ChatCompletionMessage format
openai_message = LiteLLMConverter.convert_message_to_openai(message)

# Convert message to output items using agents SDK converter
output_items = Converter.message_to_output_items(openai_message)

# Handle usage conversion similar to official implementation
if hasattr(response, "usage") and response.usage:
response_usage = response.usage
usage = Usage(
requests=1,
input_tokens=getattr(response_usage, 'prompt_tokens', 0),
output_tokens=getattr(response_usage, 'completion_tokens', 0),
total_tokens=getattr(response_usage, 'total_tokens', 0),
input_tokens_details=InputTokensDetails(
cached_tokens=getattr(
getattr(response_usage, 'prompt_tokens_details', None),
'cached_tokens', 0
) or 0
),
output_tokens_details=OutputTokensDetails(
reasoning_tokens=getattr(
getattr(response_usage, 'completion_tokens_details', None),
'reasoning_tokens', 0
) or 0
)
)
else:
# Fallback if no usage data
usage = Usage(
requests=1,
input_tokens=0,
output_tokens=0,
total_tokens=0,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens_details=OutputTokensDetails(reasoning_tokens=0)
)

# Create ModelResponse
model_response = ModelResponse(
output=output_items,
usage=usage,
response_id=getattr(response, 'id', None)
)

return model_response

except Exception as e:
# Re-raise with more context
raise RuntimeError(f"LiteLLM API call failed for model {self._model}: {str(e)}") from e

def stream_response(
self,
system_instructions: Optional[str],
input: Union[str, list],
model_settings,
tools: list,
output_schema=None,
handoffs: list = None,
tracing=None,
*,
previous_response_id: Optional[str] = None,
prompt=None,
**kwargs
):
"""Stream response from the LiteLLM model.

Note: Streaming implementation would require additional complexity
to handle the OpenAI Agents SDK streaming interface. This is a
placeholder for future implementation.
"""
raise NotImplementedError(
"Streaming is not yet implemented for LiteLLM models. "
"Use get_response() for non-streaming responses."
)
Loading