Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,8 @@ constexpr const char* reduce_scope = "reduce_scope";
constexpr const char* pragma_scope_prefix = "pragma_";
/*! \brief Import llvm source or file into the final code gen module */
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
/*! \brief Try to modify the AST to support Tensor Core */
constexpr const char* pragma_tensor_core = "pragma_tensor_core";
/*!
* \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);

/*!
* \brief Try to modify the AST to support TensorCore
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt RewriteForTensorCore(Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer);

/*!
* \brief Verify if there is any argument bound to compact buffer.
*
Expand Down
1 change: 1 addition & 0 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def lower(sch,
binds, arg_list = get_binds(args, compact, binds)

# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
Expand Down
6 changes: 6 additions & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ TVM_REGISTER_API("ir_pass.StorageFlatten")
}
});

TVM_REGISTER_API("ir_pass.RewriteForTensorCore")
.set_body_typed<Stmt(const Stmt&, const Schedule&, const Map<Tensor, Buffer>&)>
([](const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) {
return RewriteForTensorCore(stmt, schedule, extern_buffer);
});

TVM_REGISTER_API("ir_pass.AttrsEqual")
.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
return AttrsEqual()(lhs, rhs);
Expand Down
Loading