| 
 | 1 | + | 
 | 2 | +use crate::{Dimension, Order, ShapeError, ErrorKind};  | 
 | 3 | +use crate::dimension::sequence::{Sequence, SequenceMut, Forward, Reverse};  | 
 | 4 | + | 
 | 5 | +#[inline]  | 
 | 6 | +pub(crate) fn reshape_dim<D, E>(from: &D, strides: &D, to: &E, order: Order)  | 
 | 7 | +    -> Result<E, ShapeError>  | 
 | 8 | +where  | 
 | 9 | +    D: Dimension,  | 
 | 10 | +    E: Dimension,  | 
 | 11 | +{  | 
 | 12 | +    debug_assert_eq!(from.ndim(), strides.ndim());  | 
 | 13 | +    let mut to_strides = E::zeros(to.ndim());  | 
 | 14 | +    match order {  | 
 | 15 | +        Order::RowMajor => {  | 
 | 16 | +            reshape_dim_c(&Forward(from), &Forward(strides),  | 
 | 17 | +                          &Forward(to), Forward(&mut to_strides))?;  | 
 | 18 | +        }  | 
 | 19 | +        Order::ColumnMajor => {  | 
 | 20 | +            reshape_dim_c(&Reverse(from), &Reverse(strides),  | 
 | 21 | +                          &Reverse(to), Reverse(&mut to_strides))?;  | 
 | 22 | +        }  | 
 | 23 | +    }  | 
 | 24 | +    Ok(to_strides)  | 
 | 25 | +}  | 
 | 26 | + | 
 | 27 | +/// Try to reshape an array with dimensions `from_dim` and strides `from_strides` to the new  | 
 | 28 | +/// dimension `to_dim`, while keeping the same layout of elements in memory. The strides needed  | 
 | 29 | +/// if this is possible are stored into `to_strides`.  | 
 | 30 | +///  | 
 | 31 | +/// This function uses RowMajor index ordering if the inputs are read in the forward direction  | 
 | 32 | +/// (index 0 is axis 0 etc) and ColumnMajor index ordering if the inputs are read in reversed  | 
 | 33 | +/// direction (as made possible with the Sequence trait).  | 
 | 34 | +///   | 
 | 35 | +/// Preconditions:  | 
 | 36 | +///  | 
 | 37 | +/// 1. from_dim and to_dim are valid dimensions (product of all non-zero axes  | 
 | 38 | +/// fits in isize::MAX).  | 
 | 39 | +/// 2. from_dim and to_dim are don't have any axes that are zero (that should be handled before  | 
 | 40 | +///    this function).  | 
 | 41 | +/// 3. `to_strides` should be an all-zeros or all-ones dimension of the right dimensionality  | 
 | 42 | +/// (but it will be overwritten after successful exit of this function).  | 
 | 43 | +///  | 
 | 44 | +/// This function returns:  | 
 | 45 | +///  | 
 | 46 | +/// - IncompatibleShape if the two shapes are not of matching number of elements  | 
 | 47 | +/// - IncompatibleLayout if the input shape and stride can not be remapped to the output shape  | 
 | 48 | +///   without moving the array data into a new memory layout.  | 
 | 49 | +/// - Ok if the from dim could be mapped to the new to dim.  | 
 | 50 | +fn reshape_dim_c<D, E, E2>(from_dim: &D, from_strides: &D, to_dim: &E, mut to_strides: E2)  | 
 | 51 | +    -> Result<(), ShapeError>  | 
 | 52 | +where  | 
 | 53 | +    D: Sequence<Output=usize>,  | 
 | 54 | +    E: Sequence<Output=usize>,  | 
 | 55 | +    E2: SequenceMut<Output=usize>,  | 
 | 56 | +{  | 
 | 57 | +    // cursor indexes into the from and to dimensions  | 
 | 58 | +    let mut fi = 0;  // index into `from_dim`  | 
 | 59 | +    let mut ti = 0;  // index into `to_dim`.  | 
 | 60 | + | 
 | 61 | +    while fi < from_dim.len() && ti < to_dim.len() {  | 
 | 62 | +        let mut fd = from_dim[fi];  | 
 | 63 | +        let mut fs = from_strides[fi] as isize;  | 
 | 64 | +        let mut td = to_dim[ti];  | 
 | 65 | + | 
 | 66 | +        if fd == td {  | 
 | 67 | +            to_strides[ti] = from_strides[fi];  | 
 | 68 | +            fi += 1;  | 
 | 69 | +            ti += 1;  | 
 | 70 | +            continue  | 
 | 71 | +        }  | 
 | 72 | + | 
 | 73 | +        if fd == 1 {  | 
 | 74 | +            fi += 1;  | 
 | 75 | +            continue;  | 
 | 76 | +        }  | 
 | 77 | + | 
 | 78 | +        if td == 1 {  | 
 | 79 | +            to_strides[ti] = 1;  | 
 | 80 | +            ti += 1;  | 
 | 81 | +            continue;  | 
 | 82 | +        }  | 
 | 83 | + | 
 | 84 | +        if fd == 0 || td == 0 {  | 
 | 85 | +            debug_assert!(false, "zero dim not handled by this function");  | 
 | 86 | +            return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));  | 
 | 87 | +        }  | 
 | 88 | + | 
 | 89 | +        // stride times element count is to be distributed out over a combination of axes.  | 
 | 90 | +        let mut fstride_whole = fs * (fd as isize);  | 
 | 91 | +        let mut fd_product = fd;  // cumulative product of axis lengths in the combination (from)  | 
 | 92 | +        let mut td_product = td;  // cumulative product of axis lengths in the combination (to)  | 
 | 93 | + | 
 | 94 | +        // The two axis lengths are not a match, so try to combine multiple axes  | 
 | 95 | +        // to get it to match up.  | 
 | 96 | +        while fd_product != td_product {  | 
 | 97 | +            if fd_product < td_product {  | 
 | 98 | +                // Take another axis on the from side  | 
 | 99 | +                fi += 1;  | 
 | 100 | +                if fi >= from_dim.len() {  | 
 | 101 | +                    return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));  | 
 | 102 | +                }  | 
 | 103 | +                fd = from_dim[fi];  | 
 | 104 | +                fd_product *= fd;  | 
 | 105 | +                if fd > 1 {  | 
 | 106 | +                    let fs_old = fs;  | 
 | 107 | +                    fs = from_strides[fi] as isize;  | 
 | 108 | +                    // check if this axis and the next are contiguous together  | 
 | 109 | +                    if fs_old != fd as isize * fs {  | 
 | 110 | +                        return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));  | 
 | 111 | +                    }  | 
 | 112 | +                }  | 
 | 113 | +            } else {  | 
 | 114 | +                // Take another axis on the `to` side  | 
 | 115 | +                // First assign the stride to the axis we leave behind  | 
 | 116 | +                fstride_whole /= td as isize;  | 
 | 117 | +                to_strides[ti] = fstride_whole as usize;  | 
 | 118 | +                ti += 1;  | 
 | 119 | +                if ti >= to_dim.len() {  | 
 | 120 | +                    return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));  | 
 | 121 | +                }  | 
 | 122 | + | 
 | 123 | +                td = to_dim[ti];  | 
 | 124 | +                td_product *= td;  | 
 | 125 | +            }  | 
 | 126 | +        }  | 
 | 127 | + | 
 | 128 | +        fstride_whole /= td as isize;  | 
 | 129 | +        to_strides[ti] = fstride_whole as usize;  | 
 | 130 | + | 
 | 131 | +        fi += 1;  | 
 | 132 | +        ti += 1;  | 
 | 133 | +    }  | 
 | 134 | + | 
 | 135 | +    // skip past 1-dims at the end  | 
 | 136 | +    while fi < from_dim.len() && from_dim[fi] == 1 {  | 
 | 137 | +        fi += 1;  | 
 | 138 | +    }  | 
 | 139 | + | 
 | 140 | +    while ti < to_dim.len() && to_dim[ti] == 1 {  | 
 | 141 | +        to_strides[ti] = 1;  | 
 | 142 | +        ti += 1;  | 
 | 143 | +    }  | 
 | 144 | + | 
 | 145 | +    if fi < from_dim.len() || ti < to_dim.len() {  | 
 | 146 | +        return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));  | 
 | 147 | +    }  | 
 | 148 | + | 
 | 149 | +    Ok(())  | 
 | 150 | +}  | 
 | 151 | + | 
 | 152 | +#[cfg(feature = "std")]  | 
 | 153 | +#[test]  | 
 | 154 | +fn test_reshape() {  | 
 | 155 | +    use crate::Dim;  | 
 | 156 | + | 
 | 157 | +    macro_rules! test_reshape {  | 
 | 158 | +        (fail $order:ident from $from:expr, $stride:expr, to $to:expr) => {  | 
 | 159 | +            let res = reshape_dim(&Dim($from), &Dim($stride), &Dim($to), Order::$order);  | 
 | 160 | +            println!("Reshape {:?} {:?} to {:?}, order {:?}\n  => {:?}",  | 
 | 161 | +                     $from, $stride, $to, Order::$order, res);  | 
 | 162 | +            let _res = res.expect_err("Expected failed reshape");  | 
 | 163 | +        };  | 
 | 164 | +        (ok $order:ident from $from:expr, $stride:expr, to $to:expr, $to_stride:expr) => {{  | 
 | 165 | +            let res = reshape_dim(&Dim($from), &Dim($stride), &Dim($to), Order::$order);  | 
 | 166 | +            println!("Reshape {:?} {:?} to {:?}, order {:?}\n  => {:?}",  | 
 | 167 | +                     $from, $stride, $to, Order::$order, res);  | 
 | 168 | +            println!("default stride for from dim: {:?}", Dim($from).default_strides());  | 
 | 169 | +            println!("default stride for to dim: {:?}", Dim($to).default_strides());  | 
 | 170 | +            let res = res.expect("Expected successful reshape");  | 
 | 171 | +            assert_eq!(res, Dim($to_stride), "mismatch in strides");  | 
 | 172 | +        }};  | 
 | 173 | +    }  | 
 | 174 | + | 
 | 175 | +    test_reshape!(ok C from [1, 2, 3], [6, 3, 1], to [1, 2, 3], [6, 3, 1]);  | 
 | 176 | +    test_reshape!(ok C from [1, 2, 3], [6, 3, 1], to [2, 3], [3, 1]);  | 
 | 177 | +    test_reshape!(ok C from [1, 2, 3], [6, 3, 1], to [6], [1]);  | 
 | 178 | +    test_reshape!(fail C from [1, 2, 3], [6, 3, 1], to [1]);  | 
 | 179 | +    test_reshape!(fail F from [1, 2, 3], [6, 3, 1], to [1]);  | 
 | 180 | + | 
 | 181 | +    test_reshape!(ok C from [6], [1], to [3, 2], [2, 1]);  | 
 | 182 | +    test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [4, 15], [15, 1]);  | 
 | 183 | + | 
 | 184 | +    test_reshape!(ok C from [4, 4, 4], [16, 4, 1], to [16, 4], [4, 1]);  | 
 | 185 | + | 
 | 186 | +    test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 4, 1], [8, 4, 1, 1]);  | 
 | 187 | +    test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 4], [8, 4, 1]);  | 
 | 188 | +    test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 2, 2], [8, 4, 2, 1]);  | 
 | 189 | + | 
 | 190 | +    test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 1, 4], [8, 4, 1, 1]);  | 
 | 191 | + | 
 | 192 | +    test_reshape!(ok C from [4, 4, 4], [16, 4, 1], to [16, 4], [4, 1]);  | 
 | 193 | +    test_reshape!(ok C from [3, 4, 4], [16, 4, 1], to [3, 16], [16, 1]);  | 
 | 194 | + | 
 | 195 | +    test_reshape!(ok C from [4, 4], [8, 1], to [2, 2, 2, 2], [16, 8, 2, 1]);  | 
 | 196 | + | 
 | 197 | +    test_reshape!(fail C from [4, 4], [8, 1], to [2, 1, 4, 2]);  | 
 | 198 | + | 
 | 199 | +    test_reshape!(ok C from [16], [4], to [2, 2, 4], [32, 16, 4]);  | 
 | 200 | +    test_reshape!(ok C from [16], [-4isize as usize], to [2, 2, 4],  | 
 | 201 | +                  [-32isize as usize, -16isize as usize, -4isize as usize]);  | 
 | 202 | +    test_reshape!(ok F from [16], [4], to [2, 2, 4], [4, 8, 16]);  | 
 | 203 | +    test_reshape!(ok F from [16], [-4isize as usize], to [2, 2, 4],  | 
 | 204 | +                  [-4isize as usize, -8isize as usize, -16isize as usize]);  | 
 | 205 | + | 
 | 206 | +    test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [12, 5], [5, 1]);  | 
 | 207 | +    test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [4, 15], [15, 1]);  | 
 | 208 | +    test_reshape!(fail F from [3, 4, 5], [20, 5, 1], to [4, 15]);  | 
 | 209 | +    test_reshape!(ok C from [3, 4, 5, 7], [140, 35, 7, 1], to [28, 15], [15, 1]);  | 
 | 210 | + | 
 | 211 | +    // preserve stride if shape matches  | 
 | 212 | +    test_reshape!(ok C from [10], [2], to [10], [2]);  | 
 | 213 | +    test_reshape!(ok F from [10], [2], to [10], [2]);  | 
 | 214 | +    test_reshape!(ok C from [2, 10], [1, 2], to [2, 10], [1, 2]);  | 
 | 215 | +    test_reshape!(ok F from [2, 10], [1, 2], to [2, 10], [1, 2]);  | 
 | 216 | +    test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [3, 4, 5], [20, 5, 1]);  | 
 | 217 | +    test_reshape!(ok F from [3, 4, 5], [20, 5, 1], to [3, 4, 5], [20, 5, 1]);  | 
 | 218 | + | 
 | 219 | +    test_reshape!(ok C from [3, 4, 5], [4, 1, 1], to [12, 5], [1, 1]);  | 
 | 220 | +    test_reshape!(ok F from [3, 4, 5], [1, 3, 12], to [12, 5], [1, 12]);  | 
 | 221 | +    test_reshape!(ok F from [3, 4, 5], [1, 3, 1], to [12, 5], [1, 1]);  | 
 | 222 | + | 
 | 223 | +    // broadcast shapes  | 
 | 224 | +    test_reshape!(ok C from [3, 4, 5, 7], [0, 0, 7, 1], to [12, 35], [0, 1]);  | 
 | 225 | +    test_reshape!(fail C from [3, 4, 5, 7], [0, 0, 7, 1], to [28, 15]);  | 
 | 226 | + | 
 | 227 | +    // one-filled shapes  | 
 | 228 | +    test_reshape!(ok C from [10], [1], to [1, 10, 1, 1, 1], [1, 1, 1, 1, 1]);  | 
 | 229 | +    test_reshape!(ok F from [10], [1], to [1, 10, 1, 1, 1], [1, 1, 1, 1, 1]);  | 
 | 230 | +    test_reshape!(ok C from [1, 10], [10, 1], to [1, 10, 1, 1, 1], [10, 1, 1, 1, 1]);  | 
 | 231 | +    test_reshape!(ok F from [1, 10], [10, 1], to [1, 10, 1, 1, 1], [10, 1, 1, 1, 1]);  | 
 | 232 | +    test_reshape!(ok C from [1, 10], [1, 1], to [1, 5, 1, 1, 2], [1, 2, 2, 2, 1]);  | 
 | 233 | +    test_reshape!(ok F from [1, 10], [1, 1], to [1, 5, 1, 1, 2], [1, 1, 5, 5, 5]);  | 
 | 234 | +    test_reshape!(ok C from [10, 1, 1, 1, 1], [1, 1, 1, 1, 1], to [10], [1]);  | 
 | 235 | +    test_reshape!(ok F from [10, 1, 1, 1, 1], [1, 1, 1, 1, 1], to [10], [1]);  | 
 | 236 | +    test_reshape!(ok C from [1, 5, 1, 2, 1], [1, 2, 1, 1, 1], to [10], [1]);  | 
 | 237 | +    test_reshape!(fail F from [1, 5, 1, 2, 1], [1, 2, 1, 1, 1], to [10]);  | 
 | 238 | +    test_reshape!(ok F from [1, 5, 1, 2, 1], [1, 1, 1, 5, 1], to [10], [1]);  | 
 | 239 | +    test_reshape!(fail C from [1, 5, 1, 2, 1], [1, 1, 1, 5, 1], to [10]);  | 
 | 240 | +}  | 
 | 241 | + | 
0 commit comments