diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index c9588cc4e8..d82c886355 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -63,10 +64,19 @@ class EngineImpl : public Engine { this->models_.clear(); this->model_workspaces_.clear(); - auto f_create_model = [this, &engine_config, &device, &trace_recorder]( - const String& model_path, const String& model_lib_path) { - Model model = Model::Create(model_lib_path, std::move(model_path), device, - engine_config->max_num_sequence, + std::vector model_configs; + model_configs.push_back(Model::LoadModelConfig(engine_config->model)); + for (const auto& model_path : engine_config->additional_models) { + model_configs.push_back(Model::LoadModelConfig(model_path)); + } + + Optional session = CreateDiscoSession(model_configs, device); + + auto f_create_model = [this, &engine_config, &device, &trace_recorder, &model_configs, + &session](const String& model_path, const String& model_lib_path, + int model_index) { + Model model = Model::Create(model_lib_path, std::move(model_path), model_configs[model_index], + device, engine_config->max_num_sequence, session, /*trace_enabled=*/trace_recorder.defined()); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, engine_config->max_total_sequence_length, @@ -81,13 +91,13 @@ class EngineImpl : public Engine { ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); }; - f_create_model(engine_config->model, engine_config->model_lib_path); + f_create_model(engine_config->model, engine_config->model_lib_path, /*model_index=*/0); CHECK_EQ(engine_config->additional_models.size(), engine_config->additional_model_lib_paths.size()) << "The additional model and lib path list has mismatched size."; for (int i = 0; i < static_cast(engine_config->additional_models.size()); ++i) { f_create_model(engine_config->additional_models[i], - engine_config->additional_model_lib_paths[i]); + engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1); } int max_num_tokens = engine_config->max_num_sequence; @@ -287,6 +297,51 @@ class EngineImpl : public Engine { "action (e.g. prefill, decode, etc.) but it does not."; } + /************** Utility Functions **************/ + Optional CreateDiscoSession(std::vector model_configs, Device device) { + const auto& base_model_config = model_configs[0]; + + auto f_get_num_shards = [](const picojson::object& model_config) -> int { + constexpr auto kNumShardsKey = "tensor_parallel_shards"; + if (model_config.count(kNumShardsKey)) { + const auto& val = model_config.at(kNumShardsKey); + CHECK(val.is()); + return static_cast(val.get()); + } else { + LOG(FATAL) << "Key \"tensor_parallel_shards\" not found."; + } + throw; + }; + + int num_shards = std::transform_reduce( + model_configs.begin(), model_configs.end(), 1, [](int a, int b) { return std::max(a, b); }, + f_get_num_shards); + Optional session = NullOpt; + if (num_shards > 1) { + constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; + if (Registry::Get(f_create_process_pool) == nullptr) { + LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " + << "Multi-GPU inference depends on MLC LLM Python API to launch process."; + } + std::string ccl; + if (device.device_type == kDLCUDA) { + ccl = "nccl"; + } else if (device.device_type == kDLROCM) { + ccl = "rccl"; + } else { + LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type) + << " is not supported. Currently, only NCCL and RCCL are integrated."; + } + std::vector device_ids(num_shards); + for (int i = 0; i < num_shards; ++i) { + device_ids[i] = i; + } + session = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); + session.value()->InitCCL(ccl, ShapeTuple(device_ids)); + } + return session; + } + /************** Debug/Profile **************/ void DebugCallFuncOnAllAllWorker(const String& func_name) final { diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index b721eae7c3..3267f1dd38 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -69,7 +69,8 @@ PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func, }); } -void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config) { +void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config, + Optional session) { local_gpu_device = device; Device null_device{DLDeviceType(0), 0}; int num_shards; @@ -85,27 +86,8 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object this->cached_buffers = Map(); if (num_shards > 1) { - constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; - if (Registry::Get(f_create_process_pool) == nullptr) { - LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " - << "Multi-GPU inference depends on MLC LLM Python API to launch process."; - } - std::string ccl; - if (device.device_type == kDLCUDA) { - ccl = "nccl"; - } else if (device.device_type == kDLROCM) { - ccl = "rccl"; - } else { - LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type) - << " is not supported. Currently, only NCCL and RCCL are integrated."; - } - std::vector device_ids(num_shards); - for (int i = 0; i < num_shards; ++i) { - device_ids[i] = i; - } + this->sess = session.value(); this->use_disco = true; - this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); - this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), reload_lib_path, null_device); this->mod_get_func = [this, diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index b6ea3287ad..bc2b4f21c8 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -41,7 +41,8 @@ using namespace tvm::runtime; struct FunctionTable { static PackedFunc SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name); - void Init(String reload_lib_path, Device device, picojson::object model_config); + void Init(String reload_lib_path, Device device, picojson::object model_config, + Optional session); ObjectRef LoadParams(const std::string& model_path, Device device); diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 27a0043850..6f34220219 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -26,10 +26,27 @@ class ModelImpl; TVM_REGISTER_OBJECT_TYPE(ModelObj); -Model Model::Create(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled) { - return Model( - make_object(reload_lib_path, model_path, device, max_num_sequence, trace_enabled)); +Model Model::Create(String reload_lib_path, String model_path, const picojson::object& model_config, + DLDevice device, int max_num_sequence, const Optional& session, + bool trace_enabled) { + return Model(make_object(reload_lib_path, model_path, model_config, device, + max_num_sequence, session, trace_enabled)); +} + +picojson::object Model::LoadModelConfig(const String& model_path) { + picojson::object model_config; + std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); + std::ostringstream config_ostream; + ICHECK(config_istream); + config_ostream << config_istream.rdbuf(); + std::string config_str = config_ostream.str(); + picojson::value config_json; + std::string err = picojson::parse(config_json, config_str); + if (!err.empty()) { + LOG(FATAL) << err; + } + picojson::object config = config_json.get(); + return config; } class ModelImpl : public ModelObj { @@ -38,23 +55,16 @@ class ModelImpl : public ModelObj { * \brief Constructor of ModelImpl. * \sa Model::Create */ - explicit ModelImpl(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled) + explicit ModelImpl(String reload_lib_path, String model_path, picojson::object model_config, + DLDevice device, int max_num_sequence, const Optional& session, + bool trace_enabled) : device_(device) { // Step 1. Process model config json string. - picojson::object model_config; - { - std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); - std::ostringstream config_ostream; - ICHECK(config_istream); - config_ostream << config_istream.rdbuf(); - std::string config_str = config_ostream.str(); - model_config = LoadModelConfigJSON(config_str); - } + LoadModelConfigJSON(model_config); // Step 2. Initialize vm, we use the packed function mechanism // so there is no explicit abi dependency on these extra // classes other than basic tvm runtime. - this->ft_.Init(reload_lib_path, device_, model_config); + this->ft_.Init(reload_lib_path, device_, model_config, session); // Step 3. Load params in nd-array cache. this->params_ = ft_.LoadParams(model_path, device_); // Step 4. Set max_num_sequence @@ -891,15 +901,7 @@ class ModelImpl : public ModelObj { private: /*! \brief Load model configuration from JSON. */ - picojson::object LoadModelConfigJSON(const std::string& config_str) { - picojson::value config_json; - std::string err = picojson::parse(config_json, config_str); - if (!err.empty()) { - LOG(FATAL) << err; - } - - // Get json fields. - picojson::object config = config_json.get(); + picojson::object LoadModelConfigJSON(picojson::object config) { if (config.count("context_window_size")) { CHECK(config["context_window_size"].is()); this->max_window_size_ = config["context_window_size"].get(); diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 045daff874..bc63840a74 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -319,13 +319,24 @@ class Model : public ObjectRef { * \brief Create the runtime module for LLM functions. * \param reload_lib_path The model library path. * \param model_path The path to the model weight parameters. + * \param model_config The model config json object. * \param device The device to run the model on. * \param max_num_sequence The maximum number of sequences to be processed + * \param session The session to run the model on. * \param trace_enabled A boolean indicating whether tracing is enabled. * \return The created runtime module. */ - TVM_DLL static Model Create(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled); + TVM_DLL static Model Create(String reload_lib_path, String model_path, + const picojson::object& model_config, DLDevice device, + int max_num_sequence, const Optional& session, + bool trace_enabled); + + /*! + * Load the model config from the given model path. + * \param model_path The path to the model weight parameters. + * \return The model config json object. + */ + static picojson::object LoadModelConfig(const String& model_path); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj); };