-
Notifications
You must be signed in to change notification settings - Fork 144
Closed
Description
A small eager optimization we can do, to avoid useless dimshuffles.
Just return a
if not axis
before dim_it
.
pytensor/pytensor/tensor/basic.py
Lines 4187 to 4206 in a76172e
def expand_dims( | |
a: np.ndarray | TensorVariable, axis: tuple[int, ...] | |
) -> TensorVariable: | |
"""Expand the shape of an array. | |
Insert a new axis that will appear at the `axis` position in the expanded | |
array shape. | |
""" | |
a = as_tensor(a) | |
if not isinstance(axis, tuple | list): | |
axis = (axis,) | |
out_ndim = len(axis) + a.ndim | |
axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim) | |
dim_it = iter(range(a.ndim)) | |
pattern = ["x" if ax in axis else next(dim_it) for ax in range(out_ndim)] | |
return a.dimshuffle(pattern) |
Squeeze already does this:
pytensor/pytensor/tensor/extra_ops.py
Lines 591 to 593 in a76172e
if not axis: | |
# Nothing to do | |
return _x |