-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
Description
I wanted to unit test TEPass with something like:
def test_lower_primitive():
input_mod = tvm.parser.parse(
"""
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
%0 = fn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Primitive=1) -> Tensor[(5, 7), float32] {
add(%x, %y)
};
%0(%a, %a)
}
""",
"from_string", None, None,
)
actual_mod = transform(input_mod)
expected_mod = tvm.parser.parse(
"""
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
%0 = (%a, %a);
call_lowered(@test_fused_add, %0, metadata={relay_attrs={Primitive=1},all_prim_fn_vars=[@test_fused_add]})
}
def @test_fused_add(...) { ... }
""", "from_string", None, None)
tvm.ir.assert_structural_equal(actual_mod, expected_mod, True)
Firstly, it's not possible to express the call_lowered metadata attributes in the form written, so it needs to be bound to a meta table entry.
test_fused_add = actual_mod.get_global_var('test_fused_add')
call_lowered_attrs = {
"relay_attrs": tvm.ir.make_node("DictAttrs", Primitive=tvm.tir.IntImm("int32", 1)),
"all_prim_fn_vars": [test_fused_add]
}
metadata = {
"attrs": [call_lowered_attrs]
}
That's ok, but the global var baund to 'test_fused_add' is not right, it needs to be the same object as created to represent the definition in the expected module.
I think we should have a structural_equal mode that compares on name_hint alone. We almost have that, but somehow in all the logic the 'map_free_vars' options got reset to False and triggered a failure.