@@ -50,16 +50,20 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
5050 const int w_dim1 = args[13 ];
5151 const int w_dim2 = args[14 ];
5252 const int w_dim3 = args[15 ];
53- void *out_shape = args[16 ];
53+ const int n_group = args[16 ];
54+ void *out_shape = args[17 ];
5455
5556 MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal ();
57+ assert (n_group > 0 && " Group Size > 0 is expected" );
58+ if (n_group > 1 )
59+ assert (mode > 1 && " Group /Depthwise Conv mode when num of groups > 1" );
5660 // Set Mode
5761 entry_ptr->conv_entry .mode = static_cast <miopenConvolutionMode_t>(mode);
5862 // Set Ctx
5963 entry_ptr->conv_entry .ctx = TVMContext{kDLROCM , 0 };
6064 // Set Data Type
6165 entry_ptr->conv_entry .data_type = static_cast <miopenDataType_t>(
62- dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
66+ dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), int32, int8 at
6367 // this moment.
6468 // Set Desc
6569 MIOPEN_CALL (miopenInitConvolutionDescriptor (entry_ptr->conv_entry .conv_desc ,
@@ -70,11 +74,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
7074 stride_w,
7175 dilation_h,
7276 dilation_w));
77+ if (n_group > 1 )
78+ MIOPEN_CALL (miopenSetConvolutionGroupCount (entry_ptr->conv_entry .conv_desc , n_group));
7379 // Set Filter
7480 MIOPEN_CALL (miopenSet4dTensorDescriptor (entry_ptr->conv_entry .filter_desc ,
7581 entry_ptr->conv_entry .data_type ,
7682 w_dim0,
77- w_dim1,
83+ w_dim1/n_group ,
7884 w_dim2,
7985 w_dim3));
8086 // Set Input
0 commit comments