Skip to content

Commit d69472e

Browse files
committed
[RPC][RUNTIME] Support dynamic reload of runtime API according to config (apache#19)
1 parent 981559c commit d69472e

File tree

17 files changed

+433
-211
lines changed

17 files changed

+433
-211
lines changed

Makefile

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,36 +40,31 @@ ifneq ($(ADD_LDFLAGS), NONE)
4040
LDFLAGS += $(ADD_LDFLAGS)
4141
endif
4242

43-
ifeq ($(UNAME_S), Darwin)
44-
SHARED_LIBRARY_SUFFIX := dylib
45-
WHOLE_ARCH= -all_load
46-
NO_WHOLE_ARCH= -noall_load
47-
LDFLAGS += -undefined dynamic_lookup
48-
else
49-
SHARED_LIBRARY_SUFFIX := so
50-
WHOLE_ARCH= --whole-archive
51-
NO_WHOLE_ARCH= --no-whole-archive
52-
endif
5343

54-
55-
all: lib/libvta.$(SHARED_LIBRARY_SUFFIX)
44+
all: lib/libvta.so lib/libvta_runtime.so
5645

5746
VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc)
47+
5848
ifeq ($(TARGET), VTA_PYNQ_TARGET)
5949
VTA_LIB_SRC += $(wildcard src/pynq/*.cc)
6050
LDFLAGS += -L/usr/lib -lsds_lib
61-
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ -l:libdma.so
51+
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
52+
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/
53+
LDFLAGS += -l:libdma.so
6254
endif
63-
VTA_LIB_OBJ = $(patsubst %.cc, build/%.o, $(VTA_LIB_SRC))
6455

65-
test: $(TEST)
56+
VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))
6657

67-
build/src/%.o: src/%.cc
58+
build/%.o: src/%.cc
6859
@mkdir -p $(@D)
69-
$(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/src/$*.d
60+
$(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/$*.d
7061
$(CXX) -c $(CFLAGS) -c $< -o $@
7162

72-
lib/libvta.$(SHARED_LIBRARY_SUFFIX): $(VTA_LIB_OBJ)
63+
lib/libvta.so: $(filter-out build/runtime.o, $(VTA_LIB_OBJ))
64+
@mkdir -p $(@D)
65+
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)
66+
67+
lib/libvta_runtime.so: build/runtime.o
7368
@mkdir -p $(@D)
7469
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)
7570

@@ -79,7 +74,7 @@ cpplint:
7974
python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests
8075

8176
pylint:
82-
pylint python/tvm_vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
77+
pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
8378

8479
doc:
8580
doxygen docs/Doxyfile

apps/pynq_rpc/start_rpc_server.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#!/bin/bash
2-
export PYTHONPATH=${PYTHONPATH}:/home/xilinx/tvm/python
2+
export PYTHONPATH=${PYTHONPATH}:/home/xilinx/tvm/python:/home/xilinx/vta/python
33
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
4-
python -m tvm.exec.rpc_server --load-library /home/xilinx/vta/lib/libvta.so
4+
python -m vta.exec.rpc_server

examples/resnet18/pynq/imagenet_predict.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@
3434
# Program the FPGA remotely
3535
assert tvm.module.enabled("rpc")
3636
remote = rpc.connect(host, port)
37-
remote.upload(BITSTREAM_FILE, BITSTREAM_FILE)
38-
fprogram = remote.get_function("tvm.contrib.vta.init")
39-
fprogram(BITSTREAM_FILE)
37+
vta.program_fpga(remote, BITSTREAM_FILE)
4038

4139
if verbose:
4240
logging.basicConfig(level=logging.INFO)

include/vta/runtime.h

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,40 +23,20 @@ extern "C" {
2323
#define VTA_DEBUG_SKIP_WRITE_BARRIER (1 << 4)
2424
#define VTA_DEBUG_FORCE_SERIAL (1 << 5)
2525

26-
/*! \brief VTA command handle */
27-
typedef void * VTACommandHandle;
28-
29-
/*! \brief Shutdown hook of VTA to cleanup resources */
30-
void VTARuntimeShutdown();
31-
32-
/*!
33-
* \brief Get thread local command handle.
34-
* \return A thread local command handle.
35-
*/
36-
VTACommandHandle VTATLSCommandHandle();
37-
3826
/*!
3927
* \brief Allocate data buffer.
4028
* \param cmd The VTA command handle.
4129
* \param size Buffer size.
4230
* \return A pointer to the allocated buffer.
4331
*/
44-
void* VTABufferAlloc(VTACommandHandle cmd, size_t size);
32+
void* VTABufferAlloc(size_t size);
4533

4634
/*!
4735
* \brief Free data buffer.
4836
* \param cmd The VTA command handle.
4937
* \param buffer The data buffer to be freed.
5038
*/
51-
void VTABufferFree(VTACommandHandle cmd, void* buffer);
52-
53-
/*!
54-
* \brief Get the buffer access pointer on CPU.
55-
* \param cmd The VTA command handle.
56-
* \param buffer The data buffer.
57-
* \return The pointer that can be accessed by the CPU.
58-
*/
59-
void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer);
39+
void VTABufferFree(void* buffer);
6040

6141
/*!
6242
* \brief Copy data buffer from one location to another.
@@ -68,20 +48,32 @@ void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer);
6848
* \param size Size of copy.
6949
* \param kind_mask The memory copy kind.
7050
*/
71-
void VTABufferCopy(VTACommandHandle cmd,
72-
const void* from,
51+
void VTABufferCopy(const void* from,
7352
size_t from_offset,
7453
void* to,
7554
size_t to_offset,
7655
size_t size,
7756
int kind_mask);
7857

58+
/*! \brief VTA command handle */
59+
typedef void* VTACommandHandle;
60+
61+
/*! \brief Shutdown hook of VTA to cleanup resources */
62+
void VTARuntimeShutdown();
63+
7964
/*!
80-
* \brief Set debug mode on the command handle.
65+
* \brief Get thread local command handle.
66+
* \return A thread local command handle.
67+
*/
68+
VTACommandHandle VTATLSCommandHandle();
69+
70+
/*!
71+
* \brief Get the buffer access pointer on CPU.
8172
* \param cmd The VTA command handle.
82-
* \param debug_flag The debug flag.
73+
* \param buffer The data buffer.
74+
* \return The pointer that can be accessed by the CPU.
8375
*/
84-
void VTASetDebugMode(VTACommandHandle cmd, int debug_flag);
76+
void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer);
8577

8678
/*!
8779
* \brief Perform a write barrier to make a memory region visible to the CPU.
@@ -92,9 +84,10 @@ void VTASetDebugMode(VTACommandHandle cmd, int debug_flag);
9284
* \param extent The end of the region (in elements).
9385
*/
9486
void VTAWriteBarrier(VTACommandHandle cmd,
95-
void* buffer, uint32_t elem_bits,
96-
uint32_t start, uint32_t extent);
97-
87+
void* buffer,
88+
uint32_t elem_bits,
89+
uint32_t start,
90+
uint32_t extent);
9891
/*!
9992
* \brief Perform a read barrier to a memory region visible to VTA.
10093
* \param cmd The VTA command handle.
@@ -104,8 +97,17 @@ void VTAWriteBarrier(VTACommandHandle cmd,
10497
* \param extent The end of the region (in elements).
10598
*/
10699
void VTAReadBarrier(VTACommandHandle cmd,
107-
void* buffer, uint32_t elem_bits,
108-
uint32_t start, uint32_t extent);
100+
void* buffer,
101+
uint32_t elem_bits,
102+
uint32_t start,
103+
uint32_t extent);
104+
105+
/*!
106+
* \brief Set debug mode on the command handle.
107+
* \param cmd The VTA command handle.
108+
* \param debug_flag The debug flag.
109+
*/
110+
void VTASetDebugMode(VTACommandHandle cmd, int debug_flag);
109111

110112
/*!
111113
* \brief Perform a 2D data load from DRAM.

make/config.mk

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ VTA_LOG_WGT_BUFF_SIZE = 15
5454
# Log of acc buffer size in Bytes
5555
VTA_LOG_ACC_BUFF_SIZE = 17
5656

57+
5758
#---------------------
5859
# Derived VTA hardware parameters
5960
#--------------------

python/vta/__init__.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1-
"""TVM VTA runtime"""
1+
"""TVM-based VTA Compiler Toolchain"""
22
from __future__ import absolute_import as _abs
33

44
from .hw_spec import *
55

6-
from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU
7-
from .intrin import GEVM, GEMM
8-
from .build import debug_mode
6+
try:
7+
from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU
8+
from .intrin import GEVM, GEMM
9+
from .build import debug_mode
10+
from . import mock, ir_pass
11+
from . import arm_conv2d, vta_conv2d
12+
except AttributeError:
13+
pass
914

10-
from . import mock, ir_pass
11-
from . import arm_conv2d, vta_conv2d
12-
from . import graph
15+
from .rpc_client import reconfig_runtime, program_fpga
16+
17+
try:
18+
from . import graph
19+
except ImportError:
20+
pass

python/vta/exec/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""VTA Command line utils."""

python/vta/exec/rpc_server.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""VTA customized TVM RPC Server
2+
3+
Provides additional runtime function and library loading.
4+
"""
5+
from __future__ import absolute_import
6+
7+
import logging
8+
import argparse
9+
import os
10+
import ctypes
11+
import tvm
12+
from tvm.contrib import rpc, util, cc
13+
14+
15+
@tvm.register_func("tvm.contrib.rpc.server.start", override=True)
16+
def server_start():
17+
curr_path = os.path.dirname(
18+
os.path.abspath(os.path.expanduser(__file__)))
19+
dll_path = os.path.abspath(
20+
os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
21+
runtime_dll = []
22+
_load_module = tvm.get_global_func("tvm.contrib.rpc.server.load_module")
23+
24+
@tvm.register_func("tvm.contrib.rpc.server.load_module", override=True)
25+
def load_module(file_name):
26+
if not runtime_dll:
27+
runtime_dll.append(ctypes.CDLL(dll_path, ctypes.RTLD_GLOBAL))
28+
return _load_module(file_name)
29+
30+
@tvm.register_func("tvm.contrib.rpc.server.shutdown", override=True)
31+
def server_shutdown():
32+
if runtime_dll:
33+
runtime_dll[0].VTARuntimeShutdown()
34+
runtime_dll.pop()
35+
36+
@tvm.register_func("tvm.contrib.vta.reconfig_runtime", override=True)
37+
def reconfig_runtime(cflags):
38+
"""Rebuild and reload runtime with new configuration.
39+
40+
Parameters
41+
----------
42+
cfg_json : str
43+
JSON string used for configurations.
44+
"""
45+
if runtime_dll:
46+
raise RuntimeError("Can only reconfig in the beginning of session...")
47+
cflags = cflags.split()
48+
cflags += ["-O2", "-std=c++11"]
49+
lib_name = dll_path
50+
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
51+
proj_root = os.path.abspath(os.path.join(curr_path, "../../../"))
52+
runtime_source = os.path.join(proj_root, "src/runtime.cc")
53+
cflags += ["-I%s/include" % proj_root]
54+
cflags += ["-I%s/nnvm/tvm/include" % proj_root]
55+
cflags += ["-I%s/nnvm/tvm/dlpack/include" % proj_root]
56+
cflags += ["-I%s/nnvm/dmlc-core/include" % proj_root]
57+
logging.info("Rebuild runtime dll with %s", str(cflags))
58+
cc.create_shared(lib_name, [runtime_source], cflags)
59+
60+
61+
def main():
62+
"""Main funciton"""
63+
parser = argparse.ArgumentParser()
64+
parser.add_argument('--host', type=str, default="0.0.0.0",
65+
help='the hostname of the server')
66+
parser.add_argument('--port', type=int, default=9090,
67+
help='The port of the PRC')
68+
parser.add_argument('--port-end', type=int, default=9199,
69+
help='The end search port of the PRC')
70+
parser.add_argument('--key', type=str, default="",
71+
help="RPC key used to identify the connection type.")
72+
parser.add_argument('--tracker', type=str, default="",
73+
help="Report to RPC tracker")
74+
args = parser.parse_args()
75+
logging.basicConfig(level=logging.INFO)
76+
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
77+
proj_root = os.path.abspath(os.path.join(curr_path, "../../../"))
78+
lib_path = os.path.abspath(os.path.join(proj_root, "lib/libvta.so"))
79+
80+
libs = []
81+
for file_name in [lib_path]:
82+
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
83+
logging.info("Load additional library %s", file_name)
84+
85+
if args.tracker:
86+
url, port = args.tracker.split(":")
87+
port = int(port)
88+
tracker_addr = (url, port)
89+
if not args.key:
90+
raise RuntimeError(
91+
"Need key to present type of resource when tracker is available")
92+
else:
93+
tracker_addr = None
94+
95+
server = rpc.Server(args.host,
96+
args.port,
97+
args.port_end,
98+
key=args.key,
99+
tracker_addr=tracker_addr)
100+
server.libs += libs
101+
server.proc.join()
102+
103+
if __name__ == "__main__":
104+
main()

python/vta/hw_spec.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,31 @@
11
"""VTA configuration constants (should match hw_spec.h"""
22
from __future__ import absolute_import as _abs
33

4+
# Log of input/activation width in bits (default 3 -> 8 bits)
5+
VTA_LOG_INP_WIDTH = 3
6+
# Log of kernel weight width in bits (default 3 -> 8 bits)
7+
VTA_LOG_WGT_WIDTH = 3
8+
# Log of accum width in bits (default 5 -> 32 bits)
9+
VTA_LOG_ACC_WIDTH = 5
10+
# Log of tensor batch size (A in (A,B)x(B,C) matrix multiplication)
11+
VTA_LOG_BATCH = 0
12+
# Log of tensor inner block size (B in (A,B)x(B,C) matrix multiplication)
13+
VTA_LOG_BLOCK_IN = 4
14+
# Log of tensor outer block size (C in (A,B)x(B,C) matrix multiplication)
15+
VTA_LOG_BLOCK_OUT = 4
16+
VTA_LOG_OUT_WIDTH = VTA_LOG_INP_WIDTH
17+
# Log of uop buffer size in Bytes
18+
VTA_LOG_UOP_BUFF_SIZE = 15
19+
# Log of acc buffer size in Bytes
20+
VTA_LOG_ACC_BUFF_SIZE = 17
21+
422
# The Constants
523
VTA_WGT_WIDTH = 8
624
VTA_INP_WIDTH = VTA_WGT_WIDTH
725
VTA_OUT_WIDTH = 32
826

27+
VTA_TARGET = "VTA_PYNQ_TARGET"
28+
929
# Dimensions of the GEMM unit
1030
# (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT)
1131
VTA_BATCH = 1
@@ -67,4 +87,4 @@
6787
DEBUG_DUMP_INSN = (1 << 1)
6888
DEBUG_DUMP_UOP = (1 << 2)
6989
DEBUG_SKIP_READ_BARRIER = (1 << 3)
70-
DEBUG_SKIP_WRITE_BARRIER = (1 << 4)
90+
DEBUG_SKIP_WRITE_BARRIER = (1 << 4)

0 commit comments

Comments
 (0)