@@ -35,29 +35,32 @@ using namespace runtime;
3535TVM_REGISTER_GLOBAL (" tvm.contrib.miopen.conv2d.setup" )
3636.set_body([](TVMArgs args, TVMRetValue *ret) {
3737 const int mode = args[0 ];
38- const int pad_h = args[1 ];
39- const int pad_w = args[2 ];
40- const int stride_h = args[3 ];
41- const int stride_w = args[4 ];
42- const int dilation_h = args[5 ];
43- const int dilation_w = args[6 ];
44- const int x_dim0 = args[7 ];
45- const int x_dim1 = args[8 ];
46- const int x_dim2 = args[9 ];
47- const int x_dim3 = args[10 ];
48- const int w_dim0 = args[11 ];
49- const int w_dim1 = args[12 ];
50- const int w_dim2 = args[13 ];
51- const int w_dim3 = args[14 ];
52- void *out_shape = args[15 ];
38+ const int dtype = args[1 ];
39+ const int pad_h = args[2 ];
40+ const int pad_w = args[3 ];
41+ const int stride_h = args[4 ];
42+ const int stride_w = args[5 ];
43+ const int dilation_h = args[6 ];
44+ const int dilation_w = args[7 ];
45+ const int x_dim0 = args[8 ];
46+ const int x_dim1 = args[9 ];
47+ const int x_dim2 = args[10 ];
48+ const int x_dim3 = args[11 ];
49+ const int w_dim0 = args[12 ];
50+ const int w_dim1 = args[13 ];
51+ const int w_dim2 = args[14 ];
52+ const int w_dim3 = args[15 ];
53+ void *out_shape = args[16 ];
5354
5455 MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal ();
5556 // Set Mode
5657 entry_ptr->conv_entry .mode = static_cast <miopenConvolutionMode_t>(mode);
5758 // Set Ctx
5859 entry_ptr->conv_entry .ctx = TVMContext{kDLROCM , 0 };
5960 // Set Data Type
60- entry_ptr->conv_entry .data_type = miopenFloat; // MIOpen only suppports fp32
61+ entry_ptr->conv_entry .data_type = static_cast <miopenDataType_t>(
62+ dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
63+ // this moment.
6164 // Set Desc
6265 MIOPEN_CALL (miopenInitConvolutionDescriptor (entry_ptr->conv_entry .conv_desc ,
6366 entry_ptr->conv_entry .mode ,
@@ -170,16 +173,17 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
170173TVM_REGISTER_GLOBAL (" tvm.contrib.miopen.conv2d.forward" )
171174.set_body([](TVMArgs args, TVMRetValue *ret) {
172175 const int mode = args[0 ];
173- const int pad_h = args[1 ];
174- const int pad_w = args[2 ];
175- const int stride_h = args[3 ];
176- const int stride_w = args[4 ];
177- const int dilation_h = args[5 ];
178- const int dilation_w = args[6 ];
179- const int algo = args[7 ];
180- const DLTensor *x = args[8 ];
181- const DLTensor *w = args[9 ];
182- const DLTensor *y = args[10 ];
176+ const int dtype = args[1 ];
177+ const int pad_h = args[2 ];
178+ const int pad_w = args[3 ];
179+ const int stride_h = args[4 ];
180+ const int stride_w = args[5 ];
181+ const int dilation_h = args[6 ];
182+ const int dilation_w = args[7 ];
183+ const int algo = args[8 ];
184+ const DLTensor *x = args[9 ];
185+ const DLTensor *w = args[10 ];
186+ const DLTensor *y = args[11 ];
183187
184188 MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal ();
185189 entry_ptr->conv_entry .fwd_algo = static_cast <miopenConvFwdAlgorithm_t>(algo);
@@ -188,7 +192,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward")
188192 // Set Ctx
189193 entry_ptr->conv_entry .ctx = x->ctx ;
190194 // Set Data Type
191- entry_ptr->conv_entry .data_type = miopenFloat; // MIOpen only suppports fp32
195+ entry_ptr->conv_entry .data_type = static_cast <miopenDataType_t>(
196+ dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
197+ // this moment.
192198 // Set Desc
193199 MIOPEN_CALL (miopenInitConvolutionDescriptor (entry_ptr->conv_entry .conv_desc ,
194200 entry_ptr->conv_entry .mode ,
0 commit comments