|
46 | 46 | tensor, |
47 | 47 | uint_dtypes, |
48 | 48 | ) |
49 | | -from pytensor.tensor.utils import as_list, normalize_reduce_axis |
| 49 | +from pytensor.tensor.utils import normalize_reduce_axis |
50 | 50 | from pytensor.tensor.variable import ( |
51 | 51 | TensorVariable, |
52 | 52 | _tensor_py_operators, |
@@ -1919,133 +1919,6 @@ def dense_dot(a, b): |
1919 | 1919 | return _dot(a, b) |
1920 | 1920 |
|
1921 | 1921 |
|
1922 | | -def _tensordot_as_dot(a, b, axes, dot, batched): |
1923 | | - """ |
1924 | | - Reduces a tensor dot product to a matrix or vector dot product. Based |
1925 | | - on code from Tijmen Tieleman's gnumpy |
1926 | | - (http://www.cs.toronto.edu/~tijmen/gnumpy.html). |
1927 | | -
|
1928 | | - Please see the documentation of tensordot for the meaning of the a, b |
1929 | | - and axes arguments. |
1930 | | -
|
1931 | | - :param dot: a function that accepts two symbolic variables and computes |
1932 | | - the appropriate dot product (e.g. dot, batched_dot) |
1933 | | - :type dot: function |
1934 | | -
|
1935 | | - :param batched: whether to treat the first axis of a and b as a batch |
1936 | | - axis. If so, this axis will be preserved in the output, |
1937 | | - allowing this function to be used also for batched |
1938 | | - tensor dot products. |
1939 | | - :type batched: boolean |
1940 | | -
|
1941 | | - :returns: a tensor with shape equal to the concatenation of a's shape |
1942 | | - (less any dimensions that were summed over) and b's shape |
1943 | | - (less the first dimension and any dimensions that were summed |
1944 | | - over). |
1945 | | - :rtype: symbolic tensor |
1946 | | - """ |
1947 | | - a, b = as_tensor_variable(a), as_tensor_variable(b) |
1948 | | - |
1949 | | - if not np.isscalar(axes) and len(axes) != 2: |
1950 | | - raise ValueError( |
1951 | | - "Axes should be an integer or a " |
1952 | | - f"list/tuple of len 2 ({axes} was provided)" |
1953 | | - ) |
1954 | | - |
1955 | | - # if 'axes' is a number of axes to multiply and sum over (trailing axes |
1956 | | - # of a, leading axes of b), we can just reshape and use dot. |
1957 | | - elif np.isscalar(axes): |
1958 | | - axes = int(axes) |
1959 | | - |
1960 | | - for operand_name, operand in (("a", a), ("b", b)): |
1961 | | - if axes > operand.ndim: |
1962 | | - raise ValueError( |
1963 | | - f"axes can not be larger than the dimension of {operand_name} " |
1964 | | - f"({operand_name}.ndim={operand.ndim}, axes={axes})" |
1965 | | - ) |
1966 | | - if batched and axes == operand.ndim: |
1967 | | - raise ValueError( |
1968 | | - "axes to sum over must not include the batch axis " |
1969 | | - f"of {operand_name} ({operand_name}.ndim={operand.ndim}, axes={axes})" |
1970 | | - ) |
1971 | | - |
1972 | | - batch_axes = 1 if batched else 0 |
1973 | | - a_outaxes = slice(0, a.ndim - axes) |
1974 | | - b_outaxes = slice(batch_axes + axes, b.ndim) |
1975 | | - outshape = concatenate([a.shape[a_outaxes], b.shape[b_outaxes]]) |
1976 | | - outbcast = a.broadcastable[a_outaxes] + b.broadcastable[b_outaxes] |
1977 | | - outndim = len(outbcast) |
1978 | | - |
1979 | | - a_shape = [1] * 2 |
1980 | | - b_shape = [1] * 2 |
1981 | | - |
1982 | | - # compute total size of summed axes |
1983 | | - for i in range(0, axes): |
1984 | | - a_shape[1] *= a.shape[-(i + 1)] |
1985 | | - b_shape[0] *= b.shape[batch_axes + i] |
1986 | | - # compute total size of other axes |
1987 | | - for i in range(0, a.ndim - axes - batch_axes): |
1988 | | - a_shape[0] *= a.shape[batch_axes + i] |
1989 | | - for i in range(0, b.ndim - axes - batch_axes): |
1990 | | - b_shape[1] *= b.shape[-(i + 1)] |
1991 | | - |
1992 | | - if batched: |
1993 | | - a_shape.insert(0, a.shape[0]) |
1994 | | - b_shape.insert(0, b.shape[0]) |
1995 | | - |
1996 | | - a_reshaped = a.reshape(a_shape) |
1997 | | - b_reshaped = b.reshape(b_shape) |
1998 | | - |
1999 | | - out_reshaped = dot(a_reshaped, b_reshaped) |
2000 | | - out = out_reshaped.reshape(outshape, ndim=outndim) |
2001 | | - # Make sure the broadcastable pattern of the result is correct, |
2002 | | - # since some shape information can be lost in the reshapes. |
2003 | | - if out.type.broadcastable != outbcast: |
2004 | | - out = specify_broadcastable( |
2005 | | - out, *(ax for (ax, b) in enumerate(outbcast) if b) |
2006 | | - ) |
2007 | | - return out |
2008 | | - |
2009 | | - # if 'axes' is a list, transpose a and b such that the summed axes of a |
2010 | | - # are last and the summed axes of b are first. |
2011 | | - else: |
2012 | | - axes = [as_list(axes_) for axes_ in axes] |
2013 | | - |
2014 | | - if len(axes[0]) != len(axes[1]): |
2015 | | - raise ValueError("Axes elements must have the same length.") |
2016 | | - |
2017 | | - for i, (operand_name, operand) in enumerate((("a", a), ("b", b))): |
2018 | | - if len(axes[i]) > operand.ndim: |
2019 | | - raise ValueError( |
2020 | | - f"axes[{i}] should be array_like with length less than " |
2021 | | - f"the dimensions of {operand_name} ({operand_name}.ndim={operand.ndim}, len(axes[0])={len(axes[i])})." |
2022 | | - ) |
2023 | | - if len(axes[i]) > 0 and np.max(axes[i]) >= operand.ndim: |
2024 | | - raise ValueError( |
2025 | | - f"axes[{i}] contains dimensions greater than or equal " |
2026 | | - f"to {operand_name}.ndim ({operand_name}.ndim={operand.ndim}, max(axes[0])={np.max(np.array(axes[i]))})." |
2027 | | - ) |
2028 | | - if batched and 0 in axes[i]: |
2029 | | - raise ValueError( |
2030 | | - "axes to sum over must not contain the batch axis " |
2031 | | - f"(axes[{i}]={axes[i]})" |
2032 | | - ) |
2033 | | - |
2034 | | - batch_axes = [0] if batched else [] |
2035 | | - other_axes = [ |
2036 | | - [x for x in range(operand.ndim) if x not in axes[i] and x not in batch_axes] |
2037 | | - for i, operand in enumerate((a, b)) |
2038 | | - ] |
2039 | | - |
2040 | | - a_shuffled = a.dimshuffle(batch_axes + other_axes[0] + axes[0]) |
2041 | | - b_shuffled = b.dimshuffle(batch_axes + axes[1] + other_axes[1]) |
2042 | | - |
2043 | | - # now that a and b are in the right order, recur with integer axes |
2044 | | - return _tensordot_as_dot( |
2045 | | - a_shuffled, b_shuffled, len(axes[0]), dot=dot, batched=batched |
2046 | | - ) |
2047 | | - |
2048 | | - |
2049 | 1922 | def tensordot( |
2050 | 1923 | a: TensorLike, b: TensorLike, axes: int | Sequence[Sequence[int]] = 2 |
2051 | 1924 | ) -> TensorVariable: |
|
0 commit comments