diff --git a/python/tvm/contrib/miopen.py b/python/tvm/contrib/miopen.py index 616bd5a420ac..e062ac1e735e 100644 --- a/python/tvm/contrib/miopen.py +++ b/python/tvm/contrib/miopen.py @@ -50,7 +50,8 @@ def conv2d_forward(x, dilation_h=1, dilation_w=1, conv_mode=0, - data_type=1): + data_type=1, + group_count=1): """Create an extern op that compute 2D convolution with MIOpen Parameters @@ -77,13 +78,16 @@ def conv2d_forward(x, data_type: int 0: miopenHalf (fp16) 1: miopenFloat (fp32) - + group_count: int + number of groups Returns ------- y: Tensor The result tensor """ - assert (conv_mode == 0 or conv_mode == 1), "0: miopenConvolution / 1: miopenTranspose" + assert (0 <= conv_mode <= 2), "0: miopenConvolution / 1: miopenTranspose / 2: miopenGroupConv" + if group_count > 1: + conv_mode = 2 oshape = np.zeros((len(x.shape)), dtype=np.int32) xshape = x.shape wshape = w.shape @@ -104,6 +108,7 @@ def conv2d_forward(x, wshape[1].value, wshape[2].value, wshape[3].value, + group_count, _get_np_int32_array_handle(oshape)) return _api.extern( diff --git a/src/contrib/miopen/conv_forward.cc b/src/contrib/miopen/conv_forward.cc index baac86b8603d..6479d7d0906a 100644 --- a/src/contrib/miopen/conv_forward.cc +++ b/src/contrib/miopen/conv_forward.cc @@ -50,16 +50,20 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") const int w_dim1 = args[13]; const int w_dim2 = args[14]; const int w_dim3 = args[15]; - void *out_shape = args[16]; + const int n_group = args[16]; + void *out_shape = args[17]; MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); + assert(n_group > 0 && "Group Size > 0 is expected"); + if (n_group > 1) + assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1"); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); // Set Ctx entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0}; // Set Data Type entry_ptr->conv_entry.data_type = static_cast( - dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at + dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), int32, int8 at // this moment. // Set Desc MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, @@ -70,11 +74,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") stride_w, dilation_h, dilation_w)); + if (n_group > 1) + MIOPEN_CALL(miopenSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, n_group)); // Set Filter MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.data_type, w_dim0, - w_dim1, + w_dim1/n_group, w_dim2, w_dim3)); // Set Input