Skip to content

Commit 83a5243

Browse files
author
Animesh Bohara
committed
Introduce ModelConfig and ModelVisionConfig to hold relevant parameters
1 parent 7c01023 commit 83a5243

File tree

6 files changed

+151
-48
lines changed

6 files changed

+151
-48
lines changed

cpp/json_ffi/config.cc

Lines changed: 99 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,99 @@ namespace json_ffi {
1111

1212
using namespace mlc::llm;
1313

14+
/****************** Model vision config ******************/
15+
16+
ModelVisionConfig ModelVisionConfig::FromJSON(const picojson::object& json_obj, std::string* err) {
17+
ModelVisionConfig config;
18+
19+
int64_t hidden_size;
20+
if (json::ParseJSONField(json_obj, "hidden_size", hidden_size, err, false)) {
21+
config.hidden_size = hidden_size;
22+
}
23+
int64_t image_size;
24+
if (json::ParseJSONField(json_obj, "image_size", image_size, err, false)) {
25+
config.image_size = image_size;
26+
}
27+
int64_t intermediate_size;
28+
if (json::ParseJSONField(json_obj, "intermediate_size", intermediate_size, err, false)) {
29+
config.intermediate_size = intermediate_size;
30+
}
31+
int64_t num_attention_heads;
32+
if (json::ParseJSONField(json_obj, "num_attention_heads", num_attention_heads, err, false)) {
33+
config.num_attention_heads = num_attention_heads;
34+
}
35+
int64_t num_hidden_layers;
36+
if (json::ParseJSONField(json_obj, "num_hidden_layers", num_hidden_layers, err, false)) {
37+
config.num_hidden_layers = num_hidden_layers;
38+
}
39+
int64_t patch_size;
40+
if (json::ParseJSONField(json_obj, "patch_size", patch_size, err, false)) {
41+
config.patch_size = patch_size;
42+
}
43+
int64_t projection_dim;
44+
if (json::ParseJSONField(json_obj, "projection_dim", projection_dim, err, false)) {
45+
config.projection_dim = projection_dim;
46+
}
47+
int64_t vocab_size;
48+
if (json::ParseJSONField(json_obj, "vocab_size", vocab_size, err, false)) {
49+
config.vocab_size = vocab_size;
50+
}
51+
std::string dtype;
52+
if (json::ParseJSONField(json_obj, "dtype", dtype, err, false)) {
53+
config.dtype = dtype;
54+
}
55+
int64_t num_channels;
56+
if (json::ParseJSONField(json_obj, "num_channels", num_channels, err, false)) {
57+
config.num_channels = num_channels;
58+
}
59+
double layer_norm_eps;
60+
if (json::ParseJSONField(json_obj, "layer_norm_eps", layer_norm_eps, err, false)) {
61+
config.layer_norm_eps = layer_norm_eps;
62+
}
63+
64+
return config;
65+
}
66+
67+
/****************** Model config ******************/
68+
69+
ModelConfig ModelConfig::FromJSON(const picojson::object& json_obj, std::string* err) {
70+
ModelConfig config;
71+
72+
int64_t vocab_size;
73+
if (json::ParseJSONField(json_obj, "vocab_size", vocab_size, err, false)) {
74+
config.vocab_size = vocab_size;
75+
}
76+
int64_t context_window_size;
77+
if (json::ParseJSONField(json_obj, "context_window_size", context_window_size, err, false)) {
78+
config.context_window_size = context_window_size;
79+
}
80+
int64_t sliding_window_size;
81+
if (json::ParseJSONField(json_obj, "sliding_window_size", sliding_window_size, err, false)) {
82+
config.sliding_window_size = sliding_window_size;
83+
}
84+
int64_t prefill_chunk_size;
85+
if (json::ParseJSONField(json_obj, "prefill_chunk_size", prefill_chunk_size, err, false)) {
86+
config.prefill_chunk_size = prefill_chunk_size;
87+
}
88+
int64_t tensor_parallel_shards;
89+
if (json::ParseJSONField(json_obj, "tensor_parallel_shards", tensor_parallel_shards, err,
90+
false)) {
91+
config.tensor_parallel_shards = tensor_parallel_shards;
92+
}
93+
int64_t max_batch_size;
94+
if (json::ParseJSONField(json_obj, "max_batch_size", max_batch_size, err, false)) {
95+
config.max_batch_size = max_batch_size;
96+
}
97+
98+
if (json_obj.count("vision_config")) {
99+
const picojson::object& vision_config_obj =
100+
json_obj.at("vision_config").get<picojson::object>();
101+
config.vision_config = ModelVisionConfig::FromJSON(vision_config_obj, err);
102+
}
103+
104+
return config;
105+
}
106+
14107
/****************** Model-defined generation config ******************/
15108

16109
TVM_REGISTER_OBJECT_TYPE(ModelDefinedGenerationConfigNode);
@@ -63,7 +156,7 @@ std::vector<std::string> Conversation::CheckMessageSeps(std::vector<std::string>
63156
return seps;
64157
}
65158

66-
std::optional<std::vector<Data>> Conversation::AsPrompt(picojson::object config, DLDevice device,
159+
std::optional<std::vector<Data>> Conversation::AsPrompt(ModelConfig config, DLDevice device,
67160
std::string* err) {
68161
// Get the system message
69162
std::string system_msg = system_template;
@@ -155,47 +248,15 @@ std::optional<std::vector<Data>> Conversation::AsPrompt(picojson::object config,
155248
// we are just assuming this as the URL for now
156249
std::string base64_image = image_url.substr(image_url.find(",") + 1);
157250
std::optional<NDArray> image_data = LoadImageFromBase64(base64_image, err);
158-
if (!image_data) {
159-
return std::nullopt;
160-
}
161-
162-
if (config.find("model_config") == config.end()) {
163-
*err += "model_config is required in config";
164-
return std::nullopt;
165-
}
166-
if (!config["model_config"].is<picojson::object>()) {
167-
*err += "model_config should be an object";
251+
if (!image_data.has_value()) {
168252
return std::nullopt;
169253
}
170-
picojson::object model_config = config["model_config"].get<picojson::object>();
171-
if (model_config.find("vision_config") == model_config.end()) {
172-
*err += "vision_config is required in model_config";
254+
if (!config.vision_config.has_value()) {
255+
*err += "Vision config is required for image input";
173256
return std::nullopt;
174257
}
175-
if (!model_config["vision_config"].is<picojson::object>()) {
176-
*err += "vision_config should be an object";
177-
return std::nullopt;
178-
}
179-
picojson::object vision_config = model_config["vision_config"].get<picojson::object>();
180-
if (vision_config.find("image_size") == vision_config.end()) {
181-
*err += "image_size is required in vision_config";
182-
return std::nullopt;
183-
}
184-
if (!vision_config["image_size"].is<int64_t>()) {
185-
*err += "image_size should be an integer";
186-
return std::nullopt;
187-
}
188-
if (vision_config.find("patch_size") == vision_config.end()) {
189-
*err += "patch_size is required in vision_config";
190-
return std::nullopt;
191-
}
192-
if (!vision_config["patch_size"].is<int64_t>()) {
193-
*err += "patch_size should be an integer";
194-
return std::nullopt;
195-
}
196-
197-
int image_size = vision_config["image_size"].get<int64_t>();
198-
int patch_size = vision_config["patch_size"].get<int64_t>();
258+
int image_size = config.vision_config.value().image_size;
259+
int patch_size = config.vision_config.value().patch_size;
199260

200261
int embed_size = (image_size * image_size) / (patch_size * patch_size);
201262

cpp/json_ffi/config.h

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,43 @@ namespace mlc {
2222
namespace llm {
2323
namespace json_ffi {
2424

25+
/****************** Model vision config ******************/
26+
27+
/*! \brief Defines the Vision config of the model (if present) */
28+
class ModelVisionConfig {
29+
public:
30+
int hidden_size;
31+
int image_size;
32+
int intermediate_size;
33+
int num_attention_heads;
34+
int num_hidden_layers;
35+
int patch_size;
36+
int projection_dim;
37+
int vocab_size;
38+
std::string dtype;
39+
int num_channels;
40+
double layer_norm_eps;
41+
42+
static ModelVisionConfig FromJSON(const picojson::object& json_obj, std::string* err);
43+
};
44+
45+
/****************** Model config ******************/
46+
47+
/*! \brief Defines the config of the model.
48+
Populated from "model_config" field in mlc-chat-config.json */
49+
class ModelConfig {
50+
public:
51+
int vocab_size;
52+
int context_window_size;
53+
int sliding_window_size;
54+
int prefill_chunk_size;
55+
int tensor_parallel_shards;
56+
int max_batch_size;
57+
std::optional<ModelVisionConfig> vision_config = std::nullopt;
58+
59+
static ModelConfig FromJSON(const picojson::object& json_obj, std::string* err);
60+
};
61+
2562
/****************** Model-defined generation config ******************/
2663

2764
class ModelDefinedGenerationConfigNode : public Object {
@@ -129,8 +166,7 @@ struct Conversation {
129166
* \brief Create the list of prompts from the messages based on the conversation template.
130167
* When creation fails, errors are dumped to the input error string, and nullopt is returned.
131168
*/
132-
std::optional<std::vector<Data>> AsPrompt(picojson::object config, DLDevice device,
133-
std::string* err);
169+
std::optional<std::vector<Data>> AsPrompt(ModelConfig config, DLDevice device, std::string* err);
134170

135171
/*!
136172
* \brief Create a Conversation instance from the given JSON object.

cpp/json_ffi/image_utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ size_t Base64DecodedSize(const std::string& base64_str) {
5151
return 3 * len / 4 - padding;
5252
}
5353

54-
std::optional<NDArray> LoadImageFromBase64(std::string base64_str, std::string* err) {
54+
std::optional<NDArray> LoadImageFromBase64(const std::string& base64_str, std::string* err) {
5555
MemoryBufferStream stream(base64_str.c_str(), base64_str.size());
5656
tvm::support::Base64InStream base64_stream(&stream);
5757
size_t decoded_size = Base64DecodedSize(base64_str);

cpp/json_ffi/image_utils.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ namespace mlc {
1515
namespace llm {
1616
namespace json_ffi {
1717

18-
using namespace tvm::runtime;
18+
/*! \brief Load a base64 encoded image string into a CPU NDArray of shape {height, width, 3} */
19+
std::optional<tvm::runtime::NDArray> LoadImageFromBase64(const std::string& base64_str,
20+
std::string* err);
1921

20-
std::optional<NDArray> LoadImageFromBase64(std::string base64_str, std::string* err);
21-
22-
NDArray ClipPreprocessor(NDArray image_data, int target_size, DLDevice device, std::string* err);
22+
/*! \brief Preprocess the CPU image for CLIP encoder and return an NDArray on the given device */
23+
tvm::runtime::NDArray ClipPreprocessor(tvm::runtime::NDArray image_data, int target_size,
24+
DLDevice device, std::string* err);
2325

2426
} // namespace json_ffi
2527
} // namespace llm

cpp/json_ffi/json_ffi_engine.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
147147
}
148148
std::string model_config_str = std::string((std::istreambuf_iterator<char>(model_config_file)),
149149
std::istreambuf_iterator<char>());
150-
this->model_config_ = json::LoadJSONFromString(model_config_str, &err_).value();
150+
picojson::object model_config_obj = json::LoadJSONFromString(model_config_str, &err_).value();
151+
this->model_config_ =
152+
ModelConfig::FromJSON(model_config_obj.at("model_config").get<picojson::object>(), &err_);
151153

152154
this->device_ = std::move(device);
153155

@@ -183,7 +185,9 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
183185
}
184186
std::string model_config_str = std::string((std::istreambuf_iterator<char>(model_config_file)),
185187
std::istreambuf_iterator<char>());
186-
this->model_config_ = json::LoadJSONFromString(model_config_str, &err_).value();
188+
picojson::object model_config_obj = json::LoadJSONFromString(model_config_str, &err_).value();
189+
this->model_config_ =
190+
ModelConfig::FromJSON(model_config_obj.at("model_config").get<picojson::object>(), &err_);
187191
}
188192

189193
void Unload() { this->engine_->Unload(); }

cpp/json_ffi/json_ffi_engine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class JSONFFIEngine {
5050
TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request
5151
Conversation conv_template_;
5252
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs;
53-
picojson::object model_config_;
53+
ModelConfig model_config_;
5454
DLDevice device_;
5555
};
5656

0 commit comments

Comments
 (0)