11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , Any , Hashable , Iterable , overload
3+ from typing import TYPE_CHECKING , Any , Hashable , Iterable , cast , overload
44
55import pandas as pd
66
1414 merge_attrs ,
1515 merge_collected ,
1616)
17+ from .types import T_DataArray , T_Dataset
1718from .variable import Variable
1819from .variable import concat as concat_vars
1920
2021if TYPE_CHECKING :
21- from .dataarray import DataArray
22- from .dataset import Dataset
2322 from .types import CombineAttrsOptions , CompatOptions , ConcatOptions , JoinOptions
2423
2524
2625@overload
2726def concat (
28- objs : Iterable [Dataset ],
29- dim : Hashable | DataArray | pd .Index ,
27+ objs : Iterable [T_Dataset ],
28+ dim : Hashable | T_DataArray | pd .Index ,
3029 data_vars : ConcatOptions | list [Hashable ] = "all" ,
3130 coords : ConcatOptions | list [Hashable ] = "different" ,
3231 compat : CompatOptions = "equals" ,
3332 positions : Iterable [Iterable [int ]] | None = None ,
3433 fill_value : object = dtypes .NA ,
3534 join : JoinOptions = "outer" ,
3635 combine_attrs : CombineAttrsOptions = "override" ,
37- ) -> Dataset :
36+ ) -> T_Dataset :
3837 ...
3938
4039
4140@overload
4241def concat (
43- objs : Iterable [DataArray ],
44- dim : Hashable | DataArray | pd .Index ,
42+ objs : Iterable [T_DataArray ],
43+ dim : Hashable | T_DataArray | pd .Index ,
4544 data_vars : ConcatOptions | list [Hashable ] = "all" ,
4645 coords : ConcatOptions | list [Hashable ] = "different" ,
4746 compat : CompatOptions = "equals" ,
4847 positions : Iterable [Iterable [int ]] | None = None ,
4948 fill_value : object = dtypes .NA ,
5049 join : JoinOptions = "outer" ,
5150 combine_attrs : CombineAttrsOptions = "override" ,
52- ) -> DataArray :
51+ ) -> T_DataArray :
5352 ...
5453
5554
@@ -402,7 +401,7 @@ def process_subset_opt(opt, subset):
402401
403402# determine dimensional coordinate names and a dict mapping name to DataArray
404403def _parse_datasets (
405- datasets : Iterable [Dataset ],
404+ datasets : Iterable [T_Dataset ],
406405) -> tuple [dict [Hashable , Variable ], dict [Hashable , int ], set [Hashable ], set [Hashable ]]:
407406
408407 dims : set [Hashable ] = set ()
@@ -429,16 +428,16 @@ def _parse_datasets(
429428
430429
431430def _dataset_concat (
432- datasets : list [Dataset ],
433- dim : str | DataArray | pd .Index ,
431+ datasets : list [T_Dataset ],
432+ dim : str | T_DataArray | pd .Index ,
434433 data_vars : str | list [str ],
435434 coords : str | list [str ],
436435 compat : CompatOptions ,
437436 positions : Iterable [Iterable [int ]] | None ,
438437 fill_value : object = dtypes .NA ,
439438 join : JoinOptions = "outer" ,
440439 combine_attrs : CombineAttrsOptions = "override" ,
441- ) -> Dataset :
440+ ) -> T_Dataset :
442441 """
443442 Concatenate a sequence of datasets along a new or existing dimension
444443 """
@@ -482,7 +481,8 @@ def _dataset_concat(
482481
483482 # case where concat dimension is a coordinate or data_var but not a dimension
484483 if (dim in coord_names or dim in data_names ) and dim not in dim_names :
485- datasets = [ds .expand_dims (dim ) for ds in datasets ]
484+ # TODO: Overriding type because .expand_dims has incorrect typing:
485+ datasets = [cast (T_Dataset , ds .expand_dims (dim )) for ds in datasets ]
486486
487487 # determine which variables to concatenate
488488 concat_over , equals , concat_dim_lengths = _calc_concat_over (
@@ -590,7 +590,7 @@ def get_indexes(name):
590590 # preserves original variable order
591591 result_vars [name ] = result_vars .pop (name )
592592
593- result = Dataset (result_vars , attrs = result_attrs )
593+ result = type ( datasets [ 0 ]) (result_vars , attrs = result_attrs )
594594
595595 absent_coord_names = coord_names - set (result .variables )
596596 if absent_coord_names :
@@ -618,16 +618,16 @@ def get_indexes(name):
618618
619619
620620def _dataarray_concat (
621- arrays : Iterable [DataArray ],
622- dim : str | DataArray | pd .Index ,
621+ arrays : Iterable [T_DataArray ],
622+ dim : str | T_DataArray | pd .Index ,
623623 data_vars : str | list [str ],
624624 coords : str | list [str ],
625625 compat : CompatOptions ,
626626 positions : Iterable [Iterable [int ]] | None ,
627627 fill_value : object = dtypes .NA ,
628628 join : JoinOptions = "outer" ,
629629 combine_attrs : CombineAttrsOptions = "override" ,
630- ) -> DataArray :
630+ ) -> T_DataArray :
631631 from .dataarray import DataArray
632632
633633 arrays = list (arrays )
@@ -650,7 +650,8 @@ def _dataarray_concat(
650650 if compat == "identical" :
651651 raise ValueError ("array names not identical" )
652652 else :
653- arr = arr .rename (name )
653+ # TODO: Overriding type because .rename has incorrect typing:
654+ arr = cast (T_DataArray , arr .rename (name ))
654655 datasets .append (arr ._to_temp_dataset ())
655656
656657 ds = _dataset_concat (
0 commit comments