Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions lib/iris/fileformats/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,9 +1190,9 @@ def write(

# Ensure that attributes are CF compliant and if possible to make them
# compliant.
self.check_attribute_compliance(cube, cube.lazy_data())
self.check_attribute_compliance(cube, cube.dtype)
for coord in cube.coords():
self.check_attribute_compliance(coord, coord.points)
self.check_attribute_compliance(coord, coord.dtype)

# Get suitable dimension names.
mesh_dimensions, cube_dimensions = self._get_dim_names(cube)
Expand Down Expand Up @@ -1280,16 +1280,14 @@ def write(
warnings.warn(msg)

@staticmethod
def check_attribute_compliance(container, data):
def check_attribute_compliance(container, data_dtype):
def _coerce_value(val_attr, val_attr_value, data_dtype):
val_attr_tmp = np.array(val_attr_value, dtype=data_dtype)
if (val_attr_tmp != val_attr_value).any():
msg = '"{}" is not of a suitable value ({})'
raise ValueError(msg.format(val_attr, val_attr_value))
return val_attr_tmp

data_dtype = data.dtype

# Ensure that conflicting attributes are not provided.
if (
container.attributes.get("valid_min") is not None
Expand Down
6 changes: 6 additions & 0 deletions lib/iris/tests/integration/test_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,16 @@ def test_lazy_preserved_save(self):
)
acube = iris.load_cube(fpath, "air_temperature")
self.assertTrue(acube.has_lazy_data())
# Also check a coord with lazy points + bounds.
self.assertTrue(acube.coord("forecast_period").has_lazy_points())
self.assertTrue(acube.coord("forecast_period").has_lazy_bounds())
with self.temp_filename(".nc") as nc_path:
with Saver(nc_path, "NETCDF4") as saver:
saver.write(acube)
# Check that cube data is not realised, also coord points + bounds.
self.assertTrue(acube.has_lazy_data())
self.assertTrue(acube.coord("forecast_period").has_lazy_points())
self.assertTrue(acube.coord("forecast_period").has_lazy_bounds())


@tests.skip_data
Expand Down
28 changes: 15 additions & 13 deletions lib/iris/tests/unit/fileformats/netcdf/test_Saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ class _Common__check_attribute_compliance:

def setUp(self):
self.container = mock.Mock(name="container", attributes={})
self.data = self.array_lib.array(1, dtype="int32")
self.data_dtype = np.dtype("int32")

patch = mock.patch("netCDF4.Dataset")
_ = patch.start()
Expand All @@ -629,7 +629,7 @@ def assertAttribute(self, value):
def check_attribute_compliance_call(self, value):
self.set_attribute(value)
with Saver(mock.Mock(), "NETCDF4") as saver:
saver.check_attribute_compliance(self.container, self.data)
saver.check_attribute_compliance(self.container, self.data_dtype)


class Test_check_attribute_compliance__valid_range(
Expand All @@ -642,10 +642,10 @@ def attribute(self):
def test_valid_range_type_coerce(self):
value = self.array_lib.array([1, 2], dtype="float")
self.check_attribute_compliance_call(value)
self.assertAttribute(self.data.dtype)
self.assertAttribute(self.data_dtype)

def test_valid_range_unsigned_int8_data_signed_range(self):
self.data = self.data.astype("uint8")
self.data_dtype = np.dtype("uint8")
value = self.array_lib.array([1, 2], dtype="int8")
self.check_attribute_compliance_call(value)
self.assertAttribute(value.dtype)
Expand All @@ -658,7 +658,7 @@ def test_valid_range_cannot_coerce(self):

def test_valid_range_not_numpy_array(self):
# Ensure we handle the case when not a numpy array is provided.
self.data = self.data.astype("int8")
self.data_dtype = np.dtype("int8")
value = [1, 2]
self.check_attribute_compliance_call(value)
self.assertAttribute(np.int64)
Expand All @@ -674,10 +674,10 @@ def attribute(self):
def test_valid_range_type_coerce(self):
value = self.array_lib.array(1, dtype="float")
self.check_attribute_compliance_call(value)
self.assertAttribute(self.data.dtype)
self.assertAttribute(self.data_dtype)

def test_valid_range_unsigned_int8_data_signed_range(self):
self.data = self.data.astype("uint8")
self.data_dtype = np.dtype("uint8")
value = self.array_lib.array(1, dtype="int8")
self.check_attribute_compliance_call(value)
self.assertAttribute(value.dtype)
Expand All @@ -690,7 +690,7 @@ def test_valid_range_cannot_coerce(self):

def test_valid_range_not_numpy_array(self):
# Ensure we handle the case when not a numpy array is provided.
self.data = self.data.astype("int8")
self.data_dtype = np.dtype("int8")
value = 1
self.check_attribute_compliance_call(value)
self.assertAttribute(np.int64)
Expand All @@ -706,10 +706,10 @@ def attribute(self):
def test_valid_range_type_coerce(self):
value = self.array_lib.array(2, dtype="float")
self.check_attribute_compliance_call(value)
self.assertAttribute(self.data.dtype)
self.assertAttribute(self.data_dtype)

def test_valid_range_unsigned_int8_data_signed_range(self):
self.data = self.data.astype("uint8")
self.data_dtype = np.dtype("uint8")
value = self.array_lib.array(2, dtype="int8")
self.check_attribute_compliance_call(value)
self.assertAttribute(value.dtype)
Expand All @@ -722,7 +722,7 @@ def test_valid_range_cannot_coerce(self):

def test_valid_range_not_numpy_array(self):
# Ensure we handle the case when not a numpy array is provided.
self.data = self.data.astype("int8")
self.data_dtype = np.dtype("int8")
value = 2
self.check_attribute_compliance_call(value)
self.assertAttribute(np.int64)
Expand All @@ -733,13 +733,15 @@ class Test_check_attribute_compliance__exception_handling(
):
def test_valid_range_and_valid_min_valid_max_provided(self):
# Conflicting attributes should raise a suitable exception.
self.data = self.data.astype("int8")
self.data_dtype = np.dtype("int8")
self.container.attributes["valid_range"] = [1, 2]
self.container.attributes["valid_min"] = [1]
msg = 'Both "valid_range" and "valid_min"'
with Saver(mock.Mock(), "NETCDF4") as saver:
with self.assertRaisesRegex(ValueError, msg):
saver.check_attribute_compliance(self.container, self.data)
saver.check_attribute_compliance(
self.container, self.data_dtype
)


class Test__cf_coord_identity(tests.IrisTest):
Expand Down
4 changes: 4 additions & 0 deletions lib/iris/tests/unit/fileformats/netcdf/test_Saver__lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,7 @@ def test_lazy_streamed_bounds(self):
self.cube.replace_coord(lazy_coord)
self.save_common(self.cube)
self.assertTrue(self.store_watch.called)


if __name__ == "__main__":
tests.main()