@@ -1026,6 +1026,52 @@ def local_Shape_of_SpecifyShape(fgraph, node):
10261026 return [stack (shape ).astype (np .int64 )]
10271027
10281028
1029+ @register_canonicalize
1030+ @register_specialize
1031+ @node_rewriter ([SpecifyShape ])
1032+ def local_specify_shape_lift (fgraph , node ):
1033+ """Lift SpecifyShape of Elemwise towards the inputs."""
1034+ inp , * shape = node .inputs
1035+ if inp .owner and isinstance (inp .owner .op , Elemwise ):
1036+ if len (inp .owner .outputs ) != 1 :
1037+ return None
1038+
1039+ elem_inps = inp .owner .inputs
1040+ if len (elem_inps ) == 1 :
1041+ new_elem_inps = [specify_shape (elem_inps [0 ], shape )]
1042+ else :
1043+ # Rewrite does not support case where specify_shape provides new broadcastable information,
1044+ # As that may require a specify_shape for each input
1045+ out_broadcastable = node .outputs [0 ].type .broadcastable
1046+ if out_broadcastable != inp .type .broadcastable :
1047+ return None
1048+
1049+ # All non-broadcastable dimensions of inputs must match the non-broadcastbale specify_shape dims
1050+ # We look for a sufficient input to assign all the specify_shape dims
1051+ # We could consider distributing the SpecifyShape across multiple inputs, when none is sufficient
1052+
1053+ nonbcast_dims = {
1054+ i
1055+ for i , (dim , bcast ) in enumerate (zip (shape , out_broadcastable ))
1056+ if (not bcast and not NoneConst .equals (dim ))
1057+ }
1058+ new_elem_inps = elem_inps .copy ()
1059+ for i , elem_inp in enumerate (elem_inps ):
1060+ if all (
1061+ bcast_dim is False
1062+ for dim , bcast_dim in enumerate (elem_inp .type .broadcastable )
1063+ if dim in nonbcast_dims
1064+ ):
1065+ new_elem_inps [i ] = specify_shape (elem_inp , shape )
1066+ break
1067+ else : # no-break, no sufficient candidate found
1068+ return None
1069+
1070+ new_out = inp .owner .op .make_node (* new_elem_inps ).outputs
1071+ copy_stack_trace (node .outputs , new_out )
1072+ return new_out
1073+
1074+
10291075@register_useless
10301076@register_canonicalize
10311077@node_rewriter ([Shape_i ])
0 commit comments