@@ -11,6 +11,99 @@ namespace json_ffi {
1111
1212using namespace mlc ::llm;
1313
14+ /* ***************** Model vision config ******************/
15+
16+ ModelVisionConfig ModelVisionConfig::FromJSON (const picojson::object& json_obj, std::string* err) {
17+ ModelVisionConfig config;
18+
19+ int64_t hidden_size;
20+ if (json::ParseJSONField (json_obj, " hidden_size" , hidden_size, err, false )) {
21+ config.hidden_size = hidden_size;
22+ }
23+ int64_t image_size;
24+ if (json::ParseJSONField (json_obj, " image_size" , image_size, err, false )) {
25+ config.image_size = image_size;
26+ }
27+ int64_t intermediate_size;
28+ if (json::ParseJSONField (json_obj, " intermediate_size" , intermediate_size, err, false )) {
29+ config.intermediate_size = intermediate_size;
30+ }
31+ int64_t num_attention_heads;
32+ if (json::ParseJSONField (json_obj, " num_attention_heads" , num_attention_heads, err, false )) {
33+ config.num_attention_heads = num_attention_heads;
34+ }
35+ int64_t num_hidden_layers;
36+ if (json::ParseJSONField (json_obj, " num_hidden_layers" , num_hidden_layers, err, false )) {
37+ config.num_hidden_layers = num_hidden_layers;
38+ }
39+ int64_t patch_size;
40+ if (json::ParseJSONField (json_obj, " patch_size" , patch_size, err, false )) {
41+ config.patch_size = patch_size;
42+ }
43+ int64_t projection_dim;
44+ if (json::ParseJSONField (json_obj, " projection_dim" , projection_dim, err, false )) {
45+ config.projection_dim = projection_dim;
46+ }
47+ int64_t vocab_size;
48+ if (json::ParseJSONField (json_obj, " vocab_size" , vocab_size, err, false )) {
49+ config.vocab_size = vocab_size;
50+ }
51+ std::string dtype;
52+ if (json::ParseJSONField (json_obj, " dtype" , dtype, err, false )) {
53+ config.dtype = dtype;
54+ }
55+ int64_t num_channels;
56+ if (json::ParseJSONField (json_obj, " num_channels" , num_channels, err, false )) {
57+ config.num_channels = num_channels;
58+ }
59+ double layer_norm_eps;
60+ if (json::ParseJSONField (json_obj, " layer_norm_eps" , layer_norm_eps, err, false )) {
61+ config.layer_norm_eps = layer_norm_eps;
62+ }
63+
64+ return config;
65+ }
66+
67+ /* ***************** Model config ******************/
68+
69+ ModelConfig ModelConfig::FromJSON (const picojson::object& json_obj, std::string* err) {
70+ ModelConfig config;
71+
72+ int64_t vocab_size;
73+ if (json::ParseJSONField (json_obj, " vocab_size" , vocab_size, err, false )) {
74+ config.vocab_size = vocab_size;
75+ }
76+ int64_t context_window_size;
77+ if (json::ParseJSONField (json_obj, " context_window_size" , context_window_size, err, false )) {
78+ config.context_window_size = context_window_size;
79+ }
80+ int64_t sliding_window_size;
81+ if (json::ParseJSONField (json_obj, " sliding_window_size" , sliding_window_size, err, false )) {
82+ config.sliding_window_size = sliding_window_size;
83+ }
84+ int64_t prefill_chunk_size;
85+ if (json::ParseJSONField (json_obj, " prefill_chunk_size" , prefill_chunk_size, err, false )) {
86+ config.prefill_chunk_size = prefill_chunk_size;
87+ }
88+ int64_t tensor_parallel_shards;
89+ if (json::ParseJSONField (json_obj, " tensor_parallel_shards" , tensor_parallel_shards, err,
90+ false )) {
91+ config.tensor_parallel_shards = tensor_parallel_shards;
92+ }
93+ int64_t max_batch_size;
94+ if (json::ParseJSONField (json_obj, " max_batch_size" , max_batch_size, err, false )) {
95+ config.max_batch_size = max_batch_size;
96+ }
97+
98+ if (json_obj.count (" vision_config" )) {
99+ const picojson::object& vision_config_obj =
100+ json_obj.at (" vision_config" ).get <picojson::object>();
101+ config.vision_config = ModelVisionConfig::FromJSON (vision_config_obj, err);
102+ }
103+
104+ return config;
105+ }
106+
14107/* ***************** Model-defined generation config ******************/
15108
16109TVM_REGISTER_OBJECT_TYPE (ModelDefinedGenerationConfigNode);
@@ -63,7 +156,7 @@ std::vector<std::string> Conversation::CheckMessageSeps(std::vector<std::string>
63156 return seps;
64157}
65158
66- std::optional<std::vector<Data>> Conversation::AsPrompt (picojson::object config, DLDevice device,
159+ std::optional<std::vector<Data>> Conversation::AsPrompt (ModelConfig config, DLDevice device,
67160 std::string* err) {
68161 // Get the system message
69162 std::string system_msg = system_template;
@@ -155,47 +248,15 @@ std::optional<std::vector<Data>> Conversation::AsPrompt(picojson::object config,
155248 // we are just assuming this as the URL for now
156249 std::string base64_image = image_url.substr (image_url.find (" ," ) + 1 );
157250 std::optional<NDArray> image_data = LoadImageFromBase64 (base64_image, err);
158- if (!image_data) {
159- return std::nullopt ;
160- }
161-
162- if (config.find (" model_config" ) == config.end ()) {
163- *err += " model_config is required in config" ;
164- return std::nullopt ;
165- }
166- if (!config[" model_config" ].is <picojson::object>()) {
167- *err += " model_config should be an object" ;
251+ if (!image_data.has_value ()) {
168252 return std::nullopt ;
169253 }
170- picojson::object model_config = config[" model_config" ].get <picojson::object>();
171- if (model_config.find (" vision_config" ) == model_config.end ()) {
172- *err += " vision_config is required in model_config" ;
254+ if (!config.vision_config .has_value ()) {
255+ *err += " Vision config is required for image input" ;
173256 return std::nullopt ;
174257 }
175- if (!model_config[" vision_config" ].is <picojson::object>()) {
176- *err += " vision_config should be an object" ;
177- return std::nullopt ;
178- }
179- picojson::object vision_config = model_config[" vision_config" ].get <picojson::object>();
180- if (vision_config.find (" image_size" ) == vision_config.end ()) {
181- *err += " image_size is required in vision_config" ;
182- return std::nullopt ;
183- }
184- if (!vision_config[" image_size" ].is <int64_t >()) {
185- *err += " image_size should be an integer" ;
186- return std::nullopt ;
187- }
188- if (vision_config.find (" patch_size" ) == vision_config.end ()) {
189- *err += " patch_size is required in vision_config" ;
190- return std::nullopt ;
191- }
192- if (!vision_config[" patch_size" ].is <int64_t >()) {
193- *err += " patch_size should be an integer" ;
194- return std::nullopt ;
195- }
196-
197- int image_size = vision_config[" image_size" ].get <int64_t >();
198- int patch_size = vision_config[" patch_size" ].get <int64_t >();
258+ int image_size = config.vision_config .value ().image_size ;
259+ int patch_size = config.vision_config .value ().patch_size ;
199260
200261 int embed_size = (image_size * image_size) / (patch_size * patch_size);
201262
0 commit comments