@@ -531,7 +531,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
531531
532532
533533@numba_funcify .register (DimShuffle )
534- def numba_funcify_DimShuffle (op , ** kwargs ):
534+ def numba_funcify_DimShuffle (op , node , ** kwargs ):
535535 shuffle = tuple (op .shuffle )
536536 transposition = tuple (op .transposition )
537537 augment = tuple (op .augment )
@@ -560,16 +560,26 @@ def transpose(x):
560560 # To avoid this compile-time error, we omit the expression altogether.
561561 if len (shuffle ) > 0 :
562562
563- @numba_basic .numba_njit
564- def find_shape (array_shape ):
565- shape = shape_template
566- j = 0
567- for i in range (ndim_new_shape ):
568- if i not in augment :
569- length = array_shape [j ]
570- shape = numba_basic .tuple_setitem (shape , i , length )
571- j = j + 1
572- return shape
563+ # Use the statically known shape if available
564+ if all (length is not None for length in node .outputs [0 ].type .shape ):
565+ shape = node .outputs [0 ].type .shape
566+
567+ @numba_basic .numba_njit
568+ def find_shape (array_shape ):
569+ return shape
570+
571+ else :
572+
573+ @numba_basic .numba_njit
574+ def find_shape (array_shape ):
575+ shape = shape_template
576+ j = 0
577+ for i in range (ndim_new_shape ):
578+ if i not in augment :
579+ length = array_shape [j ]
580+ shape = numba_basic .tuple_setitem (shape , i , length )
581+ j = j + 1
582+ return shape
573583
574584 else :
575585
0 commit comments