diff --git a/CMakeLists.txt b/CMakeLists.txt index e59a112fab04..09b9aeb4db3b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -431,6 +431,25 @@ if(USE_GTEST) find_package(GTest REQUIRED) endif() if(GTEST_FOUND) + if(NOT TARGET GTest::gmock) + # GMock is formally supported in CMake 3.20; for now, expect libgmock.a in the same directory, + # and require that folks compiling against GTest::gmock also link against GTest::GTest + # (for the includes dir). + add_library(GTest::gmock STATIC IMPORTED GLOBAL) + get_target_property(GTEST_LIB_PATH GTest::GTest IMPORTED_LOCATION) + if("${GTEST_LIB_PATH}" STREQUAL "GTEST_LIB_PATH-NOTFOUND") + # CMake >= 3.20 makes GTest::GTest into a compatibility target. The real import location is in + # GTest::gtest. + get_target_property(GTEST_LIB_PATH GTest::gtest IMPORTED_LOCATION) + if("${GTEST_LIB_PATH}" STREQUAL "GTEST_LIB_PATH-NOTFOUND") + message(FATAL_ERROR "Neither GTest::GTest nor GTets::gtest targets defined IMPORTED_LOCATION") + endif() + endif() + get_filename_component(GTEST_LIB_DIR "${GTEST_LIB_PATH}" DIRECTORY) + set_target_properties(GTest::gmock PROPERTIES + IMPORTED_LOCATION "${GTEST_LIB_DIR}/libgmock.a") + endif() + enable_testing() include(CTest) endif() @@ -626,7 +645,7 @@ if(GTEST_FOUND) add_executable(cpptest ${TEST_SRCS}) # include runtime files for unit testing target_include_directories(cpptest PUBLIC "src/runtime") - target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main pthread dl) + target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main GTest::gmock pthread dl) set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_ALL 1) set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) # For some reason, compile definitions are not propagated correctly, so we manually add them here diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h index cd65f6fb7486..b7f7c6c0a458 100644 --- a/include/tvm/runtime/metadata.h +++ b/include/tvm/runtime/metadata.h @@ -116,6 +116,7 @@ class MetadataNode : public MetadataBaseNode { public: explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {} static constexpr const char* _type_key = "metadata.MetadataNode"; + const char* get_c_struct_name() const override; inline int64_t version() const { return int64_t(data_->version); } inline int64_t num_inputs() const { return data_->num_inputs; } ArrayAccessor inputs(); @@ -141,6 +142,7 @@ class TensorInfoNode : public MetadataBaseNode { public: explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {} static constexpr const char* _type_key = "metadata.TensorInfoNode"; + const char* get_c_struct_name() const override; inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); } inline int64_t num_shape() const { return data_->num_shape; } inline ::tvm::support::Span shape() const { diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h index 96743199fe28..698f56d46d28 100644 --- a/include/tvm/runtime/metadata_base.h +++ b/include/tvm/runtime/metadata_base.h @@ -44,6 +44,8 @@ namespace metadata { */ class MetadataBaseNode : public ::tvm::runtime::Object { public: + virtual const char* get_c_struct_name() const = 0; + static constexpr const char* _type_key = "metadata.MetadataBaseNode"; TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object); }; @@ -157,7 +159,7 @@ class ArrayAccessor { * * These are separate from TIR DataType because TIR does not model structs. */ -enum MetadataTypeIndex : uint8_t { +enum MetadataKind : uint8_t { kUint64 = 0, kInt64 = 1, kBool = 2, @@ -173,12 +175,29 @@ enum MetadataTypeIndex : uint8_t { */ class MetadataArrayNode : public MetadataBaseNode { public: - MetadataArrayNode(Array array, MetadataTypeIndex type_index, const char* struct_name) - : array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {} + MetadataArrayNode(Array array, MetadataKind kind, const char* type_key) + : array(::std::move(array)), kind{kind}, type_key{type_key} {} + + const char* get_c_struct_name() const final; + + std::string get_element_c_struct_name() const { + CHECK(kind == MetadataKind::kMetadata) + << "cannot get struct name for MetadataArray with kind=" << kind; + constexpr int prefix_size = sizeof("metadata.") - 1; + constexpr int suffix_size = sizeof("Node") - 1; + std::string type_key_str(type_key); + return std::string("TVM") + + type_key_str.substr(prefix_size, type_key_str.size() - prefix_size - suffix_size); + } Array array; - MetadataTypeIndex type_index; - const char* struct_name; + + /*! \brief Describes the storage class of the emitted struct member. */ + MetadataKind kind; + + /*! \brief When `kind` is Metadata, type_key of the MetadataBaseNode used with this array. */ + const char* type_key; + static constexpr const char* _type_key = "metadata.MetadataArrayNode"; TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode); }; @@ -186,7 +205,7 @@ class MetadataArrayNode : public MetadataBaseNode { /*! \brief Reference class for MetadataArray. */ class MetadataArray : public MetadataBase { public: - MetadataArray(Array array, MetadataTypeIndex type_index, const char* struct_name); + MetadataArray(Array array, MetadataKind kind, const char* struct_name); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode); }; diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 3d0fb407ef3f..45eaa8b8be77 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -870,7 +870,8 @@ class PreflattenedBufferMap(SpecialStmt): Example ------- .. code-block:: python - T.preflattened_buffer_map({}) + A0 = T.match_buffer(A, (48,), dtype="float32") + T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32") """ def __init__(self): @@ -892,12 +893,30 @@ def preflattened_buffer( for key, value in self.context.func_buffer_map.items(): if value.same_as(postflattened): param = key + break assert ( param is not None ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map." + if data is None: + data = self.context.func_buffer_map[param].data + buffer_name: str = f"{postflattened.name}_preflatten" + if align != -1: + if isinstance(align, IntImm): + align = align.value + else: + assert isinstance(align, int), f"align: want int or IntImm, got {align!r}" + + if offset_factor != 0: + if isinstance(offset_factor, IntImm): + offset_factor = offset_factor.value + else: + assert isinstance( + offset_factor, int + ), f"offset_factor: want int or IntImm, got {offset_factor!r}" + preflattened = tvm.tir.decl_buffer( shape, dtype, diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py index f9115fc61bfa..cedaafe80a52 100644 --- a/python/tvm/testing/tir.py +++ b/python/tvm/testing/tir.py @@ -17,10 +17,14 @@ # pylint: disable=invalid-name, import-outside-toplevel, unused-variable """Common utility functions in TVM tir""" import inspect +import re import tvm from tvm.ir.diagnostics import override_renderer +CHECK_ERROR_RE = re.compile(r"^.*# check_error: (.+)$") + + def check_error(func, rel_lineno): """check if TIR script throws error""" # Override the default renderer to accumulate errors @@ -46,3 +50,12 @@ def render(e): assert ( d.span.line - 1 == rel_lineno ), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}" + + error_line = source_code.split("\n")[rel_lineno] + m = CHECK_ERROR_RE.match(error_line) + if m: + expected_error_text = m.group(1) + errors = [e.message for e in errors] + assert ( + expected_error_text in errors + ), f'check_error expects "{expected_error_text} in str(errors): {errors}' diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 1ef62c257648..fe829016b6b5 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -151,6 +151,17 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { doc << Doc::Indent( 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); } + + if (op->preflattened_buffer_map.size() != 0) { + // print preflattened_buffer_map + std::vector preflattened_buffer_map_doc; + for (auto& v : op->preflattened_buffer_map) { + preflattened_buffer_map_doc.push_back(Print(v.first) << ": " << Print(v.second)); + } + doc << Doc::Indent(2, Doc::NewLine() + << "preflattened_buffer_map = {" + << PrintSep(preflattened_buffer_map_doc, Doc::Text(", ")) << "}"); + } doc << PrintBody(op->body); return doc; } diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 542bcd163995..c2b2ac0fc5e2 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -263,26 +263,72 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { /*! \brief Code generator for AOT executor */ class AOTExecutorCodegen : public MixedModeVisitor { protected: - /*! - * \brief Utility function to allocate a DLTensor or TVMValue - * \param type the type of allocation - * \param num the number of variable to allocate on the stack - * \return PrimExpr representing the allocated object - */ - PrimExpr StackAlloca(std::string type, size_t num) { - Array args = {tir::StringImm(type), ConstInt32(num)}; - return tir::Call(DataType::Handle(), tir::builtin::tvm_stack_alloca(), args); - } - - /*! - * \brief Utility function to convert a concrete integer to a PrimExpr. - * \param num the number to convert - * \return PrimExpr representing num - */ - inline PrimExpr ConstInt32(int32_t num) { - ICHECK_LE(num, std::numeric_limits::max()); - return tir::make_const(DataType::Int(32), static_cast(num)); - } + /*! \brief Describes the type of kernel call emitted. */ + enum CallType { + /*! + * \brief Emit PackedFunc calls bound just-in-time using TVMBackend* functions. + * + * When this type is selected, assumes all operators must be called via TVMFuncCall. Given the + * implementation of TVMFuncCall in the C++ runtime, this in practice implies that those + * functions are of type TVMBackendPackedCFunc. + * + * The following code is emitted at call sites to call a function named `func`: + * void* func_ptr = TVMBackendGetFuncFromEnv("func"); + * TVMFuncCall(func_ptr, values, tcodes, num_args, ret_values, ret_tcodes) + * + * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` + * by LowerTVMBuiltin TIR transform. + * + * If `resource_handle` is passed to `func`, it is determined by TVMFuncCall (often, + * `resource_handle` is registered with the C++ runtime to provide a `this` equivalent when + * `func` is implemented in C). + * + * Compatible with both C++ and C runtimes, implemented with the C runtime only. + */ + kPacked, // Emit tir.call_packed and wrap all arguments in DLTensor. + + /*! + * \brief Directly call a TVMBackendPackedCFunc named according to the tir::Call. + * + * When this type is selected, assumes all operators are implemented in functions of type + * `TVMBackendPackedCFunc` and should be called directly. That is, presumes at the time of + * downstream compilation that there is a symbol named after the 0th arg to tir::Call of + * type `TVMBackendPackedCFunc`. This situation should occur when target_host == target. + * + * The following code is emitted at call sites to call a function named `func`: + * func(values, tcodes, num_args, ret_values, ret_tcodes, resource_handle) + * + * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` + * by LowerTVMBuiltin TIR transform. + * + * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is + * always the device context parameter when not null. At present, the implementation does not + * support forwarding device context parameters to CPacked. + * + * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented + * in the same scenarios. + */ + kCPacked, // Emit tir.call_cpacked and wrap all arguments in DLTensor. + + /*! \brief Directly call a function accepting the `data` arrays as args. + * + * When this type is selected, assumes all operaotrs are implemented in C functions whose + * arguments are 1-to-1 with those in the tir::Call. DLTensor arguments are encoded as just the + * `data` parameters (i.e. no DLTensor object is passed along). + * + * The following code is emitted at call sites to a function named `func`: + * func(void* arg0, void* arg1, ..., void* argN) // no resource_handle + * -or- + * func(void* arg0, void* arg1, ..., void* argN, void* resource_handle) // with resource_handle + * + * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is + * always the device context parameter when not null. + * + * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented + * with the C runtime only. + */ + kUnpacked, // Emit tir.call_extern passing only the `data` part of DLTensors. + }; /*! * \brief Return a vector of variables that represents the sids for the given Relay Expr @@ -323,6 +369,21 @@ class AOTExecutorCodegen : public MixedModeVisitor { } } + /*! + * \brief Reverse lookup the device name in devices_ map. + * \param device_context Value in devices_ to find. + * \return Key matching device_context in devices_. + */ + std::string FindDeviceName(tir::Var device_context) { + for (std::pair kv : devices_) { + if (kv.second->name_hint == device_context->name_hint) { + return kv.first; + } + } + ICHECK(false) << "Did not find a device name associated with " << device_context; + return ""; + } + void PushArgs(const Expr& expr, const std::vector& sids, Array* args) { const TupleNode* t = expr.as(); if (t != nullptr) { @@ -338,12 +399,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { * returns the passed Call */ tir::Call AddCheckReturn(tir::Call existing_call) { - if (use_unpacked_api_) { - Array args = {ConstInt32(0), ConstInt32(-1), existing_call}; - return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); - } - - return existing_call; + Array args = {tir::make_const(DataType::Int(32, 1), 0, Span()), + tir::make_const(DataType::Int(32, 1), -1, Span()), existing_call}; + return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); } /*! @@ -378,56 +436,59 @@ class AOTExecutorCodegen : public MixedModeVisitor { auto result_expr_sid = PackSid(result_expr); PushArgs(result_expr, result_expr_sid, &args); - // Choose call style based on Runtime/Executor config. - Op calling_pattern; - if (use_unpacked_api_) { - calling_pattern = tvm::tir::builtin::call_extern(); - } else if (use_call_cpacked_) { - calling_pattern = tvm::tir::builtin::tvm_call_cpacked(); - } else { - calling_pattern = tvm::tir::builtin::tvm_call_packed(); - } - GlobalVar global_var = call_lowered_props.lowered_func; - tir::Var empty_var("no_device_context", DataType::Handle()); bool has_c_device_api_context = device_contexts_.count(global_var) != 0; + tir::Var device_context; + tir::Stmt func_call; + + switch (call_type_) { + case CallType::kUnpacked: { + // call_extern calling convention with optional context + if (has_c_device_api_context) { + device_context = device_contexts_.Get(global_var).value(); + args.push_back(device_context); + } + func_call = tir::Evaluate(AddCheckReturn( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args))); + break; + } + case CallType::kCPacked: { + if (has_c_device_api_context) { + device_context = device_contexts_.Get(global_var).value(); + args.push_back(device_context); + } else { + // NOTE: LowerTVMBuiltin expects some device_context placeholder. + args.push_back(tir::make_zero(DataType::Handle())); + } + func_call = tir::Evaluate( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args)); + create_func_call_stmts.push_back(func_call); + break; + } + case CallType::kPacked: { + // call_packed does not accept a device context. + CHECK(!has_c_device_api_context) << "CallType::kPacked does not accept a device context"; + func_call = tir::Evaluate(AddCheckReturn( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args))); + create_func_call_stmts.push_back(func_call); + break; + } + default: + ICHECK(false) << "Unknown CallType: " << call_type_; + } + + ICHECK(func_call.defined()) << "Must define func_call"; - // The device context is passed to the operator in one of the following calling patterns: - // * Unpacked / direct function call with context: - // operator(arg0, arg1, device_context); - // * Unpacked / direct function call without context: - // operator(arg0, arg1); - // * Type-erased packed function call with context: - // operator(args, type_codes, int num_args, out_ret_value, out_ret_tcode, - // device_context_my_device) - // * Type-erased packed function call without context (we create an empty var for codegen): - // operator(args, type_codes, int num_args, out_ret_value, out_ret_tcode, - // no_device_context) if (has_c_device_api_context) { - // call_extern calling convention with context - tir::Var context = device_contexts_.Get(global_var).value(); - args.push_back(context); - - tir::Evaluate func_call( - AddCheckReturn(tvm::tir::Call(DataType::Int(32), calling_pattern, args))); - create_func_call_stmts.push_back(tir::SeqStmt({ - GenerateDeviceHook(context, "Open"), + func_call = tir::SeqStmt(Array({ + GenerateDeviceHook(device_context, "Open"), func_call, - GenerateDeviceHook(context, "Close"), + GenerateDeviceHook(device_context, "Close"), })); - } else if (use_call_cpacked_) { - // call_cpacked calling convention needs a blank context - args.push_back(tir::make_zero(DataType::Handle())); - tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); - create_func_call_stmts.push_back(func_call); - } else { - // call_extern calling convention without context - tir::Evaluate func_call( - AddCheckReturn(tvm::tir::Call(DataType::Int(32), calling_pattern, args))); - create_func_call_stmts.push_back(func_call); } - tir::Stmt body = tir::SeqStmt(create_func_call_stmts); + tir::Stmt body = tir::SeqStmt({func_call}); + LOG(INFO) << "CreateFuncCall: " << call_lowered_props.lowered_func->name_hint << " -> " << body; stmts_.push_back(body); } @@ -446,9 +507,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { te::Var loop_idx("i", DataType::Int(32)); auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); // Copy the variable from the input to the output - tir::Stmt copy = - tir::For(loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial, - tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); + tir::Stmt copy = tir::For( + loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, + tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); } @@ -692,7 +753,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { for (int i = 0; i < ndim; i++) { int shape = kv.second->data->shape[i]; - extents.push_back(tir::make_const(DataType::Int(32), shape)); + extents.push_back(tir::make_const(DataType::Int(32), shape, Span())); } body = tir::AllocateConst(buffer_var, dtype, extents, kv.second->data, body); } @@ -855,30 +916,10 @@ class AOTExecutorCodegen : public MixedModeVisitor { /*! \brief target host */ Target target_host_; /*! - * \brief unpacked api toggle - * When set to true, the generated code will use unpacked calls to functions: - * func(void* arg0, void* arg1) - * Rather than packed calls (in which arg0 and arg1 are in `arg_values`). - * func(TVMValue* arg_values, int* arg_type_codes, int num_args, ...) - * Defaults to using the packed calling convention - * - * Unpacked API is supported when runtime == "c" and interface_api is "c". - */ - Bool use_unpacked_api_; - /*! - * \brief cpacked api toggle - * When set to true, the generated code will use call_cpacked to call functions directly, assuming - * they exist in a DSO-exportable module: - * func(...) - * Rather than through the traditional call_packed calls, which should use function pointers - * looked-up through TVMBackendGetFuncFromEnv: - * TVMBackendPackedCFunc* func_ptr = TVMBackendGetFuncFromEnv("func"); - * func_ptr(...) - * Defaults to using the packed calling convention - * - * call_cpacked is required when runtime is "c++" and supported when runtime is "c" + * \brief The type of kernel call to be emitted. + * See CallType for more documentation. */ - Bool use_call_cpacked_; + CallType call_type_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). @@ -907,11 +948,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { public: AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) - : mod_(mod), - targets_(targets), - target_host_(target_host), - use_unpacked_api_(Bool(false)), - use_call_cpacked_(Bool(false)) {} + : mod_(mod), targets_(targets), target_host_(target_host) {} LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { VLOG_CONTEXT << "AOT"; @@ -923,23 +960,36 @@ class AOTExecutorCodegen : public MixedModeVisitor { Runtime runtime_config = mod->GetAttr(tvm::attr::kRuntime).value(); Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); - String interface_api = executor_config->GetAttr("interface-api").value_or("packed"); + std::string interface_api = + executor_config->GetAttr("interface-api").value_or("packed"); Integer workspace_byte_alignment = executor_config->GetAttr("workspace-byte-alignment").value_or(16); - use_unpacked_api_ = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); - use_call_cpacked_ = !use_unpacked_api_; + bool unpacked_api = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); // Validate choice of use_unpacked_api_ and use_call_cpacked_ if (runtime_config->name == kTvmRuntimeCrt) { - ICHECK(interface_api == "packed" || static_cast(use_unpacked_api_) == true) - << "Either need interface_api == \"packed\" (got: " << interface_api - << ") or unpacked-api == true (got: " << use_unpacked_api_ - << ") when targeting c runtime"; + if (unpacked_api == true) { + call_type_ = CallType::kUnpacked; + } else if (unpacked_api == false && interface_api == "packed") { + call_type_ = CallType::kCPacked; + } else { + CHECK(interface_api == "packed" || unpacked_api == true) + << "Either need interface_api == \"packed\" (got: " << interface_api + << ") or unpacked-api == true (got: " << unpacked_api << ") when targeting c runtime"; + ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api + << ", unpacked-api=" << unpacked_api; + } } else if (runtime_config->name == kTvmRuntimeCpp) { - ICHECK(static_cast(use_unpacked_api_) == false) - << "Need unpacked-api == false (got: " << use_unpacked_api_ - << ") and interface-api == \"packed\" (got: " << interface_api - << ") when targeting c++ runtime"; + if (unpacked_api == false && interface_api == "packed") { + call_type_ = CallType::kCPacked; + } else { + CHECK(static_cast(unpacked_api) == false && interface_api == "packed") + << "Need unpacked-api == false (got: " << unpacked_api + << ") and interface-api == \"packed\" (got: " << interface_api + << ") when targeting c++ runtime"; + ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api + << ", unpacked-api=" << unpacked_api; + } } else { ICHECK(false) << "runtime_config (" << runtime_config->name << ") is not one of the expected values"; @@ -1037,7 +1087,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Legalize AOT if needed. This means that all the packed calls // need to be wrapped in TVMValues (unless use_unpacked_api is set) - if (!use_unpacked_api_) { + if (call_type_ == CallType::kCPacked || call_type_ == CallType::kPacked) { auto pack_calls = tir::transform::LegalizePackedCalls(); lowered_mod = pack_calls(lowered_mod); } @@ -1106,7 +1156,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { ret.metadata = ExecutorCodegenMetadata( inputs, input_tensor_types, output_var_names, output_tensor_types, pool_vars, devices, - runtime::kTvmExecutorAot, mod_name, interface_api, use_unpacked_api_, pool_var_info); + runtime::kTvmExecutorAot, mod_name, interface_api, unpacked_api, pool_var_info); return ret; } diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc index 90469fabad2c..c08f2872fe8a 100644 --- a/src/runtime/metadata.cc +++ b/src/runtime/metadata.cc @@ -18,7 +18,7 @@ */ /*! - * \file tvm/runtime/metadata.h + * \file src/runtime/metadata.cc * \brief Defines implementations of TVM metadata which can exist in the runtime. */ @@ -47,20 +47,27 @@ ArrayAccessor MetadataNode::pools() { TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); -MetadataArray::MetadataArray(Array array, MetadataTypeIndex type_index, - const char* struct_name) - : MetadataBase{make_object(array, type_index, struct_name)} {} +MetadataArray::MetadataArray(Array array, MetadataKind kind, const char* struct_name) + : MetadataBase{make_object(array, kind, struct_name)} {} +const char* MetadataArrayNode::get_c_struct_name() const { + ICHECK(false) << "MetadataArrayNode get_c_struct_name is unimplemented"; + return nullptr; +} TVM_REGISTER_OBJECT_TYPE(MetadataArrayNode); Metadata::Metadata(const struct ::TVMMetadata* data) : MetadataBase{make_object(data)} {} TVM_REGISTER_OBJECT_TYPE(MetadataNode); +const char* MetadataNode::get_c_struct_name() const { return "TVMMetadata"; } + TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data) : MetadataBase{make_object(data)} {} TVM_REGISTER_OBJECT_TYPE(TensorInfoNode); +const char* TensorInfoNode::get_c_struct_name() const { return "TVMTensorInfo"; } + } // namespace metadata class MetadataModuleNode : public ::tvm::runtime::ModuleNode { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 53c8f7754602..033275ae5286 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -30,8 +30,10 @@ #include #include #include +#include #include "../func_registry_generator.h" +#include "../metadata_utils.h" namespace tvm { namespace codegen { @@ -74,8 +76,7 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, // void* resource_handle); ftype_tvm_backend_packed_c_func_ = llvm::FunctionType::get( t_int_, - {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_, - t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_void_p_}, + {t_void_p_, t_int_->getPointerTo(), t_int_, t_void_p_, t_int_->getPointerTo(), t_void_p_}, false); t_tvm_crt_func_registry_ = llvm::StructType::create( {t_char_->getPointerTo(), ftype_tvm_backend_packed_c_func_->getPointerTo()}); @@ -802,10 +803,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& args, const DataType& r_type, - const int64_t begin, const int64_t end) { + const int64_t begin, const int64_t end, + bool use_string_lookup) { PackedCall pc; std::string func_name = args[0].as()->value; - llvm::Value* handle = GetPackedFuncHandle(func_name); // call the function int64_t nargs = end - begin; ICHECK_GE(nargs, 0); @@ -822,14 +823,43 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& TypedPointer ret_tcode = CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32)); + llvm::FunctionType* callee_ftype = nullptr; + llvm::Value* callee_value = nullptr; + std::vector call_args; + + if (use_string_lookup) { + callee_ftype = ftype_tvm_func_call_; + callee_value = RuntimeTVMFuncCall(); + call_args.push_back(GetPackedFuncHandle(func_name)); + call_args.insert(call_args.end(), + {arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + } else { + callee_ftype = ftype_tvm_backend_packed_c_func_; + callee_value = module_->getFunction(func_name); + if (callee_value == nullptr) { + callee_value = + llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, + func_name, module_.get()); + } + + nargs -= 1; + call_args.insert(call_args.end(), { + builder_->CreateBitCast(arg_value, t_void_p_), + arg_tcode.addr, + ConstInt32(nargs), + builder_->CreateBitCast(ret_value, t_void_p_), + ret_tcode.addr, + }); + call_args.push_back(llvm::ConstantPointerNull::get(t_void_p_)); + } #if TVM_LLVM_VERSION >= 90 - auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); + auto call_callee = llvm::FunctionCallee(callee_ftype, callee_value); #else - auto call_callee = RuntimeTVMFuncCall(); + (void)callee_ftype; // use callee_ftype to avoid unused variable warning when using older LLVM. + auto call_callee = callee_value; #endif - llvm::Value* call = builder_->CreateCall( - call_callee, - {handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + llvm::Value* call = builder_->CreateCall(call_callee, call_args); + llvm::BasicBlock* end_block = CheckCallSuccess(call); // Load the return value and cast it to the designated type (r_type). @@ -858,17 +888,18 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& return pc; } -llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) { - ICHECK_EQ(op->args.size(), 5U); +llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lookup) { + auto expected_num_args = use_string_lookup ? 5U : 6U; + ICHECK_EQ(op->args.size(), expected_num_args); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + op->args[4].as()->value, use_string_lookup); return pc.ret_value; } llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { ICHECK_EQ(op->args.size(), 6U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + op->args[4].as()->value, true); // Get traced value. llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. @@ -914,6 +945,306 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() { return GetContextPtr(gv_tvm_parallel_barrier_); } +/*! \brief Defines LLVM Types for each Metadata member type. */ +struct MetadataLlvmTypes { + llvm::Type* t_float64; + llvm::Type* t_uint8; + llvm::Type* t_int64; + llvm::Type* t_bool; + llvm::Type* t_cstring; + llvm::Type* t_void_p; + llvm::StructType* t_data_type; + + /*! \brief Maps a MetadataBase subclass' type_key to its corresponding LLVM StructType. */ + ::std::unordered_map structs_by_type_key; +}; + +class MetadataTypeDefiner : public AttrVisitor { + public: + MetadataTypeDefiner(llvm::LLVMContext* ctx, struct MetadataLlvmTypes* llvm_types) + : ctx_{ctx}, llvm_types_{llvm_types} {} + + void Visit(const char* key, double* value) final { + elements_.emplace_back(llvm_types_->t_float64); + } + void Visit(const char* key, int64_t* value) final { + elements_.emplace_back(llvm_types_->t_int64); + } + void Visit(const char* key, uint64_t* value) final { + elements_.emplace_back(llvm_types_->t_int64); + } + void Visit(const char* key, int* value) final { elements_.emplace_back(llvm_types_->t_int64); } + void Visit(const char* key, bool* value) final { elements_.emplace_back(llvm_types_->t_bool); } + void Visit(const char* key, std::string* value) final { + elements_.emplace_back(llvm_types_->t_cstring); + } + void Visit(const char* key, void** value) final { elements_.emplace_back(llvm_types_->t_void_p); } + void Visit(const char* key, DataType* value) final { + elements_.emplace_back(llvm_types_->t_data_type); + } + void Visit(const char* key, runtime::NDArray* value) final { + CHECK(false) << "Do not support serializing NDArray"; + } + + private: + void VisitMetadataBase(runtime::metadata::MetadataBase metadata) { + elements_.emplace_back(llvm::PointerType::getUnqual( + llvm::StructType::create(*ctx_, metadata->get_c_struct_name()))); + if (visited_.find(metadata->get_c_struct_name()) != visited_.end()) { + return; + } + + if (to_visit_.find(metadata->get_c_struct_name()) != to_visit_.end()) { + return; + } + to_visit_[metadata->get_c_struct_name()] = metadata; + } + + public: + using MetadataKind = runtime::metadata::MetadataKind; + + void VisitArray(const runtime::metadata::MetadataArrayNode* arr) { + switch (arr->kind) { + case MetadataKind::kUint64: // LLVM encodes signed and unsigned with same types. + case MetadataKind::kInt64: + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_int64)); + break; + case MetadataKind::kBool: + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_bool)); + break; + case MetadataKind::kString: + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_cstring)); + break; + case MetadataKind::kHandle: + CHECK(false) << "Do not support handle"; + break; + case MetadataKind::kMetadata: + elements_.emplace_back( + llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[arr->type_key])); + break; + default: + CHECK(false) << "Unsupported metadata kind " << arr->kind; + break; + } + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + VisitArray(arr); + } else { + elements_.emplace_back( + llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[(*value)->GetTypeKey()])); + } + } + + void DefineType(runtime::metadata::MetadataBase metadata) { + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + for (auto e : elements_) { + std::string value; + llvm::raw_string_ostream os(value); + e->print(os, true); + } + llvm_types_->structs_by_type_key[metadata->GetTypeKey()] = + llvm::StructType::create(*ctx_, elements_, metadata->get_c_struct_name()); + elements_.clear(); + } + + llvm::LLVMContext* ctx_; + struct MetadataLlvmTypes* llvm_types_; + ::std::unordered_set<::std::string> visited_; + ::std::unordered_map<::std::string, runtime::metadata::MetadataBase> to_visit_; + ::std::vector elements_; +}; + +class MetadataSerializerLLVM : public AttrVisitor { + using MetadataKind = runtime::metadata::MetadataKind; + + public: + MetadataSerializerLLVM(CodeGenLLVM* codegen, struct MetadataLlvmTypes* llvm_types) + : codegen_{codegen}, llvm_types_{llvm_types} {} + + void Visit(const char* key, double* value) final { + elements_.back().emplace_back(llvm::ConstantFP::get(llvm_types_->t_float64, *value)); + } + void Visit(const char* key, int64_t* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get( + llvm_types_->t_int64, static_cast(*value), true /* isSigned */)); + } + void Visit(const char* key, uint64_t* value) final { + elements_.back().emplace_back( + llvm::ConstantInt::get(llvm_types_->t_int64, *value, false /* isSigned */)); + } + void Visit(const char* key, int* value) final { + elements_.back().emplace_back( + llvm::ConstantInt::get(llvm_types_->t_int64, *value, true /* isSigned */)); + } + void Visit(const char* key, bool* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get( + llvm_types_->t_uint8, static_cast(*value), false /* isSigned */)); + } + void Visit(const char* key, std::string* value) final { + elements_.back().emplace_back(codegen_->GetConstString(*value)); + } + void Visit(const char* key, void** value) final { + CHECK(false) << "Do not support serializing void*"; + } + void Visit(const char* key, DataType* value) final { + elements_.back().emplace_back(llvm::ConstantStruct::get( + llvm_types_->t_data_type, + {llvm::ConstantInt::get(llvm_types_->t_uint8, value->code(), false /* isSigned */), + llvm::ConstantInt::get(llvm_types_->t_uint8, value->bits(), false /* isSigned */), + llvm::ConstantInt::get(llvm_types_->t_uint8, value->lanes(), false /* isSigned */)})); + } + + void Visit(const char* key, runtime::NDArray* value) final { + CHECK(false) << "Do not support serializing NDArray"; + } + + void VisitMetadata(runtime::metadata::MetadataBase metadata) { + elements_.emplace_back(std::vector()); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + auto struct_elements = elements_.back(); + elements_.pop_back(); + auto struct_ty = llvm_types_->structs_by_type_key[metadata->GetTypeKey()]; + ICHECK(struct_ty != nullptr) << "Did not find LLVM StructType* for type_key=" + << metadata->GetTypeKey(); + CHECK_EQ(struct_elements.size(), struct_ty->getNumElements()); + auto out = llvm::ConstantStruct::get(struct_ty, struct_elements); + if (elements_.size() > 0) { + elements_.back().push_back(out); + } else { + last_production_ = out; + } + } + + void VisitArray(const runtime::metadata::MetadataArrayNode* arr) { + llvm::Type* element_type; + switch (arr->kind) { + case MetadataKind::kInt64: + element_type = llvm_types_->t_int64; + break; + case MetadataKind::kUint64: + element_type = llvm_types_->t_int64; + break; + case MetadataKind::kBool: + element_type = llvm_types_->t_uint8; + break; + case MetadataKind::kString: + element_type = llvm_types_->t_cstring; + break; + case MetadataKind::kMetadata: { + element_type = llvm_types_->structs_by_type_key[arr->type_key]; + ICHECK(element_type != nullptr) + << "Did not find LLVM StructType* for type_key=" << arr->type_key; + break; + } + default: + LOG(FATAL) << "unknown metadata kind " << arr->kind; + break; + } + + elements_.emplace_back(std::vector()); + for (auto o : arr->array) { + if (o->IsInstance()) { + double value = Downcast(o)->value; + Visit(nullptr, &value); + } + if (o->IsInstance()) { + auto value = Downcast(o)->value; + Visit(nullptr, &value); + } else if (o->IsInstance()) { + ::std::string value = Downcast(o); + Visit(nullptr, &value); + } else { + // nested array not possible. + VisitMetadata(Downcast(o)); + } + } + auto array = elements_.back(); + elements_.pop_back(); + CHECK(element_type != nullptr); + auto arr_ty = llvm::ArrayType::get(element_type, array.size()); + auto llvm_arr = llvm::ConstantArray::get(arr_ty, array); + + if (elements_.size() > 0) { + elements_.back().emplace_back( + codegen_->GetGlobalConstant(llvm_arr, "", llvm::GlobalValue::PrivateLinkage)); + } else { + last_production_ = llvm_arr; + } + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + VisitArray(arr); + return; + } + + runtime::metadata::MetadataBase metadata = Downcast(*value); + VisitMetadata(metadata); + } + + llvm::Constant* Serialize(runtime::metadata::MetadataBase metadata) { + Visit(nullptr, &metadata); + ICHECK(last_production_); + return codegen_->GetGlobalConstant(last_production_); + } + + CodeGenLLVM* codegen_; + MetadataLlvmTypes* llvm_types_; + llvm::LLVMContext* ctx_; + llvm::Module* module_; + std::vector> elements_; + llvm::Constant* last_production_; +}; + +void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { + MetadataLlvmTypes llvm_types{ + t_float64_ /* t_float64 */, + llvm::Type::getInt8Ty(*ctx_) /* t_uint8 */, + t_int64_ /* t_int64 */, + llvm::Type::getInt8Ty(*ctx_) /* t_bool */, + t_char_->getPointerTo() /* t_cstring */, + t_void_p_ /* t_void_p */, + llvm::StructType::create(*ctx_, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */, + }; + + std::vector queue; + metadata::DiscoverComplexTypesVisitor discover_complex{&queue}; + discover_complex.Discover(metadata); + + MetadataTypeDefiner definer{ctx_, &llvm_types}; + for (auto md : queue) { + if (md.defined()) { + definer.DefineType(md); + } + } + + MetadataSerializerLLVM serializer{this, &llvm_types}; + auto metadata_constant_gv = serializer.Serialize(metadata); + + function_ = + llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, + "get_c_metadata", module_.get()); + function_->setCallingConv(llvm::CallingConv::C); + function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); + + llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + builder_->SetInsertPoint(entry_point_entry); + + auto ret_values_p = builder_->CreateBitCast(GetArg(function_, 3), t_void_p_->getPointerTo()); + builder_->CreateStore(builder_->CreateBitCast(metadata_constant_gv, t_void_p_), ret_values_p); + + auto ret_tcode = builder_->CreateBitCast(GetArg(function_, 4), t_int_->getPointerTo()); + builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), ret_tcode); + + builder_->CreateRet(ConstInt32(0)); +} + void CodeGenCPU::DefineFunctionRegistry(Array func_names) { ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime"; Array symbols; @@ -980,9 +1311,11 @@ void CodeGenCPU::AddStartupFunction() { llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin::tvm_call_packed_lowered())) { - return CreateCallPacked(op); + return CreateCallPacked(op, true /* use_string_lookup */); } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) { return CreateCallTracePacked(op); + } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { + return CreateCallPacked(op, false /* use_string_lookup */); } else if (op->op.same_as(builtin::tvm_static_handle())) { return CreateStaticHandle(); } else if (op->op.same_as(builtin::tvm_throw_last_error())) { @@ -1052,6 +1385,7 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_); // fail condition. builder_->SetInsertPoint(fail_block); + #if TVM_LLVM_VERSION >= 90 auto err_callee = llvm::FunctionCallee(ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError()); diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 26f251f1a9c8..a491d539a6ea 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -56,6 +56,12 @@ class CodeGenCPU : public CodeGenLLVM { */ void DefineFunctionRegistry(Array func_names); + /*! + * \brief Serialize the metadata object as data, and implement get_c_metadata function. + * \param metadata The metadata which should be serialized. + */ + void DefineMetadata(runtime::metadata::Metadata metadata); + protected: void AddStartupFunction() final; // meta data @@ -117,9 +123,9 @@ class CodeGenCPU : public CodeGenLLVM { llvm::BasicBlock* end_block; }; PackedCall MakeCallPackedLowered(const Array& args, const DataType& r_type, - const int64_t begin, const int64_t end); + const int64_t begin, const int64_t end, bool use_string_lookup); // create call into tvm packed function. - llvm::Value* CreateCallPacked(const CallNode* op); + llvm::Value* CreateCallPacked(const CallNode* op, bool use_string_lookup); // Create trace call into tvm packed function. llvm::Value* CreateCallTracePacked(const CallNode* op); // Create static initialization diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 8cd8a5199d54..d54d3c1c51c5 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -37,6 +37,7 @@ #include "codegen_cpu.h" #include "codegen_params.h" #include "llvm/Support/raw_os_ostream.h" +#include "llvm_common.h" namespace tvm { namespace codegen { @@ -134,11 +135,11 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; - ICHECK(module_->getFunction(static_cast(global_symbol.value())) == nullptr) - << "Function " << global_symbol << " already exist in module"; - - function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, - global_symbol.value().operator std::string(), module_.get()); + function_ = module_->getFunction(static_cast(global_symbol.value())); + if (function_ == nullptr) { + function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + global_symbol.value().operator std::string(), module_.get()); + } function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); @@ -191,6 +192,19 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } } +llvm::GlobalVariable* CodeGenLLVM::GetLinkedParamSymbol(const std::string& param_name, + llvm::ConstantArray* array) { + std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + param_name; + llvm::GlobalVariable* var = module_->getGlobalVariable(symbol_name, true /* AllowInternal */); + if (var == nullptr) { + CHECK(array != nullptr) << "Expect param symbol " << symbol_name + << " to either be defined or for the array to be supplied"; + var = new llvm::GlobalVariable(*module_, static_cast(array->getType()), true, + llvm::GlobalValue::InternalLinkage, array, symbol_name); + } + return var; +} + void CodeGenLLVM::LinkParameters(const Map params) { // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc, // but they are at a different layer in the compiler... @@ -209,22 +223,13 @@ void CodeGenLLVM::LinkParameters(const Map params) { llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function); builder_->SetInsertPoint(entry); - auto getArg = [function](int i) -> llvm::Argument* { -#if TVM_LLVM_VERSION >= 100 - return function->getArg(i); -#elif TVM_LLVM_VERSION >= 50 - return &function->arg_begin()[i]; -#else - return &*std::next(function->arg_begin(), i); -#endif - }; - llvm::Type* t_int64_p = t_int64_->getPointerTo(GetGlobalAddressSpace()); - llvm::Value* sid = builder_->CreateLoad(t_int64_, builder_->CreateBitCast(getArg(0), t_int64_p)); + llvm::Value* sid = + builder_->CreateLoad(t_int64_, builder_->CreateBitCast(GetArg(function, 0), t_int64_p)); - auto ret_tcode = builder_->CreateBitCast(getArg(4), t_int_p); - auto ret_value = - builder_->CreateBitCast(getArg(3), t_void_p_->getPointerTo(GetGlobalAddressSpace())); + auto ret_tcode = builder_->CreateBitCast(GetArg(function, 4), t_int_p); + auto ret_value = builder_->CreateBitCast(GetArg(function, 3), + t_void_p_->getPointerTo(GetGlobalAddressSpace())); llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function); llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1); @@ -236,9 +241,7 @@ void CodeGenLLVM::LinkParameters(const Map params) { // Add data to the global section. for (auto kv : params) { auto array = NDArrayToLLVMArray(ctx_, kv.second->param); - std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first; - llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( - *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); + llvm::GlobalVariable* param_symbol = GetLinkedParamSymbol(kv.first, array); auto dtype = tvm::runtime::DataType(kv.second->param->dtype); size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment); #if TVM_LLVM_VERSION >= 100 @@ -246,8 +249,10 @@ void CodeGenLLVM::LinkParameters(const Map params) { #else param_symbol->setAlignment(align); #endif + param_symbol->setInitializer(array); - llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + symbol_name, function); + llvm::BasicBlock* case_block = + llvm::BasicBlock::Create(*ctx_, "case_" + param_symbol->getName(), function); switch_inst->addCase( llvm::cast(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block); builder_->SetInsertPoint(case_block); @@ -388,6 +393,7 @@ void CodeGenLLVM::Optimize() { fpass.run(*it); } fpass.doFinalization(); + // PrintModule(module_.get()); mpass.run(*module_); } @@ -770,21 +776,27 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va } } -llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { - auto it = str_map_.find(str); - if (it != str_map_.end()) return it->second; - llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1); - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, true, llvm::GlobalValue::PrivateLinkage, nullptr, ".str"); +llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const std::string& name, + llvm::GlobalValue::LinkageTypes linkage_type) { + llvm::Type* ty = const_data->getType(); + llvm::GlobalVariable* global = + new llvm::GlobalVariable(*module_, ty, true, linkage_type, const_data, name); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(1)); #else global->setAlignment(1); #endif - global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str)); llvm::Constant* zero = ConstInt32(0); llvm::Constant* indices[] = {zero, zero}; - llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(type, global, indices); + llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(ty, global, indices); + return ptr; +} + +llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { + auto it = str_map_.find(str); + if (it != str_map_.end()) return it->second; + auto llvm_str = llvm::ConstantDataArray::getString(*ctx_, str); + auto ptr = GetGlobalConstant(llvm_str, ".str", llvm::GlobalValue::PrivateLinkage); str_map_[str] = ptr; return ptr; } @@ -1407,7 +1419,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { if (auto* ptr_op = op->op.as()) { auto call_op = GetRef(ptr_op); - if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { + if (op->op.same_as(builtin_lookup_param_)) { + return GetLinkedParamSymbol(Downcast(op->args[0])->value, nullptr); + } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { // call extern intrinsic ICHECK_GE(op->args.size(), 1U); auto global_symbol = Downcast(op->args[0]); @@ -1418,7 +1432,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { return this->CreateCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], op->args, false); } else { - return CreateIntrinsic(op); + VLOG(2) << "CreateIntrinsic: " << GetRef(op); + auto x = CreateIntrinsic(op); + VLOG(2) << "CreateIntrinsic done"; + return x; } } else { ICHECK(op->op.as()); @@ -1563,7 +1580,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - size_t constant_size = op->ConstantAllocationSize(); + int32_t constant_size = op->ConstantAllocationSize(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 7a7ca6578f28..7f84119345db 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -23,6 +23,7 @@ */ #ifndef TVM_TARGET_LLVM_CODEGEN_LLVM_H_ #define TVM_TARGET_LLVM_CODEGEN_LLVM_H_ +#include #ifdef TVM_LLVM_VERSION #include @@ -190,6 +191,13 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; + // Get constant string + llvm::Constant* GetConstString(const std::string& str); + + llvm::Constant* GetGlobalConstant( + llvm::Constant* const_data, const std::string& name = "", + llvm::GlobalValue::LinkageTypes linkage_type = llvm::GlobalValue::InternalLinkage); + protected: /*! * \brief Address and type pair to assist in handling opaque pointers. @@ -341,6 +349,14 @@ class CodeGenLLVM : public ExprFunctor, */ llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, llvm::ArrayRef arg_types); + /*! + * \brief Lookup or create a GlobalVariable whose content is the data field of a DLTensor for a + * given linked_param() CallNode. + * \param param_name Parameter name (e.g. unmangled, from lookup_param node). + * \return the GlobalVariable indicated in the brief. + */ + llvm::GlobalVariable* GetLinkedParamSymbol(const ::std::string& param_name, + llvm::ConstantArray* array); /*! * \brief Get the number of elements in the given vector value. * \param vec The value, must be of a vector type. @@ -353,8 +369,6 @@ class CodeGenLLVM : public ExprFunctor, int* p_native_bits); // Returns whether the LLVM type has padding for alignment bool HasAlignmentPadding(DataType dtype); - // Get constant string - llvm::Constant* GetConstString(const std::string& str); // do a scalarize call with f llvm::Value* CreateScalarizedCall(const CallNode* op, llvm::Function* f, const std::vector& args); @@ -389,6 +403,27 @@ class CodeGenLLVM : public ExprFunctor, unsigned int shared_address_space, int alignment, llvm::GlobalValue::LinkageTypes linkage); + /*! + * \brief Get the `i`th argument to the given function, respecting LLVM API changes. + * + * NOTE: in LLVM < 10.0, the underlying API returns a const llvm::Argument*. To provide a uniform + * API, const is removed here. Proper usage of LLVM APIs depends on having a non-const Argument*, + * so we take this appraoch here rather than adding const. + * + * \param function The function containing the arguments. + * \param i The index of the argument to retrieve. + * \return The retrieved argument. + */ + llvm::Argument* GetArg(const llvm::Function* function, int i) const { +#if TVM_LLVM_VERSION >= 100 + return function->getArg(i); +#elif TVM_LLVM_VERSION >= 50 + return const_cast(&function->arg_begin()[i]); +#else + return const_cast(&*std::next(function->arg_begin(), i)); +#endif + } + // The IRBuilder. using IRBuilder = llvm::IRBuilder; // The current function @@ -447,6 +482,8 @@ class CodeGenLLVM : public ExprFunctor, const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); const Op& builtin_call_llvm_pure_intrin_ = builtin::call_llvm_pure_intrin(); + const Op& builtin_lookup_param_ = builtin::lookup_param(); + const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered(); /*! \brief Helper struct for debug infos. */ struct DebugInfo { @@ -481,6 +518,7 @@ void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfu return name_a < name_b; }); for (auto& f : funcs) { + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); AddFunction(f); } } diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc index 06b2be2d9fb6..f13e8563e053 100644 --- a/src/target/llvm/llvm_common.cc +++ b/src/target/llvm/llvm_common.cc @@ -189,6 +189,13 @@ std::string LLVMTargetToString(const Target& target) { return os.str(); } +void PrintModule(const llvm::Module* mod) { + std::string modpe_str; + llvm::raw_string_ostream rso(modpe_str); + mod->print(rso, nullptr); + LOG(INFO) << rso.str(); +} + } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index 556f05d2e33a..e2e3384c1a19 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -126,6 +126,8 @@ std::unique_ptr GetLLVMTargetMachine(const Target& target, */ std::string LLVMTargetToString(const Target& target); +void PrintModule(const llvm::Module* mod); + } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index cf8b59357b47..ab679bdedd1f 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -308,14 +308,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { cg->SetFastMathFlag(fmf); + if (found_linked_params) { + cg->LinkParameters(linked_params); + } cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); } - if (found_linked_params) { - cg->LinkParameters(linked_params); - } module_ = cg->Finish(); module_->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx_, LLVMTargetToString(target))); @@ -527,6 +527,41 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob") return runtime::Module(n); }); +runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, + tvm::relay::Runtime runtime) { + InitializeLLVM(); + auto tm = GetLLVMTargetMachine(target); + bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); + auto ctx = std::make_shared(); + std::unique_ptr cg{new CodeGenCPU()}; + + cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, + false /* target_c_runtime */); + + cg->DefineMetadata(metadata); + auto mod = cg->Finish(); + mod->addModuleFlag(llvm::Module::Warning, "tvm_target", + llvm::MDString::get(*ctx, LLVMTargetToString(target))); + mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); + + if (tm->getTargetTriple().isOSDarwin()) { + mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); + } + + std::string verify_errors_storage; + llvm::raw_string_ostream verify_errors(verify_errors_storage); + LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors)) + << "LLVM module verification failed with the following errors: \n" + << verify_errors.str(); + + auto n = make_object(); + n->Init(std::move(mod), ctx); + + auto meta_mod = MetadataModuleCreate(metadata); + meta_mod->Import(runtime::Module(n)); + return meta_mod; +} + runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target, tvm::relay::Runtime runtime) { Array func_names; diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 933030e213d2..660d81400b0d 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -33,6 +33,9 @@ namespace tvm { namespace codegen { +runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, + tvm::relay::Runtime runtime); + runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target, tvm::relay::Runtime runtime); diff --git a/src/target/metadata.h b/src/target/metadata.h index b8ca24580f15..5dc1c9d0eec5 100644 --- a/src/target/metadata.h +++ b/src/target/metadata.h @@ -56,7 +56,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { inputs_array.push_back(::tvm::runtime::metadata::TensorInfo{inputs_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{ - inputs_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + inputs_array, ::tvm::runtime::metadata::MetadataKind::kMetadata, + ::tvm::runtime::metadata::TensorInfoNode::_type_key}; v->Visit("inputs", &inputs_metadata_array); int64_t num_inputs_cpp = num_inputs(); v->Visit("num_inputs", &num_inputs_cpp); @@ -67,7 +68,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { outputs_array.push_back(::tvm::runtime::metadata::TensorInfo{outputs_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray outputs_metadata_array{ - outputs_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + outputs_array, ::tvm::runtime::metadata::MetadataKind::kMetadata, + ::tvm::runtime::metadata::TensorInfoNode::_type_key}; v->Visit("outputs", &outputs_metadata_array); int64_t num_outputs_cpp = num_outputs(); v->Visit("num_outputs", &num_outputs_cpp); @@ -78,7 +80,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { pools_array.push_back(::tvm::runtime::metadata::TensorInfo{pools_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray pools_metadata_array{ - pools_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + pools_array, ::tvm::runtime::metadata::MetadataKind::kMetadata, + ::tvm::runtime::metadata::TensorInfoNode::_type_key}; v->Visit("pools", &pools_metadata_array); int64_t num_pools_cpp = num_pools(); v->Visit("num_pools", &num_pools_cpp); @@ -156,7 +159,7 @@ class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode shape_array.push_back(::tvm::Integer{static_cast(shape_accessor[i])}); } ::tvm::runtime::metadata::MetadataArray shape_metadata_array{ - shape_array, ::tvm::runtime::metadata::MetadataTypeIndex::kInt64, nullptr}; + shape_array, ::tvm::runtime::metadata::MetadataKind::kInt64, nullptr}; v->Visit("shape", &shape_metadata_array); int64_t num_shape_cpp = num_shape(); v->Visit("num_shape", &num_shape_cpp); diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 8abd18c1d8f3..5457946322c3 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -144,6 +144,12 @@ static runtime::Module CreateCppMetadataModule( auto metadata_module = CreateCSourceCppMetadataModule(runtime_metadata); metadata_module->Import(target_module); target_module = metadata_module; +#ifdef TVM_LLVM_VERSION // defining TVM_LLVM_VERSION indicates TVM was compiled with USE_LLVM ON. + } else if (target->kind->name == "llvm") { + auto metadata_module = CreateLLVMCppMetadataModule(runtime_metadata, target, runtime); + metadata_module->Import(target_module); + target_module = metadata_module; +#endif // TVM_LLVM_VERSION } else { CHECK(false) << "Don't know how to create MetadataModule for target type " << target->str(); } diff --git a/src/target/metadata_utils.cc b/src/target/metadata_utils.cc new file mode 100644 index 000000000000..db17d1862846 --- /dev/null +++ b/src/target/metadata_utils.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/metadata_utils.cc + * \brief Defines utility functions and classes for emitting metadata. + */ +#include "metadata_utils.h" + +namespace tvm { +namespace codegen { +namespace metadata { + +std::string AddressFromParts(const std::vector& parts) { + std::stringstream ss; + for (unsigned int i = 0; i < parts.size(); ++i) { + if (i > 0) { + ss << "_"; + } + ss << parts[i]; + } + return ss.str(); +} + +DiscoverArraysVisitor::DiscoverArraysVisitor(std::vector* queue) : queue_{queue} {} + +void DiscoverArraysVisitor::Visit(const char* key, double* value) {} +void DiscoverArraysVisitor::Visit(const char* key, int64_t* value) {} +void DiscoverArraysVisitor::Visit(const char* key, uint64_t* value) {} +void DiscoverArraysVisitor::Visit(const char* key, int* value) {} +void DiscoverArraysVisitor::Visit(const char* key, bool* value) {} +void DiscoverArraysVisitor::Visit(const char* key, std::string* value) {} +void DiscoverArraysVisitor::Visit(const char* key, DataType* value) {} +void DiscoverArraysVisitor::Visit(const char* key, runtime::NDArray* value) {} +void DiscoverArraysVisitor::Visit(const char* key, void** value) {} + +void DiscoverArraysVisitor::Visit(const char* key, ObjectRef* value) { + address_parts_.push_back(key); + if (value->as() != nullptr) { + auto metadata = Downcast(*value); + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + for (unsigned int i = 0; i < arr->array.size(); i++) { + ObjectRef o = arr->array[i]; + if (o.as() != nullptr) { + std::stringstream ss; + ss << i; + address_parts_.push_back(ss.str()); + runtime::metadata::MetadataBase metadata = Downcast(o); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + address_parts_.pop_back(); + } + } + + queue_->push_back(std::make_tuple(AddressFromParts(address_parts_), + Downcast(metadata))); + } else { + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + } + } + address_parts_.pop_back(); +} + +void DiscoverComplexTypesVisitor::Visit(const char* key, double* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, int64_t* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, uint64_t* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, int* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, bool* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, std::string* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, DataType* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, runtime::NDArray* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, void** value) {} + +bool DiscoverComplexTypesVisitor::DiscoverType(std::string type_key) { + VLOG(2) << "DiscoverType " << type_key; + auto position_it = type_key_to_position_.find(type_key); + if (position_it != type_key_to_position_.end()) { + return false; + } + + queue_->emplace_back(tvm::runtime::metadata::MetadataBase()); + type_key_to_position_[type_key] = queue_->size() - 1; + return true; +} + +void DiscoverComplexTypesVisitor::DiscoverInstance(runtime::metadata::MetadataBase md) { + auto position_it = type_key_to_position_.find(md->GetTypeKey()); + ICHECK(position_it != type_key_to_position_.end()) + << "DiscoverInstance requires that DiscoverType has already been called: type_key=" + << md->GetTypeKey(); + + int queue_position = (*position_it).second; + if (!(*queue_)[queue_position].defined() && md.defined()) { + VLOG(2) << "DiscoverInstance " << md->GetTypeKey() << ":" << md; + (*queue_)[queue_position] = md; + } +} + +void DiscoverComplexTypesVisitor::Visit(const char* key, ObjectRef* value) { + ICHECK_NOTNULL(value->as()); + + auto metadata = Downcast(*value); + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + + if (arr == nullptr) { + VLOG(2) << "No array, object-traversing " << metadata->GetTypeKey(); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + DiscoverType(metadata->GetTypeKey()); + DiscoverInstance(metadata); + return; + } + + if (arr->kind != tvm::runtime::metadata::MetadataKind::kMetadata) { + return; + } + + bool needs_instance = DiscoverType(arr->type_key); + for (unsigned int i = 0; i < arr->array.size(); i++) { + tvm::runtime::metadata::MetadataBase o = + Downcast(arr->array[i]); + if (needs_instance) { + DiscoverInstance(o); + needs_instance = false; + } + ReflectionVTable::Global()->VisitAttrs(o.operator->(), this); + } +} + +void DiscoverComplexTypesVisitor::Discover(runtime::metadata::MetadataBase metadata) { + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + DiscoverType(metadata->GetTypeKey()); + DiscoverInstance(metadata); +} + +} // namespace metadata +} // namespace codegen +} // namespace tvm diff --git a/src/target/metadata_utils.h b/src/target/metadata_utils.h new file mode 100644 index 000000000000..977a0f412bb5 --- /dev/null +++ b/src/target/metadata_utils.h @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/metadata_utils.h + * \brief Declares utilty functions and classes for emitting metadata. + */ +#ifndef TVM_TARGET_METADATA_UTILS_H_ +#define TVM_TARGET_METADATA_UTILS_H_ + +#include +#include +#include + +#include +#include +#include +#include + +#include "metadata.h" + +namespace tvm { +namespace codegen { +namespace metadata { + +/*! + * \brief Construct a unique string "address" for a struct member from a vector of pieces. + * + * In codegen, it is frequently necessary to assemble a C-style identifier for an + * otherwise-anonymous member of Metadata. For instance, suppose Metadata declares an array: + * struct TVMMetadata { + * int64_t* shape; + * }; + * + * In order to properly initialize this struct, the array must be declared separately with a global + * name. This function produces such a name, here termed "address." + * + * \param parts A vector of pieces, typically the struct member names which identify the path to + * this member. + * \return The joined pieces. + */ +std::string AddressFromParts(const std::vector& parts); + +/*! + * \brief A prefix in metadata symbol names. + * This prefix is typically given to AddressFromParts as the 0th item in parts. + */ +static constexpr const char* kMetadataGlobalSymbol = "kTvmgenMetadata"; + +/*! + * \brief Post-order traverse metadata to discover arrays which need to be forward-defined. + */ +class DiscoverArraysVisitor : public AttrVisitor { + public: + /*! \brief Models a single array discovered in this visitor. + * Conatains two fields: + * 0. An address which uniquely identifies the array in this Metadata instance. + * 1. The discovered MetadataArray. + */ + using DiscoveredArray = std::tuple; + explicit DiscoverArraysVisitor(std::vector* queue); + + void Visit(const char* key, double* value) final; + void Visit(const char* key, int64_t* value) final; + void Visit(const char* key, uint64_t* value) final; + void Visit(const char* key, int* value) final; + void Visit(const char* key, bool* value) final; + void Visit(const char* key, std::string* value) final; + void Visit(const char* key, DataType* value) final; + void Visit(const char* key, runtime::NDArray* value) final; + void Visit(const char* key, void** value) final; + + void Visit(const char* key, ObjectRef* value) final; + + private: + /*! \brief The queue to be filled with discovered arrays. */ + std::vector* queue_; + + /*! \brief Tracks the preceding address pieces. */ + std::vector address_parts_; +}; + +/*! + * \brief Post-order traverse Metadata to discover all complex types which need to be + * forward-defined. This visitor finds one defined() MetadataBase instance for each unique subclass + * present inside Metadata in the order in which the subclass was first discovered. + */ +class DiscoverComplexTypesVisitor : public AttrVisitor { + public: + /*! \brief Construct a new instance. + * \param queue An ordered map which holds the + */ + explicit DiscoverComplexTypesVisitor(std::vector* queue) + : queue_{queue} {} + + void Visit(const char* key, double* value) final; + void Visit(const char* key, int64_t* value) final; + void Visit(const char* key, uint64_t* value) final; + void Visit(const char* key, int* value) final; + void Visit(const char* key, bool* value) final; + void Visit(const char* key, std::string* value) final; + void Visit(const char* key, DataType* value) final; + void Visit(const char* key, runtime::NDArray* value) final; + void Visit(const char* key, void** value) final; + + void Visit(const char* key, ObjectRef* value) final; + + void Discover(runtime::metadata::MetadataBase metadata); + + private: + bool DiscoverType(std::string type_key); + + void DiscoverInstance(runtime::metadata::MetadataBase md); + + std::vector* queue_; + + /*! \brief map type_index to index in queue_. */ + std::unordered_map type_key_to_position_; +}; + +} // namespace metadata +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_METADATA_UTILS_H_ diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 0b74a1a1c4d9..d7a121c631f5 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -273,7 +273,7 @@ std::string CodeGenCHost::GetPackedName(const CallNode* op) { CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op, bool has_resource_handle) { const StringImmNode* s = op->args[0].as(); - ICHECK(s != nullptr) << "tvm_call_{c}packed_lowered expects first argument as function name"; + ICHECK(s != nullptr) << "tvm_call_[c]packed_lowered expects first argument as function name"; int64_t begin = op->args[3].as()->value; int64_t end = op->args[4].as()->value; int64_t num_args = end - begin; @@ -281,10 +281,30 @@ CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op, std::string func_name = s->value; if (has_resource_handle) { - std::string resource_handle_name = op->args[5].as()->value; - return {func_name, num_args - 1, resource_handle_name}; + const StringImmNode* resource_handle_var = op->args[5].as(); + if (resource_handle_var != nullptr) { + std::string resource_handle_name = resource_handle_var->value; + return {func_name, num_args - 1, resource_handle_name}; + } else { + // The final arg should be "(void*) NULL" to indicate the empty resource_handle. + num_args--; + + const CallNode* reinterpret_call = op->args[5].as(); + ICHECK_NE(reinterpret_call, (void*)nullptr) + << "At CallNode to " << s + << "arg 5: Expect either StringImm naming the resource_handle var from interface API or " + << "reinterpret(0); got: " << op->args[5]; + ICHECK_EQ(reinterpret_call->op, builtin::reinterpret()) + << "At CallNode to " << s + << "arg 5: Expect either StringImm naming the resource_handle var from interface API or " + << "reinterpret(0); got: " << op->args[5]; + ICHECK(is_zero(reinterpret_call->args[0])) << "At CallNode to " << s + << " arg 5: Expect either StringImm naming the " + "resource_handle var from interface API, or " + << "zero; got " << op->args[5]; + } } - return {func_name, num_args}; + return {func_name, num_args, "NULL"}; } void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 80b4f1b970f3..ef5755f3e84b 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,13 +23,13 @@ */ #include "source_module.h" +#include #include #include #include #include #include -#include #include #include #include @@ -40,6 +40,7 @@ #include "../../support/str_escape.h" #include "../func_registry_generator.h" #include "../metadata.h" +#include "../metadata_utils.h" #include "codegen_source_base.h" namespace tvm { @@ -523,69 +524,10 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } }; -static std::string address_from_parts(const std::vector& parts) { - std::stringstream ss; - for (unsigned int i = 0; i < parts.size(); ++i) { - if (i > 0) { - ss << "_"; - } - ss << parts[i]; - } - return ss.str(); -} - -class MetadataQueuer : public AttrVisitor { - public: - using QueueItem = std::tuple; - explicit MetadataQueuer(std::vector* queue) : queue_{queue} {} - - void Visit(const char* key, double* value) final {} - void Visit(const char* key, int64_t* value) final {} - void Visit(const char* key, uint64_t* value) final {} - void Visit(const char* key, int* value) final {} - void Visit(const char* key, bool* value) final {} - void Visit(const char* key, std::string* value) final {} - void Visit(const char* key, DataType* value) final {} - void Visit(const char* key, runtime::NDArray* value) final {} - void Visit(const char* key, void** value) final {} - - void Visit(const char* key, ObjectRef* value) final { - address_parts_.push_back(key); - if (value->as() != nullptr) { - auto metadata = Downcast(*value); - const runtime::metadata::MetadataArrayNode* arr = - value->as(); - if (arr != nullptr) { - for (unsigned int i = 0; i < arr->array.size(); i++) { - ObjectRef o = arr->array[i]; - if (o.as() != nullptr) { - std::stringstream ss; - ss << i; - address_parts_.push_back(ss.str()); - runtime::metadata::MetadataBase metadata = Downcast(o); - ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); - address_parts_.pop_back(); - } - } - } else { - ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); - } - - queue_->push_back(std::make_tuple(address_from_parts(address_parts_), - Downcast(*value))); - } - address_parts_.pop_back(); - } - - private: - std::vector* queue_; - std::vector address_parts_; -}; - class MetadataSerializer : public AttrVisitor { public: static constexpr const char* kGlobalSymbol = "kTvmgenMetadata"; - using MetadataTypeIndex = ::tvm::runtime::metadata::MetadataTypeIndex; + using MetadataKind = ::tvm::runtime::metadata::MetadataKind; MetadataSerializer() : is_first_item_{true} {} @@ -653,29 +595,54 @@ class MetadataSerializer : public AttrVisitor { ICHECK(false) << "do not support serializing NDArray as metadata"; } - void VisitArray(const runtime::metadata::MetadataArrayNode* array) { + void VisitArray(runtime::metadata::MetadataArray array) { auto old_is_first_item = is_first_item_; is_first_item_ = true; for (unsigned int i = 0; i < array->array.size(); ++i) { ObjectRef o = array->array[i]; - if (o->IsInstance()) { - int64_t i = Downcast(o); - Visit(nullptr, &i); - continue; - } - if (o->IsInstance()) { - std::string s = Downcast(o); - Visit(nullptr, &s); - continue; + switch (array->kind) { + case MetadataKind::kUint64: { + int64_t i = Downcast(o); + CHECK_GT(i, 0) + << "Metadata is of type uint64_t, but array type contains a negative number"; + uint64_t ui = static_cast(i); + Visit(nullptr, &ui); + continue; + } + case MetadataKind::kInt64: { + int64_t i = Downcast(o); + Visit(nullptr, &i); + continue; + } + case MetadataKind::kBool: { + bool b = Downcast(o); + Visit(nullptr, &b); + break; + } + case MetadataKind::kString: { + std::string s = Downcast(o); + Visit(nullptr, &s); + break; + } + case MetadataKind::kHandle: + CHECK(false) << "Don't know how to serialize handle"; + break; + + case MetadataKind::kMetadata: { + runtime::metadata::MetadataBase metadata = Downcast(o); + std::stringstream i_str; + i_str << i; + address_.push_back(i_str.str()); + Visit(nullptr, &metadata); + address_.pop_back(); + break; + } + default: + CHECK(false) << "Unknown MetadataKind for array: " << array->kind; + break; } - - runtime::metadata::MetadataBase metadata = Downcast(o); - std::stringstream i_str; - i_str << i; - address_.push_back(i_str.str()); - Visit(nullptr, &metadata); - address_.pop_back(); + is_first_item_ = false; } is_first_item_ = old_is_first_item; } @@ -688,7 +655,7 @@ class MetadataSerializer : public AttrVisitor { if (key != nullptr) { address_.push_back(key); } - code_ << address_from_parts(address_); + code_ << metadata::AddressFromParts(address_); if (key != nullptr) { address_.pop_back(); } @@ -705,59 +672,72 @@ class MetadataSerializer : public AttrVisitor { } } + private: + void EmitCType(const runtime::metadata::MetadataArrayNode* arr, std::ostream& os) { + switch (arr->kind) { + case MetadataKind::kUint64: + os << "uint64_t"; + break; + case MetadataKind::kInt64: + os << "int64_t"; + break; + case MetadataKind::kBool: + os << "bool"; + break; + case MetadataKind::kString: + os << "const char*"; + break; + case MetadataKind::kHandle: + os << "void*"; + break; + case MetadataKind::kMetadata: + os << "struct " << arr->get_element_c_struct_name(); + break; + default: + CHECK(false) << "Unknown kind in MetadataArray: " << arr->kind + << " (struct_name=" << arr->get_c_struct_name() << ")"; + break; + } + } + + public: void CodegenMetadata(::tvm::runtime::metadata::Metadata metadata) { decl_ << "#include " << std::endl << "#include " << std::endl << "#include " << std::endl; - std::vector queue; - MetadataQueuer queuer{&queue}; - queuer.Visit(kGlobalSymbol, &metadata); - - for (MetadataQueuer::QueueItem item : queue) { - auto struct_name = std::get<0>(item); - auto obj = std::get<1>(item); - auto arr = obj.as(); - is_first_item_ = true; - address_.push_back(struct_name); - if (arr != nullptr) { - const char* const_part = "const "; - if (arr->type_index == MetadataTypeIndex::kString) { - const_part = ""; - } - code_ << const_part; - switch (arr->type_index) { - case MetadataTypeIndex::kUint64: - code_ << "uint64_t"; - break; - case MetadataTypeIndex::kInt64: - code_ << "int64_t"; - break; - case MetadataTypeIndex::kBool: - code_ << "bool"; - break; - case MetadataTypeIndex::kString: - code_ << "const char*"; - break; - case MetadataTypeIndex::kHandle: - code_ << "void*"; - break; - case MetadataTypeIndex::kMetadata: - code_ << "struct " << arr->struct_name; - break; - default: - CHECK(false) << "Unknown type_index in array: " << arr->type_index - << " (struct_name=" << arr->struct_name << ")"; - break; - } - code_ << " " << struct_name << "[" << arr->array.size() << "] = {" << std::endl; - VisitArray(arr); - } else { - code_ << "const struct TVMMetadata " << struct_name << " = {" << std::endl; - Visit(nullptr, &obj); + std::vector queue; + metadata::DiscoverArraysVisitor array_discover{&queue}; + array_discover.Visit(metadata::kMetadataGlobalSymbol, &metadata); + + for (auto item : queue) { + auto struct_address = std::get<0>(item); + address_.push_back(struct_address); + + auto arr = std::get<1>(item); + + // Prepend const with everything except C-string, which needs appending. + if (arr->kind != MetadataKind::kString) { + code_ << "const "; + } + EmitCType(arr.operator->(), code_); + if (arr->kind == MetadataKind::kString) { + code_ << " const"; } + code_ << " " << struct_address << "[" << arr->array.size() << "] = {" << std::endl; + is_first_item_ = true; + + VisitArray(arr); address_.pop_back(); code_ << "};" << std::endl; } + + // Finally, emit overall struct. + address_.push_back(metadata::kMetadataGlobalSymbol); + code_ << "const struct TVMMetadata " << metadata::AddressFromParts(address_) << " = {" + << std::endl; + Visit(nullptr, &metadata); + code_ << "};" << std::endl; + address_.pop_back(); } std::string GetOutput() { return decl_.str() + code_.str(); } @@ -804,8 +784,8 @@ runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metad << "(TVMValue* arg_values, int* arg_tcodes, int " "num_args, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) {" << std::endl; - lookup_func << " ret_values[0].v_handle = (void*) &" << MetadataSerializer::kGlobalSymbol - << ";" << std::endl; + lookup_func << " ret_values[0].v_handle = (void*) &" << metadata::kMetadataGlobalSymbol << ";" + << std::endl; lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl; lookup_func << " return 0;" << std::endl; lookup_func << "};" << std::endl; diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 07b341dfd2c7..f4dbc238c120 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -810,7 +810,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Call Call::Call(DataType dtype, RelayExpr op, Array args, Span span) { for (size_t i = 0; i < args.size(); ++i) { - ICHECK(args[i].defined()); + ICHECK(args[i].defined()) << "arg " << i << " is not defined()"; } ObjectPtr node = make_object(); diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc index 2d8b6681fa84..43cb1fb03fa2 100644 --- a/src/tir/transforms/legalize_packed_calls.cc +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -43,10 +43,9 @@ using InputMap = */ class PackedCallLegalizer : public StmtExprMutator { public: - Stmt Legalize(const InputMap& params, tir::Stmt body) { - inputs_ = params; - return StmtExprMutator::VisitStmt(body); - } + PackedCallLegalizer(IRModule m, const InputMap& inputs) : mod_{m}, inputs_{inputs} {} + + Stmt Legalize(tir::Stmt body) { return StmtExprMutator::VisitStmt(body); } Stmt VisitStmt_(const EvaluateNode* op) final { if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op); @@ -56,49 +55,62 @@ class PackedCallLegalizer : public StmtExprMutator { // let B_packed = set_struct(tvm_value2, B) // let C_packed = set_struct(tvm_value3, C) // call_packed(f, A_packed, B_packed, C_packed) - std::vector new_stmts; if (call) { if (call->op.same_as(builtin::tvm_call_cpacked())) { Array packed_args{call->args[0]}; - std::vector tvm_values; - for (unsigned i = 1; i < call->args.size(); i++) { + VLOG(2) << "Legalize call:" << call; + BaseFunc base_func = mod_->Lookup(Downcast(call->args[0])->value); + const PrimFuncNode* prim_func = base_func.as(); + VLOG(2) << " to func " << base_func; + for (unsigned i = 1; i < call->args.size() - 1; i++) { // No need to pack inputs of the prim_func if (inputs_[call->args[i]] == true) { packed_args.push_back(call->args[i]); } else { - // Pack the argument inside a TVMValue - std::stringstream ss; - ss << "tvm_value_" << tvm_value_index_++; - auto sid_array = tir::Var(ss.str(), DataType::Handle()); - tvm_values.push_back(sid_array); - - new_stmts.push_back(tir::Evaluate( - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrData, call->args[i]}))); - new_stmts.push_back(tir::Evaluate( - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrDeviceType, kDLCPU}))); - new_stmts.push_back(tir::Evaluate( - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrDeviceId, 0}))); - packed_args.push_back(sid_array); + // Stack-allocate a DLTensor for this parameter. Note that LowerTVMBuiltin will collect + // all such stack-allocated tensors and minimize the storage needed by reusing + // DLTensors. + Array call_args{call->args[i]}; + tvm::runtime::Map::iterator param_buf_it; + if (prim_func != nullptr) { + auto param_var = prim_func->params[i - 1]; + param_buf_it = prim_func->preflattened_buffer_map.find(param_var); + } + if (prim_func != nullptr && param_buf_it != prim_func->preflattened_buffer_map.end()) { + Buffer param = (*param_buf_it).second; + PrimExpr shape = tvm::tir::Call( + DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), param->shape); + Cast var_type(param->dtype, IntImm(DataType::Int(32), 0)); + call_args.push_back(shape /* shape */); + call_args.push_back(make_zero(DataType::Handle()) /* strides */); + call_args.push_back(tvm::IntImm(DataType::UInt(32), param->shape.size()) /* ndim */); + call_args.push_back(var_type /* carries dtype */); + call_args.push_back(param->elem_offset /* elem_offset */); + } else { + // When the PrimFunc cannot be found, most DLTensor information cannot be populated. + PrimExpr shape = tvm::tir::Call( + DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), Array()); + Cast var_type(DataType::Handle(), IntImm(DataType::Int(32), 0)); + call_args.push_back(shape /* shape */); + call_args.push_back(make_zero(DataType::Handle()) /* strides */); + call_args.push_back(tvm::IntImm(DataType::UInt(32), 0) /* ndim */); + call_args.push_back(var_type /* carries dtype */); + call_args.push_back(tvm::IntImm(DataType::UInt(64), 0) /* elem_offset */); + } + packed_args.push_back(tvm::tir::Call( + DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), call_args)); } } + packed_args.push_back(call->args[call->args.size() - 1]); // push device_context // Evaluate the packed call - new_stmts.push_back(tir::Evaluate(tir::Call(call->dtype, call->op, packed_args))); - tir::Stmt call_stmt = tir::SeqStmt(new_stmts); - - // Allocate the TVMValues on the stack and define the variables - for (auto v : tvm_values) { - call_stmt = LetStmt(v, StackAlloca("array", 1), call_stmt); - } - return call_stmt; + return tir::Evaluate(tir::Call(call->dtype, call->op, packed_args)); } } return StmtExprMutator::VisitStmt_(op); } private: + IRModule mod_; InputMap inputs_; // Store the inputs to the primfunc that don't need to be packed. int tvm_value_index_; // Index of the actual tvm_value variable }; @@ -109,12 +121,12 @@ Pass LegalizePackedCalls() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - // Create the + // Note which Var are inputs and exclude them from packing. InputMap inputs; for (auto i : f->params) { inputs[i] = true; } - n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body)); + n->body = PackedCallLegalizer(m, inputs).Legalize(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {}); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index e474683b39fc..9d0087cc7a0b 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -109,11 +109,14 @@ class BuiltinLower : public StmtExprMutator { precheck.device_type_ = this->device_type_; precheck.alloca_scope_.emplace_back(); - auto& scope = precheck.alloca_scope_.back(); - scope.stack_shape = - decl_buffer({IntImm(DataType::Int(64), 0)}, DataType::Int(64), "stack_shape"); - scope.stack_tcode = - decl_buffer({IntImm(DataType::UInt(64), 0)}, DataType::Int(32), "stack_tcode"); + { + // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_. + auto& scope = precheck.alloca_scope_.back(); + scope.stack_shape = + decl_buffer({IntImm(DataType::Int(64), 0)}, DataType::Int(64), "stack_shape"); + scope.stack_tcode = + decl_buffer({IntImm(DataType::UInt(64), 0)}, DataType::Int(32), "stack_tcode"); + } precheck.VisitStmt(stmt); @@ -130,31 +133,35 @@ class BuiltinLower : public StmtExprMutator { } alloca_scope_.emplace_back(); - auto& scope = alloca_scope_.back(); - - // Initial check to identify maximum stack sizes. These are used - // to construct Buffer objects to hold the stack, which are then - // used when mutating. - scope.max_sizes = GetMaxStack(stmt); - - if (scope.max_sizes.shape_stack != -1) { - scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)}, - DataType::Int(64), "stack_shape"); - stmt = - LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack), stmt); - } + { + // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_. + auto& scope = alloca_scope_.back(); + + // Initial check to identify maximum stack sizes. These are used + // to construct Buffer objects to hold the stack, which are then + // used when mutating. + scope.max_sizes = GetMaxStack(stmt); + + if (scope.max_sizes.shape_stack != -1) { + scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)}, + DataType::Int(64), "stack_shape"); + stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack), + stmt); + } - if (scope.max_sizes.array_stack != 0) { - stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_sizes.array_stack), stmt); - } + if (scope.max_sizes.array_stack != 0) { + stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_sizes.array_stack), stmt); + } - if (scope.max_sizes.arg_stack != 0) { - scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), scope.max_sizes.arg_stack)}, - DataType::Int(32), "stack_tcode"); - stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_sizes.arg_stack), stmt); + if (scope.max_sizes.arg_stack != 0) { + scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), scope.max_sizes.arg_stack)}, + DataType::Int(32), "stack_tcode"); + stmt = + LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_sizes.arg_stack), stmt); - stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", scope.max_sizes.arg_stack), - stmt); + stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", scope.max_sizes.arg_stack), + stmt); + } } stmt = this->VisitStmt(stmt); @@ -169,14 +176,22 @@ class BuiltinLower : public StmtExprMutator { // allocate space to hold prepare stmts before s prep_seq_stack_.emplace_back(std::vector()); + auto scope_size = alloca_scope_.size(); auto stmt = StmtExprMutator::VisitStmt(s); - auto& scope = alloca_scope_.back(); - // This invariant asserts the assumption that - // make_stack_shape only happens within a call_packed. - // We could relax this in the future if we want to - // introduce root scope as a separate scope - ICHECK_EQ(scope.run_sizes.shape_stack, -1); - ICHECK_EQ(scope.run_sizes.array_stack, 0); + { + // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_. + auto& scope = alloca_scope_.back(); + // This invariant asserts the assumption that + // make_stack_shape only happens within a call_packed. + // We could relax this in the future if we want to + // introduce root scope as a separate scope + ICHECK_EQ(alloca_scope_.size(), scope_size) + << "alloca_scope_ length is different before and after recursion"; + ICHECK_EQ(scope.run_sizes.shape_stack, -1) + << "Expect no tvm_stack_make_shape outside of CallNodes"; + ICHECK_EQ(scope.run_sizes.array_stack, 0) + << "Expect no tvm_stack_make_array outside of CallNodes"; + } auto prep_seq = std::move(prep_seq_stack_.back()); prep_seq_stack_.pop_back(); @@ -369,9 +384,12 @@ class BuiltinLower : public StmtExprMutator { make_const(DataType::UInt(16), dtype.lanes()))); // set byte offset int data_bytes = GetVectorBytes(dtype); - PrimExpr byte_offset = op->args[5]; - if (!is_zero(byte_offset)) { - byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes); + PrimExpr elem_offset = op->args[5]; + PrimExpr byte_offset; + if (!is_zero(elem_offset)) { + byte_offset = elem_offset * make_const(elem_offset.dtype(), data_bytes); + } else { + byte_offset = elem_offset; } prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset, cast(DataType::UInt(64), byte_offset))); @@ -436,8 +454,14 @@ class BuiltinLower : public StmtExprMutator { // cpacked call resource_handle if (!use_string_lookup) { - tir::Var resource_handle = Downcast(op->args[arg_count]); - packed_args.push_back(StringImm(resource_handle->name_hint)); + PrimExpr last_arg = op->args[arg_count]; + const VarNode* var_node = last_arg.as(); + if (var_node != nullptr) { + tir::Var resource_handle = GetRef(var_node); + packed_args.push_back(StringImm(resource_handle->name_hint)); + } else { + packed_args.push_back(last_arg); + } } auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered() @@ -561,6 +585,7 @@ Pass LowerTVMBuiltin() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = BuiltinLower().Build(n->body); + VLOG(2) << "LowerTVMBuiltin: " << f; return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index b73534090ab5..ba5ab891baa4 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -200,8 +200,11 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda int pool_size = all_pools_sizes_[pool_info]; String buffer_var_name = pool_ref_name + "_buffer_var"; - si.buffer_map.Set(pool_var, Buffer(buffer_var, elem_dtype, {pool_size}, {1}, 1, buffer_var_name, - 16, 1, BufferType::kDefault)); + si.buffer_map.Set(pool_var, + Buffer(buffer_var /* data */, elem_dtype /* dtype */, {pool_size} /* shape */, + {1} /* strides */, 0 /* elem_offset */, buffer_var_name /* name */, + 16 /* data_alignment */, 1 /* offset_factor */, + BufferType::kDefault /* buffer-type */)); } if (resource_handle) { si.params.push_back(resource_handle.value()); @@ -223,8 +226,8 @@ PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( if (emit_tvmscript_printable_) { original_attrs = DictAttrs(); } - PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, {}, - original_attrs); + PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, + si.buffer_map, original_attrs); if (!emit_tvmscript_printable_) { ret = WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params); } diff --git a/tests/cpp/aot_metadata_test.cc b/tests/cpp/aot_metadata_test.cc index abf37ce4569a..b1dea64aaa9c 100644 --- a/tests/cpp/aot_metadata_test.cc +++ b/tests/cpp/aot_metadata_test.cc @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -25,6 +24,7 @@ #include #include "../src/target/metadata.h" +#include "../src/target/metadata_utils.h" namespace { @@ -46,12 +46,28 @@ const struct TVMMetadata kNormal = { } // namespace using ::testing::ElementsAre; +using ::testing::ElementsAreArray; using ::testing::Eq; +using ::testing::Matcher; +using ::testing::MatcherInterface; +using ::testing::MatchResultListener; using ::testing::StrEq; + +using ::tvm::codegen::metadata::DiscoverArraysVisitor; +using ::tvm::codegen::metadata::DiscoverComplexTypesVisitor; +using ::tvm::codegen::metadata::kMetadataGlobalSymbol; + +using ::tvm::runtime::Array; using ::tvm::runtime::Downcast; +using ::tvm::runtime::ObjectRef; + +using ::tvm::runtime::metadata::Metadata; +using ::tvm::runtime::metadata::MetadataArray; +using ::tvm::runtime::metadata::MetadataKind; +using ::tvm::runtime::metadata::TensorInfo; TEST(Metadata, ParseStruct) { - tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal); + Metadata md = Metadata(&kNormal); EXPECT_THAT(md->version(), Eq(TVM_METADATA_VERSION)); EXPECT_THAT(md->num_inputs(), Eq(2)); @@ -137,7 +153,7 @@ class TestVisitor : public tvm::AttrVisitor { }; TEST(Metadata, Visitor) { - tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal); + Metadata md = Metadata(&kNormal); TestVisitor v; ::tvm::ReflectionVTable::Global()->VisitAttrs(md.operator->(), &v); @@ -149,17 +165,17 @@ TEST(Metadata, Visitor) { EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); // Just identify the tensor. - auto input_array = Downcast(v.values[1]); - EXPECT_THAT(input_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); - EXPECT_THAT(input_array->struct_name, StrEq("TVMTensorInfo")); + auto input_array = Downcast(v.values[1]); + EXPECT_THAT(input_array->kind, Eq(MetadataKind::kMetadata)); + EXPECT_THAT(input_array->type_key, StrEq("metadata.TensorInfoNode")); EXPECT_THAT(input_array->array.size(), Eq(2)); - auto input1 = Downcast(input_array->array[0]); + auto input1 = Downcast(input_array->array[0]); EXPECT_THAT(input1->name(), StrEq("input1")); EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3)); EXPECT_THAT(input1->dtype(), tvm::runtime::DataType(DLDataType{1, 2, 3})); - auto input2 = Downcast(input_array->array[1]); + auto input2 = Downcast(input_array->array[1]); EXPECT_THAT(input1->name(), StrEq("input1")); EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3)); EXPECT_THAT(input1->dtype(), tvm::runtime::DataType(DLDataType{1, 2, 3})); @@ -167,20 +183,20 @@ TEST(Metadata, Visitor) { auto num_inputs = Downcast(v.values[2]); EXPECT_THAT(num_inputs->value, Eq(2)); - auto output_array = Downcast(v.values[3]); - EXPECT_THAT(output_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); - EXPECT_THAT(output_array->struct_name, StrEq("TVMTensorInfo")); - auto output1 = Downcast(output_array->array[0]); + auto output_array = Downcast(v.values[3]); + EXPECT_THAT(output_array->kind, Eq(MetadataKind::kMetadata)); + EXPECT_THAT(output_array->type_key, StrEq("metadata.TensorInfoNode")); + auto output1 = Downcast(output_array->array[0]); EXPECT_THAT(output1->name(), Eq("output1")); auto num_outputs = Downcast(v.values[4]); EXPECT_THAT(num_outputs->value, Eq(1)); - auto pool_array = Downcast(v.values[5]); - EXPECT_THAT(pool_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); - EXPECT_THAT(pool_array->struct_name, StrEq("TVMTensorInfo")); - auto pool1 = Downcast(pool_array->array[0]); + auto pool_array = Downcast(v.values[5]); + EXPECT_THAT(pool_array->kind, Eq(MetadataKind::kMetadata)); + EXPECT_THAT(pool_array->type_key, StrEq("metadata.TensorInfoNode")); + auto pool1 = Downcast(pool_array->array[0]); EXPECT_THAT(pool1->name(), Eq("pool1")); @@ -193,27 +209,24 @@ TEST(Metadata, Visitor) { using ::tvm::runtime::make_object; TEST(Metadata, InMemory) { - tvm::runtime::metadata::Metadata md = - tvm::runtime::metadata::Metadata(make_object( - TVM_METADATA_VERSION, - std::vector( - {tvm::runtime::metadata::TensorInfo( - make_object( - tvm::String("Input1"), std::vector{1, 5, 5, 3}, - tvm::runtime::DataType(DLDataType{1, 2, 3}))), - tvm::runtime::metadata::TensorInfo( - make_object( - tvm::String("Input2"), std::vector{1, 5, 5, 3}, - tvm::runtime::DataType(DLDataType{2, 3, 4})))}), - std::vector({tvm::runtime::metadata::TensorInfo( - make_object( - tvm::String("Output1"), std::vector{3, 8, 8}, - tvm::runtime::DataType(DLDataType{3, 4, 5})))}), - std::vector({tvm::runtime::metadata::TensorInfo( - make_object( - tvm::String("Pool1"), std::vector{5, 10, 10}, - tvm::runtime::DataType(DLDataType{3, 4, 7})))}), - "default")); + Metadata md = Metadata(make_object( + TVM_METADATA_VERSION, + std::vector( + {TensorInfo(make_object( + tvm::String("Input1"), std::vector{1, 5, 5, 3}, + tvm::runtime::DataType(DLDataType{1, 2, 3}))), + TensorInfo(make_object( + tvm::String("Input2"), std::vector{1, 5, 5, 3}, + tvm::runtime::DataType(DLDataType{2, 3, 4})))}), + std::vector( + {TensorInfo(make_object( + tvm::String("Output1"), std::vector{3, 8, 8}, + tvm::runtime::DataType(DLDataType{3, 4, 5})))}), + std::vector( + {TensorInfo(make_object( + tvm::String("Pool1"), std::vector{5, 10, 10}, + tvm::runtime::DataType(DLDataType{3, 4, 7})))}), + "default")); auto md_data = md->data(); EXPECT_THAT(md_data->version, Eq(TVM_METADATA_VERSION)); @@ -251,14 +264,13 @@ TEST(Metadata, InMemory) { } TEST(Metadata, ZeroElementLists) { - tvm::runtime::metadata::Metadata md = - tvm::runtime::metadata::Metadata(make_object( - TVM_METADATA_VERSION, std::vector({}), - std::vector({tvm::runtime::metadata::TensorInfo( - make_object( - tvm::String("Output1"), std::vector{}, - tvm::runtime::DataType(DLDataType{3, 4, 5})))}), - std::vector({}), "default")); + Metadata md = Metadata(make_object( + TVM_METADATA_VERSION, std::vector({}), + std::vector( + {TensorInfo(make_object( + tvm::String("Output1"), std::vector{}, + tvm::runtime::DataType(DLDataType{3, 4, 5})))}), + std::vector({}), "default")); EXPECT_THAT(md->data()->num_inputs, Eq(0)); EXPECT_THAT(md->inputs().size(), Eq(0)); @@ -274,3 +286,84 @@ TEST(Metadata, ZeroElementLists) { EXPECT_THAT(md->num_pools(), Eq(0)); EXPECT_THAT(md->pools(), ElementsAre()); } + +TEST(MetadataArray, GetElementCStructName) { + MetadataArray arr_struct{make_object( + Array(), MetadataKind::kMetadata, "metadata.FooMetadataNode")}; + EXPECT_THAT(arr_struct->kind, Eq(MetadataKind::kMetadata)); + EXPECT_THAT(arr_struct->get_element_c_struct_name(), StrEq("TVMFooMetadata")); + + MetadataArray arr_int{make_object( + Array(), MetadataKind::kInt64, nullptr)}; + EXPECT_THROW(arr_int->get_element_c_struct_name(), std::runtime_error); +} + +namespace { +std::string ExplainDiscoveredNameEq(bool negation, std::string expected_name) { + std::stringstream ss; + ss << "std::get<0>(discovered_array) " << (negation ? "isn't" : "is") << " equal to " + << expected_name; + return ss.str(); +} +} // namespace + +MATCHER_P(DiscoveredNameEq, expected_name, ExplainDiscoveredNameEq(negation, expected_name)) { + return std::string(std::get<0>(arg)) == expected_name; +} + +TEST(DiscoverArraysVisitor, DiscoverArrays) { + std::vector q; + DiscoverArraysVisitor visitor(&q); + + Metadata md = Metadata(&kNormal); + visitor.Visit(kMetadataGlobalSymbol, &md); + + EXPECT_THAT(q, ElementsAreArray({DiscoveredNameEq("kTvmgenMetadata_inputs_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_inputs_1_shape"), + DiscoveredNameEq("kTvmgenMetadata_inputs"), + DiscoveredNameEq("kTvmgenMetadata_outputs_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_outputs"), + DiscoveredNameEq("kTvmgenMetadata_pools_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_pools")})); +} + +template ::value, bool> = + true> +class TVMObjectIsInstanceMatcher : public MatcherInterface { + public: + using is_gtest_matcher = void; + + bool MatchAndExplain(tvm::runtime::metadata::MetadataBase arg, + MatchResultListener* os) const override { + bool result = arg->IsInstance(); + if (!result) { + (*os) << "is an instance of type " << T::ContainerType::_type_key; + } + + return result; + } + + void DescribeTo(std::ostream* os) const override { + (*os) << "is an instance of type " << T::ContainerType::_type_key; + } + + void DescribeNegationTo(std::ostream* os) const override { + (*os) << "is not an instance of type " << T::ContainerType::_type_key; + } +}; + +template +Matcher TVMObjectIsInstance() { + return Matcher(new TVMObjectIsInstanceMatcher()); +} + +TEST(DiscoverComplexTypesVisitor, DiscoverComplexTypes) { + std::vector q; + DiscoverComplexTypesVisitor visitor(&q); + + Metadata md = Metadata(&kNormal); + visitor.Discover(md); + + EXPECT_THAT(q, ElementsAre(TVMObjectIsInstance(), TVMObjectIsInstance())); +} diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index 72a6fe3f83b8..dbc581ae3dfd 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -321,7 +321,7 @@ def test_aot_executor(hexagon_session): params=params, target=tvm.target.Target(target_hexagon, host="c"), runtime=Runtime("cpp"), - executor=Executor("aot", {"unpacked-api": False, "interface-api": "c"}), + executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}), ) if hexagon_session is None: @@ -401,7 +401,7 @@ def test_aot_executor_multiple_conv2d(hexagon_session): params=params, target=tvm.target.Target(target_hexagon, host="c"), runtime=Runtime("cpp"), - executor=Executor("aot", {"unpacked-api": False, "interface-api": "c"}), + executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}), ) if hexagon_session is None: diff --git a/tests/python/relay/aot/test_c_device_api.py b/tests/python/relay/aot/test_c_device_api.py index 6a12a38d35c2..d547b52e85c3 100644 --- a/tests/python/relay/aot/test_c_device_api.py +++ b/tests/python/relay/aot/test_c_device_api.py @@ -143,6 +143,7 @@ def test_device_api_hooks_unpacked_api(device_api_main_func): + " device_context_ethos_u))\n" ) # Open Device + print("main func", repr(main_func.body)) assert ( str(main_func.body[1][0][0][0]) == "tir.tvm_check_return(0, -1, tir.call_extern(" @@ -239,23 +240,11 @@ def test_without_device_api_packed_api(non_device_api_main_func): main_func = non_device_api_main_func(interface_api="packed", use_unpacked_api=False) assert str(main_func.body) == ( - 'let tvm_value_3 = tir.tvm_stack_alloca("array", 1)\n' - 'let tvm_value_2 = tir.tvm_stack_alloca("array", 1)\n' - 'let tvm_value_1 = tir.tvm_stack_alloca("array", 1)\n' - 'let tvm_value_0 = tir.tvm_stack_alloca("array", 1)\n' - "tir.tvm_struct_set(tvm_value_0, 0, 1, x_buffer_var)\n" - "tir.tvm_struct_set(tvm_value_0, 0, 10, 1)\n" - "tir.tvm_struct_set(tvm_value_0, 0, 9, 0)\n" - "tir.tvm_struct_set(tvm_value_1, 0, 1, y_buffer_var)\n" - "tir.tvm_struct_set(tvm_value_1, 0, 10, 1)\n" - "tir.tvm_struct_set(tvm_value_1, 0, 9, 0)\n" - "tir.tvm_struct_set(tvm_value_2, 0, 1, output_buffer_var)\n" - "tir.tvm_struct_set(tvm_value_2, 0, 10, 1)\n" - "tir.tvm_struct_set(tvm_value_2, 0, 9, 0)\n" - "tir.tvm_struct_set(tvm_value_3, 0, 1, tir.reinterpret((uint64)0))\n" - "tir.tvm_struct_set(tvm_value_3, 0, 10, 1)\n" - "tir.tvm_struct_set(tvm_value_3, 0, 9, 0)\n" - 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", tvm_value_0, tvm_value_1, tvm_value_2, tvm_value_3)\n' + 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", ' + "tir.tvm_stack_make_array(x_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " + "tir.tvm_stack_make_array(y_buffer_var, tir.tvm_stack_make_shape(1, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " + "tir.tvm_stack_make_array(output_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " + "tir.reinterpret((uint64)0))\n" ) diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 48057404dd4c..2a11e7e28748 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -24,20 +24,10 @@ import pytest import tvm -from tvm import relay, TVMError -from tvm.ir.module import IRModule -from tvm.relay import backend, testing, transform -from tvm.relay.testing import byoc -from tvm.relay.op.annotation import compiler_begin, compiler_end -from aot_test_utils import ( - AOTTestModel, - AOT_DEFAULT_RUNNER, - generate_ref_data, - convert_to_relay, - compile_and_run, - compile_models, - parametrize_aot_options, -) +from tvm import IRModule +from tvm import relay +from tvm.relay import backend, testing +from aot_test_utils import AOT_DEFAULT_RUNNER, AOTTestModel, generate_ref_data, compile_and_run def test_error_c_interface(): @@ -51,25 +41,22 @@ def test_error_c_interface(): with pytest.raises( tvm.TVMError, match=re.escape( - 'Either need interface_api == "packed" (got: c) or ' - "unpacked-api == true (got: (bool)0) when targeting " - "c runtime" + 'Need unpacked-api == false (got: 0) and interface-api == "packed" (got: c) when ' + "targeting c++ runtime" ), ): - compile_and_run( - AOTTestModel( - module=IRModule.from_expr(func), inputs={}, outputs=generate_ref_data(func, {}) - ), - test_runner, - interface_api, - use_unpacked_api, + tvm.relay.build( + IRModule.from_expr(func), + target="llvm", + executor=backend.Executor("aot", {"interface-api": "c"}), ) enable_usmp = tvm.testing.parameter(True, False) +target_kind = tvm.testing.parameter("c", "llvm") -def test_conv2d(enable_usmp): +def test_conv2d(enable_usmp, target_kind): RELAY_MODEL = textwrap.dedent( """\ #[version = "0.0.5"] @@ -117,7 +104,7 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), mod = tvm.relay.build( ir_mod, params=params, - target="c", + target=target_kind, executor=backend.Executor("aot", {"interface-api": "packed"}), ) @@ -131,18 +118,20 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), assert (runner.get_output(0).asnumpy() == list(ref_outputs.values())[0]).all() -def test_mobilenet(): +def test_mobilenet(enable_usmp, target_kind): ir_mod, params = testing.mobilenet.get_workload(batch_size=1) data_shape = [int(x) for x in ir_mod["main"].checked_type.arg_types[0].shape] data = np.random.uniform(size=data_shape).astype("float32") inputs = {"data": data} ref_outputs = generate_ref_data(ir_mod, inputs, params) - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + with tvm.transform.PassContext( + opt_level=3, config={"tir.disable_vectorize": True, "tir.usmp.enable": enable_usmp} + ): mod = tvm.relay.build( ir_mod, params=params, - target="c", + target=target_kind, executor=backend.Executor("aot", {"interface-api": "packed"}), ) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 51a503ecfe38..3c44d2bf1bc8 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -60,7 +60,7 @@ def test_error_c_interface_with_packed_api(): tvm.TVMError, match=re.escape( 'Either need interface_api == "packed" (got: c) or ' - "unpacked-api == true (got: (bool)0) when targeting " + "unpacked-api == true (got: 0) when targeting " "c runtime" ), ): diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index 54561ade23e4..c7c0daa30e2f 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -24,11 +24,24 @@ @tvm.script.ir_module class Module: + @T.prim_func + def tvm_test_cpacked( + A: T.handle, B: T.handle, C: T.handle, device_context: T.handle + ) -> T.handle: + A_0 = T.match_buffer(A, (1,), dtype="float32") + A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32") + B_0 = T.match_buffer(B, (1,), dtype="float32") + B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32") + C_0 = T.match_buffer(C, (1,), dtype="float32") + C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32") + T.evaluate(C) + @T.prim_func def tir_packed_call() -> None: A = T.var("handle") B = T.var("handle") C = T.var("handle") + device_context = T.var("handle") # body T.evaluate( T.tvm_call_cpacked( @@ -36,6 +49,7 @@ def tir_packed_call() -> None: A, B, C, + device_context, dtype="int32", ) ) @@ -43,40 +57,60 @@ def tir_packed_call() -> None: @tvm.script.ir_module class Expected: + @T.prim_func + def tvm_test_cpacked( + A: T.handle, B: T.handle, C: T.handle, device_context: T.handle + ) -> T.handle: + A_0 = T.match_buffer(A, (1,), dtype="float32") + A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32") + B_0 = T.match_buffer(B, (1,), dtype="float32") + B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32") + C_0 = T.match_buffer(C, (1,), dtype="float32") + C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32") + T.evaluate(C) + @T.prim_func def tir_packed_call() -> None: A = T.var("handle") B = T.var("handle") C = T.var("handle") + device_context = T.var("handle") # body - tvm_value_2 = T.var("handle") - tvm_value_1 = T.var("handle") - tvm_value_0 = T.var("handle") - with T.let(tvm_value_2, T.tvm_stack_alloca("array", 1, dtype="handle")): - with T.let(tvm_value_1, T.tvm_stack_alloca("array", 1, dtype="handle")): - with T.let(tvm_value_0, T.tvm_stack_alloca("array", 1, dtype="handle")): - T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 1, A, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 10, 1, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 9, 0, dtype="handle")) - - T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 1, B, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 10, 1, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 9, 0, dtype="handle")) - - T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 1, C, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 10, 1, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 9, 0, dtype="handle")) - - T.evaluate( - T.tvm_call_cpacked( - "tvm_test_cpacked", - tvm_value_0, - tvm_value_1, - tvm_value_2, - dtype="int32", - ) - ) + T.evaluate( + T.tvm_call_cpacked( + "tvm_test_cpacked", + T.tvm_stack_make_array( + A, + T.tvm_stack_make_shape(1, dtype="handle"), + T.reinterpret(T.uint64(0), dtype="handle"), + T.uint32(1), + T.cast(0, dtype="float32"), + 0, + dtype="handle", + ), + T.tvm_stack_make_array( + B, + T.tvm_stack_make_shape(1, dtype="handle"), + T.reinterpret(T.uint64(0), dtype="handle"), + T.uint32(1), + T.cast(0, dtype="float32"), + 0, + dtype="handle", + ), + T.tvm_stack_make_array( + C, + T.tvm_stack_make_shape(1, dtype="handle"), + T.reinterpret(T.uint64(0), dtype="handle"), + T.uint32(1), + T.cast(0, dtype="float32"), + 0, + dtype="handle", + ), + device_context, + dtype="int32", + ) + ) def test_aot_packed_call(): diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 4ed02615cd44..ce8675f575ee 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -74,8 +74,11 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_4, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_5, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(T_subtract_1, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): @@ -86,9 +89,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_65, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_66, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_67, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(T_cast_21, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): @@ -108,7 +115,9 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_29, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(T_cast_7, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): @@ -140,9 +149,9 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: @tvm.script.ir_module class LinearStructurePlanned: @T.prim_func - def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory_1_var: T.Ptr[T.uint8], output: T.handle) -> None: - fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory_1_var: T.Ptr[T.uint8], output: T.handle) -> None: + fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) @@ -155,9 +164,13 @@ def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") + T.preflattened_buffer(placeholder_29, [802816], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") - fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_cast_7, [177], dtype="int16") + fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(fast_memory_6_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(slow_memory_7_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body tensor_2_let = T.buffer_decl([200704], dtype="uint8") with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")): @@ -172,10 +185,15 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr[T.uint8], slow_memory_3_var: T.Ptr[T.uint8]) -> None: placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") + T.preflattened_buffer(placeholder_4, [150528], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") + T.preflattened_buffer(placeholder_5, [1], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") - fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_subtract_1, [452], dtype="int16") + fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(fast_memory_2_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(slow_memory_3_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @@ -183,11 +201,17 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.Ptr[T.uint8], slow_memory_5_var: T.Ptr[T.uint8]) -> None: placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16") + T.preflattened_buffer(placeholder_65, [150528], dtype="int16") placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") + T.preflattened_buffer(placeholder_66, [9408], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") + T.preflattened_buffer(placeholder_67, [64], dtype="int32") T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8") - fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_cast_21, [289], dtype="uint8") + fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(fast_memory_4_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(slow_memory_5_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_7_let = T.buffer_decl([157323], "int16") with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): @@ -251,8 +275,11 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") + T.preflattened_buffer(placeholder_2, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T.preflattened_buffer(placeholder_3, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") + T.preflattened_buffer(T_cast_1, [215], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -262,9 +289,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") + T.preflattened_buffer(placeholder_13, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") + T.preflattened_buffer(placeholder_14, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") + T.preflattened_buffer(placeholder_15, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") + T.preflattened_buffer(T_cast_5, [215], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): @@ -283,9 +314,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") + T.preflattened_buffer(placeholder_19, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") + T.preflattened_buffer(placeholder_20, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") + T.preflattened_buffer(placeholder_21, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") + T.preflattened_buffer(T_add_1, [407], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): @@ -305,10 +340,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") + T.preflattened_buffer(placeholder_29, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") + T.preflattened_buffer(placeholder_27, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") + T.preflattened_buffer(placeholder_26, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") + T.preflattened_buffer(placeholder_28, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") + T.preflattened_buffer(T_cast_7, [407], dtype="uint8") # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): @@ -345,9 +385,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") + T.preflattened_buffer(placeholder_7, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") + T.preflattened_buffer(placeholder_8, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") + T.preflattened_buffer(placeholder_9, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") + T.preflattened_buffer(T_cast_3, [215], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): @@ -369,9 +413,13 @@ class ResnetStructurePlanned: @T.prim_func def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr[T.uint8]) -> None: placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") + T.preflattened_buffer(placeholder_2, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T.preflattened_buffer(placeholder_3, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") - global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_cast_1, [215], dtype="int16") + global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_1_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -379,11 +427,17 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") + T.preflattened_buffer(placeholder_29, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") + T.preflattened_buffer(placeholder_27, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") + T.preflattened_buffer(placeholder_26, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") + T.preflattened_buffer(placeholder_28, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") - global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_cast_7, [407], dtype="uint8") + global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_5_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_3_let = T.buffer_decl([360000], 'int16') with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")): @@ -403,10 +457,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.Ptr[T.uint8]) -> None: placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") + T.preflattened_buffer(placeholder_19, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") + T.preflattened_buffer(placeholder_20, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") + T.preflattened_buffer(placeholder_21, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") - global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_add_1, [407], dtype="int32") + global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_4_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_2_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")): @@ -426,10 +485,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.Ptr[T.uint8]) -> None: placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") + T.preflattened_buffer(placeholder_7, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") + T.preflattened_buffer(placeholder_8, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") + T.preflattened_buffer(placeholder_9, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") - global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_cast_3, [215], dtype="int16") + global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_2_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")): @@ -448,10 +512,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.Ptr[T.uint8]) -> None: placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") + T.preflattened_buffer(placeholder_13, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") + T.preflattened_buffer(placeholder_14, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") + T.preflattened_buffer(placeholder_15, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") - global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + T.preflattened_buffer(T_cast_5, [215], dtype="int16") + global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_3_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_1_let = T.buffer_decl([379456], "int16") with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): @@ -469,7 +538,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla @T.prim_func def __tvm_main__(input: T.handle, global_workspace_0_var: T.Ptr[T.uint8], output: T.handle) -> None: - global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 73be9d8cdc58..0610559a05d8 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -636,5 +636,27 @@ def test_non_integer_typed_block_iter(): check_error(non_integer_typed_block_iter, 3) +def preflattened_buffer_map_align_nonint(foo: T.handle): + foo_1 = T.match_buffer(foo, [1]) + T.preflattened_buffer( + foo_1, [1], align="bar" + ) # check_error: align: want int or IntImm, got 'bar' + + +def test_preflattened_buffer_map_align(): + check_error(preflattened_buffer_map_align_nonint, 3) + + +def preflattened_buffer_map_offset_factor_nonint(foo: T.handle): + foo_1 = T.match_buffer(foo, [1]) + T.preflattened_buffer( + foo_1, [1], offset_factor="bar" + ) # check_error: offset_factor: want int or IntImm, got 'bar' + + +def test_preflattened_buffer_map_offset_factor(): + check_error(preflattened_buffer_map_offset_factor_nonint, 3) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 26a6f4530bda..4a2482c11d22 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -181,6 +181,23 @@ def test_dynamic_shape_gemm(): assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip) +@T.prim_func +def preflattened_buffer_map(A: T.handle, B: T.handle): + A_1 = T.match_buffer(A, [1]) + T.preflattened_buffer(A_1, [1], align=T.int32(1), offset_factor=T.int64(2)) + B_1 = T.match_buffer(B, [1]) + T.preflattened_buffer(B_1, [1]) + B_1[0] = A_1[0] + + +def test_preflattened_buffer_map(): + A_var = [ + k for k, _ in preflattened_buffer_map.preflattened_buffer_map.items() if k.name == "A" + ][0] + assert preflattened_buffer_map.preflattened_buffer_map[A_var].data_alignment == 1 + assert preflattened_buffer_map.preflattened_buffer_map[A_var].offset_factor == 2 + + @T.prim_func def match_buffer_int64(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (T.int64(128), T.int64(128)), dtype="float32")