Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ public interface IPlanningHook
Task<string> GetSummaryAdditionalRequirements(string planner, RoleDialogModel message)
=> Task.FromResult(string.Empty);

Task OnSourceCodeGenerated(string planner, RoleDialogModel msg, string language)
=> Task.CompletedTask;

Task OnPlanningCompleted(string planner, RoleDialogModel msg)
=> Task.CompletedTask;
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class OneStepForwardReasoner : IRoutingReasoner
private readonly IServiceProvider _services;
private readonly ILogger _logger;

public OneStepForwardReasoner(IServiceProvider services, ILogger<NaiveReasoner> logger)
public OneStepForwardReasoner(IServiceProvider services, ILogger<OneStepForwardReasoner> logger)
{
_services = services;
_logger = logger;
Expand Down Expand Up @@ -116,7 +116,7 @@ public async Task<bool> AgentExecuted(Agent router, FunctionCallFromLlm inst, Ro
}
else
{
context.Empty(reason: $"Agent queue is cleared by {nameof(NaiveReasoner)}");
context.Empty(reason: $"Agent queue is cleared by {nameof(OneStepForwardReasoner)}");
// context.Push(inst.OriginalAgent, "Push user goal agent");
}
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class SequentialReasoner : IRoutingReasoner
public int MaxLoopCount => 100;
private FunctionCallFromLlm _lastInst;

public SequentialReasoner(IServiceProvider services, ILogger<NaiveReasoner> logger)
public SequentialReasoner(IServiceProvider services, ILogger<SequentialReasoner> logger)
{
_services = services;
_logger = logger;
Expand Down
10 changes: 6 additions & 4 deletions src/Plugins/BotSharp.Plugin.Planner/Functions/SummaryPlanFn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,15 @@ public async Task<bool> Execute(RoleDialogModel message)
var summary = await GetAiResponse(plannerAgent);
message.Content = summary.Content;

// Validate the sql result
// Emit event if the sql statement is generated by planner
var args = JsonSerializer.Deserialize<SummaryPlan>(message.FunctionArgs);
if (args.IsSqlTemplate == false)
if (args != null && !args.IsSqlTemplate && args.ContainsSqlStatements)
{
await fn.InvokeFunction("validate_sql", message);
await HookEmitter.Emit<IPlanningHook>(_services, async hook =>
await hook.OnSourceCodeGenerated(nameof(TwoStageTaskPlanner), message, "sql")
);
}

await HookEmitter.Emit<IPlanningHook>(_services, async hook =>
await hook.OnPlanningCompleted(nameof(TwoStageTaskPlanner), message)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ public class SummaryPlan
{
[JsonPropertyName("is_sql_template")]
public bool IsSqlTemplate { get; set; } = false;

[JsonPropertyName("contains_sql_statements")]
public bool ContainsSqlStatements { get; set; } = false;
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"id": "282a7128-69a1-44b0-878c-a9159b88f3b9",
"name": "Planner",
"description": "Plan feasible implementation steps for user task request",
"description": "Plan feasible implementation steps for complex user task request",
"type": "task",
"createdDateTime": "2023-08-27T10:39:00Z",
"updatedDateTime": "2023-08-27T14:39:00Z",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
"type": "boolean",
"description": "If user request is to generate sql template instead of actual sql statement."
},
"contains_sql_statements": {
"type": "boolean",
"description": "Set to true if the response contains sql statements."
},
"related_tables": {
"type": "array",
"description": "table name in planning steps",
Expand All @@ -17,6 +21,6 @@
}
}
},
"required": [ "related_tables", "is_sql_template" ]
"required": [ "related_tables", "is_sql_template", "contains_sql_statements" ]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,23 @@ public SqlDriverPlanningHook(IServiceProvider services)
_services = services;
}

public async Task OnPlanningCompleted(string planner, RoleDialogModel msg)
public async Task OnSourceCodeGenerated(string planner, RoleDialogModel msg, string language)
{
// envoke validate
if (language != "sql")
{
return;
}

var routing = _services.GetRequiredService<IRoutingService>();
await routing.InvokeFunction("validate_sql", msg);

await HookEmitter.Emit<ISqlDriverHook>(_services, async (hook) =>
{
await hook.SqlGenerated(msg);
});

var settings = _services.GetRequiredService<SqlDriverSetting>();
var settings = _services.GetRequiredService<SqlDriverSetting>();
if (!settings.ExecuteSqlSelectAutonomous)
{
var conversationStateService = _services.GetRequiredService<IConversationStateService>();
Expand All @@ -51,7 +60,6 @@ await HookEmitter.Emit<ISqlDriverHook>(_services, async (hook) =>
var response = await completion.GetChatCompletions(agent, wholeDialogs);

// Invoke "execute_sql"
var routing = _services.GetRequiredService<IRoutingService>();
await routing.InvokeFunction(response.FunctionName, response);

msg.CurrentAgentId = agent.Id;
Expand All @@ -61,6 +69,11 @@ await HookEmitter.Emit<ISqlDriverHook>(_services, async (hook) =>
msg.StopCompletion = response.StopCompletion;
}

public async Task OnPlanningCompleted(string planner, RoleDialogModel msg)
{

}

public async Task<string> GetSummaryAdditionalRequirements(string planner, RoleDialogModel message)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
Expand Down