Skip to content

Commit 9cc4620

Browse files
committed
[MetaSchedule] Introduce Async Pipeline in MultiLevelTiling
This PR introduces async pipeline in the current TVM's MultiLevelTiling Rules. This PR is blocking on #13966 since some conv2d workload will use `tir.if_then_else` to pad the input to the correct size, and this PR uses async copy in such copy statement. 1. Add a subrule in `src/meta_schedule/schedule_rule/multi_level_tiling.h/.cc` that annotate async copy for mlt. In CUDA Core, this PR has a perf boost of around 1T GFLOP/s in most Conv2d test cases and 1T ~ 2T in most GEMM test cases. All generated codes, scripts, and traces are available at https://github.com/Rainy-Memory/tvm-async-rule-benchmark. Currently tested on commit `afbfb7aa7e43732cb716f8e443df696110be6afc` in conv2d NHWC workload, with a RTX 3080 GPU. Workload: Conv2d NHWC |Shape|Mainline TVM|Mainline TVM with Async| |-|-|-| |N=1_H=224_W=224_C=3_K=64_R=7_S=7_STR=2_PAD=3_DIL=1|13838.05219|14687.89452| |N=1_H=56_W=56_C=64_K=64_R=1_S=1_STR=1_PAD=0_DIL=1|5398.305085|5613.892553| |N=1_H=56_W=56_C=64_K=64_R=3_S=3_STR=1_PAD=1_DIL=1|11652.96825|13157.88249| |N=1_H=56_W=56_C=64_K=256_R=1_S=1_STR=1_PAD=0_DIL=1|10638.8309|11674.68499| |N=1_H=56_W=56_C=256_K=64_R=1_S=1_STR=1_PAD=0_DIL=1|8692.32829|9469.264089| |N=1_H=56_W=56_C=256_K=128_R=1_S=1_STR=2_PAD=0_DIL=1|4685.767442|5698.19634| |N=1_H=28_W=28_C=128_K=128_R=3_S=3_STR=1_PAD=1_DIL=1|9872.787087|10404.60405| |N=1_H=28_W=28_C=128_K=512_R=1_S=1_STR=1_PAD=0_DIL=1|9974.281496|10073.31657| |N=1_H=28_W=28_C=512_K=128_R=1_S=1_STR=1_PAD=0_DIL=1|7075.866932|8564.572712| |N=1_H=28_W=28_C=512_K=256_R=1_S=1_STR=2_PAD=0_DIL=1|3648.330914|4021.923142| |N=1_H=14_W=14_C=256_K=256_R=3_S=3_STR=1_PAD=1_DIL=1|8192.954618|9160.182054| |N=1_H=14_W=14_C=256_K=1024_R=1_S=1_STR=1_PAD=0_DIL=1|8008.870153|9362.825279| |N=1_H=14_W=14_C=1024_K=256_R=1_S=1_STR=1_PAD=0_DIL=1|5210.062241|6051.208379| |N=1_H=14_W=14_C=1024_K=512_R=1_S=1_STR=2_PAD=0_DIL=1|2550.787202|3587.902938| |N=1_H=7_W=7_C=512_K=512_R=3_S=3_STR=1_PAD=1_DIL=1|4350.626084|5432.788068| |N=1_H=7_W=7_C=512_K=2048_R=1_S=1_STR=1_PAD=0_DIL=1|6672.068026|7663.725217| |N=1_H=7_W=7_C=2048_K=512_R=1_S=1_STR=1_PAD=0_DIL=1|3142.564263|4297.988014| Workload: GEMM NN |Shape|Mainline TVM|Mainline TVM with Async| |-|-|-| |M=512_N=256_K=640|8678.46|10607.37| |M=512_N=384_K=256|8109.13|10290.72| |M=512_N=512_K=512|11419.83|14000.86| |M=512_N=3072_K=768|19709.39|18351.61| |M=512_N=768_K=3072|12844.59|13730.88| |M=896_N=896_K=896|16149.91|16131.39| |M=1024_N=1024_K=1024|18842.11|19662.8| |M=1152_N=1152_K=1152|15386.79|16736.1| |M=1536_N=1536_K=1536|18522.67|18872.06| |M=2048_N=2048_K=2048|19515.42|18874.85| |M=3072_N=3072_K=3072|19233.9|19291.42| |M=4096_N=4096_K=4096|17122.17|19259.01|
1 parent 49b6c3a commit 9cc4620

File tree

116 files changed

+2754
-1940
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

116 files changed

+2754
-1940
lines changed

apps/microtvm/arduino/template_project/microtvm_api_server.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ def _disassemble_mlf(self, mlf_tar_path, source_dir):
197197
metadata = json.load(f)
198198
return metadata
199199

200-
def _template_model_header(self, source_dir, metadata):
201-
with open(source_dir / "model.h", "r") as f:
200+
def _template_model(self, source_dir, metadata):
201+
with open(source_dir / "platform.c", "r") as f:
202202
model_h_template = Template(f.read())
203203

204204
all_module_names = []
@@ -218,7 +218,7 @@ def _template_model_header(self, source_dir, metadata):
218218
"workspace_size_bytes": workspace_size_bytes,
219219
}
220220

221-
with open(source_dir / "model.h", "w") as f:
221+
with open(source_dir / "platform.c", "w") as f:
222222
f.write(model_h_template.substitute(template_values))
223223

224224
# Arduino ONLY recognizes .ino, .ccp, .c, .h
@@ -415,9 +415,9 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec
415415
metadata = self._disassemble_mlf(model_library_format_path, source_dir)
416416
shutil.copy2(model_library_format_path, project_dir / MODEL_LIBRARY_FORMAT_RELPATH)
417417

418-
# For AOT, template model.h with metadata to minimize space usage
418+
# For AOT, template platform.c with metadata to minimize space usage
419419
if project_type == "example_project":
420-
self._template_model_header(source_dir, metadata)
420+
self._template_model(source_dir, metadata)
421421

422422
self._change_cpp_file_extensions(source_dir)
423423

apps/microtvm/arduino/template_project/src/example_project/model.c renamed to apps/microtvm/arduino/template_project/src/example_project/platform.c

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,22 @@
1717
* under the License.
1818
*/
1919

20-
#include "model.h"
20+
/*!
21+
* \brief Implementation of TVMPlatform functions in tvm/runtime/crt/platform.h
22+
*/
2123

2224
#include "Arduino.h"
2325
#include "standalone_crt/include/dlpack/dlpack.h"
2426
#include "standalone_crt/include/tvm/runtime/crt/stack_allocator.h"
2527

28+
#define TVM_WORKSPACE_SIZE_BYTES $workspace_size_bytes
29+
2630
// AOT memory array, stack allocator wants it aligned
27-
static uint8_t g_aot_memory[WORKSPACE_SIZE]
31+
static uint8_t g_aot_memory[TVM_WORKSPACE_SIZE_BYTES]
2832
__attribute__((aligned(TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES)));
2933
tvm_workspace_t app_workspace;
3034

35+
// Called when an internal error occurs and execution cannot continue.
3136
// Blink code for debugging purposes
3237
void TVMPlatformAbort(tvm_crt_error_t error) {
3338
TVMLogf("TVMPlatformAbort: 0x%08x\n", error);
@@ -45,19 +50,23 @@ void TVMPlatformAbort(tvm_crt_error_t error) {
4550
}
4651
}
4752

48-
void TVMLogf(const char* msg, ...) {}
49-
53+
// Allocate memory for use by TVM.
5054
tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) {
5155
return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr);
5256
}
5357

58+
// Free memory used by TVM.
5459
tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) {
5560
return StackMemoryManager_Free(&app_workspace, ptr);
5661
}
5762

63+
// Internal logging API call implementation.
64+
void TVMLogf(const char* msg, ...) {}
65+
5866
unsigned long g_utvm_start_time_micros;
5967
int g_utvm_timer_running = 0;
6068

69+
// Start a device timer.
6170
tvm_crt_error_t TVMPlatformTimerStart() {
6271
if (g_utvm_timer_running) {
6372
return kTvmErrorPlatformTimerBadState;
@@ -67,6 +76,7 @@ tvm_crt_error_t TVMPlatformTimerStart() {
6776
return kTvmErrorNoError;
6877
}
6978

79+
// Stop the running device timer and get the elapsed time (in microseconds).
7080
tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
7181
if (!g_utvm_timer_running) {
7282
return kTvmErrorPlatformTimerBadState;
@@ -77,14 +87,19 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
7787
return kTvmErrorNoError;
7888
}
7989

90+
// Fill a buffer with random data.
8091
tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
8192
for (size_t i = 0; i < num_bytes; i++) {
8293
buffer[i] = rand();
8394
}
8495
return kTvmErrorNoError;
8596
}
8697

87-
void TVMInitialize() { StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE); }
98+
// Initialize TVM inference.
99+
tvm_crt_error_t TVMPlatformInitialize() {
100+
StackMemoryManager_Init(&app_workspace, g_aot_memory, sizeof(g_aot_memory));
101+
return kTvmErrorNoError;
102+
}
88103

89104
void TVMExecute(void* input_data, void* output_data) {
90105
int ret_val = tvmgen_default___tvm_main__(input_data, output_data);

apps/microtvm/arduino/template_project/src/example_project/model.h renamed to apps/microtvm/arduino/template_project/src/example_project/platform.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,10 @@
1717
* under the License.
1818
*/
1919

20-
#define WORKSPACE_SIZE $workspace_size_bytes
21-
2220
#ifdef __cplusplus
2321
extern "C" {
2422
#endif
2523

26-
void TVMInitialize();
27-
2824
/* TODO template this function signature with the input and output
2925
* data types and sizes. For example:
3026
*

apps/microtvm/arduino/template_project/src/example_project/project.ino

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
* under the License.
1818
*/
1919

20-
#include "src/model.h"
20+
#include "src/standalone_crt/include/tvm/runtime/crt/platform.h"
2121

2222
void setup() {
23-
TVMInitialize();
23+
TVMPlatformInitialize();
2424
// If desired, initialize the RNG with random noise
2525
// randomSeed(analogRead(0));
2626
}

apps/microtvm/arduino/template_project/src/host_driven/model_support.c renamed to apps/microtvm/arduino/template_project/src/host_driven/platform.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,28 @@
1717
* under the License.
1818
*/
1919

20+
/*!
21+
* \brief Implementation of TVMPlatform functions in tvm/runtime/crt/platform.h
22+
*/
23+
2024
#include "standalone_crt/include/dlpack/dlpack.h"
2125
#include "standalone_crt/include/tvm/runtime/crt/error_codes.h"
2226
#include "stdarg.h"
2327

24-
// Blink code for debugging purposes
28+
// Called when an internal error occurs and execution cannot continue.
2529
void TVMPlatformAbort(tvm_crt_error_t error) {
2630
TVMLogf("TVMPlatformAbort: 0x%08x\n", error);
2731
for (;;)
2832
;
2933
}
3034

35+
// Called by the microTVM RPC server to implement TVMLogf.
3136
size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt,
3237
va_list args) {
3338
return vsnprintf(out_buf, out_buf_size_bytes, fmt, args);
3439
}
3540

41+
// Allocate memory for use by TVM.
3642
tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) {
3743
if (num_bytes == 0) {
3844
num_bytes = sizeof(int);
@@ -41,6 +47,7 @@ tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void**
4147
return (*out_ptr == NULL) ? kTvmErrorPlatformNoMemory : kTvmErrorNoError;
4248
}
4349

50+
// Free memory used by TVM.
4451
tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) {
4552
free(ptr);
4653
return kTvmErrorNoError;
@@ -49,6 +56,7 @@ tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) {
4956
unsigned long g_utvm_start_time_micros;
5057
int g_utvm_timer_running = 0;
5158

59+
// Start a device timer.
5260
tvm_crt_error_t TVMPlatformTimerStart() {
5361
if (g_utvm_timer_running) {
5462
return kTvmErrorPlatformTimerBadState;
@@ -58,6 +66,7 @@ tvm_crt_error_t TVMPlatformTimerStart() {
5866
return kTvmErrorNoError;
5967
}
6068

69+
// Stop the running device timer and get the elapsed time (in microseconds).
6170
tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
6271
if (!g_utvm_timer_running) {
6372
return kTvmErrorPlatformTimerBadState;
@@ -68,6 +77,7 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
6877
return kTvmErrorNoError;
6978
}
7079

80+
// Fill a buffer with random data.
7181
tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
7282
for (size_t i = 0; i < num_bytes; i++) {
7383
buffer[i] = rand();

apps/microtvm/zephyr/template_project/CMakeLists.txt.template

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,4 @@ endif()
8383

8484
file(GLOB_RECURSE app_srcs src/**.c src/**.cc)
8585
target_sources(app PRIVATE ${app_srcs} ${cmsis_lib_srcs})
86-
target_include_directories(app PRIVATE crt_config ${CMAKE_SOURCE_DIR}/include crt/include ${cmsis_includes})
86+
target_include_directories(app PRIVATE crt_config include ${CMAKE_SOURCE_DIR}/include crt/include ${cmsis_includes})

apps/microtvm/zephyr/template_project/microtvm_api_server.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,14 @@ def _get_board_mem_size_bytes(zephyr_base: str, board: str):
210210
return None
211211

212212

213-
DEFAULT_HEAP_SIZE_BYTES = 216 * 1024
213+
DEFAULT_WORKSPACE_SIZE_BYTES = 216 * 1024
214214

215215

216216
def _get_recommended_heap_size_bytes(board: str):
217217
prop = BOARD_PROPERTIES[board]
218218
if "recommended_heap_size_bytes" in prop:
219219
return prop["recommended_heap_size_bytes"]
220-
return DEFAULT_HEAP_SIZE_BYTES
220+
return DEFAULT_WORKSPACE_SIZE_BYTES
221221

222222

223223
def generic_find_serial_port(serial_number: str = None):
@@ -358,11 +358,11 @@ def _get_nrf_device_args(serial_number: str = None) -> list:
358358
help="Run on the FVP emulator instead of hardware.",
359359
),
360360
server.ProjectOption(
361-
"heap_size_bytes",
361+
"workspace_size_bytes",
362362
optional=["generate_project"],
363363
type="int",
364364
default=None,
365-
help="Sets the value for HEAP_SIZE_BYTES passed to K_HEAP_DEFINE() to service TVM memory allocation requests.",
365+
help="Sets the value for TVM_WORKSPACE_SIZE_BYTES passed to K_HEAP_DEFINE() to service TVM memory allocation requests.",
366366
),
367367
]
368368

@@ -403,7 +403,13 @@ def server_info_query(self, tvm_version):
403403
}
404404

405405
def _create_prj_conf(
406-
self, project_dir: pathlib.Path, board: str, project_type: str, config_main_stack_size
406+
self,
407+
project_dir: pathlib.Path,
408+
board: str,
409+
project_type: str,
410+
config_main_stack_size: int,
411+
config_led: bool,
412+
use_fvp: bool,
407413
):
408414
with open(project_dir / "prj.conf", "w") as f:
409415
f.write(
@@ -413,6 +419,13 @@ def _create_prj_conf(
413419
"CONFIG_UART_INTERRUPT_DRIVEN=y\n"
414420
"\n"
415421
)
422+
if (
423+
config_led
424+
and not self._is_qemu(board, use_fvp)
425+
and not self._is_fvp(board, use_fvp)
426+
):
427+
f.write("# For debugging.\n" "CONFIG_LED=y\n" "\n")
428+
416429
f.write("# For TVMPlatformAbort().\n" "CONFIG_REBOOT=y\n" "\n")
417430

418431
if project_type == "host_driven":
@@ -522,6 +535,18 @@ def _generate_cmake_args(
522535

523536
return cmake_args
524537

538+
def _copy_src_and_header_files(self, src_dir: pathlib.Path, dst_dir: pathlib.Path):
539+
"""Copy content of src_dir from template project to dst_dir in separate
540+
source and header sub-directories.
541+
"""
542+
for file in os.listdir(src_dir):
543+
file = src_dir / file
544+
if file.is_file():
545+
if file.suffix in [".cc", ".c"]:
546+
shutil.copy2(file, dst_dir / "src")
547+
elif file.suffix in [".h"]:
548+
shutil.copy2(file, dst_dir / "include" / "tvm")
549+
525550
def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options):
526551
zephyr_board = options["board"]
527552
project_type = options["project_type"]
@@ -533,7 +558,7 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec
533558
verbose = options.get("verbose")
534559

535560
recommended_heap_size = _get_recommended_heap_size_bytes(zephyr_board)
536-
heap_size_bytes = options.get("heap_size_bytes") or recommended_heap_size
561+
workspace_size_bytes = options.get("workspace_size_bytes") or recommended_heap_size
537562
board_mem_size = _get_board_mem_size_bytes(zephyr_base, zephyr_board)
538563

539564
compile_definitions = options.get("compile_definitions")
@@ -602,7 +627,7 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec
602627
else:
603628
shutil.copy2(src_path, dst_path)
604629

605-
# Populate Makefile.
630+
# Populate CMakeLists.
606631
with open(project_dir / CMAKELIST_FILENAME, "w") as cmake_f:
607632
with open(API_SERVER_DIR / f"{CMAKELIST_FILENAME}.template", "r") as cmake_template_f:
608633
for line in cmake_template_f:
@@ -629,10 +654,10 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec
629654

630655
if board_mem_size is not None:
631656
assert (
632-
heap_size_bytes < board_mem_size
633-
), f"Heap size {heap_size_bytes} is larger than memory size {board_mem_size} on this board."
657+
workspace_size_bytes < board_mem_size
658+
), f"Workspace size {workspace_size_bytes} is larger than memory size {board_mem_size} on this board."
634659
cmake_f.write(
635-
f"target_compile_definitions(app PUBLIC -DHEAP_SIZE_BYTES={heap_size_bytes})\n"
660+
f"target_compile_definitions(app PUBLIC -DTVM_WORKSPACE_SIZE_BYTES={workspace_size_bytes})\n"
636661
)
637662

638663
if compile_definitions:
@@ -649,7 +674,9 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec
649674
if self._is_fvp(zephyr_board, use_fvp):
650675
cmake_f.write(f"target_compile_definitions(app PUBLIC -DFVP=1)\n")
651676

652-
self._create_prj_conf(project_dir, zephyr_board, project_type, config_main_stack_size)
677+
self._create_prj_conf(
678+
project_dir, zephyr_board, project_type, config_main_stack_size, verbose, use_fvp
679+
)
653680

654681
# Populate crt-config.h
655682
crt_config_dir = project_dir / "crt_config"
@@ -658,13 +685,19 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec
658685
API_SERVER_DIR / "crt_config" / "crt_config.h", crt_config_dir / "crt_config.h"
659686
)
660687

661-
# Populate src/
688+
# Populate `src` and `include`
662689
src_dir = project_dir / "src"
663-
if project_type != "host_driven" or self._is_fvp(zephyr_board, use_fvp):
664-
shutil.copytree(API_SERVER_DIR / "src" / project_type, src_dir)
665-
else:
666-
src_dir.mkdir()
667-
shutil.copy2(API_SERVER_DIR / "src" / project_type / "main.c", src_dir)
690+
src_dir.mkdir()
691+
include_dir = project_dir / "include" / "tvm"
692+
include_dir.mkdir(parents=True)
693+
src_project_type_dir = API_SERVER_DIR / "src" / project_type
694+
self._copy_src_and_header_files(src_project_type_dir, project_dir)
695+
696+
if self._is_fvp(zephyr_board, use_fvp):
697+
self._copy_src_and_header_files(src_project_type_dir / "fvp", project_dir)
698+
699+
if project_type == "mlperftiny":
700+
shutil.copytree(src_project_type_dir / "api", src_dir / "api")
668701

669702
# Populate extra_files
670703
if extra_files_tar:

0 commit comments

Comments
 (0)