66 * to you under the Apache License, Version 2.0 (the 
77 * "License"); you may not use this file except in compliance 
88 * with the License.  You may obtain a copy of the License at 
9-  *   
9+  * 
1010 *   http://www.apache.org/licenses/LICENSE-2.0 
11-  *   
11+  * 
1212 * Unless required by applicable law or agreed to in writing, 
1313 * software distributed under the License is distributed on an 
1414 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 
2121 *  Copyright (c) 2017 by Contributors 
2222 * \file Use external miopen utils function 
2323 */  
24+ #include  < tvm/runtime/device_api.h> 
2425#include  < tvm/runtime/registry.h> 
2526#include  < tvm/runtime/util.h> 
26- #include  < tvm/runtime/device_api.h> 
2727#include  " miopen_utils.h" 
2828
2929namespace  tvm  {
@@ -33,211 +33,167 @@ namespace miopen {
3333using  namespace  runtime ; 
3434
3535TVM_REGISTER_GLOBAL (" tvm.contrib.miopen.conv2d.setup" 
36- .set_body([](TVMArgs args, TVMRetValue *ret) {
37-   const  int  mode = args[0 ];
38-   const  int  dtype = args[1 ];
39-   const  int  pad_h = args[2 ];
40-   const  int  pad_w = args[3 ];
41-   const  int  stride_h = args[4 ];
42-   const  int  stride_w = args[5 ];
43-   const  int  dilation_h = args[6 ];
44-   const  int  dilation_w = args[7 ];
45-   const  int  x_dim0 = args[8 ];
46-   const  int  x_dim1 = args[9 ];
47-   const  int  x_dim2 = args[10 ];
48-   const  int  x_dim3 = args[11 ];
49-   const  int  w_dim0 = args[12 ];
50-   const  int  w_dim1 = args[13 ];
51-   const  int  w_dim2 = args[14 ];
52-   const  int  w_dim3 = args[15 ];
53-   void  *out_shape = args[16 ];
54- 
55-   MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal ();
56-   //  Set Mode
57-   entry_ptr->conv_entry .mode  = static_cast <miopenConvolutionMode_t>(mode);
58-   //  Set Ctx
59-   entry_ptr->conv_entry .ctx  = TVMContext{kDLROCM , 0 };
60-   //  Set Data Type
61-   entry_ptr->conv_entry .data_type  = static_cast <miopenDataType_t>(dtype);  //  MIOpen suppports fp32(miopenFloat), fp16(miopenHalf) at this moment.
62-   //  Set Desc
63-   MIOPEN_CALL (miopenInitConvolutionDescriptor (entry_ptr->conv_entry .conv_desc ,
64-                                               entry_ptr->conv_entry .mode ,
65-                                               pad_h,
66-                                               pad_w,
67-                                               stride_h,
68-                                               stride_w,
69-                                               dilation_h,
70-                                               dilation_w));
71-   //  Set Filter
72-   MIOPEN_CALL (miopenSet4dTensorDescriptor (entry_ptr->conv_entry .filter_desc ,
73-                                           entry_ptr->conv_entry .data_type ,
74-                                           w_dim0,
75-                                           w_dim1,
76-                                           w_dim2,
77-                                           w_dim3));
78-   //  Set Input
79-   MIOPEN_CALL (miopenSet4dTensorDescriptor (entry_ptr->conv_entry .input_desc ,
80-                                           entry_ptr->conv_entry .data_type ,
81-                                           x_dim0,
82-                                           x_dim1,
83-                                           x_dim2,
84-                                           x_dim3));
85- 
86-   //  Set Output shape
87-   MIOPEN_CALL (miopenGetConvolutionForwardOutputDim (entry_ptr->conv_entry .conv_desc ,
88-                                                    entry_ptr->conv_entry .input_desc ,
89-                                                    entry_ptr->conv_entry .filter_desc ,
90-                                                    static_cast <int *>(out_shape),
91-                                                    static_cast <int *>(out_shape) + 1 ,
92-                                                    static_cast <int *>(out_shape) + 2 ,
93-                                                    static_cast <int *>(out_shape) + 3 ));
94- 
95-   const  int  *oshape = static_cast <int *>(out_shape);
96-   //  Set Output
97-   MIOPEN_CALL (miopenSet4dTensorDescriptor (entry_ptr->conv_entry .output_desc ,
98-                                           entry_ptr->conv_entry .data_type ,
99-                                           oshape[0 ],
100-                                           oshape[1 ],
101-                                           oshape[2 ],
102-                                           oshape[3 ]));
103- 
104-   //  Set workspace
105-   size_t  workspace_size = 0 ;
106-   MIOPEN_CALL (miopenConvolutionForwardGetWorkSpaceSize (entry_ptr->handle ,
107-                                                        entry_ptr->conv_entry .filter_desc ,
108-                                                        entry_ptr->conv_entry .input_desc ,
109-                                                        entry_ptr->conv_entry .conv_desc ,
110-                                                        entry_ptr->conv_entry .output_desc ,
111-                                                        &workspace_size));
112-   entry_ptr->conv_entry .UpdateWorkspace (workspace_size);
113- 
114-   const  size_t  input_size = x_dim0 * x_dim1 * x_dim2 * x_dim3;
115-   const  size_t  filter_size = w_dim0 * w_dim1 * w_dim2 * w_dim3;
116-   const  size_t  output_size = oshape[0 ] * oshape[1 ] * oshape[2 ] * oshape[3 ];
117- 
118-   runtime::DeviceAPI* rocm_api = entry_ptr->conv_entry .rocm_api ;
119-   float * input_buf = static_cast <float *>(rocm_api->AllocWorkspace (entry_ptr->conv_entry .ctx ,
120-                                                                   input_size * sizeof (float )));
121-   float * filter_buf = static_cast <float *>(rocm_api->AllocWorkspace (entry_ptr->conv_entry .ctx ,
122-                                                                    filter_size * sizeof (float )));
123-   float * output_buf = static_cast <float *>(rocm_api->AllocWorkspace (entry_ptr->conv_entry .ctx ,
124-                                                                    output_size * sizeof (float )));
125- 
126-   const  int  request_algo_count = 4 ;
127-   const  bool  exhaustive_search = false ;
128-   void * workspace = entry_ptr->conv_entry .workspace ;
129-   if  (workspace_size == 0 ) workspace = nullptr ;
130-   int  returned_algo_count = 0 ;
131-   miopenConvAlgoPerf_t perfs[4 ];
132- 
133-   MIOPEN_CALL (miopenFindConvolutionForwardAlgorithm (entry_ptr->handle ,
134-                                                     entry_ptr->conv_entry .input_desc ,
135-                                                     input_buf,
136-                                                     entry_ptr->conv_entry .filter_desc ,
137-                                                     filter_buf,
138-                                                     entry_ptr->conv_entry .conv_desc ,
139-                                                     entry_ptr->conv_entry .output_desc ,
140-                                                     output_buf,
141-                                                     request_algo_count,
142-                                                     &returned_algo_count,
143-                                                     perfs,
144-                                                     workspace,
145-                                                     workspace_size,
146-                                                     exhaustive_search));
147- 
148-   rocm_api->FreeWorkspace (entry_ptr->conv_entry .ctx , input_buf);
149-   rocm_api->FreeWorkspace (entry_ptr->conv_entry .ctx , filter_buf);
150-   rocm_api->FreeWorkspace (entry_ptr->conv_entry .ctx , output_buf);
151- 
152-   const  std::vector<std::string> fwd_algo_names{
153-       " miopenConvolutionFwdAlgoGEMM" 
154-       " miopenConvolutionFwdAlgoDirect" 
155-       " miopenConvolutionFwdAlgoFFT" 
156-       " miopenConvolutionFwdAlgoWinograd" 
157-   };
158-   const  auto  best_algo = perfs[0 ].fwd_algo ;
159-   LOG (INFO) << " \t MIOpen Found " 
160-             << "  fwd algorithms, choosing " 
161-   for  (int  i = 0 ; i < returned_algo_count; ++i) {
162-     LOG (INFO) << " \t\t " " ) " fwd_algo ]
163-               << "  - time: " time  << "  ms" 
164-               << " , Memory: " memory ;
165-   }
166-   //  Set Algo
167-   ret[0 ] = static_cast <int >(best_algo);
168- });
169- 
36+     .set_body([](TVMArgs args, TVMRetValue* ret) {
37+       const  int  mode = args[0 ];
38+       const  int  dtype = args[1 ];
39+       const  int  pad_h = args[2 ];
40+       const  int  pad_w = args[3 ];
41+       const  int  stride_h = args[4 ];
42+       const  int  stride_w = args[5 ];
43+       const  int  dilation_h = args[6 ];
44+       const  int  dilation_w = args[7 ];
45+       const  int  x_dim0 = args[8 ];
46+       const  int  x_dim1 = args[9 ];
47+       const  int  x_dim2 = args[10 ];
48+       const  int  x_dim3 = args[11 ];
49+       const  int  w_dim0 = args[12 ];
50+       const  int  w_dim1 = args[13 ];
51+       const  int  w_dim2 = args[14 ];
52+       const  int  w_dim3 = args[15 ];
53+       void * out_shape = args[16 ];
54+ 
55+       MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal ();
56+       //  Set Mode
57+       entry_ptr->conv_entry .mode  = static_cast <miopenConvolutionMode_t>(mode);
58+       //  Set Ctx
59+       entry_ptr->conv_entry .ctx  = TVMContext{kDLROCM , 0 };
60+       //  Set Data Type
61+       entry_ptr->conv_entry .data_type  = static_cast <miopenDataType_t>(
62+           dtype);  //  MIOpen suppports fp32(miopenFloat), fp16(miopenHalf) at
63+                    //  this moment.
64+       //  Set Desc
65+       MIOPEN_CALL (miopenInitConvolutionDescriptor (
66+           entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .mode , pad_h,
67+           pad_w, stride_h, stride_w, dilation_h, dilation_w));
68+       //  Set Filter
69+       MIOPEN_CALL (miopenSet4dTensorDescriptor (entry_ptr->conv_entry .filter_desc ,
70+                                               entry_ptr->conv_entry .data_type ,
71+                                               w_dim0, w_dim1, w_dim2, w_dim3));
72+       //  Set Input
73+       MIOPEN_CALL (miopenSet4dTensorDescriptor (entry_ptr->conv_entry .input_desc ,
74+                                               entry_ptr->conv_entry .data_type ,
75+                                               x_dim0, x_dim1, x_dim2, x_dim3));
76+ 
77+       //  Set Output shape
78+       MIOPEN_CALL (miopenGetConvolutionForwardOutputDim (
79+           entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .input_desc ,
80+           entry_ptr->conv_entry .filter_desc , static_cast <int *>(out_shape),
81+           static_cast <int *>(out_shape) + 1 , static_cast <int *>(out_shape) + 2 ,
82+           static_cast <int *>(out_shape) + 3 ));
83+ 
84+       const  int * oshape = static_cast <int *>(out_shape);
85+       //  Set Output
86+       MIOPEN_CALL (miopenSet4dTensorDescriptor (
87+           entry_ptr->conv_entry .output_desc , entry_ptr->conv_entry .data_type ,
88+           oshape[0 ], oshape[1 ], oshape[2 ], oshape[3 ]));
89+ 
90+       //  Set workspace
91+       size_t  workspace_size = 0 ;
92+       MIOPEN_CALL (miopenConvolutionForwardGetWorkSpaceSize (
93+           entry_ptr->handle , entry_ptr->conv_entry .filter_desc ,
94+           entry_ptr->conv_entry .input_desc , entry_ptr->conv_entry .conv_desc ,
95+           entry_ptr->conv_entry .output_desc , &workspace_size));
96+       entry_ptr->conv_entry .UpdateWorkspace (workspace_size);
97+ 
98+       const  size_t  input_size = x_dim0 * x_dim1 * x_dim2 * x_dim3;
99+       const  size_t  filter_size = w_dim0 * w_dim1 * w_dim2 * w_dim3;
100+       const  size_t  output_size = oshape[0 ] * oshape[1 ] * oshape[2 ] * oshape[3 ];
101+ 
102+       runtime::DeviceAPI* rocm_api = entry_ptr->conv_entry .rocm_api ;
103+       float * input_buf = static_cast <float *>(rocm_api->AllocWorkspace (
104+           entry_ptr->conv_entry .ctx , input_size * sizeof (float )));
105+       float * filter_buf = static_cast <float *>(rocm_api->AllocWorkspace (
106+           entry_ptr->conv_entry .ctx , filter_size * sizeof (float )));
107+       float * output_buf = static_cast <float *>(rocm_api->AllocWorkspace (
108+           entry_ptr->conv_entry .ctx , output_size * sizeof (float )));
109+ 
110+       const  int  request_algo_count = 4 ;
111+       const  bool  exhaustive_search = false ;
112+       void * workspace = entry_ptr->conv_entry .workspace ;
113+       if  (workspace_size == 0 ) workspace = nullptr ;
114+       int  returned_algo_count = 0 ;
115+       miopenConvAlgoPerf_t perfs[4 ];
116+ 
117+       MIOPEN_CALL (miopenFindConvolutionForwardAlgorithm (
118+           entry_ptr->handle , entry_ptr->conv_entry .input_desc , input_buf,
119+           entry_ptr->conv_entry .filter_desc , filter_buf,
120+           entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .output_desc ,
121+           output_buf, request_algo_count, &returned_algo_count, perfs,
122+           workspace, workspace_size, exhaustive_search));
123+ 
124+       rocm_api->FreeWorkspace (entry_ptr->conv_entry .ctx , input_buf);
125+       rocm_api->FreeWorkspace (entry_ptr->conv_entry .ctx , filter_buf);
126+       rocm_api->FreeWorkspace (entry_ptr->conv_entry .ctx , output_buf);
127+ 
128+       const  std::vector<std::string> fwd_algo_names{
129+           " miopenConvolutionFwdAlgoGEMM" " miopenConvolutionFwdAlgoDirect" 
130+           " miopenConvolutionFwdAlgoFFT" " miopenConvolutionFwdAlgoWinograd" 
131+       };
132+       const  auto  best_algo = perfs[0 ].fwd_algo ;
133+       LOG (INFO) << " \t MIOpen Found " 
134+                 << "  fwd algorithms, choosing " 
135+       for  (int  i = 0 ; i < returned_algo_count; ++i) {
136+         LOG (INFO) << " \t\t " " ) " fwd_algo ]
137+                   << "  - time: " time  << "  ms" 
138+                   << " , Memory: " memory ;
139+       }
140+       //  Set Algo
141+       ret[0 ] = static_cast <int >(best_algo);
142+     });
170143
171144TVM_REGISTER_GLOBAL (" tvm.contrib.miopen.conv2d.forward" 
172- .set_body([](TVMArgs args, TVMRetValue *ret) {
173-   const  int  mode = args[0 ];
174-   const  int  dtype = args[1 ];
175-   const  int  pad_h = args[2 ];
176-   const  int  pad_w = args[3 ];
177-   const  int  stride_h = args[4 ];
178-   const  int  stride_w = args[5 ];
179-   const  int  dilation_h = args[6 ];
180-   const  int  dilation_w = args[7 ];
181-   const  int  algo = args[8 ];
182-   const  DLTensor *x = args[9 ];
183-   const  DLTensor *w = args[10 ];
184-   const  DLTensor *y = args[11 ];
185- 
186-   MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal ();
187-   entry_ptr->conv_entry .fwd_algo  = static_cast <miopenConvFwdAlgorithm_t>(algo);
188-   //  Set Mode
189-   entry_ptr->conv_entry .mode  = static_cast <miopenConvolutionMode_t>(mode);
190-   //  Set Ctx
191-   entry_ptr->conv_entry .ctx  = x->ctx ;
192-   //  Set Data Type
193-   entry_ptr->conv_entry .data_type  = static_cast <miopenDataType_t>(dtype);  //  MIOpen suppports fp32(miopenFloat), fp16(miopenHalf) at this moment.
194-   //  Set Desc
195-   MIOPEN_CALL (miopenInitConvolutionDescriptor (entry_ptr->conv_entry .conv_desc ,
196-                                               entry_ptr->conv_entry .mode ,
197-                                               pad_h,
198-                                               pad_w,
199-                                               stride_h,
200-                                               stride_w,
201-                                               dilation_h,
202-                                               dilation_w));
203-   //  Set Filter
204-   MIOPEN_CALL (miopenSet4dTensorDescriptor (entry_ptr->conv_entry .filter_desc ,
205-                                           entry_ptr->conv_entry .data_type ,
206-                                           w->shape [0 ],
207-                                           w->shape [1 ],
208-                                           w->shape [2 ],
209-                                           w->shape [3 ]));
210-   //  Set Input
211-   MIOPEN_CALL (miopenSet4dTensorDescriptor (entry_ptr->conv_entry .input_desc ,
212-                                           entry_ptr->conv_entry .data_type ,
213-                                           x->shape [0 ],
214-                                           x->shape [1 ],
215-                                           x->shape [2 ],
216-                                           x->shape [3 ]));
217-   //  Set Output
218-   MIOPEN_CALL (miopenSet4dTensorDescriptor (entry_ptr->conv_entry .output_desc ,
219-                                           entry_ptr->conv_entry .data_type ,
220-                                           y->shape [0 ],
221-                                           y->shape [1 ],
222-                                           y->shape [2 ],
223-                                           y->shape [3 ]));
224- 
225-   const  float  alpha = 1 .f ;
226-   const  float  beta = 0 .f ;
227-   MIOPEN_CALL (miopenConvolutionForward (entry_ptr->handle ,
228-                                        &alpha,
229-                                        entry_ptr->conv_entry .input_desc ,
230-                                        x->data ,
231-                                        entry_ptr->conv_entry .filter_desc ,
232-                                        w->data ,
233-                                        entry_ptr->conv_entry .conv_desc ,
234-                                        entry_ptr->conv_entry .fwd_algo ,
235-                                        &beta,
236-                                        entry_ptr->conv_entry .output_desc ,
237-                                        y->data ,
238-                                        entry_ptr->conv_entry .workspace ,
239-                                        entry_ptr->conv_entry .workspace_size ));
240- });
145+     .set_body([](TVMArgs args, TVMRetValue* ret) {
146+       const  int  mode = args[0 ];
147+       const  int  dtype = args[1 ];
148+       const  int  pad_h = args[2 ];
149+       const  int  pad_w = args[3 ];
150+       const  int  stride_h = args[4 ];
151+       const  int  stride_w = args[5 ];
152+       const  int  dilation_h = args[6 ];
153+       const  int  dilation_w = args[7 ];
154+       const  int  algo = args[8 ];
155+       const  DLTensor* x = args[9 ];
156+       const  DLTensor* w = args[10 ];
157+       const  DLTensor* y = args[11 ];
158+ 
159+       MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal ();
160+       entry_ptr->conv_entry .fwd_algo  =
161+           static_cast <miopenConvFwdAlgorithm_t>(algo);
162+       //  Set Mode
163+       entry_ptr->conv_entry .mode  = static_cast <miopenConvolutionMode_t>(mode);
164+       //  Set Ctx
165+       entry_ptr->conv_entry .ctx  = x->ctx ;
166+       //  Set Data Type
167+       entry_ptr->conv_entry .data_type  = static_cast <miopenDataType_t>(
168+           dtype);  //  MIOpen suppports fp32(miopenFloat), fp16(miopenHalf) at
169+                    //  this moment.
170+       //  Set Desc
171+       MIOPEN_CALL (miopenInitConvolutionDescriptor (
172+           entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .mode , pad_h,
173+           pad_w, stride_h, stride_w, dilation_h, dilation_w));
174+       //  Set Filter
175+       MIOPEN_CALL (miopenSet4dTensorDescriptor (
176+           entry_ptr->conv_entry .filter_desc , entry_ptr->conv_entry .data_type ,
177+           w->shape [0 ], w->shape [1 ], w->shape [2 ], w->shape [3 ]));
178+       //  Set Input
179+       MIOPEN_CALL (miopenSet4dTensorDescriptor (
180+           entry_ptr->conv_entry .input_desc , entry_ptr->conv_entry .data_type ,
181+           x->shape [0 ], x->shape [1 ], x->shape [2 ], x->shape [3 ]));
182+       //  Set Output
183+       MIOPEN_CALL (miopenSet4dTensorDescriptor (
184+           entry_ptr->conv_entry .output_desc , entry_ptr->conv_entry .data_type ,
185+           y->shape [0 ], y->shape [1 ], y->shape [2 ], y->shape [3 ]));
186+ 
187+       const  float  alpha = 1 .f ;
188+       const  float  beta = 0 .f ;
189+       MIOPEN_CALL (miopenConvolutionForward (
190+           entry_ptr->handle , &alpha, entry_ptr->conv_entry .input_desc , x->data ,
191+           entry_ptr->conv_entry .filter_desc , w->data ,
192+           entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .fwd_algo ,
193+           &beta, entry_ptr->conv_entry .output_desc , y->data ,
194+           entry_ptr->conv_entry .workspace ,
195+           entry_ptr->conv_entry .workspace_size ));
196+     });
241197
242198}  //  namespace miopen
243199}  //  namespace contrib
0 commit comments