Skip to content

Commit 72d542e

Browse files
authored
[Bugfix][ONNX] Skip constant If node generated by PyTorch (#17383)
* [Bugfix][VTA] Fix FSIM compile error on macOS. VTA FSIM could not be built on macOS, for it leverages malloc.h and memalign, yet both have been deprecated and are not provided by macOS. This issue was captured in #13173. This commit stops including malloc.h in VTA Runtime as stdlib.h has provided functions we need. This commit uses posix_memalign instead of memalign. It is a portable standard function. * Fix format. * [Bugfix][ONNX] Skip constant If node generated by PyTorch This commit adds a check for If nodes for ONNX frontend of Relay to skip the broadcast if the predicate is constant. Sometimes PyTorch to ONNX inserts silly if nodes that produce dynamic ranks, and ONNX frontend of TVM would broadcast the lower dimensions between branches, which is irrational for some cases, e.g. 5×5×3×4 to 5×5×3×4×1. The predicate of silly if might be constant and reasonable to skip to avoid the broadcast problem. This issue was captured in #16898. * Fix format.
1 parent 425e15b commit 72d542e

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4565,6 +4565,23 @@ def _impl_v1(cls, inputs, attr, params):
45654565
"Attempting to unify ranks but this may produce incorrect results."
45664566
)
45674567
warnings.warn(warning_msg)
4568+
# Skip constant If node to avoid irrational broadcast
4569+
if isinstance(inputs[0], tvm.relay.expr.Constant):
4570+
predicate = inputs[0].data.asnumpy()[0]
4571+
node_name = attr["tvm_custom"]["name"]
4572+
warn_msg_begin = f"Predicate of If node {node_name} is always "
4573+
if predicate == np.bool_(True):
4574+
warnings.warn(
4575+
warn_msg_begin
4576+
+ "true so only then branch would be executed. Removing else branch. "
4577+
)
4578+
else_expr = then_expr
4579+
elif predicate == np.bool_(False):
4580+
warnings.warn(
4581+
warn_msg_begin
4582+
+ "false so only else branch would be executed. Removing then branch. "
4583+
)
4584+
then_expr = else_expr
45684585
if len(then_shape) < len(else_shape):
45694586
then_expr = _op.broadcast_to_like(then_expr, else_expr)
45704587
else:
@@ -6529,6 +6546,7 @@ def _impl_v11(cls, inputs, attr, params):
65296546
# compatible operators that do NOT require any conversion.
65306547
_identity_list = []
65316548

6549+
65326550
# _convert_map defines maps of name to converter functor(callable)
65336551
# for 1 to 1 mapping, use Renamer if nothing but name is different
65346552
# use AttrCvt if attributes need to be converted

0 commit comments

Comments
 (0)