Skip to content

Commit bcba3dd

Browse files
committed
add test case
1 parent c484dd8 commit bcba3dd

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

tests/python/relay/test_ir_parser.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,45 @@ def test_load_prelude():
910910
tvm.parser.parse(mod.astext())
911911

912912

913+
def test_call_attrs():
914+
def get_func(shape, dtype):
915+
x0 = relay.var("data", shape=shape, dtype=dtype)
916+
w0 = relay.var("weight", shape=shape, dtype=dtype)
917+
a = relay.nn.dense(x0, w0)
918+
b = relay.nn.relu(a)
919+
d = relay.add(b, relay.const(1.0, dtype=dtype))
920+
return relay.Function([x0, w0], d)
921+
922+
# build relay graph
923+
shape = (2, 4)
924+
dtype = "float32"
925+
sub_func = get_func(shape, dtype)
926+
p0 = relay.var("p0", shape=shape, dtype=dtype)
927+
p1 = relay.var("p1", shape=shape, dtype=dtype)
928+
attr = tvm.ir.make_node("attrs.TestAttrs", name="func_call_attrs")
929+
call = relay.Call(sub_func, [p0, p1], attrs=attr)
930+
func = relay.Function([p0, p1], call)
931+
932+
# build relay module
933+
mod = tvm.IRModule()
934+
mod["main"] = func
935+
mod = tvm.relay.transform.InferType()(mod)
936+
937+
# assert equal
938+
program = """
939+
def @main(%p0: Tensor[(2, 4), float32], %p1: Tensor[(2, 4), float32]) {
940+
%2 = fn (%data: Tensor[(2, 4), float32], %weight: Tensor[(2, 4), float32]) {
941+
%0 = nn.dense(%data, %weight, units=None);
942+
%1 = nn.relu(%0);
943+
add(%1, 1f)
944+
};
945+
%2(%p0, %p1, name="func_call_attrs", attrs_type_key="attrs.TestAttrs")
946+
}
947+
"""
948+
parsed = parse_module(program)
949+
assert_graph_equal(parsed, mod)
950+
951+
913952
if __name__ == "__main__":
914953
import sys
915954

0 commit comments

Comments
 (0)