Skip to content

Commit 4122a6a

Browse files
Hzfengsytqchenvinx13MasterJH5574
authored
[TensorIR] CreatePrimFunc from TE (#7987)
Co-authored-by: Tianqi Chen <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Ruihang Lai <[email protected]>
1 parent 254563a commit 4122a6a

File tree

5 files changed

+650
-0
lines changed

5 files changed

+650
-0
lines changed

include/tvm/tir/var.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ class IterVar : public ObjectRef {
298298
inline operator PrimExpr() const;
299299

300300
TVM_DEFINE_OBJECT_REF_METHODS(IterVar, ObjectRef, IterVarNode);
301+
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterVarNode);
301302
};
302303

303304
// inline implementations

python/tvm/te/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .tag import tag_scope
3434
from .operation import placeholder, compute, scan, extern, var, size_var
3535
from .operation import thread_axis, reduce_axis
36+
from .operation import create_prim_func
3637

3738
from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp
3839
from .autodiff import gradient

python/tvm/te/operation.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
""" Operation class for computation declaration."""
1818
# pylint: disable=invalid-name
1919
from numbers import Integral as _Integral
20+
from typing import List
2021

2122
import tvm._ffi
2223
import tvm.tir
@@ -426,3 +427,52 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None):
426427
An iteration variable representing the value.
427428
"""
428429
return tvm.tir.IterVar(dom, name, 2, thread_tag, span)
430+
431+
432+
def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc:
433+
"""Create a TensorIR PrimFunc from tensor expression
434+
Parameters
435+
----------
436+
ops : List[Tensor]
437+
The source expression.
438+
439+
Example
440+
-------
441+
We define a matmul kernel using following code:
442+
443+
.. code-block:: python
444+
445+
import tvm
446+
from tvm import te
447+
448+
A = te.placeholder((128, 128), name="A")
449+
B = te.placeholder((128, 128), name="B")
450+
C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C")
451+
func = create_prim_func([A, B, C])
452+
print(tvm.script.asscript(func))
453+
454+
If we want to use TensorIR schedule to do transformations on such kernel,
455+
we need to use `create_prim_func([A, B, C])` to create a schedulable PrimFunc.
456+
The generated function looks like:
457+
458+
.. code-block:: python
459+
460+
@tvm.script.tir
461+
def tir_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
462+
A = tir.match_buffer(a, (128, 128))
463+
B = tir.match_buffer(b, (128, 128))
464+
C = tir.match_buffer(c, (128, 128))
465+
466+
with tir.block([128, 128, tir.reduce_axis(0, 128)]) as [i, j, k]:
467+
with tir.init():
468+
C[i, j] = 0.0
469+
C[i, j] += A[i, k] * B[j, k]
470+
471+
Returns
472+
-------
473+
func : tir.PrimFunc
474+
The created function.
475+
"""
476+
if not isinstance(ops, list):
477+
ops = [ops]
478+
return _ffi_api.CreatePrimFunc(ops)
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <tvm/runtime/registry.h>
21+
#include <tvm/tir/function.h>
22+
#include <tvm/tir/stmt_functor.h>
23+
24+
#include <algorithm>
25+
26+
#include "../schedule/graph.h"
27+
28+
namespace tvm {
29+
namespace tir {
30+
31+
/*! \brief The helper mutator that transforms ProducerLoad to BufferLoad */
32+
class ProducerToBufferTransformer : public StmtExprMutator {
33+
public:
34+
explicit ProducerToBufferTransformer(const std::unordered_map<te::Tensor, Buffer>& tensor2buffers)
35+
: tensor2buffers_(tensor2buffers) {}
36+
37+
PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
38+
te::Tensor tensor = Downcast<te::Tensor>(op->producer);
39+
auto it = tensor2buffers_.find(tensor);
40+
ICHECK(it != tensor2buffers_.end()) << "IndexError: Cannot find the tensor " << tensor;
41+
const Buffer& buffer = it->second;
42+
return BufferLoad(buffer, op->indices);
43+
}
44+
45+
private:
46+
/*! \brief The Map from Operations to buffers */
47+
const std::unordered_map<te::Tensor, Buffer>& tensor2buffers_;
48+
};
49+
50+
/*! \brief Helper data structural to store informations. */
51+
struct CreateFuncInfo {
52+
/*! \brief The Tensor arg_list. */
53+
Array<te::Tensor> arg_list;
54+
/*! \brief The map from each Tensor to its corresponding buffer. */
55+
std::unordered_map<te::Tensor, Buffer> tensor2buffers;
56+
/*! \brief The transformer from ProducerLoad to BufferLoad. */
57+
ProducerToBufferTransformer transformer;
58+
/*! \brief The buffers should be allocated at function root. */
59+
Array<Buffer> root_alloc;
60+
/*! \brief The count map to make block name unique. */
61+
std::unordered_map<String, int> name_count;
62+
63+
explicit CreateFuncInfo(Array<te::Tensor> arg_list)
64+
: arg_list(std::move(arg_list)), transformer(tensor2buffers) {}
65+
66+
bool IsArg(const te::Tensor& tensor) const {
67+
return std::any_of(arg_list.begin(), arg_list.end(),
68+
[&tensor](const te::Tensor& arg) { return tensor == arg; });
69+
}
70+
71+
String GetUniqueName(const String& prefix) {
72+
String unique_prefix = prefix;
73+
auto it = name_count.find(prefix);
74+
while (name_count.count(unique_prefix)) {
75+
unique_prefix = prefix + "_" + std::to_string(++it->second);
76+
}
77+
name_count[unique_prefix] = 0;
78+
return unique_prefix;
79+
}
80+
};
81+
82+
BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::Tensor& tensor,
83+
Array<PrimExpr> bindings, PrimExpr expr_body,
84+
CreateFuncInfo* info) {
85+
// Step 1. Push_back data_par axis and reduce_axis into block_vars.
86+
Array<IterVar> iter_vars;
87+
std::unordered_map<const VarNode*, PrimExpr> var_map;
88+
iter_vars.reserve(compute_op->axis.size() + compute_op->reduce_axis.size());
89+
auto f_push_block_vars = [&iter_vars, &var_map](const Array<IterVar>& iters) {
90+
for (IterVar iter_var : iters) {
91+
// Create new var
92+
Var new_var(iter_var->var->name_hint, iter_var->var->dtype);
93+
var_map[iter_var->var.get()] = new_var;
94+
95+
IterVarNode* iter_var_node = iter_var.CopyOnWrite();
96+
iter_var_node->dom = Range::FromMinExtent(iter_var->dom->min, iter_var->dom->extent);
97+
iter_var_node->var = new_var;
98+
iter_vars.push_back(iter_var);
99+
}
100+
};
101+
f_push_block_vars(compute_op->axis);
102+
f_push_block_vars(compute_op->reduce_axis);
103+
104+
// Step 2. Declare buffer and update op2buffers
105+
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint());
106+
info->tensor2buffers[tensor] = buffer;
107+
108+
// Step 3. Add Buffer to root_alloc
109+
if (!info->IsArg(tensor)) {
110+
info->root_alloc.push_back(buffer);
111+
}
112+
113+
// Step 4. Calculate indices for BufferStore
114+
Array<PrimExpr> indices;
115+
indices.reserve(compute_op->axis.size());
116+
for (const IterVar& iter_var : compute_op->axis) {
117+
auto it = var_map.find(iter_var->var.get());
118+
ICHECK(it != var_map.end());
119+
indices.push_back(it->second);
120+
}
121+
122+
// Step 5. Create block body.
123+
Optional<Stmt> init = NullOpt;
124+
Stmt body;
125+
if (const auto* reduce = expr_body.as<ReduceNode>()) {
126+
// Case 1. Reduce compute
127+
ICHECK_EQ(reduce->source.size(), 1);
128+
const PrimExpr& lhs = BufferLoad(buffer, indices);
129+
const PrimExpr& rhs = Substitute(info->transformer(reduce->source[0]), var_map);
130+
ICHECK(lhs->dtype == rhs->dtype);
131+
body = BufferStore(buffer, reduce->combiner.get()->operator()({lhs}, {rhs})[0], indices);
132+
init = BufferStore(buffer, reduce->combiner->identity_element[0], indices);
133+
} else {
134+
// Case 2. Data parallel compute
135+
body = BufferStore(buffer, Substitute(info->transformer(expr_body), var_map), indices);
136+
}
137+
138+
// Step 6. Add script_parsing_detect_access attr for auto complete the whole IR.
139+
Map<String, ObjectRef> annotations = compute_op->attrs;
140+
annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3));
141+
142+
// Step 7. Create Block and BlockRealize.
143+
return BlockRealize(/*iter_values=*/std::move(bindings),
144+
/*predicate=*/Bool(true),
145+
/*block=*/
146+
Block(/*iter_vars=*/std::move(iter_vars),
147+
/*reads=*/{},
148+
/*writes=*/{},
149+
/*name_hint=*/info->GetUniqueName(tensor->GetNameHint()),
150+
/*body=*/std::move(body),
151+
/*init=*/std::move(init),
152+
/*alloc_buffers=*/{},
153+
/*match_buffers=*/{},
154+
/*annotations=*/std::move(annotations)));
155+
}
156+
157+
Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info) {
158+
// Step 1. Creating loop vars for block bindings.
159+
Array<IterVar> axes = compute_op->axis;
160+
axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end());
161+
Array<PrimExpr> bindings;
162+
for (size_t i = 0; i < axes.size(); ++i) {
163+
bindings.push_back(Var("i" + std::to_string(i)));
164+
}
165+
// Step 2. Generate block bodies.
166+
Array<Stmt> seq_stmt;
167+
for (int i = 0; i < compute_op->num_outputs(); ++i) {
168+
const te::Tensor& tensor = compute_op.output(i);
169+
PrimExpr expr_body = compute_op->body[i];
170+
seq_stmt.push_back(
171+
GenerateBlockFromTensor(compute_op, tensor, bindings, std::move(expr_body), info));
172+
}
173+
Stmt body = SeqStmt::Flatten(seq_stmt);
174+
175+
// Step 3. Generate loop nesting.
176+
for (size_t i = axes.size(); i > 0; --i) {
177+
const IterVar& axis = axes[i - 1];
178+
const Var& loop_var = Downcast<Var>(bindings[i - 1]);
179+
body = For(loop_var, axis->dom->min, axis->dom->extent, ForKind::kSerial, body);
180+
}
181+
182+
return body;
183+
}
184+
185+
Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* info) {
186+
// Step 1. Check all inputs are visited before and update var_map.
187+
std::unordered_map<const VarNode*, PrimExpr> var_map;
188+
ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size());
189+
for (size_t i = 0; i < extern_op->inputs.size(); ++i) {
190+
const Buffer& placeholder = extern_op->input_placeholders[i];
191+
const te::Tensor& input_tensor = extern_op->inputs[i];
192+
auto it = info->tensor2buffers.find(input_tensor);
193+
ICHECK(it != info->tensor2buffers.end());
194+
var_map[placeholder->data.get()] = it->second->data;
195+
}
196+
197+
// Step 2. Update info with its output tensor and placeholder buffer.
198+
ICHECK_EQ(extern_op->num_outputs(), extern_op->output_placeholders.size());
199+
for (int i = 0; i < extern_op->num_outputs(); ++i) {
200+
const Buffer& placeholder = extern_op->output_placeholders[i];
201+
const te::Tensor& output_tensor = extern_op.output(i);
202+
info->tensor2buffers[output_tensor] = placeholder;
203+
if (!info->IsArg(output_tensor)) {
204+
info->root_alloc.push_back(placeholder);
205+
}
206+
}
207+
208+
// Step 3. Collect Access Region
209+
Array<BufferRegion> reads, writes;
210+
for (const te::Tensor& tensor : extern_op->inputs) {
211+
// We have ICHECK before so it is not needed here.
212+
reads.push_back(BufferRegion::FullRegion(info->tensor2buffers[tensor]));
213+
}
214+
for (const Buffer& buffer : extern_op->output_placeholders) {
215+
writes.push_back(BufferRegion::FullRegion(buffer));
216+
}
217+
218+
Stmt body = Substitute(extern_op->body, var_map);
219+
220+
// Step 4. Generate opaque block as body.
221+
return BlockRealize(/*iter_values=*/{},
222+
/*predicate=*/Bool(true),
223+
/*block=*/
224+
Block(/*iter_vars=*/{},
225+
/*reads=*/std::move(reads),
226+
/*writes=*/std::move(writes),
227+
/*name_hint=*/info->GetUniqueName(extern_op->name),
228+
/*body=*/std::move(body),
229+
/*init=*/NullOpt,
230+
/*alloc_buffers=*/{},
231+
/*match_buffers=*/{},
232+
/*annotations=*/extern_op->attrs));
233+
}
234+
235+
/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */
236+
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
237+
// Step 1. Create tensor read graph.
238+
Array<te::Operation> arg_ops;
239+
for (const te::Tensor& arg : arg_list) {
240+
arg_ops.push_back(arg->op);
241+
}
242+
te::ReadGraph g = te::CreateReadGraph(arg_ops);
243+
Array<te::Operation> order = te::PostDFSOrder(arg_ops, g);
244+
245+
// Step 2. Checking all Operations are supported.
246+
for (const te::Operation& op : order) {
247+
if (!(op->IsInstance<te::PlaceholderOpNode>() || op->IsInstance<te::ComputeOpNode>() ||
248+
op->IsInstance<te::ExternOpNode>()))
249+
LOG(FATAL) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". "
250+
<< "Only te.placeholder and te.compute are allowed for now.";
251+
}
252+
253+
// Infomations used in CreatePrimFunc and its sub-funtions.
254+
CreateFuncInfo info(arg_list);
255+
// Root body stmts.
256+
Array<Stmt> root_stmts;
257+
258+
// Step 3. Rewrite compute stages into blocks.
259+
for (const te::Operation& op : order) {
260+
if (const auto* placeholder = op.as<te::PlaceholderOpNode>()) {
261+
// Case 1. PlaceholderOp (te.placeholder)
262+
ICHECK_EQ(op->num_outputs(), 1);
263+
const te::Tensor& tensor = op.output(0);
264+
// Check op is in op list
265+
ICHECK(info.IsArg(tensor));
266+
const Buffer& buffer = decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name);
267+
info.tensor2buffers[tensor] = buffer;
268+
} else if (const auto* compute_op = op.as<te::ComputeOpNode>()) {
269+
// Case 2. ComputeOp (te.compute)
270+
root_stmts.push_back(GenerateStmtFromCompute(GetRef<te::ComputeOp>(compute_op), &info));
271+
} else if (const auto extern_op = op.as<te::ExternOpNode>()) {
272+
// Case 3. ExternOp (te.extern)
273+
root_stmts.push_back(GenerateStmtFromExternOp(GetRef<te::ExternOp>(extern_op), &info));
274+
} else {
275+
ICHECK(false) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". "
276+
<< "Only te.placeholder and te.compute are allowed for now.";
277+
}
278+
}
279+
280+
// Step 4. Create func and complete it.
281+
Array<Var> parameters;
282+
Map<Var, Buffer> buffer_map;
283+
for (const te::Tensor& tensor : arg_list) {
284+
Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle()));
285+
parameters.push_back(arg);
286+
auto it = info.tensor2buffers.find(tensor);
287+
ICHECK(it != info.tensor2buffers.end());
288+
buffer_map.Set(arg, it->second);
289+
}
290+
PrimFunc func = PrimFunc(/*params=*/std::move(parameters),
291+
/*body=*/SeqStmt::Flatten(root_stmts),
292+
/*ret_type=*/VoidType(),
293+
/*buffer_map=*/std::move(buffer_map));
294+
295+
const auto* complete = runtime::Registry::Get("script.Complete");
296+
ICHECK(complete);
297+
298+
return (*complete)(func, info.root_alloc);
299+
} // namespace tir
300+
301+
TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed([](const Array<te::Tensor>& tensors) {
302+
return CreatePrimFunc(tensors);
303+
});
304+
305+
} // namespace tir
306+
} // namespace tvm

0 commit comments

Comments
 (0)