Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/script/printer/tir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<
/*args=*/args);
}

ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
const IRDocsifier& d) {
Map<String, ExprDoc> attrs = BufferAttrs(buffer, p, frame, d);
ExprDoc shape = attrs.Get("shape").value();
ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype));
return TIR("Buffer")->Call({shape, dtype}, {}, {});
}

Array<Doc> BufferIndices(const Array<PrimExpr>& indices, const ObjectPath& p,
const IRDocsifier& d) {
int n = indices.size();
Expand Down
105 changes: 105 additions & 0 deletions src/script/printer/tir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/tir/stmt_functor.h>

#include "./utils.h"

namespace tvm {
Expand All @@ -34,16 +37,115 @@ String FindFunctionName(const IRDocsifier& d, const tir::PrimFunc& f) {
return "main";
}

bool IsSimpleBuffer(const tir::Buffer& buf) {
if (!buf->strides.empty()) {
return false;
}
for (const PrimExpr& shp_i : buf->shape) {
if (!tir::UndefinedVars(shp_i).empty()) {
return false;
}
}
for (const PrimExpr& stride_i : buf->strides) {
if (!tir::UndefinedVars(stride_i).empty()) {
return false;
}
}
if (!tir::UndefinedVars(buf->elem_offset).empty()) {
return false;
} else if (buf->elem_offset->IsInstance<IntImmNode>()) {
IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
if (elem_offset->value != 0) {
return false;
}
}
return buf.scope() == "global" && buf->data_alignment == runtime::kAllocAlignment &&
buf->offset_factor == 1 && buf->buffer_type == tir::BufferType::kDefault &&
!buf->axis_separators.size();
}

int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) {
class OccurrenceCounter : public tir::StmtExprVisitor {
public:
int count = 0;
const tir::VarNode* v = nullptr;

void VisitExpr_(const tir::VarNode* op) final {
if (op == v) {
++count;
}
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::BufferStoreNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const tir::BufferLoadNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::DeclBufferNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitBuffer(const tir::BufferNode* buffer) {
VisitExpr(buffer->data);
for (const PrimExpr& shape_i : buffer->shape) {
VisitExpr(shape_i);
}
for (const PrimExpr& stride_i : buffer->strides) {
VisitExpr(stride_i);
}
VisitExpr(buffer->elem_offset);
}
};

OccurrenceCounter counter;
counter.v = v.get();
counter(f->body);
for (const tir::Var& v : f->params) {
counter(v);
}
for (const auto& pair : f->buffer_map) {
counter(pair.first);
counter.VisitBuffer(pair.second.get());
}
return counter.count;
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc {
With<TIRFrame> frame(MakeDispatchFrame(d, func, func));
int n_args = func->params.size();
std::unordered_map<const tir::VarNode*, int> buffer_data_counter;
for (const auto& pair : func->buffer_map) {
const tir::VarNode* data_var = pair.second->data.get();
if (!buffer_data_counter.count(data_var)) {
buffer_data_counter.insert({data_var, 0});
}
++buffer_data_counter.at(data_var);
}
// Step 1. Handle `func->params`
Array<AssignDoc> args;
args.reserve(n_args);
std::unordered_set<const tir::BufferNode*> buffer_inlined;
for (int i = 0; i < n_args; ++i) {
tir::Var var = func->params[i];
ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
if (CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) {
tir::Buffer buffer = func->buffer_map[var];
if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) {
ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var);
args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt,
BufferAttn(buffer, buffer_p, *frame, d)));
buffer_inlined.insert(buffer.get());
continue;
}
}
ExprDoc a = d->AsDoc<ExprDoc>(var->type_annotation, var_p->Attr("type_annotation"));
args.push_back(AssignDoc(DefineVar(var, *frame, d), NullOpt, a));
}
Expand All @@ -58,6 +160,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
tir::Var param = func->params[i];
if (func->buffer_map.count(param)) {
tir::Buffer buffer = func->buffer_map[param];
if (buffer_inlined.count(buffer.get())) {
continue;
}
ExprDoc param = args[i]->lhs;
ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param);
ExprDoc lhs =
Expand Down
11 changes: 11 additions & 0 deletions src/script/printer/tir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ inline void ReprPrintTIR(const ObjectRef& obj, ReprPrinter* p) {
ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<ExprDoc>& args,
const ObjectPath& p, const Frame& frame, const IRDocsifier& d);

/*!
* \brief Declare and define a buffer as annotation
* \param buffer The buffer to be defined
* \param p The object path
* \param f The frame
* \param d The IRDocsifier
* \return The ExprDoc corresponding to the buffer declaration
*/
ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
const IRDocsifier& d);

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
52 changes: 50 additions & 2 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,56 @@ def test_prim_func():
func,
expected="""
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
T.evaluate(0)""",
)


def test_prim_func_no_sugar_inlined_buffer():
a = tir.Var("a", "handle")
b = tir.Var("b", "handle")
func = tir.PrimFunc(
params=[a, b],
ret_type=None,
buffer_map={
a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
},
body=tir.Evaluate(a),
)
_assert_print(
func,
expected="""
@T.prim_func
def main(a: T.handle, B: T.Buffer((256, 256), "float32")):
A = T.match_buffer(a, (128, 128))
T.evaluate(a)
""",
)


def test_prim_func_no_sugar_shared_buffer_data():
a = tir.Var("a", "handle")
b = tir.Var("b", "handle")
buffer_data = tir.decl_buffer(shape=[128, 128], dtype="float32", name="A").data
func = tir.PrimFunc(
params=[a, b],
ret_type=None,
buffer_map={
a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A", data=buffer_data),
b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B", data=buffer_data),
},
body=tir.Evaluate(0),
)
_assert_print(
func,
expected="""
@T.prim_func
def main(a: T.handle, b: T.handle):
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (256, 256))
T.evaluate(0)""",
B = T.match_buffer(b, (256, 256), data=A.data)
T.evaluate(0)
""",
)


Expand Down Expand Up @@ -641,6 +687,8 @@ def main():

if __name__ == "__main__":
test_prim_func()
test_prim_func_no_sugar_inlined_buffer()
test_prim_func_no_sugar_shared_buffer_data()
test_block_realize()
test_block()
test_buffer()
Expand Down