Skip to content

Commit 4178617

Browse files
areuschmasahi
andauthored
[AOT] Support LLVM backend with C++ runtime (#10753)
* add get_c_struct_name() method to Metadata to distinguish struct type name in llvm * add metadata serialization support to llvm codegen * Organize MetadataQueuer into a separate file. * Add DiscoverArraysVisitor to metadata_utils * Fill DLTensor metadata in LegalizePackedCalls. * Improve error message from Call asserts * Pass non-String device_context down to codegen. * this is necessary to allow CodeGenCPU to emit calls that include resource_handle. * Scope usage of lvalue refs in LowerTVMBuiltin to avoid corrupt memory. * test fixes * Also fill preflattened_buffer_map (TODO, maybe don't do this) * Fix C codegen. * Set USMP elem_offset to 0. * Clarify calculation of byte_offset from elem_offset. * fix tests * Fix arm compile warning * Fix hexagon test. * previously I believe we required interface_api == "c", but this really means to generate C API bindings, and we are generating "packed" bindings. * I think "c" was chosen here because the distinction between interface-api and use-unpacked-api is confusing. "c" interface-api means to generate an entrypoint API for microcontrollers that accepts bare data buffers. "packed" interface-api means to generate a TVMBackendPackedCFunc entrypoint. use-unpacked-api forms the same determination for the operator functions. * A further confusion here is that there are two ways to call "packed" operator functions: tir.tvm_builtin_call_packed and tir.tvm_builtin_call_cpacked. This distinction describes whether or not to late-bind calls via TVMBackendGetFuncFromEnv. Right now, AOT only ever requires call_cpacked because target_host == target, and for all suitable target_host, we expect a single DSO-exportable runtime.Module. When we move away from this by introducing heterogeneous target support to AOT, we can use this as a condition to help us choose between call_cpacked and call_packed (and possibly add a compile-time option to assert it is call_cpacked, for situations where we really don't want call_packed). * Document T.preflattened_buffer * Fix test_aot_legalize_packed_calls * Address manupa comments * Fix convert_pool_allocations_to_offsets test. * lint * Fix T.preflattened_buffer * Add preflattened_buffer_map to TIRTextPrinter * Fix tests * Fix BYOC * Fix invoking C device API. * remove comments * Address Mousius comments * lint * lint * Fix GMock linking on new CMake * address masahi comment Co-authored-by: Masahiro Masuda <[email protected]>
1 parent 0cd4fa6 commit 4178617

35 files changed

+1667
-527
lines changed

CMakeLists.txt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,25 @@ if(USE_GTEST)
431431
find_package(GTest REQUIRED)
432432
endif()
433433
if(GTEST_FOUND)
434+
if(NOT TARGET GTest::gmock)
435+
# GMock is formally supported in CMake 3.20; for now, expect libgmock.a in the same directory,
436+
# and require that folks compiling against GTest::gmock also link against GTest::GTest
437+
# (for the includes dir).
438+
add_library(GTest::gmock STATIC IMPORTED GLOBAL)
439+
get_target_property(GTEST_LIB_PATH GTest::GTest IMPORTED_LOCATION)
440+
if("${GTEST_LIB_PATH}" STREQUAL "GTEST_LIB_PATH-NOTFOUND")
441+
# CMake >= 3.20 makes GTest::GTest into a compatibility target. The real import location is in
442+
# GTest::gtest.
443+
get_target_property(GTEST_LIB_PATH GTest::gtest IMPORTED_LOCATION)
444+
if("${GTEST_LIB_PATH}" STREQUAL "GTEST_LIB_PATH-NOTFOUND")
445+
message(FATAL_ERROR "Neither GTest::GTest nor GTets::gtest targets defined IMPORTED_LOCATION")
446+
endif()
447+
endif()
448+
get_filename_component(GTEST_LIB_DIR "${GTEST_LIB_PATH}" DIRECTORY)
449+
set_target_properties(GTest::gmock PROPERTIES
450+
IMPORTED_LOCATION "${GTEST_LIB_DIR}/libgmock.a")
451+
endif()
452+
434453
enable_testing()
435454
include(CTest)
436455
endif()
@@ -626,7 +645,7 @@ if(GTEST_FOUND)
626645
add_executable(cpptest ${TEST_SRCS})
627646
# include runtime files for unit testing
628647
target_include_directories(cpptest PUBLIC "src/runtime")
629-
target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main pthread dl)
648+
target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main GTest::gmock pthread dl)
630649
set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_ALL 1)
631650
set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1)
632651
# For some reason, compile definitions are not propagated correctly, so we manually add them here

include/tvm/runtime/metadata.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class MetadataNode : public MetadataBaseNode {
116116
public:
117117
explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {}
118118
static constexpr const char* _type_key = "metadata.MetadataNode";
119+
const char* get_c_struct_name() const override;
119120
inline int64_t version() const { return int64_t(data_->version); }
120121
inline int64_t num_inputs() const { return data_->num_inputs; }
121122
ArrayAccessor<struct TVMTensorInfo, TensorInfo> inputs();
@@ -141,6 +142,7 @@ class TensorInfoNode : public MetadataBaseNode {
141142
public:
142143
explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {}
143144
static constexpr const char* _type_key = "metadata.TensorInfoNode";
145+
const char* get_c_struct_name() const override;
144146
inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); }
145147
inline int64_t num_shape() const { return data_->num_shape; }
146148
inline ::tvm::support::Span<const int64_t, int64_t> shape() const {

include/tvm/runtime/metadata_base.h

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ namespace metadata {
4444
*/
4545
class MetadataBaseNode : public ::tvm::runtime::Object {
4646
public:
47+
virtual const char* get_c_struct_name() const = 0;
48+
4749
static constexpr const char* _type_key = "metadata.MetadataBaseNode";
4850
TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object);
4951
};
@@ -157,7 +159,7 @@ class ArrayAccessor<const char*, ::tvm::runtime::String> {
157159
*
158160
* These are separate from TIR DataType because TIR does not model structs.
159161
*/
160-
enum MetadataTypeIndex : uint8_t {
162+
enum MetadataKind : uint8_t {
161163
kUint64 = 0,
162164
kInt64 = 1,
163165
kBool = 2,
@@ -173,20 +175,37 @@ enum MetadataTypeIndex : uint8_t {
173175
*/
174176
class MetadataArrayNode : public MetadataBaseNode {
175177
public:
176-
MetadataArrayNode(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name)
177-
: array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {}
178+
MetadataArrayNode(Array<ObjectRef> array, MetadataKind kind, const char* type_key)
179+
: array(::std::move(array)), kind{kind}, type_key{type_key} {}
180+
181+
const char* get_c_struct_name() const final;
182+
183+
std::string get_element_c_struct_name() const {
184+
CHECK(kind == MetadataKind::kMetadata)
185+
<< "cannot get struct name for MetadataArray with kind=" << kind;
186+
constexpr int prefix_size = sizeof("metadata.") - 1;
187+
constexpr int suffix_size = sizeof("Node") - 1;
188+
std::string type_key_str(type_key);
189+
return std::string("TVM") +
190+
type_key_str.substr(prefix_size, type_key_str.size() - prefix_size - suffix_size);
191+
}
178192

179193
Array<ObjectRef> array;
180-
MetadataTypeIndex type_index;
181-
const char* struct_name;
194+
195+
/*! \brief Describes the storage class of the emitted struct member. */
196+
MetadataKind kind;
197+
198+
/*! \brief When `kind` is Metadata, type_key of the MetadataBaseNode used with this array. */
199+
const char* type_key;
200+
182201
static constexpr const char* _type_key = "metadata.MetadataArrayNode";
183202
TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode);
184203
};
185204

186205
/*! \brief Reference class for MetadataArray. */
187206
class MetadataArray : public MetadataBase {
188207
public:
189-
MetadataArray(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name);
208+
MetadataArray(Array<ObjectRef> array, MetadataKind kind, const char* struct_name);
190209

191210
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode);
192211
};

python/tvm/script/tir/special_stmt.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,8 @@ class PreflattenedBufferMap(SpecialStmt):
870870
Example
871871
-------
872872
.. code-block:: python
873-
T.preflattened_buffer_map({})
873+
A0 = T.match_buffer(A, (48,), dtype="float32")
874+
T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32")
874875
"""
875876

876877
def __init__(self):
@@ -892,12 +893,30 @@ def preflattened_buffer(
892893
for key, value in self.context.func_buffer_map.items():
893894
if value.same_as(postflattened):
894895
param = key
896+
break
895897

896898
assert (
897899
param is not None
898900
), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map."
899901

902+
if data is None:
903+
data = self.context.func_buffer_map[param].data
904+
900905
buffer_name: str = f"{postflattened.name}_preflatten"
906+
if align != -1:
907+
if isinstance(align, IntImm):
908+
align = align.value
909+
else:
910+
assert isinstance(align, int), f"align: want int or IntImm, got {align!r}"
911+
912+
if offset_factor != 0:
913+
if isinstance(offset_factor, IntImm):
914+
offset_factor = offset_factor.value
915+
else:
916+
assert isinstance(
917+
offset_factor, int
918+
), f"offset_factor: want int or IntImm, got {offset_factor!r}"
919+
901920
preflattened = tvm.tir.decl_buffer(
902921
shape,
903922
dtype,

python/tvm/testing/tir.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
# pylint: disable=invalid-name, import-outside-toplevel, unused-variable
1818
"""Common utility functions in TVM tir"""
1919
import inspect
20+
import re
2021
import tvm
2122
from tvm.ir.diagnostics import override_renderer
2223

2324

25+
CHECK_ERROR_RE = re.compile(r"^.*# check_error: (.+)$")
26+
27+
2428
def check_error(func, rel_lineno):
2529
"""check if TIR script throws error"""
2630
# Override the default renderer to accumulate errors
@@ -46,3 +50,12 @@ def render(e):
4650
assert (
4751
d.span.line - 1 == rel_lineno
4852
), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}"
53+
54+
error_line = source_code.split("\n")[rel_lineno]
55+
m = CHECK_ERROR_RE.match(error_line)
56+
if m:
57+
expected_error_text = m.group(1)
58+
errors = [e.message for e in errors]
59+
assert (
60+
expected_error_text in errors
61+
), f'check_error expects "{expected_error_text} in str(errors): {errors}'

src/printer/tir_text_printer.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,17 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
151151
doc << Doc::Indent(
152152
2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
153153
}
154+
155+
if (op->preflattened_buffer_map.size() != 0) {
156+
// print preflattened_buffer_map
157+
std::vector<Doc> preflattened_buffer_map_doc;
158+
for (auto& v : op->preflattened_buffer_map) {
159+
preflattened_buffer_map_doc.push_back(Print(v.first) << ": " << Print(v.second));
160+
}
161+
doc << Doc::Indent(2, Doc::NewLine()
162+
<< "preflattened_buffer_map = {"
163+
<< PrintSep(preflattened_buffer_map_doc, Doc::Text(", ")) << "}");
164+
}
154165
doc << PrintBody(op->body);
155166
return doc;
156167
}

0 commit comments

Comments
 (0)