@@ -212,8 +212,6 @@ def concatenate_managers(
212212 for placement , join_units in concat_plan :
213213 unit = join_units [0 ]
214214 blk = unit .block
215- # Assertion disabled for performance
216- # assert len(join_units) == len(mgrs_indexers)
217215
218216 if len (join_units ) == 1 :
219217 values = blk .values
@@ -331,20 +329,27 @@ def _get_mgr_concatenation_plan(mgr: BlockManager):
331329 plan : list of (BlockPlacement, JoinUnit) tuples
332330
333331 """
332+ # Calculate post-reindex shape , save for item axis which will be separate
333+ # for each block anyway.
334+ mgr_shape_list = list (mgr .shape )
335+ mgr_shape = tuple (mgr_shape_list )
334336
335337 if mgr .is_single_block :
336338 blk = mgr .blocks [0 ]
337- return [(blk .mgr_locs , JoinUnit (blk ))]
339+ return [(blk .mgr_locs , JoinUnit (blk , mgr_shape ))]
338340
339341 blknos = mgr .blknos
340342 blklocs = mgr .blklocs
341343
342344 plan = []
343345 for blkno , placements in libinternals .get_blkno_placements (blknos , group = False ):
344346
345- # Assertions disabled for performance; these should always hold
346- # assert placements.is_slice_like
347- # assert blkno != -1
347+ assert placements .is_slice_like
348+ assert blkno != - 1
349+
350+ shape_list = list (mgr_shape )
351+ shape_list [0 ] = len (placements )
352+ shape = tuple (shape_list )
348353
349354 blk = mgr .blocks [blkno ]
350355 ax0_blk_indexer = blklocs [placements .indexer ]
@@ -374,16 +379,19 @@ def _get_mgr_concatenation_plan(mgr: BlockManager):
374379
375380 # Assertions disabled for performance
376381 # assert blk._mgr_locs.as_slice == placements.as_slice
377- unit = JoinUnit (blk )
382+ # assert blk.shape[0] == shape[0]
383+ unit = JoinUnit (blk , shape )
378384
379385 plan .append ((placements , unit ))
380386
381387 return plan
382388
383389
384390class JoinUnit :
385- def __init__ (self , block : Block ) -> None :
391+ def __init__ (self , block : Block , shape : Shape ):
392+ # Passing shape explicitly is required for cases when block is None.
386393 self .block = block
394+ self .shape = shape
387395
388396 def __repr__ (self ) -> str :
389397 return f"{ type (self ).__name__ } ({ repr (self .block )} )"
@@ -396,11 +404,22 @@ def is_na(self) -> bool:
396404 return False
397405
398406 def get_reindexed_values (self , empty_dtype : DtypeObj ) -> ArrayLike :
407+ values : ArrayLike
408+
399409 if self .is_na :
400- return make_na_array (empty_dtype , self .block . shape )
410+ return make_na_array (empty_dtype , self .shape )
401411
402412 else :
403- return self .block .values
413+
414+ if not self .block ._can_consolidate :
415+ # preserve these for validation in concat_compat
416+ return self .block .values
417+
418+ # No dtype upcasting is done here, it will be performed during
419+ # concatenation itself.
420+ values = self .block .values
421+
422+ return values
404423
405424
406425def make_na_array (dtype : DtypeObj , shape : Shape ) -> ArrayLike :
@@ -539,9 +558,6 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
539558 first = join_units [0 ].block
540559 if first .dtype .kind == "V" :
541560 return False
542- elif len (join_units ) == 1 :
543- # only use this path when there is something to concatenate
544- return False
545561 return (
546562 # exclude cases where a) ju.block is None or b) we have e.g. Int64+int64
547563 all (type (ju .block ) is type (first ) for ju in join_units )
@@ -554,8 +570,13 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
554570 or ju .block .dtype .kind in ["b" , "i" , "u" ]
555571 for ju in join_units
556572 )
557- # this also precludes any blocks with dtype.kind == "V", since
558- # we excluded that case for `first` above.
573+ and
574+ # no blocks that would get missing values (can lead to type upcasts)
575+ # unless we're an extension dtype.
576+ all (not ju .is_na or ju .block .is_extension for ju in join_units )
577+ and
578+ # only use this path when there is something to concatenate
579+ len (join_units ) > 1
559580 )
560581
561582
@@ -577,7 +598,10 @@ def _trim_join_unit(join_unit: JoinUnit, length: int) -> JoinUnit:
577598 extra_block = join_unit .block .getitem_block (slice (length , None ))
578599 join_unit .block = join_unit .block .getitem_block (slice (length ))
579600
580- return JoinUnit (block = extra_block )
601+ extra_shape = (join_unit .shape [0 ] - length ,) + join_unit .shape [1 :]
602+ join_unit .shape = (length ,) + join_unit .shape [1 :]
603+
604+ return JoinUnit (block = extra_block , shape = extra_shape )
581605
582606
583607def _combine_concat_plans (plans ):
0 commit comments