@@ -393,11 +393,21 @@ def _format_passthrough_output(self, result: Any, context: Dict[str, Any]) -> An
393393 return passthrough_output
394394
395395 def _format_chat_prompt_output (
396- self , result : Any , tool_calls : Optional [list ] = None
396+ self ,
397+ result : Any ,
398+ tool_calls : Optional [list ] = None ,
399+ metadata : Optional [dict ] = None ,
397400 ) -> AIMessage :
398401 """Format output for ChatPromptValue input."""
399402 content = self ._extract_content_from_result (result )
400- if tool_calls :
403+
404+ if metadata and isinstance (metadata , dict ):
405+ metadata_copy = metadata .copy ()
406+ metadata_copy .pop ("content" , None )
407+ if tool_calls :
408+ metadata_copy ["tool_calls" ] = tool_calls
409+ return AIMessage (content = content , ** metadata_copy )
410+ elif tool_calls :
401411 return AIMessage (content = content , tool_calls = tool_calls )
402412 return AIMessage (content = content )
403413
@@ -406,11 +416,21 @@ def _format_string_prompt_output(self, result: Any) -> str:
406416 return self ._extract_content_from_result (result )
407417
408418 def _format_message_output (
409- self , result : Any , tool_calls : Optional [list ] = None
419+ self ,
420+ result : Any ,
421+ tool_calls : Optional [list ] = None ,
422+ metadata : Optional [dict ] = None ,
410423 ) -> AIMessage :
411424 """Format output for BaseMessage input types."""
412425 content = self ._extract_content_from_result (result )
413- if tool_calls :
426+
427+ if metadata and isinstance (metadata , dict ):
428+ metadata_copy = metadata .copy ()
429+ metadata_copy .pop ("content" , None )
430+ if tool_calls :
431+ metadata_copy ["tool_calls" ] = tool_calls
432+ return AIMessage (content = content , ** metadata_copy )
433+ elif tool_calls :
414434 return AIMessage (content = content , tool_calls = tool_calls )
415435 return AIMessage (content = content )
416436
@@ -434,25 +454,50 @@ def _format_dict_output_for_dict_message_list(
434454 }
435455
436456 def _format_dict_output_for_base_message_list (
437- self , result : Any , output_key : str , tool_calls : Optional [list ] = None
457+ self ,
458+ result : Any ,
459+ output_key : str ,
460+ tool_calls : Optional [list ] = None ,
461+ metadata : Optional [dict ] = None ,
438462 ) -> Dict [str , Any ]:
439463 """Format dict output when user input was a list of BaseMessage objects."""
440464 content = self ._extract_content_from_result (result )
441- if tool_calls :
465+
466+ if metadata and isinstance (metadata , dict ):
467+ metadata_copy = metadata .copy ()
468+ metadata_copy .pop ("content" , None )
469+ if tool_calls :
470+ metadata_copy ["tool_calls" ] = tool_calls
471+ return {output_key : AIMessage (content = content , ** metadata_copy )}
472+ elif tool_calls :
442473 return {output_key : AIMessage (content = content , tool_calls = tool_calls )}
443474 return {output_key : AIMessage (content = content )}
444475
445476 def _format_dict_output_for_base_message (
446- self , result : Any , output_key : str , tool_calls : Optional [list ] = None
477+ self ,
478+ result : Any ,
479+ output_key : str ,
480+ tool_calls : Optional [list ] = None ,
481+ metadata : Optional [dict ] = None ,
447482 ) -> Dict [str , Any ]:
448483 """Format dict output when user input was a BaseMessage."""
449484 content = self ._extract_content_from_result (result )
450- if tool_calls :
485+
486+ if metadata :
487+ metadata_copy = metadata .copy ()
488+ if tool_calls :
489+ metadata_copy ["tool_calls" ] = tool_calls
490+ return {output_key : AIMessage (content = content , ** metadata_copy )}
491+ elif tool_calls :
451492 return {output_key : AIMessage (content = content , tool_calls = tool_calls )}
452493 return {output_key : AIMessage (content = content )}
453494
454495 def _format_dict_output (
455- self , input_dict : dict , result : Any , tool_calls : Optional [list ] = None
496+ self ,
497+ input_dict : dict ,
498+ result : Any ,
499+ tool_calls : Optional [list ] = None ,
500+ metadata : Optional [dict ] = None ,
456501 ) -> Dict [str , Any ]:
457502 """Format output for dictionary input."""
458503 output_key = self .passthrough_bot_output_key
@@ -471,13 +516,13 @@ def _format_dict_output(
471516 )
472517 elif all (isinstance (msg , BaseMessage ) for msg in user_input ):
473518 return self ._format_dict_output_for_base_message_list (
474- result , output_key , tool_calls
519+ result , output_key , tool_calls , metadata
475520 )
476521 else :
477522 return {output_key : result }
478523 elif isinstance (user_input , BaseMessage ):
479524 return self ._format_dict_output_for_base_message (
480- result , output_key , tool_calls
525+ result , output_key , tool_calls , metadata
481526 )
482527
483528 # Generic fallback for dictionaries
@@ -490,6 +535,7 @@ def _format_output(
490535 result : Any ,
491536 context : Dict [str , Any ],
492537 tool_calls : Optional [list ] = None ,
538+ metadata : Optional [dict ] = None ,
493539 ) -> Any :
494540 """Format the output based on the input type and rails result.
495541
@@ -512,17 +558,17 @@ def _format_output(
512558 return self ._format_passthrough_output (result , context )
513559
514560 if isinstance (input , ChatPromptValue ):
515- return self ._format_chat_prompt_output (result , tool_calls )
561+ return self ._format_chat_prompt_output (result , tool_calls , metadata )
516562 elif isinstance (input , StringPromptValue ):
517563 return self ._format_string_prompt_output (result )
518564 elif isinstance (input , (HumanMessage , AIMessage , BaseMessage )):
519- return self ._format_message_output (result , tool_calls )
565+ return self ._format_message_output (result , tool_calls , metadata )
520566 elif isinstance (input , list ) and all (
521567 isinstance (msg , BaseMessage ) for msg in input
522568 ):
523- return self ._format_message_output (result , tool_calls )
569+ return self ._format_message_output (result , tool_calls , metadata )
524570 elif isinstance (input , dict ):
525- return self ._format_dict_output (input , result , tool_calls )
571+ return self ._format_dict_output (input , result , tool_calls , metadata )
526572 elif isinstance (input , str ):
527573 return self ._format_string_prompt_output (result )
528574 else :
@@ -669,7 +715,9 @@ def _full_rails_invoke(
669715 result = result [0 ]
670716
671717 # Format and return the output based in input type
672- return self ._format_output (input , result , context , res .tool_calls )
718+ return self ._format_output (
719+ input , result , context , res .tool_calls , res .llm_metadata
720+ )
673721
674722 async def ainvoke (
675723 self ,
@@ -731,7 +779,9 @@ async def _full_rails_ainvoke(
731779 result = res .response
732780
733781 # Format and return the output based on input type
734- return self ._format_output (input , result , context , res .tool_calls )
782+ return self ._format_output (
783+ input , result , context , res .tool_calls , res .llm_metadata
784+ )
735785
736786 def stream (
737787 self ,
0 commit comments