@@ -2440,6 +2440,144 @@ def return_inside_dataflow(A: R.Tensor([16], "float16")):
24402440
24412441 tvm .ir .assert_structural_equal (output_then_return , return_inside_dataflow )
24422442
2443+ def test_symbolic_shape_variables_are_size_var ():
2444+ """Symbolic variables inferred from shapes are SizeVar
2445+ The indices in `R.strided_slice` follow Python's conventions for
2446+ negative indices. Absent any additional information, a slice
2447+ `arr[0:i]` would either have length `i` when `i >= 0`, or length
2448+ `len(arr) + i` when `i < 0`.
2449+ In this case, though, the dynamic `extent` variable is known to be
2450+ non-negative, because negative values may not be used as the
2451+ dimensions of `R.Tensor` or `R.Shape`. Because Relax struct
2452+ inference is performed while TVMScript is being parsed, this
2453+ constraint must be exposed during TVMScript parsing in order to
2454+ correctly infer the resulting StructInfo.
2455+ """
2456+
2457+ @R .function (private = True )
2458+ def inferred_sinfo (A : R .Tensor (["extent" ])):
2459+ extent = T .int64 ()
2460+ output = R .strided_slice (A , [0 ], [0 ], [extent ])
2461+ return output
2462+
2463+ @R .function (private = True )
2464+ def expected (A : R .Tensor (["extent" ])) -> R .Tensor (["extent" ]):
2465+ extent = T .int64 ()
2466+ output : R .Tensor ([extent ]) = R .strided_slice (A , [0 ], [0 ], [extent ])
2467+ return output
2468+
2469+ tvm .ir .assert_structural_equal (inferred_sinfo , expected )
2470+
2471+ assert isinstance (inferred_sinfo .params [0 ].struct_info .shape [0 ], tir .SizeVar )
2472+
2473+
2474+ def test_symbolic_variables_from_prim_value_may_be_negative ():
2475+ """Symbolic variables inferred from R.Prim are Var
2476+ Not all symbolic variables represent shapes. While a
2477+ `relax::PrimValue` can be the source of definition for a TIR
2478+ variable, a `relax::PrimValue` may not represent a shape, and may
2479+ be negative.
2480+ This test is similar to
2481+ `test_symbolic_shape_variables_are_size_var`, except that the
2482+ `extent` variable is defined by a `R.Prim` argument, and not by a
2483+ `R.Tensor` argument. As a result, we do not know whether `extent`
2484+ is negative, and cannot simplify expressions that depend on
2485+ `extent<0`.
2486+ """
2487+
2488+ @R .function (private = True )
2489+ def inferred_sinfo (A : R .Tensor ([16 ]), _ : R .Prim (value = "extent" )):
2490+ extent = T .int64 ()
2491+ output = R .strided_slice (A , [0 ], [0 ], [extent ])
2492+ return output
2493+
2494+ @R .function (private = True )
2495+ def expected (A : R .Tensor ([16 ]), _ : R .Prim (value = "extent" )):
2496+ extent = T .int64 ()
2497+ output : R .Tensor (
2498+ [T .min (T .max (T .if_then_else (extent < 0 , extent + 16 , extent ), 0 ), 16 )]
2499+ ) = R .strided_slice (A , [0 ], [0 ], [extent ])
2500+ return output
2501+
2502+ tvm .ir .assert_structural_equal (inferred_sinfo , expected )
2503+
2504+ assert not isinstance (inferred_sinfo .params [1 ].struct_info .value , tir .SizeVar )
2505+
2506+
2507+ def test_other_arguments_may_cause_prim_value_to_define_size_var ():
2508+ """Other arguments may cause R.Prim to hold SizeVar
2509+ This test is similar to
2510+ `test_symbolic_variables_from_prim_value_may_be_negative`, except
2511+ that `extent` also appears in a `R.Shape`. While the
2512+ `R.Prim(value="extent")` occurs first in the parameter list, and
2513+ is the source of definition, the presence of `extent` in `R.Shape`
2514+ parameter shows that it is a `SizeVar`.
2515+ """
2516+
2517+ @R .function (private = True )
2518+ def inferred_sinfo (
2519+ A : R .Tensor ([16 ]),
2520+ _prim : R .Prim (value = "extent" ),
2521+ _shape : R .Shape (
2522+ ["extent" ],
2523+ ),
2524+ ):
2525+ extent = T .int64 ()
2526+ output = R .strided_slice (A , [0 ], [0 ], [extent ])
2527+ return output
2528+
2529+ @R .function (private = True )
2530+ def expected (
2531+ A : R .Tensor ([16 ]),
2532+ _prim : R .Prim (value = "extent" ),
2533+ _shape : R .Shape (["extent" ]),
2534+ ):
2535+ extent = T .int64 ()
2536+ output : R .Tensor ([T .min (extent , 16 )]) = R .strided_slice (A , [0 ], [0 ], [extent ])
2537+ return output
2538+
2539+ tvm .ir .assert_structural_equal (inferred_sinfo , expected )
2540+
2541+ assert isinstance (inferred_sinfo .params [1 ].struct_info .value , tir .SizeVar )
2542+
2543+
2544+ @pytest .mark .xfail (reason = "Bug: Implicit bounds not provided when parsing" )
2545+ def test_known_positive_expressions ():
2546+ """Expressions may be known as non-negative
2547+ The variable `N` is not defined as a shape variable, and may be
2548+ either positive or negative. However, the expression `N+16` is
2549+ used as the shape of a tensor, and is therefore known not to be
2550+ negative. Later use of the expression `N+16 < 0` may therefore be
2551+ simplified.
2552+ This test is currently marked as failing. When using
2553+ `relax::BlockBuilder::VisitWithNewScope` is provided with
2554+ parameters, it can mark shape expressions as non-negative, in
2555+ addition to individual variables. However, this is not currently
2556+ used for TVMScript parsing.
2557+ """
2558+
2559+ @R .function (private = True )
2560+ def inferred_sinfo (
2561+ A : R .Tensor (["N + 16" ]),
2562+ _ : R .Prim (value = "N" ),
2563+ ):
2564+ N = T .int64 ()
2565+ output = R .strided_slice (A , [0 ], [0 ], [N + 16 ])
2566+ return output
2567+
2568+ @R .function (private = True )
2569+ def expected (
2570+ A : R .Tensor (["N + 16" ]),
2571+ _ : R .Prim (value = "N" ),
2572+ ):
2573+ N = T .int64 ()
2574+ output : R .Tensor ([N + 16 ]) = R .strided_slice (A , [0 ], [0 ], [N + 16 ])
2575+ return output
2576+
2577+ tvm .ir .assert_structural_equal (inferred_sinfo , expected )
2578+
2579+ assert not isinstance (inferred_sinfo .params [1 ].struct_info .value , tir .SizeVar )
2580+
24432581
24442582if __name__ == "__main__" :
24452583 tvm .testing .main ()
0 commit comments