diff --git a/distarray/dist/maps.py b/distarray/dist/maps.py index 9fd98671..1ac96114 100644 --- a/distarray/dist/maps.py +++ b/distarray/dist/maps.py @@ -612,7 +612,17 @@ def __new__(cls, context, shape, dist=None, grid_shape=None, targets=None): # list of `ClientMap` objects, one per dimension. maps = [map_from_sizes(*args) for args in zip(shape, dist, grid_shape)] - return cls.from_maps(context=context, maps=maps, targets=targets) + + self = cls.from_maps(context=context, maps=maps, targets=targets) + + # TODO: FIXME: this is a workaround. The reason we slice here is to + # return a distribution with no empty local shapes. The `from_maps()` + # classmethod should be fixed to ensure no empty local arrays are + # created in the first place. That will remove the need to slice the + # distribution to remove empty localshapes. + if all(d in ('n', 'b') for d in self.dist): + self = self.slice((slice(None),)*self.ndim) + return self @classmethod def from_global_dim_data(cls, context, global_dim_data, targets=None): diff --git a/distarray/dist/tests/test_maps.py b/distarray/dist/tests/test_maps.py index 2e552899..c7823bd3 100644 --- a/distarray/dist/tests/test_maps.py +++ b/distarray/dist/tests/test_maps.py @@ -300,5 +300,23 @@ def test_all_n_dist(self): self.context.ones(distribution) +class TestNoEmptyLocals(ContextTestCase): + + def test_no_empty_local_arrays_4_targets(self): + for n in range(1, 20): + dist = Distribution(self.context, shape=(n,), + dist=('b',), + targets=self.context.targets[:4]) + for ls in dist.localshapes(): + self.assertNotIn(0, ls) + + def test_no_empty_local_arrays_3_targets(self): + for n in range(1, 20): + dist = Distribution(self.context, shape=(n,), + dist=('b',), + targets=self.context.targets[:3]) + for ls in dist.localshapes(): + self.assertNotIn(0, ls) + if __name__ == '__main__': unittest.main(verbosity=2)