3636 register_useless ,
3737 topo_constant_folding ,
3838)
39+ from pytensor .tensor .rewriting .elemwise import apply_local_dimshuffle_lift
3940from pytensor .tensor .shape import (
4041 Reshape ,
4142 Shape ,
@@ -757,40 +758,36 @@ def apply(self, fgraph):
757758pytensor .compile .mode .optdb .register ("UnShapeOpt" , UnShapeOptimizer (), position = 10 )
758759
759760
761+ @register_useless
760762@register_canonicalize
761763@node_rewriter ([Reshape ])
762- def local_useless_dimshuffle_in_reshape (fgraph , node ):
764+ def local_useless_expand_dims_in_reshape (fgraph , node ):
763765 """
764- Removes useless DimShuffle operation inside Reshape:
765-
766- reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
767- reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
768- reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
769- reshape(col.dimshuffle(0), shp) => reshape(col, shp)
766+ Removes useless expand_dims `DimShuffle` operations inside Reshape:
767+ reshape(expand_dims(vector, axis=0), shp) => reshape(vector, shp)
768+ reshape(expand_dims(matrix, axis=(0, 2), shp) => reshape(matrix, shp)
770769
770+ Implicit (and useless) squeezes are kept in the graph, as they are
771+ part of the canonical form of the graph.
771772 """
772- dimshuffled_x , new_shape = node .inputs
773+ expanded_x , new_shape = node .inputs
773774
774775 if not (
775- dimshuffled_x .owner is not None
776- and isinstance (dimshuffled_x .owner .op , DimShuffle )
776+ expanded_x .owner is not None
777+ and isinstance (expanded_x .owner .op , DimShuffle )
778+ and expanded_x .owner .op .augment
777779 ):
778780 return False
779781
780- [inp ] = dimshuffled_x .owner .inputs
781- new_order = dimshuffled_x .owner .op .new_order
782- new_order_of_nonbroadcast = []
783- for i , s in zip (new_order , node .inputs [0 ].type .shape , strict = True ):
784- if s != 1 :
785- new_order_of_nonbroadcast .append (i )
786- no_change_in_order = all (
787- new_order_of_nonbroadcast [i ] <= new_order_of_nonbroadcast [i + 1 ]
788- for i in range (len (new_order_of_nonbroadcast ) - 1 )
789- )
790- if no_change_in_order :
791- ret = inp .reshape (new_shape )
792- copy_stack_trace (node .outputs [0 ], ret )
793- return [ret ]
782+ [x ] = expanded_x .owner .inputs
783+
784+ new_order = tuple (o for o in expanded_x .owner .op .new_order if o != "x" )
785+ if new_order != tuple (range (x .type .ndim )):
786+ x = x .dimshuffle (new_order )
787+
788+ new_reshaped_x = x .reshape (new_shape )
789+ copy_stack_trace (node .outputs [0 ], new_reshaped_x )
790+ return [new_reshaped_x ]
794791
795792
796793@register_canonicalize ("shape_unsafe" )
@@ -920,10 +917,10 @@ def local_useless_reshape(fgraph, node):
920917
921918 shape_feature = getattr (fgraph , "shape_feature" , None )
922919
923- # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for -1
924- # or cases where all but one dimension are provably preserved
920+ # Match case where at least (n-1) entries correspond to the original shape:
921+ # Reshape(x, [x.shape[0], ..., x.shape[-1]]), or Reshape(x, [x.shape[0], y, x.shape[2], ... x.shape[-1]])
922+ # Where y can be -1 or anything with an unknown value, since the only valid reshape is still a no reshape.
925923 output_shape_is = _unpack_shape_vector (output_shape )
926-
927924 nb_m1 = 0
928925 shape_match = [False ] * inp .type .ndim
929926 for dim in range (inp .type .ndim ):
@@ -935,48 +932,136 @@ def local_useless_reshape(fgraph, node):
935932 nb_m1 += 1
936933
937934 if nb_m1 <= 1 and all (shape_match ):
938- return [inp ]
935+ return [inp ] # This is provably correct
939936
940937 # There is one missing match, but all other dimensions match
938+ # Such as x.type.shape == (3, 5, None) and output_shape == (3, 5, y)
941939 if (nb_m1 == 0 ) and (shape_match .count (False ) == 1 ):
942- return [inp ]
940+ return [inp ] # This could mask a shape error
943941
944942 return False
945943
946944
947- @register_canonicalize
945+ @register_canonicalize ( "shape_unsafe" )
948946@node_rewriter ([Reshape ])
949947def local_reshape_to_dimshuffle (fgraph , node ):
950- r"""Replace broadcastable dimensions in `Reshape` nodes with `DimShuffle`\s .
948+ r"""Remove `Reshape` operations over length-1 (broadcastable) dimensions .
951949
952- The goal is to avoid using `Reshape` to add or remove broadcastable
953- dimensions, and to use `DimShuffle` instead, since `DimShuffle`\s can
954- cancel out and/or be removed later on.
950+ It's always valid to squeeze an input before doing the same reshape operation.
951+ Equivalently, it's always valid to remove `1` entries from the reshape shape
952+ and replace them by an expand_dims after the rewritten reshape operation.
953+
954+ We chose to canonicalize the graph in this way as it allows isolating
955+ operations that are unique to the reshaping operation (mixing dimensions)
956+ from those that can be more legibly encoded by DimShuffle (squeeze and expand_dims).
957+ This can allow further simplifications by other rewrites that target
958+ DimShuffle but not Reshape, as well as facilitate the removal of useless reshape operations.
955959
956960 For example:
957- - reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,))
958- - reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
961+ - reshape(col, (m, n)) -> reshape(squeeze(col, axis=1), (m, n))
962+ - reshape(col, (1, m, n)) -> expand_dims(reshape(squeeze(col, axis=1), (m, n)), axis=0)
963+ - reshape(x, (1, m, 1, n, 1, 1)) -> expand_dims(reshape(x, (m, n)), axis=(0, 2, 4, 5))
964+
959965 """
960966 inp , output_shape = node .inputs
961967 [output ] = node .outputs
962968
963- unpacked_shape = _unpack_shape_vector (output_shape )
964- expand_axes = []
965- new_output_shape = []
966- for i , dim in enumerate (unpacked_shape ):
967- if isinstance (dim , Constant ) and dim .data == 1 :
968- expand_axes .append (i )
969- else :
970- new_output_shape .append (dim )
969+ # Remove any broadcastable dimensions from the input
970+ squeeze_axes = [i for i , bcast in enumerate (inp .type .broadcastable ) if bcast ]
971+
972+ # Trivial case, all dimensions of input/output are known to be broadcastable:
973+ # there's nothing to reshape
974+ if all (inp .type .broadcastable ) or all (output .type .broadcastable ):
975+ new_output_shape = []
976+ expand_axes = tuple (range (output .type .ndim ))
977+
978+ else :
979+ unpacked_shape = _unpack_shape_vector (output_shape )
980+ new_output_shape = []
981+ expand_axes = []
982+ for i , dim_length in enumerate (unpacked_shape ):
983+ if isinstance (dim_length , Constant ) and (
984+ dim_length .data == 1
985+ # -1 can be an implicit expand_dims, but it's tricky to prove
986+ # as we would need to check whether all other dimensions
987+ # already explain the full size of the array.
988+ # Example: np.zeros((2, 2, 2)).reshape((8, -1))
989+ # We rely on the output static shape which will already have figured
990+ # it out for some (but not all) cases
991+ or (dim_length .data == - 1 and output .type .shape [i ] == 1 )
992+ ):
993+ expand_axes .append (i )
994+ else :
995+ new_output_shape .append (dim_length )
996+
997+ if squeeze_axes or expand_axes :
998+ new_out = inp .squeeze (squeeze_axes )
999+
1000+ if new_output_shape :
1001+ new_out = new_out .reshape (new_output_shape )
1002+ copy_stack_trace (output , new_out )
1003+
1004+ new_out = expand_dims (new_out , expand_axes )
1005+
1006+ if not new_output_shape :
1007+ # Eagerly merge consecutive squeeze and expand_dims
1008+ new_out = apply_local_dimshuffle_lift (fgraph , new_out )
9711009
972- if len (new_output_shape ) != output .type .ndim :
973- inner = inp .reshape (new_output_shape )
974- copy_stack_trace (output , inner )
975- new_out = expand_dims (inner , expand_axes )
9761010 copy_stack_trace (output , new_out )
9771011 return [new_out ]
9781012
9791013
1014+ @register_specialize
1015+ @node_rewriter ([Reshape ])
1016+ def local_fuse_squeeze_reshape (fgraph , node ):
1017+ r"""If there is a squeeze right before a reshape, merge them.
1018+
1019+ This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
1020+ """
1021+ x , new_shape = node .inputs
1022+
1023+ if (
1024+ x .owner is not None
1025+ and isinstance (x .owner .op , DimShuffle )
1026+ and x .owner .op .is_squeeze
1027+ ):
1028+ # A reshape can always subsume a squeeze.
1029+ x = x .owner .inputs [0 ]
1030+ return [x .reshape (new_shape )]
1031+
1032+
1033+ @register_specialize
1034+ @node_rewriter ([DimShuffle ])
1035+ def local_fuse_expand_dims_reshape (fgraph , node ):
1036+ r"""If there is an expand_dims right after a reshape, merge them.
1037+
1038+ This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
1039+ """
1040+ if not node .op .is_expand_dims :
1041+ return None
1042+
1043+ reshaped_x = node .inputs [0 ]
1044+
1045+ if not (reshaped_x .owner and isinstance (reshaped_x .owner .op , Reshape )):
1046+ return None
1047+
1048+ if len (fgraph .clients [reshaped_x ]) > 1 :
1049+ # The reshape is used elsewhere, don't fuse as it can sometimes require a copy.
1050+ # Example: `x = pt.matrix(); y = x.T.reshape(-1); out = y[: None] * y[None, :]`
1051+ return None
1052+
1053+ x , new_shape = reshaped_x .owner .inputs
1054+
1055+ # Add expand_dims to shape
1056+ new_shape = list (_unpack_shape_vector (new_shape ))
1057+ for i in node .op .augment :
1058+ new_shape .insert (i , 1 )
1059+
1060+ new_reshaped_x = x .reshape (new_shape )
1061+ copy_stack_trace (node .outputs [0 ], new_reshaped_x )
1062+ return [new_reshaped_x ]
1063+
1064+
9801065@register_canonicalize
9811066@register_specialize
9821067@node_rewriter ([Reshape ])
0 commit comments