@@ -11,7 +11,8 @@ use crate::imp_prelude::*;
1111
1212/// # Methods for Dynamic-Dimensional Arrays
1313impl < A , S > ArrayBase < S , IxDyn >
14- where S : Data < Elem = A >
14+ where
15+ S : Data < Elem = A > ,
1516{
1617 /// Insert new array axis of length 1 at `axis`, modifying the shape and
1718 /// strides in-place.
@@ -29,8 +30,7 @@ where S: Data<Elem = A>
2930 /// assert_eq!(a.shape(), &[2, 1, 3]);
3031 /// ```
3132 #[ track_caller]
32- pub fn insert_axis_inplace ( & mut self , axis : Axis )
33- {
33+ pub fn insert_axis_inplace ( & mut self , axis : Axis ) {
3434 assert ! ( axis. index( ) <= self . ndim( ) ) ;
3535 self . dim = self . dim . insert_axis ( axis) ;
3636 self . strides = self . strides . insert_axis ( axis) ;
@@ -52,10 +52,62 @@ where S: Data<Elem = A>
5252 /// assert_eq!(a.shape(), &[2]);
5353 /// ```
5454 #[ track_caller]
55- pub fn index_axis_inplace ( & mut self , axis : Axis , index : usize )
56- {
55+ pub fn index_axis_inplace ( & mut self , axis : Axis , index : usize ) {
5756 self . collapse_axis ( axis, index) ;
5857 self . dim = self . dim . remove_axis ( axis) ;
5958 self . strides = self . strides . remove_axis ( axis) ;
6059 }
60+
61+ /// Remove axes of length 1 and return the modified array.
62+ ///
63+ /// If the array has more the one dimension, the result array will always
64+ /// have at least one dimension, even if it has a length of 1.
65+ ///
66+ /// ```
67+ /// use ndarray::{arr2, arr3};
68+ ///
69+ /// let a = arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn();
70+ /// assert_eq!(a.shape(), &[2, 1, 3]);
71+ /// let b = a.squeeze();
72+ /// assert_eq!(b, arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn());
73+ /// assert_eq!(b.shape(), &[2, 3]);
74+ ///
75+ /// let c = arr2(&[[1]]).into_dyn();
76+ /// assert_eq!(c.shape(), &[1, 1]);
77+ /// let d = c.squeeze();
78+ /// assert_eq!(d, arr1(&[1]).into_dyn());
79+ /// assert_eq!(d.shape(), &[1]);
80+ /// ```
81+ #[ track_caller]
82+ pub fn squeeze ( self ) -> Self {
83+ let mut out = self ;
84+ for axis in ( 0 ..out. shape ( ) . len ( ) ) . rev ( ) {
85+ if out. shape ( ) [ axis] == 1 && out. shape ( ) . len ( ) > 1 {
86+ out = out. remove_axis ( Axis ( axis) ) ;
87+ }
88+ }
89+ out
90+ }
91+ }
92+
93+ #[ cfg( test) ]
94+ mod tests {
95+ use crate :: { arr1, arr2, arr3} ;
96+
97+ #[ test]
98+ fn test_squeeze ( ) {
99+ let a = arr3 ( & [ [ [ 1 , 2 , 3 ] ] , [ [ 4 , 5 , 6 ] ] ] ) . into_dyn ( ) ;
100+ assert_eq ! ( a. shape( ) , & [ 2 , 1 , 3 ] ) ;
101+
102+ let b = a. squeeze ( ) ;
103+ assert_eq ! ( b, arr2( & [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ) . into_dyn( ) ) ;
104+ assert_eq ! ( b. shape( ) , & [ 2 , 3 ] ) ;
105+
106+ let c = arr2 ( & [ [ 1 ] ] ) . into_dyn ( ) ;
107+ assert_eq ! ( c. shape( ) , & [ 1 , 1 ] ) ;
108+
109+ let d = c. squeeze ( ) ;
110+ assert_eq ! ( d, arr1( & [ 1 ] ) . into_dyn( ) ) ;
111+ assert_eq ! ( d. shape( ) , & [ 1 ] ) ;
112+ }
61113}
0 commit comments