|
1 | 1 | using BotSharp.Abstraction.Agents; |
2 | 2 | using BotSharp.Abstraction.Agents.Enums; |
3 | | -using BotSharp.Abstraction.Conversations; |
4 | 3 | using BotSharp.Abstraction.Loggers; |
| 4 | +using BotSharp.Abstraction.Functions.Models; |
| 5 | +using BotSharp.Abstraction.Routing; |
5 | 6 | using BotSharp.Plugin.GoogleAI.Settings; |
6 | 7 | using LLMSharp.Google.Palm; |
7 | 8 | using Microsoft.Extensions.Logging; |
| 9 | +using LLMSharp.Google.Palm.DiscussService; |
8 | 10 |
|
9 | 11 | namespace BotSharp.Plugin.GoogleAI.Providers; |
10 | 12 |
|
@@ -34,29 +36,105 @@ public RoleDialogModel GetChatCompletions(Agent agent, List<RoleDialogModel> con |
34 | 36 | hook.BeforeGenerating(agent, conversations)).ToArray()); |
35 | 37 |
|
36 | 38 | var client = new GooglePalmClient(apiKey: _settings.PaLM.ApiKey); |
37 | | - var messages = conversations.Select(c => new PalmChatMessage(c.Content, c.Role == AgentRole.User ? "user" : "AI")) |
38 | | - .ToList(); |
39 | 39 |
|
40 | | - var agentService = _services.GetRequiredService<IAgentService>(); |
41 | | - var instruction = agentService.RenderedInstruction(agent); |
42 | | - var response = client.ChatAsync(messages, instruction, null).Result; |
| 40 | + var (prompt, messages, hasFunctions) = PrepareOptions(agent, conversations); |
43 | 41 |
|
44 | | - var message = response.Candidates.First(); |
45 | | - var msg = new RoleDialogModel(AgentRole.Assistant, message.Content) |
| 42 | + RoleDialogModel msg; |
| 43 | + |
| 44 | + if (hasFunctions) |
| 45 | + { |
| 46 | + // use text completion |
| 47 | + // var response = client.GenerateTextAsync(prompt, null).Result; |
| 48 | + var response = client.ChatAsync(new PalmChatCompletionRequest |
| 49 | + { |
| 50 | + Context = prompt, |
| 51 | + Messages = messages, |
| 52 | + Temperature = 0.1f |
| 53 | + }).Result; |
| 54 | + |
| 55 | + var message = response.Candidates.First(); |
| 56 | + |
| 57 | + // check if returns function calling |
| 58 | + var llmResponse = message.Content.JsonContent<FunctionCallingResponse>(); |
| 59 | + |
| 60 | + msg = new RoleDialogModel(llmResponse.Role, llmResponse.Content) |
| 61 | + { |
| 62 | + CurrentAgentId = agent.Id, |
| 63 | + FunctionName = llmResponse.FunctionName, |
| 64 | + FunctionArgs = JsonSerializer.Serialize(llmResponse.Args) |
| 65 | + }; |
| 66 | + } |
| 67 | + else |
46 | 68 | { |
47 | | - CurrentAgentId = agent.Id |
48 | | - }; |
| 69 | + var response = client.ChatAsync(messages, context: prompt, examples: null, options: null).Result; |
| 70 | + |
| 71 | + var message = response.Candidates.First(); |
| 72 | + |
| 73 | + // check if returns function calling |
| 74 | + var llmResponse = message.Content.JsonContent<FunctionCallingResponse>(); |
| 75 | + |
| 76 | + msg = new RoleDialogModel(llmResponse.Role, llmResponse.Content ?? message.Content) |
| 77 | + { |
| 78 | + CurrentAgentId = agent.Id |
| 79 | + }; |
| 80 | + } |
49 | 81 |
|
50 | 82 | // After chat completion hook |
51 | 83 | Task.WaitAll(hooks.Select(hook => |
52 | 84 | hook.AfterGenerated(msg, new TokenStatsModel |
53 | 85 | { |
| 86 | + Prompt = prompt, |
54 | 87 | Model = _model |
55 | 88 | })).ToArray()); |
56 | 89 |
|
57 | 90 | return msg; |
58 | 91 | } |
59 | 92 |
|
| 93 | + private (string, List<PalmChatMessage>, bool) PrepareOptions(Agent agent, List<RoleDialogModel> conversations) |
| 94 | + { |
| 95 | + var prompt = ""; |
| 96 | + |
| 97 | + var agentService = _services.GetRequiredService<IAgentService>(); |
| 98 | + |
| 99 | + if (!string.IsNullOrEmpty(agent.Instruction)) |
| 100 | + { |
| 101 | + prompt += agentService.RenderedInstruction(agent); |
| 102 | + } |
| 103 | + |
| 104 | + var routing = _services.GetRequiredService<IRoutingService>(); |
| 105 | + var router = routing.Router; |
| 106 | + |
| 107 | + var messages = conversations.Select(c => new PalmChatMessage(c.Content, c.Role == AgentRole.User ? "user" : "AI")) |
| 108 | + .ToList(); |
| 109 | + |
| 110 | + if (agent.Functions != null && agent.Functions.Count > 0) |
| 111 | + { |
| 112 | + prompt += "\r\n\r\n[Functions] defined in JSON Schema:\r\n"; |
| 113 | + prompt += JsonSerializer.Serialize(agent.Functions, new JsonSerializerOptions |
| 114 | + { |
| 115 | + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, |
| 116 | + WriteIndented = true |
| 117 | + }); |
| 118 | + |
| 119 | + prompt += "\r\n\r\n[Conversations]\r\n"; |
| 120 | + foreach (var dialog in conversations) |
| 121 | + { |
| 122 | + prompt += dialog.Role == AgentRole.Function ? |
| 123 | + $"{dialog.Role}: {dialog.FunctionName} => {dialog.Content}\r\n" : |
| 124 | + $"{dialog.Role}: {dialog.Content}\r\n"; |
| 125 | + } |
| 126 | + |
| 127 | + prompt += "\r\n\r\n" + router.Templates.FirstOrDefault(x => x.Name == "response_with_function").Content; |
| 128 | + |
| 129 | + return (prompt, new List<PalmChatMessage> |
| 130 | + { |
| 131 | + new PalmChatMessage("Which function should be used for the next step based on latest user or function response, output your response in JSON:", AgentRole.User), |
| 132 | + }, true); |
| 133 | + } |
| 134 | + |
| 135 | + return (prompt, messages, false); |
| 136 | + } |
| 137 | + |
60 | 138 | public Task<bool> GetChatCompletionsAsync(Agent agent, List<RoleDialogModel> conversations, Func<RoleDialogModel, Task> onMessageReceived, Func<RoleDialogModel, Task> onFunctionExecuting) |
61 | 139 | { |
62 | 140 | throw new NotImplementedException(); |
|
0 commit comments