1010from xarray .core .alignment import align , broadcast
1111from xarray .core .computation import apply_ufunc , dot
1212from xarray .core .pycompat import is_duck_dask_array
13- from xarray .core .types import Dims , T_Xarray
13+ from xarray .core .types import Dims , T_DataArray , T_Xarray
1414from xarray .util .deprecation_helpers import _deprecate_positional_args
1515
1616# Weighted quantile methods are a subset of the numpy supported quantile methods.
@@ -145,7 +145,7 @@ class Weighted(Generic[T_Xarray]):
145145
146146 __slots__ = ("obj" , "weights" )
147147
148- def __init__ (self , obj : T_Xarray , weights : DataArray ) -> None :
148+ def __init__ (self , obj : T_Xarray , weights : T_DataArray ) -> None :
149149 """
150150 Create a Weighted object
151151
@@ -189,7 +189,7 @@ def _weight_check(w):
189189 _weight_check (weights .data )
190190
191191 self .obj : T_Xarray = obj
192- self .weights : DataArray = weights
192+ self .weights : T_DataArray = weights
193193
194194 def _check_dim (self , dim : Dims ):
195195 """raise an error if any dimension is missing"""
@@ -208,11 +208,11 @@ def _check_dim(self, dim: Dims):
208208
209209 @staticmethod
210210 def _reduce (
211- da : DataArray ,
212- weights : DataArray ,
211+ da : T_DataArray ,
212+ weights : T_DataArray ,
213213 dim : Dims = None ,
214214 skipna : bool | None = None ,
215- ) -> DataArray :
215+ ) -> T_DataArray :
216216 """reduce using dot; equivalent to (da * weights).sum(dim, skipna)
217217
218218 for internal use only
@@ -230,7 +230,7 @@ def _reduce(
230230 # DataArray (if `weights` has additional dimensions)
231231 return dot (da , weights , dim = dim )
232232
233- def _sum_of_weights (self , da : DataArray , dim : Dims = None ) -> DataArray :
233+ def _sum_of_weights (self , da : T_DataArray , dim : Dims = None ) -> T_DataArray :
234234 """Calculate the sum of weights, accounting for missing values"""
235235
236236 # we need to mask data values that are nan; else the weights are wrong
@@ -255,10 +255,10 @@ def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray:
255255
256256 def _sum_of_squares (
257257 self ,
258- da : DataArray ,
258+ da : T_DataArray ,
259259 dim : Dims = None ,
260260 skipna : bool | None = None ,
261- ) -> DataArray :
261+ ) -> T_DataArray :
262262 """Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""
263263
264264 demeaned = da - da .weighted (self .weights ).mean (dim = dim )
@@ -267,20 +267,20 @@ def _sum_of_squares(
267267
268268 def _weighted_sum (
269269 self ,
270- da : DataArray ,
270+ da : T_DataArray ,
271271 dim : Dims = None ,
272272 skipna : bool | None = None ,
273- ) -> DataArray :
273+ ) -> T_DataArray :
274274 """Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""
275275
276276 return self ._reduce (da , self .weights , dim = dim , skipna = skipna )
277277
278278 def _weighted_mean (
279279 self ,
280- da : DataArray ,
280+ da : T_DataArray ,
281281 dim : Dims = None ,
282282 skipna : bool | None = None ,
283- ) -> DataArray :
283+ ) -> T_DataArray :
284284 """Reduce a DataArray by a weighted ``mean`` along some dimension(s)."""
285285
286286 weighted_sum = self ._weighted_sum (da , dim = dim , skipna = skipna )
@@ -291,10 +291,10 @@ def _weighted_mean(
291291
292292 def _weighted_var (
293293 self ,
294- da : DataArray ,
294+ da : T_DataArray ,
295295 dim : Dims = None ,
296296 skipna : bool | None = None ,
297- ) -> DataArray :
297+ ) -> T_DataArray :
298298 """Reduce a DataArray by a weighted ``var`` along some dimension(s)."""
299299
300300 sum_of_squares = self ._sum_of_squares (da , dim = dim , skipna = skipna )
@@ -305,21 +305,21 @@ def _weighted_var(
305305
306306 def _weighted_std (
307307 self ,
308- da : DataArray ,
308+ da : T_DataArray ,
309309 dim : Dims = None ,
310310 skipna : bool | None = None ,
311- ) -> DataArray :
311+ ) -> T_DataArray :
312312 """Reduce a DataArray by a weighted ``std`` along some dimension(s)."""
313313
314- return cast ("DataArray " , np .sqrt (self ._weighted_var (da , dim , skipna )))
314+ return cast ("T_DataArray " , np .sqrt (self ._weighted_var (da , dim , skipna )))
315315
316316 def _weighted_quantile (
317317 self ,
318- da : DataArray ,
318+ da : T_DataArray ,
319319 q : ArrayLike ,
320320 dim : Dims = None ,
321321 skipna : bool | None = None ,
322- ) -> DataArray :
322+ ) -> T_DataArray :
323323 """Apply a weighted ``quantile`` to a DataArray along some dimension(s)."""
324324
325325 def _get_h (n : float , q : np .ndarray , method : QUANTILE_METHODS ) -> np .ndarray :
0 commit comments