@@ -244,7 +244,7 @@ class BinaryBase(OnnxOpConverter):
244244 relax_op : Callable = None
245245
246246 @classmethod
247- def _impl_v1 (cls , bb , inputs , attr , params ):
247+ def base_impl (cls , bb , inputs , attr , params ):
248248 if cls .numpy_op is None or cls .relax_op is None :
249249 raise ValueError ("Numpy and Relax operators must be defined for BinaryBase." )
250250 if all ([isinstance (inp , relax .Constant ) for inp in inputs ]):
@@ -274,83 +274,131 @@ class Add(BinaryBase):
274274 numpy_op = _np .add
275275 relax_op = relax .op .add
276276
277+ @classmethod
278+ def _impl_v1 (cls , bb , inputs , attr , params ):
279+ return cls .base_impl (bb , inputs , attr , params )
280+
277281
278282class Sub (BinaryBase ):
279283 """Converts an onnx Sub node into an equivalent Relax expression."""
280284
281285 numpy_op = _np .subtract
282286 relax_op = relax .op .subtract
283287
288+ @classmethod
289+ def _impl_v1 (cls , bb , inputs , attr , params ):
290+ return cls .base_impl (bb , inputs , attr , params )
291+
284292
285293class Mul (BinaryBase ):
286294 """Converts an onnx Mul node into an equivalent Relax expression."""
287295
288296 numpy_op = _np .multiply
289297 relax_op = relax .op .multiply
290298
299+ @classmethod
300+ def _impl_v1 (cls , bb , inputs , attr , params ):
301+ return cls .base_impl (bb , inputs , attr , params )
302+
291303
292304class Div (BinaryBase ):
293305 """Converts an onnx Div node into an equivalent Relax expression."""
294306
295307 numpy_op = _np .divide
296308 relax_op = relax .op .divide
297309
310+ @classmethod
311+ def _impl_v1 (cls , bb , inputs , attr , params ):
312+ return cls .base_impl (bb , inputs , attr , params )
313+
298314
299315class Pow (BinaryBase ):
300316 """Converts an onnx Pow node into an equivalent Relax expression."""
301317
302318 numpy_op = _np .power
303319 relax_op = relax .op .power
304320
321+ @classmethod
322+ def _impl_v1 (cls , bb , inputs , attr , params ):
323+ return cls .base_impl (bb , inputs , attr , params )
324+
305325
306326class And (BinaryBase ):
307327 """Converts an onnx And node into an equivalent Relax expression."""
308328
309329 numpy_op = _np .logical_and
310330 relax_op = relax .op .logical_and
311331
332+ @classmethod
333+ def _impl_v1 (cls , bb , inputs , attr , params ):
334+ return cls .base_impl (bb , inputs , attr , params )
335+
312336
313337class Or (BinaryBase ):
314338 """Converts an onnx Or node into an equivalent Relax expression."""
315339
316340 numpy_op = _np .logical_or
317341 relax_op = relax .op .logical_or
318342
343+ @classmethod
344+ def _impl_v1 (cls , bb , inputs , attr , params ):
345+ return cls .base_impl (bb , inputs , attr , params )
346+
319347
320348class Xor (BinaryBase ):
321349 """Converts an onnx Xor node into an equivalent Relax expression."""
322350
323351 numpy_op = _np .logical_xor
324352 relax_op = relax .op .logical_xor
325353
354+ @classmethod
355+ def _impl_v1 (cls , bb , inputs , attr , params ):
356+ return cls .base_impl (bb , inputs , attr , params )
357+
326358
327359class Less (BinaryBase ):
328360 """Converts an onnx Less node into an equivalent Relax expression."""
329361
330362 numpy_op = _np .less
331363 relax_op = relax .op .less
332364
365+ @classmethod
366+ def _impl_v1 (cls , bb , inputs , attr , params ):
367+ return cls .base_impl (bb , inputs , attr , params )
368+
333369
334370class LessOrEqual (BinaryBase ):
335371 """Converts an onnx LessEqual node into an equivalent Relax expression."""
336372
337373 numpy_op = _np .less_equal
338374 relax_op = relax .op .less_equal
339375
376+ @classmethod
377+ def _impl_v1 (cls , bb , inputs , attr , params ):
378+ return cls .base_impl (bb , inputs , attr , params )
379+
340380
341381class Greater (BinaryBase ):
342382 """Converts an onnx Greater node into an equivalent Relax expression."""
343383
344384 numpy_op = _np .greater
345385 relax_op = relax .op .greater
346386
387+ @classmethod
388+ def _impl_v1 (cls , bb , inputs , attr , params ):
389+ return cls .base_impl (bb , inputs , attr , params )
390+
347391
348392class GreaterOrEqual (BinaryBase ):
349393 """Converts an onnx GreaterEqual node into an equivalent Relax expression."""
350394
351395 numpy_op = _np .greater_equal
352396 relax_op = relax .op .greater_equal
353397
398+ @classmethod
399+ def _impl_v1 (cls , bb , inputs , attr , params ):
400+ return cls .base_impl (bb , inputs , attr , params )
401+
354402
355403class Equal (OnnxOpConverter ):
356404 """Converts an onnx Equal node into an equivalent Relax expression."""
@@ -374,39 +422,77 @@ class BitwiseBase(BinaryBase):
374422 """Converts an onnx BitwiseBase node into an equivalent Relax expression."""
375423
376424 @classmethod
377- def base_impl (cls , bb , inputs , attr , params , py_func , relax_op ):
425+ def base_impl (cls , bb , inputs , attr , params ):
378426 valid_types = ["int8" , "int16" , "int32" , "int64" , "uint8" , "uint16" , "uint32" , "uint64" ]
379427 for num , inp in enumerate (inputs ):
380428 if inp .struct_info .dtype not in valid_types :
381429 raise ValueError (
382430 f"Bitwise operations expect all inputs to have integer types, "
383431 f"got { inp .struct_info .dtype } for input { num } "
384432 )
385- return BinaryBase .base_impl (bb , inputs , attr , params , py_func , relax_op )
433+ return super () .base_impl (bb , inputs , attr , params )
386434
387435
388436class BitwiseAnd (BitwiseBase ):
389437 """Converts an onnx BitwiseAnd node into an equivalent Relax expression."""
390438
439+ numpy_op = _np .bitwise_and
440+ relax_op = relax .op .bitwise_and
441+
391442 @classmethod
392443 def _impl_v18 (cls , bb , inputs , attr , params ):
393- return cls .base_impl (bb , inputs , attr , params , lambda x , y : x & y , relax . op . bitwise_and )
444+ return cls .base_impl (bb , inputs , attr , params )
394445
395446
396447class BitwiseOr (BitwiseBase ):
397448 """Converts an onnx BitwiseOr node into an equivalent Relax expression."""
398449
450+ numpy_op = _np .bitwise_or
451+ relax_op = relax .op .bitwise_or
452+
399453 @classmethod
400454 def _impl_v18 (cls , bb , inputs , attr , params ):
401- return cls .base_impl (bb , inputs , attr , params , lambda x , y : x | y , relax . op . bitwise_or )
455+ return cls .base_impl (bb , inputs , attr , params )
402456
403457
404458class BitwiseXor (BitwiseBase ):
405459 """Converts an onnx BitwiseXor node into an equivalent Relax expression."""
406460
461+ numpy_op = _np .bitwise_xor
462+ relax_op = relax .op .bitwise_xor
463+
407464 @classmethod
408465 def _impl_v18 (cls , bb , inputs , attr , params ):
409- return cls .base_impl (bb , inputs , attr , params , lambda x , y : x ^ y , relax .op .bitwise_xor )
466+ return cls .base_impl (bb , inputs , attr , params )
467+
468+
469+ class BitwiseNot (BitwiseBase ):
470+ """Converts an onnx BitwiseNot node into an equivalent Relax expression."""
471+
472+ numpy_op = _np .bitwise_not
473+ relax_op = relax .op .bitwise_not
474+
475+ @classmethod
476+ def _impl_v18 (cls , bb , inputs , attr , params ):
477+ return cls .base_impl (bb , inputs , attr , params )
478+
479+
480+ class BitShift (BitwiseBase ):
481+ """Converts an onnx BitShift node into an equivalent Relax expression."""
482+
483+ @classmethod
484+ def _impl_v11 (cls , bb , inputs , attr , params ):
485+ direction = attr .get ("direction" , "LEFT" ).decode ("ascii" )
486+ if direction == "LEFT" :
487+ cls .numpy_op = _np .left_shift
488+ cls .relax_op = relax .op .left_shift
489+ elif direction == "RIGHT" :
490+ cls .numpy_op = _np .right_shift
491+ cls .relax_op = relax .op .right_shift
492+ else :
493+ raise ValueError ("Unsupported Shift Direction: " + direction )
494+
495+ return cls .base_impl (bb , inputs , attr , params )
410496
411497
412498class Sigmoid (OnnxOpConverter ):
@@ -2652,8 +2738,8 @@ def _get_convert_map():
26522738 "BitwiseAnd" : BitwiseAnd ,
26532739 "BitwiseOr" : BitwiseOr ,
26542740 "BitwiseXor" : BitwiseXor ,
2655- # "BitwiseNot": BitwiseNot,
2656- # "BitwiseShift ": BitwiseShift ,
2741+ "BitwiseNot" : BitwiseNot ,
2742+ "BitShift " : BitShift ,
26572743 "And" : And ,
26582744 "Or" : Or ,
26592745 "Xor" : Xor ,
0 commit comments