From e0a687c7ecf3fb73aa354121182fdc88f81f9846 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 16 Feb 2023 18:10:57 -0800 Subject: [PATCH 1/9] Support kIsHostFunc. --- include/tvm/tir/function.h | 7 +++++++ src/tir/analysis/verify_memory.cc | 4 ++++ src/tir/transforms/primfunc_utils.cc | 4 ++++ src/tir/transforms/split_host_device.cc | 8 ++++++++ 4 files changed, 23 insertions(+) diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index e135c261990b..a2e8484e5d68 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -325,6 +325,13 @@ constexpr const char* kIsEntryFunc = "tir.is_entry_func"; */ constexpr const char* kIsGlobalFunc = "tir.is_global_func"; +/*! + * \brief Mark the function as run on the host. + * + * Type: Integer + */ +constexpr const char* kIsHostFunc = "tir.is_host_func"; + } // namespace attr } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 9d932d236355..911475a44478 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -177,6 +177,10 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Interface of VerifyMemory pass std::vector VerifyMemory_(const PrimFunc& func) { + // skip the check if the function is host function. + if (func->GetAttr(tvm::tir::attr::kIsHostFunc) == 1) + return {}; + auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index d2bb259f9921..a985df626cdd 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -30,6 +30,10 @@ namespace tir { namespace transform { transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + if (f->GetAttr(tvm::tir::attr::kIsHostFunc) == 1) { + ICHECK(target->host.defined()); + return WithAttr(std::move(f), tvm::attr::kTarget, target->host.value()); + } return WithAttr(std::move(f), tvm::attr::kTarget, target); }; return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {}); diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 2de7d38d7d57..8628e8ec0f1e 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -313,6 +313,7 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); + device_func = WithAttr(std::move(device_func), tir::attr::kIsHostFunc, Integer(0)); if (m.use_dyn_shmem_) { device_func = WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); @@ -346,6 +347,13 @@ class HostDeviceSplitter : public StmtMutator { }; PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { + // skip the check if the function is host function. + if (func->GetAttr(tvm::tir::attr::kIsHostFunc) == 1) { + // set the host target to None. + func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr)); + return std::move(func); + } + auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); From de3513b33f609d5384300000eaee0200f4d09cb7 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 16 Feb 2023 22:16:45 -0800 Subject: [PATCH 2/9] Add unit test. --- src/tir/analysis/verify_memory.cc | 2 +- src/tir/transforms/split_host_device.cc | 2 +- tests/python/unittest/test_tir_host_func.py | 58 +++++++++++++++++++++ 3 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 tests/python/unittest/test_tir_host_func.py diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 911475a44478..def6558ac7ff 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -180,7 +180,7 @@ std::vector VerifyMemory_(const PrimFunc& func) { // skip the check if the function is host function. if (func->GetAttr(tvm::tir::attr::kIsHostFunc) == 1) return {}; - + auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 8628e8ec0f1e..955a84c217f9 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -353,7 +353,7 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr)); return std::move(func); } - + auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/tests/python/unittest/test_tir_host_func.py b/tests/python/unittest/test_tir_host_func.py new file mode 100644 index 000000000000..052e162d79bd --- /dev/null +++ b/tests/python/unittest/test_tir_host_func.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm +from tvm import te +from tvm.script import ir as I +from tvm.script import tir as T + + +def test_host_func(): + """Test that host functions are not split.""" + # te schedule copied from test_tir_transform_split_host_device.py + m = te.size_var("m") + l = te.size_var("l") + A = te.placeholder((m, l), name="A") + + A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") + A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2") + + s = te.create_schedule(A2.op) + xo, xi = s[A2].split(A2.op.axis[0], factor=8) + s[A2].bind(xo, te.thread_axis("blockIdx.x")) + s[A1].compute_at(s[A2], xo) + s[A1].set_scope("shared") + + mod = tvm.lower(s, [A, A2], name="f") + + assert len(mod.get_global_vars()) == 1, """Before split, expected 1 global function.""" + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr( + { + "global_symbol": "test", + "target": tvm.target.Target("cuda"), + "tir.is_host_func": True, + } + ) + )(mod) + mod = tvm.tir.transform.SplitHostDevice()(mod) + assert len(mod.get_global_vars()) == 1, """Expected host function not to be splited.""" + + +if __name__ == "__main__": + test_host_func() From 8df9ff6e0bd34eeafe3aec54a9a2e997ea7b8893 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 16 Feb 2023 23:45:38 -0800 Subject: [PATCH 3/9] Revert verify memory pass. --- src/tir/analysis/verify_memory.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index def6558ac7ff..9d932d236355 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -177,10 +177,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Interface of VerifyMemory pass std::vector VerifyMemory_(const PrimFunc& func) { - // skip the check if the function is host function. - if (func->GetAttr(tvm::tir::attr::kIsHostFunc) == 1) - return {}; - auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; From 74155905f0755dee1f159217aab807eab199aacf Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 17 Feb 2023 15:40:03 -0800 Subject: [PATCH 4/9] Address comments. --- include/tvm/tir/function.h | 2 +- src/tir/transforms/primfunc_utils.cc | 6 +- src/tir/transforms/split_host_device.cc | 7 --- tests/python/unittest/test_tir_host_func.py | 63 ++++++++++++++------- 4 files changed, 47 insertions(+), 31 deletions(-) diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index a2e8484e5d68..48328263fb55 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -326,7 +326,7 @@ constexpr const char* kIsEntryFunc = "tir.is_entry_func"; constexpr const char* kIsGlobalFunc = "tir.is_global_func"; /*! - * \brief Mark the function as run on the host. + * \brief Mark the function as run on the host, mutually exclusive with kTarget. * * Type: Integer */ diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index a985df626cdd..cabd9bd62eb0 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -31,8 +31,10 @@ namespace transform { transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { if (f->GetAttr(tvm::tir::attr::kIsHostFunc) == 1) { - ICHECK(target->host.defined()); - return WithAttr(std::move(f), tvm::attr::kTarget, target->host.value()); + return WithAttrs(std::move(f), Map{ + {tvm::attr::kTarget, target->host.value_or(Target("llvm"))}, + {tvm::tir::attr::kIsHostFunc, Integer(0)}, + }); } return WithAttr(std::move(f), tvm::attr::kTarget, target); }; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 955a84c217f9..8f2d3914a435 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -347,13 +347,6 @@ class HostDeviceSplitter : public StmtMutator { }; PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { - // skip the check if the function is host function. - if (func->GetAttr(tvm::tir::attr::kIsHostFunc) == 1) { - // set the host target to None. - func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr)); - return std::move(func); - } - auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/tests/python/unittest/test_tir_host_func.py b/tests/python/unittest/test_tir_host_func.py index 052e162d79bd..3ee7050b03b9 100644 --- a/tests/python/unittest/test_tir_host_func.py +++ b/tests/python/unittest/test_tir_host_func.py @@ -14,44 +14,65 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - import tvm -from tvm import te from tvm.script import ir as I from tvm.script import tir as T +from tvm.meta_schedule.testing import te_workload +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring +# fmt: off -def test_host_func(): - """Test that host functions are not split.""" - # te schedule copied from test_tir_transform_split_host_device.py - m = te.size_var("m") - l = te.size_var("l") - A = te.placeholder((m, l), name="A") +@I.ir_module +class Module: + @T.prim_func + def main( + A: T.Buffer((729, 729), "float32"), + B: T.Buffer((729, 729), "float32"), + C: T.Buffer((729, 729), "float32"), + ): + T.func_attr( + { + "global_symbol": "test", + "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), + "tir.is_host_func": True, + "tir.noalias": True, + } + ) + # with T.block("root"): + for i, j, k in T.grid(729, 729, 729): + with T.block("C"): + v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) + T.reads(A[v_i, v_k], B[v_k, v_j]) + T.writes(C[v_i, v_j]) + with T.init(): + C[v_i, v_j] = T.float32(0) + C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j] - A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") - A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2") +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring - s = te.create_schedule(A2.op) - xo, xi = s[A2].split(A2.op.axis[0], factor=8) - s[A2].bind(xo, te.thread_axis("blockIdx.x")) - s[A1].compute_at(s[A2], xo) - s[A1].set_scope("shared") - mod = tvm.lower(s, [A, A2], name="f") +def test_host_func(): + """Test that host functions are not split.""" + # te schedule copied from test_tir_transform_split_host_device.py - assert len(mod.get_global_vars()) == 1, """Before split, expected 1 global function.""" + func = tvm.te.create_prim_func( + te_workload.matmul(729, 729, 729, in_dtype="float32", out_dtype="float32") + ) + mod = tvm.ir.IRModule({"main": func}) + mod.show() + target = tvm.target.Target("cuda") mod = tvm.tir.transform.Apply( lambda f: f.with_attr( { "global_symbol": "test", - "target": tvm.target.Target("cuda"), + "target": target, "tir.is_host_func": True, } ) )(mod) - mod = tvm.tir.transform.SplitHostDevice()(mod) - assert len(mod.get_global_vars()) == 1, """Expected host function not to be splited.""" + mod = tvm.tir.transform.BindTarget(target)(mod) + tvm.ir.assert_structural_equal(mod, Module) if __name__ == "__main__": From ca9e97b7d5ba0bee7364d3ae66e8a014d0b3f6e3 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 17 Feb 2023 15:59:38 -0800 Subject: [PATCH 5/9] Make sure is cleared. --- tests/python/unittest/test_tir_host_func.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_host_func.py b/tests/python/unittest/test_tir_host_func.py index 3ee7050b03b9..39370e0ee1d3 100644 --- a/tests/python/unittest/test_tir_host_func.py +++ b/tests/python/unittest/test_tir_host_func.py @@ -34,7 +34,7 @@ def main( { "global_symbol": "test", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), - "tir.is_host_func": True, + "tir.is_host_func": 0, "tir.noalias": True, } ) @@ -60,19 +60,21 @@ def test_host_func(): te_workload.matmul(729, 729, 729, in_dtype="float32", out_dtype="float32") ) mod = tvm.ir.IRModule({"main": func}) - mod.show() target = tvm.target.Target("cuda") mod = tvm.tir.transform.Apply( lambda f: f.with_attr( { "global_symbol": "test", "target": target, - "tir.is_host_func": True, + "tir.is_host_func": 1, } ) )(mod) mod = tvm.tir.transform.BindTarget(target)(mod) tvm.ir.assert_structural_equal(mod, Module) + assert ( + mod["main"].attrs["tir.is_host_func"] == 0 + ), """Target and is_host_func attributes should be mutually exclusive""" if __name__ == "__main__": From 235f2140528c7928601cfd70a58d4025fdf0d5a9 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 17 Feb 2023 17:08:04 -0800 Subject: [PATCH 6/9] Fix linting. --- src/tir/transforms/primfunc_utils.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index cabd9bd62eb0..a824d3589f2c 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -31,10 +31,11 @@ namespace transform { transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { if (f->GetAttr(tvm::tir::attr::kIsHostFunc) == 1) { - return WithAttrs(std::move(f), Map{ - {tvm::attr::kTarget, target->host.value_or(Target("llvm"))}, - {tvm::tir::attr::kIsHostFunc, Integer(0)}, - }); + return WithAttrs(std::move(f), + Map{ + {tvm::attr::kTarget, target->host.value_or(Target("llvm"))}, + {tvm::tir::attr::kIsHostFunc, Integer(0)}, + }); } return WithAttr(std::move(f), tvm::attr::kTarget, target); }; From 54dab404961fd077e0aa4e1f37b5da0984a33928 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 17 Feb 2023 19:01:03 -0800 Subject: [PATCH 7/9] Remove target attribute. --- tests/python/unittest/test_tir_host_func.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_tir_host_func.py b/tests/python/unittest/test_tir_host_func.py index 39370e0ee1d3..e7c99ee5fc4c 100644 --- a/tests/python/unittest/test_tir_host_func.py +++ b/tests/python/unittest/test_tir_host_func.py @@ -65,7 +65,6 @@ def test_host_func(): lambda f: f.with_attr( { "global_symbol": "test", - "target": target, "tir.is_host_func": 1, } ) From d6dea2d720403769e2b8b74b38f047b9231079ba Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 17 Feb 2023 19:12:49 -0800 Subject: [PATCH 8/9] Make attributesmutually exclusive. --- src/tir/transforms/primfunc_utils.cc | 7 ++----- tests/python/unittest/test_tir_host_func.py | 3 +-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index a824d3589f2c..208077b492da 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -31,11 +31,8 @@ namespace transform { transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { if (f->GetAttr(tvm::tir::attr::kIsHostFunc) == 1) { - return WithAttrs(std::move(f), - Map{ - {tvm::attr::kTarget, target->host.value_or(Target("llvm"))}, - {tvm::tir::attr::kIsHostFunc, Integer(0)}, - }); + return WithAttr(std::move(WithoutAttr(std::move(f), tvm::tir::attr::kIsHostFunc)), + tvm::attr::kTarget, target->host.value_or(Target("llvm"))); } return WithAttr(std::move(f), tvm::attr::kTarget, target); }; diff --git a/tests/python/unittest/test_tir_host_func.py b/tests/python/unittest/test_tir_host_func.py index e7c99ee5fc4c..ea0ad7ba4a8a 100644 --- a/tests/python/unittest/test_tir_host_func.py +++ b/tests/python/unittest/test_tir_host_func.py @@ -34,7 +34,6 @@ def main( { "global_symbol": "test", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), - "tir.is_host_func": 0, "tir.noalias": True, } ) @@ -72,7 +71,7 @@ def test_host_func(): mod = tvm.tir.transform.BindTarget(target)(mod) tvm.ir.assert_structural_equal(mod, Module) assert ( - mod["main"].attrs["tir.is_host_func"] == 0 + "tir.is_host_func" not in mod["main"].attrs ), """Target and is_host_func attributes should be mutually exclusive""" From 54cb5a32332dc6074ee23421d325103d67abac43 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 17 Feb 2023 19:15:56 -0800 Subject: [PATCH 9/9] Remove unnecessary attribute. --- src/tir/transforms/split_host_device.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 8f2d3914a435..2de7d38d7d57 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -313,7 +313,6 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); - device_func = WithAttr(std::move(device_func), tir::attr::kIsHostFunc, Integer(0)); if (m.use_dyn_shmem_) { device_func = WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1));