Skip to content

Commit 51b1308

Browse files
authored
feat(tool-calling): add tool call passthrough support in LLMRails (#1364)
Implements tool call extraction and passthrough functionality in LLMRails: - Add tool_calls_var context variable for storing LLM tool calls - Refactor llm_call utils to extract and store tool calls from responses - Support tool calls in both GenerationResponse and dict message formats - Add ToolMessage support for langchain message conversion - Comprehensive test coverage for tool calling integration
1 parent 5d974e5 commit 51b1308

File tree

8 files changed

+890
-44
lines changed

8 files changed

+890
-44
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 106 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,15 @@
2020
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager
2121
from langchain.prompts.base import StringPromptValue
2222
from langchain.prompts.chat import ChatPromptValue
23-
from langchain.schema import AIMessage, HumanMessage, SystemMessage
23+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
2424

2525
from nemoguardrails.colang.v2_x.lang.colang_ast import Flow
2626
from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents
27-
from nemoguardrails.context import llm_call_info_var, reasoning_trace_var
27+
from nemoguardrails.context import (
28+
llm_call_info_var,
29+
reasoning_trace_var,
30+
tool_calls_var,
31+
)
2832
from nemoguardrails.logging.callbacks import logging_callbacks
2933
from nemoguardrails.logging.explain import LLMCallInfo
3034

@@ -72,7 +76,22 @@ async def llm_call(
7276
custom_callback_handlers: Optional[List[AsyncCallbackHandler]] = None,
7377
) -> str:
7478
"""Calls the LLM with a prompt and returns the generated text."""
75-
# We initialize a new LLM call if we don't have one already
79+
_setup_llm_call_info(llm, model_name, model_provider)
80+
all_callbacks = _prepare_callbacks(custom_callback_handlers)
81+
82+
if isinstance(prompt, str):
83+
response = await _invoke_with_string_prompt(llm, prompt, all_callbacks, stop)
84+
else:
85+
response = await _invoke_with_message_list(llm, prompt, all_callbacks, stop)
86+
87+
_store_tool_calls(response)
88+
return _extract_content(response)
89+
90+
91+
def _setup_llm_call_info(
92+
llm: BaseLanguageModel, model_name: Optional[str], model_provider: Optional[str]
93+
) -> None:
94+
"""Initialize or update LLM call info in context."""
7695
llm_call_info = llm_call_info_var.get()
7796
if llm_call_info is None:
7897
llm_call_info = LLMCallInfo()
@@ -81,52 +100,84 @@ async def llm_call(
81100
llm_call_info.llm_model_name = model_name or _infer_model_name(llm)
82101
llm_call_info.llm_provider_name = model_provider
83102

103+
104+
def _prepare_callbacks(
105+
custom_callback_handlers: Optional[List[AsyncCallbackHandler]],
106+
) -> BaseCallbackManager:
107+
"""Prepare callback manager with custom handlers if provided."""
84108
if custom_callback_handlers and custom_callback_handlers != [None]:
85-
all_callbacks = BaseCallbackManager(
109+
return BaseCallbackManager(
86110
handlers=logging_callbacks.handlers + custom_callback_handlers,
87111
inheritable_handlers=logging_callbacks.handlers + custom_callback_handlers,
88112
)
89-
else:
90-
all_callbacks = logging_callbacks
113+
return logging_callbacks
91114

92-
if isinstance(prompt, str):
93-
# stop sinks here
94-
try:
95-
result = await llm.agenerate_prompt(
96-
[StringPromptValue(text=prompt)], callbacks=all_callbacks, stop=stop
115+
116+
async def _invoke_with_string_prompt(
117+
llm: BaseLanguageModel,
118+
prompt: str,
119+
callbacks: BaseCallbackManager,
120+
stop: Optional[List[str]],
121+
):
122+
"""Invoke LLM with string prompt."""
123+
try:
124+
return await llm.ainvoke(prompt, config={"callbacks": callbacks, "stop": stop})
125+
except Exception as e:
126+
raise LLMCallException(e)
127+
128+
129+
async def _invoke_with_message_list(
130+
llm: BaseLanguageModel,
131+
prompt: List[dict],
132+
callbacks: BaseCallbackManager,
133+
stop: Optional[List[str]],
134+
):
135+
"""Invoke LLM with message list after converting to LangChain format."""
136+
messages = _convert_messages_to_langchain_format(prompt)
137+
try:
138+
return await llm.ainvoke(
139+
messages, config={"callbacks": callbacks, "stop": stop}
140+
)
141+
except Exception as e:
142+
raise LLMCallException(e)
143+
144+
145+
def _convert_messages_to_langchain_format(prompt: List[dict]) -> List:
146+
"""Convert message list to LangChain message format."""
147+
messages = []
148+
for msg in prompt:
149+
msg_type = msg["type"] if "type" in msg else msg["role"]
150+
151+
if msg_type == "user":
152+
messages.append(HumanMessage(content=msg["content"]))
153+
elif msg_type in ["bot", "assistant"]:
154+
messages.append(AIMessage(content=msg["content"]))
155+
elif msg_type == "system":
156+
messages.append(SystemMessage(content=msg["content"]))
157+
elif msg_type == "tool":
158+
messages.append(
159+
ToolMessage(
160+
content=msg["content"],
161+
tool_call_id=msg.get("tool_call_id", ""),
162+
)
97163
)
98-
except Exception as e:
99-
raise LLMCallException(e)
100-
llm_call_info.raw_response = result.llm_output
164+
else:
165+
raise ValueError(f"Unknown message type {msg_type}")
101166

102-
# TODO: error handling
103-
return result.generations[0][0].text
104-
else:
105-
# We first need to translate the array of messages into LangChain message format
106-
messages = []
107-
for _msg in prompt:
108-
msg_type = _msg["type"] if "type" in _msg else _msg["role"]
109-
if msg_type == "user":
110-
messages.append(HumanMessage(content=_msg["content"]))
111-
elif msg_type in ["bot", "assistant"]:
112-
messages.append(AIMessage(content=_msg["content"]))
113-
elif msg_type == "system":
114-
messages.append(SystemMessage(content=_msg["content"]))
115-
else:
116-
# TODO: add support for tool-related messages
117-
raise ValueError(f"Unknown message type {msg_type}")
167+
return messages
118168

119-
try:
120-
result = await llm.agenerate_prompt(
121-
[ChatPromptValue(messages=messages)], callbacks=all_callbacks, stop=stop
122-
)
123169

124-
except Exception as e:
125-
raise LLMCallException(e)
170+
def _store_tool_calls(response) -> None:
171+
"""Extract and store tool calls from response in context."""
172+
tool_calls = getattr(response, "tool_calls", None)
173+
tool_calls_var.set(tool_calls)
126174

127-
llm_call_info.raw_response = result.llm_output
128175

129-
return result.generations[0][0].text
176+
def _extract_content(response) -> str:
177+
"""Extract text content from response."""
178+
if hasattr(response, "content"):
179+
return response.content
180+
return str(response)
130181

131182

132183
def get_colang_history(
@@ -175,15 +226,15 @@ def get_colang_history(
175226
history += f'user "{event["text"]}"\n'
176227
elif event["type"] == "UserIntent":
177228
if include_texts:
178-
history += f' {event["intent"]}\n'
229+
history += f" {event['intent']}\n"
179230
else:
180-
history += f'user {event["intent"]}\n'
231+
history += f"user {event['intent']}\n"
181232
elif event["type"] == "BotIntent":
182233
# If we have instructions, we add them before the bot message.
183234
# But we only do that for the last bot message.
184235
if "instructions" in event and idx == last_bot_intent_idx:
185236
history += f"# {event['instructions']}\n"
186-
history += f'bot {event["intent"]}\n'
237+
history += f"bot {event['intent']}\n"
187238
elif event["type"] == "StartUtteranceBotAction" and include_texts:
188239
history += f' "{event["script"]}"\n'
189240
# We skip system actions from this log
@@ -352,9 +403,9 @@ def flow_to_colang(flow: Union[dict, Flow]) -> str:
352403
if "_type" not in element:
353404
raise Exception("bla")
354405
if element["_type"] == "UserIntent":
355-
colang_flow += f'user {element["intent_name"]}\n'
406+
colang_flow += f"user {element['intent_name']}\n"
356407
elif element["_type"] == "run_action" and element["action_name"] == "utter":
357-
colang_flow += f'bot {element["action_params"]["value"]}\n'
408+
colang_flow += f"bot {element['action_params']['value']}\n"
358409

359410
return colang_flow
360411

@@ -592,3 +643,15 @@ def get_and_clear_reasoning_trace_contextvar() -> Optional[str]:
592643
reasoning_trace_var.set(None)
593644
return reasoning_trace
594645
return None
646+
647+
648+
def get_and_clear_tool_calls_contextvar() -> Optional[list]:
649+
"""Get the current tool calls and clear them from the context.
650+
651+
Returns:
652+
Optional[list]: The tool calls if they exist, None otherwise.
653+
"""
654+
if tool_calls := tool_calls_var.get():
655+
tool_calls_var.set(None)
656+
return tool_calls
657+
return None

nemoguardrails/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,8 @@
3737
reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
3838
"reasoning_trace", default=None
3939
)
40+
41+
# The tool calls from the current LLM response.
42+
tool_calls_var: contextvars.ContextVar[Optional[list]] = contextvars.ContextVar(
43+
"tool_calls", default=None
44+
)

nemoguardrails/rails/llm/llmrails.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from nemoguardrails.actions.llm.generation import LLMGenerationActions
3434
from nemoguardrails.actions.llm.utils import (
3535
get_and_clear_reasoning_trace_contextvar,
36+
get_and_clear_tool_calls_contextvar,
3637
get_colang_history,
3738
)
3839
from nemoguardrails.actions.output_mapping import is_output_blocked
@@ -1084,6 +1085,8 @@ async def generate_async(
10841085
options.log.llm_calls = True
10851086
options.log.internal_events = True
10861087

1088+
tool_calls = get_and_clear_tool_calls_contextvar()
1089+
10871090
# If we have generation options, we prepare a GenerationResponse instance.
10881091
if options:
10891092
# If a prompt was used, we only need to return the content of the message.
@@ -1100,6 +1103,9 @@ async def generate_async(
11001103
reasoning_trace + res.response[0]["content"]
11011104
)
11021105

1106+
if tool_calls:
1107+
res.tool_calls = tool_calls
1108+
11031109
if self.config.colang_version == "1.0":
11041110
# If output variables are specified, we extract their values
11051111
if options.output_vars:
@@ -1228,6 +1234,8 @@ async def generate_async(
12281234
if prompt:
12291235
return new_message["content"]
12301236
else:
1237+
if tool_calls:
1238+
new_message["tool_calls"] = tool_calls
12311239
return new_message
12321240

12331241
def stream_async(

nemoguardrails/rails/llm/options.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,10 @@ class GenerationResponse(BaseModel):
408408
default=None,
409409
description="A state object which can be used in subsequent calls to continue the interaction.",
410410
)
411+
tool_calls: Optional[list] = Field(
412+
default=None,
413+
description="Tool calls extracted from the LLM response, if any.",
414+
)
411415

412416

413417
if __name__ == "__main__":

tests/rails/llm/test_options.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515

1616
import pytest
1717

18-
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationRailsOptions
18+
from nemoguardrails.rails.llm.options import (
19+
GenerationOptions,
20+
GenerationRailsOptions,
21+
GenerationResponse,
22+
)
1923

2024

2125
def test_generation_options_initialization():
@@ -110,3 +114,82 @@ def test_generation_options_serialization():
110114
assert '"output":false' in options_json
111115
assert '"activated_rails":true' in options_json
112116
assert '"llm_calls":true' in options_json
117+
118+
119+
def test_generation_response_initialization():
120+
"""Test GenerationResponse initialization."""
121+
response = GenerationResponse(response="Hello, world!")
122+
assert response.response == "Hello, world!"
123+
assert response.tool_calls is None
124+
assert response.llm_output is None
125+
assert response.state is None
126+
127+
128+
def test_generation_response_with_tool_calls():
129+
test_tool_calls = [
130+
{
131+
"name": "get_weather",
132+
"args": {"location": "NYC"},
133+
"id": "call_123",
134+
"type": "tool_call",
135+
},
136+
{
137+
"name": "calculate",
138+
"args": {"expression": "2+2"},
139+
"id": "call_456",
140+
"type": "tool_call",
141+
},
142+
]
143+
144+
response = GenerationResponse(
145+
response=[{"role": "assistant", "content": "I'll help you with that."}],
146+
tool_calls=test_tool_calls,
147+
)
148+
149+
assert response.tool_calls == test_tool_calls
150+
assert len(response.tool_calls) == 2
151+
assert response.tool_calls[0]["id"] == "call_123"
152+
assert response.tool_calls[1]["name"] == "calculate"
153+
154+
155+
def test_generation_response_empty_tool_calls():
156+
"""Test GenerationResponse with empty tool calls list."""
157+
response = GenerationResponse(response="No tools needed", tool_calls=[])
158+
159+
assert response.tool_calls == []
160+
assert len(response.tool_calls) == 0
161+
162+
163+
def test_generation_response_serialization_with_tool_calls():
164+
test_tool_calls = [
165+
{"name": "test_func", "args": {}, "id": "call_test", "type": "tool_call"}
166+
]
167+
168+
response = GenerationResponse(response="Response text", tool_calls=test_tool_calls)
169+
170+
response_dict = response.dict()
171+
assert "tool_calls" in response_dict
172+
assert response_dict["tool_calls"] == test_tool_calls
173+
174+
response_json = response.json()
175+
assert "tool_calls" in response_json
176+
assert "test_func" in response_json
177+
178+
179+
def test_generation_response_model_validation():
180+
"""Test GenerationResponse model validation."""
181+
test_tool_calls = [
182+
{"name": "valid_function", "args": {}, "id": "call_123", "type": "tool_call"},
183+
{"name": "another_function", "args": {}, "id": "call_456", "type": "tool_call"},
184+
]
185+
186+
response = GenerationResponse(
187+
response=[{"role": "assistant", "content": "Test response"}],
188+
tool_calls=test_tool_calls,
189+
llm_output={"token_usage": {"total_tokens": 50}},
190+
)
191+
192+
assert response.tool_calls is not None
193+
assert isinstance(response.tool_calls, list)
194+
assert len(response.tool_calls) == 2
195+
assert response.llm_output["token_usage"]["total_tokens"] == 50

0 commit comments

Comments
 (0)