@@ -9132,34 +9132,39 @@ def polyfit(
91329132 variables [sing .name ] = sing
91339133
91349134 # If we have a coordinate get its underlying dimension.
9135- true_dim = self .coords [dim ].dims [ 0 ]
9135+ ( true_dim ,) = self .coords [dim ].dims
91369136
9137- for name , da in self .data_vars .items ():
9138- if true_dim not in da .dims :
9137+ other_coords = {
9138+ dim : self ._variables [dim ]
9139+ for dim in set (self .dims ) - {true_dim }
9140+ if dim in self ._variables
9141+ }
9142+ present_dims = set ()
9143+ for name , var in self ._variables .items ():
9144+ if name in self ._coord_names or name in self .dims :
9145+ continue
9146+ if true_dim not in var .dims :
91399147 continue
91409148
9141- if is_duck_dask_array (da . data ) and (
9149+ if is_duck_dask_array (var . _data ) and (
91429150 rank != order or full or skipna is None
91439151 ):
91449152 # Current algorithm with dask and skipna=False neither supports
91459153 # deficient ranks nor does it output the "full" info (issue dask/dask#6516)
91469154 skipna_da = True
91479155 elif skipna is None :
9148- skipna_da = bool (np .any (da .isnull ()))
9149-
9150- dims_to_stack = [dimname for dimname in da .dims if dimname != true_dim ]
9151- stacked_coords : dict [Hashable , DataArray ] = {}
9152- if dims_to_stack :
9153- stacked_dim = utils .get_temp_dimname (dims_to_stack , "stacked" )
9154- rhs = da .transpose (true_dim , * dims_to_stack ).stack (
9155- {stacked_dim : dims_to_stack }
9156- )
9157- stacked_coords = {stacked_dim : rhs [stacked_dim ]}
9158- scale_da = scale [:, np .newaxis ]
9156+ skipna_da = bool (np .any (var .isnull ()))
9157+
9158+ if var .ndim > 1 :
9159+ rhs = var .transpose (true_dim , ...)
9160+ other_dims = rhs .dims [1 :]
9161+ scale_da = scale .reshape (- 1 , * ((1 ,) * len (other_dims )))
91599162 else :
9160- rhs = da
9163+ rhs = var
91619164 scale_da = scale
9165+ other_dims = ()
91629166
9167+ present_dims .update (* other_dims )
91639168 if w is not None :
91649169 rhs = rhs * w [:, np .newaxis ]
91659170
@@ -9179,26 +9184,15 @@ def polyfit(
91799184 # Thus a ReprObject => polyfit was called on a DataArray
91809185 name = ""
91819186
9182- coeffs = DataArray (
9183- coeffs / scale_da ,
9184- dims = [degree_dim ] + list (stacked_coords .keys ()),
9185- coords = {degree_dim : np .arange (order )[::- 1 ], ** stacked_coords },
9186- name = name + "polyfit_coefficients" ,
9187- )
9188- if dims_to_stack :
9189- coeffs = coeffs .unstack (stacked_dim )
9190- variables [coeffs .name ] = coeffs
9187+ coeffs = Variable (data = coeffs / scale_da , dims = (degree_dim ,) + other_dims )
9188+ variables [name + "polyfit_coefficients" ] = coeffs
91919189
91929190 if full or (cov is True ):
9193- residuals = DataArray (
9194- residuals if dims_to_stack else residuals .squeeze (),
9195- dims = list (stacked_coords .keys ()),
9196- coords = stacked_coords ,
9197- name = name + "polyfit_residuals" ,
9191+ residuals = Variable (
9192+ data = residuals if var .ndim > 1 else residuals .squeeze (),
9193+ dims = other_dims ,
91989194 )
9199- if dims_to_stack :
9200- residuals = residuals .unstack (stacked_dim )
9201- variables [residuals .name ] = residuals
9195+ variables [name + "polyfit_residuals" ] = residuals
92029196
92039197 if cov :
92049198 Vbase = np .linalg .inv (np .dot (lhs .T , lhs ))
@@ -9214,7 +9208,18 @@ def polyfit(
92149208 covariance = DataArray (Vbase , dims = ("cov_i" , "cov_j" )) * fac
92159209 variables [name + "polyfit_covariance" ] = covariance
92169210
9217- return type (self )(data_vars = variables , attrs = self .attrs .copy ())
9211+ return type (self )(
9212+ data_vars = variables ,
9213+ coords = {
9214+ degree_dim : np .arange (order )[::- 1 ],
9215+ ** {
9216+ name : coord
9217+ for name , coord in other_coords .items ()
9218+ if name in present_dims
9219+ },
9220+ },
9221+ attrs = self .attrs .copy (),
9222+ )
92189223
92199224 def pad (
92209225 self ,
0 commit comments