diff --git a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Controllers/RoutingSpeederController.cs b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Controllers/RoutingSpeederController.cs index 67cb0c401..e55d95a34 100644 --- a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Controllers/RoutingSpeederController.cs +++ b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Controllers/RoutingSpeederController.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Text; using System.Threading.Tasks; +using BotSharp.Abstraction.Conversations.Models; using BotSharp.Plugin.RoutingSpeeder.Providers; using BotSharp.Plugin.RoutingSpeeder.Providers.Models; using Microsoft.AspNetCore.Authorization; @@ -24,9 +25,16 @@ public RoutingSpeederController(IServiceProvider service) public IActionResult TrainIntentClassifier(TrainingParams trainingParams) { var intentClassifier = _service.GetRequiredService(); - intentClassifier.InitClassifer(trainingParams.Inference); intentClassifier.Train(trainingParams); return Ok(intentClassifier.Labels); } + [HttpPost("/routing-speeder/classifier/inference")] + public IActionResult InferenceIntentClassifier([FromBody] DialoguePredictionModel message) + { + var intentClassifier = _service.GetRequiredService(); + var vector = intentClassifier.GetTextEmbedding(message.Text); + var predText = intentClassifier.Predict(vector); + return Ok(predText); + } } diff --git a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs index 08acb6736..977e43bb2 100644 --- a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs +++ b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs @@ -15,33 +15,33 @@ using System.Linq; using Tensorflow.Keras; using BotSharp.Abstraction.Agents; +using Microsoft.Extensions.Logging; namespace BotSharp.Plugin.RoutingSpeeder.Providers; public class IntentClassifier { private readonly IServiceProvider _services; + private readonly ILogger _logger; private KnowledgeBaseSettings _knowledgeBaseSettings; Model _model; public Model model => _model; private bool _isModelReady; public bool isModelReady => _isModelReady; private ClassifierSetting _settings; + private bool _inferenceMode = true; private string[] _labels; - public string[] Labels => GetLabels(); - private int _numLabels - { - get - { - return Labels.Length; - } - } + public string[] Labels => _labels == null ? GetLabels() : _labels; - public IntentClassifier(IServiceProvider services, ClassifierSetting settings, KnowledgeBaseSettings knowledgeBaseSettings) + public IntentClassifier(IServiceProvider services, + ClassifierSetting settings, + KnowledgeBaseSettings knowledgeBaseSettings, + ILogger logger) { _services = services; _settings = settings; _knowledgeBaseSettings = knowledgeBaseSettings; + _logger = logger; } private void Reset() @@ -65,7 +65,7 @@ private void Build() keras.layers.InputLayer((vector.Dimension), name: "Input"), keras.layers.Dense(256, activation:"relu"), keras.layers.Dense(256, activation:"relu"), - keras.layers.Dense(_numLabels, activation: keras.activations.Softmax) + keras.layers.Dense(GetLabels().Length, activation: keras.activations.Softmax) }; _model = keras.Sequential(layers); @@ -73,7 +73,6 @@ private void Build() Console.WriteLine(); _model.summary(); #endif - _isModelReady = true; } private void Fit(NDArray x, NDArray y, TrainingParams trainingParams) @@ -97,7 +96,7 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams) earlyStop }; - var weights = LoadWeights(trainingParams.Inference); + var weights = LoadWeights(); _model.fit(x, y, batch_size: trainingParams.BatchSize, @@ -110,15 +109,15 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams) _isModelReady = true; } - public string LoadWeights(bool inference = true) + public string LoadWeights() { var agentService = _services.CreateScope() .ServiceProvider .GetRequiredService(); - var weightsFile = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, $"intent-classifier.h5"); + var weightsFile = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, _settings.WEIGHT_FILE_NAME); - if (File.Exists(weightsFile) && inference) + if (File.Exists(weightsFile) && _inferenceMode) { _model.load_weights(weightsFile); _isModelReady = true; @@ -126,8 +125,9 @@ public string LoadWeights(bool inference = true) } else { - var logInfo = inference ? "No available weights." : "Will implement model training process and write trained weights into local"; - Console.WriteLine(logInfo); + var logInfo = _inferenceMode ? "No available weights." : "Will implement model training process and write trained weights into local"; + _isModelReady = false; + _logger.LogInformation(logInfo); } return weightsFile; @@ -159,7 +159,14 @@ public NDArray GetTextEmbedding(string text) if (!Directory.Exists(rootDirectory)) { - throw new Exception($"No training data found! Please put training data in this path: {rootDirectory}"); + Directory.CreateDirectory(rootDirectory); + } + + int numFiles = Directory.GetFiles(rootDirectory).Length; + + if (numFiles == 0) + { + throw new Exception($"No dialogue data found in {rootDirectory} folder! Please put dialogue data in this path: {rootDirectory}"); } // Do embedding and store results @@ -214,25 +221,32 @@ public string[] GetFiles(string prefix = "") public string[] GetLabels() { - if (_labels == null) + var agentService = _services.CreateScope() + .ServiceProvider + .GetRequiredService(); + string labelPath = Path.Combine( + agentService.GetDataDir(), + _settings.MODEL_DIR, + _settings.LABEL_FILE_NAME); + + if (_inferenceMode) { - var agentService = _services.CreateScope() - .ServiceProvider - .GetRequiredService(); - - string[] labels = GetFiles() + if (_labels == null) + { + if (!File.Exists(labelPath)) + { + throw new Exception($"Label file doesn't exist. Please training model first or move label.txt to {labelPath}"); + } + _labels = File.ReadAllLines(labelPath); + } + } + else + { + _labels = GetFiles() .Select(x => Path.GetFileName(x).Split(".")[^2]) + .OrderBy(x => x) .ToArray(); - - string writePath = Path.Combine( - agentService.GetDataDir(), - _settings.MODEL_DIR, - _settings.LABEL_FILE_NAME); - - _labels = labels.OrderBy(x => x).ToArray(); - - // Write labels into the local txt file - File.WriteAllLines(writePath, _labels); + File.WriteAllLines(labelPath, _labels); } return _labels; } @@ -248,24 +262,25 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f) var prob = _model.predict(vector).numpy(); var probLabel = tf.arg_max(prob, -1).numpy().ToArray(); prob = np.squeeze(prob, axis: 0); + var labelIndex = probLabel[0]; if (prob[probLabel[0]] < confidenceScore) { return string.Empty; } - var labelIndex = probLabel[0]; return _labels[labelIndex]; } - public void InitClassifer(bool inference = true) + public void InitClassifer() { Reset(); Build(); - LoadWeights(inference); + LoadWeights(); } public void Train(TrainingParams trainingParams) { + _inferenceMode = false; Reset(); (var x, var y) = PrepareLoadData(); Build(); diff --git a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/Models/DialoguePredictionModel.cs b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/Models/DialoguePredictionModel.cs index 4641b9cdb..36b7c38cd 100644 --- a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/Models/DialoguePredictionModel.cs +++ b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/Models/DialoguePredictionModel.cs @@ -7,7 +7,7 @@ namespace BotSharp.Plugin.RoutingSpeeder.Providers.Models; public class DialoguePredictionModel { public int Id { get; set; } - public string text { get; set; } - public string? label { get; set; } - public string? prediction { get; set; } + public string Text { get; set; } + public string? Label { get; set; } + public string? Prediction { get; set; } } diff --git a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Settings/classifierSetting.cs b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Settings/classifierSetting.cs index 09bda6e86..5e8c34c8b 100644 --- a/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Settings/classifierSetting.cs +++ b/src/Plugins/BotSharp.Plugin.RoutingSpeeder/Settings/classifierSetting.cs @@ -7,13 +7,14 @@ namespace BotSharp.Plugin.RoutingSpeeder.Settings; public class ClassifierSetting { public Dictionary LabelMappingDict { get; set; } = new Dictionary() - { - {"goodbye", 0f}, - {"greeting", 1f}, - {"other", 2f} - }; + { + {"goodbye", 0f}, + {"greeting", 1f}, + {"other", 2f} + }; public string RAW_DATA_DIR { get; set; } = "raw_data"; public string MODEL_DIR { get; set; } = "models"; public string LABEL_FILE_NAME { get; set; } = "label.txt"; + public string WEIGHT_FILE_NAME { get; set; } = "intent-classifier.h5"; }