diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index c87a7162b070..6581f10a2f56 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -40,10 +40,15 @@ from tvm.relay.expr import GlobalVar from tvm.relay.expr_functor import ExprMutator, ExprVisitor +from tvm.relay.analysis import analysis as _analysis +from tvm.relay import expr as _expr + + from ... import _ffi_api from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback from .register import register_pattern_table + logger = logging.getLogger("DNNL") @@ -139,12 +144,22 @@ def make_dense_pattern(with_bias=True, with_eltwise=None): data = wildcard() weight = wildcard() bias = wildcard() + dense = is_op("nn.dense")(data, weight) if with_bias: dense_out = is_op("add")(dense, bias) else: dense_out = dense - if with_eltwise: + if with_eltwise == "gelu": + const1 = wildcard() + const2 = wildcard() + const3 = wildcard() + div = is_op("divide")(dense_out, const1) + erf_val = is_op("erf")(div) + added_erf_val = is_op("add")(erf_val, const2) + mul_val = is_op("multiply")(dense_out, added_erf_val) + dense_out = is_op("multiply")(mul_val, const3) + elif with_eltwise: dense_out = is_op(with_eltwise)(dense_out) return dense_out @@ -176,7 +191,7 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise): dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise)) else: logger.warning( - "Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and " + "Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose, " "dense op are supported, but got %s.", op_name, ) @@ -193,12 +208,12 @@ def pattern_table(): dnnl_patterns : List[dnnl_pattern] Created patterns. """ - elt_list = ["nn.relu", "tanh", "sigmoid", None] + elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None] dnnl_patterns = [] for with_bias in [True, False]: for elt in elt_list: if not with_bias and not elt: - return dnnl_patterns + continue for conv_name in [ "nn.conv1d", "nn.conv2d", @@ -206,7 +221,8 @@ def pattern_table(): "nn.conv2d_transpose", "nn.conv3d_transpose", ]: - dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt)) + if elt != "gelu": + dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt)) dnnl_patterns.append(make_dnnl_pattern("nn.dense", with_bias, elt)) return dnnl_patterns @@ -339,6 +355,7 @@ def tag2layout(input_data, is_weight=False, conv_type="Conv1D"): res += i else: raise ValueError("Unsupport layout format: %s" % input_data) + return res @@ -594,3 +611,99 @@ def rewrite_layer_norm(mod): """ mod["main"] = rewrite(LayerNormRewrite(), mod["main"]) return mod + + +class DenseReshapeBiasGeluRewrite(DFPatternCallback): + """ + A callback to reorder reshape operators when the patterns are as below: + + Pattern #1: + 1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */, + units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */; + 2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */; + 3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63) + /* ty=Tensor[(1, 3136, 64), float32] */; + + Pattern #2: + 1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */, + units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */; + 2 %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136, 512), float32] */; + 3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77) + /* ty=Tensor[(1, 3136, 512), float32] */; + 4 %79 = divide(%78, 1.41421f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; + 5 %80 = erf(%79) /* ty=Tensor[(1, 3136, 512), float32] */; + 6 %81 = add(%80, 1f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; + 7 %82 = multiply(%78, %81) /* ty=Tensor[(1, 3136, 512), float32] */; + 8 %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; + """ + + def __init__(self, has_gelu=True): + super(DenseReshapeBiasGeluRewrite, self).__init__() + self.data = wildcard() + self.weight = wildcard() + self.bias = wildcard() + self.const1 = wildcard() + self.const2 = wildcard() + self.const3 = wildcard() + + self.attr_map = {} + self.has_gelu = has_gelu + + den = is_op("nn.dense")(self.data, self.weight) + re_den = is_op("reshape")(den) + added = is_op("add")(self.bias, re_den) + if self.has_gelu: + divisor = is_op("divide")(added, self.const1) + val_erf = is_op("erf")(divisor) + added_erf = is_op("add")(val_erf, self.const2) + mul1 = is_op("multiply")(added, added_erf) + mul2 = is_op("multiply")(mul1, self.const3) + self.pattern = mul2 + else: + self.pattern = added + + def get_attr(self, pre): + """Recursively retrieve attributes from reshape operator.""" + + def visit_func(expr): + if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"): + new_attrs = {} + for k in expr.attrs.keys(): + new_attrs[k] = expr.attrs[k] + self.attr_map["reshape"] = new_attrs + + _analysis.post_order_visit(pre, visit_func) + + def callback(self, pre, post, node_map): + self.get_attr(pre) + + data = node_map[self.data][0] + weight = node_map[self.weight][0] + bias = node_map[self.bias][0] + + den = relay.op.nn.dense(data, weight) + added = relay.op.add(bias, den) + if not self.has_gelu: + return relay.op.reshape(added, self.attr_map["reshape"]["newshape"]) + + const1 = node_map[self.const1][0] + const2 = node_map[self.const2][0] + const3 = node_map[self.const3][0] + + divisor = relay.op.divide(added, const1) + val_erf = relay.op.erf(divisor) + added_erf = relay.op.add(val_erf, const2) + mul1 = relay.op.multiply(added, added_erf) + mul2 = relay.op.multiply(mul1, const3) + return relay.op.reshape(mul2, self.attr_map["reshape"]["newshape"]) + + +def rewrite_dense_bias_gelu_reshape_last(mod): + """Rewrite the input graph to reorder reshape operators so that + we can perform dense_bias_gelu/dense_bias fusion and then offload + them to byoc part. + """ + mod["main"] = rewrite( + [DenseReshapeBiasGeluRewrite(), DenseReshapeBiasGeluRewrite(has_gelu=False)], mod["main"] + ) + return mod diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 41480ed33b0a..927cd12ae0fb 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -454,6 +454,7 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { ICHECK_NE(pattern_name, ""); std::vector op_list; size_t pos = 0, start = 0; + while ((pos = pattern_name.find(interval, start)) != std::string::npos) { std::string op_name = pattern_name.substr(start, pos - start); if (op_name.find("dnnl") != std::string::npos) { @@ -508,8 +509,7 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; } else if (name.find("dnnl.dense") != std::string::npos) { - std::vector op_list = ParsingOpList(name); - call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); + call = GetRootCall(fn->body.as(), 10, "nn.dense"); ICHECK(call->op.as()) << "Not op node"; } else { LOG(FATAL) << "Unrecognized DNNL pattern: " << name; diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 360f366a162e..70080254c414 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -480,6 +480,38 @@ inline const CallNode* GetRootCall(const CallNode* current_call, const std::stri return GetRootCall(next_call, op_name); } +/*! + * \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in + * relu(add(conv2d)) + * \param call A Relay call node. Typically nn.relu when called the first time. + * \param max_depth The maximum number of calls before the root op, counting from current_call. + * \param op_name The name of expected "root" op in this fused call. + * \return A CallNode corresponding to the root op + */ +inline const CallNode* GetRootCall(const CallNode* current_call, int max_depth, + const std::string& op_name) { + ICHECK(current_call && max_depth >= 0); + + if (max_depth == 0) { + ICHECK(current_call && IsOp(current_call, op_name)); + return current_call; + } + if (IsOp(current_call, op_name)) { + return current_call; + } + + ICHECK_GT(current_call->args.size(), 0); + + size_t valid_node_idx = 0; + while (valid_node_idx < current_call->args.size() && + current_call->args[valid_node_idx].as()) { + valid_node_idx++; + } + + const auto* next_call = current_call->args[valid_node_idx].as(); + return GetRootCall(next_call, max_depth - 1, op_name); +} + /*! * \brief Get the external symbol of the Relay function name. * diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index db8f25e2a6ea..5045f3323af7 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -83,7 +83,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // Find proper dnnl::memory buffers std::unordered_map mem_args; for (const auto& kvp : arg_reqs) mem_args[kvp.first] = mem_solver(kvp.second); - prim.execute(stream_, mem_args); } } @@ -143,6 +142,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::regex relu_pat(".*_relu.*"); std::regex tanh_pat(".*_tanh.*"); std::regex sigmoid_pat(".*_sigmoid.*"); + std::regex gelu_pat(".*_gelu.*"); // Parsing post-ops. dnnl::post_ops ops; @@ -155,7 +155,12 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (std::regex_match(op_name, sigmoid_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f); } - attr.set_post_ops(ops); + if (std::regex_match(op_name, gelu_pat)) { + ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); + } + if (ops.len() != 0) { + attr.set_post_ops(ops); + } // Parsing bias_add. return std::regex_match(op_name, bias_add_pat) ? true : false; diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 3e4e831aa594..c884665421cb 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -19,6 +19,7 @@ import numpy as np import sys import subprocess +import math import tvm from tvm import relay @@ -56,7 +57,7 @@ def bf16_supported(): return _bf16_supported -def partition_for_dnnl(mod, params=None, alter_layout=True): +def partition_for_dnnl(mod, params=None, alter_layout=True, prune_subgraphs=True): """Partition the graph greedily offloading supported operators to DNNL. Parameters @@ -112,6 +113,7 @@ def partition_for_dnnl(mod, params=None, alter_layout=True): mod = alter_layout_seq(mod) mod = dnnl.rewrite_layer_norm(mod) + mod = dnnl.rewrite_dense_bias_gelu_reshape_last(mod) byoc_seq = tvm.transform.Sequential( [ @@ -121,9 +123,11 @@ def partition_for_dnnl(mod, params=None, alter_layout=True): transform.PartitionGraph(), ] ) + with tvm.transform.PassContext(opt_level=3): mod = byoc_seq(mod) - mod = dnnl.prune_dnnl_subgraphs(mod) + if prune_subgraphs: + mod = dnnl.prune_dnnl_subgraphs(mod) return mod @@ -150,16 +154,15 @@ def assert_result_dict_holds(result_dict): tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3) -def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, test_bf16=True): - def check_dnnl_used(mod, subgraph_num=None): - num_dnnl_subgraphs = sum( - [1 if "dnnl" in gv.name_hint else 0 for gv in mod.get_global_vars()] - ) - if subgraph_num: - assert num_dnnl_subgraphs == subgraph_num - else: - assert num_dnnl_subgraphs >= 1 +def check_dnnl_used(mod, subgraph_num=None): + num_dnnl_subgraphs = sum([1 if "dnnl" in gv.name_hint else 0 for gv in mod.get_global_vars()]) + if subgraph_num: + assert num_dnnl_subgraphs == subgraph_num + else: + assert num_dnnl_subgraphs >= 1 + +def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, test_bf16=True): dev = tvm.cpu() result_dict = dict() for mode in ["graph", "vm"]: @@ -170,6 +173,7 @@ def check_dnnl_used(mod, subgraph_num=None): ] if test_bf16 and bf16_supported(): configs += [(True, False, True), (True, True, True)] + for use_dnnl, alter_layout, use_bf16 in configs: result_key = ( mode @@ -585,21 +589,56 @@ def get_conv3d_transpose_bias( return out, dic, param_lst -def get_dense(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"): +def gelu_helper(data): + const1 = relay.const(math.sqrt(2.0)) + const2 = relay.const(1.0) + const3 = relay.const(0.5) + divisor = relay.op.divide(data, const1) + val_erf = relay.op.erf(divisor) + added_erf = relay.op.add(val_erf, const2) + mul1 = relay.op.multiply(data, added_erf) + out = relay.op.multiply(mul1, const3) + return out + + +def get_dense( + x_shape=(1, 16), k_shape=(32, 16), activation=None, has_reshape=False, dtype="float32" +): x = relay.var("x", shape=(x_shape), dtype=dtype) kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) out = relay.nn.dense(x, kernel, units=k_shape[0]) + # out = relay.nn.dense(x, kernel, units=None) + if has_reshape: + out = relay.reshape(out, newshape=(1, x_shape[0], k_shape[0])) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] + + if activation == "gelu": + out = gelu_helper(out) return out, dic, param_lst -def get_dense_bias(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"): - dense, dic, param_lst = get_dense(x_shape=x_shape, k_shape=k_shape, dtype=dtype) +def get_dense_bias( + x_shape=(1, 16), + k_shape=(32, 16), + activation=None, + has_reshape=False, + use_add=False, + dtype="float32", +): + dense, dic, param_lst = get_dense( + x_shape=x_shape, k_shape=k_shape, has_reshape=has_reshape, dtype=dtype + ) bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) - out = relay.nn.bias_add(dense, bias) + if use_add: + out = relay.add(dense, bias) + else: + out = relay.nn.bias_add(dense, bias) dic["bias"] = (k_shape[0],) param_lst += ["bias"] + + if activation == "gelu": + out = gelu_helper(out) return out, dic, param_lst @@ -891,6 +930,11 @@ def test_dense(run_module, dtype="float32"): config = dense, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) + dense, dic, param_lst = get_dense(x_shape, k_shape, activation="gelu", dtype=dtype) + dense = tvm.IRModule.from_expr(dense) + config = dense, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + def test_dense_pattern(run_module, dtype="float32"): x_shape = (1, 16) @@ -906,6 +950,11 @@ def test_dense_pattern(run_module, dtype="float32"): config = dense_bias, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) + dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, activation="gelu", dtype=dtype) + dense_bias = tvm.IRModule.from_expr(dense_bias) + config = dense_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + def test_pool2d(run_module, dtype="float32"): def get_graph( @@ -1053,5 +1102,29 @@ def test_layer_norm(run_module, dtype="float32"): run_and_verify_func(config, run_module=run_module, dtype=dtype) +def test_rewrite_dense_bias_gelu_reshape_last(run_module, dtype="float32"): + def get_graph(act=None): + x_shape = (1, 16) + k_shape = (32, 16) + + dense_bias, dic, param_lst = get_dense_bias( + x_shape, k_shape, activation=act, has_reshape=True, use_add=True, dtype=dtype + ) + dense_bias = tvm.IRModule.from_expr(dense_bias) + processed_dense_bias = partition_for_dnnl( + dense_bias, params=None, alter_layout=False, prune_subgraphs=False + ) + check_dnnl_used(processed_dense_bias, 1) + + return dense_bias, dic, param_lst + + run_and_verify_func( + get_graph("gelu"), subgraph_num=1, run_module=run_module, dtype=dtype, test_bf16=False + ) + run_and_verify_func( + get_graph(), subgraph_num=1, run_module=run_module, dtype=dtype, test_bf16=False + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 761a430997b0..dedeae56e9da 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -928,9 +928,9 @@ def test_dnnl_fuse(): ) = ( dnnl_patterns[1], dnnl_patterns[13], - dnnl_patterns[19], - dnnl_patterns[25], - dnnl_patterns[37], + dnnl_patterns[20], + dnnl_patterns[26], + dnnl_patterns[38], ) def get_blocks(