Skip to content

Commit

Permalink
Renamed flat_iterator.rs to iterators.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
BhavyeMathur committed Dec 22, 2024
1 parent b4314de commit cc9be2f
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/tensor/equals.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::data_buffer::DataBuffer;
use crate::dtype::RawDataType;
use crate::iterator::flat_iterator::FlatIterator;
use crate::iterator::iterators::FlatIterator;
use crate::TensorBase;

impl<B1, T1, B2, T2> PartialEq<TensorBase<B1>> for TensorBase<B2>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::data_buffer::DataBuffer;
use crate::dtype::RawDataType;
use crate::iterator::buffer_iterator::BufferIterator;
use crate::iterator::flat_index_generator::FlatIndexGenerator;
use crate::{Tensor, TensorView};
use crate::tensor_iterator::TensorIterator;
use crate::traits::haslength::HasLength;
use crate::{AxisType, Tensor, TensorBase, TensorView};
use std::ops::Range;

pub trait FlatIterator<T: RawDataType> {
Expand All @@ -23,3 +26,21 @@ impl<T: RawDataType> FlatIterator<T> for TensorView<T> {
BufferIterator::from(self, indices)
}
}

impl<B, T> TensorBase<B>
where
B: DataBuffer<crate::tensor::data_buffer::buffer::DType=T>,
T: RawDataType,
{
pub fn iter(&self) -> TensorIterator<T> {
TensorIterator::from(self, [0])
}

pub fn iter_along(&self, axis: impl AxisType) -> TensorIterator<T> {
TensorIterator::from(self, [axis.usize()])
}

pub fn nditer(&self, axes: impl IntoIterator<Item=usize> + HasLength + Clone) -> TensorIterator<T> {
TensorIterator::from(self, axes)
}
}
4 changes: 2 additions & 2 deletions src/tensor/iterator/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub mod buffer_iterator;
pub mod flat_index_generator;
pub mod flat_iterator;
pub mod iterators;
pub mod tensor_iterator;

pub(super) mod collapse_contiguous;
mod util;

pub use flat_iterator::*;
pub use iterators::*;
20 changes: 1 addition & 19 deletions src/tensor/iterator/tensor_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::data_buffer::DataBuffer;
use crate::dtype::RawDataType;
use crate::iterator::util::split_by_indices;
use crate::traits::haslength::HasLength;
use crate::{AxisType, TensorBase, TensorView};
use crate::{TensorBase, TensorView};

#[non_exhaustive]
pub struct TensorIterator<T>
Expand All @@ -20,24 +20,6 @@ where
size: usize,
}

impl<B, T> TensorBase<B>
where
B: DataBuffer<DType=T>,
T: RawDataType,
{
pub fn iter(&self) -> TensorIterator<T> {
TensorIterator::from(self, [0])
}

pub fn iter_along(&self, axis: impl AxisType) -> TensorIterator<T> {
TensorIterator::from(self, [axis.usize()])
}

pub fn nditer(&self, axes: impl IntoIterator<Item=usize> + HasLength + Clone) -> TensorIterator<T> {
TensorIterator::from(self, axes)
}
}

impl<T> TensorIterator<T>
where
T: RawDataType,
Expand Down
1 change: 0 additions & 1 deletion tests/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ fn flat_iter() {
assert_eq!(slice, [10, 11, 12, 16, 17, 18, 22, 23, 24]);

let b = a.slice(s![1]);
println!("{:?}", b);
let slice: Vec<_> = b.flat_iter().collect();
assert_eq!(slice, [16, 17, 18, 19, 20, 21]);

Expand Down

0 comments on commit cc9be2f

Please sign in to comment.