Skip to content

Commit 679d3a8

Browse files
authored
[REFACTOR] Refactor JSONFFI Conv template (#2331)
This PR refactors JSONFFI conv template to use immutable processing. This helps to prevent bugs from multiple requests and concurrent access to the conversation data structure. It also reduces the need to deep copy the struct.
1 parent 45a0487 commit 679d3a8

File tree

9 files changed

+348
-254
lines changed

9 files changed

+348
-254
lines changed

cpp/json_ffi/conv_template.cc

Lines changed: 190 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ ModelConfig ModelConfig::FromJSON(const picojson::object& json_obj) {
131131

132132
/****************** Conversation template ******************/
133133

134-
std::map<MessagePlaceholders, std::string> PLACEHOLDERS = {
134+
std::unordered_map<MessagePlaceholders, std::string> PLACEHOLDERS = {
135135
{MessagePlaceholders::SYSTEM, "{system_message}"},
136136
{MessagePlaceholders::USER, "{user_message}"},
137137
{MessagePlaceholders::ASSISTANT, "{assistant_message}"},
@@ -153,120 +153,213 @@ Conversation::Conversation()
153153
{"assistant", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]},
154154
{"tool", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {}
155155

156-
Result<std::vector<Data>> Conversation::AsPrompt(ModelConfig config, DLDevice device) {
157-
using TResult = Result<std::vector<Data>>;
158-
// Get the system message
159-
std::string system_msg = system_template;
160-
size_t pos = system_msg.find(PLACEHOLDERS[MessagePlaceholders::SYSTEM]);
156+
std::string Conversation::GetSystemText(const std::string& system_msg) const {
157+
std::string system_text = this->system_template;
158+
static std::string system_placeholder = PLACEHOLDERS[MessagePlaceholders::SYSTEM];
159+
size_t pos = system_text.find(system_placeholder);
161160
if (pos != std::string::npos) {
162-
system_msg.replace(pos, PLACEHOLDERS[MessagePlaceholders::SYSTEM].length(),
163-
this->system_message);
161+
system_text.replace(pos, system_placeholder.length(), system_msg);
164162
}
163+
return system_text;
164+
}
165165

166-
// Get the message strings
167-
std::vector<Data> message_list;
168-
std::vector<std::string> separators = seps;
169-
if (separators.size() == 1) {
170-
separators.push_back(separators[0]);
166+
std::string Conversation::GetRoleText(const std::string& role, const std::string& content,
167+
const std::optional<std::string>& fn_call_string) const {
168+
std::string role_text = this->role_templates.at(role);
169+
std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)];
170+
size_t pos = role_text.find(placeholder);
171+
if (pos != std::string::npos) {
172+
role_text.replace(pos, placeholder.length(), content);
173+
}
174+
if (fn_call_string) {
175+
// replace placeholder[FUNCTION] with function_string
176+
// this assumes function calling is used for a single request scenario only
177+
pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]);
178+
if (pos != std::string::npos) {
179+
role_text.replace(pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length(),
180+
fn_call_string.value());
181+
}
171182
}
183+
return role_text;
184+
}
172185

173-
if (!system_msg.empty()) {
174-
system_msg += separators[0];
175-
message_list.push_back(TextData(system_message));
186+
/// Try to detect if function calling is needed, if so, return the function calling string
187+
Result<std::optional<std::string>> TryGetFunctionCallingString(
188+
const Conversation& conv, const ChatCompletionRequest& request) {
189+
using TResult = Result<std::optional<std::string>>;
190+
if (!request.tools.has_value() ||
191+
(request.tool_choice.has_value() && request.tool_choice.value() == "none")) {
192+
return TResult::Ok(std::nullopt);
193+
}
194+
std::vector<ChatTool> tools_ = request.tools.value();
195+
std::string tool_choice_ = request.tool_choice.value();
196+
197+
// TODO: support with tool choice as dict
198+
for (const auto& tool : tools_) {
199+
if (tool.function.name == tool_choice_) {
200+
picojson::value function_str(tool.function.AsJSON());
201+
return TResult::Ok(function_str.serialize());
202+
}
176203
}
177204

178-
for (int i = 0; i < messages.size(); i++) {
179-
std::string role = messages[i].role;
180-
// Todo(mlc-team): support content to be a single string.
181-
std::optional<std::vector<std::unordered_map<std::string, std::string>>> content =
182-
messages[i].content;
183-
if (roles.find(role) == roles.end()) {
184-
return TResult::Error("Role \"" + role + "\" is not supported");
185-
}
205+
if (tool_choice_ != "auto") {
206+
return TResult::Error("Invalid tool_choice value in the request: " + tool_choice_);
207+
}
208+
209+
picojson::array function_list;
210+
for (const auto& tool : tools_) {
211+
function_list.push_back(picojson::value(tool.function.AsJSON()));
212+
}
186213

187-
std::string separator = separators[role == "assistant"]; // check assistant role
214+
picojson::value function_list_json(function_list);
215+
return TResult::Ok(function_list_json.serialize());
216+
};
188217

189-
// If content is empty, add the role and separator
190-
// assistant's turn to generate text
191-
if (!content.has_value()) {
192-
message_list.push_back(TextData(roles[role] + role_empty_sep));
193-
continue;
194-
}
218+
Result<std::vector<Data>> CreatePrompt(const Conversation& conv,
219+
const ChatCompletionRequest& request,
220+
const ModelConfig& config, DLDevice device) {
221+
using TResult = Result<std::vector<Data>>;
222+
223+
Result<std::optional<std::string>> fn_call_str_tmp = TryGetFunctionCallingString(conv, request);
224+
if (fn_call_str_tmp.IsErr()) {
225+
return TResult::Error(fn_call_str_tmp.UnwrapErr());
226+
}
227+
std::optional<std::string> fn_call_string = fn_call_str_tmp.Unwrap();
228+
229+
// Handle system message
230+
// concz
231+
bool has_custom_system = false;
232+
std::string custom_system_inputs;
195233

196-
std::string message = "";
197-
std::string role_prefix = "";
198-
// Do not append role prefix if this is the first message and there
199-
// is already a system message
200-
if (add_role_after_system_message || system_msg.empty() || i != 0) {
201-
role_prefix = roles[role] + role_content_sep;
234+
auto f_populate_system_message = [&](const std::vector<ChatCompletionMessage>& msg_vec) {
235+
for (ChatCompletionMessage msg : msg_vec) {
236+
if (msg.role == "system") {
237+
ICHECK(msg.content.IsText()) << "System message must be text";
238+
custom_system_inputs += msg.content.Text();
239+
has_custom_system = true;
240+
}
202241
}
242+
};
243+
// go through messages in template and passed in.
244+
f_populate_system_message(conv.messages);
245+
f_populate_system_message(request.messages);
203246

204-
message += role_prefix;
247+
// pending text records the text to be put into data
248+
// we lazily accumulate the pending text
249+
// to reduce amount of segments in the Data vector
250+
std::string pending_text =
251+
conv.GetSystemText(has_custom_system ? custom_system_inputs : conv.system_message);
205252

206-
for (const auto& item : content.value()) {
207-
auto it_type = item.find("type");
208-
if (it_type == item.end()) {
209-
return TResult::Error("The content of a message does not have \"type\" field");
253+
// the seperator after system message.
254+
if (!pending_text.empty()) {
255+
pending_text += conv.seps[0];
256+
}
257+
258+
// Get the message strings
259+
std::vector<Data> message_list;
260+
size_t non_system_msg_count = 0;
261+
262+
// returns error if error happens
263+
auto f_process_messages =
264+
[&](const std::vector<ChatCompletionMessage>& msg_vec) -> std::optional<TResult> {
265+
for (size_t i = 0; i < msg_vec.size(); ++i) {
266+
const ChatCompletionMessage& msg = msg_vec[i];
267+
auto role_it = conv.roles.find(msg.role);
268+
if (role_it == conv.roles.end()) {
269+
return TResult::Error("Role \"" + msg.role + "\" is not supported");
210270
}
211-
if (it_type->second == "text") {
212-
auto it_text = item.find("text");
213-
if (it_text == item.end()) {
214-
return TResult::Error("The text type content of a message does not have \"text\" field");
215-
}
216-
// replace placeholder[ROLE] with input message from role
217-
std::string role_text = role_templates[role];
218-
std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)];
219-
size_t pos = role_text.find(placeholder);
220-
if (pos != std::string::npos) {
221-
role_text.replace(pos, placeholder.length(), it_text->second);
222-
}
223-
if (use_function_calling) {
224-
// replace placeholder[FUNCTION] with function_string
225-
// this assumes function calling is used for a single request scenario only
226-
if (!function_string.has_value()) {
227-
return TResult::Error(
228-
"The function string in conversation template is not defined for function "
229-
"calling.");
271+
const std::string& role_name = role_it->second;
272+
// skip system message as it is already processed
273+
if (msg.role == "system") continue;
274+
// skip when content is empty
275+
if (msg.content.IsNull()) {
276+
pending_text += role_name + conv.role_empty_sep;
277+
continue;
278+
}
279+
++non_system_msg_count;
280+
// assistant uses conv.seps[1] if there are two seps
281+
int sep_offset = msg.role == "assistant" ? 1 : 0;
282+
const std::string& seperator = conv.seps[sep_offset % conv.seps.size()];
283+
// setup role prefix
284+
std::string role_prefix = "";
285+
// Do not append role prefix if this is the first message and there is already a system
286+
// message
287+
if (conv.add_role_after_system_message || pending_text.empty() || non_system_msg_count != 1) {
288+
role_prefix = role_name + conv.role_content_sep;
289+
}
290+
pending_text += role_prefix;
291+
292+
if (msg.content.IsParts()) {
293+
for (const auto& item : msg.content.Parts()) {
294+
auto it_type = item.find("type");
295+
if (it_type == item.end()) {
296+
return TResult::Error("The content of a message does not have \"type\" field");
230297
}
231-
pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]);
232-
if (pos != std::string::npos) {
233-
role_text.replace(pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length(),
234-
function_string.value());
298+
if (it_type->second == "text") {
299+
auto it_text = item.find("text");
300+
if (it_text == item.end()) {
301+
return TResult::Error(
302+
"The text type content of a message does not have \"text\" field");
303+
}
304+
// replace placeholder[ROLE] with input message from role
305+
pending_text += conv.GetRoleText(msg.role, it_text->second, fn_call_string);
306+
} else if (it_type->second == "image_url") {
307+
if (item.find("image_url") == item.end()) {
308+
return TResult::Error("Content should have an image_url field");
309+
}
310+
std::string image_url =
311+
item.at("image_url"); // TODO(mlc-team): According to OpenAI API reference this
312+
// should be a map, with a "url" key containing the URL, but
313+
// we are just assuming this as the URL for now
314+
std::string base64_image = image_url.substr(image_url.find(",") + 1);
315+
Result<NDArray> image_data_res = LoadImageFromBase64(base64_image);
316+
if (image_data_res.IsErr()) {
317+
return TResult::Error(image_data_res.UnwrapErr());
318+
}
319+
if (!config.vision_config.has_value()) {
320+
return TResult::Error("Vision config is required for image input");
321+
}
322+
int image_size = config.vision_config.value().image_size;
323+
int patch_size = config.vision_config.value().patch_size;
324+
325+
int embed_size = (image_size * image_size) / (patch_size * patch_size);
326+
327+
auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device);
328+
// lazily commit text data
329+
if (pending_text.length() != 0) {
330+
message_list.push_back(TextData(pending_text));
331+
pending_text = "";
332+
}
333+
message_list.push_back(ImageData(image_ndarray, embed_size));
334+
} else {
335+
return TResult::Error("Unsupported content type: " + it_type->second);
235336
}
236337
}
237-
message += role_text;
238-
} else if (it_type->second == "image_url") {
239-
if (item.find("image_url") == item.end()) {
240-
return TResult::Error("Content should have an image_url field");
241-
}
242-
std::string image_url =
243-
item.at("image_url"); // TODO(mlc-team): According to OpenAI API reference this
244-
// should be a map, with a "url" key containing the URL, but
245-
// we are just assuming this as the URL for now
246-
std::string base64_image = image_url.substr(image_url.find(",") + 1);
247-
Result<NDArray> image_data_res = LoadImageFromBase64(base64_image);
248-
if (image_data_res.IsErr()) {
249-
return TResult::Error(image_data_res.UnwrapErr());
250-
}
251-
if (!config.vision_config.has_value()) {
252-
return TResult::Error("Vision config is required for image input");
253-
}
254-
int image_size = config.vision_config.value().image_size;
255-
int patch_size = config.vision_config.value().patch_size;
256-
257-
int embed_size = (image_size * image_size) / (patch_size * patch_size);
258-
259-
auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device);
260-
message_list.push_back(ImageData(image_ndarray, embed_size));
261338
} else {
262-
return TResult::Error("Unsupported content type: " + it_type->second);
339+
ICHECK(msg.content.IsText());
340+
pending_text += conv.GetRoleText(msg.role, msg.content.Text(), fn_call_string);
263341
}
342+
pending_text += seperator;
264343
}
344+
return std::nullopt;
345+
};
265346

266-
message += separator;
267-
message_list.push_back(TextData(message));
347+
if (auto err = f_process_messages(conv.messages)) {
348+
return err.value();
349+
}
350+
if (auto err = f_process_messages(request.messages)) {
351+
return err.value();
352+
}
353+
// append last assistant begin message
354+
ChatCompletionMessage last_assistant_begin;
355+
last_assistant_begin.role = "assistant";
356+
last_assistant_begin.content = std::nullopt;
357+
if (auto err = f_process_messages({last_assistant_begin})) {
358+
return err.value();
359+
}
360+
if (pending_text.length() != 0) {
361+
message_list.push_back(TextData(pending_text));
268362
}
269-
270363
return TResult::Ok(message_list);
271364
}
272365

@@ -383,7 +476,10 @@ Result<Conversation> Conversation::FromJSON(const picojson::object& json_obj) {
383476
content.push_back(std::move(item_map));
384477
}
385478
}
386-
conv.messages.push_back({role_res.Unwrap(), content});
479+
ChatCompletionMessage msg;
480+
msg.role = role_res.Unwrap();
481+
msg.content = content;
482+
conv.messages.push_back(msg);
387483
}
388484

389485
Result<picojson::array> seps_arr_res =
@@ -438,21 +534,6 @@ Result<Conversation> Conversation::FromJSON(const picojson::object& json_obj) {
438534
}
439535
conv.stop_token_ids.push_back(stop.get<int64_t>());
440536
}
441-
442-
Result<std::optional<std::string>> function_string_res =
443-
json::LookupOptionalWithResultReturn<std::string>(json_obj, "function_string");
444-
if (function_string_res.IsErr()) {
445-
return TResult::Error(function_string_res.UnwrapErr());
446-
}
447-
conv.function_string = function_string_res.Unwrap();
448-
449-
Result<bool> use_function_calling_res = json::LookupOrDefaultWithResultReturn<bool>(
450-
json_obj, "use_function_calling", conv.use_function_calling);
451-
if (use_function_calling_res.IsErr()) {
452-
return TResult::Error(use_function_calling_res.UnwrapErr());
453-
}
454-
conv.use_function_calling = use_function_calling_res.Unwrap();
455-
456537
return TResult::Ok(conv);
457538
}
458539

0 commit comments

Comments
 (0)