@@ -10,14 +10,11 @@ use core::mem;
1010use core:: pin:: Pin ;
1111use core:: task:: { Context , Poll } ;
1212
13- use super :: { assert_future, TryFuture , TryMaybeDone } ;
13+ use super :: { assert_future, join_all , IntoFuture , TryFuture , TryMaybeDone } ;
1414
15- fn iter_pin_mut < T > ( slice : Pin < & mut [ T ] > ) -> impl Iterator < Item = Pin < & mut T > > {
16- // Safety: `std` _could_ make this unsound if it were to decide Pin's
17- // invariants aren't required to transmit through slices. Otherwise this has
18- // the same safety as a normal field pin projection.
19- unsafe { slice. get_unchecked_mut ( ) } . iter_mut ( ) . map ( |t| unsafe { Pin :: new_unchecked ( t) } )
20- }
15+ #[ cfg( not( futures_no_atomic_cas) ) ]
16+ use crate :: stream:: { FuturesOrdered , TryCollect , TryStreamExt } ;
17+ use crate :: TryFutureExt ;
2118
2219enum FinalState < E = ( ) > {
2320 Pending ,
@@ -31,17 +28,37 @@ pub struct TryJoinAll<F>
3128where
3229 F : TryFuture ,
3330{
34- elems : Pin < Box < [ TryMaybeDone < F > ] > > ,
31+ kind : TryJoinAllKind < F > ,
32+ }
33+
34+ enum TryJoinAllKind < F >
35+ where
36+ F : TryFuture ,
37+ {
38+ Small {
39+ elems : Pin < Box < [ TryMaybeDone < IntoFuture < F > > ] > > ,
40+ } ,
41+ #[ cfg( not( futures_no_atomic_cas) ) ]
42+ Big {
43+ fut : TryCollect < FuturesOrdered < IntoFuture < F > > , Vec < F :: Ok > > ,
44+ } ,
3545}
3646
3747impl < F > fmt:: Debug for TryJoinAll < F >
3848where
3949 F : TryFuture + fmt:: Debug ,
4050 F :: Ok : fmt:: Debug ,
4151 F :: Error : fmt:: Debug ,
52+ F :: Output : fmt:: Debug ,
4253{
4354 fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
44- f. debug_struct ( "TryJoinAll" ) . field ( "elems" , & self . elems ) . finish ( )
55+ match self . kind {
56+ TryJoinAllKind :: Small { ref elems } => {
57+ f. debug_struct ( "TryJoinAll" ) . field ( "elems" , elems) . finish ( )
58+ }
59+ #[ cfg( not( futures_no_atomic_cas) ) ]
60+ TryJoinAllKind :: Big { ref fut, .. } => fmt:: Debug :: fmt ( fut, f) ,
61+ }
4562 }
4663}
4764
@@ -83,15 +100,37 @@ where
83100/// assert_eq!(try_join_all(futures).await, Err(2));
84101/// # });
85102/// ```
86- pub fn try_join_all < I > ( i : I ) -> TryJoinAll < I :: Item >
103+ pub fn try_join_all < I > ( iter : I ) -> TryJoinAll < I :: Item >
87104where
88105 I : IntoIterator ,
89106 I :: Item : TryFuture ,
90107{
91- let elems: Box < [ _ ] > = i. into_iter ( ) . map ( TryMaybeDone :: Future ) . collect ( ) ;
92- assert_future :: < Result < Vec < <I :: Item as TryFuture >:: Ok > , <I :: Item as TryFuture >:: Error > , _ > (
93- TryJoinAll { elems : elems. into ( ) } ,
94- )
108+ let iter = iter. into_iter ( ) . map ( TryFutureExt :: into_future) ;
109+
110+ #[ cfg( futures_no_atomic_cas) ]
111+ {
112+ let kind = TryJoinAllKind :: Small {
113+ elems : iter. map ( TryMaybeDone :: Future ) . collect :: < Box < [ _ ] > > ( ) . into ( ) ,
114+ } ;
115+
116+ assert_future :: < Result < Vec < <I :: Item as TryFuture >:: Ok > , <I :: Item as TryFuture >:: Error > , _ > (
117+ TryJoinAll { kind } ,
118+ )
119+ }
120+
121+ #[ cfg( not( futures_no_atomic_cas) ) ]
122+ {
123+ let kind = match iter. size_hint ( ) . 1 {
124+ Some ( max) if max <= join_all:: SMALL => TryJoinAllKind :: Small {
125+ elems : iter. map ( TryMaybeDone :: Future ) . collect :: < Box < [ _ ] > > ( ) . into ( ) ,
126+ } ,
127+ _ => TryJoinAllKind :: Big { fut : iter. collect :: < FuturesOrdered < _ > > ( ) . try_collect ( ) } ,
128+ } ;
129+
130+ assert_future :: < Result < Vec < <I :: Item as TryFuture >:: Ok > , <I :: Item as TryFuture >:: Error > , _ > (
131+ TryJoinAll { kind } ,
132+ )
133+ }
95134}
96135
97136impl < F > Future for TryJoinAll < F >
@@ -101,36 +140,46 @@ where
101140 type Output = Result < Vec < F :: Ok > , F :: Error > ;
102141
103142 fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
104- let mut state = FinalState :: AllDone ;
105-
106- for elem in iter_pin_mut ( self . elems . as_mut ( ) ) {
107- match elem. try_poll ( cx) {
108- Poll :: Pending => state = FinalState :: Pending ,
109- Poll :: Ready ( Ok ( ( ) ) ) => { }
110- Poll :: Ready ( Err ( e) ) => {
111- state = FinalState :: Error ( e) ;
112- break ;
143+ match & mut self . kind {
144+ TryJoinAllKind :: Small { elems } => {
145+ let mut state = FinalState :: AllDone ;
146+
147+ for elem in join_all:: iter_pin_mut ( elems. as_mut ( ) ) {
148+ match elem. try_poll ( cx) {
149+ Poll :: Pending => state = FinalState :: Pending ,
150+ Poll :: Ready ( Ok ( ( ) ) ) => { }
151+ Poll :: Ready ( Err ( e) ) => {
152+ state = FinalState :: Error ( e) ;
153+ break ;
154+ }
155+ }
113156 }
114- }
115- }
116157
117- match state {
118- FinalState :: Pending => Poll :: Pending ,
119- FinalState :: AllDone => {
120- let mut elems = mem:: replace ( & mut self . elems , Box :: pin ( [ ] ) ) ;
121- let results =
122- iter_pin_mut ( elems. as_mut ( ) ) . map ( |e| e. take_output ( ) . unwrap ( ) ) . collect ( ) ;
123- Poll :: Ready ( Ok ( results) )
124- }
125- FinalState :: Error ( e) => {
126- let _ = mem:: replace ( & mut self . elems , Box :: pin ( [ ] ) ) ;
127- Poll :: Ready ( Err ( e) )
158+ match state {
159+ FinalState :: Pending => Poll :: Pending ,
160+ FinalState :: AllDone => {
161+ let mut elems = mem:: replace ( elems, Box :: pin ( [ ] ) ) ;
162+ let results = join_all:: iter_pin_mut ( elems. as_mut ( ) )
163+ . map ( |e| e. take_output ( ) . unwrap ( ) )
164+ . collect ( ) ;
165+ Poll :: Ready ( Ok ( results) )
166+ }
167+ FinalState :: Error ( e) => {
168+ let _ = mem:: replace ( elems, Box :: pin ( [ ] ) ) ;
169+ Poll :: Ready ( Err ( e) )
170+ }
171+ }
128172 }
173+ #[ cfg( not( futures_no_atomic_cas) ) ]
174+ TryJoinAllKind :: Big { fut } => Pin :: new ( fut) . poll ( cx) ,
129175 }
130176 }
131177}
132178
133- impl < F : TryFuture > FromIterator < F > for TryJoinAll < F > {
179+ impl < F > FromIterator < F > for TryJoinAll < F >
180+ where
181+ F : TryFuture ,
182+ {
134183 fn from_iter < T : IntoIterator < Item = F > > ( iter : T ) -> Self {
135184 try_join_all ( iter)
136185 }
0 commit comments