Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions distarray/dist/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,11 @@ def from_dim_data_per_rank(cls, context, dim_data_per_rank, targets=None):
self.maps = [_map_from_axis_dim_dicts(axis_dim_dicts) for
axis_dim_dicts in axis_dim_dicts_per_axis]

# check for empty localarrays
sizes = self.localsizes()
if 0 in sizes:
raise ValueError("A localarray has zero size")

return self

@classmethod
Expand Down Expand Up @@ -567,6 +572,12 @@ def from_shape(cls, context, shape, dist=None, grid_shape=None,
# List of `ClientMap` objects, one per dimension.
self.maps = [map_from_sizes(*args)
for args in zip(self.shape, self.dist, self.grid_shape)]

# check for empty localarrays
sizes = self.localsizes()
if 0 in sizes:
raise ValueError("A localarray has zero size")

return self

def __init__(self, context, global_dim_data, targets=None):
Expand Down Expand Up @@ -673,6 +684,11 @@ def __init__(self, context, global_dim_data, targets=None):
nelts = reduce(operator.mul, self.grid_shape, 1)
self.rank_from_coords = np.arange(nelts).reshape(self.grid_shape)

# check for empty localarrays
sizes = self.localsizes()
if 0 in sizes:
raise ValueError("A localarray has zero size")

def __getitem__(self, idx):
return self.maps[idx]

Expand Down Expand Up @@ -782,3 +798,10 @@ def reduce(self, axes):

def localshapes(self):
return shapes_from_dim_data_per_rank(self.get_dim_data_per_rank())

def localsizes(self):
lshapes = shapes_from_dim_data_per_rank(self.get_dim_data_per_rank())
sizes = []
for shape in lshapes:
sizes.append(reduce(operator.mul, shape, 1))
return tuple(sizes)