Skip to content

Commit b632238

Browse files
author
Siyuan Feng
committed
[Relax] Support left_shift and right_shift op
Introduced left_shift and right_shift op in Relax with ONNX frontend support.
1 parent accd582 commit b632238

File tree

10 files changed

+182
-8
lines changed

10 files changed

+182
-8
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 94 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ class BinaryBase(OnnxOpConverter):
244244
relax_op: Callable = None
245245

246246
@classmethod
247-
def _impl_v1(cls, bb, inputs, attr, params):
247+
def base_impl(cls, bb, inputs, attr, params):
248248
if cls.numpy_op is None or cls.relax_op is None:
249249
raise ValueError("Numpy and Relax operators must be defined for BinaryBase.")
250250
if all([isinstance(inp, relax.Constant) for inp in inputs]):
@@ -274,83 +274,131 @@ class Add(BinaryBase):
274274
numpy_op = _np.add
275275
relax_op = relax.op.add
276276

277+
@classmethod
278+
def _impl_v1(cls, bb, inputs, attr, params):
279+
return cls.base_impl(bb, inputs, attr, params)
280+
277281

278282
class Sub(BinaryBase):
279283
"""Converts an onnx Sub node into an equivalent Relax expression."""
280284

281285
numpy_op = _np.subtract
282286
relax_op = relax.op.subtract
283287

288+
@classmethod
289+
def _impl_v1(cls, bb, inputs, attr, params):
290+
return cls.base_impl(bb, inputs, attr, params)
291+
284292

285293
class Mul(BinaryBase):
286294
"""Converts an onnx Mul node into an equivalent Relax expression."""
287295

288296
numpy_op = _np.multiply
289297
relax_op = relax.op.multiply
290298

299+
@classmethod
300+
def _impl_v1(cls, bb, inputs, attr, params):
301+
return cls.base_impl(bb, inputs, attr, params)
302+
291303

292304
class Div(BinaryBase):
293305
"""Converts an onnx Div node into an equivalent Relax expression."""
294306

295307
numpy_op = _np.divide
296308
relax_op = relax.op.divide
297309

310+
@classmethod
311+
def _impl_v1(cls, bb, inputs, attr, params):
312+
return cls.base_impl(bb, inputs, attr, params)
313+
298314

299315
class Pow(BinaryBase):
300316
"""Converts an onnx Pow node into an equivalent Relax expression."""
301317

302318
numpy_op = _np.power
303319
relax_op = relax.op.power
304320

321+
@classmethod
322+
def _impl_v1(cls, bb, inputs, attr, params):
323+
return cls.base_impl(bb, inputs, attr, params)
324+
305325

306326
class And(BinaryBase):
307327
"""Converts an onnx And node into an equivalent Relax expression."""
308328

309329
numpy_op = _np.logical_and
310330
relax_op = relax.op.logical_and
311331

332+
@classmethod
333+
def _impl_v1(cls, bb, inputs, attr, params):
334+
return cls.base_impl(bb, inputs, attr, params)
335+
312336

313337
class Or(BinaryBase):
314338
"""Converts an onnx Or node into an equivalent Relax expression."""
315339

316340
numpy_op = _np.logical_or
317341
relax_op = relax.op.logical_or
318342

343+
@classmethod
344+
def _impl_v1(cls, bb, inputs, attr, params):
345+
return cls.base_impl(bb, inputs, attr, params)
346+
319347

320348
class Xor(BinaryBase):
321349
"""Converts an onnx Xor node into an equivalent Relax expression."""
322350

323351
numpy_op = _np.logical_xor
324352
relax_op = relax.op.logical_xor
325353

354+
@classmethod
355+
def _impl_v1(cls, bb, inputs, attr, params):
356+
return cls.base_impl(bb, inputs, attr, params)
357+
326358

327359
class Less(BinaryBase):
328360
"""Converts an onnx Less node into an equivalent Relax expression."""
329361

330362
numpy_op = _np.less
331363
relax_op = relax.op.less
332364

365+
@classmethod
366+
def _impl_v1(cls, bb, inputs, attr, params):
367+
return cls.base_impl(bb, inputs, attr, params)
368+
333369

334370
class LessOrEqual(BinaryBase):
335371
"""Converts an onnx LessEqual node into an equivalent Relax expression."""
336372

337373
numpy_op = _np.less_equal
338374
relax_op = relax.op.less_equal
339375

376+
@classmethod
377+
def _impl_v1(cls, bb, inputs, attr, params):
378+
return cls.base_impl(bb, inputs, attr, params)
379+
340380

341381
class Greater(BinaryBase):
342382
"""Converts an onnx Greater node into an equivalent Relax expression."""
343383

344384
numpy_op = _np.greater
345385
relax_op = relax.op.greater
346386

387+
@classmethod
388+
def _impl_v1(cls, bb, inputs, attr, params):
389+
return cls.base_impl(bb, inputs, attr, params)
390+
347391

348392
class GreaterOrEqual(BinaryBase):
349393
"""Converts an onnx GreaterEqual node into an equivalent Relax expression."""
350394

351395
numpy_op = _np.greater_equal
352396
relax_op = relax.op.greater_equal
353397

398+
@classmethod
399+
def _impl_v1(cls, bb, inputs, attr, params):
400+
return cls.base_impl(bb, inputs, attr, params)
401+
354402

355403
class Equal(OnnxOpConverter):
356404
"""Converts an onnx Equal node into an equivalent Relax expression."""
@@ -374,39 +422,77 @@ class BitwiseBase(BinaryBase):
374422
"""Converts an onnx BitwiseBase node into an equivalent Relax expression."""
375423

376424
@classmethod
377-
def base_impl(cls, bb, inputs, attr, params, py_func, relax_op):
425+
def base_impl(cls, bb, inputs, attr, params):
378426
valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]
379427
for num, inp in enumerate(inputs):
380428
if inp.struct_info.dtype not in valid_types:
381429
raise ValueError(
382430
f"Bitwise operations expect all inputs to have integer types, "
383431
f"got {inp.struct_info.dtype} for input {num}"
384432
)
385-
return BinaryBase.base_impl(bb, inputs, attr, params, py_func, relax_op)
433+
return super().base_impl(bb, inputs, attr, params)
386434

387435

388436
class BitwiseAnd(BitwiseBase):
389437
"""Converts an onnx BitwiseAnd node into an equivalent Relax expression."""
390438

439+
numpy_op = _np.bitwise_and
440+
relax_op = relax.op.bitwise_and
441+
391442
@classmethod
392443
def _impl_v18(cls, bb, inputs, attr, params):
393-
return cls.base_impl(bb, inputs, attr, params, lambda x, y: x & y, relax.op.bitwise_and)
444+
return cls.base_impl(bb, inputs, attr, params)
394445

395446

396447
class BitwiseOr(BitwiseBase):
397448
"""Converts an onnx BitwiseOr node into an equivalent Relax expression."""
398449

450+
numpy_op = _np.bitwise_or
451+
relax_op = relax.op.bitwise_or
452+
399453
@classmethod
400454
def _impl_v18(cls, bb, inputs, attr, params):
401-
return cls.base_impl(bb, inputs, attr, params, lambda x, y: x | y, relax.op.bitwise_or)
455+
return cls.base_impl(bb, inputs, attr, params)
402456

403457

404458
class BitwiseXor(BitwiseBase):
405459
"""Converts an onnx BitwiseXor node into an equivalent Relax expression."""
406460

461+
numpy_op = _np.bitwise_xor
462+
relax_op = relax.op.bitwise_xor
463+
407464
@classmethod
408465
def _impl_v18(cls, bb, inputs, attr, params):
409-
return cls.base_impl(bb, inputs, attr, params, lambda x, y: x ^ y, relax.op.bitwise_xor)
466+
return cls.base_impl(bb, inputs, attr, params)
467+
468+
469+
class BitwiseNot(BitwiseBase):
470+
"""Converts an onnx BitwiseNot node into an equivalent Relax expression."""
471+
472+
numpy_op = _np.bitwise_not
473+
relax_op = relax.op.bitwise_not
474+
475+
@classmethod
476+
def _impl_v18(cls, bb, inputs, attr, params):
477+
return cls.base_impl(bb, inputs, attr, params)
478+
479+
480+
class BitShift(BitwiseBase):
481+
"""Converts an onnx BitShift node into an equivalent Relax expression."""
482+
483+
@classmethod
484+
def _impl_v11(cls, bb, inputs, attr, params):
485+
direction = attr.get("direction", "LEFT").decode("ascii")
486+
if direction == "LEFT":
487+
cls.numpy_op = _np.left_shift
488+
cls.relax_op = relax.op.left_shift
489+
elif direction == "RIGHT":
490+
cls.numpy_op = _np.right_shift
491+
cls.relax_op = relax.op.right_shift
492+
else:
493+
raise ValueError("Unsupported Shift Direction: " + direction)
494+
495+
return cls.base_impl(bb, inputs, attr, params)
410496

411497

412498
class Sigmoid(OnnxOpConverter):
@@ -2652,8 +2738,8 @@ def _get_convert_map():
26522738
"BitwiseAnd": BitwiseAnd,
26532739
"BitwiseOr": BitwiseOr,
26542740
"BitwiseXor": BitwiseXor,
2655-
# "BitwiseNot": BitwiseNot,
2656-
# "BitwiseShift": BitwiseShift,
2741+
"BitwiseNot": BitwiseNot,
2742+
"BitShift": BitShift,
26572743
"And": And,
26582744
"Or": Or,
26592745
"Xor": Xor,

python/tvm/relax/op/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
floor_divide,
5353
greater,
5454
greater_equal,
55+
left_shift,
5556
less,
5657
less_equal,
5758
logical_and,
@@ -62,6 +63,7 @@
6263
multiply,
6364
not_equal,
6465
power,
66+
right_shift,
6567
subtract,
6668
)
6769
from .create import (

python/tvm/relax/op/binary.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,35 @@ def bitwise_xor(x1: Expr, x2: Expr) -> Expr:
386386
The computed result.
387387
"""
388388
return _ffi_api.bitwise_xor(x1, x2)
389+
390+
391+
def left_shift(x1: Expr, x2: Expr) -> Expr:
392+
"""Bitwise Shift Left
393+
Parameters
394+
----------
395+
x1 : relax.Expr
396+
The input tensor to be shifted.
397+
x2 : relax.Expr
398+
The number of positions to shift.
399+
Returns
400+
-------
401+
result : relax.Expr
402+
The computed result.
403+
"""
404+
return _ffi_api.left_shift(x1, x2)
405+
406+
407+
def right_shift(x1: Expr, x2: Expr) -> Expr:
408+
"""Bitwise Shift Right
409+
Parameters
410+
----------
411+
x1 : relax.Expr
412+
The input tensor to be shifted.
413+
x2 : relax.Expr
414+
The number of positions to shift.
415+
Returns
416+
-------
417+
result : relax.Expr
418+
The computed result.
419+
"""
420+
return _ffi_api.right_shift(x1, x2)

python/tvm/relax/transform/legalize_ops/binary.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr:
6262
register_legalize("relax.bitwise_and", _binary(topi.bitwise_and))
6363
register_legalize("relax.bitwise_or", _binary(topi.bitwise_or))
6464
register_legalize("relax.bitwise_xor", _binary(topi.bitwise_xor))
65+
register_legalize("relax.left_shift", _binary(topi.left_shift))
66+
register_legalize("relax.right_shift", _binary(topi.right_shift))
6567

6668
# logical
6769
register_legalize("relax.logical_and", _binary(topi.logical_and))

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
isinf,
103103
isnan,
104104
layout_transform,
105+
left_shift,
105106
less,
106107
less_equal,
107108
linear,
@@ -133,6 +134,7 @@
133134
quantize,
134135
repeat,
135136
reshape,
137+
right_shift,
136138
round,
137139
rsqrt,
138140
scatter_elements,
@@ -773,6 +775,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
773775
"isinf",
774776
"isnan",
775777
"layout_transform",
778+
"left_shift",
776779
"less",
777780
"less_equal",
778781
"linear",
@@ -809,6 +812,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
809812
"repeat",
810813
"reshape",
811814
"rewriter",
815+
"right_shift",
812816
"tensor_to_shape",
813817
"shape_to_tensor",
814818
"rocm",

src/relax/op/distributed/binary.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(logical_xor);
6868
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_and);
6969
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_or);
7070
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_xor);
71+
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(left_shift);
72+
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(right_shift);
7173

7274
} // namespace distributed
7375
} // namespace relax

src/relax/op/tensor/binary.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_xor);
207207
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_and);
208208
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_or);
209209
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_xor);
210+
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(left_shift);
211+
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(right_shift);
210212

211213
} // namespace relax
212214
} // namespace tvm

src/relax/op/tensor/binary.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@ Expr bitwise_or(Expr x1, Expr x2);
129129
/*! \brief Broadcasted element-wise bitwise xor */
130130
Expr bitwise_xor(Expr x1, Expr x2);
131131

132+
/*! \brief Broadcasted element-wise bitwise shift left */
133+
Expr left_shift(Expr x1, Expr x2);
134+
135+
/*! \brief Broadcasted element-wise bitwise shift right */
136+
Expr right_shift(Expr x1, Expr x2);
137+
132138
} // namespace relax
133139
} // namespace tvm
134140

0 commit comments

Comments
 (0)