Skip to content

Commit bdffdf0

Browse files
author
Animesh Bohara
committed
Using new Result interface
1 parent f181ce2 commit bdffdf0

File tree

12 files changed

+481
-100
lines changed

12 files changed

+481
-100
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@
1010
[submodule "3rdparty/tvm"]
1111
path = 3rdparty/tvm
1212
url = https://github.com/mlc-ai/relax.git
13+
[submodule "3rdparty/stb"]
14+
path = 3rdparty/stb
15+
url = https://github.com/nothings/stb.git

3rdparty/stb

Submodule stb added at ae721c5

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ target_include_directories(mlc_llm_objs PRIVATE ${MLC_LLM_INCLUDES})
8888
target_compile_definitions(mlc_llm_objs PRIVATE ${MLC_LLM_COMPILE_DEFS})
8989
target_include_directories(mlc_llm_objs PRIVATE ${TOKENZIER_CPP_PATH}/include)
9090
target_compile_definitions(mlc_llm_objs PRIVATE -DMLC_LLM_EXPORTS)
91+
target_include_directories(mlc_llm_objs PRIVATE 3rdparty/stb)
9192

9293
add_library(mlc_llm SHARED $<TARGET_OBJECTS:mlc_llm_objs>)
9394
add_library(mlc_llm_static STATIC $<TARGET_OBJECTS:mlc_llm_objs>)

cpp/json_ffi/conv_template.cc

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,132 @@
33
#include <tvm/runtime/registry.h>
44

55
#include "../support/json_parser.h"
6+
#include "image_utils.h"
67

78
namespace mlc {
89
namespace llm {
910
namespace json_ffi {
1011

1112
using namespace mlc::llm;
1213

14+
/****************** Model vision config ******************/
15+
16+
ModelVisionConfig ModelVisionConfig::FromJSON(const picojson::object& json_obj) {
17+
ModelVisionConfig config;
18+
19+
Result<int64_t> hidden_size_res = json::LookupWithResultReturn<int64_t>(json_obj, "hidden_size");
20+
if (hidden_size_res.IsOk()) {
21+
config.hidden_size = hidden_size_res.Unwrap();
22+
}
23+
24+
Result<int64_t> image_size_res = json::LookupWithResultReturn<int64_t>(json_obj, "image_size");
25+
if (image_size_res.IsOk()) {
26+
config.image_size = image_size_res.Unwrap();
27+
}
28+
29+
Result<int64_t> intermediate_size_res =
30+
json::LookupWithResultReturn<int64_t>(json_obj, "intermediate_size");
31+
if (intermediate_size_res.IsOk()) {
32+
config.intermediate_size = intermediate_size_res.Unwrap();
33+
}
34+
35+
Result<int64_t> num_attention_heads_res =
36+
json::LookupWithResultReturn<int64_t>(json_obj, "num_attention_heads");
37+
if (num_attention_heads_res.IsOk()) {
38+
config.num_attention_heads = num_attention_heads_res.Unwrap();
39+
}
40+
41+
Result<int64_t> num_hidden_layers_res =
42+
json::LookupWithResultReturn<int64_t>(json_obj, "num_hidden_layers");
43+
if (num_hidden_layers_res.IsOk()) {
44+
config.num_hidden_layers = num_hidden_layers_res.Unwrap();
45+
}
46+
47+
Result<int64_t> patch_size_res = json::LookupWithResultReturn<int64_t>(json_obj, "patch_size");
48+
if (patch_size_res.IsOk()) {
49+
config.patch_size = patch_size_res.Unwrap();
50+
}
51+
52+
Result<int64_t> projection_dim_res =
53+
json::LookupWithResultReturn<int64_t>(json_obj, "projection_dim");
54+
if (projection_dim_res.IsOk()) {
55+
config.projection_dim = projection_dim_res.Unwrap();
56+
}
57+
58+
Result<int64_t> vocab_size_res = json::LookupWithResultReturn<int64_t>(json_obj, "vocab_size");
59+
if (vocab_size_res.IsOk()) {
60+
config.vocab_size = vocab_size_res.Unwrap();
61+
}
62+
63+
Result<std::string> dtype_res = json::LookupWithResultReturn<std::string>(json_obj, "dtype");
64+
if (dtype_res.IsOk()) {
65+
config.dtype = dtype_res.Unwrap();
66+
}
67+
68+
Result<int64_t> num_channels_res =
69+
json::LookupWithResultReturn<int64_t>(json_obj, "num_channels");
70+
if (num_channels_res.IsOk()) {
71+
config.num_channels = num_channels_res.Unwrap();
72+
}
73+
74+
Result<double> layer_norm_eps_res =
75+
json::LookupWithResultReturn<double>(json_obj, "layer_norm_eps");
76+
if (layer_norm_eps_res.IsOk()) {
77+
config.layer_norm_eps = layer_norm_eps_res.Unwrap();
78+
}
79+
80+
return config;
81+
}
82+
83+
/****************** Model config ******************/
84+
85+
ModelConfig ModelConfig::FromJSON(const picojson::object& json_obj) {
86+
ModelConfig config;
87+
88+
Result<int64_t> vocab_size_res = json::LookupWithResultReturn<int64_t>(json_obj, "vocab_size");
89+
if (vocab_size_res.IsOk()) {
90+
config.vocab_size = vocab_size_res.Unwrap();
91+
}
92+
93+
Result<int64_t> context_window_size_res =
94+
json::LookupWithResultReturn<int64_t>(json_obj, "context_window_size");
95+
if (context_window_size_res.IsOk()) {
96+
config.context_window_size = context_window_size_res.Unwrap();
97+
}
98+
99+
Result<int64_t> sliding_window_size_res =
100+
json::LookupWithResultReturn<int64_t>(json_obj, "sliding_window_size");
101+
if (sliding_window_size_res.IsOk()) {
102+
config.sliding_window_size = sliding_window_size_res.Unwrap();
103+
}
104+
105+
Result<int64_t> prefill_chunk_size_res =
106+
json::LookupWithResultReturn<int64_t>(json_obj, "prefill_chunk_size");
107+
if (prefill_chunk_size_res.IsOk()) {
108+
config.prefill_chunk_size = prefill_chunk_size_res.Unwrap();
109+
}
110+
111+
Result<int64_t> tensor_parallel_shards_res =
112+
json::LookupWithResultReturn<int64_t>(json_obj, "tensor_parallel_shards");
113+
if (tensor_parallel_shards_res.IsOk()) {
114+
config.tensor_parallel_shards = tensor_parallel_shards_res.Unwrap();
115+
}
116+
117+
Result<int64_t> max_batch_size_res =
118+
json::LookupWithResultReturn<int64_t>(json_obj, "max_batch_size");
119+
if (max_batch_size_res.IsOk()) {
120+
config.max_batch_size = max_batch_size_res.Unwrap();
121+
}
122+
123+
if (json_obj.count("vision_config")) {
124+
const picojson::object& vision_config_obj =
125+
json_obj.at("vision_config").get<picojson::object>();
126+
config.vision_config = ModelVisionConfig::FromJSON(vision_config_obj);
127+
}
128+
129+
return config;
130+
}
131+
13132
/****************** Conversation template ******************/
14133

15134
std::map<MessagePlaceholders, std::string> PLACEHOLDERS = {
@@ -34,7 +153,7 @@ Conversation::Conversation()
34153
{"assistant", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]},
35154
{"tool", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {}
36155

37-
Result<std::vector<Data>> Conversation::AsPrompt() {
156+
Result<std::vector<Data>> Conversation::AsPrompt(ModelConfig config, DLDevice device) {
38157
using TResult = Result<std::vector<Data>>;
39158
// Get the system message
40159
std::string system_msg = system_template;
@@ -116,6 +235,29 @@ Result<std::vector<Data>> Conversation::AsPrompt() {
116235
}
117236
}
118237
message += role_text;
238+
} else if (it_type->second == "image_url") {
239+
if (item.find("image_url") == item.end()) {
240+
return TResult::Error("Content should have an image_url field");
241+
}
242+
std::string image_url =
243+
item.at("image_url"); // TODO(mlc-team): According to OpenAI API reference this
244+
// should be a map, with a "url" key containing the URL, but
245+
// we are just assuming this as the URL for now
246+
std::string base64_image = image_url.substr(image_url.find(",") + 1);
247+
Result<NDArray> image_data_res = LoadImageFromBase64(base64_image);
248+
if (image_data_res.IsErr()) {
249+
return TResult::Error(image_data_res.UnwrapErr());
250+
}
251+
if (!config.vision_config.has_value()) {
252+
return TResult::Error("Vision config is required for image input");
253+
}
254+
int image_size = config.vision_config.value().image_size;
255+
int patch_size = config.vision_config.value().patch_size;
256+
257+
int embed_size = (image_size * image_size) / (patch_size * patch_size);
258+
259+
auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device);
260+
message_list.push_back(ImageData(image_ndarray, embed_size));
119261
} else {
120262
return TResult::Error("Unsupported content type: " + it_type->second);
121263
}

cpp/json_ffi/conv_template.h

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,43 @@ namespace mlc {
1919
namespace llm {
2020
namespace json_ffi {
2121

22+
/****************** Model vision config ******************/
23+
24+
/*! \brief Defines the Vision config of the model (if present) */
25+
class ModelVisionConfig {
26+
public:
27+
int hidden_size;
28+
int image_size;
29+
int intermediate_size;
30+
int num_attention_heads;
31+
int num_hidden_layers;
32+
int patch_size;
33+
int projection_dim;
34+
int vocab_size;
35+
std::string dtype;
36+
int num_channels;
37+
double layer_norm_eps;
38+
39+
static ModelVisionConfig FromJSON(const picojson::object& json_obj);
40+
};
41+
42+
/****************** Model config ******************/
43+
44+
/*! \brief Defines the config of the model.
45+
Populated from "model_config" field in mlc-chat-config.json */
46+
class ModelConfig {
47+
public:
48+
int vocab_size;
49+
int context_window_size;
50+
int sliding_window_size;
51+
int prefill_chunk_size;
52+
int tensor_parallel_shards;
53+
int max_batch_size;
54+
std::optional<ModelVisionConfig> vision_config = std::nullopt;
55+
56+
static ModelConfig FromJSON(const picojson::object& json_obj);
57+
};
58+
2259
/****************** Conversation template ******************/
2360

2461
enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION };
@@ -92,7 +129,7 @@ struct Conversation {
92129
Conversation();
93130

94131
/*! \brief Create the list of prompts from the messages based on the conversation template. */
95-
Result<std::vector<Data>> AsPrompt();
132+
Result<std::vector<Data>> AsPrompt(ModelConfig config, DLDevice device);
96133

97134
/*! \brief Create a Conversation instance from the given JSON object. */
98135
static Result<Conversation> FromJSON(const picojson::object& json);

cpp/json_ffi/image_utils.cc

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#include "image_utils.h"
2+
3+
#include <dmlc/io.h>
4+
5+
#include "../../3rdparty/tvm/src/support/base64.h"
6+
#define STB_IMAGE_IMPLEMENTATION
7+
#include "stb_image.h"
8+
9+
namespace mlc {
10+
namespace llm {
11+
namespace json_ffi {
12+
13+
using namespace tvm::runtime;
14+
15+
class MemoryBufferStream : public dmlc::Stream {
16+
public:
17+
MemoryBufferStream(const char* data, size_t size) : data_(data), size_(size), pos_(0) {}
18+
19+
size_t Read(void* ptr, size_t size) override {
20+
size_t remaining = size_ - pos_;
21+
if (size > remaining) {
22+
size = remaining;
23+
}
24+
if (size == 0) {
25+
return 0;
26+
}
27+
std::memcpy(ptr, data_ + pos_, size);
28+
pos_ += size;
29+
return size;
30+
}
31+
32+
void Write(const void* ptr, size_t size) override {
33+
LOG(FATAL) << "MemoryBufferStream does not support write";
34+
}
35+
36+
private:
37+
const char* data_;
38+
size_t size_;
39+
size_t pos_;
40+
};
41+
42+
size_t Base64DecodedSize(const std::string& base64_str) {
43+
size_t len = base64_str.size();
44+
size_t padding = 0;
45+
if (base64_str[len - 1] == '=') {
46+
padding++;
47+
}
48+
if (base64_str[len - 2] == '=') {
49+
padding++;
50+
}
51+
return 3 * len / 4 - padding;
52+
}
53+
54+
Result<NDArray> LoadImageFromBase64(const std::string& base64_str) {
55+
using TResult = Result<NDArray>;
56+
MemoryBufferStream stream(base64_str.c_str(), base64_str.size());
57+
tvm::support::Base64InStream base64_stream(&stream);
58+
size_t decoded_size = Base64DecodedSize(base64_str);
59+
std::vector<unsigned char> decoded(decoded_size);
60+
base64_stream.InitPosition();
61+
base64_stream.Read((void*)decoded.data(), decoded_size);
62+
int width, height, num_channels;
63+
unsigned char* image_data =
64+
stbi_load_from_memory(decoded.data(), decoded_size, &width, &height, &num_channels, 3);
65+
if (!image_data) {
66+
return TResult::Error(stbi_failure_reason());
67+
}
68+
auto image_ndarray = NDArray::Empty({height, width, 3}, {kDLUInt, 8, 1}, {kDLCPU, 0});
69+
image_ndarray.CopyFromBytes((void*)image_data, width * height * 3);
70+
stbi_image_free(image_data);
71+
return TResult::Ok(image_ndarray);
72+
}
73+
74+
NDArray ClipPreprocessor(NDArray image_data, int target_size, DLDevice device) {
75+
int height = image_data->shape[0];
76+
int width = image_data->shape[1];
77+
// Resize
78+
const int short_side = width < height ? width : height;
79+
const int long_side = width > height ? width : height;
80+
const int new_short_side = target_size;
81+
const int new_long_side = (int)(new_short_side * (long_side / (float)short_side));
82+
const int new_width = width < height ? new_short_side : new_long_side;
83+
const int new_height = width > height ? new_short_side : new_long_side;
84+
85+
std::vector<float> processed_image_data(new_width * new_height * 3);
86+
87+
// Bilinear Interpolation
88+
for (int y = 0; y < new_height; y++) {
89+
for (int x = 0; x < new_width; x++) {
90+
const float x_ratio = float(width - 1) / new_width;
91+
const float y_ratio = float(height - 1) / new_height;
92+
const int x1 = int(x_ratio * x);
93+
const int y1 = int(y_ratio * y);
94+
const int x2 = x1 + 1;
95+
const int y2 = y1 + 1;
96+
const float x_diff = x_ratio * x - x1;
97+
const float y_diff = y_ratio * y - y1;
98+
for (int c = 0; c < 3; c++) {
99+
const uint8_t top_left = ((uint8_t*)image_data->data)[(y1 * width + x1) * 3 + c];
100+
const uint8_t top_right = ((uint8_t*)image_data->data)[(y1 * width + x2) * 3 + c];
101+
const uint8_t bottom_left = ((uint8_t*)image_data->data)[(y2 * width + x1) * 3 + c];
102+
const uint8_t bottom_right = ((uint8_t*)image_data->data)[(y2 * width + x2) * 3 + c];
103+
processed_image_data[(y * new_width + x) * 3 + c] =
104+
(float)(int(top_left * (1 - x_diff) * (1 - y_diff) + top_right * x_diff * (1 - y_diff) +
105+
bottom_left * y_diff * (1 - x_diff) + bottom_right * x_diff * y_diff));
106+
}
107+
}
108+
}
109+
110+
// Center crop
111+
const int crop_x = (new_width - target_size) / 2;
112+
const int crop_y = (new_height - target_size) / 2;
113+
std::vector<float> cropped_image_data(target_size * target_size * 3);
114+
for (int y = 0; y < target_size; y++) {
115+
for (int x = 0; x < target_size; x++) {
116+
for (int c = 0; c < 3; c++) {
117+
cropped_image_data[(y * target_size + x) * 3 + c] =
118+
processed_image_data[((y + crop_y) * new_width + x + crop_x) * 3 + c];
119+
}
120+
}
121+
}
122+
123+
// Rescale
124+
for (int i = 0; i < target_size * target_size * 3; i++) {
125+
cropped_image_data[i] = cropped_image_data[i] / 255.0f;
126+
}
127+
128+
// Normalize
129+
const float IMAGE_MEAN[] = {0.48145466f, 0.4578275f, 0.40821073f};
130+
const float IMAGE_STD[] = {0.26862954f, 0.26130258f, 0.27577711f};
131+
for (int i = 0; i < target_size * target_size * 3; i++) {
132+
const int c = i % 3;
133+
cropped_image_data[i] = (cropped_image_data[i] - IMAGE_MEAN[c]) / IMAGE_STD[c];
134+
}
135+
136+
std::vector<float> image_data_channel_first(target_size * target_size * 3);
137+
for (int y = 0; y < target_size; y++) {
138+
for (int x = 0; x < target_size; x++) {
139+
for (int c = 0; c < 3; c++) {
140+
image_data_channel_first[c * target_size * target_size + y * target_size + x] =
141+
cropped_image_data[(y * target_size + x) * 3 + c];
142+
}
143+
}
144+
}
145+
146+
// Create NDArray
147+
auto image_ndarray = NDArray::Empty({1, 3, target_size, target_size}, {kDLFloat, 32, 1}, device);
148+
image_ndarray.CopyFromBytes((void*)image_data_channel_first.data(),
149+
target_size * target_size * 3 * sizeof(float));
150+
151+
return image_ndarray;
152+
}
153+
154+
} // namespace json_ffi
155+
} // namespace llm
156+
} // namespace mlc

0 commit comments

Comments
 (0)