@@ -2158,62 +2158,79 @@ def tensordot(
21582158 a = as_tensor_variable (a )
21592159 b = as_tensor_variable (b )
21602160 runtime_shape_a = a .shape
2161- bcast_a = a .broadcastable
21622161 static_shape_a = a .type .shape
2163- ndim_a = a .ndim
2162+ ndim_a = a .type . ndim
21642163 runtime_shape_b = b .shape
2165- bcast_b = b .broadcastable
21662164 static_shape_b = b .type .shape
2167- ndim_b = b .ndim
2165+ ndim_b = b .type . ndim
21682166 if na != nb :
21692167 raise ValueError (
21702168 "The number of axes supplied for tensordot must be equal for each tensor. "
21712169 f"Got { na } and { nb } respectively."
21722170 )
21732171 axes_a = list (normalize_axis_tuple (axes_a , ndim_a ))
21742172 axes_b = list (normalize_axis_tuple (axes_b , ndim_b ))
2173+
2174+ # The operation is only valid if the original dimensions match in length
2175+ # The ravelling of the dimensions to coerce the operation into a single dot
2176+ # could mask such errors, so we add an Assert if needed.
21752177 must_assert_runtime = False
2176- for k in range (na ):
2177- ax_a = axes_a [k ]
2178- ax_b = axes_b [k ]
2179- if (bcast_a [ax_a ] != bcast_b [ax_b ]) or (
2178+ for ax_a , ax_b in zip (axes_a , axes_b , strict = True ):
2179+ if (
21802180 static_shape_a [ax_a ] is not None
21812181 and static_shape_b [ax_b ] is not None
21822182 and static_shape_a [ax_a ] != static_shape_b [ax_b ]
21832183 ):
21842184 raise ValueError (
2185- "Input arrays have inconsistent broadcastable pattern or type shape along the axes "
2185+ "Input arrays have inconsistent type shape along the axes "
21862186 "that are to be reduced with tensordot."
21872187 )
21882188 elif static_shape_a [ax_a ] is None or static_shape_b [ax_b ] is None :
21892189 if must_assert_runtime :
21902190 a = Assert (
21912191 "Input array shape along reduced axes of tensordot are not equal"
2192- )(a , eq (a . shape [ax_a ], b . shape [ax_b ]))
2192+ )(a , eq (runtime_shape_a [ax_a ], runtime_shape_b [ax_b ]))
21932193 must_assert_runtime = True
21942194
2195- # Move the axes to sum over to the end of "a"
2196- # and to the front of "b"
2197- notin = [k for k in range (ndim_a ) if k not in axes_a ]
2198- newaxes_a = notin + axes_a
2199- N2 = 1
2200- for axis in axes_a :
2201- N2 *= runtime_shape_a [axis ]
2202- newshape_a = (- 1 , N2 )
2203- olda = [runtime_shape_a [axis ] for axis in notin ]
2204-
2205- notin = [k for k in range (ndim_b ) if k not in axes_b ]
2206- newaxes_b = axes_b + notin
2207- N2 = 1
2208- for axis in axes_b :
2209- N2 *= runtime_shape_b [axis ]
2210- newshape_b = (N2 , - 1 )
2211- oldb = [runtime_shape_b [axis ] for axis in notin ]
2212-
2213- at = a .transpose (newaxes_a ).reshape (newshape_a )
2214- bt = b .transpose (newaxes_b ).reshape (newshape_b )
2215- res = _dot (at , bt )
2216- return res .reshape (olda + oldb )
2195+ # Convert tensordot into a stacked dot product.
2196+ # We stack the summed axes and the non-summed axes of each tensor separately,
2197+ # and place the summed axes at the end of a and the beginning of b
2198+ non_summed_axes_a = [k for k in range (ndim_a ) if k not in axes_a ]
2199+ non_summed_dims_a = [runtime_shape_a [axis ] for axis in non_summed_axes_a ]
2200+ transpose_axes_a = non_summed_axes_a + axes_a
2201+ # We only need a reshape when we need to combine summed or non-summed dims
2202+ # or introduce a new dimension (expand_dims), when doing a non-scalar outer product (axes = 0)
2203+ a_needs_reshape = (ndim_a != 0 ) and (
2204+ (len (non_summed_axes_a ) > 1 ) or (len (axes_a ) != 1 )
2205+ )
2206+
2207+ non_summed_axes_b = [k for k in range (ndim_b ) if k not in axes_b ]
2208+ non_summed_dims_b = [runtime_shape_b [axis ] for axis in non_summed_axes_b ]
2209+ transpose_axes_b = axes_b + non_summed_axes_b
2210+ b_needs_reshape = (ndim_b != 0 ) and (
2211+ (len (non_summed_axes_b ) > 1 ) or (len (axes_b ) != 1 )
2212+ )
2213+
2214+ # summed_size_a and summed_size_b must be the same,
2215+ # but to facilitate reasoning about useless reshapes we compute both from their shapes
2216+ at = a .transpose (transpose_axes_a )
2217+ if a_needs_reshape :
2218+ non_summed_size_a = variadic_mul (* non_summed_dims_a )
2219+ summed_size_a = variadic_mul (* [runtime_shape_a [axis ] for axis in axes_a ])
2220+ at = at .reshape ((non_summed_size_a , summed_size_a ))
2221+
2222+ bt = b .transpose (transpose_axes_b )
2223+ if b_needs_reshape :
2224+ non_summed_size_b = variadic_mul (* non_summed_dims_b )
2225+ summed_size_b = variadic_mul (* [runtime_shape_b [axis ] for axis in axes_b ])
2226+ bt = bt .reshape ((summed_size_b , non_summed_size_b ))
2227+
2228+ res = dot (at , bt )
2229+
2230+ if a_needs_reshape or b_needs_reshape :
2231+ res = res .reshape (non_summed_dims_a + non_summed_dims_b )
2232+
2233+ return res
22172234
22182235
22192236def outer (x , y ):
0 commit comments