Skip to content

Commit 2990e8e

Browse files
HzfengsytqchenjunrushaoMasterJH5574
authored andcommitted
[TensorIR][PASS][M1c] PlanUpdateBufferAllocationLocation (apache#7873)
Co-authored-by: Tianqi Chen <[email protected]> Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Ruihang Lai <[email protected]>
1 parent 9ba9e8f commit 2990e8e

File tree

4 files changed

+318
-0
lines changed

4 files changed

+318
-0
lines changed

include/tvm/tir/transform.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,14 @@ TVM_DLL Pass HoistIfThenElse();
352352
*/
353353
TVM_DLL Pass LowerInitBlock();
354354

355+
/*!
356+
* \brief Locate the buffer allocation to the exact position (usually is
357+
* the lca of buffer access). This pass will inject opaque block
358+
* with alloc_buffers at the allocation site.
359+
* \return The pass.
360+
*/
361+
TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();
362+
355363
} // namespace transform
356364
} // namespace tir
357365
} // namespace tvm

python/tvm/tir/transform/transform.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,3 +547,16 @@ def LowerInitBlock():
547547
The result pass
548548
"""
549549
return _ffi_api.LowerInitBlock()
550+
551+
552+
def PlanAndUpdateBufferAllocationLocation():
553+
"""Locate the buffer allocation to the exact position (usually is
554+
the lca of buffer access). This pass will inject opaque block
555+
with alloc_buffers at the allocation site.
556+
557+
Returns
558+
-------
559+
fpass : tvm.transform.Pass
560+
The result pass
561+
"""
562+
return _ffi_api.PlanAndUpdateBufferAllocationLocation()
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \brief Planning where buffers to be allocated and update the AST.
22+
* \file plan_update_buffer_allocation_location.cc
23+
*/
24+
25+
#include <tvm/tir/analysis.h>
26+
#include <tvm/tir/stmt_functor.h>
27+
#include <tvm/tir/transform.h>
28+
29+
namespace tvm {
30+
namespace tir {
31+
32+
class BufferAllocationLocator : public StmtExprMutator {
33+
public:
34+
explicit BufferAllocationLocator(const PrimFunc& func) {
35+
Map<Buffer, Stmt> buffer_lca = DetectBufferAccessLCA(func);
36+
std::unordered_set<const BufferNode*> arg_buffers;
37+
for (const auto& kv : func->buffer_map) {
38+
const Buffer& buffer = kv.second;
39+
arg_buffers.emplace(buffer.get());
40+
buffer_data_to_buffer_.Set(buffer->data, buffer);
41+
}
42+
// create buffers to be allocated at each stmts
43+
for (const auto& kv : buffer_lca) {
44+
const Buffer& buffer = kv.first;
45+
const StmtNode* stmt = kv.second.get();
46+
if (arg_buffers.count(buffer.get())) {
47+
continue;
48+
}
49+
alloc_buffers_[stmt].push_back(buffer);
50+
}
51+
}
52+
53+
private:
54+
Stmt VisitStmt_(const ForNode* op) final {
55+
auto it = alloc_buffers_.find(op);
56+
if (it == alloc_buffers_.end()) {
57+
return StmtMutator::VisitStmt_(op);
58+
}
59+
for (const Buffer& buf : it->second) {
60+
buffer_data_to_buffer_.Set(buf->data, buf);
61+
}
62+
Stmt stmt = StmtMutator::VisitStmt_(op);
63+
op = stmt.as<ForNode>();
64+
ICHECK(op != nullptr);
65+
for (const Buffer& buf : it->second) {
66+
buffer_data_to_buffer_.erase(buf->data);
67+
}
68+
Stmt body = InjectOpaqueBlock(op->body, it->second);
69+
ObjectPtr<ForNode> n = CopyOnWrite(op);
70+
n->body = std::move(body);
71+
return Stmt(n);
72+
}
73+
74+
Stmt VisitStmt_(const BlockNode* op) final {
75+
ICHECK(!op->init.defined());
76+
bool is_root = is_root_;
77+
is_root_ = false;
78+
Array<Buffer> alloc_buffers;
79+
auto it = alloc_buffers_.find(op);
80+
if (it != alloc_buffers_.end()) {
81+
alloc_buffers = it->second;
82+
for (const Buffer& buf : it->second) {
83+
buffer_data_to_buffer_.Set(buf->data, buf);
84+
}
85+
}
86+
Stmt stmt = StmtMutator::VisitStmt_(op);
87+
op = stmt.as<BlockNode>();
88+
ICHECK(op != nullptr);
89+
90+
// Ignore buffer allocated inside the block when getting access region.
91+
if (it != alloc_buffers_.end()) {
92+
for (const Buffer& buf : it->second) {
93+
buffer_data_to_buffer_.erase(buf->data);
94+
}
95+
}
96+
97+
ObjectPtr<BlockNode> n = CopyOnWrite(op);
98+
n->alloc_buffers = std::move(alloc_buffers);
99+
// The read/write regions of root block are always empty.
100+
if (!is_root) {
101+
// Recalculate block access region
102+
CollectReadWrite(GetRef<Block>(op), &n->reads, &n->writes);
103+
}
104+
105+
return Stmt(n);
106+
}
107+
108+
Stmt VisitStmt_(const BufferRealizeNode* op) final {
109+
ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR.";
110+
throw;
111+
}
112+
113+
Stmt InjectOpaqueBlock(Stmt body, const Array<Buffer>& alloc_buffers) {
114+
ICHECK(!alloc_buffers.empty());
115+
Block opaque_block(/*iter_vars=*/{},
116+
/*reads=*/{},
117+
/*writes=*/{},
118+
/*name_hint=*/"",
119+
/*body=*/std::move(body),
120+
/*init=*/NullOpt,
121+
/*alloc_buffers=*/alloc_buffers);
122+
ObjectPtr<BlockNode> n = CopyOnWrite(opaque_block.get());
123+
CollectReadWrite(opaque_block, &n->reads, &n->writes);
124+
BlockRealize realize({}, Bool(true), Block(n));
125+
return std::move(realize);
126+
}
127+
128+
void CollectReadWrite(const Block& block, Array<BufferRegion>* reads,
129+
Array<BufferRegion>* writes) {
130+
Array<Array<BufferRegion>> access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
131+
*reads = access[0];
132+
*writes = access[1];
133+
for (const auto& opaque_access : access[2]) {
134+
reads->push_back(opaque_access);
135+
writes->push_back(opaque_access);
136+
}
137+
}
138+
139+
/*! \brief The map from stmt to the buffers to be allocated under it. */
140+
std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
141+
/*! \brief The buffer already allocated during recursive visiting. */
142+
Map<Var, Buffer> buffer_data_to_buffer_;
143+
/*! \brief indicate the whether the block is root. */
144+
bool is_root_{true};
145+
};
146+
147+
PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
148+
auto fptr = func.CopyOnWrite();
149+
BufferAllocationLocator locator(func);
150+
fptr->body = locator(fptr->body);
151+
return func;
152+
}
153+
154+
namespace transform {
155+
156+
Pass PlanAndUpdateBufferAllocationLocation() {
157+
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
158+
return PlanAndUpdateBufferAllocationLocation(std::move(f));
159+
};
160+
return CreatePrimFuncPass(pass_func, 0, "tir.PlanAndUpdateBufferAllocationLocation", {});
161+
}
162+
163+
TVM_REGISTER_GLOBAL("tir.transform.PlanAndUpdateBufferAllocationLocation")
164+
.set_body_typed(PlanAndUpdateBufferAllocationLocation);
165+
166+
} // namespace transform
167+
168+
} // namespace tir
169+
} // namespace tvm
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import tvm
18+
from tvm import tir
19+
from tvm.script import ty
20+
21+
22+
def _check(original, transformed):
23+
func = original
24+
mod = tvm.IRModule.from_expr(func)
25+
mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
26+
tvm.ir.assert_structural_equal(mod["main"], transformed)
27+
28+
29+
@tvm.script.tir
30+
def element_func(a: ty.handle, c: ty.handle) -> None:
31+
A = tir.match_buffer(a, (16, 16))
32+
C = tir.match_buffer(c, (16, 16))
33+
B = tir.alloc_buffer((16, 16))
34+
for i_0 in range(0, 16):
35+
for j_0 in range(0, 16):
36+
with tir.block([16, 16]) as [i, j]:
37+
B[i, j] = A[i, j] + 1.0
38+
for j_0 in range(0, 16):
39+
with tir.block([16, 16]) as [i, j]:
40+
C[i, j] = B[i, j] * 2.0
41+
42+
43+
@tvm.script.tir
44+
def transformed_element_func(a: ty.handle, c: ty.handle) -> None:
45+
A = tir.match_buffer(a, [16, 16])
46+
C = tir.match_buffer(c, [16, 16])
47+
48+
for i_0 in range(0, 16):
49+
with tir.block([]):
50+
tir.reads([A[i_0, 0:16]])
51+
tir.writes([C[i_0, 0:16]])
52+
B = tir.alloc_buffer([16, 16])
53+
for j_0 in tir.serial(0, 16):
54+
with tir.block([16, 16], "") as [i, j]:
55+
tir.bind(i, i_0)
56+
tir.bind(j, j_0)
57+
B[i, j] = A[i, j] + 1.0
58+
for j_0 in tir.serial(0, 16):
59+
with tir.block([16, 16], "") as [i, j]:
60+
tir.bind(i, i_0)
61+
tir.bind(j, j_0)
62+
C[i, j] = B[i, j] * 2.0
63+
64+
65+
@tvm.script.tir
66+
def original_func() -> None:
67+
A = tir.alloc_buffer((128, 128), "float32")
68+
with tir.block([128, 128]) as [i, j]:
69+
A[i, j] = tir.float32(0)
70+
with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]:
71+
B = tir.alloc_buffer((128, 128), "float32")
72+
C = tir.alloc_buffer((128, 128), "float32")
73+
D = tir.alloc_buffer((128, 128), "float32")
74+
if k == 0:
75+
for ii, jj in tir.grid(4, 4):
76+
B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
77+
for ii, jj in tir.grid(4, 4):
78+
for kk in range(0, 4):
79+
B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk]
80+
for kk in range(0, 4):
81+
B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]
82+
83+
84+
@tvm.script.tir
85+
def transformed_func() -> None:
86+
A = tir.alloc_buffer([128, 128])
87+
with tir.block([128, 128], "") as [i, j]:
88+
A[i, j] = tir.float32(0)
89+
with tir.block([32, 32, tir.reduce_axis(0, 32)], "") as [i, j, k]:
90+
B = tir.alloc_buffer([128, 128])
91+
if k == 0:
92+
for ii, jj in tir.grid(4, 4):
93+
B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
94+
for ii, jj in tir.grid(4, 4):
95+
with tir.block([], ""):
96+
tir.reads([B[((i * 4) + ii), ((j * 4) + jj)]])
97+
tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]])
98+
C = tir.alloc_buffer([128, 128])
99+
for kk in tir.serial(0, 4):
100+
B[((i * 4) + ii), ((j * 4) + jj)] = (
101+
B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)]
102+
)
103+
for kk in tir.serial(0, 4):
104+
with tir.block([], ""):
105+
tir.reads(
106+
[
107+
B[((i * 4) + ii), ((j * 4) + jj)],
108+
C[((i * 4) + ii), ((k * 4) + kk)],
109+
]
110+
)
111+
tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]])
112+
D = tir.alloc_buffer([128, 128])
113+
B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + (
114+
D[((j * 4) + jj), ((k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)]
115+
)
116+
117+
118+
def test_elementwise():
119+
_check(element_func, transformed_element_func)
120+
121+
122+
def test_locate_buffer_allocation():
123+
_check(original_func, transformed_func)
124+
125+
126+
if __name__ == "__main__":
127+
test_elementwise()
128+
test_locate_buffer_allocation()

0 commit comments

Comments
 (0)