@@ -529,7 +529,6 @@ def shuffle(self) -> None:
529529 """
530530 from xarray .core .dataarray import DataArray
531531 from xarray .core .dataset import Dataset
532- from xarray .core .duck_array_ops import shuffle_array
533532
534533 (grouper ,) = self .groupers
535534 dim = self ._group_dim
@@ -538,6 +537,8 @@ def shuffle(self) -> None:
538537 if all (isinstance (idx , slice ) for idx in self ._group_indices ):
539538 return
540539
540+ indices : tuple [list [int ]] = self ._group_indices # type: ignore[assignment]
541+
541542 was_array = isinstance (self ._obj , DataArray )
542543 as_dataset = self ._obj ._to_temp_dataset () if was_array else self ._obj
543544
@@ -546,21 +547,22 @@ def shuffle(self) -> None:
546547 if dim not in var .dims :
547548 shuffled [name ] = var
548549 continue
549- shuffled_data = shuffle_array (
550- var ._data , list (self ._group_indices ), axis = var .get_axis_num (dim )
551- )
552- shuffled [name ] = var ._replace (data = shuffled_data )
550+ shuffled [name ] = var ._shuffle (indices = list (indices ), dim = dim )
553551
554552 # Replace self._group_indices with slices
555553 slices = []
556554 start = 0
557555 for idxr in self ._group_indices :
556+ if TYPE_CHECKING :
557+ assert not isinstance (idxr , slice )
558558 slices .append (slice (start , start + len (idxr )))
559559 start += len (idxr )
560560 # TODO: we have now broken the invariant
561561 # self._group_indices ≠ self.groupers[0].group_indices
562562 self ._group_indices = tuple (slices )
563563 if was_array :
564+ if TYPE_CHECKING :
565+ assert isinstance (self ._obj , DataArray )
564566 self ._obj = self ._obj ._from_temp_dataset (shuffled )
565567 else :
566568 self ._obj = shuffled
0 commit comments