Skip to content

Commit 8613c79

Browse files
authored
[TIR] Enable Host Func Attribute for PrimFunc (#14020)
1 parent ac57b01 commit 8613c79

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed

include/tvm/tir/function.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,13 @@ constexpr const char* kIsEntryFunc = "tir.is_entry_func";
325325
*/
326326
constexpr const char* kIsGlobalFunc = "tir.is_global_func";
327327

328+
/*!
329+
* \brief Mark the function as run on the host, mutually exclusive with kTarget.
330+
*
331+
* Type: Integer
332+
*/
333+
constexpr const char* kIsHostFunc = "tir.is_host_func";
334+
328335
} // namespace attr
329336
} // namespace tir
330337
} // namespace tvm

src/tir/transforms/primfunc_utils.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ namespace tir {
3030
namespace transform {
3131
transform::Pass BindTarget(Target target) {
3232
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
33+
if (f->GetAttr<Integer>(tvm::tir::attr::kIsHostFunc) == 1) {
34+
return WithAttr(std::move(WithoutAttr(std::move(f), tvm::tir::attr::kIsHostFunc)),
35+
tvm::attr::kTarget, target->host.value_or(Target("llvm")));
36+
}
3337
return WithAttr(std::move(f), tvm::attr::kTarget, target);
3438
};
3539
return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {});
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import tvm
18+
from tvm.script import ir as I
19+
from tvm.script import tir as T
20+
from tvm.meta_schedule.testing import te_workload
21+
22+
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring
23+
# fmt: off
24+
25+
@I.ir_module
26+
class Module:
27+
@T.prim_func
28+
def main(
29+
A: T.Buffer((729, 729), "float32"),
30+
B: T.Buffer((729, 729), "float32"),
31+
C: T.Buffer((729, 729), "float32"),
32+
):
33+
T.func_attr(
34+
{
35+
"global_symbol": "test",
36+
"target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}),
37+
"tir.noalias": True,
38+
}
39+
)
40+
# with T.block("root"):
41+
for i, j, k in T.grid(729, 729, 729):
42+
with T.block("C"):
43+
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
44+
T.reads(A[v_i, v_k], B[v_k, v_j])
45+
T.writes(C[v_i, v_j])
46+
with T.init():
47+
C[v_i, v_j] = T.float32(0)
48+
C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
49+
50+
# fmt: on
51+
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring
52+
53+
54+
def test_host_func():
55+
"""Test that host functions are not split."""
56+
# te schedule copied from test_tir_transform_split_host_device.py
57+
58+
func = tvm.te.create_prim_func(
59+
te_workload.matmul(729, 729, 729, in_dtype="float32", out_dtype="float32")
60+
)
61+
mod = tvm.ir.IRModule({"main": func})
62+
target = tvm.target.Target("cuda")
63+
mod = tvm.tir.transform.Apply(
64+
lambda f: f.with_attr(
65+
{
66+
"global_symbol": "test",
67+
"tir.is_host_func": 1,
68+
}
69+
)
70+
)(mod)
71+
mod = tvm.tir.transform.BindTarget(target)(mod)
72+
tvm.ir.assert_structural_equal(mod, Module)
73+
assert (
74+
"tir.is_host_func" not in mod["main"].attrs
75+
), """Target and is_host_func attributes should be mutually exclusive"""
76+
77+
78+
if __name__ == "__main__":
79+
test_host_func()

0 commit comments

Comments
 (0)