@@ -11,28 +11,37 @@ use std::{
1111use pyo3_macros:: { pyclass, pymethods} ;
1212
1313use crate :: {
14- coroutine:: { cancel :: ThrowCallback , waker:: AsyncioWaker } ,
15- exceptions:: { PyAttributeError , PyRuntimeError , PyStopIteration } ,
14+ coroutine:: waker:: CoroutineWaker ,
15+ exceptions:: { PyAttributeError , PyGeneratorExit , PyRuntimeError , PyStopIteration } ,
1616 panic:: PanicException ,
17- types:: { string:: PyStringMethods , PyIterator , PyString } ,
18- Bound , IntoPyObject , IntoPyObjectExt , Py , PyAny , PyErr , PyObject , PyResult , Python ,
17+ types:: { string:: PyStringMethods , PyString } ,
18+ Bound , IntoPyObject , IntoPyObjectExt , Py , PyErr , PyObject , PyResult , Python ,
1919} ;
2020
21- pub ( crate ) mod cancel;
21+ mod asyncio;
22+ mod awaitable;
23+ mod cancel;
2224mod waker;
2325
24- pub use cancel:: CancelHandle ;
26+ pub use awaitable:: await_in_coroutine;
27+ pub use cancel:: { CancelHandle , ThrowCallback } ;
2528
2629const COROUTINE_REUSED_ERROR : & str = "cannot reuse already awaited coroutine" ;
2730
31+ pub ( crate ) enum CoroOp {
32+ Send ( PyObject ) ,
33+ Throw ( PyObject ) ,
34+ Close ,
35+ }
36+
2837/// Python coroutine wrapping a [`Future`].
2938#[ pyclass( crate = "crate" ) ]
3039pub struct Coroutine {
3140 name : Option < Py < PyString > > ,
3241 qualname_prefix : Option < & ' static str > ,
3342 throw_callback : Option < ThrowCallback > ,
3443 future : Option < Pin < Box < dyn Future < Output = PyResult < PyObject > > + Send > > > ,
35- waker : Option < Arc < AsyncioWaker > > ,
44+ waker : Option < Arc < CoroutineWaker > > ,
3645}
3746
3847// Safety: `Coroutine` is allowed to be `Sync` even though the future is not,
@@ -71,55 +80,58 @@ impl Coroutine {
7180 }
7281 }
7382
74- fn poll ( & mut self , py : Python < ' _ > , throw : Option < PyObject > ) -> PyResult < PyObject > {
83+ fn poll_inner ( & mut self , py : Python < ' _ > , mut op : CoroOp ) -> PyResult < PyObject > {
7584 // raise if the coroutine has already been run to completion
7685 let future_rs = match self . future {
7786 Some ( ref mut fut) => fut,
7887 None => return Err ( PyRuntimeError :: new_err ( COROUTINE_REUSED_ERROR ) ) ,
7988 } ;
80- // reraise thrown exception it
81- match ( throw, & self . throw_callback ) {
82- ( Some ( exc) , Some ( cb) ) => cb. throw ( exc) ,
83- ( Some ( exc) , None ) => {
84- self . close ( ) ;
85- return Err ( PyErr :: from_value ( exc. into_bound ( py) ) ) ;
86- }
87- ( None , _) => { }
89+ // if the future is not pending on a Python awaitable,
90+ // execute throw callback or complete on close
91+ if !matches ! ( self . waker, Some ( ref w) if w. is_delegated( py) ) {
92+ match op {
93+ send @ CoroOp :: Send ( _) => op = send,
94+ CoroOp :: Throw ( exc) => match & self . throw_callback {
95+ Some ( cb) => {
96+ cb. throw ( exc. clone_ref ( py) ) ;
97+ op = CoroOp :: Send ( py. None ( ) ) ;
98+ }
99+ None => return Err ( PyErr :: from_value ( exc. into_bound ( py) ) ) ,
100+ } ,
101+ CoroOp :: Close => return Err ( PyGeneratorExit :: new_err ( py. None ( ) ) ) ,
102+ } ;
88103 }
89104 // create a new waker, or try to reset it in place
90105 if let Some ( waker) = self . waker . as_mut ( ) . and_then ( Arc :: get_mut) {
91- waker. reset ( ) ;
106+ waker. reset ( op ) ;
92107 } else {
93- self . waker = Some ( Arc :: new ( AsyncioWaker :: new ( ) ) ) ;
108+ self . waker = Some ( Arc :: new ( CoroutineWaker :: new ( op ) ) ) ;
94109 }
95- let waker = Waker :: from ( self . waker . clone ( ) . unwrap ( ) ) ;
96- // poll the Rust future and forward its results if ready
110+ // poll the future and forward its results if ready; otherwise, yield from waker
97111 // polling is UnwindSafe because the future is dropped in case of panic
112+ let waker = Waker :: from ( self . waker . clone ( ) . unwrap ( ) ) ;
98113 let poll = || future_rs. as_mut ( ) . poll ( & mut Context :: from_waker ( & waker) ) ;
99114 match panic:: catch_unwind ( panic:: AssertUnwindSafe ( poll) ) {
100- Ok ( Poll :: Ready ( res ) ) => {
101- self . close ( ) ;
102- return Err ( PyStopIteration :: new_err ( ( res? , ) ) ) ;
103- }
104- Err ( err ) => {
105- self . close ( ) ;
106- return Err ( PanicException :: from_panic_payload ( err ) ) ;
107- }
108- _ => { }
115+ Err ( err ) => Err ( PanicException :: from_panic_payload ( err ) ) ,
116+ // See #4407, `PyStopIteration::new_err` argument must be wrap in a tuple,
117+ // otherwise, when a tuple is returned, its fields would be expanded as error
118+ // arguments
119+ Ok ( Poll :: Ready ( res ) ) => Err ( PyStopIteration :: new_err ( ( res? , ) ) ) ,
120+ Ok ( Poll :: Pending ) => match self . waker . as_ref ( ) . unwrap ( ) . yield_ ( py ) {
121+ Ok ( to_yield ) => Ok ( to_yield ) ,
122+ Err ( err ) => Err ( err ) ,
123+ } ,
109124 }
110- // otherwise, initialize the waker `asyncio.Future`
111- if let Some ( future) = self . waker . as_ref ( ) . unwrap ( ) . initialize_future ( py) ? {
112- // `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__`
113- // and will yield itself if its result has not been set in polling above
114- if let Some ( future) = PyIterator :: from_object ( future) . unwrap ( ) . next ( ) {
115- // future has not been leaked into Python for now, and Rust code can only call
116- // `set_result(None)` in `Wake` implementation, so it's safe to unwrap
117- return Ok ( future. unwrap ( ) . into ( ) ) ;
118- }
125+ }
126+
127+ fn poll ( & mut self , py : Python < ' _ > , op : CoroOp ) -> PyResult < PyObject > {
128+ let result = self . poll_inner ( py, op) ;
129+ if result. is_err ( ) {
130+ // the Rust future is dropped, and the field set to `None`
131+ // to indicate the coroutine has been run to completion
132+ drop ( self . future . take ( ) ) ;
119133 }
120- // if waker has been waken during future polling, this is roughly equivalent to
121- // `await asyncio.sleep(0)`, so just yield `None`.
122- Ok ( py. None ( ) )
134+ result
123135 }
124136}
125137
@@ -145,25 +157,27 @@ impl Coroutine {
145157 }
146158 }
147159
148- fn send ( & mut self , py : Python < ' _ > , _value : & Bound < ' _ , PyAny > ) -> PyResult < PyObject > {
149- self . poll ( py, None )
160+ fn send ( & mut self , py : Python < ' _ > , value : PyObject ) -> PyResult < PyObject > {
161+ self . poll ( py, CoroOp :: Send ( value ) )
150162 }
151163
152164 fn throw ( & mut self , py : Python < ' _ > , exc : PyObject ) -> PyResult < PyObject > {
153- self . poll ( py, Some ( exc) )
165+ self . poll ( py, CoroOp :: Throw ( exc) )
154166 }
155167
156- fn close ( & mut self ) {
157- // the Rust future is dropped, and the field set to `None`
158- // to indicate the coroutine has been run to completion
159- drop ( self . future . take ( ) ) ;
168+ fn close ( & mut self , py : Python < ' _ > ) -> PyResult < ( ) > {
169+ match self . poll ( py, CoroOp :: Close ) {
170+ Ok ( _) => Ok ( ( ) ) ,
171+ Err ( err) if err. is_instance_of :: < PyGeneratorExit > ( py) => Ok ( ( ) ) ,
172+ Err ( err) => Err ( err) ,
173+ }
160174 }
161175
162176 fn __await__ ( self_ : Py < Self > ) -> Py < Self > {
163177 self_
164178 }
165179
166180 fn __next__ ( & mut self , py : Python < ' _ > ) -> PyResult < PyObject > {
167- self . poll ( py, None )
181+ self . poll ( py, CoroOp :: Send ( py . None ( ) ) )
168182 }
169183}
0 commit comments