Skip to content

Commit 16b5877

Browse files
tmoreau89tqchen
authored andcommitted
[PYTHON, TVM] Python TVM library, unit tests and end to end example
* VTA python library * Python unit tests * End to end example with Resnet18 * README instructions * Bug fixes
1 parent e7557db commit 16b5877

35 files changed

+4046
-77
lines changed

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ endif
5555
all: lib/libvta.$(SHARED_LIBRARY_SUFFIX)
5656

5757
VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc)
58-
ifeq ($(TARGET), PYNQ_TARGET)
58+
ifeq ($(TARGET), VTA_PYNQ_TARGET)
5959
VTA_LIB_SRC += $(wildcard src/pynq/*.cc)
6060
LDFLAGS += -L/usr/lib -lsds_lib
61-
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ -l:libdma.so
61+
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/ -l:libdma.so
6262
endif
6363
VTA_LIB_OBJ = $(patsubst %.cc, build/%.o, $(VTA_LIB_SRC))
6464

@@ -79,7 +79,7 @@ cpplint:
7979
python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests
8080

8181
pylint:
82-
pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
82+
pylint python/tvm_vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
8383

8484
doc:
8585
doxygen docs/Doxyfile

apps/pynq_rpc/README.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
### PYNQ RPC Server for VTA
2+
3+
This guide describes how to setup a Pynq-based RPC server to accelerate deep learning workloads with VTA.
4+
5+
## Pynq Setup
6+
7+
Follow the getting started tutorial for the [Pynq board](http://pynq.readthedocs.io/en/latest/getting_started.html).
8+
* For this RPC setup make sure to go with the *Connect to a Computer* Ethernet setup.
9+
10+
Make sure that you can ssh into your Pynq board successfully:
11+
```bash
12+
13+
```
14+
15+
When ssh-ing onto the board, the default password for the `xilinx` account is `xilinx`.
16+
17+
For convenience let's go ahead and mount the Pynq board's file system to easily access it and maintain it:
18+
```bash
19+
sshfs [email protected]:/home/xilinx <mountpoint>
20+
```
21+
22+
## Pynq TVM & VTA installation
23+
24+
On your **host PC**, go to the `<mountpoint>` directory of your Pynq board file system.
25+
```bash
26+
cd <mountpoint>
27+
```
28+
29+
From there, clone the VTA repository:
30+
```bash
31+
git clone [email protected]:uwsaml/vta.git --recursive
32+
```
33+
34+
Next, clone the TVM repository:
35+
```bash
36+
git clone [email protected]:dmlc/tvm.git --recursive
37+
```
38+
39+
TVM is rapidly changing, and to ensure stability, we keep track of working TVM checkpoints.
40+
As of now, the TVM checkpoint `e4c2af9abdcb3c7aabafba8084414d7739c17c4c` is known to work with VTA.
41+
```bash
42+
git checkout e4c2af9abdcb3c7aabafba8084414d7739c17c4c
43+
```
44+
45+
Now, ssh into your **Pynq board** to build the TVM runtime with the following commands:
46+
```bash
47+
ssh [email protected] # ssh if you haven't done so
48+
cd ~/tvm
49+
cp make/config.mk .
50+
echo USE_RPC=1 >> config.mk
51+
make runtime -j2
52+
```
53+
54+
## Pynq RPC server setup
55+
56+
We're now ready to build the Pynq RPC server on the Pynq board.
57+
```bash
58+
ssh [email protected] # ssh if you haven't done so
59+
cd ~/vta
60+
export TVM_PATH = /home/xilinx/tvm
61+
make
62+
```
63+
64+
The last stage will build the `192.168.2.99:home/xilinx/vta/lib/libvta.so` library file. We are now ready to launch the RPC server on the Pynq. In order to enable the FPGA drivers, we need to run the RPC server with administrator privileges (using `su`, account: `xilinx`, pwd: `xilinx`).
65+
```bash
66+
ssh [email protected] # ssh if you haven't done so
67+
cd ~/vta
68+
su
69+
./apps/pynq_rpc/start_rpc_server.sh
70+
```
71+
72+
You should see the following being displayed when starting the RPC server:
73+
```
74+
INFO:root:Load additional library /home/xilinx/vta/lib/libvta.so
75+
INFO:root:RPCServer: bind to 0.0.0.0:9091
76+
```
77+
78+
Note that it should be listening on port `9091`.
79+
80+
To kill the RPC server, just enter the `Ctrl + c` command.

apps/pynq_rpc/start_rpc_server.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#!/bin/bash
22
export PYTHONPATH=${PYTHONPATH}:/home/xilinx/tvm/python
3-
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
3+
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/python3.6/lib/python3.6/site-packages/pynq/lib/
44
python -m tvm.exec.rpc_server --load-library /home/xilinx/vta/lib/libvta.so

examples/resnet18/pynq/.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
quantize_graph.json
2+
quantize_params.pkl
3+
synset.txt
4+
*.jpg
5+
vta.bit

examples/resnet18/pynq/README.md

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Resnet-18 Example on Pynq-based VTA Design
2+
3+
In order to run this example you'll need to have:
4+
* VTA installed
5+
* TVM installed
6+
* NNVM installed
7+
* A Pynq-based RPC server running
8+
9+
## VTA installation
10+
11+
Clone the VTA repository in the directory of your choosing:
12+
```bash
13+
git clone [email protected]:uwsaml/vta.git --recursive
14+
```
15+
16+
Update your `~/.bashrc` file to include the VTA python libraries in your `PYTHONPATH` (don't forget to source the newly modified `.bashrc` file!):
17+
```bash
18+
export PYTHONPATH=<vta root>/python:${PYTHONPATH}
19+
```
20+
21+
## TVM installation
22+
23+
Clone the TVM repository in the directory of your choosing:
24+
```bash
25+
git clone [email protected]:dmlc/tvm.git --recursive
26+
```
27+
28+
TVM is rapidly changing, and to ensure stability, we keep track of working TVM checkpoints.
29+
As of now, the TVM checkpoint `e4c2af9abdcb3c7aabafba8084414d7739c17c4c` is known to work with VTA.
30+
```bash
31+
git checkout e4c2af9abdcb3c7aabafba8084414d7739c17c4c
32+
```
33+
34+
Before building TVM, copy the `make/config.mk` file into the root TVM directory:
35+
```bash
36+
cd <tvm root>
37+
cp make/config.mk .
38+
```
39+
40+
In the 'config.mk' file sure that:
41+
* `LLVM_CONFIG` points to the llvm-config executable (e.g. `LLVM_CONFIG = /usr/bin/llvm-config-4.0`). You'll need to have llvm4.0 installed or later.
42+
* `USE_RPC` should be set to 1
43+
44+
Launch the compilation, this takes about 5 minutes.
45+
```bash
46+
cd <tvm root>
47+
make -j4
48+
```
49+
50+
Finally update your `~/.bashrc` file to include the TVM python libraries in your `PYTHONPATH` (don't forget to source the newly modified `.bashrc` file!):
51+
```bash
52+
export PYTHONPATH=<tvm root>/python:<tvm root>/topi/python:${PYTHONPATH}
53+
```
54+
55+
## NNVM installation
56+
57+
Clone the NNVM repository from `tqchen` in the directory of your choosing:
58+
```bash
59+
git clone [email protected]:tqchen/nnvm.git --recursive
60+
```
61+
62+
To run this example, we rely on a special branch of NNVM: `qt`:
63+
```bash
64+
cd <nnvm root>
65+
git checkout qt
66+
```
67+
68+
Launch the compilation, this takes less a minute.
69+
```bash
70+
cd <nnvm root>
71+
make -j4
72+
```
73+
74+
Finally update your `~/.bashrc` file to include the NNVM python libraries in your `PYTHONPATH` (don't forget to source the newly modified `.bashrc` file!):
75+
```bash
76+
export PYTHONPATH=<nnvm root>/python:${PYTHONPATH}
77+
```
78+
79+
## Pynq RPC Server Setup
80+
81+
Follow the [Pynq RPC Server Guide](https://github.com/saml/vta/tree/master/apps/pynq_rpc/README.md)
82+
83+
## Running the example
84+
85+
Simply run the following python script:
86+
```bash
87+
python imagenet_predict.py
88+
```
89+
90+
This will run imagenet classification using the ResNet18 architecture on a VTA design that performs 8-bit integer inference, to perform classification on a cat image `cat.jpg`.
91+
92+
The script reports runtime measured on the Pynq board, and the top-1 result category:
93+
```
94+
('x', (1, 3, 224, 224))
95+
Build complete...
96+
('TVM prediction top-1:', 281, 'tabby, tabby cat')
97+
t-cost=0.41906
98+
```
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# some standard imports
2+
import nnvm
3+
import tvm
4+
from nnvm.compiler import graph_attr
5+
import vta
6+
import os
7+
import numpy as np
8+
from PIL import Image
9+
import pickle
10+
import json
11+
import logging
12+
import wget
13+
from tvm.contrib import graph_runtime, rpc, util
14+
15+
factor = 16
16+
host = "pynq"
17+
port = 9091
18+
verbose = False
19+
# only run fpga component, mark non-conv ops as nop
20+
debug_fpga_only = False
21+
22+
# Obtain model and hardware files (they're too large to check-in)
23+
url = "https://homes.cs.washington.edu/~moreau/media/vta/"
24+
TEST_FILE = 'cat.jpg'
25+
CATEG_FILE = 'synset.txt'
26+
RESNET_GRAPH_FILE = 'quantize_graph.json'
27+
RESNET_PARAMS_FILE = 'quantize_params.pkl'
28+
BITSTREAM_FILE = 'vta.bit'
29+
for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE, BITSTREAM_FILE]:
30+
if not os.path.isfile(file):
31+
print "Downloading {}".format(file)
32+
wget.download(url+file)
33+
34+
# Program the FPGA remotely
35+
assert tvm.module.enabled("rpc")
36+
remote = rpc.connect(host, port)
37+
remote.upload(BITSTREAM_FILE, BITSTREAM_FILE)
38+
fprogram = remote.get_function("tvm.contrib.vta.init")
39+
fprogram(BITSTREAM_FILE)
40+
41+
if verbose:
42+
logging.basicConfig(level=logging.INFO)
43+
44+
# Change to -device=tcpu to run cpu only inference.
45+
target = "llvm -device=vta"
46+
47+
synset = eval(open(os.path.join(CATEG_FILE)).read())
48+
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))
49+
50+
def transform_image(image):
51+
image = np.array(image) - np.array([123., 117., 104.])
52+
image /= np.array([58.395, 57.12, 57.375])
53+
image = image.transpose((2, 0, 1))
54+
image = image[np.newaxis, :]
55+
return image
56+
57+
def mark_nop(graph, conv_layer=-1, skip_conv_layer=()):
58+
"""Helper function to mark certain op as nop
59+
60+
Useful to debug performance issues.
61+
"""
62+
jgraph = json.loads(graph.json())
63+
counter = 0
64+
for nid, node in enumerate(jgraph["nodes"]):
65+
op_name = node["op"]
66+
if op_name != "tvm_op":
67+
continue
68+
attrs = node["attrs"]
69+
node_name = node["name"]
70+
func_name = attrs["func_name"]
71+
if func_name.find("quantized_conv2d") != -1:
72+
if conv_layer >= 0:
73+
if counter != conv_layer:
74+
attrs["func_name"] = "__nop"
75+
if counter in skip_conv_layer:
76+
attrs["func_name"] = "__nop"
77+
counter += 1
78+
else:
79+
if conv_layer >= 0:
80+
attrs["func_name"] = "__nop"
81+
attrs["func_name"] = "__nop"
82+
if attrs["func_name"] != "__nop":
83+
print("Run function %s"% func_name)
84+
graph = nnvm.graph.load_json(json.dumps(jgraph))
85+
return graph
86+
87+
x = transform_image(image)
88+
print('x', x.shape)
89+
90+
######################################################################
91+
# now compile the graph
92+
import nnvm.compiler
93+
np.random.seed(0)
94+
sym = nnvm.graph.load_json(
95+
open(os.path.join(RESNET_GRAPH_FILE)).read())
96+
params = pickle.load(
97+
open(os.path.join(RESNET_PARAMS_FILE)))
98+
99+
shape_dict = {"data": x.shape}
100+
dtype_dict = {"data": 'float32'}
101+
shape_dict.update({k: v.shape for k, v in params.items()})
102+
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
103+
104+
graph = nnvm.graph.create(sym)
105+
graph_attr.set_shape_inputs(sym, shape_dict)
106+
graph_attr.set_dtype_inputs(sym, dtype_dict)
107+
graph = graph.apply("InferShape").apply("InferType")
108+
109+
dtype = "float32"
110+
sym = vta.graph.remove_stochastic(sym)
111+
sym = vta.graph.clean_cast(sym)
112+
sym = vta.graph.clean_conv_fuse(sym)
113+
if "vta" in target:
114+
sym = vta.graph.pack(sym, shape_dict, factor)
115+
116+
graph_attr.set_shape_inputs(sym, shape_dict)
117+
sym = sym.apply("InferShape")
118+
graph_attr.set_dtype_inputs(sym, dtype_dict)
119+
sym = sym.apply("InferType")
120+
121+
with nnvm.compiler.build_config(opt_level=3):
122+
bdict = {}
123+
if "vta" not in target:
124+
bdict = {"add_lower_pass": []}
125+
else:
126+
bdict = {"add_lower_pass": vta.debug_mode(0)}
127+
with tvm.build_config(**bdict):
128+
graph, lib, params = nnvm.compiler.build(
129+
sym, target, shape_dict, dtype_dict,
130+
params=params)
131+
132+
remote = rpc.connect(host, port)
133+
temp = util.tempdir()
134+
lib.save(temp.relpath("graphlib.o"))
135+
remote.upload(temp.relpath("graphlib.o"))
136+
lib = remote.load_module("graphlib.o")
137+
ctx = remote.ext_dev(0) if "vta" in target else remote.cpu(0)
138+
139+
print("Build complete...")
140+
141+
def run_e2e(graph):
142+
"""Running end to end example
143+
"""
144+
if debug_fpga_only:
145+
graph = mark_nop(graph, skip_conv_layer=(0,))
146+
m = graph_runtime.create(graph, lib, ctx)
147+
# set inputs
148+
m.set_input('data', tvm.nd.array(x.astype("float32")))
149+
m.set_input(**params)
150+
# execute
151+
timer = m.module.time_evaluator("run", ctx, number=10)
152+
tcost = timer()
153+
# get outputs
154+
tvm_output = m.get_output(
155+
0,tvm.nd.empty((1000,), dtype, remote.cpu(0)))
156+
top1 = np.argmax(tvm_output.asnumpy())
157+
print('TVM prediction top-1:', top1, synset[top1])
158+
print("t-cost=%g" % tcost.mean)
159+
160+
161+
def run_layer(old_graph):
162+
"""Run a certain layer."""
163+
for layer_id in range(1, 2):
164+
graph = mark_nop(old_graph, layer_id)
165+
m = graph_runtime.create(graph, lib, ctx)
166+
# set inputs
167+
m.set_input('data', tvm.nd.array(x.astype("float32")))
168+
m.set_input(**params)
169+
# execute
170+
timer = m.module.time_evaluator("run", ctx, number=10)
171+
tcost = timer()
172+
print("resnet[%d]: %g\n"% (layer_id, tcost.mean))
173+
174+
run_e2e(graph)
File renamed without changes.

0 commit comments

Comments
 (0)