|
2 | 2 | use ndarray::*; |
3 | 3 | use npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API}; |
4 | 4 | use num_traits::AsPrimitive; |
5 | | -use pyo3::{exceptions::TypeError, ffi, prelude::*, types::PyObjectRef}; |
| 5 | +use pyo3::{ffi, prelude::*, types::PyObjectRef}; |
6 | 6 | use pyo3::{PyDowncastError, PyObjectWithToken, ToPyPointer}; |
7 | 7 | use std::iter::ExactSizeIterator; |
8 | 8 | use std::marker::PhantomData; |
@@ -116,30 +116,21 @@ impl<'a, T, D> ::std::convert::From<&'a PyArray<T, D>> for &'a PyObjectRef { |
116 | 116 | } |
117 | 117 |
|
118 | 118 | impl<'a, T: TypeNum, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> { |
119 | | - // here we do type-check twice |
| 119 | + // here we do type-check three times |
120 | 120 | // 1. Checks if the object is PyArray |
121 | 121 | // 2. Checks if the data type of the array is T |
| 122 | + // 3. Checks if the dimension is same as D |
122 | 123 | fn extract(ob: &'a PyObjectRef) -> PyResult<Self> { |
123 | 124 | let array = unsafe { |
124 | 125 | if npyffi::PyArray_Check(ob.as_ptr()) == 0 { |
125 | 126 | return Err(PyDowncastError.into()); |
126 | 127 | } |
127 | | - if let Some(ndim) = D::NDIM { |
128 | | - let ptr = ob.as_ptr() as *mut npyffi::PyArrayObject; |
129 | | - if (*ptr).nd as usize != ndim { |
130 | | - return Err(PyErr::new::<TypeError, _>(format!( |
131 | | - "specified dim was {}, but actual dim was {}", |
132 | | - ndim, |
133 | | - (*ptr).nd |
134 | | - ))); |
135 | | - } |
136 | | - } |
137 | 128 | &*(ob as *const PyObjectRef as *const PyArray<T, D>) |
138 | 129 | }; |
139 | 130 | array |
140 | 131 | .type_check() |
141 | 132 | .map(|_| array) |
142 | | - .into_pyresult_with(|| "FromPyObject::extract typecheck failed") |
| 133 | + .into_pyresult_with(|| "[FromPyObject::extract] typecheck failed") |
143 | 134 | } |
144 | 135 | } |
145 | 136 |
|
@@ -398,6 +389,27 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> { |
398 | 389 | } |
399 | 390 | } |
400 | 391 |
|
| 392 | + /// Get the immutable view of the internal data of `PyArray`, as slice. |
| 393 | + /// # Example |
| 394 | + /// ``` |
| 395 | + /// # extern crate pyo3; extern crate numpy; fn main() { |
| 396 | + /// use numpy::PyArray; |
| 397 | + /// let gil = pyo3::Python::acquire_gil(); |
| 398 | + /// let py_array = PyArray::arange(gil.python(), 0, 4, 1).reshape([2, 2]).unwrap(); |
| 399 | + /// assert_eq!(py_array.as_slice(), &[0, 1, 2, 3]); |
| 400 | + /// # } |
| 401 | + /// ``` |
| 402 | + pub fn as_slice(&self) -> &[T] { |
| 403 | + self.type_check_assert(); |
| 404 | + unsafe { ::std::slice::from_raw_parts(self.data(), self.len()) } |
| 405 | + } |
| 406 | + |
| 407 | + /// Get the mmutable view of the internal data of `PyArray`, as slice. |
| 408 | + pub fn as_slice_mut(&self) -> &mut [T] { |
| 409 | + self.type_check_assert(); |
| 410 | + unsafe { ::std::slice::from_raw_parts_mut(self.data(), self.len()) } |
| 411 | + } |
| 412 | + |
401 | 413 | /// Construct PyArray from `ndarray::ArrayBase`. |
402 | 414 | /// |
403 | 415 | /// This method allocates memory in Python's heap via numpy api, and then copies all elements |
@@ -584,6 +596,22 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> { |
584 | 596 | let python = self.py(); |
585 | 597 | unsafe { PyArray::from_borrowed_ptr(python, self.as_ptr()) } |
586 | 598 | } |
| 599 | + |
| 600 | + fn type_check_assert(&self) { |
| 601 | + let type_check = self.type_check(); |
| 602 | + assert!(type_check.is_ok(), "{:?}", type_check); |
| 603 | + } |
| 604 | + |
| 605 | + fn type_check(&self) -> Result<(), ErrorKind> { |
| 606 | + let truth = self.typenum(); |
| 607 | + let dim = self.shape().len(); |
| 608 | + let dim_ok = D::NDIM.map(|n| n == dim).unwrap_or(true); |
| 609 | + if T::is_same_type(truth) && dim_ok { |
| 610 | + Ok(()) |
| 611 | + } else { |
| 612 | + Err(ErrorKind::to_rust(truth, dim, T::npy_data_type(), D::NDIM)) |
| 613 | + } |
| 614 | + } |
587 | 615 | } |
588 | 616 |
|
589 | 617 | impl<T: TypeNum> PyArray<T, Ix1> { |
@@ -828,41 +856,6 @@ impl<T: TypeNum, D> PyArray<T, D> { |
828 | 856 | NpyDataType::from_i32(self.typenum()) |
829 | 857 | } |
830 | 858 |
|
831 | | - fn type_check_assert(&self) { |
832 | | - let type_check = self.type_check(); |
833 | | - assert!(type_check.is_ok(), "{:?}", type_check); |
834 | | - } |
835 | | - |
836 | | - fn type_check(&self) -> Result<(), ErrorKind> { |
837 | | - let truth = self.typenum(); |
838 | | - if T::is_same_type(truth) { |
839 | | - Ok(()) |
840 | | - } else { |
841 | | - Err(ErrorKind::to_rust(truth, T::npy_data_type())) |
842 | | - } |
843 | | - } |
844 | | - |
845 | | - /// Get the immutable view of the internal data of `PyArray`, as slice. |
846 | | - /// # Example |
847 | | - /// ``` |
848 | | - /// # extern crate pyo3; extern crate numpy; fn main() { |
849 | | - /// use numpy::PyArray; |
850 | | - /// let gil = pyo3::Python::acquire_gil(); |
851 | | - /// let py_array = PyArray::arange(gil.python(), 0, 4, 1).reshape([2, 2]).unwrap(); |
852 | | - /// assert_eq!(py_array.as_slice(), &[0, 1, 2, 3]); |
853 | | - /// # } |
854 | | - /// ``` |
855 | | - pub fn as_slice(&self) -> &[T] { |
856 | | - self.type_check_assert(); |
857 | | - unsafe { ::std::slice::from_raw_parts(self.data(), self.len()) } |
858 | | - } |
859 | | - |
860 | | - /// Get the mmutable view of the internal data of `PyArray`, as slice. |
861 | | - pub fn as_slice_mut(&self) -> &mut [T] { |
862 | | - self.type_check_assert(); |
863 | | - unsafe { ::std::slice::from_raw_parts_mut(self.data(), self.len()) } |
864 | | - } |
865 | | - |
866 | 859 | /// Copies self into `other`, performing a data-type conversion if necessary. |
867 | 860 | /// # Example |
868 | 861 | /// ``` |
|
0 commit comments