Skip to content

Commit a28ce62

Browse files
Ivycrazydemo
authored andcommitted
add byoc-dnnl pattern and their test cases
1 parent d6c8862 commit a28ce62

File tree

5 files changed

+485
-173
lines changed

5 files changed

+485
-173
lines changed

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,17 @@
3232
- The other way is to implement the function by themselves to
3333
check the attributes of the op and decide if it should be offloaded to DNNL.
3434
"""
35+
import logging
36+
3537
import tvm.ir
36-
from ...dataflow_pattern import wildcard, is_op
37-
from .register import register_pattern_table
3838
from tvm.relay import transform
3939
from tvm.relay.build_module import bind_params_by_name
4040

41+
from ...dataflow_pattern import wildcard, is_op
42+
from .register import register_pattern_table
43+
44+
logger = logging.getLogger("DNNL")
45+
4146

4247
def _register_external_op_helper(op_name, supported=True):
4348
"""The helper function to indicate that a given operator can be supported
@@ -65,11 +70,26 @@ def _func_wrapper(expr):
6570
_register_external_op_helper("nn.conv2d")
6671
_register_external_op_helper("nn.dense")
6772
_register_external_op_helper("nn.relu")
73+
_register_external_op_helper("tanh")
74+
_register_external_op_helper("sigmoid")
6875
_register_external_op_helper("add")
6976
_register_external_op_helper("multiply")
7077

7178

72-
def make_pattern(with_bias=True):
79+
def make_conv_pattern(with_bias=True, with_eltwise=None):
80+
"""Create patterns related to nn.conv2d.
81+
82+
Parameters
83+
----------
84+
with_bias : bool
85+
Whether attach `bias_add` to `nn.conv2d`.
86+
with_eltwise : str
87+
The attached elementwise post-op name.
88+
Returns
89+
-------
90+
conv_out : CallPattern
91+
Call node sequence.
92+
"""
7393
data = wildcard()
7494
weight = wildcard()
7595
bias = wildcard()
@@ -78,21 +98,88 @@ def make_pattern(with_bias=True):
7898
conv_out = is_op("add")(conv, bias)
7999
else:
80100
conv_out = conv
81-
return is_op("nn.relu")(conv_out)
101+
if with_eltwise:
102+
return is_op(with_eltwise)(conv_out)
103+
return conv_out
104+
105+
106+
def make_dense_pattern(with_bias=True, with_eltwise=None):
107+
"""Create patterns related to nn.dense.
108+
109+
Parameters
110+
----------
111+
with_bias : bool
112+
Whether attach `bias_add` to `nn.dense`.
113+
with_eltwise : str
114+
The attached elementwise post-op name.
115+
Returns
116+
-------
117+
dense_out : CallPattern
118+
Call node sequence.
119+
"""
120+
data = wildcard()
121+
weight = wildcard()
122+
bias = wildcard()
123+
dense = is_op("nn.dense")(data, weight)
124+
if with_bias:
125+
dense_out = is_op("add")(dense, bias)
126+
else:
127+
dense_out = dense
128+
if with_eltwise:
129+
dense_out = is_op(with_eltwise)(dense_out)
130+
return dense_out
131+
132+
133+
def make_dnnl_pattern(op, with_bias, with_eltwise):
134+
"""Create dnnl patterns.
135+
136+
Parameters
137+
----------
138+
op : str
139+
The first call node's op name.
140+
with_bias : bool
141+
Whether attach `bias_add` to `nn.dense`.
142+
with_eltwise : str
143+
The attached elementwise post-op name.
144+
Returns
145+
-------
146+
pattern : Tuple(pattern_name, CallPattern)
147+
Created pattern name, along with its CallPattern.
148+
"""
149+
pat_name = "dnnl." + op
150+
pat_name += "_bias" if with_bias else ""
151+
pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
152+
if op == "conv2d":
153+
dnnl_pattern = (pat_name, make_conv_pattern(with_bias, with_eltwise))
154+
elif op == "dense":
155+
dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
156+
else:
157+
logger.warning("Currently, only conv2d and dense op are supported, but got %s.", op)
158+
dnnl_pattern = ()
159+
return dnnl_pattern
82160

83161

84162
@register_pattern_table("dnnl")
85163
def pattern_table():
86-
conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True))
87-
conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
88-
dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
164+
"""Create dnnl patterns.
165+
166+
Returns
167+
-------
168+
dnnl_patterns : List[dnnl_pattern]
169+
Created patterns.
170+
"""
171+
elt_list = ["nn.relu", "tanh", "sigmoid", None]
172+
dnnl_patterns = []
173+
for with_bias in [True, False]:
174+
for elt in elt_list:
175+
if not with_bias and not elt:
176+
return dnnl_patterns
177+
dnnl_patterns.append(make_dnnl_pattern("conv2d", with_bias, elt))
178+
dnnl_patterns.append(make_dnnl_pattern("dense", with_bias, elt))
89179
return dnnl_patterns
90180

91181

92-
def partition_for_dnnl(
93-
mod,
94-
params=None,
95-
):
182+
def partition_for_dnnl(mod, params=None):
96183
"""Partition the graph greedily offloading supported operators to DNNL.
97184
98185
Parameters
@@ -111,6 +198,14 @@ def partition_for_dnnl(
111198
mod["main"] = bind_params_by_name(mod["main"], params)
112199
seq = tvm.transform.Sequential(
113200
[
201+
transform.CanonicalizeOps(),
202+
transform.InferType(),
203+
transform.SimplifyInference(),
204+
transform.FoldConstant(),
205+
transform.FoldScaleAxis(),
206+
# fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
207+
transform.SimplifyExpr(),
208+
transform.FoldConstant(),
114209
transform.MergeComposite(pattern_table()),
115210
transform.AnnotateTarget("dnnl"),
116211
transform.MergeCompilerRegions(),

src/relay/backend/contrib/dnnl/codegen.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,27 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
455455

456456
if (name == "dnnl.conv2d_bias_relu") {
457457
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
458+
} else if (name == "dnnl.conv2d_bias_tanh") {
459+
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "tanh"});
460+
ICHECK(call->op.as<OpNode>()) << "Not op node";
461+
} else if (name == "dnnl.conv2d_bias_sigmoid") {
462+
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "sigmoid"});
463+
ICHECK(call->op.as<OpNode>()) << "Not op node";
464+
} else if (name == "dnnl.conv2d_bias") {
465+
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "add"});
466+
ICHECK(call->op.as<OpNode>()) << "Not op node";
458467
} else if (name == "dnnl.conv2d_relu") {
459468
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
460469
ICHECK(call->op.as<OpNode>()) << "Not op node";
470+
} else if (name == "dnnl.conv2d_tanh") {
471+
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "tanh"});
472+
ICHECK(call->op.as<OpNode>()) << "Not op node";
473+
} else if (name == "dnnl.conv2d_sigmoid") {
474+
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "sigmoid"});
475+
ICHECK(call->op.as<OpNode>()) << "Not op node";
476+
} else if (name == "dnnl.dense_bias") {
477+
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.dense", "add"});
478+
ICHECK(call->op.as<OpNode>()) << "Not op node";
461479
} else {
462480
LOG(FATAL) << "Unrecognized DNNL pattern: " << name;
463481
}

src/runtime/contrib/dnnl/dnnl_json_runtime.cc

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,31 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
103103
if ("nn.conv2d" == op_name) {
104104
Conv2d(nid);
105105
} else if ("dnnl.conv2d_relu" == op_name) {
106-
Conv2d(nid, true, false);
106+
Conv2d(nid, true, false, dnnl::algorithm::eltwise_relu);
107+
} else if ("dnnl.conv2d_tanh" == op_name) {
108+
Conv2d(nid, true, false, dnnl::algorithm::eltwise_tanh);
109+
} else if ("dnnl.conv2d_sigmoid" == op_name) {
110+
Conv2d(nid, true, false, dnnl::algorithm::eltwise_logistic);
111+
} else if ("dnnl.conv2d_bias" == op_name) {
112+
Conv2d(nid, false, true);
107113
} else if ("dnnl.conv2d_bias_relu" == op_name) {
108-
Conv2d(nid, true, true);
114+
Conv2d(nid, true, true, dnnl::algorithm::eltwise_relu);
115+
} else if ("dnnl.conv2d_bias_tanh" == op_name) {
116+
Conv2d(nid, true, true, dnnl::algorithm::eltwise_tanh);
117+
} else if ("dnnl.conv2d_bias_sigmoid" == op_name) {
118+
Conv2d(nid, true, true, dnnl::algorithm::eltwise_logistic);
109119
} else if ("nn.dense" == op_name) {
110120
Dense(nid);
121+
} else if ("dnnl.dense_bias" == op_name) {
122+
Dense(nid, true);
111123
} else if ("nn.batch_norm" == op_name) {
112124
BatchNorm(nid);
113125
} else if ("nn.relu" == op_name) {
114-
Relu(nid);
126+
Eltwise(nid, dnnl::algorithm::eltwise_relu);
127+
} else if ("tanh" == op_name) {
128+
Eltwise(nid, dnnl::algorithm::eltwise_tanh);
129+
} else if ("sigmoid" == op_name) {
130+
Eltwise(nid, dnnl::algorithm::eltwise_logistic);
115131
} else if ("add" == op_name) {
116132
Binary(nid, dnnl::algorithm::binary_add);
117133
} else if ("multiply" == op_name) {
@@ -150,7 +166,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
150166
return entry_out_mem_[eid].first;
151167
}
152168

153-
void Conv2d(const size_t& nid, const bool has_relu = false, const bool has_bias = false) {
169+
void Conv2d(const size_t& nid, const bool has_elt = false, const bool has_bias = false,
170+
dnnl::algorithm algo = dnnl::algorithm::eltwise_relu) {
154171
auto node = nodes_[nid];
155172

156173
// Setup attributes.
@@ -159,24 +176,29 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
159176
dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
160177
dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
161178
std::vector<std::string> str_strides = node.GetAttr<std::vector<std::string>>("strides");
179+
std::vector<std::string> str_dilates = node.GetAttr<std::vector<std::string>>("dilation");
162180
std::vector<std::string> str_padding = node.GetAttr<std::vector<std::string>>("padding");
163181
dnnl::memory::dim groups = std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
164182

165-
dnnl::memory::dim N = input_shape[0], // batch size
166-
IC = input_shape[1], // input channels
167-
IH = input_shape[2], // input height
168-
IW = input_shape[3], // input width
169-
OC = weight_shape[0], // output channels
170-
KH = weight_shape[2], // weight height
171-
KW = weight_shape[3], // weight width
172-
PW_L = std::stoi(str_padding[1]), // width padding: left
173-
PW_R = std::stoi(str_padding[3]), // width padding: right
174-
PH_L = std::stoi(str_padding[0]), // height padding: top
175-
PH_R = std::stoi(str_padding[2]), // height padding: bottom
176-
SH = std::stoi(str_strides[0]), // height-wise stride
177-
SW = std::stoi(str_strides[1]), // weight-wise stride
178-
OH = (IH - KH + PH_L + PH_R) / SH + 1, // output height
179-
OW = (IW - KW + PW_L + PW_R) / SW + 1; // output width
183+
dnnl::memory::dim N = input_shape[0], // batch size
184+
IC = input_shape[1], // input channels
185+
IH = input_shape[2], // input height
186+
IW = input_shape[3], // input width
187+
OC = weight_shape[0], // output channels
188+
KH = weight_shape[2], // weight height
189+
KW = weight_shape[3], // weight width
190+
PW_L = std::stoi(str_padding[1]), // width padding: left
191+
PW_R = std::stoi(str_padding[3]), // width padding: right
192+
PH_L = std::stoi(str_padding[0]), // height padding: top
193+
PH_R = std::stoi(str_padding[2]), // height padding: bottom
194+
SH = std::stoi(str_strides[0]), // height-wise stride
195+
SW = std::stoi(str_strides[1]), // weight-wise stride
196+
DH = std::stoi(str_dilates[0]) - 1, // height-wise dilate
197+
DW = std::stoi(str_dilates[1]) - 1, // weight-wise dilate
198+
DKH = 1 + (KH - 1) * (DH + 1), // dilated weight height
199+
DKW = 1 + (KW - 1) * (DW + 1), // dilated weight width
200+
OH = (IH - DKH + PH_L + PH_R) / SH + 1, // output height
201+
OW = (IW - DKW + PW_L + PW_R) / SW + 1; // output width
180202

181203
// Memory shapes.
182204
dnnl::memory::dims src_dims = {N, IC, IH, IW};
@@ -187,6 +209,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
187209
dnnl::memory::dims bias_dims = {OC};
188210
dnnl::memory::dims dst_dims = {N, OC, OH, OW};
189211
dnnl::memory::dims strides_dims = {SH, SW};
212+
dnnl::memory::dims dilates_dims = {DH, DW};
190213
dnnl::memory::dims padding_dims_l = {PH_L, PW_L};
191214
dnnl::memory::dims padding_dims_r = {PH_R, PW_R};
192215

@@ -199,13 +222,14 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
199222
// Covn2d description.
200223
auto conv_desc = dnnl::convolution_forward::desc(
201224
dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md,
202-
conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, padding_dims_l, padding_dims_r);
225+
conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, dilates_dims, padding_dims_l,
226+
padding_dims_r);
203227

204-
// Enable ReLU
228+
// Enable elementwise post-ops
205229
dnnl::primitive_attr attr;
206-
if (has_relu) {
230+
if (has_elt) {
207231
dnnl::post_ops ops;
208-
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
232+
ops.append_eltwise(1.f, algo, 0.f, 0.f);
209233
attr.set_post_ops(ops);
210234
}
211235

@@ -245,7 +269,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
245269
{DNNL_ARG_DST, conv2d_dst_memory}});
246270
}
247271

248-
void Dense(const size_t& nid) {
272+
void Dense(const size_t& nid, const bool has_bias = false) {
249273
auto node = nodes_[nid];
250274

251275
// Setup attributes.
@@ -281,9 +305,18 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
281305
// Memories.
282306
auto data_memory = BindDNNLMemory(data_entry, data_md);
283307
auto weight_memory = BindDNNLMemory(weight_entry, weight_md);
308+
309+
// Bias memory.
284310
auto bias_memory = dnnl::memory(bias_md, engine_);
285-
float bias[OC] = {0};
286-
write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float));
311+
if (has_bias) {
312+
auto bias_entry = node.GetInputs()[2];
313+
BindDNNLMemory(bias_entry, bias_memory);
314+
} else {
315+
float bias[OC] = {0};
316+
write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float));
317+
}
318+
319+
// Output memory.
287320
JSONGraphNodeEntry out_entry(nid, 0);
288321
auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc());
289322

@@ -335,20 +368,20 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
335368
{DNNL_ARG_VARIANCE, variance_memory}});
336369
}
337370

338-
void Relu(const size_t& nid) {
371+
void Eltwise(const size_t& nid, dnnl::algorithm algo) {
339372
auto node = nodes_[nid];
340373

341374
auto data_entry = node.GetInputs()[0];
342375
dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
343376
dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32);
344377

345-
auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference,
346-
dnnl::algorithm::eltwise_relu, data_md, 0);
347-
auto relu_prim_desc = dnnl::eltwise_forward::primitive_desc(relu_desc, engine_);
348-
ICHECK(data_md == relu_prim_desc.dst_desc());
378+
auto elt_desc =
379+
dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, 0);
380+
auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_);
381+
ICHECK(data_md == elt_prim_desc.dst_desc());
349382

350-
auto relu = dnnl::eltwise_forward(relu_prim_desc);
351-
net_.push_back(relu);
383+
auto elt = dnnl::eltwise_forward(elt_prim_desc);
384+
net_.push_back(elt);
352385

353386
auto data_memory = BindDNNLMemory(data_entry, data_md);
354387
JSONGraphNodeEntry out_entry(nid, 0);

0 commit comments

Comments
 (0)