|
57 | 57 |
|
58 | 58 | from xarray.core.dataarray import DataArray |
59 | 59 | from xarray.core.dataset import Dataset |
60 | | - from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey |
| 60 | + from xarray.core.types import ( |
| 61 | + GroupIndex, |
| 62 | + GroupIndices, |
| 63 | + GroupInput, |
| 64 | + GroupKey, |
| 65 | + T_Chunks, |
| 66 | + ) |
61 | 67 | from xarray.core.utils import Frozen |
62 | 68 | from xarray.groupers import EncodedGroups, Grouper |
63 | 69 |
|
@@ -676,6 +682,76 @@ def sizes(self) -> Mapping[Hashable, int]: |
676 | 682 | self._sizes = self._obj.isel({self._group_dim: index}).sizes |
677 | 683 | return self._sizes |
678 | 684 |
|
| 685 | + def shuffle_to_chunks(self, chunks: T_Chunks = None) -> T_Xarray: |
| 686 | + """ |
| 687 | + Sort or "shuffle" the underlying object. |
| 688 | +
|
| 689 | + "Shuffle" means the object is sorted so that all group members occur sequentially, |
| 690 | + in the same chunk. Multiple groups may occur in the same chunk. |
| 691 | + This method is particularly useful for chunked arrays (e.g. dask, cubed). |
| 692 | + particularly when you need to map a function that requires all members of a group |
| 693 | + to be present in a single chunk. For chunked array types, the order of appearance |
| 694 | + is not guaranteed, but will depend on the input chunking. |
| 695 | +
|
| 696 | + Parameters |
| 697 | + ---------- |
| 698 | + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional |
| 699 | + How to adjust chunks along dimensions not present in the array being grouped by. |
| 700 | +
|
| 701 | + Returns |
| 702 | + ------- |
| 703 | + DataArrayGroupBy or DatasetGroupBy |
| 704 | +
|
| 705 | + Examples |
| 706 | + -------- |
| 707 | + >>> import dask.array |
| 708 | + >>> da = xr.DataArray( |
| 709 | + ... dims="x", |
| 710 | + ... data=dask.array.arange(10, chunks=3), |
| 711 | + ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, |
| 712 | + ... name="a", |
| 713 | + ... ) |
| 714 | + >>> shuffled = da.groupby("x").shuffle_to_chunks() |
| 715 | + >>> shuffled |
| 716 | + <xarray.DataArray 'a' (x: 10)> Size: 80B |
| 717 | + dask.array<shuffle, shape=(10,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray> |
| 718 | + Coordinates: |
| 719 | + * x (x) int64 80B 0 1 1 1 2 2 2 3 3 3 |
| 720 | +
|
| 721 | + >>> shuffled.groupby("x").quantile(q=0.5).compute() |
| 722 | + <xarray.DataArray 'a' (x: 4)> Size: 32B |
| 723 | + array([9., 3., 4., 5.]) |
| 724 | + Coordinates: |
| 725 | + quantile float64 8B 0.5 |
| 726 | + * x (x) int64 32B 0 1 2 3 |
| 727 | +
|
| 728 | + See Also |
| 729 | + -------- |
| 730 | + dask.dataframe.DataFrame.shuffle |
| 731 | + dask.array.shuffle |
| 732 | + """ |
| 733 | + self._raise_if_by_is_chunked() |
| 734 | + return self._shuffle_obj(chunks) |
| 735 | + |
| 736 | + def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: |
| 737 | + from xarray.core.dataarray import DataArray |
| 738 | + |
| 739 | + was_array = isinstance(self._obj, DataArray) |
| 740 | + as_dataset = self._obj._to_temp_dataset() if was_array else self._obj |
| 741 | + |
| 742 | + for grouper in self.groupers: |
| 743 | + if grouper.name not in as_dataset._variables: |
| 744 | + as_dataset.coords[grouper.name] = grouper.group |
| 745 | + |
| 746 | + shuffled = as_dataset._shuffle( |
| 747 | + dim=self._group_dim, indices=self.encoded.group_indices, chunks=chunks |
| 748 | + ) |
| 749 | + unstacked: Dataset = self._maybe_unstack(shuffled) |
| 750 | + if was_array: |
| 751 | + return self._obj._from_temp_dataset(unstacked) |
| 752 | + else: |
| 753 | + return unstacked # type: ignore[return-value] |
| 754 | + |
679 | 755 | def map( |
680 | 756 | self, |
681 | 757 | func: Callable, |
@@ -896,7 +972,9 @@ def _maybe_unstack(self, obj): |
896 | 972 | # and `inserted_dims` |
897 | 973 | # if multiple groupers all share the same single dimension, then |
898 | 974 | # we don't stack/unstack. Do that manually now. |
899 | | - obj = obj.unstack(*self.encoded.unique_coord.dims) |
| 975 | + dims_to_unstack = self.encoded.unique_coord.dims |
| 976 | + if all(dim in obj.dims for dim in dims_to_unstack): |
| 977 | + obj = obj.unstack(*dims_to_unstack) |
900 | 978 | to_drop = [ |
901 | 979 | grouper.name |
902 | 980 | for grouper in self.groupers |
|
0 commit comments