From 749222f2c0f9a301c3ec8ccac936a73f3f4468cb Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 9 Sep 2025 17:49:57 -0500 Subject: [PATCH 1/2] Cleaned integrations directory --- .../integrations/langchain/runnable_rails.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/nemoguardrails/integrations/langchain/runnable_rails.py b/nemoguardrails/integrations/langchain/runnable_rails.py index 2e8e0fbf2..edacb1519 100644 --- a/nemoguardrails/integrations/langchain/runnable_rails.py +++ b/nemoguardrails/integrations/langchain/runnable_rails.py @@ -15,9 +15,13 @@ from __future__ import annotations +from typing import Any, List, Optional, Union, cast import logging from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from langchain_core.language_models import BaseChatModel +from langchain_core.language_models.llms import BaseLLM +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.language_models import BaseLanguageModel from langchain_core.prompt_values import ChatPromptValue, StringPromptValue from langchain_core.runnables import Runnable, RunnableConfig @@ -33,7 +37,7 @@ message_to_dict, ) from nemoguardrails.integrations.langchain.utils import async_wrap -from nemoguardrails.rails.llm.options import GenerationOptions +from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse logger = logging.getLogger(__name__) @@ -62,7 +66,7 @@ class RunnableRails(Runnable[Input, Output]): def __init__( self, config: RailsConfig, - llm: Optional[BaseLanguageModel] = None, + llm: Optional[Union[BaseLLM, BaseChatModel]] = None, tools: Optional[List[Tool]] = None, passthrough: bool = True, runnable: Optional[Runnable] = None, @@ -110,7 +114,7 @@ def __init__( if self.passthrough_runnable: self._init_passthrough_fn() - def _init_passthrough_fn(self): + def _init_passthrough_fn(self) -> None: """Initialize the passthrough function for the LLM rails instance.""" async def passthrough_fn(context: dict, events: List[dict]): @@ -134,7 +138,8 @@ async def passthrough_fn(context: dict, events: List[dict]): return text, _output - self.rails.llm_generation_actions.passthrough_fn = passthrough_fn + # Dynamically assign passthrough_fn to avoid type checker issues + setattr(self.rails.llm_generation_actions, "passthrough_fn", passthrough_fn) def __or__( self, other: Union[BaseLanguageModel, Runnable[Any, Any]] @@ -687,6 +692,9 @@ def _full_rails_invoke( res = self.rails.generate( messages=input_messages, options=GenerationOptions(output_vars=True) ) + # When using output_vars=True, rails.generate returns a GenerationResponse + if not isinstance(res, GenerationResponse): + raise Exception(f"Expected GenerationResponse, got {type(res)}") context = res.output_data result = res.response From e50dac6bbb83b09d0a084a0cca22f98bd764ed74 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 14 Oct 2025 15:27:17 -0500 Subject: [PATCH 2/2] Clean up post-merge imports --- nemoguardrails/integrations/langchain/runnable_rails.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/nemoguardrails/integrations/langchain/runnable_rails.py b/nemoguardrails/integrations/langchain/runnable_rails.py index edacb1519..c719fbb38 100644 --- a/nemoguardrails/integrations/langchain/runnable_rails.py +++ b/nemoguardrails/integrations/langchain/runnable_rails.py @@ -15,14 +15,12 @@ from __future__ import annotations -from typing import Any, List, Optional, Union, cast import logging -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast -from langchain_core.language_models import BaseChatModel +from langchain_core.language_models import BaseChatModel, BaseLanguageModel from langchain_core.language_models.llms import BaseLLM from langchain_core.messages import AIMessage, HumanMessage, SystemMessage -from langchain_core.language_models import BaseLanguageModel from langchain_core.prompt_values import ChatPromptValue, StringPromptValue from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.utils import Input, Output, gather_with_concurrency