Skip to content
This repository was archived by the owner on Jan 7, 2023. It is now read-only.

Commit 655418b

Browse files
committed
stretch: handle fixed-length array columns
1 parent 3d8c4fe commit 655418b

File tree

1 file changed

+35
-33
lines changed

1 file changed

+35
-33
lines changed

root_numpy/_utils.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
VLEN = np.vectorize(len)
1818

1919

20-
def _is_object_field(arr, col):
21-
return arr.dtype[col] == 'O'
22-
23-
2420
def rec2array(rec, fields=None):
2521
"""Convert a record array into a ndarray with a homogeneous data type.
2622
@@ -103,45 +99,51 @@ def stretch(arr, fields=None):
10399
dtype=[('scalar', '<i8'), ('array', '<f8')])
104100
105101
"""
106-
dt = []
107-
has_array_field = False
108-
first_array = None
102+
dtype = []
103+
len_array = None
109104

110105
if fields is None:
111106
fields = arr.dtype.names
112107

113-
# Construct dtype
108+
# Construct dtype and check consistency
114109
for field in fields:
115-
if _is_object_field(arr, field):
116-
dt.append((field, arr[field][0].dtype))
117-
has_array_field = True
118-
if first_array is None:
119-
first_array = field
110+
dt = arr.dtype[field]
111+
if dt == 'O' or len(dt.shape):
112+
if dt == 'O':
113+
# Variable-length array field
114+
lengths = VLEN(arr[field])
115+
else:
116+
lengths = np.repeat(dt.shape[0], arr.shape[0])
117+
# Fixed-length array field
118+
if len_array is None:
119+
len_array = lengths
120+
elif not np.array_equal(lengths, len_array):
121+
raise ValueError(
122+
"inconsistent lengths of array columns in input")
123+
if dt == 'O':
124+
dtype.append((field, arr[field][0].dtype))
125+
else:
126+
dtype.append((field, arr[field].dtype, dt.shape[1:]))
120127
else:
121-
# Assume scalar
122-
dt.append((field, arr[field].dtype))
128+
# Scalar field
129+
dtype.append((field, dt))
123130

124-
if not has_array_field:
125-
raise RuntimeError("No array column specified")
126-
127-
len_array = VLEN(arr[first_array])
128-
numrec = np.sum(len_array)
129-
ret = np.empty(numrec, dtype=dt)
131+
if len_array is None:
132+
raise RuntimeError("no array column in input")
130133

134+
# Build stretched output
135+
print dtype
136+
ret = np.empty(np.sum(len_array), dtype=dtype)
131137
for field in fields:
132-
if _is_object_field(arr, field):
133-
# FIXME: this is rather inefficient since the stack
134-
# is copied over to the return value
135-
stack = np.hstack(arr[field])
136-
if len(stack) != numrec:
137-
raise ValueError(
138-
"Array lengths do not match: "
139-
"expected %d but found %d in %s" %
140-
(numrec, len(stack), field))
141-
ret[field] = stack
138+
dt = arr.dtype[field]
139+
if dt == 'O' or len(dt.shape) == 1:
140+
# Variable-length or 1D fixed-length array field
141+
ret[field] = np.hstack(arr[field])
142+
elif len(dt.shape):
143+
# Multidimensional fixed-length array field
144+
ret[field] = np.vstack(arr[field])
142145
else:
143-
# FIXME: this is rather inefficient since the repeat result
144-
# is copied over to the return value
146+
# Scalar field
145147
ret[field] = np.repeat(arr[field], len_array)
146148

147149
return ret

0 commit comments

Comments
 (0)