Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@


logger = logging.getLogger("DNNL")
supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]


def _register_external_op_helper(op_name, supported=True):
Expand Down Expand Up @@ -120,6 +121,8 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
conv_out : CallPattern
Call node sequence.
"""
if with_eltwise not in supported_post_elts:
raise ValueError("Unsupported eltwise post-op: %s" % with_eltwise)
data = wildcard()
weight = wildcard()
bias = wildcard()
Expand All @@ -128,8 +131,11 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
conv_out = is_op("add")(conv, bias)
else:
conv_out = conv
if with_eltwise:
return is_op(with_eltwise)(conv_out)
if with_eltwise == "swish":
sig_out = is_op("sigmoid")(conv_out)
conv_out = is_op("multiply")(conv_out, sig_out)
elif with_eltwise:
conv_out = is_op(with_eltwise)(conv_out)
return conv_out


Expand All @@ -147,6 +153,8 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
dense_out : CallPattern
Call node sequence.
"""
if with_eltwise not in supported_post_elts:
raise ValueError("Unsupported eltwise post-op: %s" % with_eltwise)
data = wildcard()
weight = wildcard()
bias = wildcard()
Expand All @@ -165,6 +173,9 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
added_erf_val = is_op("add")(erf_val, const2)
mul_val = is_op("multiply")(dense_out, added_erf_val)
dense_out = is_op("multiply")(mul_val, const3)
elif with_eltwise == "swish":
sig_out = is_op("sigmoid")(dense_out)
dense_out = is_op("multiply")(dense_out, sig_out)
elif with_eltwise:
dense_out = is_op(with_eltwise)(dense_out)
return dense_out
Expand All @@ -191,6 +202,7 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise):
pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::]
pat_name += "_bias" if with_bias else ""
pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
pat_name = pat_name.replace("_swish", "_sigmoid_mul")
if "conv" in op_name:
dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise))
elif op_name == "nn.dense":
Expand Down Expand Up @@ -282,7 +294,7 @@ def pattern_table():
dnnl_patterns.append(make_qnn_conv2d_pattern())
dnnl_patterns.append(make_qnn_dense_pattern())

elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None]
elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
for with_bias in [True, False]:
for elt in elt_list:
if not with_bias and not elt:
Expand Down Expand Up @@ -380,6 +392,8 @@ def get_shape(tensor):
if isinstance(tensor, tvm.ir.container.Array):
return tensor[-1].shape
if isinstance(tensor, relay.expr.Call):
if tensor.op.name == "multiply":
return tensor.type_args[0].shape
return tensor.checked_type.shape
raise TypeError("Unsupport data type: %s" % type(tensor))

Expand All @@ -395,6 +409,8 @@ def get_dtype(tensor):
if isinstance(tensor, tvm.ir.container.Array):
return tensor[-1].dtype
if isinstance(tensor, relay.expr.Call):
if tensor.op.name == "multiply":
return tensor.type_args[0].dtype
return tensor.checked_type.dtype
raise TypeError("Unsupport data type: %s" % type(tensor))

Expand Down
9 changes: 9 additions & 0 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
{"relu", "nn.relu"},
{"tanh", "tanh"},
{"sigmoid", "sigmoid"},
{"clip", "clip"},
{"mul", "multiply"},
{"nn.deconv2d", "nn.conv2d_transpose"},
{"nn.deconv3d", "nn.conv3d_transpose"},
};
Expand Down Expand Up @@ -566,6 +568,13 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
"kernel", /* op_type_ */
inputs, 1 /* num_outputs_ */);
SetCallNodeAttribute(node, call);
// If has post-op `clip`. Assume the last op is clip, add clip's attrs to the pattern attrs.
if (name.find("_clip") != std::string::npos) {
auto clip_call = cn->op.as<FunctionNode>()->body.as<CallNode>();
ICHECK(IsOp(clip_call, "clip"));
SetCallNodeAttribute(node, clip_call);
}
// For QNN.
for (const auto& kvp : extra_attrs) node->SetAttr(kvp.first, kvp.second);

return AddNode(node, GetRef<Expr>(cn));
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
current_call->args[valid_node_idx].as<VarNode>()) {
valid_node_idx++;
}
while (valid_node_idx < current_call->args.size() &&
!(IsOp(current_call->args[valid_node_idx].as<CallNode>(), expected_op_names[depth - 1]))) {
valid_node_idx++;
}
const auto* next_call = current_call->args[valid_node_idx].as<CallNode>();
return GetRootCall(next_call, depth - 1, expected_op_names);
}
Expand Down
12 changes: 11 additions & 1 deletion src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::regex relu_pat(".*_relu.*");
std::regex tanh_pat(".*_tanh.*");
std::regex sigmoid_pat(".*_sigmoid.*");
std::regex clip_pat(".*_clip.*");
std::regex gelu_pat(".*_gelu.*");

// Parsing post-ops.
Expand All @@ -199,8 +200,17 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (std::regex_match(op_name, tanh_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f);
}
if (std::regex_match(op_name, clip_pat)) {
float a_min = GetNodeAttr<float>(nodes_[nid], "a_min");
float a_max = GetNodeAttr<float>(nodes_[nid], "a_max");
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max);
}
if (std::regex_match(op_name, sigmoid_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
if (op_name.find("_sigmoid_mul") != std::string::npos) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
} else {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
}
}
if (std::regex_match(op_name, gelu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
Expand Down
126 changes: 31 additions & 95 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, te
if use_dnnl:
processed_mod = partition_for_dnnl(processed_mod, params, alter_layout)
check_dnnl_used(processed_mod)

with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=processed_mod, device=dev, target=target
Expand Down Expand Up @@ -237,6 +236,23 @@ def run_and_verify_func(
)


def add_activation(activation, out, dic, param_lst):
if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
elif activation == "clip":
return relay.clip(out, 0.0, 6.0), dic, param_lst
elif activation == "swish":
sig_out = relay.sigmoid(out)
out = relay.multiply(out, sig_out)
return out, dic, param_lst
else:
return out, dic, param_lst


def get_conv1d(
x_shape=((1, 3, 224)),
k_shape=(16, 3, 3),
Expand All @@ -262,15 +278,7 @@ def get_conv1d(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dtype="float32"):
Expand All @@ -279,15 +287,7 @@ def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dt
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[0],)
param_lst += ["bias"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv1d_bias_bn_relu(x_shape=(1, 3, 224), k_shape=(10, 3, 3), dtype="float32"):
Expand Down Expand Up @@ -334,15 +334,7 @@ def get_conv2d(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv2d_transpose(
Expand All @@ -367,15 +359,7 @@ def get_conv2d_transpose(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv2d_weights_const(
Expand Down Expand Up @@ -412,15 +396,7 @@ def get_conv2d_bias(
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[0],)
param_lst += ["bias"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv2d_transpose_bias(
Expand All @@ -431,15 +407,7 @@ def get_conv2d_transpose_bias(
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[1],)
param_lst += ["bias"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"):
Expand Down Expand Up @@ -503,15 +471,7 @@ def get_conv3d(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv3d_transpose(
Expand Down Expand Up @@ -542,15 +502,7 @@ def get_conv3d_transpose(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv3d_bias(
Expand All @@ -561,15 +513,7 @@ def get_conv3d_bias(
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[0],)
param_lst += ["bias"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv3d_transpose_bias(
Expand All @@ -580,15 +524,7 @@ def get_conv3d_transpose_bias(
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[1],)
param_lst += ["bias"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def gelu_helper(data):
Expand Down Expand Up @@ -797,7 +733,7 @@ def test_conv2d_weights_const(run_module, dtype="float32"):
def test_conv2d_pattern(run_module, dtype="float32"):
x_shape = (1, 32, 8, 8)
k_shape = (16, 32, 3, 3)
activation_lst = [None, "relu", "tanh", "sigmoid"]
activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
for a in activation_lst:
conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype)
conv2d = tvm.IRModule.from_expr(conv2d)
Expand Down Expand Up @@ -839,7 +775,7 @@ def test_conv2d_transpose(run_module, dtype="float32"):


def test_conv2d_transpose_pattern(run_module, dtype="float32"):
activation_lst = [None, "relu", "tanh", "sigmoid"]
activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
for a in activation_lst:
conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype)
conv2d = tvm.IRModule.from_expr(conv2d)
Expand Down Expand Up @@ -872,7 +808,7 @@ def test_conv3d(run_module, dtype="float32"):


def test_conv3d_pattern(run_module, dtype="float32"):
activation_lst = [None, "relu", "tanh", "sigmoid"]
activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
for a in activation_lst:
conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype)
conv3d = tvm.IRModule.from_expr(conv3d)
Expand Down Expand Up @@ -905,7 +841,7 @@ def test_conv3d_transpose(run_module, dtype="float32"):


def test_conv3d_transpose_pattern(run_module, dtype="float32"):
activation_lst = [None, "relu", "tanh", "sigmoid"]
activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
for a in activation_lst:
conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype)
conv3d = tvm.IRModule.from_expr(conv3d)
Expand Down
Loading