@@ -17,23 +17,58 @@ mod windows;
1717use std:: iter:: FromIterator ;
1818use std:: marker:: PhantomData ;
1919use std:: ptr;
20+ use std:: slice:: { self , Iter as SliceIter , IterMut as SliceIterMut } ;
2021use alloc:: vec:: Vec ;
2122
23+ use crate :: imp_prelude:: * ;
2224use crate :: Ix1 ;
2325
24- use super :: { ArrayBase , ArrayView , ArrayViewMut , Axis , Data , NdProducer , RemoveAxis } ;
25- use super :: { Dimension , Ix , Ixs } ;
26+ use super :: { NdProducer , RemoveAxis } ;
2627
2728pub use self :: chunks:: { ExactChunks , ExactChunksIter , ExactChunksIterMut , ExactChunksMut } ;
2829pub use self :: lanes:: { Lanes , LanesMut } ;
2930pub use self :: windows:: Windows ;
3031pub use self :: into_iter:: IntoIter ;
3132
32- use std:: slice:: { self , Iter as SliceIter , IterMut as SliceIterMut } ;
33+ use crate :: dimension;
34+
35+ /// No traversal optmizations that would change element order or axis dimensions are permitted.
36+ ///
37+ /// This option is suitable for example for the indexed iterator.
38+ pub ( crate ) enum NoOptimization { }
39+
40+ /// Preserve element iteration order, but modify dimensions if profitable; for example we can
41+ /// change from shape [10, 1] to [1, 10], because that axis has len == 1, without consequence here.
42+ ///
43+ /// This option is suitable for example for the default .iter() iterator.
44+ pub ( crate ) enum PreserveOrder { }
45+
46+ /// Allow use of arbitrary element iteration order
47+ ///
48+ /// This option is suitable for example for an arbitrary order iterator.
49+ pub ( crate ) enum ArbitraryOrder { }
50+
51+ pub ( crate ) trait OrderOption {
52+ const ALLOW_REMOVE_REDUNDANT_AXES : bool = false ;
53+ const ALLOW_ARBITRARY_ORDER : bool = false ;
54+ }
55+
56+ impl OrderOption for NoOptimization { }
57+
58+ impl OrderOption for PreserveOrder {
59+ const ALLOW_REMOVE_REDUNDANT_AXES : bool = true ;
60+ }
61+
62+ impl OrderOption for ArbitraryOrder {
63+ const ALLOW_REMOVE_REDUNDANT_AXES : bool = true ;
64+ const ALLOW_ARBITRARY_ORDER : bool = true ;
65+ }
3366
3467/// Base for iterators over all axes.
3568///
3669/// Iterator element type is `*mut A`.
70+ ///
71+ /// `F` is for layout/iteration order flags
3772pub ( crate ) struct Baseiter < A , D > {
3873 ptr : * mut A ,
3974 dim : D ,
@@ -46,12 +81,43 @@ impl<A, D: Dimension> Baseiter<A, D> {
4681 /// to be correct to avoid performing an unsafe pointer offset while
4782 /// iterating.
4883 #[ inline]
49- pub unsafe fn new ( ptr : * mut A , len : D , stride : D ) -> Baseiter < A , D > {
84+ pub unsafe fn new ( ptr : * mut A , dim : D , strides : D ) -> Baseiter < A , D > {
85+ Self :: new_with_order :: < NoOptimization > ( ptr, dim, strides)
86+ }
87+ }
88+
89+ impl < A , D : Dimension > Baseiter < A , D > {
90+ /// Creating a Baseiter is unsafe because shape and stride parameters need
91+ /// to be correct to avoid performing an unsafe pointer offset while
92+ /// iterating.
93+ #[ inline]
94+ pub unsafe fn new_with_order < Flags : OrderOption > ( mut ptr : * mut A , mut dim : D , mut strides : D )
95+ -> Baseiter < A , D >
96+ {
97+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
98+ if Flags :: ALLOW_ARBITRARY_ORDER {
99+ // iterate in memory order; merge axes if possible
100+ // make all axes positive and put the pointer back to the first element in memory
101+ let offset = dimension:: offset_from_ptr_to_memory ( & dim, & strides) ;
102+ ptr = ptr. offset ( offset) ;
103+ for i in 0 ..strides. ndim ( ) {
104+ let s = strides. get_stride ( Axis ( i) ) ;
105+ if s < 0 {
106+ strides. set_stride ( Axis ( i) , -s) ;
107+ }
108+ }
109+ dimension:: sort_axes_to_standard ( & mut dim, & mut strides) ;
110+ }
111+ if Flags :: ALLOW_REMOVE_REDUNDANT_AXES {
112+ // preserve element order but shift dimensions
113+ dimension:: merge_axes_from_the_back ( & mut dim, & mut strides) ;
114+ dimension:: squeeze ( & mut dim, & mut strides) ;
115+ }
50116 Baseiter {
51117 ptr,
52- index : len . first_index ( ) ,
53- dim : len ,
54- strides : stride ,
118+ index : dim . first_index ( ) ,
119+ dim,
120+ strides,
55121 }
56122 }
57123}
@@ -1499,3 +1565,147 @@ where
14991565 debug_assert_eq ! ( size, result. len( ) ) ;
15001566 result
15011567}
1568+
1569+ #[ cfg( test) ]
1570+ #[ cfg( feature = "std" ) ]
1571+ mod tests {
1572+ use crate :: prelude:: * ;
1573+ use super :: Baseiter ;
1574+ use super :: { ArbitraryOrder , PreserveOrder , NoOptimization } ;
1575+ use itertools:: assert_equal;
1576+ use itertools:: Itertools ;
1577+
1578+ // 3-d axis swaps
1579+ fn swaps ( ) -> impl Iterator < Item =Vec < ( usize , usize ) > > {
1580+ vec ! [
1581+ vec![ ] ,
1582+ vec![ ( 0 , 1 ) ] ,
1583+ vec![ ( 0 , 2 ) ] ,
1584+ vec![ ( 1 , 2 ) ] ,
1585+ vec![ ( 0 , 1 ) , ( 1 , 2 ) ] ,
1586+ vec![ ( 0 , 1 ) , ( 0 , 2 ) ] ,
1587+ ] . into_iter ( )
1588+ }
1589+
1590+ // 3-d axis inverts
1591+ fn inverts ( ) -> impl Iterator < Item =Vec < Axis > > {
1592+ vec ! [
1593+ vec![ ] ,
1594+ vec![ Axis ( 0 ) ] ,
1595+ vec![ Axis ( 1 ) ] ,
1596+ vec![ Axis ( 2 ) ] ,
1597+ vec![ Axis ( 0 ) , Axis ( 1 ) ] ,
1598+ vec![ Axis ( 0 ) , Axis ( 2 ) ] ,
1599+ vec![ Axis ( 1 ) , Axis ( 2 ) ] ,
1600+ vec![ Axis ( 0 ) , Axis ( 1 ) , Axis ( 2 ) ] ,
1601+ ] . into_iter ( )
1602+ }
1603+
1604+ #[ test]
1605+ fn test_arbitrary_order ( ) {
1606+ for swap in swaps ( ) {
1607+ for invert in inverts ( ) {
1608+ for & slice in & [ false , true ] {
1609+ // pattern is 0, 1; 4, 5; 8, 9; etc..
1610+ let mut a = Array :: from_iter ( 0 ..24 ) . into_shape ( ( 3 , 4 , 2 ) ) . unwrap ( ) ;
1611+ if slice {
1612+ a. slice_collapse ( s ! [ .., ..; 2 , ..] ) ;
1613+ }
1614+ for & ( i, j) in & swap {
1615+ a. swap_axes ( i, j) ;
1616+ }
1617+ for & i in & invert {
1618+ a. invert_axis ( i) ;
1619+ }
1620+ unsafe {
1621+ // Should have in-memory order for arbitrary order
1622+ let iter = Baseiter :: new_with_order :: < ArbitraryOrder > ( a. as_mut_ptr ( ) ,
1623+ a. dim , a. strides ) ;
1624+ if !slice {
1625+ assert_equal ( iter. map ( |ptr| * ptr) , 0 ..a. len ( ) ) ;
1626+ } else {
1627+ assert_eq ! ( iter. map( |ptr| * ptr) . collect_vec( ) ,
1628+ ( 0 ..a. len( ) * 2 ) . filter( |& x| ( x / 2 ) % 2 == 0 ) . collect_vec( ) ) ;
1629+ }
1630+ }
1631+ }
1632+ }
1633+ }
1634+ }
1635+
1636+ #[ test]
1637+ fn test_logical_order ( ) {
1638+ for swap in swaps ( ) {
1639+ for invert in inverts ( ) {
1640+ for & slice in & [ false , true ] {
1641+ let mut a = Array :: from_iter ( 0 ..24 ) . into_shape ( ( 3 , 4 , 2 ) ) . unwrap ( ) ;
1642+ for & ( i, j) in & swap {
1643+ a. swap_axes ( i, j) ;
1644+ }
1645+ for & i in & invert {
1646+ a. invert_axis ( i) ;
1647+ }
1648+ if slice {
1649+ a. slice_collapse ( s ! [ .., ..; 2 , ..] ) ;
1650+ }
1651+
1652+ unsafe {
1653+ let mut iter = Baseiter :: new_with_order :: < NoOptimization > ( a. as_mut_ptr ( ) ,
1654+ a. dim , a. strides ) ;
1655+ let mut index = Dim ( [ 0 , 0 , 0 ] ) ;
1656+ let mut elts = 0 ;
1657+ while let Some ( elt) = iter. next ( ) {
1658+ assert_eq ! ( * elt, a[ index] ) ;
1659+ if let Some ( index_) = a. raw_dim ( ) . next_for ( index) {
1660+ index = index_;
1661+ }
1662+ elts += 1 ;
1663+ }
1664+ assert_eq ! ( elts, a. len( ) ) ;
1665+ }
1666+ }
1667+ }
1668+ }
1669+ }
1670+
1671+ #[ test]
1672+ fn test_preserve_order ( ) {
1673+ for swap in swaps ( ) {
1674+ for invert in inverts ( ) {
1675+ for & slice in & [ false , true ] {
1676+ let mut a = Array :: from_iter ( 0 ..20 ) . into_shape ( ( 2 , 10 , 1 ) ) . unwrap ( ) ;
1677+ for & ( i, j) in & swap {
1678+ a. swap_axes ( i, j) ;
1679+ }
1680+ for & i in & invert {
1681+ a. invert_axis ( i) ;
1682+ }
1683+ if slice {
1684+ a. slice_collapse ( s ! [ .., ..; 2 , ..] ) ;
1685+ }
1686+
1687+ unsafe {
1688+ let mut iter = Baseiter :: new_with_order :: < PreserveOrder > (
1689+ a. as_mut_ptr ( ) , a. dim , a. strides ) ;
1690+
1691+ // check that axes have been merged (when it's easy to check)
1692+ if a. shape ( ) == & [ 2 , 10 , 1 ] && invert. is_empty ( ) {
1693+ assert_eq ! ( iter. dim, Dim ( [ 1 , 1 , 20 ] ) ) ;
1694+ }
1695+
1696+ let mut index = Dim ( [ 0 , 0 , 0 ] ) ;
1697+ let mut elts = 0 ;
1698+ while let Some ( elt) = iter. next ( ) {
1699+ assert_eq ! ( * elt, a[ index] ) ;
1700+ if let Some ( index_) = a. raw_dim ( ) . next_for ( index) {
1701+ index = index_;
1702+ }
1703+ elts += 1 ;
1704+ }
1705+ assert_eq ! ( elts, a. len( ) ) ;
1706+ }
1707+ }
1708+ }
1709+ }
1710+ }
1711+ }
0 commit comments