@@ -701,6 +701,33 @@ where
701701 }
702702}
703703
704+ /// Attempt to merge axes if possible, starting from the back
705+ ///
706+ /// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
707+ /// to merge all axes one by one into Axis(3); when/if this fails,
708+ /// it attempts to merge the rest of the axes together into the next
709+ /// axis in line, for example a result could be:
710+ ///
711+ /// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
712+ /// mean axes were merged.
713+ pub ( crate ) fn merge_axes_from_the_back < D > ( dim : & mut D , strides : & mut D )
714+ where
715+ D : Dimension ,
716+ {
717+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
718+ match dim. ndim ( ) {
719+ 0 | 1 => { }
720+ n => {
721+ let mut last = n - 1 ;
722+ for i in ( 0 ..last) . rev ( ) {
723+ if !merge_axes ( dim, strides, Axis ( i) , Axis ( last) ) {
724+ last = i;
725+ }
726+ }
727+ }
728+ }
729+ }
730+
704731/// Move the axis which has the smallest absolute stride and a length
705732/// greater than one to be the last axis.
706733pub fn move_min_stride_axis_to_last < D > ( dim : & mut D , strides : & mut D )
@@ -765,12 +792,40 @@ where
765792 * strides = new_strides;
766793}
767794
795+
796+ /// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
797+ /// stride
798+ ///
799+ /// The axes are sorted according to the .abs() of their stride.
800+ pub ( crate ) fn sort_axes_to_standard < D > ( dim : & mut D , strides : & mut D )
801+ where
802+ D : Dimension ,
803+ {
804+ debug_assert ! ( dim. ndim( ) > 1 ) ;
805+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
806+ // bubble sort axes
807+ let mut changed = true ;
808+ while changed {
809+ changed = false ;
810+ for i in 0 ..dim. ndim ( ) - 1 {
811+ // make sure higher stride axes sort before.
812+ if strides. get_stride ( Axis ( i) ) . abs ( ) < strides. get_stride ( Axis ( i + 1 ) ) . abs ( ) {
813+ changed = true ;
814+ dim. slice_mut ( ) . swap ( i, i + 1 ) ;
815+ strides. slice_mut ( ) . swap ( i, i + 1 ) ;
816+ }
817+ }
818+ }
819+ }
820+
821+
768822#[ cfg( test) ]
769823mod test {
770824 use super :: {
771825 arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd,
772826 max_abs_offset_check_overflow, slice_min_max, slices_intersect,
773827 solve_linear_diophantine_eq, IntoDimension , squeeze,
828+ merge_axes_from_the_back,
774829 } ;
775830 use crate :: error:: { from_kind, ErrorKind } ;
776831 use crate :: slice:: Slice ;
@@ -1119,4 +1174,26 @@ mod test {
11191174 assert_eq ! ( d, dans) ;
11201175 assert_eq ! ( s, sans) ;
11211176 }
1177+
1178+ #[ test]
1179+ fn test_merge_axes_from_the_back ( ) {
1180+ let dyndim = Dim :: < & [ usize ] > ;
1181+
1182+ let mut d = Dim ( [ 3 , 4 , 5 ] ) ;
1183+ let mut s = Dim ( [ 20 , 5 , 1 ] ) ;
1184+ merge_axes_from_the_back ( & mut d, & mut s) ;
1185+ assert_eq ! ( d, Dim ( [ 1 , 1 , 60 ] ) ) ;
1186+ assert_eq ! ( s, Dim ( [ 20 , 5 , 1 ] ) ) ;
1187+
1188+ let mut d = Dim ( [ 3 , 4 , 5 , 2 ] ) ;
1189+ let mut s = Dim ( [ 80 , 20 , 2 , 1 ] ) ;
1190+ merge_axes_from_the_back ( & mut d, & mut s) ;
1191+ assert_eq ! ( d, Dim ( [ 1 , 12 , 1 , 10 ] ) ) ;
1192+ assert_eq ! ( s, Dim ( [ 80 , 20 , 2 , 1 ] ) ) ;
1193+ let mut d = d. into_dyn ( ) ;
1194+ let mut s = s. into_dyn ( ) ;
1195+ squeeze ( & mut d, & mut s) ;
1196+ assert_eq ! ( d, dyndim( & [ 12 , 10 ] ) ) ;
1197+ assert_eq ! ( s, dyndim( & [ 20 , 1 ] ) ) ;
1198+ }
11221199}
0 commit comments