33import datetime as dt
44import itertools
55import warnings
6+ from collections import ChainMap
67from collections .abc import Callable , Generator , Hashable , Sequence
78from functools import partial
89from numbers import Number
@@ -710,59 +711,66 @@ def interpolate_variable(
710711 func , kwargs = _get_interpolator_nd (method , ** kwargs )
711712
712713 in_coords , result_coords = zip (* (v for v in indexes_coords .values ()), strict = True )
713- # broadcast out manually to minize confusing behaviour
714- broadcast_result_coords = broadcast_variables (* result_coords )
715- result_dims = broadcast_result_coords [0 ].dims
716714
717715 # input coordinates along which we are interpolation are core dimensions
718716 # the corresponding output coordinates may or may not have the same name,
719717 # so `all_in_core_dims` is also `exclude_dims`
720718 all_in_core_dims = set (indexes_coords )
721719
720+ result_dims = OrderedSet (itertools .chain (* (_ .dims for _ in result_coords )))
721+ result_sizes = ChainMap (* (_ .sizes for _ in result_coords ))
722+
722723 # any dimensions on the output that are present on the input, but are not being
723- # interpolated along are broadcast or loop dimensions along which we automatically
724- # vectorize. Consider the problem in
725- # https://github.com/pydata/xarray/issues/6799#issuecomment-2474126217
724+ # interpolated along are dimensions along which we automatically vectorize.
725+ # Consider the problem in https://github.com/pydata/xarray/issues/6799#issuecomment-2474126217
726726 # In the following, dimension names are listed out in [].
727727 # # da[time, q, lat, lon].interp(q=bar[lat,lon]). Here `lat`, `lon`
728- # are input dimensions, present on the output, along which we vectorize.
729- # We track these as "result broadcast dimensions" .
728+ # are input dimensions, present on the output, but are not the coordinates
729+ # we are explicitly interpolating. These are the dimensions along which we vectorize .
730730 # `q` is the only input core dimensions, and changes size (disappears)
731731 # so it is in exclude_dims.
732- result_broadcast_dims = set (
733- itertools .chain (dim for dim in result_dims if dim not in all_in_core_dims )
734- )
732+ vectorize_dims = (result_dims - all_in_core_dims ) & set (var .dims )
735733
736734 # remove any output broadcast dimensions from the list of core dimensions
737- output_core_dims = tuple (d for d in result_dims if d not in result_broadcast_dims )
735+ output_core_dims = tuple (d for d in result_dims if d not in vectorize_dims )
738736 input_core_dims = (
739737 # all coordinates on the input that we interpolate along
740738 [tuple (indexes_coords )]
741739 # the input coordinates are always 1D at the moment, so we just need to list out their names
742740 + [tuple (_ .dims ) for _ in in_coords ]
743741 # The last set of inputs are the coordinates we are interpolating to.
744- # These have been broadcast already for ease.
745- + [output_core_dims ] * len (result_coords )
742+ + [
743+ tuple (d for d in coord .dims if d not in vectorize_dims )
744+ for coord in result_coords
745+ ]
746746 )
747- output_sizes = {k : broadcast_result_coords [ 0 ]. sizes [k ] for k in output_core_dims }
747+ output_sizes = {k : result_sizes [k ] for k in output_core_dims }
748748
749749 # scipy.interpolate.interp1d always forces to float.
750750 dtype = float if not issubclass (var .dtype .type , np .inexact ) else var .dtype
751751 result = apply_ufunc (
752752 _interpnd ,
753753 var ,
754754 * in_coords ,
755- * broadcast_result_coords ,
755+ * result_coords ,
756756 input_core_dims = input_core_dims ,
757757 output_core_dims = [output_core_dims ],
758758 exclude_dims = all_in_core_dims ,
759759 dask = "parallelized" ,
760- kwargs = dict (interp_func = func , interp_kwargs = kwargs ),
760+ kwargs = dict (
761+ interp_func = func ,
762+ interp_kwargs = kwargs ,
763+ # we leave broadcasting up to dask if possible
764+ # but we need broadcasted values in _interpnd, so propagate that
765+ # context (dimension names), and broadcast there
766+ # This would be unnecessary if we could tell apply_ufunc
767+ # to insert size-1 broadcast dimensions
768+ result_coord_core_dims = input_core_dims [- len (result_coords ) :],
769+ ),
761770 # TODO: deprecate and have the user rechunk themselves
762771 dask_gufunc_kwargs = dict (output_sizes = output_sizes , allow_rechunk = True ),
763772 output_dtypes = [dtype ],
764- # if there are any broadcast dims on the result, we must vectorize on them
765- vectorize = bool (result_broadcast_dims ),
773+ vectorize = bool (vectorize_dims ),
766774 keep_attrs = True ,
767775 )
768776 return result
@@ -787,7 +795,11 @@ def _interp1d(
787795
788796
789797def _interpnd (
790- data : np .ndarray , * coords : np .ndarray , interp_func : InterpCallable , interp_kwargs
798+ data : np .ndarray ,
799+ * coords : np .ndarray ,
800+ interp_func : InterpCallable ,
801+ interp_kwargs ,
802+ result_coord_core_dims ,
791803) -> np .ndarray :
792804 """
793805 Core nD array interpolation routine.
@@ -801,10 +813,12 @@ def _interpnd(
801813 # Convert everything to Variables, since that makes applying
802814 # `_localize` and `_floatize_x` much easier
803815 x = [Variable ([f"dim_{ nconst + dim } " ], _x ) for dim , _x in enumerate (coords [:n_x ])]
804- new_x = [
805- Variable ([f"dim_{ ndim + dim } " for dim in range (_x .ndim )], _x )
806- for _x in coords [n_x :]
807- ]
816+ new_x = broadcast_variables (
817+ * (
818+ Variable (dims , _x )
819+ for dims , _x in zip (result_coord_core_dims , coords [n_x :], strict = True )
820+ )
821+ )
808822 var = Variable ([f"dim_{ dim } " for dim in range (ndim )], data )
809823
810824 if interp_kwargs .get ("method" ) in ["linear" , "nearest" ]:
0 commit comments