Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 61 additions & 6 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>

#include <numeric>
#include <optional>
#include <tuple>
#include <unordered_set>
Expand Down Expand Up @@ -63,10 +64,19 @@ class EngineImpl : public Engine {
this->models_.clear();
this->model_workspaces_.clear();

auto f_create_model = [this, &engine_config, &device, &trace_recorder](
const String& model_path, const String& model_lib_path) {
Model model = Model::Create(model_lib_path, std::move(model_path), device,
engine_config->max_num_sequence,
std::vector<picojson::object> model_configs;
model_configs.push_back(Model::LoadModelConfig(engine_config->model));
for (const auto& model_path : engine_config->additional_models) {
model_configs.push_back(Model::LoadModelConfig(model_path));
}

Optional<Session> session = CreateDiscoSession(model_configs, device);

auto f_create_model = [this, &engine_config, &device, &trace_recorder, &model_configs,
&session](const String& model_path, const String& model_lib_path,
int model_index) {
Model model = Model::Create(model_lib_path, std::move(model_path), model_configs[model_index],
device, engine_config->max_num_sequence, session,
/*trace_enabled=*/trace_recorder.defined());
model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence,
engine_config->max_total_sequence_length,
Expand All @@ -81,13 +91,13 @@ class EngineImpl : public Engine {
ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()});
};

f_create_model(engine_config->model, engine_config->model_lib_path);
f_create_model(engine_config->model, engine_config->model_lib_path, /*model_index=*/0);
CHECK_EQ(engine_config->additional_models.size(),
engine_config->additional_model_lib_paths.size())
<< "The additional model and lib path list has mismatched size.";
for (int i = 0; i < static_cast<int>(engine_config->additional_models.size()); ++i) {
f_create_model(engine_config->additional_models[i],
engine_config->additional_model_lib_paths[i]);
engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1);
}

int max_num_tokens = engine_config->max_num_sequence;
Expand Down Expand Up @@ -287,6 +297,51 @@ class EngineImpl : public Engine {
"action (e.g. prefill, decode, etc.) but it does not.";
}

/************** Utility Functions **************/
Optional<Session> CreateDiscoSession(std::vector<picojson::object> model_configs, Device device) {
const auto& base_model_config = model_configs[0];

auto f_get_num_shards = [](const picojson::object& model_config) -> int {
constexpr auto kNumShardsKey = "tensor_parallel_shards";
if (model_config.count(kNumShardsKey)) {
const auto& val = model_config.at(kNumShardsKey);
CHECK(val.is<int64_t>());
return static_cast<int>(val.get<int64_t>());
} else {
LOG(FATAL) << "Key \"tensor_parallel_shards\" not found.";
}
throw;
};

int num_shards = std::transform_reduce(
model_configs.begin(), model_configs.end(), 1, [](int a, int b) { return std::max(a, b); },
f_get_num_shards);
Optional<Session> session = NullOpt;
if (num_shards > 1) {
constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool";
if (Registry::Get(f_create_process_pool) == nullptr) {
LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. "
<< "Multi-GPU inference depends on MLC LLM Python API to launch process.";
}
std::string ccl;
if (device.device_type == kDLCUDA) {
ccl = "nccl";
} else if (device.device_type == kDLROCM) {
ccl = "rccl";
} else {
LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type)
<< " is not supported. Currently, only NCCL and RCCL are integrated.";
}
std::vector<int64_t> device_ids(num_shards);
for (int i = 0; i < num_shards; ++i) {
device_ids[i] = i;
}
session = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker");
session.value()->InitCCL(ccl, ShapeTuple(device_ids));
}
return session;
}

/************** Debug/Profile **************/

void DebugCallFuncOnAllAllWorker(const String& func_name) final {
Expand Down
24 changes: 3 additions & 21 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func,
});
}

void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config) {
void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config,
Optional<Session> session) {
local_gpu_device = device;
Device null_device{DLDeviceType(0), 0};
int num_shards;
Expand All @@ -85,27 +86,8 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object
this->cached_buffers = Map<String, ObjectRef>();

if (num_shards > 1) {
constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool";
if (Registry::Get(f_create_process_pool) == nullptr) {
LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. "
<< "Multi-GPU inference depends on MLC LLM Python API to launch process.";
}
std::string ccl;
if (device.device_type == kDLCUDA) {
ccl = "nccl";
} else if (device.device_type == kDLROCM) {
ccl = "rccl";
} else {
LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type)
<< " is not supported. Currently, only NCCL and RCCL are integrated.";
}
std::vector<int64_t> device_ids(num_shards);
for (int i = 0; i < num_shards; ++i) {
device_ids[i] = i;
}
this->sess = session.value();
this->use_disco = true;
this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker");
this->sess->InitCCL(ccl, ShapeTuple(device_ids));
this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"),
reload_lib_path, null_device);
this->mod_get_func = [this,
Expand Down
3 changes: 2 additions & 1 deletion cpp/serve/function_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ using namespace tvm::runtime;
struct FunctionTable {
static PackedFunc SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name);

void Init(String reload_lib_path, Device device, picojson::object model_config);
void Init(String reload_lib_path, Device device, picojson::object model_config,
Optional<Session> session);

ObjectRef LoadParams(const std::string& model_path, Device device);

Expand Down
52 changes: 27 additions & 25 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,27 @@ class ModelImpl;

TVM_REGISTER_OBJECT_TYPE(ModelObj);

Model Model::Create(String reload_lib_path, String model_path, DLDevice device,
int max_num_sequence, bool trace_enabled) {
return Model(
make_object<ModelImpl>(reload_lib_path, model_path, device, max_num_sequence, trace_enabled));
Model Model::Create(String reload_lib_path, String model_path, const picojson::object& model_config,
DLDevice device, int max_num_sequence, const Optional<Session>& session,
bool trace_enabled) {
return Model(make_object<ModelImpl>(reload_lib_path, model_path, model_config, device,
max_num_sequence, session, trace_enabled));
}

picojson::object Model::LoadModelConfig(const String& model_path) {
picojson::object model_config;
std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str());
std::ostringstream config_ostream;
ICHECK(config_istream);
config_ostream << config_istream.rdbuf();
std::string config_str = config_ostream.str();
picojson::value config_json;
std::string err = picojson::parse(config_json, config_str);
if (!err.empty()) {
LOG(FATAL) << err;
}
picojson::object config = config_json.get<picojson::object>();
return config;
}

class ModelImpl : public ModelObj {
Expand All @@ -38,23 +55,16 @@ class ModelImpl : public ModelObj {
* \brief Constructor of ModelImpl.
* \sa Model::Create
*/
explicit ModelImpl(String reload_lib_path, String model_path, DLDevice device,
int max_num_sequence, bool trace_enabled)
explicit ModelImpl(String reload_lib_path, String model_path, picojson::object model_config,
DLDevice device, int max_num_sequence, const Optional<Session>& session,
bool trace_enabled)
: device_(device) {
// Step 1. Process model config json string.
picojson::object model_config;
{
std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str());
std::ostringstream config_ostream;
ICHECK(config_istream);
config_ostream << config_istream.rdbuf();
std::string config_str = config_ostream.str();
model_config = LoadModelConfigJSON(config_str);
}
LoadModelConfigJSON(model_config);
// Step 2. Initialize vm, we use the packed function mechanism
// so there is no explicit abi dependency on these extra
// classes other than basic tvm runtime.
this->ft_.Init(reload_lib_path, device_, model_config);
this->ft_.Init(reload_lib_path, device_, model_config, session);
// Step 3. Load params in nd-array cache.
this->params_ = ft_.LoadParams(model_path, device_);
// Step 4. Set max_num_sequence
Expand Down Expand Up @@ -891,15 +901,7 @@ class ModelImpl : public ModelObj {

private:
/*! \brief Load model configuration from JSON. */
picojson::object LoadModelConfigJSON(const std::string& config_str) {
picojson::value config_json;
std::string err = picojson::parse(config_json, config_str);
if (!err.empty()) {
LOG(FATAL) << err;
}

// Get json fields.
picojson::object config = config_json.get<picojson::object>();
picojson::object LoadModelConfigJSON(picojson::object config) {
if (config.count("context_window_size")) {
CHECK(config["context_window_size"].is<int64_t>());
this->max_window_size_ = config["context_window_size"].get<int64_t>();
Expand Down
15 changes: 13 additions & 2 deletions cpp/serve/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,24 @@ class Model : public ObjectRef {
* \brief Create the runtime module for LLM functions.
* \param reload_lib_path The model library path.
* \param model_path The path to the model weight parameters.
* \param model_config The model config json object.
* \param device The device to run the model on.
* \param max_num_sequence The maximum number of sequences to be processed
* \param session The session to run the model on.
* \param trace_enabled A boolean indicating whether tracing is enabled.
* \return The created runtime module.
*/
TVM_DLL static Model Create(String reload_lib_path, String model_path, DLDevice device,
int max_num_sequence, bool trace_enabled);
TVM_DLL static Model Create(String reload_lib_path, String model_path,
const picojson::object& model_config, DLDevice device,
int max_num_sequence, const Optional<Session>& session,
bool trace_enabled);

/*!
* Load the model config from the given model path.
* \param model_path The path to the model weight parameters.
* \return The model config json object.
*/
static picojson::object LoadModelConfig(const String& model_path);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj);
};
Expand Down