Skip to content

Commit d64bf6b

Browse files
minminsuntqchen
authored andcommitted
Auto TensorCore CodeGen (#4234)
* Add Auto TensorCore TensorCore Unit Test * Rebase to tvm master branch & Add auto tensor core * Code Refine * Add tensor core switch by pragma * Add pragma in tensor core example code * Get real tile size to replace hard coded 16 * support more than 2 dimensions (e.g. batchmatmul) for buffer bind scope * support batch matmul * Move cuda env check to tensor_core.cc * Coderefine for tensor_core.cc * Refine comments * Some refinements of code and comment * Update TensorCore UT to pass the CPU test * remove redundant code * matmul's storage align for different layout * Add support for differenct position of type cast * Add formal tutorial for auto tensorcore codegen * move tensorcore check up to tutorial code * code and doc refine * comment out tune_and_evaluate in tutorial * fix cpplint error
1 parent 281f643 commit d64bf6b

File tree

7 files changed

+1920
-0
lines changed

7 files changed

+1920
-0
lines changed

include/tvm/ir.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,8 @@ constexpr const char* reduce_scope = "reduce_scope";
12481248
constexpr const char* pragma_scope_prefix = "pragma_";
12491249
/*! \brief Import llvm source or file into the final code gen module */
12501250
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
1251+
/*! \brief Try to modify the AST to support Tensor Core */
1252+
constexpr const char* pragma_tensor_core = "pragma_tensor_core";
12511253
/*!
12521254
* \brief Mark of prefetch scope, value=offset,
12531255
* run prefetch of Tensor on the current loop scope

include/tvm/ir_pass.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,20 @@ Stmt StorageFlatten(Stmt stmt,
206206
Map<Tensor, Buffer> extern_buffer,
207207
int cache_line_size,
208208
bool create_bound_attribute = false);
209+
210+
/*!
211+
* \brief Try to modify the AST to support TensorCore
212+
*
213+
* \param stmt The stmt to be trasnformed.
214+
* \param schedule The original schedule.
215+
* \param extern_buffer Map specifies external
216+
* buffer assignment of input and outputs.
217+
* \return Transformed stmt.
218+
*/
219+
Stmt RewriteForTensorCore(Stmt stmt,
220+
Schedule schedule,
221+
Map<Tensor, Buffer> extern_buffer);
222+
209223
/*!
210224
* \brief Verify if there is any argument bound to compact buffer.
211225
*

python/tvm/build_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def lower(sch,
387387
binds, arg_list = get_binds(args, compact, binds)
388388

389389
# Phase 1
390+
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
390391
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
391392
stmt = ir_pass.CanonicalSimplify(stmt)
392393
for f in lower_phase1:

src/api/api_pass.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ TVM_REGISTER_API("ir_pass.StorageFlatten")
9494
}
9595
});
9696

97+
TVM_REGISTER_API("ir_pass.RewriteForTensorCore")
98+
.set_body_typed<Stmt(const Stmt&, const Schedule&, const Map<Tensor, Buffer>&)>
99+
([](const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) {
100+
return RewriteForTensorCore(stmt, schedule, extern_buffer);
101+
});
102+
97103
TVM_REGISTER_API("ir_pass.AttrsEqual")
98104
.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
99105
return AttrsEqual()(lhs, rhs);

0 commit comments

Comments
 (0)