Skip to content

Commit 0cc56e8

Browse files
committed
Fix lint
1 parent a4e1a5d commit 0cc56e8

File tree

3 files changed

+39
-73
lines changed

3 files changed

+39
-73
lines changed

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 36 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
check the attributes of the op and decide if it should be offloaded to DNNL.
3434
"""
3535
import logging
36-
import math
3736

3837
import tvm.ir
3938
from tvm import relay
@@ -46,10 +45,9 @@
4645

4746

4847
from ... import _ffi_api
49-
from ...dataflow_pattern import wildcard, is_op, is_constant, rewrite, DFPatternCallback, is_expr
48+
from ...dataflow_pattern import wildcard, is_op, rewrite, DFPatternCallback
5049
from .register import register_pattern_table
5150

52-
import re
5351

5452
logger = logging.getLogger("DNNL")
5553

@@ -549,18 +547,29 @@ def visit_call(self, call):
549547

550548
class DenseReshapeBiasGeluRewrite(DFPatternCallback):
551549
"""
552-
A callback to reorder reshape operators when the patten is as below:
553-
1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */;
550+
A callback to reorder reshape operators when the patterns are as below:
551+
552+
Pattern #1:
553+
1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */,
554+
units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */;
555+
2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */;
556+
3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63)
557+
/* ty=Tensor[(1, 3136, 64), float32] */;
558+
559+
Pattern #2:
560+
1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */,
561+
units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */;
554562
2 %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136, 512), float32] */;
555-
3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77) /* ty=Tensor[(1, 3136, 512), float32] */;
563+
3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77)
564+
/* ty=Tensor[(1, 3136, 512), float32] */;
556565
4 %79 = divide(%78, 1.41421f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
557566
5 %80 = erf(%79) /* ty=Tensor[(1, 3136, 512), float32] */;
558567
6 %81 = add(%80, 1f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
559568
7 %82 = multiply(%78, %81) /* ty=Tensor[(1, 3136, 512), float32] */;
560569
8 %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
561570
"""
562571

563-
def __init__(self):
572+
def __init__(self, has_gelu=True):
564573
super(DenseReshapeBiasGeluRewrite, self).__init__()
565574
self.data = wildcard()
566575
self.weight = wildcard()
@@ -570,29 +579,30 @@ def __init__(self):
570579
self.const3 = wildcard()
571580

572581
self.attr_map = {}
582+
self.has_gelu = has_gelu
573583

574584
den = is_op("nn.dense")(self.data, self.weight)
575585
re_den = is_op("reshape")(den)
576586
added = is_op("add")(self.bias, re_den)
577-
divisor = is_op("divide")(added, self.const1)
578-
val_erf = is_op("erf")(divisor)
579-
added_erf = is_op("add")(val_erf, self.const2)
580-
mul1 = is_op("multiply")(added, added_erf)
581-
mul2 = is_op("multiply")(mul1, self.const3)
582-
self.pattern = mul2
587+
if self.has_gelu:
588+
divisor = is_op("divide")(added, self.const1)
589+
val_erf = is_op("erf")(divisor)
590+
added_erf = is_op("add")(val_erf, self.const2)
591+
mul1 = is_op("multiply")(added, added_erf)
592+
mul2 = is_op("multiply")(mul1, self.const3)
593+
self.pattern = mul2
594+
else:
595+
self.pattern = added
583596

584597
def get_attr(self, pre):
598+
"""Recursively retrieve attributes from reshape operator."""
599+
585600
def visit_func(expr):
586601
if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"):
587602
new_attrs = {}
588603
for k in expr.attrs.keys():
589604
new_attrs[k] = expr.attrs[k]
590605
self.attr_map["reshape"] = new_attrs
591-
elif isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.dense"):
592-
new_attrs = {}
593-
for k in expr.attrs.keys():
594-
new_attrs[k] = expr.attrs[k]
595-
self.attr_map["nn.dense"] = new_attrs
596606

597607
_analysis.post_order_visit(pre, visit_func)
598608

@@ -602,12 +612,16 @@ def callback(self, pre, post, node_map):
602612
data = node_map[self.data][0]
603613
weight = node_map[self.weight][0]
604614
bias = node_map[self.bias][0]
615+
616+
den = relay.op.nn.dense(data, weight)
617+
added = relay.op.add(bias, den)
618+
if not self.has_gelu:
619+
return relay.op.reshape(added, self.attr_map["reshape"]["newshape"])
620+
605621
const1 = node_map[self.const1][0]
606622
const2 = node_map[self.const2][0]
607623
const3 = node_map[self.const3][0]
608-
609-
den = relay.op.nn.dense(data, weight)
610-
added = relay.op.add(bias, den)
624+
611625
divisor = relay.op.divide(added, const1)
612626
val_erf = relay.op.erf(divisor)
613627
added_erf = relay.op.add(val_erf, const2)
@@ -624,57 +638,9 @@ def rewrite_dense_bias_gelu_reshape_last(mod):
624638
return mod
625639

626640

627-
class DenseReshapeBiasRewrite(DFPatternCallback):
628-
"""
629-
A callback to reorder reshape operators when the patten is as below:
630-
1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */;
631-
2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */;
632-
3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63) /* ty=Tensor[(1, 3136, 64), float32] */;
633-
"""
634-
635-
def __init__(self):
636-
super(DenseReshapeBiasRewrite, self).__init__()
637-
self.data = wildcard()
638-
self.weight = wildcard()
639-
self.bias = wildcard()
640-
641-
self.attr_map = {}
642-
643-
den = is_op("nn.dense")(self.data, self.weight)
644-
re_den = is_op("reshape")(den)
645-
added = is_op("add")(self.bias, re_den)
646-
self.pattern = added
647-
648-
def get_attr(self, pre):
649-
def visit_func(expr):
650-
if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"):
651-
new_attrs = {}
652-
for k in expr.attrs.keys():
653-
new_attrs[k] = expr.attrs[k]
654-
self.attr_map["reshape"] = new_attrs
655-
elif isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.dense"):
656-
new_attrs = {}
657-
for k in expr.attrs.keys():
658-
new_attrs[k] = expr.attrs[k]
659-
self.attr_map["nn.dense"] = new_attrs
660-
661-
_analysis.post_order_visit(pre, visit_func)
662-
663-
def callback(self, pre, post, node_map):
664-
self.get_attr(pre)
665-
666-
data = node_map[self.data][0]
667-
weight = node_map[self.weight][0]
668-
bias = node_map[self.bias][0]
669-
670-
den = relay.op.nn.dense(data, weight)
671-
added = relay.op.add(bias, den)
672-
return relay.op.reshape(added, self.attr_map["reshape"]["newshape"])
673-
674-
675641
def rewrite_dense_bias_reshape_last(mod):
676642
"""Rewrite the input graph to reorder reshape operators so that
677643
we can perform dense_bias fusion and then offload them to byoc part.
678644
"""
679-
mod["main"] = rewrite(DenseReshapeBiasRewrite(), mod["main"])
645+
mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(has_gelu=False), mod["main"])
680646
return mod

src/runtime/contrib/dnnl/dnnl_json_runtime.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
158158
if (std::regex_match(op_name, gelu_pat)) {
159159
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
160160
}
161-
if (ops.len() != 0){
161+
if (ops.len() != 0) {
162162
attr.set_post_ops(ops);
163163
}
164164

tests/python/contrib/test_dnnl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ def test_dense(run_module, dtype="float32"):
900900
config = dense, dic, param_lst
901901
run_and_verify_func(config, run_module=run_module, dtype=dtype)
902902

903-
dense, dic, param_lst = get_dense(x_shape, k_shape, activation="gelu", dtype=dtype)
903+
dense, dic, param_lst = get_dense(x_shape, k_shape, activation="gelu", dtype=dtype)
904904
dense = tvm.IRModule.from_expr(dense)
905905
config = dense, dic, param_lst
906906
run_and_verify_func(config, run_module=run_module, dtype=dtype)
@@ -920,7 +920,7 @@ def test_dense_pattern(run_module, dtype="float32"):
920920
config = dense_bias, dic, param_lst
921921
run_and_verify_func(config, run_module=run_module, dtype=dtype)
922922

923-
dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, activation="gelu", dtype=dtype)
923+
dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, activation="gelu", dtype=dtype)
924924
dense_bias = tvm.IRModule.from_expr(dense_bias)
925925
config = dense_bias, dic, param_lst
926926
run_and_verify_func(config, run_module=run_module, dtype=dtype)

0 commit comments

Comments
 (0)