Skip to content

Commit 82559d0

Browse files
petrexwweic
authored andcommitted
Enable miopen transpose convolution and fp16 support (apache#3952)
* Enable miopen transpose convolution and fp16 support * linter
1 parent 17d9575 commit 82559d0

File tree

4 files changed

+46
-32
lines changed

4 files changed

+46
-32
lines changed

python/tvm/contrib/miopen.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def conv2d_forward(x,
4949
pad_w=0,
5050
dilation_h=1,
5151
dilation_w=1,
52-
conv_mode=0):
52+
conv_mode=0,
53+
data_type=1):
5354
"""Create an extern op that compute 2D convolution with MIOpen
5455
5556
Parameters
@@ -73,18 +74,22 @@ def conv2d_forward(x,
7374
conv_mode: int
7475
0: miopenConvolution
7576
1: miopenTranspose
77+
data_type: int
78+
0: miopenHalf (fp16)
79+
1: miopenFloat (fp32)
7680
7781
Returns
7882
-------
7983
y: Tensor
8084
The result tensor
8185
"""
82-
assert conv_mode == 0, "Transpose convolutions not supported yet."
86+
assert (conv_mode == 0 or conv_mode == 1), "0: miopenConvolution / 1: miopenTranspose"
8387
oshape = np.zeros((len(x.shape)), dtype=np.int32)
8488
xshape = x.shape
8589
wshape = w.shape
8690
setup_func = _get_global_func("tvm.contrib.miopen.conv2d.setup")
8791
algo = setup_func(conv_mode,
92+
data_type,
8893
pad_h,
8994
pad_w,
9095
stride_h,
@@ -106,6 +111,7 @@ def conv2d_forward(x,
106111
lambda ins, outs: _intrin.call_packed(
107112
"tvm.contrib.miopen.conv2d.forward",
108113
conv_mode,
114+
data_type,
109115
pad_h,
110116
pad_w,
111117
stride_h,

src/contrib/miopen/conv_forward.cc

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,32 @@ using namespace runtime;
3535
TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
3636
.set_body([](TVMArgs args, TVMRetValue *ret) {
3737
const int mode = args[0];
38-
const int pad_h = args[1];
39-
const int pad_w = args[2];
40-
const int stride_h = args[3];
41-
const int stride_w = args[4];
42-
const int dilation_h = args[5];
43-
const int dilation_w = args[6];
44-
const int x_dim0 = args[7];
45-
const int x_dim1 = args[8];
46-
const int x_dim2 = args[9];
47-
const int x_dim3 = args[10];
48-
const int w_dim0 = args[11];
49-
const int w_dim1 = args[12];
50-
const int w_dim2 = args[13];
51-
const int w_dim3 = args[14];
52-
void *out_shape = args[15];
38+
const int dtype = args[1];
39+
const int pad_h = args[2];
40+
const int pad_w = args[3];
41+
const int stride_h = args[4];
42+
const int stride_w = args[5];
43+
const int dilation_h = args[6];
44+
const int dilation_w = args[7];
45+
const int x_dim0 = args[8];
46+
const int x_dim1 = args[9];
47+
const int x_dim2 = args[10];
48+
const int x_dim3 = args[11];
49+
const int w_dim0 = args[12];
50+
const int w_dim1 = args[13];
51+
const int w_dim2 = args[14];
52+
const int w_dim3 = args[15];
53+
void *out_shape = args[16];
5354

5455
MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
5556
// Set Mode
5657
entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
5758
// Set Ctx
5859
entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0};
5960
// Set Data Type
60-
entry_ptr->conv_entry.data_type = miopenFloat; // MIOpen only suppports fp32
61+
entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
62+
dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
63+
// this moment.
6164
// Set Desc
6265
MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
6366
entry_ptr->conv_entry.mode,
@@ -170,16 +173,17 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
170173
TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward")
171174
.set_body([](TVMArgs args, TVMRetValue *ret) {
172175
const int mode = args[0];
173-
const int pad_h = args[1];
174-
const int pad_w = args[2];
175-
const int stride_h = args[3];
176-
const int stride_w = args[4];
177-
const int dilation_h = args[5];
178-
const int dilation_w = args[6];
179-
const int algo = args[7];
180-
const DLTensor *x = args[8];
181-
const DLTensor *w = args[9];
182-
const DLTensor *y = args[10];
176+
const int dtype = args[1];
177+
const int pad_h = args[2];
178+
const int pad_w = args[3];
179+
const int stride_h = args[4];
180+
const int stride_w = args[5];
181+
const int dilation_h = args[6];
182+
const int dilation_w = args[7];
183+
const int algo = args[8];
184+
const DLTensor *x = args[9];
185+
const DLTensor *w = args[10];
186+
const DLTensor *y = args[11];
183187

184188
MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
185189
entry_ptr->conv_entry.fwd_algo = static_cast<miopenConvFwdAlgorithm_t>(algo);
@@ -188,7 +192,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward")
188192
// Set Ctx
189193
entry_ptr->conv_entry.ctx = x->ctx;
190194
// Set Data Type
191-
entry_ptr->conv_entry.data_type = miopenFloat; // MIOpen only suppports fp32
195+
entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
196+
dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
197+
// this moment.
192198
// Set Desc
193199
MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
194200
entry_ptr->conv_entry.mode,

tests/python/contrib/test_miopen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def test_conv2d():
5050
pad_w,
5151
dilation_h,
5252
dilation_w,
53-
conv_mode=0)
53+
conv_mode=0,
54+
data_type=1)
5455

5556
yshape = [x.value for x in Y.shape]
5657
import topi
@@ -65,7 +66,7 @@ def verify():
6566
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
6667
f(x, w, y)
6768

68-
Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w))
69+
Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w))
6970
with tvm.target.rocm():
7071
s_ref = topi.generic.schedule_conv2d_nchw([Y_ref])
7172
f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm")

topi/python/topi/rocm/conv2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
7878
pad_w,
7979
dilation_h,
8080
dilation_w,
81-
conv_mode=0)
81+
conv_mode=0,
82+
data_type=1)
8283

8384
return conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
8485

0 commit comments

Comments
 (0)