@@ -224,7 +224,9 @@ def is_int64_overflow_possible(shape: Shape) -> bool:
224224 return the_prod >= lib .i8max
225225
226226
227- def decons_group_index (comp_labels , shape : Shape ):
227+ def _decons_group_index (
228+ comp_labels : npt .NDArray [np .intp ], shape : Shape
229+ ) -> list [npt .NDArray [np .intp ]]:
228230 # reconstruct labels
229231 if is_int64_overflow_possible (shape ):
230232 # at some point group indices are factorized,
@@ -233,7 +235,7 @@ def decons_group_index(comp_labels, shape: Shape):
233235
234236 label_list = []
235237 factor = 1
236- y = 0
238+ y = np . array ( 0 )
237239 x = comp_labels
238240 for i in reversed (range (len (shape ))):
239241 labels = (x - y ) % (factor * shape [i ]) // factor
@@ -245,24 +247,32 @@ def decons_group_index(comp_labels, shape: Shape):
245247
246248
247249def decons_obs_group_ids (
248- comp_ids : npt .NDArray [np .intp ], obs_ids , shape : Shape , labels , xnull : bool
249- ):
250+ comp_ids : npt .NDArray [np .intp ],
251+ obs_ids : npt .NDArray [np .intp ],
252+ shape : Shape ,
253+ labels : Sequence [npt .NDArray [np .signedinteger ]],
254+ xnull : bool ,
255+ ) -> list [npt .NDArray [np .intp ]]:
250256 """
251257 Reconstruct labels from observed group ids.
252258
253259 Parameters
254260 ----------
255261 comp_ids : np.ndarray[np.intp]
262+ obs_ids: np.ndarray[np.intp]
263+ shape : tuple[int]
264+ labels : Sequence[np.ndarray[np.signedinteger]]
256265 xnull : bool
257266 If nulls are excluded; i.e. -1 labels are passed through.
258267 """
259268 if not xnull :
260- lift = np .fromiter (((a == - 1 ).any () for a in labels ), dtype = "i8" )
261- shape = np .asarray (shape , dtype = "i8" ) + lift
269+ lift = np .fromiter (((a == - 1 ).any () for a in labels ), dtype = np .intp )
270+ arr_shape = np .asarray (shape , dtype = np .intp ) + lift
271+ shape = tuple (arr_shape )
262272
263273 if not is_int64_overflow_possible (shape ):
264274 # obs ids are deconstructable! take the fast route!
265- out = decons_group_index (obs_ids , shape )
275+ out = _decons_group_index (obs_ids , shape )
266276 return out if xnull or not lift .any () else [x - y for x , y in zip (out , lift )]
267277
268278 indexer = unique_label_indices (comp_ids )
0 commit comments