2020from  langchain .callbacks .base  import  AsyncCallbackHandler , BaseCallbackManager 
2121from  langchain .prompts .base  import  StringPromptValue 
2222from  langchain .prompts .chat  import  ChatPromptValue 
23- from  langchain . schema  import  AIMessage , HumanMessage , SystemMessage 
23+ from  langchain_core . messages  import  AIMessage , HumanMessage , SystemMessage ,  ToolMessage 
2424
2525from  nemoguardrails .colang .v2_x .lang .colang_ast  import  Flow 
2626from  nemoguardrails .colang .v2_x .runtime .flows  import  InternalEvent , InternalEvents 
27- from  nemoguardrails .context  import  llm_call_info_var , reasoning_trace_var 
27+ from  nemoguardrails .context  import  (
28+     llm_call_info_var ,
29+     reasoning_trace_var ,
30+     tool_calls_var ,
31+ )
2832from  nemoguardrails .logging .callbacks  import  logging_callbacks 
2933from  nemoguardrails .logging .explain  import  LLMCallInfo 
3034
@@ -72,7 +76,22 @@ async def llm_call(
7276    custom_callback_handlers : Optional [List [AsyncCallbackHandler ]] =  None ,
7377) ->  str :
7478    """Calls the LLM with a prompt and returns the generated text.""" 
75-     # We initialize a new LLM call if we don't have one already 
79+     _setup_llm_call_info (llm , model_name , model_provider )
80+     all_callbacks  =  _prepare_callbacks (custom_callback_handlers )
81+ 
82+     if  isinstance (prompt , str ):
83+         response  =  await  _invoke_with_string_prompt (llm , prompt , all_callbacks , stop )
84+     else :
85+         response  =  await  _invoke_with_message_list (llm , prompt , all_callbacks , stop )
86+ 
87+     _store_tool_calls (response )
88+     return  _extract_content (response )
89+ 
90+ 
91+ def  _setup_llm_call_info (
92+     llm : BaseLanguageModel , model_name : Optional [str ], model_provider : Optional [str ]
93+ ) ->  None :
94+     """Initialize or update LLM call info in context.""" 
7695    llm_call_info  =  llm_call_info_var .get ()
7796    if  llm_call_info  is  None :
7897        llm_call_info  =  LLMCallInfo ()
@@ -81,52 +100,84 @@ async def llm_call(
81100    llm_call_info .llm_model_name  =  model_name  or  _infer_model_name (llm )
82101    llm_call_info .llm_provider_name  =  model_provider 
83102
103+ 
104+ def  _prepare_callbacks (
105+     custom_callback_handlers : Optional [List [AsyncCallbackHandler ]],
106+ ) ->  BaseCallbackManager :
107+     """Prepare callback manager with custom handlers if provided.""" 
84108    if  custom_callback_handlers  and  custom_callback_handlers  !=  [None ]:
85-         all_callbacks   =  BaseCallbackManager (
109+         return  BaseCallbackManager (
86110            handlers = logging_callbacks .handlers  +  custom_callback_handlers ,
87111            inheritable_handlers = logging_callbacks .handlers  +  custom_callback_handlers ,
88112        )
89-     else :
90-         all_callbacks  =  logging_callbacks 
113+     return  logging_callbacks 
91114
92-     if  isinstance (prompt , str ):
93-         # stop sinks here 
94-         try :
95-             result  =  await  llm .agenerate_prompt (
96-                 [StringPromptValue (text = prompt )], callbacks = all_callbacks , stop = stop 
115+ 
116+ async  def  _invoke_with_string_prompt (
117+     llm : BaseLanguageModel ,
118+     prompt : str ,
119+     callbacks : BaseCallbackManager ,
120+     stop : Optional [List [str ]],
121+ ):
122+     """Invoke LLM with string prompt.""" 
123+     try :
124+         return  await  llm .ainvoke (prompt , config = {"callbacks" : callbacks , "stop" : stop })
125+     except  Exception  as  e :
126+         raise  LLMCallException (e )
127+ 
128+ 
129+ async  def  _invoke_with_message_list (
130+     llm : BaseLanguageModel ,
131+     prompt : List [dict ],
132+     callbacks : BaseCallbackManager ,
133+     stop : Optional [List [str ]],
134+ ):
135+     """Invoke LLM with message list after converting to LangChain format.""" 
136+     messages  =  _convert_messages_to_langchain_format (prompt )
137+     try :
138+         return  await  llm .ainvoke (
139+             messages , config = {"callbacks" : callbacks , "stop" : stop }
140+         )
141+     except  Exception  as  e :
142+         raise  LLMCallException (e )
143+ 
144+ 
145+ def  _convert_messages_to_langchain_format (prompt : List [dict ]) ->  List :
146+     """Convert message list to LangChain message format.""" 
147+     messages  =  []
148+     for  msg  in  prompt :
149+         msg_type  =  msg ["type" ] if  "type"  in  msg  else  msg ["role" ]
150+ 
151+         if  msg_type  ==  "user" :
152+             messages .append (HumanMessage (content = msg ["content" ]))
153+         elif  msg_type  in  ["bot" , "assistant" ]:
154+             messages .append (AIMessage (content = msg ["content" ]))
155+         elif  msg_type  ==  "system" :
156+             messages .append (SystemMessage (content = msg ["content" ]))
157+         elif  msg_type  ==  "tool" :
158+             messages .append (
159+                 ToolMessage (
160+                     content = msg ["content" ],
161+                     tool_call_id = msg .get ("tool_call_id" , "" ),
162+                 )
97163            )
98-         except  Exception  as  e :
99-             raise  LLMCallException (e )
100-         llm_call_info .raw_response  =  result .llm_output 
164+         else :
165+             raise  ValueError (f"Unknown message type { msg_type }  " )
101166
102-         # TODO: error handling 
103-         return  result .generations [0 ][0 ].text 
104-     else :
105-         # We first need to translate the array of messages into LangChain message format 
106-         messages  =  []
107-         for  _msg  in  prompt :
108-             msg_type  =  _msg ["type" ] if  "type"  in  _msg  else  _msg ["role" ]
109-             if  msg_type  ==  "user" :
110-                 messages .append (HumanMessage (content = _msg ["content" ]))
111-             elif  msg_type  in  ["bot" , "assistant" ]:
112-                 messages .append (AIMessage (content = _msg ["content" ]))
113-             elif  msg_type  ==  "system" :
114-                 messages .append (SystemMessage (content = _msg ["content" ]))
115-             else :
116-                 # TODO: add support for tool-related messages 
117-                 raise  ValueError (f"Unknown message type { msg_type }  " )
167+     return  messages 
118168
119-         try :
120-             result  =  await  llm .agenerate_prompt (
121-                 [ChatPromptValue (messages = messages )], callbacks = all_callbacks , stop = stop 
122-             )
123169
124-         except  Exception  as  e :
125-             raise  LLMCallException (e )
170+ def  _store_tool_calls (response ) ->  None :
171+     """Extract and store tool calls from response in context.""" 
172+     tool_calls  =  getattr (response , "tool_calls" , None )
173+     tool_calls_var .set (tool_calls )
126174
127-         llm_call_info .raw_response  =  result .llm_output 
128175
129-         return  result .generations [0 ][0 ].text 
176+ def  _extract_content (response ) ->  str :
177+     """Extract text content from response.""" 
178+     if  hasattr (response , "content" ):
179+         return  response .content 
180+     return  str (response )
130181
131182
132183def  get_colang_history (
@@ -175,15 +226,15 @@ def get_colang_history(
175226                history  +=  f'user "{ event ["text" ]}  "\n ' 
176227            elif  event ["type" ] ==  "UserIntent" :
177228                if  include_texts :
178-                     history  +=  f'   { event [" intent"  ]} \n '  
229+                     history  +=  f"   { event [' intent'  ]} \n "  
179230                else :
180-                     history  +=  f' user { event [" intent"  ]} \n '  
231+                     history  +=  f" user { event [' intent'  ]} \n "  
181232            elif  event ["type" ] ==  "BotIntent" :
182233                # If we have instructions, we add them before the bot message. 
183234                # But we only do that for the last bot message. 
184235                if  "instructions"  in  event  and  idx  ==  last_bot_intent_idx :
185236                    history  +=  f"# { event ['instructions' ]} \n " 
186-                 history  +=  f' bot { event [" intent"  ]} \n '  
237+                 history  +=  f" bot { event [' intent'  ]} \n "  
187238            elif  event ["type" ] ==  "StartUtteranceBotAction"  and  include_texts :
188239                history  +=  f'  "{ event ["script" ]}  "\n ' 
189240            # We skip system actions from this log 
@@ -352,9 +403,9 @@ def flow_to_colang(flow: Union[dict, Flow]) -> str:
352403            if  "_type"  not  in   element :
353404                raise  Exception ("bla" )
354405            if  element ["_type" ] ==  "UserIntent" :
355-                 colang_flow  +=  f' user { element [" intent_name"  ]} \n '  
406+                 colang_flow  +=  f" user { element [' intent_name'  ]} \n "  
356407            elif  element ["_type" ] ==  "run_action"  and  element ["action_name" ] ==  "utter" :
357-                 colang_flow  +=  f' bot { element [" action_params" ][ " value"  ]} \n '  
408+                 colang_flow  +=  f" bot { element [' action_params' ][ ' value'  ]} \n "  
358409
359410    return  colang_flow 
360411
@@ -592,3 +643,15 @@ def get_and_clear_reasoning_trace_contextvar() -> Optional[str]:
592643        reasoning_trace_var .set (None )
593644        return  reasoning_trace 
594645    return  None 
646+ 
647+ 
648+ def  get_and_clear_tool_calls_contextvar () ->  Optional [list ]:
649+     """Get the current tool calls and clear them from the context. 
650+ 
651+     Returns: 
652+         Optional[list]: The tool calls if they exist, None otherwise. 
653+     """ 
654+     if  tool_calls  :=  tool_calls_var .get ():
655+         tool_calls_var .set (None )
656+         return  tool_calls 
657+     return  None 
0 commit comments