Skip to content

Commit 9d249da

Browse files
author
Peter Yeh
committed
enable group conv through miopen
1 parent 8a2f10e commit 9d249da

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

python/tvm/contrib/miopen.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def conv2d_forward(x,
5050
dilation_h=1,
5151
dilation_w=1,
5252
conv_mode=0,
53-
data_type=1):
53+
data_type=1,
54+
group_count=1):
5455
"""Create an extern op that compute 2D convolution with MIOpen
5556
5657
Parameters
@@ -77,13 +78,16 @@ def conv2d_forward(x,
7778
data_type: int
7879
0: miopenHalf (fp16)
7980
1: miopenFloat (fp32)
80-
81+
group_count: int
82+
number of groups
8183
Returns
8284
-------
8385
y: Tensor
8486
The result tensor
8587
"""
86-
assert (conv_mode == 0 or conv_mode == 1), "0: miopenConvolution / 1: miopenTranspose"
88+
assert (0 <= conv_mode <= 2), "0: miopenConvolution / 1: miopenTranspose / 2: miopenGroupConv"
89+
if group_count > 1:
90+
conv_mode = 2
8791
oshape = np.zeros((len(x.shape)), dtype=np.int32)
8892
xshape = x.shape
8993
wshape = w.shape
@@ -104,6 +108,7 @@ def conv2d_forward(x,
104108
wshape[1].value,
105109
wshape[2].value,
106110
wshape[3].value,
111+
group_count,
107112
_get_np_int32_array_handle(oshape))
108113

109114
return _api.extern(

src/contrib/miopen/conv_forward.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,20 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
5050
const int w_dim1 = args[13];
5151
const int w_dim2 = args[14];
5252
const int w_dim3 = args[15];
53-
void *out_shape = args[16];
53+
const int n_group = args[16];
54+
void *out_shape = args[17];
5455

5556
MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
57+
assert(n_group > 0 && "Group Size > 0 is expected");
58+
if (n_group > 1)
59+
assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1");
5660
// Set Mode
5761
entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
5862
// Set Ctx
5963
entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0};
6064
// Set Data Type
6165
entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
62-
dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
66+
dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), int32, int8 at
6367
// this moment.
6468
// Set Desc
6569
MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
@@ -70,11 +74,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
7074
stride_w,
7175
dilation_h,
7276
dilation_w));
77+
if (n_group > 1)
78+
MIOPEN_CALL(miopenSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, n_group));
7379
// Set Filter
7480
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
7581
entry_ptr->conv_entry.data_type,
7682
w_dim0,
77-
w_dim1,
83+
w_dim1/n_group,
7884
w_dim2,
7985
w_dim3));
8086
// Set Input

0 commit comments

Comments
 (0)