|
17 | 17 | VLEN = np.vectorize(len) |
18 | 18 |
|
19 | 19 |
|
20 | | -def _is_object_field(arr, col): |
21 | | - return arr.dtype[col] == 'O' |
22 | | - |
23 | | - |
24 | 20 | def rec2array(rec, fields=None): |
25 | 21 | """Convert a record array into a ndarray with a homogeneous data type. |
26 | 22 |
|
@@ -103,45 +99,51 @@ def stretch(arr, fields=None): |
103 | 99 | dtype=[('scalar', '<i8'), ('array', '<f8')]) |
104 | 100 |
|
105 | 101 | """ |
106 | | - dt = [] |
107 | | - has_array_field = False |
108 | | - first_array = None |
| 102 | + dtype = [] |
| 103 | + len_array = None |
109 | 104 |
|
110 | 105 | if fields is None: |
111 | 106 | fields = arr.dtype.names |
112 | 107 |
|
113 | | - # Construct dtype |
| 108 | + # Construct dtype and check consistency |
114 | 109 | for field in fields: |
115 | | - if _is_object_field(arr, field): |
116 | | - dt.append((field, arr[field][0].dtype)) |
117 | | - has_array_field = True |
118 | | - if first_array is None: |
119 | | - first_array = field |
| 110 | + dt = arr.dtype[field] |
| 111 | + if dt == 'O' or len(dt.shape): |
| 112 | + if dt == 'O': |
| 113 | + # Variable-length array field |
| 114 | + lengths = VLEN(arr[field]) |
| 115 | + else: |
| 116 | + lengths = np.repeat(dt.shape[0], arr.shape[0]) |
| 117 | + # Fixed-length array field |
| 118 | + if len_array is None: |
| 119 | + len_array = lengths |
| 120 | + elif not np.array_equal(lengths, len_array): |
| 121 | + raise ValueError( |
| 122 | + "inconsistent lengths of array columns in input") |
| 123 | + if dt == 'O': |
| 124 | + dtype.append((field, arr[field][0].dtype)) |
| 125 | + else: |
| 126 | + dtype.append((field, arr[field].dtype, dt.shape[1:])) |
120 | 127 | else: |
121 | | - # Assume scalar |
122 | | - dt.append((field, arr[field].dtype)) |
| 128 | + # Scalar field |
| 129 | + dtype.append((field, dt)) |
123 | 130 |
|
124 | | - if not has_array_field: |
125 | | - raise RuntimeError("No array column specified") |
126 | | - |
127 | | - len_array = VLEN(arr[first_array]) |
128 | | - numrec = np.sum(len_array) |
129 | | - ret = np.empty(numrec, dtype=dt) |
| 131 | + if len_array is None: |
| 132 | + raise RuntimeError("no array column in input") |
130 | 133 |
|
| 134 | + # Build stretched output |
| 135 | + print dtype |
| 136 | + ret = np.empty(np.sum(len_array), dtype=dtype) |
131 | 137 | for field in fields: |
132 | | - if _is_object_field(arr, field): |
133 | | - # FIXME: this is rather inefficient since the stack |
134 | | - # is copied over to the return value |
135 | | - stack = np.hstack(arr[field]) |
136 | | - if len(stack) != numrec: |
137 | | - raise ValueError( |
138 | | - "Array lengths do not match: " |
139 | | - "expected %d but found %d in %s" % |
140 | | - (numrec, len(stack), field)) |
141 | | - ret[field] = stack |
| 138 | + dt = arr.dtype[field] |
| 139 | + if dt == 'O' or len(dt.shape) == 1: |
| 140 | + # Variable-length or 1D fixed-length array field |
| 141 | + ret[field] = np.hstack(arr[field]) |
| 142 | + elif len(dt.shape): |
| 143 | + # Multidimensional fixed-length array field |
| 144 | + ret[field] = np.vstack(arr[field]) |
142 | 145 | else: |
143 | | - # FIXME: this is rather inefficient since the repeat result |
144 | | - # is copied over to the return value |
| 146 | + # Scalar field |
145 | 147 | ret[field] = np.repeat(arr[field], len_array) |
146 | 148 |
|
147 | 149 | return ret |
|
0 commit comments