From 30e381fe14e82df78ded2adc688d2df9928ef954 Mon Sep 17 00:00:00 2001 From: Anilturaga Date: Fri, 25 Jul 2025 01:26:55 +0530 Subject: [PATCH] Model name bug and litellm support - Model name provided during the Agent creation is not being used - LiteLLM support similar to openai-agents. ex model name would be 'litellm/anthropic/claude-sonnet-4-20250514 --- temporalio/contrib/openai_agents/__init__.py | 9 +- .../openai_agents/_invoke_model_activity.py | 63 +++- .../contrib/openai_agents/_litellm_model.py | 271 ++++++++++++++++++ .../contrib/openai_agents/_openai_runner.py | 10 +- 4 files changed, 344 insertions(+), 9 deletions(-) create mode 100644 temporalio/contrib/openai_agents/_litellm_model.py diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index 274f5b98b..a3422806c 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -8,11 +8,13 @@ Use with caution in production environments. """ +from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._temporal_openai_agents import ( OpenAIAgentsPlugin, TestModel, TestModelProvider, + set_open_ai_agent_temporal_overrides, ) from temporalio.contrib.openai_agents._trace_interceptor import ( OpenAIAgentsTracingInterceptor, @@ -21,9 +23,12 @@ from . import workflow __all__ = [ - "OpenAIAgentsPlugin", + "ModelActivity", "ModelActivityParameters", - "workflow", + "OpenAIAgentsPlugin", + "OpenAIAgentsTracingInterceptor", + "set_open_ai_agent_temporal_overrides", "TestModel", "TestModelProvider", + "workflow", ] diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index 2fc60df02..c172349bd 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -14,6 +14,7 @@ FileSearchTool, FunctionTool, Handoff, + Model, ModelProvider, ModelResponse, ModelSettings, @@ -126,21 +127,73 @@ class ActivityModelInput(TypedDict, total=False): class ModelActivity: - """Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled. + """Class wrapper for model invocation activities to allow model customization. + + By default, we use an OpenAIProvider with retries disabled. The activity automatically + detects models prefixed with 'litellm/' and routes them to LiteLLM when available. + Disabling retries in your model of choice is recommended to allow activity retries to define the retry model. """ def __init__(self, model_provider: Optional[ModelProvider] = None): """Initialize the activity with a model provider.""" - self._model_provider = model_provider or OpenAIProvider( - openai_client=AsyncOpenAI(max_retries=0) - ) + self._custom_model_provider = model_provider + self._default_model_provider = None # Lazy initialization + + def _get_openai_provider(self) -> ModelProvider: + """Get the OpenAI provider, initializing it lazily to avoid requiring API key upfront.""" + if self._custom_model_provider: + return self._custom_model_provider + + if self._default_model_provider is None: + self._default_model_provider = OpenAIProvider( + openai_client=AsyncOpenAI(max_retries=0) + ) + + return self._default_model_provider + + def _get_litellm_model(self, model_name: str) -> Model: + """Get a LiteLLM model for the given model name. + + Args: + model_name: Model name prefixed with 'litellm/' (e.g., 'litellm/anthropic/claude-3-5-sonnet') + + Returns: + A LiteLLM model instance + + Raises: + ImportError: If LiteLLM is not installed + ValueError: If model name is invalid + """ + try: + from temporalio.contrib.openai_agents._litellm_model import LiteLLMModel + except ImportError: + raise ImportError( + f"LiteLLM model '{model_name}' requested but LiteLLM is not installed. " + "Install with: pip install litellm" + ) + + # Remove the 'litellm/' prefix to get the actual model name + actual_model_name = model_name[8:] # len("litellm/") = 8 + + if not actual_model_name: + raise ValueError("Model name cannot be just 'litellm/' - provide the actual model after the prefix") + + return LiteLLMModel(model=actual_model_name) @activity.defn @_auto_heartbeater async def invoke_model_activity(self, input: ActivityModelInput) -> ModelResponse: """Activity that invokes a model with the given input.""" - model = self._model_provider.get_model(input.get("model_name")) + model_name = input.get("model_name") + + # Check if this is a LiteLLM model (prefixed with 'litellm/') + if model_name and model_name.startswith("litellm/"): + model = self._get_litellm_model(model_name) + else: + # Use regular model provider (AsyncOpenAI by default, lazy initialization) + openai_provider = self._get_openai_provider() + model = openai_provider.get_model(model_name) async def empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: return "" diff --git a/temporalio/contrib/openai_agents/_litellm_model.py b/temporalio/contrib/openai_agents/_litellm_model.py new file mode 100644 index 000000000..2223566fc --- /dev/null +++ b/temporalio/contrib/openai_agents/_litellm_model.py @@ -0,0 +1,271 @@ +"""LiteLLM model class used when models are prefixed with 'litellm/'. +The routing logic is handled in the ModelActivity itself. +""" + +from typing import Optional, Union + +from agents import Model, ModelResponse, Usage +from agents.models.chatcmpl_converter import Converter +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails +from openai.types.chat import ChatCompletionMessage + +try: + import litellm + from litellm import acompletion + LITELLM_AVAILABLE = True +except ImportError: + LITELLM_AVAILABLE = False + litellm = None + acompletion = None + + +class LiteLLMConverter: + """Helper class to convert LiteLLM messages to OpenAI format.""" + + @classmethod + def convert_message_to_openai(cls, message) -> ChatCompletionMessage: + """Convert LiteLLM message to OpenAI ChatCompletionMessage format.""" + if hasattr(message, 'role') and message.role != "assistant": + raise ValueError(f"Unsupported role: {message.role}") + + # Convert tool calls if present + tool_calls = None + if hasattr(message, 'tool_calls') and message.tool_calls: + tool_calls = [] + for tc in message.tool_calls: + tool_calls.append({ + 'id': tc.id, + 'type': tc.type, + 'function': { + 'name': tc.function.name, + 'arguments': tc.function.arguments + } + }) + + # Handle provider-specific fields like refusal + provider_specific_fields = getattr(message, 'provider_specific_fields', None) or {} + refusal = provider_specific_fields.get('refusal', None) + + return ChatCompletionMessage( + content=getattr(message, 'content', None), + refusal=refusal, + role="assistant", + tool_calls=tool_calls, + ) + + +class LiteLLMModel(Model): + """LiteLLM model implementation compatible with OpenAI Agents SDK interface. + + This model provides the same interface as other models in the OpenAI Agents SDK + while using LiteLLM's unified API to communicate with various LLM providers. + """ + + def __init__( + self, + model: str, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + **kwargs + ): + """Initialize the LiteLLM model. + + Args: + model: Model name/identifier (e.g., "gpt-4", "anthropic/claude-3-5-sonnet") + api_key: API key for the model provider + base_url: Optional base URL for custom endpoints + **kwargs: Additional configuration for LiteLLM + """ + if not LITELLM_AVAILABLE: + raise ImportError( + "LiteLLM is not installed. Install it with: pip install litellm" + ) + + self._model = model + self._api_key = api_key + self._base_url = base_url + self._config = kwargs + + async def get_response( + self, + system_instructions: Union[str, None], + input: Union[str, list], + model_settings, + tools: list, + output_schema=None, + handoffs: list = None, + tracing=None, + *, + previous_response_id: Union[str, None] = None, + prompt=None, + **kwargs + ): + """Get a response from the LiteLLM model. + + This method translates OpenAI Agents SDK parameters to LiteLLM's + completion API format and returns a ModelResponse. + """ + + # Build messages from system instructions and input + messages = [] + + if system_instructions: + messages.append({"role": "system", "content": system_instructions}) + + # Handle different input types + if isinstance(input, str): + messages.append({"role": "user", "content": input}) + elif isinstance(input, list): + # Handle list of messages/content items + for item in input: + if isinstance(item, dict): + messages.append(item) + else: + # Convert other types to user message + messages.append({"role": "user", "content": str(item)}) + + # Prepare LiteLLM parameters + litellm_params = { + "model": self._model, + "messages": messages, + } + + # Add API key if available + if self._api_key: + litellm_params["api_key"] = self._api_key + + # Add base URL if available + if self._base_url: + litellm_params["base_url"] = self._base_url + + # Map model settings to LiteLLM parameters + if hasattr(model_settings, 'temperature') and model_settings.temperature is not None: + litellm_params["temperature"] = model_settings.temperature + if hasattr(model_settings, 'max_tokens') and model_settings.max_tokens is not None: + litellm_params["max_tokens"] = model_settings.max_tokens + if hasattr(model_settings, 'top_p') and model_settings.top_p is not None: + litellm_params["top_p"] = model_settings.top_p + + # Handle tools if provided + if tools: + # Convert tools to OpenAI format that LiteLLM expects + litellm_tools = [] + for tool in tools: + if hasattr(tool, 'name') and hasattr(tool, 'description'): + tool_def = { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + } + } + if hasattr(tool, 'params_json_schema'): + tool_def["function"]["parameters"] = tool.params_json_schema + litellm_tools.append(tool_def) + + if litellm_tools: + litellm_params["tools"] = litellm_tools + + # Handle output schema/response format + if output_schema and not (hasattr(output_schema, 'is_plain_text') and output_schema.is_plain_text()): + if hasattr(output_schema, 'json_schema'): + try: + schema = output_schema.json_schema() + litellm_params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": getattr(output_schema, 'name', lambda: "response")(), + "schema": schema, + "strict": getattr(output_schema, 'is_strict_json_schema', lambda: False)() + } + } + except Exception: + # Fallback to basic JSON mode if schema extraction fails + litellm_params["response_format"] = {"type": "json_object"} + + # Add any additional config + litellm_params.update(self._config) + + try: + # Make the LiteLLM API call + response = await acompletion(**litellm_params) + + # Convert response to ModelResponse format using the same approach as official openai-agents + choice = response.choices[0] + message = choice.message + + # Convert LiteLLM message to OpenAI ChatCompletionMessage format + openai_message = LiteLLMConverter.convert_message_to_openai(message) + + # Convert message to output items using agents SDK converter + output_items = Converter.message_to_output_items(openai_message) + + # Handle usage conversion similar to official implementation + if hasattr(response, "usage") and response.usage: + response_usage = response.usage + usage = Usage( + requests=1, + input_tokens=getattr(response_usage, 'prompt_tokens', 0), + output_tokens=getattr(response_usage, 'completion_tokens', 0), + total_tokens=getattr(response_usage, 'total_tokens', 0), + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + getattr(response_usage, 'prompt_tokens_details', None), + 'cached_tokens', 0 + ) or 0 + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + getattr(response_usage, 'completion_tokens_details', None), + 'reasoning_tokens', 0 + ) or 0 + ) + ) + else: + # Fallback if no usage data + usage = Usage( + requests=1, + input_tokens=0, + output_tokens=0, + total_tokens=0, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0) + ) + + # Create ModelResponse + model_response = ModelResponse( + output=output_items, + usage=usage, + response_id=getattr(response, 'id', None) + ) + + return model_response + + except Exception as e: + # Re-raise with more context + raise RuntimeError(f"LiteLLM API call failed for model {self._model}: {str(e)}") from e + + def stream_response( + self, + system_instructions: Optional[str], + input: Union[str, list], + model_settings, + tools: list, + output_schema=None, + handoffs: list = None, + tracing=None, + *, + previous_response_id: Optional[str] = None, + prompt=None, + **kwargs + ): + """Stream response from the LiteLLM model. + + Note: Streaming implementation would require additional complexity + to handle the OpenAI Agents SDK streaming interface. This is a + placeholder for future implementation. + """ + raise NotImplementedError( + "Streaming is not yet implemented for LiteLLM models. " + "Use get_response() for non-streaming responses." + ) \ No newline at end of file diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index 1ccbc5f4d..3c3ce05d3 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -60,14 +60,20 @@ async def run( if run_config is None: run_config = RunConfig() - if run_config.model is not None and not isinstance(run_config.model, str): + # Use run_config.model if specified, otherwise fall back to the agent's model + agent_model = getattr(starting_agent, 'model', None) + run_config_model = getattr(run_config, 'model', None) + + model_name = run_config_model if run_config_model is not None else agent_model + + if model_name is not None and not isinstance(model_name, str): raise ValueError( "Temporal workflows require a model name to be a string in the run config." ) updated_run_config = replace( run_config, model=_TemporalModelStub( - run_config.model, + model_name, model_params=self.model_params, ), )