-
Notifications
You must be signed in to change notification settings - Fork 17
Pyo3 and numpy update #958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,16 +23,16 @@ fn create_computation_graph_from_py_bytes(computation: Vec<u8>) -> Computation { | |
|
|
||
| fn pyobj_to_value(py: Python, obj: PyObject) -> PyResult<Value> { | ||
| let obj_ref = obj.as_ref(py); | ||
| if obj_ref.is_instance::<PyString>()? { | ||
| if obj_ref.is_instance_of::<PyString>()? { | ||
| let string_value: String = obj.extract(py)?; | ||
| Ok(Value::HostString(Box::new(HostString( | ||
| string_value, | ||
| HostPlacement::from("fake"), | ||
| )))) | ||
| } else if obj_ref.is_instance::<PyFloat>()? { | ||
| } else if obj_ref.is_instance_of::<PyFloat>()? { | ||
| let float_value: f64 = obj.extract(py)?; | ||
| Ok(Value::Float64(Box::new(float_value))) | ||
| } else if obj_ref.is_instance::<PyArrayDyn<f32>>()? { | ||
| } else if obj_ref.is_instance_of::<PyArrayDyn<f32>>()? { | ||
| // NOTE: this passes for any inner dtype, since python's isinstance will | ||
| // only do a shallow typecheck. inside the pyobj_tensor_to_value we do further | ||
| // introspection on the array & its dtype to map to the correct kind of Value | ||
|
|
@@ -65,19 +65,40 @@ fn pyobj_tensor_to_host_bit_tensor(py: Python, obj: &PyObject) -> HostBitTensor | |
| fn pyobj_tensor_to_value(py: Python, obj: &PyObject) -> Result<Value, anyhow::Error> { | ||
| let dtype_obj = obj.getattr(py, "dtype")?; | ||
| let dtype: &PyArrayDescr = dtype_obj.cast_as(py).unwrap(); | ||
| let np_dtype = dtype.get_datatype().unwrap(); | ||
| match np_dtype { | ||
| numpy::DataType::Float32 => Ok(Value::from(pyobj_tensor_to_host_tensor::<f32>(py, obj))), | ||
| numpy::DataType::Float64 => Ok(Value::from(pyobj_tensor_to_host_tensor::<f64>(py, obj))), | ||
| numpy::DataType::Int8 => Ok(Value::from(pyobj_tensor_to_host_tensor::<i8>(py, obj))), | ||
| numpy::DataType::Int16 => Ok(Value::from(pyobj_tensor_to_host_tensor::<i16>(py, obj))), | ||
| numpy::DataType::Int32 => Ok(Value::from(pyobj_tensor_to_host_tensor::<i32>(py, obj))), | ||
| numpy::DataType::Int64 => Ok(Value::from(pyobj_tensor_to_host_tensor::<i64>(py, obj))), | ||
| numpy::DataType::Uint8 => Ok(Value::from(pyobj_tensor_to_host_tensor::<u8>(py, obj))), | ||
| numpy::DataType::Uint16 => Ok(Value::from(pyobj_tensor_to_host_tensor::<u16>(py, obj))), | ||
| numpy::DataType::Uint32 => Ok(Value::from(pyobj_tensor_to_host_tensor::<u32>(py, obj))), | ||
| numpy::DataType::Uint64 => Ok(Value::from(pyobj_tensor_to_host_tensor::<u64>(py, obj))), | ||
| numpy::DataType::Bool => Ok(Value::from(pyobj_tensor_to_host_bit_tensor(py, obj))), | ||
| match dtype { | ||
| dt if dt.is_equiv_to(numpy::dtype::<f32>(py)) => { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively we can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fwiw I also thought |
||
| Ok(Value::from(pyobj_tensor_to_host_tensor::<f32>(py, obj))) | ||
| } | ||
| dt if dt.is_equiv_to(numpy::dtype::<f64>(py)) => { | ||
| Ok(Value::from(pyobj_tensor_to_host_tensor::<f64>(py, obj))) | ||
| } | ||
| dt if dt.is_equiv_to(numpy::dtype::<i8>(py)) => { | ||
| Ok(Value::from(pyobj_tensor_to_host_tensor::<i8>(py, obj))) | ||
| } | ||
| dt if dt.is_equiv_to(numpy::dtype::<i16>(py)) => { | ||
| Ok(Value::from(pyobj_tensor_to_host_tensor::<i16>(py, obj))) | ||
| } | ||
| dt if dt.is_equiv_to(numpy::dtype::<i32>(py)) => { | ||
| Ok(Value::from(pyobj_tensor_to_host_tensor::<i32>(py, obj))) | ||
| } | ||
| dt if dt.is_equiv_to(numpy::dtype::<i64>(py)) => { | ||
| Ok(Value::from(pyobj_tensor_to_host_tensor::<i64>(py, obj))) | ||
| } | ||
| dt if dt.is_equiv_to(numpy::dtype::<u8>(py)) => { | ||
| Ok(Value::from(pyobj_tensor_to_host_tensor::<u8>(py, obj))) | ||
| } | ||
| dt if dt.is_equiv_to(numpy::dtype::<u16>(py)) => { | ||
| Ok(Value::from(pyobj_tensor_to_host_tensor::<u16>(py, obj))) | ||
| } | ||
| dt if dt.is_equiv_to(numpy::dtype::<u32>(py)) => { | ||
| Ok(Value::from(pyobj_tensor_to_host_tensor::<u32>(py, obj))) | ||
| } | ||
| dt if dt.is_equiv_to(numpy::dtype::<u64>(py)) => { | ||
| Ok(Value::from(pyobj_tensor_to_host_tensor::<u64>(py, obj))) | ||
| } | ||
| dt if dt.is_equiv_to(numpy::dtype::<bool>(py)) => { | ||
| Ok(Value::from(pyobj_tensor_to_host_bit_tensor(py, obj))) | ||
| } | ||
| otherwise => Err(anyhow::Error::msg(format!( | ||
| "Unsupported numpy datatype {:?}", | ||
| otherwise | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This enum was removed as a result of this comment: PyO3/rust-numpy#256 (comment)
Not much of a conversation and/or migration guide. I think I followed the suggested path though, judging by the comments and the changes in that PR.