@@ -910,6 +910,45 @@ def test_load_prelude():
910910 tvm .parser .parse (mod .astext ())
911911
912912
913+ def test_call_attrs ():
914+ def get_func (shape , dtype ):
915+ x0 = relay .var ("data" , shape = shape , dtype = dtype )
916+ w0 = relay .var ("weight" , shape = shape , dtype = dtype )
917+ a = relay .nn .dense (x0 , w0 )
918+ b = relay .nn .relu (a )
919+ d = relay .add (b , relay .const (1.0 , dtype = dtype ))
920+ return relay .Function ([x0 , w0 ], d )
921+
922+ # build relay graph
923+ shape = (2 , 4 )
924+ dtype = "float32"
925+ sub_func = get_func (shape , dtype )
926+ p0 = relay .var ("p0" , shape = shape , dtype = dtype )
927+ p1 = relay .var ("p1" , shape = shape , dtype = dtype )
928+ attr = tvm .ir .make_node ("attrs.TestAttrs" , name = "func_call_attrs" )
929+ call = relay .Call (sub_func , [p0 , p1 ], attrs = attr )
930+ func = relay .Function ([p0 , p1 ], call )
931+
932+ # build relay module
933+ mod = tvm .IRModule ()
934+ mod ["main" ] = func
935+ mod = tvm .relay .transform .InferType ()(mod )
936+
937+ # assert equal
938+ program = """
939+ def @main(%p0: Tensor[(2, 4), float32], %p1: Tensor[(2, 4), float32]) {
940+ %2 = fn (%data: Tensor[(2, 4), float32], %weight: Tensor[(2, 4), float32]) {
941+ %0 = nn.dense(%data, %weight, units=None);
942+ %1 = nn.relu(%0);
943+ add(%1, 1f)
944+ };
945+ %2(%p0, %p1, name="func_call_attrs", attrs_type_key="attrs.TestAttrs")
946+ }
947+ """
948+ parsed = parse_module (program )
949+ assert_graph_equal (parsed , mod )
950+
951+
913952if __name__ == "__main__" :
914953 import sys
915954
0 commit comments