Skip to content

Commit a01a38e

Browse files
Giuseppe RossiniMousius
andcommitted
[AOT] Introducing AOT in TVM
This change adds the code generation and minimal runtime API to use the Ahead Of Time (AOT) compilation flow. The main logic is contained in: - src/relay/backend/aot_codegen.cc Which produces a TIR PrimFunc traversing the Relay graph The runtime interface (authored by @Mousius) leaves a gap for future iterations using platform-specific features from RTOS. Currently AOT runs successfully on x86 in a host OS, running these tests on micro is coming soon. This PR is based on the RFC described here: https://discuss.tvm.apache.org/t/implementing-aot-in-tvm/9206 Co-authored-by: Christopher Sidebottom <[email protected]> Change-Id: I9f731c953231f129e1472298915dddc01788efd7
1 parent 43ec869 commit a01a38e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2140
-88
lines changed

cmake/modules/StandaloneCrt.cmake

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ if(USE_MICRO)
4040
"3rdparty/dmlc-core/include *.h -> include"
4141
"include/tvm/runtime c_*_api.h -> include/tvm/runtime"
4242
"include/tvm/runtime/crt *.h -> include/tvm/runtime/crt"
43+
"include/tvm/runtime/crt/aot *.h -> src/runtime/crt/aot"
4344
"src/runtime/crt Makefile -> ."
4445
"src/runtime/crt/include *.h -> include"
4546
"src/runtime/crt/common *.c -> src/runtime/crt/common"
@@ -48,6 +49,7 @@ if(USE_MICRO)
4849
"src/runtime/crt/host crt_config.h -> template/host"
4950
"src/runtime/crt/host *.cc -> template/host"
5051
"src/runtime/crt/memory *.c -> src/runtime/crt/memory"
52+
"src/runtime/crt/aot *.c -> src/runtime/crt/aot"
5153
"src/runtime/crt/utvm_rpc_common *.cc -> src/runtime/crt/utvm_rpc_common"
5254
"src/runtime/crt/utvm_rpc_server *.cc -> src/runtime/crt/utvm_rpc_server"
5355
"src/runtime/minrpc *.h -> src/runtime/minrpc"
@@ -135,6 +137,7 @@ if(USE_MICRO)
135137
file(GLOB TEST_SRCS ${CMAKE_SOURCE_DIR}/tests/crt/*_test.cc)
136138
find_path(GTEST_INCLUDE_DIR gtest/gtest.h)
137139
find_library(GTEST_LIB gtest "$ENV{GTEST_LIB}")
140+
set(aot_executor_src "${standalone_crt_base}/src/runtime/crt/aot/tvm_executor.c")
138141

139142
# Create the `crttest` target if we can find GTest. If not, we create dummy
140143
# targets that give the user an informative error message.
@@ -144,7 +147,9 @@ if(USE_MICRO)
144147
string(REPLACE ".cc" "" __execname ${__srcname})
145148
add_executable(${__execname} ${__srcpath})
146149
list(APPEND TEST_EXECS ${__execname})
147-
target_include_directories(${__execname} PUBLIC ${GTEST_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include ${CMAKE_SOURCE_DIR}/src/runtime/crt/host)
150+
target_sources(${__execname} PRIVATE ${aot_executor_src})
151+
target_include_directories(${__execname} PUBLIC ${GTEST_INCLUDE_DIR} ${CMAKE_SOURCE_DIR}/src/runtime/crt/host)
152+
target_include_directories(${__execname} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/src/runtime/crt/aot)
148153
target_compile_options(${__execname} PRIVATE -pthread)
149154
target_link_libraries(${__execname} ${cmake_crt_libraries} ${GTEST_LIB} pthread)
150155
set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file include/tvm/runtime/crt/aot/tvm_backend.h
22+
* \brief Backend functions for the AOT executor
23+
*
24+
* These are not designed to user-facing and may change without warning
25+
*/
26+
27+
#ifndef TVM_RUNTIME_CRT_AOT_TVM_BACKEND_H_
28+
#define TVM_RUNTIME_CRT_AOT_TVM_BACKEND_H_
29+
30+
#include <stddef.h>
31+
#include <stdint.h>
32+
33+
#include "tvm_error.h"
34+
35+
#ifdef __cplusplus
36+
extern "C" {
37+
#endif
38+
39+
/*! Memory alignment for allocator */
40+
#ifndef TVM_RUNTIME_ALLOC_ALIGNMENT
41+
#define TVM_RUNTIME_ALLOC_ALIGNMENT 16
42+
#endif
43+
44+
/*! The AOT runtime links staticly */
45+
#define TVM_DLL
46+
47+
/*!
48+
* \brief Minimal TVMValue
49+
*/
50+
typedef union {
51+
int64_t v_int64; /** Currently used for parameter lookup */
52+
void* v_handle; /** Pointer to other values */
53+
} TVMValue;
54+
55+
/*!
56+
* \brief Packed function signature definition
57+
*/
58+
typedef int32_t(tvm_function_t)(void* args, void* arg_type_ids, int32_t num_args,
59+
void* out_ret_value, void* out_ret_tcode, void* resource_handle);
60+
61+
/*!
62+
* \brief Workspace memory structure
63+
*/
64+
typedef struct {
65+
uint8_t* next_alloc; /** Pointer to the next block of bytes to allocate */
66+
uint8_t* workspace; /** Pointer to start of the workspace */
67+
size_t workspace_size; /** Total number of bytes in the workspace */
68+
} tvm_workspace_t;
69+
70+
/**
71+
* \brief Backend function to allocate temporal workspace.
72+
*
73+
* \note The result allocated space is ensured to be aligned to TVM_RUNTIME_ALLOC_ALIGNMENT.
74+
* \note Currently matches CRT runtime signature but this will change in future to accommodate
75+
* memory planning
76+
*
77+
* \param device_type Ignored
78+
* \param device_id Ignored
79+
* \param nbytes The size of the space requested.
80+
* \param dtype_code_hint Ignored
81+
* \param dtype_bits_hint Ignored
82+
* \return void* NULL on error, a valid pointer on success
83+
*/
84+
void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, int dtype_code_hint,
85+
int dtype_bits_hint);
86+
87+
/*!
88+
* \brief Backend function to free temporal workspace.
89+
*
90+
* \note Currently matches CRT runtime signature but this will change in future to accomodate memory
91+
* planning
92+
*
93+
* \param ptr The result allocated space pointer.
94+
* \param device_type Ignored
95+
* \param device_id Ignored
96+
* \return tvm_crt_error_t Containing any error statuses
97+
*/
98+
tvm_crt_error_t TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr);
99+
100+
#ifdef __cplusplus
101+
} // extern "C"
102+
#endif
103+
104+
#endif // TVM_RUNTIME_CRT_AOT_TVM_BACKEND_H_
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file include/tvm/runtime/crt/aot/tvm_error.h
22+
* \brief Defines a subset of error codes returned by the CRT AOT executor.
23+
*/
24+
25+
#ifndef TVM_RUNTIME_CRT_AOT_TVM_ERROR_H_
26+
#define TVM_RUNTIME_CRT_AOT_TVM_ERROR_H_
27+
28+
#ifdef __cplusplus
29+
extern "C" {
30+
#endif
31+
32+
#define TVM_CRT_ERROR_CATEGORY_Pos 8
33+
#define TVM_CRT_ERROR_CATEGORY_Msk (0xff << TVM_CRT_ERROR_CATEGORY_Pos)
34+
#define TVM_CRT_ERROR_CODE_Pos 0
35+
#define TVM_CRT_ERROR_CODE_Msk (0xff << TVM_CRT_ERROR_CODE_Pos)
36+
37+
#define DEFINE_TVM_CRT_ERROR(category, code) \
38+
(((category) << TVM_CRT_ERROR_CATEGORY_Pos) | ((code) << TVM_CRT_ERROR_CODE_Pos))
39+
typedef enum {
40+
kTvmErrorCategoryPlatform = 5,
41+
kTvmErrorCategoryFunctionCall = 8,
42+
} tvm_crt_error_category_t;
43+
44+
typedef enum {
45+
kTvmErrorNoError = 0,
46+
47+
// Platform
48+
kTvmErrorPlatformCheckFailure = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 0),
49+
kTvmErrorPlatformMemoryManagerInitialized = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 1),
50+
kTvmErrorPlatformShutdown = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 2),
51+
kTvmErrorPlatformNoMemory = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 3),
52+
kTvmErrorPlatformTimerBadState = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 4),
53+
54+
// Function Calls - common problems encountered calling functions.
55+
kTvmErrorFunctionCallNumArguments = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 0),
56+
kTvmErrorFunctionCallWrongArgType = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 1),
57+
kTvmErrorFunctionCallNotImplemented = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 2),
58+
59+
// System errors are always negative integers; this mask indicates presence of a system error.
60+
// Cast tvm_crt_error_t to a signed integer to interpret the negative error code.
61+
kTvmErrorSystemErrorMask = (1 << (sizeof(int) * 4 - 1)),
62+
} tvm_crt_error_t;
63+
64+
#ifdef __cplusplus
65+
}
66+
#endif
67+
68+
#endif // TVM_RUNTIME_CRT_AOT_TVM_ERROR_H_
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file include/tvm/runtime/crt/aot/tvm_executor.h
22+
* \brief TVM Executor for the Ahead-of-Time Runtime
23+
*
24+
* AOT models are described by the TVM model descriptor format
25+
* which can be passed to tvm_runtime_run. These descriptors will be
26+
* generated by the AOT compilation process. This can optionally be
27+
* augmented with platform specific context to be passed to the TVM
28+
* operators.
29+
*
30+
* Example:
31+
* extern tvm_model_t my_network;
32+
* int main() {
33+
* void* data = get_data();
34+
* void* output[4] = {0, 0, 0, 0};
35+
* void* inputs = {data};
36+
* void* outputs = {output};
37+
* tvm_context_t my_context = {
38+
* .driver = ...;
39+
* };
40+
* tvm_runtime_run(
41+
* &my_network,
42+
* inputs,
43+
* outputs
44+
* &my_context
45+
* );
46+
* return 0;
47+
* }
48+
*/
49+
50+
#ifndef TVM_RUNTIME_CRT_AOT_TVM_EXECUTOR_H_
51+
#define TVM_RUNTIME_CRT_AOT_TVM_EXECUTOR_H_
52+
53+
#include <stdint.h>
54+
55+
#include "tvm_backend.h"
56+
#include "tvm_error.h"
57+
58+
#ifdef __cplusplus
59+
extern "C" {
60+
#endif
61+
62+
/*!
63+
* \brief Context information for future integrations
64+
* which is passed through to the operators.
65+
*
66+
* \note Can be used for drivers and platform specific information.
67+
*/
68+
typedef struct {
69+
} tvm_context_t;
70+
71+
/*!
72+
* \brief TVM Model descriptor to describe the
73+
* model to the runtime.
74+
*/
75+
typedef struct {
76+
uint32_t num_input_tensors; /** Number of expected input tensors */
77+
uint32_t num_output_tensors; /** Number of expected output tensors */
78+
tvm_function_t* run_func; /** Generated model function, called through tvm_runtime_run */
79+
tvm_workspace_t* workspace; /** Memory workspace for the model to use */
80+
} tvm_model_t;
81+
82+
/*!
83+
* \brief Main entry point for
84+
* \param model Model descriptor structure to reference for runtime information
85+
* \param inputs Pointer to input pointer(s)
86+
* \param outputs Pointer to output pointer(s)
87+
* \param context Context information to be passed through to operators
88+
* \return tvm_status_t containing success or errors from the model run
89+
*/
90+
tvm_crt_error_t tvm_runtime_run(const tvm_model_t* model, void** inputs, void** outputs,
91+
tvm_context_t* context);
92+
93+
#ifdef __cplusplus
94+
} // extern "C"
95+
#endif
96+
97+
#endif // TVM_RUNTIME_CRT_AOT_TVM_EXECUTOR_H_

include/tvm/runtime/module.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ constexpr const char* tvm_module_main = "__tvm_main__";
230230
constexpr const char* tvm_param_prefix = "__tvm_param__";
231231
/*! \brief A PackedFunc that looks up linked parameters by storage_id. */
232232
constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
233+
/*! \brief The main AOT executor function */
234+
constexpr const char* tvm_run_func_prefix = "tvm__run_func";
233235
} // namespace symbol
234236

235237
// implementations of inline functions.

include/tvm/tir/builtin.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,10 @@ TVM_DLL const Op& tvm_stack_make_array();
343343
*/
344344
TVM_DLL const Op& tvm_call_packed();
345345

346+
// This achieve the same of a packed call, but with an extern call
347+
// directly to the operator
348+
TVM_DLL const Op& tvm_call_unpacked();
349+
346350
/*!
347351
* \brief See pesudo code
348352
*

python/tvm/driver/tvmc/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def compile_model(
241241

242242
# TODO we need to update this return to use the updated graph module APIs
243243
# as these getter functions will be deprecated in the next release (@leandron)
244-
return graph_module.get_json(), graph_module.get_lib(), graph_module.get_params(), dumps
244+
return graph_module.get_graph(), graph_module.get_lib(), graph_module.get_params(), dumps
245245

246246

247247
def save_module(module_path, graph, lib, params, cross=None):

0 commit comments

Comments
 (0)