diff --git a/lib/ruby_llm/active_record/acts_as.rb b/lib/ruby_llm/active_record/acts_as.rb index 678a3eea4..248114711 100644 --- a/lib/ruby_llm/active_record/acts_as.rb +++ b/lib/ruby_llm/active_record/acts_as.rb @@ -12,8 +12,9 @@ module ActsAs def acts_as_chat(message_class: 'Message', tool_call_class: 'ToolCall') include ChatMethods + acts_as_chat_with_tool_call(tool_call_class: tool_call_class) if RubyLLM.config.active_record_use_tool_calls + @message_class = message_class.to_s - @tool_call_class = tool_call_class.to_s has_many :messages, -> { order(created_at: :asc) }, @@ -25,22 +26,20 @@ def acts_as_chat(message_class: 'Message', tool_call_class: 'ToolCall') to: :to_llm end - def acts_as_message(chat_class: 'Chat', tool_call_class: 'ToolCall', touch_chat: false) # rubocop:disable Metrics/MethodLength + def acts_as_message( + chat_class: 'Chat', + chat_foreign_key: 'chat_id', + tool_call_class: 'ToolCall', + touch_chat: false + ) include MessageMethods - @chat_class = chat_class.to_s - @tool_call_class = tool_call_class.to_s + if RubyLLM.config.active_record_use_tool_calls + acts_as_message_with_tool_call(tool_call_class: tool_call_class) + end - belongs_to :chat, class_name: @chat_class, touch: touch_chat - has_many :tool_calls, class_name: @tool_call_class, dependent: :destroy - - belongs_to :parent_tool_call, - class_name: @tool_call_class, - foreign_key: 'tool_call_id', - optional: true, - inverse_of: :result - - delegate :tool_call?, :tool_result?, :tool_results, to: :to_llm + @chat_class = chat_class.to_s + belongs_to :chat, class_name: @chat_class, touch: touch_chat, foreign_key: chat_foreign_key end def acts_as_tool_call(message_class: 'Message') @@ -54,6 +53,29 @@ def acts_as_tool_call(message_class: 'Message') inverse_of: :parent_tool_call, dependent: :nullify end + + private + + def acts_as_chat_with_tool_call(tool_call_class:) + include ChatToolCallMethods + + @tool_call_class = tool_call_class.to_s + end + + def acts_as_message_with_tool_call(tool_call_class:) + include MessageToolCallMethods + + @tool_call_class = tool_call_class.to_s + has_many :tool_calls, class_name: @tool_call_class, dependent: :destroy + + belongs_to :parent_tool_call, + class_name: @tool_call_class, + foreign_key: 'tool_call_id', + optional: true, + inverse_of: :result + + delegate :tool_call?, :tool_result?, :tool_results, to: :to_llm + end end end @@ -94,16 +116,6 @@ def with_instructions(instructions, replace: false) self end - def with_tool(tool) - to_llm.with_tool(tool) - self - end - - def with_tools(*tools) - to_llm.with_tools(*tools) - self - end - def with_model(model_id, provider: nil) to_llm.with_model(model_id, provider: provider) self @@ -141,22 +153,48 @@ def persist_new_message ) end - def persist_message_completion(message) # rubocop:disable Metrics/AbcSize,Metrics/MethodLength + def persist_message_completion(message) return unless message - if message.tool_call_id - tool_call_id = self.class.tool_call_class.constantize.find_by(tool_call_id: message.tool_call_id).id - end - transaction do @message.update!( role: message.role, content: message.content, model_id: message.model_id, - tool_call_id: tool_call_id, input_tokens: message.input_tokens, output_tokens: message.output_tokens ) + end + end + end + + # Methods mixed into chat models to handle tool calls. + module ChatToolCallMethods + include ChatMethods + + def with_tool(tool) + to_llm.with_tool(tool) + self + end + + def with_tools(*tools) + to_llm.with_tools(*tools) + self + end + + private + + def persist_message_completion(message) + return super unless message&.tool_call_id || message&.tool_calls + + transaction do + super + + if message.tool_call_id + tool_call_id = self.class.tool_call_class.constantize.find_by(tool_call_id: message.tool_call_id).id + @message.update!(tool_call_id: tool_call_id) + end + persist_tool_calls(message.tool_calls) if message.tool_calls.present? end end @@ -179,14 +217,27 @@ def to_llm RubyLLM::Message.new( role: role.to_sym, content: extract_content, - tool_calls: extract_tool_calls, - tool_call_id: extract_tool_call_id, input_tokens: input_tokens, output_tokens: output_tokens, model_id: model_id ) end + def extract_content + content + end + end + + # Methods mixed into message models to handle tool calls. + module MessageToolCallMethods + def to_llm + RubyLLM::Message.new( + **super.to_h, + tool_calls: extract_tool_calls, + tool_call_id: extract_tool_call_id + ) + end + def extract_tool_calls tool_calls.to_h do |tool_call| [ @@ -203,10 +254,6 @@ def extract_tool_calls def extract_tool_call_id parent_tool_call&.tool_call_id end - - def extract_content - content - end end end end diff --git a/lib/ruby_llm/configuration.rb b/lib/ruby_llm/configuration.rb index f20185a48..f9733dd58 100644 --- a/lib/ruby_llm/configuration.rb +++ b/lib/ruby_llm/configuration.rb @@ -29,7 +29,9 @@ class Configuration :max_retries, :retry_interval, :retry_backoff_factor, - :retry_interval_randomness + :retry_interval_randomness, + # Feature flags + :active_record_use_tool_calls def initialize # Connection configuration @@ -43,6 +45,8 @@ def initialize @default_model = 'gpt-4.1-nano' @default_embedding_model = 'text-embedding-3-small' @default_image_model = 'dall-e-3' + + @active_record_use_tool_calls = true end end end