Skip to content

Commit 9cd432e

Browse files
improve signature of ffi::PyIter_Send & add PyIterator::send
1 parent 82ab509 commit 9cd432e

File tree

2 files changed

+68
-3
lines changed

2 files changed

+68
-3
lines changed

pyo3-ffi/src/abstract_.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ extern "C" {
120120
pub fn PyIter_Next(arg1: *mut PyObject) -> *mut PyObject;
121121
#[cfg(all(not(PyPy), Py_3_10))]
122122
#[cfg_attr(PyPy, link_name = "PyPyIter_Send")]
123-
pub fn PyIter_Send(iter: *mut PyObject, arg: *mut PyObject, presult: *mut *mut PyObject);
123+
pub fn PyIter_Send(
124+
iter: *mut PyObject,
125+
arg: *mut PyObject,
126+
presult: *mut *mut PyObject,
127+
) -> PySendResult;
124128

125129
#[cfg_attr(PyPy, link_name = "PyPyNumber_Check")]
126130
pub fn PyNumber_Check(o: *mut PyObject) -> c_int;

src/types/iterator.rs

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,31 @@ impl PyIterator {
5252
}
5353
}
5454

55+
#[derive(Debug)]
56+
pub enum PySendResult<'py> {
57+
Next(Bound<'py, PyAny>),
58+
Return(Bound<'py, PyAny>),
59+
}
60+
61+
impl<'py> Bound<'py, PyIterator> {
62+
/// Sends a value into the iterator.
63+
#[inline]
64+
#[cfg(all(not(PyPy), Py_3_10))]
65+
pub fn send(&self, value: &Bound<'py, PyAny>) -> PyResult<PySendResult<'py>> {
66+
let py = self.py();
67+
let mut result = std::ptr::null_mut();
68+
match unsafe { ffi::PyIter_Send(self.as_ptr(), value.as_ptr(), &mut result) } {
69+
ffi::PySendResult::PYGEN_ERROR => Err(PyErr::fetch(py)),
70+
ffi::PySendResult::PYGEN_RETURN => Ok(PySendResult::Return(unsafe {
71+
result.assume_owned_unchecked(py).to_owned()
72+
})),
73+
ffi::PySendResult::PYGEN_NEXT => Ok(PySendResult::Next(unsafe {
74+
result.assume_owned_unchecked(py).to_owned()
75+
})),
76+
}
77+
}
78+
}
79+
5580
impl<'py> Iterator for Bound<'py, PyIterator> {
5681
type Item = PyResult<Bound<'py, PyAny>>;
5782

@@ -105,9 +130,9 @@ impl PyTypeCheck for PyIterator {
105130

106131
#[cfg(test)]
107132
mod tests {
108-
use super::PyIterator;
133+
use super::{PyIterator, PySendResult};
109134
use crate::exceptions::PyTypeError;
110-
use crate::types::{PyAnyMethods, PyDict, PyList, PyListMethods};
135+
use crate::types::{PyAnyMethods, PyDict, PyList, PyListMethods, PyNone};
111136
use crate::{ffi, IntoPyObject, Python};
112137

113138
#[test]
@@ -201,6 +226,42 @@ def fibonacci(target):
201226
});
202227
}
203228

229+
#[test]
230+
#[cfg(all(not(PyPy), Py_3_10))]
231+
fn send_generator() {
232+
let generator = ffi::c_str!(
233+
r#"
234+
def gen():
235+
value = None
236+
while(True):
237+
value = yield value
238+
if value is None:
239+
return
240+
"#
241+
);
242+
243+
Python::with_gil(|py| {
244+
let context = PyDict::new(py);
245+
py.run(generator, None, Some(&context)).unwrap();
246+
247+
let generator = py.eval(ffi::c_str!("gen()"), None, Some(&context)).unwrap();
248+
249+
let one = 1i32.into_pyobject(py).unwrap();
250+
assert!(matches!(
251+
generator.try_iter().unwrap().send(&PyNone::get(py)).unwrap(),
252+
PySendResult::Next(value) if value.is_none()
253+
));
254+
assert!(matches!(
255+
generator.try_iter().unwrap().send(&one).unwrap(),
256+
PySendResult::Next(value) if value.is(&one)
257+
));
258+
assert!(matches!(
259+
generator.try_iter().unwrap().send(&PyNone::get(py)).unwrap(),
260+
PySendResult::Return(value) if value.is_none()
261+
));
262+
});
263+
}
264+
204265
#[test]
205266
fn fibonacci_generator_bound() {
206267
use crate::types::any::PyAnyMethods;

0 commit comments

Comments
 (0)