-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Functional iterator for single axis for Tensor (not TensorView) #3
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
use crate::data_buffer::{DataBuffer, DataOwned, DataView}; | ||
use crate::dtype::RawDataType; | ||
use crate::{tensor, TensorView}; | ||
Check warning Code scanning / clippy unused import: tensor Warning
unused import: tensor
|
||
|
||
#[non_exhaustive] | ||
pub struct IteratorBase<'a, T, B> | ||
where | ||
T: RawDataType, | ||
B: DataBuffer<DType = T>, | ||
{ | ||
data_buffer: &'a B, | ||
axis: usize, | ||
shape: Vec<usize>, | ||
stride: Vec<usize>, | ||
indices: usize, | ||
iter_count: isize, | ||
} | ||
|
||
impl<'a, T, B> IteratorBase<'a, T, B> | ||
where | ||
T: RawDataType, | ||
B: DataBuffer<DType = T>, | ||
{ | ||
pub(super) fn from( | ||
data_buffer: &'a B, | ||
axis: usize, | ||
shape: Vec<usize>, | ||
stride: Vec<usize>, | ||
indices: usize, | ||
) -> Self { | ||
Self { | ||
data_buffer, | ||
axis, | ||
shape, | ||
stride, | ||
indices, | ||
iter_count: 0, | ||
} | ||
} | ||
} | ||
|
||
impl<'a, T, B> Iterator for IteratorBase<'a, T, B> | ||
Check warning Code scanning / clippy the following explicit lifetimes could be elided: 'a Warning
the following explicit lifetimes could be elided: 'a
Check warning Code scanning / clippy the following explicit lifetimes could be elided: 'a Warning
the following explicit lifetimes could be elided: 'a
|
||
where | ||
T: RawDataType, | ||
B: DataBuffer<DType = T>, | ||
{ | ||
type Item = TensorView<T>; | ||
|
||
fn next(&mut self) -> Option<Self::Item> { | ||
match self.iter_count < self.shape[self.axis] as isize { | ||
false => None, | ||
true => unsafe { | ||
let mut ptr_offset = 0isize; | ||
let mut data_vec: Vec<T> = Vec::new(); | ||
|
||
let mut new_shape = self.shape.clone(); | ||
let mut new_stride = self.stride.clone(); | ||
|
||
for i in 0..self.axis { | ||
Check warning Code scanning / clippy the loop variable i is only used to index new_stride Warning
the loop variable i is only used to index new\_stride
|
||
new_stride[i] = new_stride[i] / new_shape[self.axis]; | ||
Check warning Code scanning / clippy manual implementation of an assign operation Warning
manual implementation of an assign operation
|
||
} | ||
new_shape.remove(self.axis); | ||
new_stride.remove(self.axis); | ||
|
||
let mut buffer_count: Vec<usize> = vec![0; self.axis + 1]; | ||
|
||
for _i in 0..self.indices { | ||
// Calculating offset on each iteration works like a counter, where each digit is an element | ||
// in an array/vector with a base corresponding to the shape at the index of the digit. | ||
// In the 'units' place, the 'base' is the stride at the axis of iteration. | ||
// These 'digits' are maintained in buffer_count | ||
|
||
let mut curr_axis = self.axis as isize; | ||
data_vec.push( | ||
*self | ||
.data_buffer | ||
.const_ptr() | ||
.offset(self.iter_count * self.stride[self.axis] as isize + ptr_offset), | ||
); | ||
|
||
buffer_count[curr_axis as usize] += 1; | ||
ptr_offset += 1; | ||
while curr_axis >= 0 | ||
&& ((curr_axis == self.axis as isize | ||
&& buffer_count[curr_axis as usize] == self.stride[self.axis]) | ||
|| (curr_axis != self.axis as isize | ||
&& buffer_count[curr_axis as usize] | ||
== self.shape[curr_axis as usize])) | ||
{ | ||
buffer_count[curr_axis as usize] = 0; | ||
curr_axis -= 1; | ||
|
||
if curr_axis < 0 { | ||
break; | ||
} | ||
buffer_count[curr_axis as usize] += 1; | ||
ptr_offset = (buffer_count[curr_axis as usize] | ||
* self.stride[curr_axis as usize]) | ||
as isize; | ||
} | ||
} | ||
|
||
let data_buffer = DataView::from_vec_ref(data_vec.clone(), 0, data_vec.len()); | ||
|
||
self.iter_count += 1; | ||
|
||
Some(TensorView { | ||
data: data_buffer, | ||
shape: new_shape.clone(), | ||
stride: new_stride.clone(), | ||
ndims: new_shape.len(), | ||
}) | ||
}, | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
use crate::data_buffer::{DataBuffer, DataOwned}; | ||
use crate::dtype::RawDataType; | ||
use crate::iterator::iterator_base::IteratorBase; | ||
use crate::{Axis, Tensor}; | ||
|
||
pub trait TensorIterator<T: RawDataType> { | ||
type Buffer: DataBuffer<DType = T>; | ||
fn iter(&self, axis: Axis) -> IteratorBase<T, Self::Buffer>; | ||
} | ||
|
||
impl<T: RawDataType> TensorIterator<T> for Tensor<T> { | ||
type Buffer = DataOwned<T>; | ||
fn iter(&self, axis: Axis) -> IteratorBase<T, Self::Buffer> { | ||
assert!( | ||
axis.0 < self.ndims, | ||
"Axis must be smaller than number of dimensions!" | ||
); | ||
IteratorBase::from( | ||
&self.data, | ||
axis.0, | ||
self.shape.clone(), | ||
self.stride.clone(), | ||
self.size() / self.shape[axis.0], | ||
) | ||
} | ||
} |
Check warning
Code scanning / clippy
unused import: DataOwned Warning