diff --git a/src/tensor/equals.rs b/src/tensor/equals.rs index b4dab62..f6784f4 100644 --- a/src/tensor/equals.rs +++ b/src/tensor/equals.rs @@ -1,6 +1,6 @@ use crate::data_buffer::DataBuffer; use crate::dtype::RawDataType; -use crate::iterator::flat_iterator::FlatIterator; +use crate::iterator::iterators::FlatIterator; use crate::TensorBase; impl PartialEq> for TensorBase diff --git a/src/tensor/iterator/flat_iterator.rs b/src/tensor/iterator/iterators.rs similarity index 54% rename from src/tensor/iterator/flat_iterator.rs rename to src/tensor/iterator/iterators.rs index dfe7b23..0f64e12 100644 --- a/src/tensor/iterator/flat_iterator.rs +++ b/src/tensor/iterator/iterators.rs @@ -1,7 +1,10 @@ +use crate::data_buffer::DataBuffer; use crate::dtype::RawDataType; use crate::iterator::buffer_iterator::BufferIterator; use crate::iterator::flat_index_generator::FlatIndexGenerator; -use crate::{Tensor, TensorView}; +use crate::tensor_iterator::TensorIterator; +use crate::traits::haslength::HasLength; +use crate::{AxisType, Tensor, TensorBase, TensorView}; use std::ops::Range; pub trait FlatIterator { @@ -23,3 +26,21 @@ impl FlatIterator for TensorView { BufferIterator::from(self, indices) } } + +impl TensorBase +where + B: DataBuffer, + T: RawDataType, +{ + pub fn iter(&self) -> TensorIterator { + TensorIterator::from(self, [0]) + } + + pub fn iter_along(&self, axis: impl AxisType) -> TensorIterator { + TensorIterator::from(self, [axis.usize()]) + } + + pub fn nditer(&self, axes: impl IntoIterator + HasLength + Clone) -> TensorIterator { + TensorIterator::from(self, axes) + } +} diff --git a/src/tensor/iterator/mod.rs b/src/tensor/iterator/mod.rs index 4a11641..69261a3 100644 --- a/src/tensor/iterator/mod.rs +++ b/src/tensor/iterator/mod.rs @@ -1,9 +1,9 @@ pub mod buffer_iterator; pub mod flat_index_generator; -pub mod flat_iterator; +pub mod iterators; pub mod tensor_iterator; pub(super) mod collapse_contiguous; mod util; -pub use flat_iterator::*; +pub use iterators::*; diff --git a/src/tensor/iterator/tensor_iterator.rs b/src/tensor/iterator/tensor_iterator.rs index de4a3a1..565dd81 100644 --- a/src/tensor/iterator/tensor_iterator.rs +++ b/src/tensor/iterator/tensor_iterator.rs @@ -2,7 +2,7 @@ use crate::data_buffer::DataBuffer; use crate::dtype::RawDataType; use crate::iterator::util::split_by_indices; use crate::traits::haslength::HasLength; -use crate::{AxisType, TensorBase, TensorView}; +use crate::{TensorBase, TensorView}; #[non_exhaustive] pub struct TensorIterator @@ -20,24 +20,6 @@ where size: usize, } -impl TensorBase -where - B: DataBuffer, - T: RawDataType, -{ - pub fn iter(&self) -> TensorIterator { - TensorIterator::from(self, [0]) - } - - pub fn iter_along(&self, axis: impl AxisType) -> TensorIterator { - TensorIterator::from(self, [axis.usize()]) - } - - pub fn nditer(&self, axes: impl IntoIterator + HasLength + Clone) -> TensorIterator { - TensorIterator::from(self, axes) - } -} - impl TensorIterator where T: RawDataType, diff --git a/tests/tensor.rs b/tests/tensor.rs index c942c2b..b333442 100644 --- a/tests/tensor.rs +++ b/tests/tensor.rs @@ -319,7 +319,6 @@ fn flat_iter() { assert_eq!(slice, [10, 11, 12, 16, 17, 18, 22, 23, 24]); let b = a.slice(s![1]); - println!("{:?}", b); let slice: Vec<_> = b.flat_iter().collect(); assert_eq!(slice, [16, 17, 18, 19, 20, 21]);