Skip to content

Commit e0b6cbf

Browse files
committed
feat: add coroutine::await_in_coroutine to await awaitables in coroutine context
1 parent ad5f6d4 commit e0b6cbf

File tree

14 files changed

+727
-188
lines changed

14 files changed

+727
-188
lines changed

Cargo.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ unindent = { version = "0.2.1", optional = true }
3131
# support crate for multiple-pymethods feature
3232
inventory = { version = "0.3.0", optional = true }
3333

34+
# support crate for experimental async feature
35+
scoped-tls-hkt = { version = "0.1", optional = true }
36+
3437
# crate integrations that can be added using the eponymous features
3538
anyhow = { version = "1.0.1", optional = true }
3639
chrono = { version = "0.4.25", default-features = false, optional = true }
@@ -41,7 +44,7 @@ hashbrown = { version = ">= 0.14.5, < 0.16", optional = true }
4144
indexmap = { version = ">= 2.5.0, < 3", optional = true }
4245
num-bigint = { version = "0.4.2", optional = true }
4346
num-complex = { version = ">= 0.4.6, < 0.5", optional = true }
44-
num-rational = {version = "0.4.1", optional = true }
47+
num-rational = { version = "0.4.1", optional = true }
4548
rust_decimal = { version = "1.15", default-features = false, optional = true }
4649
serde = { version = "1.0", optional = true }
4750
smallvec = { version = "1.0", optional = true }
@@ -63,7 +66,7 @@ rayon = "1.6.1"
6366
futures = "0.3.28"
6467
tempfile = "3.12.0"
6568
static_assertions = "1.1.0"
66-
uuid = {version = "1.10.0", features = ["v4"] }
69+
uuid = { version = "1.10.0", features = ["v4"] }
6770

6871
[build-dependencies]
6972
pyo3-build-config = { path = "pyo3-build-config", version = "=0.23.3", features = ["resolve-config"] }
@@ -72,7 +75,7 @@ pyo3-build-config = { path = "pyo3-build-config", version = "=0.23.3", features
7275
default = ["macros"]
7376

7477
# Enables support for `async fn` for `#[pyfunction]` and `#[pymethods]`.
75-
experimental-async = ["macros", "pyo3-macros/experimental-async"]
78+
experimental-async = ["macros", "pyo3-macros/experimental-async", "scoped-tls-hkt"]
7679

7780
# Enables pyo3::inspect module and additional type information on FromPyObject
7881
# and IntoPy traits

guide/src/SUMMARY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
- [Mapping of Rust types to Python types](conversions/tables.md)
2626
- [Conversion traits](conversions/traits.md)
2727
- [Using `async` and `await`](async-await.md)
28+
- [Awaiting Python awaitables](async-await/awaiting_python_awaitables)
2829
- [Parallelism](parallelism.md)
2930
- [Supporting Free-Threaded Python](free-threading.md)
3031
- [Debugging](debugging.md)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Awaiting Python awaitables
2+
3+
Python awaitable can be awaited on Rust side
4+
using [`await_in_coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/function.await_in_coroutine).
5+
6+
```rust
7+
# # ![allow(dead_code)]
8+
# #[cfg(feature = "experimental-async")] {
9+
use pyo3::{prelude::*, coroutine::await_in_coroutine};
10+
11+
#[pyfunction]
12+
async fn wrap_awaitable(awaitable: PyObject) -> PyResult<PyObject> {
13+
Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?.await
14+
}
15+
# }
16+
```
17+
18+
Behind the scene, `await_in_coroutine` calls the `__await__` method of the Python awaitable (or `__iter__` for
19+
generator-based coroutine).
20+
21+
## Restrictions
22+
23+
As the name suggests, `await_in_coroutine` resulting future can only be awaited in coroutine context. Otherwise, it
24+
panics.
25+
26+
```rust
27+
# # ![allow(dead_code)]
28+
# #[cfg(feature = "experimental-async")] {
29+
use pyo3::{prelude::*, coroutine::await_in_coroutine};
30+
31+
#[pyfunction]
32+
fn block_on(awaitable: PyObject) -> PyResult<PyObject> {
33+
let future = Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?;
34+
futures::executor::block_on(future) // ERROR: Python awaitable must be awaited in coroutine context
35+
}
36+
# }
37+
```
38+
39+
The future must also be the only one to be awaited at a time; it means that it's forbidden to await it in a `select!`.
40+
Otherwise, it panics.
41+
42+
```rust
43+
# # ![allow(dead_code)]
44+
# #[cfg(feature = "experimental-async")] {
45+
use futures::FutureExt;
46+
use pyo3::{prelude::*, coroutine::await_in_coroutine};
47+
48+
#[pyfunction]
49+
async fn select(awaitable: PyObject) -> PyResult<PyObject> {
50+
let future = Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?;
51+
futures::select_biased! {
52+
_ = std::future::pending::<()>().fuse() => unreachable!(),
53+
res = future.fuse() => res, // ERROR: Python awaitable mixed with Rust future
54+
}
55+
}
56+
# }
57+
```
58+
59+
These restrictions exist because awaiting a `await_in_coroutine` future strongly binds it to the
60+
enclosing coroutine. The coroutine will then delegate its `send`/`throw`/`close` methods to the
61+
awaited future. If it was awaited in a `select!`, `Coroutine::send` would no able to know if
62+
the value passed would have to be delegated or not.

newsfragments/3611.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `coroutine::await_in_coroutine` to await awaitables in coroutine context

pyo3-ffi/src/abstract_.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
use crate::object::*;
2-
use crate::pyport::Py_ssize_t;
1+
use std::os::raw::{c_char, c_int};
2+
33
#[cfg(any(Py_3_12, all(Py_3_8, not(Py_LIMITED_API))))]
44
use libc::size_t;
5-
use std::os::raw::{c_char, c_int};
5+
6+
use crate::{object::*, pyport::Py_ssize_t};
67

78
#[inline]
89
#[cfg(all(not(Py_3_13), not(PyPy)))] // CPython exposed as a function in 3.13, in object.h
@@ -143,7 +144,11 @@ extern "C" {
143144
pub fn PyIter_Next(arg1: *mut PyObject) -> *mut PyObject;
144145
#[cfg(all(not(PyPy), Py_3_10))]
145146
#[cfg_attr(PyPy, link_name = "PyPyIter_Send")]
146-
pub fn PyIter_Send(iter: *mut PyObject, arg: *mut PyObject, presult: *mut *mut PyObject);
147+
pub fn PyIter_Send(
148+
iter: *mut PyObject,
149+
arg: *mut PyObject,
150+
presult: *mut *mut PyObject,
151+
) -> c_int;
147152

148153
#[cfg_attr(PyPy, link_name = "PyPyNumber_Check")]
149154
pub fn PyNumber_Check(o: *mut PyObject) -> c_int;

src/coroutine.rs

Lines changed: 63 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,37 @@ use std::{
1111
use pyo3_macros::{pyclass, pymethods};
1212

1313
use crate::{
14-
coroutine::{cancel::ThrowCallback, waker::AsyncioWaker},
15-
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
14+
coroutine::waker::CoroutineWaker,
15+
exceptions::{PyAttributeError, PyGeneratorExit, PyRuntimeError, PyStopIteration},
1616
panic::PanicException,
17-
types::{string::PyStringMethods, PyIterator, PyString},
18-
Bound, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr, PyObject, PyResult, Python,
17+
types::{string::PyStringMethods, PyString},
18+
Bound, IntoPyObject, IntoPyObjectExt, Py, PyErr, PyObject, PyResult, Python,
1919
};
2020

21-
pub(crate) mod cancel;
21+
mod asyncio;
22+
mod awaitable;
23+
mod cancel;
2224
mod waker;
2325

24-
pub use cancel::CancelHandle;
26+
pub use awaitable::await_in_coroutine;
27+
pub use cancel::{CancelHandle, ThrowCallback};
2528

2629
const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
2730

31+
pub(crate) enum CoroOp {
32+
Send(PyObject),
33+
Throw(PyObject),
34+
Close,
35+
}
36+
2837
/// Python coroutine wrapping a [`Future`].
2938
#[pyclass(crate = "crate")]
3039
pub struct Coroutine {
3140
name: Option<Py<PyString>>,
3241
qualname_prefix: Option<&'static str>,
3342
throw_callback: Option<ThrowCallback>,
3443
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
35-
waker: Option<Arc<AsyncioWaker>>,
44+
waker: Option<Arc<CoroutineWaker>>,
3645
}
3746

3847
// Safety: `Coroutine` is allowed to be `Sync` even though the future is not,
@@ -71,55 +80,58 @@ impl Coroutine {
7180
}
7281
}
7382

74-
fn poll(&mut self, py: Python<'_>, throw: Option<PyObject>) -> PyResult<PyObject> {
83+
fn poll_inner(&mut self, py: Python<'_>, mut op: CoroOp) -> PyResult<PyObject> {
7584
// raise if the coroutine has already been run to completion
7685
let future_rs = match self.future {
7786
Some(ref mut fut) => fut,
7887
None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
7988
};
80-
// reraise thrown exception it
81-
match (throw, &self.throw_callback) {
82-
(Some(exc), Some(cb)) => cb.throw(exc),
83-
(Some(exc), None) => {
84-
self.close();
85-
return Err(PyErr::from_value(exc.into_bound(py)));
86-
}
87-
(None, _) => {}
89+
// if the future is not pending on a Python awaitable,
90+
// execute throw callback or complete on close
91+
if !matches!(self.waker, Some(ref w) if w.is_delegated(py)) {
92+
match op {
93+
send @ CoroOp::Send(_) => op = send,
94+
CoroOp::Throw(exc) => match &self.throw_callback {
95+
Some(cb) => {
96+
cb.throw(exc.clone_ref(py));
97+
op = CoroOp::Send(py.None());
98+
}
99+
None => return Err(PyErr::from_value(exc.into_bound(py))),
100+
},
101+
CoroOp::Close => return Err(PyGeneratorExit::new_err(py.None())),
102+
};
88103
}
89104
// create a new waker, or try to reset it in place
90105
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
91-
waker.reset();
106+
waker.reset(op);
92107
} else {
93-
self.waker = Some(Arc::new(AsyncioWaker::new()));
108+
self.waker = Some(Arc::new(CoroutineWaker::new(op)));
94109
}
95-
let waker = Waker::from(self.waker.clone().unwrap());
96-
// poll the Rust future and forward its results if ready
110+
// poll the future and forward its results if ready; otherwise, yield from waker
97111
// polling is UnwindSafe because the future is dropped in case of panic
112+
let waker = Waker::from(self.waker.clone().unwrap());
98113
let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
99114
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
100-
Ok(Poll::Ready(res)) => {
101-
self.close();
102-
return Err(PyStopIteration::new_err((res?,)));
103-
}
104-
Err(err) => {
105-
self.close();
106-
return Err(PanicException::from_panic_payload(err));
107-
}
108-
_ => {}
115+
Err(err) => Err(PanicException::from_panic_payload(err)),
116+
// See #4407, `PyStopIteration::new_err` argument must be wrap in a tuple,
117+
// otherwise, when a tuple is returned, its fields would be expanded as error
118+
// arguments
119+
Ok(Poll::Ready(res)) => Err(PyStopIteration::new_err((res?,))),
120+
Ok(Poll::Pending) => match self.waker.as_ref().unwrap().yield_(py) {
121+
Ok(to_yield) => Ok(to_yield),
122+
Err(err) => Err(err),
123+
},
109124
}
110-
// otherwise, initialize the waker `asyncio.Future`
111-
if let Some(future) = self.waker.as_ref().unwrap().initialize_future(py)? {
112-
// `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__`
113-
// and will yield itself if its result has not been set in polling above
114-
if let Some(future) = PyIterator::from_object(future).unwrap().next() {
115-
// future has not been leaked into Python for now, and Rust code can only call
116-
// `set_result(None)` in `Wake` implementation, so it's safe to unwrap
117-
return Ok(future.unwrap().into());
118-
}
125+
}
126+
127+
fn poll(&mut self, py: Python<'_>, op: CoroOp) -> PyResult<PyObject> {
128+
let result = self.poll_inner(py, op);
129+
if result.is_err() {
130+
// the Rust future is dropped, and the field set to `None`
131+
// to indicate the coroutine has been run to completion
132+
drop(self.future.take());
119133
}
120-
// if waker has been waken during future polling, this is roughly equivalent to
121-
// `await asyncio.sleep(0)`, so just yield `None`.
122-
Ok(py.None())
134+
result
123135
}
124136
}
125137

@@ -145,25 +157,27 @@ impl Coroutine {
145157
}
146158
}
147159

148-
fn send(&mut self, py: Python<'_>, _value: &Bound<'_, PyAny>) -> PyResult<PyObject> {
149-
self.poll(py, None)
160+
fn send(&mut self, py: Python<'_>, value: PyObject) -> PyResult<PyObject> {
161+
self.poll(py, CoroOp::Send(value))
150162
}
151163

152164
fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult<PyObject> {
153-
self.poll(py, Some(exc))
165+
self.poll(py, CoroOp::Throw(exc))
154166
}
155167

156-
fn close(&mut self) {
157-
// the Rust future is dropped, and the field set to `None`
158-
// to indicate the coroutine has been run to completion
159-
drop(self.future.take());
168+
fn close(&mut self, py: Python<'_>) -> PyResult<()> {
169+
match self.poll(py, CoroOp::Close) {
170+
Ok(_) => Ok(()),
171+
Err(err) if err.is_instance_of::<PyGeneratorExit>(py) => Ok(()),
172+
Err(err) => Err(err),
173+
}
160174
}
161175

162176
fn __await__(self_: Py<Self>) -> Py<Self> {
163177
self_
164178
}
165179

166180
fn __next__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
167-
self.poll(py, None)
181+
self.poll(py, CoroOp::Send(py.None()))
168182
}
169183
}

0 commit comments

Comments
 (0)