-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Auto TensorCore CodeGen #4106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto TensorCore CodeGen #4106
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -207,6 +207,12 @@ Stmt StorageFlatten(Stmt stmt, | |
| int cache_line_size, | ||
| bool create_bound_attribute = false); | ||
|
|
||
| Stmt TensorCore(Stmt stmt, | ||
| Schedule schedule, | ||
| double cuda_compute_capability, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use integer here instead of double? |
||
| double cuda_version, | ||
| Map<Tensor, Buffer> extern_buffer); | ||
|
|
||
| /*! | ||
| * \brief Remove No Op from the Stmt. | ||
| * \param stmt The stmt to be trasnformed | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -378,6 +378,18 @@ def lower(sch, | |
| for f in lower_phase0: | ||
| stmt = f(stmt) | ||
| # Phase 1 | ||
| try: | ||
| # device_type 2 for GPU | ||
| # choose device 0 | ||
| # attr type 4 for CUDA Compute Capability | ||
| cuda_compute_capability = _api_internal._GetDeviceAttr(2, 0, 4) | ||
| from tvm.contrib.nvcc import find_cuda_path, get_cuda_version | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems like not necessary to put this "from..import" inside Try, and it may cause another problem that once "tvm.contrib.nvcc" changed , this module would set cuda_compute_capability into None instead of report the error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, we have same concerns about this part when submitting this pr. It should be better to move the cuda version and capability check to somewhere inside the TensorCore pass. |
||
| cuda_version = float(get_cuda_version(find_cuda_path())) | ||
| except: | ||
| cuda_compute_capability = None | ||
| if cuda_compute_capability and float(cuda_compute_capability) >= 7.0 and cuda_version >= 9.0: | ||
| stmt = ir_pass.TensorCore(stmt, sch, float(cuda_compute_capability), float(cuda_version), binds) | ||
|
|
||
| stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) | ||
| stmt = ir_pass.CanonicalSimplify(stmt) | ||
| for f in lower_phase1: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -94,6 +94,13 @@ TVM_REGISTER_API("ir_pass.StorageFlatten") | |
| } | ||
| }); | ||
|
|
||
| TVM_REGISTER_API("ir_pass.TensorCore") | ||
| .set_body([](TVMArgs args, TVMRetValue *ret) { | ||
| if (args.size() == 5) { | ||
| *ret = TensorCore(args[0], args[1], args[2], args[3], args[4]); | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should handle not "args.size() == 5" case and set *ret value. |
||
| }); | ||
|
|
||
| TVM_REGISTER_API("ir_pass.AttrsEqual") | ||
| .set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) { | ||
| return AttrsEqual()(lhs, rhs); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, why are not them symmetric?