33#include < tvm/runtime/registry.h>
44
55#include " ../support/json_parser.h"
6+ #include " image_utils.h"
67
78namespace mlc {
89namespace llm {
910namespace json_ffi {
1011
1112using namespace mlc ::llm;
1213
14+ /* ***************** Model vision config ******************/
15+
16+ ModelVisionConfig ModelVisionConfig::FromJSON (const picojson::object& json_obj) {
17+ ModelVisionConfig config;
18+
19+ Result<int64_t > hidden_size_res = json::LookupWithResultReturn<int64_t >(json_obj, " hidden_size" );
20+ if (hidden_size_res.IsOk ()) {
21+ config.hidden_size = hidden_size_res.Unwrap ();
22+ }
23+
24+ Result<int64_t > image_size_res = json::LookupWithResultReturn<int64_t >(json_obj, " image_size" );
25+ if (image_size_res.IsOk ()) {
26+ config.image_size = image_size_res.Unwrap ();
27+ }
28+
29+ Result<int64_t > intermediate_size_res =
30+ json::LookupWithResultReturn<int64_t >(json_obj, " intermediate_size" );
31+ if (intermediate_size_res.IsOk ()) {
32+ config.intermediate_size = intermediate_size_res.Unwrap ();
33+ }
34+
35+ Result<int64_t > num_attention_heads_res =
36+ json::LookupWithResultReturn<int64_t >(json_obj, " num_attention_heads" );
37+ if (num_attention_heads_res.IsOk ()) {
38+ config.num_attention_heads = num_attention_heads_res.Unwrap ();
39+ }
40+
41+ Result<int64_t > num_hidden_layers_res =
42+ json::LookupWithResultReturn<int64_t >(json_obj, " num_hidden_layers" );
43+ if (num_hidden_layers_res.IsOk ()) {
44+ config.num_hidden_layers = num_hidden_layers_res.Unwrap ();
45+ }
46+
47+ Result<int64_t > patch_size_res = json::LookupWithResultReturn<int64_t >(json_obj, " patch_size" );
48+ if (patch_size_res.IsOk ()) {
49+ config.patch_size = patch_size_res.Unwrap ();
50+ }
51+
52+ Result<int64_t > projection_dim_res =
53+ json::LookupWithResultReturn<int64_t >(json_obj, " projection_dim" );
54+ if (projection_dim_res.IsOk ()) {
55+ config.projection_dim = projection_dim_res.Unwrap ();
56+ }
57+
58+ Result<int64_t > vocab_size_res = json::LookupWithResultReturn<int64_t >(json_obj, " vocab_size" );
59+ if (vocab_size_res.IsOk ()) {
60+ config.vocab_size = vocab_size_res.Unwrap ();
61+ }
62+
63+ Result<std::string> dtype_res = json::LookupWithResultReturn<std::string>(json_obj, " dtype" );
64+ if (dtype_res.IsOk ()) {
65+ config.dtype = dtype_res.Unwrap ();
66+ }
67+
68+ Result<int64_t > num_channels_res =
69+ json::LookupWithResultReturn<int64_t >(json_obj, " num_channels" );
70+ if (num_channels_res.IsOk ()) {
71+ config.num_channels = num_channels_res.Unwrap ();
72+ }
73+
74+ Result<double > layer_norm_eps_res =
75+ json::LookupWithResultReturn<double >(json_obj, " layer_norm_eps" );
76+ if (layer_norm_eps_res.IsOk ()) {
77+ config.layer_norm_eps = layer_norm_eps_res.Unwrap ();
78+ }
79+
80+ return config;
81+ }
82+
83+ /* ***************** Model config ******************/
84+
85+ ModelConfig ModelConfig::FromJSON (const picojson::object& json_obj) {
86+ ModelConfig config;
87+
88+ Result<int64_t > vocab_size_res = json::LookupWithResultReturn<int64_t >(json_obj, " vocab_size" );
89+ if (vocab_size_res.IsOk ()) {
90+ config.vocab_size = vocab_size_res.Unwrap ();
91+ }
92+
93+ Result<int64_t > context_window_size_res =
94+ json::LookupWithResultReturn<int64_t >(json_obj, " context_window_size" );
95+ if (context_window_size_res.IsOk ()) {
96+ config.context_window_size = context_window_size_res.Unwrap ();
97+ }
98+
99+ Result<int64_t > sliding_window_size_res =
100+ json::LookupWithResultReturn<int64_t >(json_obj, " sliding_window_size" );
101+ if (sliding_window_size_res.IsOk ()) {
102+ config.sliding_window_size = sliding_window_size_res.Unwrap ();
103+ }
104+
105+ Result<int64_t > prefill_chunk_size_res =
106+ json::LookupWithResultReturn<int64_t >(json_obj, " prefill_chunk_size" );
107+ if (prefill_chunk_size_res.IsOk ()) {
108+ config.prefill_chunk_size = prefill_chunk_size_res.Unwrap ();
109+ }
110+
111+ Result<int64_t > tensor_parallel_shards_res =
112+ json::LookupWithResultReturn<int64_t >(json_obj, " tensor_parallel_shards" );
113+ if (tensor_parallel_shards_res.IsOk ()) {
114+ config.tensor_parallel_shards = tensor_parallel_shards_res.Unwrap ();
115+ }
116+
117+ Result<int64_t > max_batch_size_res =
118+ json::LookupWithResultReturn<int64_t >(json_obj, " max_batch_size" );
119+ if (max_batch_size_res.IsOk ()) {
120+ config.max_batch_size = max_batch_size_res.Unwrap ();
121+ }
122+
123+ if (json_obj.count (" vision_config" )) {
124+ const picojson::object& vision_config_obj =
125+ json_obj.at (" vision_config" ).get <picojson::object>();
126+ config.vision_config = ModelVisionConfig::FromJSON (vision_config_obj);
127+ }
128+
129+ return config;
130+ }
131+
13132/* ***************** Conversation template ******************/
14133
15134std::map<MessagePlaceholders, std::string> PLACEHOLDERS = {
@@ -34,7 +153,7 @@ Conversation::Conversation()
34153 {" assistant" , PLACEHOLDERS[MessagePlaceholders::ASSISTANT]},
35154 {" tool" , PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {}
36155
37- Result<std::vector<Data>> Conversation::AsPrompt () {
156+ Result<std::vector<Data>> Conversation::AsPrompt (ModelConfig config, DLDevice device ) {
38157 using TResult = Result<std::vector<Data>>;
39158 // Get the system message
40159 std::string system_msg = system_template;
@@ -116,6 +235,29 @@ Result<std::vector<Data>> Conversation::AsPrompt() {
116235 }
117236 }
118237 message += role_text;
238+ } else if (it_type->second == " image_url" ) {
239+ if (item.find (" image_url" ) == item.end ()) {
240+ return TResult::Error (" Content should have an image_url field" );
241+ }
242+ std::string image_url =
243+ item.at (" image_url" ); // TODO(mlc-team): According to OpenAI API reference this
244+ // should be a map, with a "url" key containing the URL, but
245+ // we are just assuming this as the URL for now
246+ std::string base64_image = image_url.substr (image_url.find (" ," ) + 1 );
247+ Result<NDArray> image_data_res = LoadImageFromBase64 (base64_image);
248+ if (image_data_res.IsErr ()) {
249+ return TResult::Error (image_data_res.UnwrapErr ());
250+ }
251+ if (!config.vision_config .has_value ()) {
252+ return TResult::Error (" Vision config is required for image input" );
253+ }
254+ int image_size = config.vision_config .value ().image_size ;
255+ int patch_size = config.vision_config .value ().patch_size ;
256+
257+ int embed_size = (image_size * image_size) / (patch_size * patch_size);
258+
259+ auto image_ndarray = ClipPreprocessor (image_data_res.Unwrap (), image_size, device);
260+ message_list.push_back (ImageData (image_ndarray, embed_size));
119261 } else {
120262 return TResult::Error (" Unsupported content type: " + it_type->second );
121263 }
0 commit comments