diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 259ef2d1715a..681c6cd7842d 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -100,6 +100,9 @@ class Schedule : public NodeRef { * \return reference to self. */ Schedule& reorder(const Array& order); // NOLINT(*) + Schedule& tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer, + IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner, + Expr x_factor, Expr y_factor); // NOLINT(*) }; /*! diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index dee3f3309481..02f73660f0e3 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -107,3 +107,8 @@ def reorder(self, *args): The order to be ordered """ _function_internal._ScheduleReorder(self, args) + + def tile(self, x_parent, y_parent, x_factor, y_factor): + x_outer, y_outer, x_inner, y_inner = _function_internal._ScheduleTile( + self, x_parent, y_parent, x_factor, y_factor) + return x_outer, y_outer, x_inner, y_inner diff --git a/src/c_api/c_api_lang.cc b/src/c_api/c_api_lang.cc index 6ee137a7bcf4..49cc9c642e83 100644 --- a/src/c_api/c_api_lang.cc +++ b/src/c_api/c_api_lang.cc @@ -151,5 +151,13 @@ TVM_REGISTER_API(_ScheduleReorder) .reorder(args.at(1)); }); +TVM_REGISTER_API(_ScheduleTile) + .set_body([](const ArgStack& args, RetValue *ret) { + IterVar x_outer, y_outer, x_inner, y_inner; + args.at(0).operator Schedule() + .tile(args.at(1), args.at(2), &x_outer, &y_outer, + &x_inner, &y_inner, args.at(3), args.at(4)); + *ret = Array({x_outer, y_outer, x_inner, y_inner}); + }); } // namespace tvm diff --git a/src/lang/schedule.cc b/src/lang/schedule.cc index 47f5ee744285..1628e5ef3bc5 100644 --- a/src/lang/schedule.cc +++ b/src/lang/schedule.cc @@ -148,6 +148,16 @@ Schedule& Schedule::reorder(const Array& order) { // NOLINT(*) return *this; } +Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer, + IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner, + Expr x_factor, Expr y_factor) { // NOLINT(*) + + split(x_parent, p_x_outer, p_x_inner, x_factor); + split(y_parent, p_y_outer, p_y_inner, y_factor); + reorder(Array({*p_x_inner, *p_y_inner, *p_x_outer, *p_y_outer})); + return *this; +} + IterVarRelation SplitNode::make( IterVar parent, IterVar outer, IterVar inner, Expr factor) { diff --git a/tests/python/test_schedule.py b/tests/python/test_schedule.py index 773be8b55c65..850781a72a0b 100644 --- a/tests/python/test_schedule.py +++ b/tests/python/test_schedule.py @@ -34,8 +34,18 @@ def test_reorder(): sch_T.reorder(*order) assert tuple(sch_T.leaf_iter_vars) == order +def test_tile(): + m = tvm.Var('m') + n = tvm.Var('n') + A = tvm.placeholder((m, n), name='A') + T = tvm.compute((m, n), lambda i, j: A[i, j]) + + sch_T = tvm.Schedule(T.op, scope="shared") + xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5) + assert tuple(sch_T.leaf_iter_vars) == (xi, yi, xo, yo) if __name__ == "__main__": test_schedule_create() test_reorder() + test_tile()