Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions android/library/prepare_model_lib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os

from tvm.contrib import ndk


Expand All @@ -23,8 +24,8 @@ def main():
tar_list = []
model_set = set()

for model, model_lib_path in app_config["model_lib_path_for_prepare_libs"].items():
path = os.path.join(artifact_path, model_lib_path)
for model, model_lib in app_config["model_lib_path_for_prepare_libs"].items():
path = os.path.join(artifact_path, model_lib)
if not os.path.isfile(path):
raise RuntimeError(f"Cannot find android library {path}")
tar_list.append(path)
Expand Down Expand Up @@ -58,11 +59,11 @@ def main():
model_prefix_pattern not in global_symbol_map
and "_" + model_prefix_pattern not in global_symbol_map
):
model_lib_path = app_config["model_lib_path_for_prepare_libs"][model_lib]
model_lib = app_config["model_lib_path_for_prepare_libs"][model_lib]
print(
"ValidationError:\n"
f"\tmodel_lib {model_lib} requested in {app_config_path} is not found in {lib_path}\n"
f"\tspecifically the model_lib for {model_lib_path} in model_lib_path_for_prepare_libs.\n"
f"\tspecifically the model_lib for {model_lib} in model_lib_path_for_prepare_libs.\n"
f"\tcurrent available model_libs in {lib_path}: {available_model_libs}"
)
error_happened = True
Expand Down
43 changes: 2 additions & 41 deletions cpp/json_ffi/config.cc → cpp/json_ffi/conv_template.cc
Original file line number Diff line number Diff line change
@@ -1,36 +1,15 @@
#include "config.h"
#include "conv_template.h"

#include <tvm/runtime/registry.h>

#include "../metadata/json_parser.h"
#include "../support/json_parser.h"

namespace mlc {
namespace llm {
namespace json_ffi {

using namespace mlc::llm;

/****************** Model-defined generation config ******************/

TVM_REGISTER_OBJECT_TYPE(ModelDefinedGenerationConfigNode);

ModelDefinedGenerationConfig::ModelDefinedGenerationConfig(double temperature, double top_p,
double frequency_penalty,
double presence_penalty) {
ObjectPtr<ModelDefinedGenerationConfigNode> n = make_object<ModelDefinedGenerationConfigNode>();
n->temperature = temperature;
n->top_p = top_p;
n->frequency_penalty = frequency_penalty;
n->presence_penalty = presence_penalty;
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("mlc.json_ffi.ModelDefinedGenerationConfig")
.set_body_typed([](double temperature, double top_p, double frequency_penalty,
double presence_penalty) {
return ModelDefinedGenerationConfig(temperature, top_p, frequency_penalty, presence_penalty);
});

/****************** Conversation template ******************/

std::map<MessagePlaceholders, std::string> PLACEHOLDERS = {
Expand Down Expand Up @@ -334,24 +313,6 @@ std::optional<Conversation> Conversation::FromJSON(const std::string& json_str,
return Conversation::FromJSON(json_obj.value(), err);
}

/****************** JSON FFI engine config ******************/

TVM_REGISTER_OBJECT_TYPE(JSONFFIEngineConfigNode);

JSONFFIEngineConfig::JSONFFIEngineConfig(
String conv_template, Map<String, ModelDefinedGenerationConfig> model_generation_cfgs) {
ObjectPtr<JSONFFIEngineConfigNode> n = make_object<JSONFFIEngineConfigNode>();
n->conv_template = conv_template;
n->model_generation_cfgs = model_generation_cfgs;
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("mlc.json_ffi.JSONFFIEngineConfig")
.set_body_typed([](String conv_template,
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs) {
return JSONFFIEngineConfig(std::move(conv_template), std::move(model_generation_cfgs));
});

} // namespace json_ffi
} // namespace llm
} // namespace mlc
57 changes: 4 additions & 53 deletions cpp/json_ffi/config.h → cpp/json_ffi/conv_template.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
#ifndef MLC_LLM_JSON_FFI_CONFIG_H
#define MLC_LLM_JSON_FFI_CONFIG_H

#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/object.h>
#ifndef MLC_LLM_JSON_FFI_CONV_TEMPLATE_H
#define MLC_LLM_JSON_FFI_CONV_TEMPLATE_H

#include <iostream>
#include <map>
Expand All @@ -22,35 +18,11 @@ namespace mlc {
namespace llm {
namespace json_ffi {

/****************** Model-defined generation config ******************/

class ModelDefinedGenerationConfigNode : public Object {
public:
double temperature;
double top_p;
double frequency_penalty;
double presence_penalty;

static constexpr const char* _type_key = "mlc.json_ffi.ModelDefinedGenerationConfig";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(ModelDefinedGenerationConfigNode, Object);
};

class ModelDefinedGenerationConfig : public ObjectRef {
public:
explicit ModelDefinedGenerationConfig(double temperature, double top_p, double frequency_penalty,
double presence_penalty);

TVM_DEFINE_OBJECT_REF_METHODS(ModelDefinedGenerationConfig, ObjectRef,
ModelDefinedGenerationConfigNode);
};

/****************** Conversation template ******************/

enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION };

MessagePlaceholders messagePlaceholderFromString(const std::string& role);
MessagePlaceholders MessagePlaceholderFromString(const std::string& role);

class Message {
public:
Expand Down Expand Up @@ -144,29 +116,8 @@ struct Conversation {
static std::optional<Conversation> FromJSON(const std::string& json_str, std::string* err);
};

/****************** JSON FFI engine config ******************/

class JSONFFIEngineConfigNode : public Object {
public:
String conv_template;
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs;

static constexpr const char* _type_key = "mlc.json_ffi.JSONFFIEngineConfig";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(JSONFFIEngineConfigNode, Object);
};

class JSONFFIEngineConfig : public ObjectRef {
public:
explicit JSONFFIEngineConfig(String conv_template,
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs);

TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineConfig, ObjectRef, JSONFFIEngineConfigNode);
};

} // namespace json_ffi
} // namespace llm
} // namespace mlc

#endif /* MLC_LLM_JSON_FFI_CONV_TEMPLATE_H */
#endif // MLC_LLM_JSON_FFI_CONV_TEMPLATE_H
71 changes: 47 additions & 24 deletions cpp/json_ffi/json_ffi_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>

#include "../serve/model.h"
#include "../support/json_parser.h"
#include "../support/result.h"

namespace mlc {
namespace llm {
namespace json_ffi {
Expand Down Expand Up @@ -83,13 +87,27 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
Array<Data> inputs = inputs_obj.value();

// generation_cfg
Optional<GenerationConfig> generation_cfg = GenerationConfig::Create(
request_json_str, &err_, conv_template, this->model_generation_cfgs[request.model]);
if (!generation_cfg.defined()) {
return false;
Array<String> stop_strs;
stop_strs.reserve(conv_template.stop_str.size());
for (const std::string& stop_str : conv_template.stop_str) {
stop_strs.push_back(stop_str);
}
if (request.stop.has_value()) {
stop_strs.reserve(stop_strs.size() + request.stop.value().size());
for (const std::string& stop_str : request.stop.value()) {
stop_strs.push_back(stop_str);
}
}

Request engine_request(request_id, inputs, generation_cfg.value());
GenerationConfig generation_cfg(request.n, request.temperature, request.top_p,
request.frequency_penalty, request.presence_penalty,
/*repetition_penalty=*/std::nullopt, request.logprobs,
request.top_logprobs, request.logit_bias, request.seed,
request.ignore_eos, request.max_tokens, std::move(stop_strs),
conv_template.stop_token_ids, /*response_format=*/std::nullopt,
this->default_generation_cfg_json_str_);

Request engine_request(request_id, inputs, generation_cfg);
this->engine_->AddRequest(engine_request);

return true;
Expand Down Expand Up @@ -122,22 +140,8 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop);
TVM_MODULE_VTABLE_END();

void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config,
Device device, Optional<PackedFunc> request_stream_callback,
void InitBackgroundEngine(Device device, Optional<PackedFunc> request_stream_callback,
Optional<EventTraceRecorder> trace_recorder) {
std::optional<Conversation> conv_template =
Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_);
if (!conv_template.has_value()) {
LOG(FATAL) << "Invalid conversation template JSON: " << err_;
}
this->conv_template_ = conv_template.value();
this->model_generation_cfgs = json_ffi_engine_config->model_generation_cfgs;

// Todo(mlc-team): decouple InitBackgroundEngine into two functions
// by removing `engine_config` from arguments, after properly handling
// streamers.
this->streamer_ = TextStreamer(Tokenizer::FromPath(engine_config->model));

CHECK(request_stream_callback.defined())
<< "JSONFFIEngine requires request stream callback function, but it is not given.";
this->request_stream_callback_ = request_stream_callback.value();
Expand All @@ -150,12 +154,31 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
};

request_stream_callback = PackedFunc(frequest_stream_callback_wrapper);
this->engine_->InitBackgroundEngine(device, std::move(request_stream_callback),
std::move(trace_recorder));
this->engine_->Reload(std::move(engine_config));
this->engine_->InitThreadedEngine(device, std::move(request_stream_callback),
std::move(trace_recorder));
}

void Reload(EngineConfig engine_config) { this->engine_->Reload(std::move(engine_config)); }
void Reload(String engine_config_json_str) {
this->engine_->Reload(engine_config_json_str);
this->default_generation_cfg_json_str_ = this->engine_->GetDefaultGenerationConfigJSONString();
picojson::object engine_config_json =
json::ParseToJsonObject(this->engine_->GetCompleteEngineConfigJSONString());

// Load conversation template.
Result<picojson::object> model_config_json =
serve::Model::LoadModelConfig(json::Lookup<std::string>(engine_config_json, "model"));
CHECK(model_config_json.IsOk()) << model_config_json.UnwrapErr();
std::optional<Conversation> conv_template = Conversation::FromJSON(
json::Lookup<picojson::object>(model_config_json.Unwrap(), "conv_template"), &err_);
if (!conv_template.has_value()) {
LOG(FATAL) << "Invalid conversation template JSON: " << err_;
}
this->conv_template_ = conv_template.value();
// Create streamer.
// Todo(mlc-team): Create one streamer for each request, instead of a global one.
this->streamer_ =
TextStreamer(Tokenizer::FromPath(json::Lookup<std::string>(engine_config_json, "model")));
}

void Unload() { this->engine_->Unload(); }

Expand Down
4 changes: 2 additions & 2 deletions cpp/json_ffi/json_ffi_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

#include "../serve/threaded_engine.h"
#include "../streamer.h"
#include "config.h"
#include "conv_template.h"
#include "openai_api_protocol.h"

namespace mlc {
Expand Down Expand Up @@ -49,7 +49,7 @@ class JSONFFIEngine {
PackedFunc request_stream_callback_;
TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request
Conversation conv_template_;
Map<String, ModelDefinedGenerationConfig> model_generation_cfgs;
String default_generation_cfg_json_str_;
};

} // namespace json_ffi
Expand Down
2 changes: 1 addition & 1 deletion cpp/json_ffi/openai_api_protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/
#include "openai_api_protocol.h"

#include "../metadata/json_parser.h"
#include "../support/json_parser.h"

namespace mlc {
namespace llm {
Expand Down
4 changes: 2 additions & 2 deletions cpp/json_ffi/openai_api_protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <unordered_map>
#include <vector>

#include "config.h"
#include "conv_template.h"
#include "picojson.h"

namespace mlc {
Expand Down Expand Up @@ -94,7 +94,7 @@ class ChatCompletionRequest {
std::optional<double> presence_penalty = std::nullopt;
bool logprobs = false;
int top_logprobs = 0;
std::optional<std::unordered_map<int, double>> logit_bias = std::nullopt;
std::optional<std::vector<std::pair<int, float>>> logit_bias = std::nullopt;
std::optional<int> max_tokens = std::nullopt;
int n = 1;
std::optional<int> seed = std::nullopt;
Expand Down
27 changes: 15 additions & 12 deletions cpp/metadata/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include <unordered_map>

#include "./json_parser.h"
#include "../support/json_parser.h"

namespace mlc {
namespace llm {
Expand Down Expand Up @@ -39,6 +39,16 @@ ModelMetadata::Param ModelMetadata::Param::FromJSON(const picojson::object& para
return result;
}

ModelMetadata::KVCacheMetadata ModelMetadata::KVCacheMetadata::FromJSON(
const picojson::object& json) {
KVCacheMetadata kv_cache_metadata;
kv_cache_metadata.num_hidden_layers = json::Lookup<int64_t>(json, "num_hidden_layers");
kv_cache_metadata.head_dim = json::Lookup<int64_t>(json, "head_dim");
kv_cache_metadata.num_attention_heads = json::Lookup<int64_t>(json, "num_attention_heads");
kv_cache_metadata.num_key_value_heads = json::Lookup<int64_t>(json, "num_key_value_heads");
return kv_cache_metadata;
}

ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata,
const picojson::object& model_config) {
ModelMetadata result;
Expand All @@ -53,6 +63,8 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata,
if (metadata.count("attention_sink_size")) // remove after sink is decoupled from model lib
result.attention_sink_size = json::Lookup<int64_t>(metadata, "attention_sink_size");
result.tensor_parallel_shards = json::Lookup<int64_t>(metadata, "tensor_parallel_shards");
result.kv_cache_metadata =
KVCacheMetadata::FromJSON(json::Lookup<picojson::object>(metadata, "kv_cache"));
{
std::vector<ModelMetadata::Param>& params = result.params;
picojson::array json_params = json::Lookup<picojson::array>(metadata, "params");
Expand All @@ -76,17 +88,8 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata,
ModelMetadata ModelMetadata::FromModule(tvm::runtime::Module module,
const picojson::object& model_config) {
std::string json_str = "";
try {
TypedPackedFunc<String()> pf = module.GetFunction("_metadata");
if (pf == nullptr) {
// legacy path
// TODO: remove this after full SLMify
return ModelMetadata();
}
json_str = pf();
} catch (...) {
return ModelMetadata(); // TODO: add a warning message about legacy usecases
}
TypedPackedFunc<String()> pf = module.GetFunction("_metadata");
json_str = pf();
picojson::object json = json::ParseToJsonObject(json_str);
try {
return ModelMetadata::FromJSON(json, model_config);
Expand Down
Loading