Skip to content

Commit de96d81

Browse files
committed
Provide an efficient way to go from Array<Py<T>, D> to PyArray<PyObject, D>.
1 parent 4aa655e commit de96d81

File tree

4 files changed

+58
-13
lines changed

4 files changed

+58
-13
lines changed

src/array.rs

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use ndarray::*;
44
use num_traits::AsPrimitive;
55
use pyo3::{
66
ffi, prelude::*, pyobject_native_type_info, pyobject_native_type_named, type_object,
7-
types::PyAny, AsPyPointer, PyDowncastError, PyNativeType, PyResult,
7+
types::PyAny, AsPyPointer, PyClass, PyDowncastError, PyNativeType, PyResult,
88
};
99
use std::{
1010
cell::Cell,
@@ -637,23 +637,25 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
637637
}
638638
}
639639

640-
/// Construct PyArray from
641-
/// [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html).
640+
/// Construct PyArray from [`ndarray::Array`]
642641
///
643-
/// This method uses internal [`Vec`](https://doc.rust-lang.org/std/vec/struct.Vec.html)
644-
/// of `ndarray::Array` as numpy array.
642+
/// This method uses internal [`Vec`] of `ndarray::Array` as NumPy array.
645643
///
646644
/// # Example
645+
///
647646
/// ```
648-
/// # #[macro_use] extern crate ndarray;
647+
/// use ndarray::array;
649648
/// use numpy::PyArray;
649+
///
650650
/// pyo3::Python::with_gil(|py| {
651651
/// let pyarray = PyArray::from_owned_array(py, array![[1, 2], [3, 4]]);
652652
/// assert_eq!(pyarray.readonly().as_array(), array![[1, 2], [3, 4]]);
653653
/// });
654654
/// ```
655655
pub fn from_owned_array<'py>(py: Python<'py>, arr: Array<T, D>) -> &'py Self {
656-
IntoPyArray::into_pyarray(arr, py)
656+
let (strides, dims) = (arr.npy_strides(), arr.raw_dim());
657+
let data_ptr = arr.as_ptr();
658+
unsafe { PyArray::from_raw_parts(py, dims, strides.as_ptr(), data_ptr, arr) }
657659
}
658660

659661
/// Get the immutable reference of the specified element, with checking the passed index is valid.
@@ -850,6 +852,51 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
850852
}
851853
}
852854

855+
impl<D: Dimension> PyArray<PyObject, D> {
856+
/// Construct PyArray containing objects from [`ndarray::Array`]
857+
///
858+
/// This method uses internal [`Vec`] of `ndarray::Array` as NumPy array.
859+
///
860+
/// # Example
861+
///
862+
/// ```
863+
/// use ndarray::array;
864+
/// use pyo3::{pyclass, Py, Python};
865+
/// use numpy::PyArray;
866+
///
867+
/// #[pyclass]
868+
/// struct CustomElement {
869+
/// foo: i32,
870+
/// bar: f64,
871+
/// }
872+
///
873+
/// Python::with_gil(|py| {
874+
/// let array = array![
875+
/// Py::new(py, CustomElement {
876+
/// foo: 1,
877+
/// bar: 2.0,
878+
/// }).unwrap(),
879+
/// Py::new(py, CustomElement {
880+
/// foo: 3,
881+
/// bar: 4.0,
882+
/// }).unwrap(),
883+
/// ];
884+
///
885+
/// let pyarray = PyArray::from_owned_object_array(py, array);
886+
///
887+
/// assert!(pyarray.readonly().get(0).unwrap().as_ref(py).is_instance::<CustomElement>().unwrap());
888+
/// });
889+
/// ```
890+
pub fn from_owned_object_array<'py, T: PyClass>(
891+
py: Python<'py>,
892+
arr: Array<Py<T>, D>,
893+
) -> &'py Self {
894+
let (strides, dims) = (arr.npy_strides(), arr.raw_dim());
895+
let data_ptr = arr.as_ptr() as *const PyObject;
896+
unsafe { PyArray::from_raw_parts(py, dims, strides.as_ptr(), data_ptr, arr) }
897+
}
898+
}
899+
853900
impl<T: Copy + Element> PyArray<T, Ix0> {
854901
/// Get the element of zero-dimensional PyArray.
855902
///

src/convert.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ where
6363
type Item = A;
6464
type Dim = D;
6565
fn into_pyarray<'py>(self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
66-
let (strides, dims) = (self.npy_strides(), self.raw_dim());
67-
let data_ptr = self.as_ptr();
68-
unsafe { PyArray::from_raw_parts(py, dims, strides.as_ptr(), data_ptr, self) }
66+
PyArray::from_owned_array(py, self)
6967
}
7068
}
7169

src/dtype.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ impl PyArrayDescr {
126126
/// as only the data type can be checked during downcasting, but not the dynamic type of each element stored in the array.
127127
/// This is not possible as the elements and hence their dynamic type could be changed later on via aliasing Python references to the array.
128128
/// Therefore, the only safe type to store in object arrays is `Py<PyAny>` also known as `PyObject`.
129+
///
130+
/// You can however create `ndarray::Array<Py<T>, D>` and turn that into a NumPy array safely and efficiently using [`from_owned_object_array`][crate::PyArray::from_owned_object_array].
129131
pub unsafe trait Element: Clone + Send {
130132
/// Flag that indicates whether this type is trivially copyable.
131133
///

src/slice_container.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ use pyo3::pyclass_slots::PyClassDummySlot;
77
use pyo3::type_object::{LazyStaticType, PyTypeInfo};
88
use pyo3::{ffi, types::PyAny, PyCell};
99

10-
use crate::dtype::Element;
11-
1210
/// Utility type to safely store Box<[_]> or Vec<_> on the Python heap
1311
pub(crate) struct PySliceContainer {
1412
ptr: *mut u8,
@@ -69,7 +67,7 @@ impl<T: Send> From<Vec<T>> for PySliceContainer {
6967

7068
impl<A, D> From<ArrayBase<OwnedRepr<A>, D>> for PySliceContainer
7169
where
72-
A: Element,
70+
A: Send,
7371
D: Dimension,
7472
{
7573
fn from(data: ArrayBase<OwnedRepr<A>, D>) -> Self {

0 commit comments

Comments
 (0)