@@ -103,15 +103,31 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
103103 if (" nn.conv2d" == op_name) {
104104 Conv2d (nid);
105105 } else if (" dnnl.conv2d_relu" == op_name) {
106- Conv2d (nid, true , false );
106+ Conv2d (nid, true , false , dnnl::algorithm::eltwise_relu);
107+ } else if (" dnnl.conv2d_tanh" == op_name) {
108+ Conv2d (nid, true , false , dnnl::algorithm::eltwise_tanh);
109+ } else if (" dnnl.conv2d_sigmoid" == op_name) {
110+ Conv2d (nid, true , false , dnnl::algorithm::eltwise_logistic);
111+ } else if (" dnnl.conv2d_bias" == op_name) {
112+ Conv2d (nid, false , true );
107113 } else if (" dnnl.conv2d_bias_relu" == op_name) {
108- Conv2d (nid, true , true );
114+ Conv2d (nid, true , true , dnnl::algorithm::eltwise_relu);
115+ } else if (" dnnl.conv2d_bias_tanh" == op_name) {
116+ Conv2d (nid, true , true , dnnl::algorithm::eltwise_tanh);
117+ } else if (" dnnl.conv2d_bias_sigmoid" == op_name) {
118+ Conv2d (nid, true , true , dnnl::algorithm::eltwise_logistic);
109119 } else if (" nn.dense" == op_name) {
110120 Dense (nid);
121+ } else if (" dnnl.dense_bias" == op_name) {
122+ Dense (nid, true );
111123 } else if (" nn.batch_norm" == op_name) {
112124 BatchNorm (nid);
113125 } else if (" nn.relu" == op_name) {
114- Relu (nid);
126+ Eltwise (nid, dnnl::algorithm::eltwise_relu);
127+ } else if (" tanh" == op_name) {
128+ Eltwise (nid, dnnl::algorithm::eltwise_tanh);
129+ } else if (" sigmoid" == op_name) {
130+ Eltwise (nid, dnnl::algorithm::eltwise_logistic);
115131 } else if (" add" == op_name) {
116132 Binary (nid, dnnl::algorithm::binary_add);
117133 } else if (" multiply" == op_name) {
@@ -150,7 +166,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
150166 return entry_out_mem_[eid].first ;
151167 }
152168
153- void Conv2d (const size_t & nid, const bool has_relu = false , const bool has_bias = false ) {
169+ void Conv2d (const size_t & nid, const bool has_elt = false , const bool has_bias = false ,
170+ dnnl::algorithm algo = dnnl::algorithm::eltwise_relu) {
154171 auto node = nodes_[nid];
155172
156173 // Setup attributes.
@@ -159,24 +176,29 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
159176 dnnl::memory::dims input_shape = nodes_[data_entry.id_ ].GetOpShape ()[data_entry.index_ ];
160177 dnnl::memory::dims weight_shape = nodes_[weight_entry.id_ ].GetOpShape ()[weight_entry.index_ ];
161178 std::vector<std::string> str_strides = node.GetAttr <std::vector<std::string>>(" strides" );
179+ std::vector<std::string> str_dilates = node.GetAttr <std::vector<std::string>>(" dilation" );
162180 std::vector<std::string> str_padding = node.GetAttr <std::vector<std::string>>(" padding" );
163181 dnnl::memory::dim groups = std::stoi (node.GetAttr <std::vector<std::string>>(" groups" )[0 ]);
164182
165- dnnl::memory::dim N = input_shape[0 ], // batch size
166- IC = input_shape[1 ], // input channels
167- IH = input_shape[2 ], // input height
168- IW = input_shape[3 ], // input width
169- OC = weight_shape[0 ], // output channels
170- KH = weight_shape[2 ], // weight height
171- KW = weight_shape[3 ], // weight width
172- PW_L = std::stoi (str_padding[1 ]), // width padding: left
173- PW_R = std::stoi (str_padding[3 ]), // width padding: right
174- PH_L = std::stoi (str_padding[0 ]), // height padding: top
175- PH_R = std::stoi (str_padding[2 ]), // height padding: bottom
176- SH = std::stoi (str_strides[0 ]), // height-wise stride
177- SW = std::stoi (str_strides[1 ]), // weight-wise stride
178- OH = (IH - KH + PH_L + PH_R) / SH + 1 , // output height
179- OW = (IW - KW + PW_L + PW_R) / SW + 1 ; // output width
183+ dnnl::memory::dim N = input_shape[0 ], // batch size
184+ IC = input_shape[1 ], // input channels
185+ IH = input_shape[2 ], // input height
186+ IW = input_shape[3 ], // input width
187+ OC = weight_shape[0 ], // output channels
188+ KH = weight_shape[2 ], // weight height
189+ KW = weight_shape[3 ], // weight width
190+ PW_L = std::stoi (str_padding[1 ]), // width padding: left
191+ PW_R = std::stoi (str_padding[3 ]), // width padding: right
192+ PH_L = std::stoi (str_padding[0 ]), // height padding: top
193+ PH_R = std::stoi (str_padding[2 ]), // height padding: bottom
194+ SH = std::stoi (str_strides[0 ]), // height-wise stride
195+ SW = std::stoi (str_strides[1 ]), // weight-wise stride
196+ DH = std::stoi (str_dilates[0 ]) - 1 , // height-wise dilate
197+ DW = std::stoi (str_dilates[1 ]) - 1 , // weight-wise dilate
198+ DKH = 1 + (KH - 1 ) * (DH + 1 ), // dilated weight height
199+ DKW = 1 + (KW - 1 ) * (DW + 1 ), // dilated weight width
200+ OH = (IH - DKH + PH_L + PH_R) / SH + 1 , // output height
201+ OW = (IW - DKW + PW_L + PW_R) / SW + 1 ; // output width
180202
181203 // Memory shapes.
182204 dnnl::memory::dims src_dims = {N, IC, IH, IW};
@@ -187,6 +209,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
187209 dnnl::memory::dims bias_dims = {OC};
188210 dnnl::memory::dims dst_dims = {N, OC, OH, OW};
189211 dnnl::memory::dims strides_dims = {SH, SW};
212+ dnnl::memory::dims dilates_dims = {DH, DW};
190213 dnnl::memory::dims padding_dims_l = {PH_L, PW_L};
191214 dnnl::memory::dims padding_dims_r = {PH_R, PW_R};
192215
@@ -199,13 +222,14 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
199222 // Covn2d description.
200223 auto conv_desc = dnnl::convolution_forward::desc (
201224 dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md,
202- conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, padding_dims_l, padding_dims_r);
225+ conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, dilates_dims, padding_dims_l,
226+ padding_dims_r);
203227
204- // Enable ReLU
228+ // Enable elementwise post-ops
205229 dnnl::primitive_attr attr;
206- if (has_relu ) {
230+ if (has_elt ) {
207231 dnnl::post_ops ops;
208- ops.append_eltwise (1 .f , dnnl::algorithm::eltwise_relu , 0 .f , 0 .f );
232+ ops.append_eltwise (1 .f , algo , 0 .f , 0 .f );
209233 attr.set_post_ops (ops);
210234 }
211235
@@ -245,7 +269,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
245269 {DNNL_ARG_DST, conv2d_dst_memory}});
246270 }
247271
248- void Dense (const size_t & nid) {
272+ void Dense (const size_t & nid, const bool has_bias = false ) {
249273 auto node = nodes_[nid];
250274
251275 // Setup attributes.
@@ -281,9 +305,18 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
281305 // Memories.
282306 auto data_memory = BindDNNLMemory (data_entry, data_md);
283307 auto weight_memory = BindDNNLMemory (weight_entry, weight_md);
308+
309+ // Bias memory.
284310 auto bias_memory = dnnl::memory (bias_md, engine_);
285- float bias[OC] = {0 };
286- write_to_dnnl_memory (bias, bias_memory, OC * sizeof (float ));
311+ if (has_bias) {
312+ auto bias_entry = node.GetInputs ()[2 ];
313+ BindDNNLMemory (bias_entry, bias_memory);
314+ } else {
315+ float bias[OC] = {0 };
316+ write_to_dnnl_memory (bias, bias_memory, OC * sizeof (float ));
317+ }
318+
319+ // Output memory.
287320 JSONGraphNodeEntry out_entry (nid, 0 );
288321 auto dst_memory = BindDNNLMemory (out_entry, dense_prim_desc.dst_desc ());
289322
@@ -335,20 +368,20 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
335368 {DNNL_ARG_VARIANCE, variance_memory}});
336369 }
337370
338- void Relu (const size_t & nid) {
371+ void Eltwise (const size_t & nid, dnnl::algorithm algo ) {
339372 auto node = nodes_[nid];
340373
341374 auto data_entry = node.GetInputs ()[0 ];
342375 dnnl::memory::dims shape = nodes_[data_entry.id_ ].GetOpShape ()[data_entry.index_ ];
343376 dnnl::memory::desc data_md = GenDNNLMemDescByShape (shape, dt::f32 );
344377
345- auto relu_desc = dnnl::eltwise_forward::desc (dnnl::prop_kind::forward_inference,
346- dnnl::algorithm::eltwise_relu , data_md, 0 );
347- auto relu_prim_desc = dnnl::eltwise_forward::primitive_desc (relu_desc , engine_);
348- ICHECK (data_md == relu_prim_desc .dst_desc ());
378+ auto elt_desc =
379+ dnnl::eltwise_forward::desc (dnnl::prop_kind::forward_inference, algo , data_md, 0 );
380+ auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc (elt_desc , engine_);
381+ ICHECK (data_md == elt_prim_desc .dst_desc ());
349382
350- auto relu = dnnl::eltwise_forward (relu_prim_desc );
351- net_.push_back (relu );
383+ auto elt = dnnl::eltwise_forward (elt_prim_desc );
384+ net_.push_back (elt );
352385
353386 auto data_memory = BindDNNLMemory (data_entry, data_md);
354387 JSONGraphNodeEntry out_entry (nid, 0 );
0 commit comments