Skip to content

Commit 6e32204

Browse files
committed
preserve vlen string dtypes, allow vlen string fill_values
1 parent 99f8446 commit 6e32204

File tree

5 files changed

+51
-46
lines changed

5 files changed

+51
-46
lines changed

xarray/backends/h5netcdf_.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,6 @@ def prepare_variable(
266266
dtype = _get_datatype(variable, raise_on_invalid_encoding=check_encoding)
267267

268268
fillvalue = attrs.pop("_FillValue", None)
269-
if dtype is str and fillvalue is not None:
270-
raise NotImplementedError(
271-
"h5netcdf does not yet support setting a fill value for "
272-
"variable-length strings "
273-
"(https://github.com/h5netcdf/h5netcdf/issues/37). "
274-
f"Either remove '_FillValue' from encoding on variable {name!r} "
275-
"or set {'dtype': 'S1'} in encoding to use the fixed width "
276-
"NC_CHAR type."
277-
)
278269

279270
if dtype is str:
280271
dtype = h5py.special_dtype(vlen=str)

xarray/backends/netCDF4_.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -490,16 +490,6 @@ def prepare_variable(
490490

491491
fill_value = attrs.pop("_FillValue", None)
492492

493-
if datatype is str and fill_value is not None:
494-
raise NotImplementedError(
495-
"netCDF4 does not yet support setting a fill value for "
496-
"variable-length strings "
497-
"(https://github.com/Unidata/netcdf4-python/issues/730). "
498-
f"Either remove '_FillValue' from encoding on variable {name!r} "
499-
"or set {'dtype': 'S1'} in encoding to use the fixed width "
500-
"NC_CHAR type."
501-
)
502-
503493
encoding = _extract_nc4_variable_encoding(
504494
variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims
505495
)

xarray/coding/variables.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,3 +547,15 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
547547

548548
def decode(self):
549549
raise NotImplementedError()
550+
551+
552+
class ObjectStringCoder(VariableCoder):
553+
def encode(self):
554+
return NotImplementedError
555+
556+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
557+
if variable.dtype == object and variable.encoding.get("dtype", False) == str:
558+
variable = variable.astype(variable.encoding["dtype"])
559+
return variable
560+
else:
561+
return variable

xarray/conventions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,10 @@ def decode_cf_variable(
265265
var = strings.CharacterArrayCoder().decode(var, name=name)
266266
var = strings.EncodedStringCoder().decode(var)
267267

268+
if original_dtype == object:
269+
var = variables.ObjectStringCoder().decode(var)
270+
original_dtype = var.dtype
271+
268272
if mask_and_scale:
269273
for coder in [
270274
variables.UnsignedIntegerCoder(),

xarray/tests/test_backends.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -866,12 +866,13 @@ def test_roundtrip_empty_vlen_string_array(self) -> None:
866866
assert check_vlen_dtype(original["a"].dtype) == str
867867
with self.roundtrip(original) as actual:
868868
assert_identical(original, actual)
869-
assert object == actual["a"].dtype
870-
assert actual["a"].dtype == original["a"].dtype
871-
# only check metadata for capable backends
872-
# eg. NETCDF3 based backends do not roundtrip metadata
873-
if actual["a"].dtype.metadata is not None:
874-
assert check_vlen_dtype(actual["a"].dtype) == str
869+
if np.issubdtype(actual["a"].dtype, object):
870+
# only check metadata for capable backends
871+
# eg. NETCDF3 based backends do not roundtrip metadata
872+
if actual["a"].dtype.metadata is not None:
873+
assert check_vlen_dtype(actual["a"].dtype) == str
874+
else:
875+
assert actual["a"].dtype == np.dtype("<U1")
875876

876877
@pytest.mark.parametrize(
877878
"decoded_fn, encoded_fn",
@@ -1376,32 +1377,39 @@ def test_write_groups(self) -> None:
13761377
with self.open(tmp_file, group="data/2") as actual2:
13771378
assert_identical(data2, actual2)
13781379

1379-
def test_encoding_kwarg_vlen_string(self) -> None:
1380-
for input_strings in [[b"foo", b"bar", b"baz"], ["foo", "bar", "baz"]]:
1381-
original = Dataset({"x": input_strings})
1382-
expected = Dataset({"x": ["foo", "bar", "baz"]})
1383-
kwargs = dict(encoding={"x": {"dtype": str}})
1384-
with self.roundtrip(original, save_kwargs=kwargs) as actual:
1385-
assert actual["x"].encoding["dtype"] is str
1386-
assert_identical(actual, expected)
1387-
1388-
def test_roundtrip_string_with_fill_value_vlen(self) -> None:
1380+
@pytest.mark.parametrize(
1381+
"input_strings, is_bytes",
1382+
[
1383+
([b"foo", b"bar", b"baz"], True),
1384+
(["foo", "bar", "baz"], False),
1385+
(["foó", "bár", "baź"], False),
1386+
],
1387+
)
1388+
def test_encoding_kwarg_vlen_string(
1389+
self, input_strings: list[str], is_bytes: bool
1390+
) -> None:
1391+
original = Dataset({"x": input_strings})
1392+
1393+
expected_string = ["foo", "bar", "baz"] if is_bytes else input_strings
1394+
expected = Dataset({"x": expected_string})
1395+
kwargs = dict(encoding={"x": {"dtype": str}})
1396+
with self.roundtrip(original, save_kwargs=kwargs) as actual:
1397+
assert actual["x"].encoding["dtype"] == "<U3"
1398+
assert actual["x"].dtype == "<U3"
1399+
assert_identical(actual, expected)
1400+
1401+
@pytest.mark.parametrize("fill_value", ["XXX", "", "bár"])
1402+
def test_roundtrip_string_with_fill_value_vlen(self, fill_value: str) -> None:
13891403
values = np.array(["ab", "cdef", np.nan], dtype=object)
13901404
expected = Dataset({"x": ("t", values)})
13911405

1392-
# netCDF4-based backends don't support an explicit fillvalue
1393-
# for variable length strings yet.
1394-
# https://github.com/Unidata/netcdf4-python/issues/730
1395-
# https://github.com/h5netcdf/h5netcdf/issues/37
1396-
original = Dataset({"x": ("t", values, {}, {"_FillValue": "XXX"})})
1397-
with pytest.raises(NotImplementedError):
1398-
with self.roundtrip(original) as actual:
1399-
assert_identical(expected, actual)
1406+
original = Dataset({"x": ("t", values, {}, {"_FillValue": fill_value})})
1407+
with self.roundtrip(original) as actual:
1408+
assert_identical(expected, actual)
14001409

14011410
original = Dataset({"x": ("t", values, {}, {"_FillValue": ""})})
1402-
with pytest.raises(NotImplementedError):
1403-
with self.roundtrip(original) as actual:
1404-
assert_identical(expected, actual)
1411+
with self.roundtrip(original) as actual:
1412+
assert_identical(expected, actual)
14051413

14061414
def test_roundtrip_character_array(self) -> None:
14071415
with create_tmp_file() as tmp_file:

0 commit comments

Comments
 (0)