Skip to content

Commit c6c69fc

Browse files
committed
[TEST] Xavie initialization for benchmarks (apache#54)
* [TEST] Xavie initialization for benchmarks * remove additional line
1 parent 4fcb5d0 commit c6c69fc

File tree

4 files changed

+129
-12
lines changed

4 files changed

+129
-12
lines changed

nnvm/python/nnvm/testing/init.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Initializer of parameters."""
2+
import numpy as np
3+
4+
class Initializer(object):
5+
"""The base class of an initializer."""
6+
def __init__(self, **kwargs):
7+
self._kwargs = kwargs
8+
9+
def __call__(self, desc, arr):
10+
"""Initialize an array
11+
12+
Parameters
13+
----------
14+
desc : str
15+
Initialization pattern descriptor.
16+
17+
arr : NDArray
18+
The array to be initialized.
19+
"""
20+
if desc.endswith('weight'):
21+
self._init_weight(desc, arr)
22+
elif desc.endswith('bias'):
23+
self._init_bias(desc, arr)
24+
elif desc.endswith('gamma'):
25+
self._init_gamma(desc, arr)
26+
elif desc.endswith('beta'):
27+
self._init_beta(desc, arr)
28+
elif desc.endswith('mean'):
29+
self._init_mean(desc, arr)
30+
elif desc.endswith('var'):
31+
self._init_var(desc, arr)
32+
else:
33+
self._init_default(desc, arr)
34+
35+
def _init_bias(self, _, arr):
36+
arr[:] = 0.0
37+
38+
def _init_gamma(self, _, arr):
39+
arr[:] = 1.0
40+
41+
def _init_beta(self, _, arr):
42+
arr[:] = 0.0
43+
44+
def _init_mean(self, _, arr):
45+
arr[:] = 0.0
46+
47+
def _init_var(self, _, arr):
48+
arr[:] = 1.0
49+
50+
def _init_weight(self, name, arr):
51+
"""Abstract method to Initialize weight."""
52+
raise NotImplementedError("Must override it")
53+
54+
def _init_default(self, name, _):
55+
raise ValueError(
56+
'Unknown initialization pattern for %s. ' \
57+
'Default initialization is now limited to '\
58+
'"weight", "bias", "gamma" (1.0), and "beta" (0.0).' \
59+
'Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern' % name)
60+
61+
62+
class Xavier(Initializer):
63+
""" "Xavier" initialization for weights
64+
65+
Parameters
66+
----------
67+
rnd_type: str, optional
68+
Random generator type, can be ``'gaussian'`` or ``'uniform'``.
69+
70+
factor_type: str, optional
71+
Can be ``'avg'``, ``'in'``, or ``'out'``.
72+
73+
magnitude: float, optional
74+
Scale of random number.
75+
"""
76+
def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3):
77+
super(Xavier, self).__init__(rnd_type=rnd_type,
78+
factor_type=factor_type,
79+
magnitude=magnitude)
80+
self.rnd_type = rnd_type
81+
self.factor_type = factor_type
82+
self.magnitude = float(magnitude)
83+
84+
85+
def _init_weight(self, name, arr):
86+
shape = arr.shape
87+
hw_scale = 1.
88+
if len(shape) < 2:
89+
raise ValueError('Xavier initializer cannot be applied to vector {0}. It requires at'
90+
' least 2D.'.format(name))
91+
if len(shape) > 2:
92+
hw_scale = np.prod(shape[2:])
93+
fan_in, fan_out = shape[1] * hw_scale, shape[0] * hw_scale
94+
factor = 1.
95+
if self.factor_type == "avg":
96+
factor = (fan_in + fan_out) / 2.0
97+
elif self.factor_type == "in":
98+
factor = fan_in
99+
elif self.factor_type == "out":
100+
factor = fan_out
101+
else:
102+
raise ValueError("Incorrect factor type")
103+
# Hack for mobilenet, because there is less connectivity
104+
if "depthwise" in name:
105+
factor = 3 * 3
106+
scale = np.sqrt(self.magnitude / factor)
107+
if self.rnd_type == "uniform":
108+
arr[:] = np.random.uniform(-scale, scale, size=arr.shape)
109+
else:
110+
raise ValueError("Unknown random type")

nnvm/python/nnvm/testing/mobilenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def separable_conv_block(data, name, depthwise_channels,
3030
# depthwise convolution + bn + relu
3131
conv1 = sym.conv2d(data=data, channels=depthwise_channels,
3232
groups=depthwise_channels, kernel_size=kernel_size, strides=strides,
33-
padding=padding, use_bias=False, layout="NCHW", name=name + "_conv1")
33+
padding=padding, use_bias=False, layout="NCHW", name=name + "_depthwise_conv1")
3434
bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + "_bn1")
3535
act1 = sym.relu(data=bn1, name=name + "_relu1")
3636
# pointwise convolution + bn + relu

nnvm/python/nnvm/testing/utils.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import tvm
66
from ..compiler import graph_util
77
from ..import graph
8+
from . init import Xavier
89

9-
10-
def create_workload(net, batch_size, image_shape=(3, 224, 224), dtype="float32"):
10+
def create_workload(net, batch_size, image_shape=(3, 224, 224),
11+
dtype="float32", initializer=None, seed=0):
1112
"""Helper function to create benchmark workload for input network
1213
1314
Parameters
@@ -24,6 +25,12 @@ def create_workload(net, batch_size, image_shape=(3, 224, 224), dtype="float32")
2425
dtype : str, optional
2526
The data type
2627
28+
initializer : Initializer
29+
The initializer used
30+
31+
seed : int
32+
The seed used in initialization.
33+
2734
Returns
2835
-------
2936
net : nnvm.Symbol
@@ -38,15 +45,12 @@ def create_workload(net, batch_size, image_shape=(3, 224, 224), dtype="float32")
3845
g = graph.create(net)
3946
input_shapes, _ = graph_util.infer_shape(g, data=data_shape)
4047
shape_dict = dict(zip(g.index.input_names, input_shapes))
48+
np.random.seed(seed)
49+
initializer = initializer if initializer else Xavier(magnitude=3)
4150
for k, v in shape_dict.items():
4251
if k == "data":
4352
continue
44-
# Specially generate non-negative parameters.
45-
if k.endswith("gamma"):
46-
init = np.random.uniform(0.9, 1, size=v)
47-
elif k.endswith("var"):
48-
init = np.random.uniform(0.9, 1, size=v)
49-
else:
50-
init = np.random.uniform(-0.1, 0.1, size=v)
51-
params[k] = tvm.nd.array(init.astype(dtype), ctx=tvm.cpu(0))
53+
init_value = np.zeros(v).astype(dtype)
54+
initializer(k, init_value)
55+
params[k] = tvm.nd.array(init_value, ctx=tvm.cpu(0))
5256
return net, params

nnvm/tutorials/imagenet_inference_gpu.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
To begin with, we import nnvm(for compilation) and TVM(for deployment).
99
"""
1010
import tvm
11+
import numpy as np
1112
from tvm.contrib import nvcc, graph_runtime
1213
import nnvm.compiler
1314
import nnvm.testing
@@ -64,6 +65,7 @@ def tvm_callback_cuda_compile(code):
6465
graph, lib, params = nnvm.compiler.build(
6566
net, target, shape={"data": data_shape}, params=params)
6667

68+
6769
######################################################################
6870
# Run the Compiled Module
6971
# -----------------------
@@ -74,10 +76,11 @@ def tvm_callback_cuda_compile(code):
7476
# This example runs on the same machine.
7577
#
7678
# Note that the code below no longer depends on NNVM, and only relies TVM's runtime to run(deploy).
77-
79+
data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
7880
module = graph_runtime.create(graph, lib, ctx)
7981
# set input
8082
module.set_input(**params)
83+
module.set_input("data", data)
8184
# run
8285
module.run()
8386
# get output

0 commit comments

Comments
 (0)