@@ -131,7 +131,7 @@ ModelConfig ModelConfig::FromJSON(const picojson::object& json_obj) {
131131
132132/* ***************** Conversation template ******************/
133133
134- std::map <MessagePlaceholders, std::string> PLACEHOLDERS = {
134+ std::unordered_map <MessagePlaceholders, std::string> PLACEHOLDERS = {
135135 {MessagePlaceholders::SYSTEM, " {system_message}" },
136136 {MessagePlaceholders::USER, " {user_message}" },
137137 {MessagePlaceholders::ASSISTANT, " {assistant_message}" },
@@ -153,120 +153,213 @@ Conversation::Conversation()
153153 {" assistant" , PLACEHOLDERS[MessagePlaceholders::ASSISTANT]},
154154 {" tool" , PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {}
155155
156- Result<std::vector<Data>> Conversation::AsPrompt (ModelConfig config, DLDevice device) {
157- using TResult = Result<std::vector<Data>>;
158- // Get the system message
159- std::string system_msg = system_template;
160- size_t pos = system_msg.find (PLACEHOLDERS[MessagePlaceholders::SYSTEM]);
156+ std::string Conversation::GetSystemText (const std::string& system_msg) const {
157+ std::string system_text = this ->system_template ;
158+ static std::string system_placeholder = PLACEHOLDERS[MessagePlaceholders::SYSTEM];
159+ size_t pos = system_text.find (system_placeholder);
161160 if (pos != std::string::npos) {
162- system_msg.replace (pos, PLACEHOLDERS[MessagePlaceholders::SYSTEM].length (),
163- this ->system_message );
161+ system_text.replace (pos, system_placeholder.length (), system_msg);
164162 }
163+ return system_text;
164+ }
165165
166- // Get the message strings
167- std::vector<Data> message_list;
168- std::vector<std::string> separators = seps;
169- if (separators.size () == 1 ) {
170- separators.push_back (separators[0 ]);
166+ std::string Conversation::GetRoleText (const std::string& role, const std::string& content,
167+ const std::optional<std::string>& fn_call_string) const {
168+ std::string role_text = this ->role_templates .at (role);
169+ std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString (role)];
170+ size_t pos = role_text.find (placeholder);
171+ if (pos != std::string::npos) {
172+ role_text.replace (pos, placeholder.length (), content);
173+ }
174+ if (fn_call_string) {
175+ // replace placeholder[FUNCTION] with function_string
176+ // this assumes function calling is used for a single request scenario only
177+ pos = role_text.find (PLACEHOLDERS[MessagePlaceholders::FUNCTION]);
178+ if (pos != std::string::npos) {
179+ role_text.replace (pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length (),
180+ fn_call_string.value ());
181+ }
171182 }
183+ return role_text;
184+ }
172185
173- if (!system_msg.empty ()) {
174- system_msg += separators[0 ];
175- message_list.push_back (TextData (system_message));
186+ // / Try to detect if function calling is needed, if so, return the function calling string
187+ Result<std::optional<std::string>> TryGetFunctionCallingString (
188+ const Conversation& conv, const ChatCompletionRequest& request) {
189+ using TResult = Result<std::optional<std::string>>;
190+ if (!request.tools .has_value () ||
191+ (request.tool_choice .has_value () && request.tool_choice .value () == " none" )) {
192+ return TResult::Ok (std::nullopt );
193+ }
194+ std::vector<ChatTool> tools_ = request.tools .value ();
195+ std::string tool_choice_ = request.tool_choice .value ();
196+
197+ // TODO: support with tool choice as dict
198+ for (const auto & tool : tools_) {
199+ if (tool.function .name == tool_choice_) {
200+ picojson::value function_str (tool.function .AsJSON ());
201+ return TResult::Ok (function_str.serialize ());
202+ }
176203 }
177204
178- for ( int i = 0 ; i < messages. size (); i++ ) {
179- std::string role = messages[i]. role ;
180- // Todo(mlc-team): support content to be a single string.
181- std::optional<std::vector<std::unordered_map<std::string, std::string>>> content =
182- messages[i]. content ;
183- if (roles. find (role) == roles. end () ) {
184- return TResult::Error ( " Role \" " + role + " \" is not supported " );
185- }
205+ if (tool_choice_ != " auto " ) {
206+ return TResult::Error ( " Invalid tool_choice value in the request: " + tool_choice_) ;
207+ }
208+
209+ picojson::array function_list ;
210+ for ( const auto & tool : tools_ ) {
211+ function_list. push_back ( picojson::value (tool. function . AsJSON ()) );
212+ }
186213
187- std::string separator = separators[role == " assistant" ]; // check assistant role
214+ picojson::value function_list_json (function_list);
215+ return TResult::Ok (function_list_json.serialize ());
216+ };
188217
189- // If content is empty, add the role and separator
190- // assistant's turn to generate text
191- if (!content.has_value ()) {
192- message_list.push_back (TextData (roles[role] + role_empty_sep));
193- continue ;
194- }
218+ Result<std::vector<Data>> CreatePrompt (const Conversation& conv,
219+ const ChatCompletionRequest& request,
220+ const ModelConfig& config, DLDevice device) {
221+ using TResult = Result<std::vector<Data>>;
222+
223+ Result<std::optional<std::string>> fn_call_str_tmp = TryGetFunctionCallingString (conv, request);
224+ if (fn_call_str_tmp.IsErr ()) {
225+ return TResult::Error (fn_call_str_tmp.UnwrapErr ());
226+ }
227+ std::optional<std::string> fn_call_string = fn_call_str_tmp.Unwrap ();
228+
229+ // Handle system message
230+ // concz
231+ bool has_custom_system = false ;
232+ std::string custom_system_inputs;
195233
196- std::string message = " " ;
197- std::string role_prefix = " " ;
198- // Do not append role prefix if this is the first message and there
199- // is already a system message
200- if (add_role_after_system_message || system_msg.empty () || i != 0 ) {
201- role_prefix = roles[role] + role_content_sep;
234+ auto f_populate_system_message = [&](const std::vector<ChatCompletionMessage>& msg_vec) {
235+ for (ChatCompletionMessage msg : msg_vec) {
236+ if (msg.role == " system" ) {
237+ ICHECK (msg.content .IsText ()) << " System message must be text" ;
238+ custom_system_inputs += msg.content .Text ();
239+ has_custom_system = true ;
240+ }
202241 }
242+ };
243+ // go through messages in template and passed in.
244+ f_populate_system_message (conv.messages );
245+ f_populate_system_message (request.messages );
203246
204- message += role_prefix;
247+ // pending text records the text to be put into data
248+ // we lazily accumulate the pending text
249+ // to reduce amount of segments in the Data vector
250+ std::string pending_text =
251+ conv.GetSystemText (has_custom_system ? custom_system_inputs : conv.system_message );
205252
206- for (const auto & item : content.value ()) {
207- auto it_type = item.find (" type" );
208- if (it_type == item.end ()) {
209- return TResult::Error (" The content of a message does not have \" type\" field" );
253+ // the seperator after system message.
254+ if (!pending_text.empty ()) {
255+ pending_text += conv.seps [0 ];
256+ }
257+
258+ // Get the message strings
259+ std::vector<Data> message_list;
260+ size_t non_system_msg_count = 0 ;
261+
262+ // returns error if error happens
263+ auto f_process_messages =
264+ [&](const std::vector<ChatCompletionMessage>& msg_vec) -> std::optional<TResult> {
265+ for (size_t i = 0 ; i < msg_vec.size (); ++i) {
266+ const ChatCompletionMessage& msg = msg_vec[i];
267+ auto role_it = conv.roles .find (msg.role );
268+ if (role_it == conv.roles .end ()) {
269+ return TResult::Error (" Role \" " + msg.role + " \" is not supported" );
210270 }
211- if (it_type->second == " text" ) {
212- auto it_text = item.find (" text" );
213- if (it_text == item.end ()) {
214- return TResult::Error (" The text type content of a message does not have \" text\" field" );
215- }
216- // replace placeholder[ROLE] with input message from role
217- std::string role_text = role_templates[role];
218- std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString (role)];
219- size_t pos = role_text.find (placeholder);
220- if (pos != std::string::npos) {
221- role_text.replace (pos, placeholder.length (), it_text->second );
222- }
223- if (use_function_calling) {
224- // replace placeholder[FUNCTION] with function_string
225- // this assumes function calling is used for a single request scenario only
226- if (!function_string.has_value ()) {
227- return TResult::Error (
228- " The function string in conversation template is not defined for function "
229- " calling." );
271+ const std::string& role_name = role_it->second ;
272+ // skip system message as it is already processed
273+ if (msg.role == " system" ) continue ;
274+ // skip when content is empty
275+ if (msg.content .IsNull ()) {
276+ pending_text += role_name + conv.role_empty_sep ;
277+ continue ;
278+ }
279+ ++non_system_msg_count;
280+ // assistant uses conv.seps[1] if there are two seps
281+ int sep_offset = msg.role == " assistant" ? 1 : 0 ;
282+ const std::string& seperator = conv.seps [sep_offset % conv.seps .size ()];
283+ // setup role prefix
284+ std::string role_prefix = " " ;
285+ // Do not append role prefix if this is the first message and there is already a system
286+ // message
287+ if (conv.add_role_after_system_message || pending_text.empty () || non_system_msg_count != 1 ) {
288+ role_prefix = role_name + conv.role_content_sep ;
289+ }
290+ pending_text += role_prefix;
291+
292+ if (msg.content .IsParts ()) {
293+ for (const auto & item : msg.content .Parts ()) {
294+ auto it_type = item.find (" type" );
295+ if (it_type == item.end ()) {
296+ return TResult::Error (" The content of a message does not have \" type\" field" );
230297 }
231- pos = role_text.find (PLACEHOLDERS[MessagePlaceholders::FUNCTION]);
232- if (pos != std::string::npos) {
233- role_text.replace (pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length (),
234- function_string.value ());
298+ if (it_type->second == " text" ) {
299+ auto it_text = item.find (" text" );
300+ if (it_text == item.end ()) {
301+ return TResult::Error (
302+ " The text type content of a message does not have \" text\" field" );
303+ }
304+ // replace placeholder[ROLE] with input message from role
305+ pending_text += conv.GetRoleText (msg.role , it_text->second , fn_call_string);
306+ } else if (it_type->second == " image_url" ) {
307+ if (item.find (" image_url" ) == item.end ()) {
308+ return TResult::Error (" Content should have an image_url field" );
309+ }
310+ std::string image_url =
311+ item.at (" image_url" ); // TODO(mlc-team): According to OpenAI API reference this
312+ // should be a map, with a "url" key containing the URL, but
313+ // we are just assuming this as the URL for now
314+ std::string base64_image = image_url.substr (image_url.find (" ," ) + 1 );
315+ Result<NDArray> image_data_res = LoadImageFromBase64 (base64_image);
316+ if (image_data_res.IsErr ()) {
317+ return TResult::Error (image_data_res.UnwrapErr ());
318+ }
319+ if (!config.vision_config .has_value ()) {
320+ return TResult::Error (" Vision config is required for image input" );
321+ }
322+ int image_size = config.vision_config .value ().image_size ;
323+ int patch_size = config.vision_config .value ().patch_size ;
324+
325+ int embed_size = (image_size * image_size) / (patch_size * patch_size);
326+
327+ auto image_ndarray = ClipPreprocessor (image_data_res.Unwrap (), image_size, device);
328+ // lazily commit text data
329+ if (pending_text.length () != 0 ) {
330+ message_list.push_back (TextData (pending_text));
331+ pending_text = " " ;
332+ }
333+ message_list.push_back (ImageData (image_ndarray, embed_size));
334+ } else {
335+ return TResult::Error (" Unsupported content type: " + it_type->second );
235336 }
236337 }
237- message += role_text;
238- } else if (it_type->second == " image_url" ) {
239- if (item.find (" image_url" ) == item.end ()) {
240- return TResult::Error (" Content should have an image_url field" );
241- }
242- std::string image_url =
243- item.at (" image_url" ); // TODO(mlc-team): According to OpenAI API reference this
244- // should be a map, with a "url" key containing the URL, but
245- // we are just assuming this as the URL for now
246- std::string base64_image = image_url.substr (image_url.find (" ," ) + 1 );
247- Result<NDArray> image_data_res = LoadImageFromBase64 (base64_image);
248- if (image_data_res.IsErr ()) {
249- return TResult::Error (image_data_res.UnwrapErr ());
250- }
251- if (!config.vision_config .has_value ()) {
252- return TResult::Error (" Vision config is required for image input" );
253- }
254- int image_size = config.vision_config .value ().image_size ;
255- int patch_size = config.vision_config .value ().patch_size ;
256-
257- int embed_size = (image_size * image_size) / (patch_size * patch_size);
258-
259- auto image_ndarray = ClipPreprocessor (image_data_res.Unwrap (), image_size, device);
260- message_list.push_back (ImageData (image_ndarray, embed_size));
261338 } else {
262- return TResult::Error (" Unsupported content type: " + it_type->second );
339+ ICHECK (msg.content .IsText ());
340+ pending_text += conv.GetRoleText (msg.role , msg.content .Text (), fn_call_string);
263341 }
342+ pending_text += seperator;
264343 }
344+ return std::nullopt ;
345+ };
265346
266- message += separator;
267- message_list.push_back (TextData (message));
347+ if (auto err = f_process_messages (conv.messages )) {
348+ return err.value ();
349+ }
350+ if (auto err = f_process_messages (request.messages )) {
351+ return err.value ();
352+ }
353+ // append last assistant begin message
354+ ChatCompletionMessage last_assistant_begin;
355+ last_assistant_begin.role = " assistant" ;
356+ last_assistant_begin.content = std::nullopt ;
357+ if (auto err = f_process_messages ({last_assistant_begin})) {
358+ return err.value ();
359+ }
360+ if (pending_text.length () != 0 ) {
361+ message_list.push_back (TextData (pending_text));
268362 }
269-
270363 return TResult::Ok (message_list);
271364}
272365
@@ -383,7 +476,10 @@ Result<Conversation> Conversation::FromJSON(const picojson::object& json_obj) {
383476 content.push_back (std::move (item_map));
384477 }
385478 }
386- conv.messages .push_back ({role_res.Unwrap (), content});
479+ ChatCompletionMessage msg;
480+ msg.role = role_res.Unwrap ();
481+ msg.content = content;
482+ conv.messages .push_back (msg);
387483 }
388484
389485 Result<picojson::array> seps_arr_res =
@@ -438,21 +534,6 @@ Result<Conversation> Conversation::FromJSON(const picojson::object& json_obj) {
438534 }
439535 conv.stop_token_ids .push_back (stop.get <int64_t >());
440536 }
441-
442- Result<std::optional<std::string>> function_string_res =
443- json::LookupOptionalWithResultReturn<std::string>(json_obj, " function_string" );
444- if (function_string_res.IsErr ()) {
445- return TResult::Error (function_string_res.UnwrapErr ());
446- }
447- conv.function_string = function_string_res.Unwrap ();
448-
449- Result<bool > use_function_calling_res = json::LookupOrDefaultWithResultReturn<bool >(
450- json_obj, " use_function_calling" , conv.use_function_calling );
451- if (use_function_calling_res.IsErr ()) {
452- return TResult::Error (use_function_calling_res.UnwrapErr ());
453- }
454- conv.use_function_calling = use_function_calling_res.Unwrap ();
455-
456537 return TResult::Ok (conv);
457538}
458539
0 commit comments