Skip to content
This repository was archived by the owner on Jan 7, 2023. It is now read-only.

Commit cc2f178

Browse files
committed
stretch: handle fixed-length array columns
1 parent 3d8c4fe commit cc2f178

File tree

2 files changed

+75
-77
lines changed

2 files changed

+75
-77
lines changed

root_numpy/_utils.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
VLEN = np.vectorize(len)
1818

1919

20-
def _is_object_field(arr, col):
21-
return arr.dtype[col] == 'O'
22-
23-
2420
def rec2array(rec, fields=None):
2521
"""Convert a record array into a ndarray with a homogeneous data type.
2622
@@ -103,45 +99,50 @@ def stretch(arr, fields=None):
10399
dtype=[('scalar', '<i8'), ('array', '<f8')])
104100
105101
"""
106-
dt = []
107-
has_array_field = False
108-
first_array = None
102+
dtype = []
103+
len_array = None
109104

110105
if fields is None:
111106
fields = arr.dtype.names
112107

113-
# Construct dtype
108+
# Construct dtype and check consistency
114109
for field in fields:
115-
if _is_object_field(arr, field):
116-
dt.append((field, arr[field][0].dtype))
117-
has_array_field = True
118-
if first_array is None:
119-
first_array = field
110+
dt = arr.dtype[field]
111+
if dt == 'O' or len(dt.shape):
112+
if dt == 'O':
113+
# Variable-length array field
114+
lengths = VLEN(arr[field])
115+
else:
116+
lengths = np.repeat(dt.shape[0], arr.shape[0])
117+
# Fixed-length array field
118+
if len_array is None:
119+
len_array = lengths
120+
elif not np.array_equal(lengths, len_array):
121+
raise ValueError(
122+
"inconsistent lengths of array columns in input")
123+
if dt == 'O':
124+
dtype.append((field, arr[field][0].dtype))
125+
else:
126+
dtype.append((field, arr[field].dtype, dt.shape[1:]))
120127
else:
121-
# Assume scalar
122-
dt.append((field, arr[field].dtype))
128+
# Scalar field
129+
dtype.append((field, dt))
123130

124-
if not has_array_field:
125-
raise RuntimeError("No array column specified")
126-
127-
len_array = VLEN(arr[first_array])
128-
numrec = np.sum(len_array)
129-
ret = np.empty(numrec, dtype=dt)
131+
if len_array is None:
132+
raise RuntimeError("no array column in input")
130133

134+
# Build stretched output
135+
ret = np.empty(np.sum(len_array), dtype=dtype)
131136
for field in fields:
132-
if _is_object_field(arr, field):
133-
# FIXME: this is rather inefficient since the stack
134-
# is copied over to the return value
135-
stack = np.hstack(arr[field])
136-
if len(stack) != numrec:
137-
raise ValueError(
138-
"Array lengths do not match: "
139-
"expected %d but found %d in %s" %
140-
(numrec, len(stack), field))
141-
ret[field] = stack
137+
dt = arr.dtype[field]
138+
if dt == 'O' or len(dt.shape) == 1:
139+
# Variable-length or 1D fixed-length array field
140+
ret[field] = np.hstack(arr[field])
141+
elif len(dt.shape):
142+
# Multidimensional fixed-length array field
143+
ret[field] = np.vstack(arr[field])
142144
else:
143-
# FIXME: this is rather inefficient since the repeat result
144-
# is copied over to the return value
145+
# Scalar field
145146
ret[field] = np.repeat(arr[field], len_array)
146147

147148
return ret

root_numpy/tests.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -574,56 +574,53 @@ def test_fill_graph():
574574

575575

576576
def test_stretch():
577-
nrec = 5
578-
arr = np.empty(nrec,
577+
arr = np.empty(5,
579578
dtype=[
580579
('scalar', np.int),
581-
('df1', 'O'),
582-
('df2', 'O'),
583-
('df3', 'O')])
584-
585-
for i in range(nrec):
586-
df1 = np.array(range(i + 1), dtype=np.float)
587-
df2 = np.array(range(i + 1), dtype=np.int) * 2
588-
df3 = np.array(range(i + 1), dtype=np.double) * 3
589-
arr[i] = (i, df1, df2, df3)
580+
('vl1', 'O'),
581+
('vl2', 'O'),
582+
('vl3', 'O'),
583+
('fl1', np.int, (2, 2)),
584+
('fl2', np.float, (2, 3)),
585+
('fl3', np.double, (3, 2))])
586+
587+
for i in range(arr.shape[0]):
588+
vl1 = np.array(range(i + 1), dtype=np.int)
589+
vl2 = np.array(range(i + 2), dtype=np.float) * 2
590+
vl3 = np.array(range(2), dtype=np.double) * 3
591+
fl1 = np.array(range(4), dtype=np.int).reshape((2, 2))
592+
fl2 = np.array(range(6), dtype=np.float).reshape((2, 3))
593+
fl3 = np.array(range(6), dtype=np.double).reshape((3, 2))
594+
arr[i] = (i, vl1, vl2, vl3, fl1, fl2, fl3)
595+
596+
# no array columns included
597+
assert_raises(RuntimeError, rnp.stretch, arr, ['scalar',])
590598

591-
stretched = rnp.stretch(
592-
arr, ['scalar', 'df1', 'df2', 'df3'])
599+
# lengths don't match
600+
assert_raises(ValueError, rnp.stretch, arr, ['scalar', 'vl1', 'vl2',])
601+
assert_raises(ValueError, rnp.stretch, arr, ['scalar', 'fl1', 'fl3',])
602+
assert_raises(ValueError, rnp.stretch, arr)
593603

604+
# variable-length stretch
605+
stretched = rnp.stretch(arr, ['scalar', 'vl1',])
594606
assert_equal(stretched.dtype,
595-
[('scalar', np.int),
596-
('df1', np.float),
597-
('df2', np.int),
598-
('df3', np.double)])
599-
assert_equal(stretched.size, 15)
600-
601-
assert_almost_equal(stretched['df1'][14], 4.0)
602-
assert_almost_equal(stretched['df2'][14], 8)
603-
assert_almost_equal(stretched['df3'][14], 12.0)
604-
assert_almost_equal(stretched['scalar'][14], 4)
605-
assert_almost_equal(stretched['scalar'][13], 4)
606-
assert_almost_equal(stretched['scalar'][12], 4)
607-
assert_almost_equal(stretched['scalar'][11], 4)
608-
assert_almost_equal(stretched['scalar'][10], 4)
609-
assert_almost_equal(stretched['scalar'][9], 3)
610-
611-
arr = np.empty(1, dtype=[('scalar', np.int),])
612-
arr[0] = (1,)
613-
assert_raises(RuntimeError, rnp.stretch, arr, ['scalar',])
607+
[('scalar', np.int),
608+
('vl1', np.int)])
609+
assert_equal(stretched.shape[0], 15)
610+
assert_array_equal(
611+
stretched['scalar'],
612+
np.repeat(arr['scalar'], np.vectorize(len)(arr['vl1'])))
614613

615-
nrec = 5
616-
arr = np.empty(nrec,
617-
dtype=[
618-
('scalar', np.int),
619-
('df1', 'O'),
620-
('df2', 'O')])
621-
622-
for i in range(nrec):
623-
df1 = np.array(range(i + 1), dtype=np.float)
624-
df2 = np.array(range(i + 2), dtype=np.int) * 2
625-
arr[i] = (i, df1, df2)
626-
assert_raises(ValueError, rnp.stretch, arr, ['scalar', 'df1', 'df2'])
614+
# fixed-length stretch
615+
stretched = rnp.stretch(arr, ['scalar', 'vl3', 'fl1', 'fl2',])
616+
assert_equal(stretched.dtype,
617+
[('scalar', np.int),
618+
('vl3', np.double),
619+
('fl1', np.int, (2,)),
620+
('fl2', np.float, (3,))])
621+
assert_equal(stretched.shape[0], 10)
622+
assert_array_equal(
623+
stretched['scalar'], np.repeat(arr['scalar'], 2))
627624

628625

629626
def test_blockwise_inner_join():

0 commit comments

Comments
 (0)