Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functional iterator for single axis for Tensor (not TensorView) #3

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/tensor/data_buffer/data_owned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ impl<T: RawDataType> DataOwned<T> {
}
}

// pub fn new_from_ref(data: Vec<&T>) -> Self {
// if data.is_empty() {
// panic!("Tensor::from() failed, cannot create data buffer from empty data");
// }
//
// // take control of the data so that Rust doesn't drop it once the vector goes out of scope
// let mut data = ManuallyDrop::new(data);
//
// // safe to unwrap because we've checked length above
// let ptr = data.as_mut_ptr();
//
// Self {
// len: data.len(),
// capacity: data.capacity(),
// ptr: NonNull::new(ptr).unwrap(),
// }
// }

pub fn from(data: impl Flatten<T> + Homogenous) -> Self {
Self::new(data.flatten())
}
Expand Down
12 changes: 11 additions & 1 deletion src/tensor/data_buffer/data_view.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::mem::ManuallyDrop;
use crate::tensor::data_buffer::{DataBuffer, DataOwned};
use crate::tensor::dtype::RawDataType;
use std::ops::Index;
Expand All @@ -12,14 +13,23 @@ pub struct DataView<T: RawDataType> {
impl<T: RawDataType> DataView<T> {
pub(in crate::tensor) fn from_buffer<B>(value: &B, offset: usize, len: usize) -> Self
where
B: DataBuffer<DType=T>,
B: DataBuffer<DType = T>,
{
assert!(offset + len <= value.len());
Self {
ptr: unsafe { value.ptr().offset(offset as isize) },
len,
}
}

pub(in crate::tensor) fn from_vec_ref(vec: Vec<T>, offset: usize, len: usize) -> Self {
assert!(offset + len <= vec.len());
let mut data = ManuallyDrop::new(vec);
Self {
ptr: NonNull::new(data.as_mut_ptr()).unwrap(),
len,
}
}
}

impl<T: RawDataType> From<&DataOwned<T>> for DataView<T> {
Expand Down
116 changes: 116 additions & 0 deletions src/tensor/iterator/iterator_base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use crate::data_buffer::{DataBuffer, DataOwned, DataView};

Check warning

Code scanning / clippy

unused import: DataOwned Warning

unused import: DataOwned
use crate::dtype::RawDataType;
use crate::{tensor, TensorView};

Check warning

Code scanning / clippy

unused import: tensor Warning

unused import: tensor

#[non_exhaustive]
pub struct IteratorBase<'a, T, B>
where
T: RawDataType,
B: DataBuffer<DType = T>,
{
data_buffer: &'a B,
axis: usize,
shape: Vec<usize>,
stride: Vec<usize>,
indices: usize,
iter_count: isize,
}

impl<'a, T, B> IteratorBase<'a, T, B>
where
T: RawDataType,
B: DataBuffer<DType = T>,
{
pub(super) fn from(
data_buffer: &'a B,
axis: usize,
shape: Vec<usize>,
stride: Vec<usize>,
indices: usize,
) -> Self {
Self {
data_buffer,
axis,
shape,
stride,
indices,
iter_count: 0,
}
}
}

impl<'a, T, B> Iterator for IteratorBase<'a, T, B>

Check warning

Code scanning / clippy

the following explicit lifetimes could be elided: 'a Warning

the following explicit lifetimes could be elided: 'a

Check warning

Code scanning / clippy

the following explicit lifetimes could be elided: 'a Warning

the following explicit lifetimes could be elided: 'a
where
T: RawDataType,
B: DataBuffer<DType = T>,
{
type Item = TensorView<T>;

fn next(&mut self) -> Option<Self::Item> {
match self.iter_count < self.shape[self.axis] as isize {
false => None,
true => unsafe {
let mut ptr_offset = 0isize;
let mut data_vec: Vec<T> = Vec::new();

let mut new_shape = self.shape.clone();
let mut new_stride = self.stride.clone();

for i in 0..self.axis {

Check warning

Code scanning / clippy

the loop variable i is only used to index new_stride Warning

the loop variable i is only used to index new\_stride
new_stride[i] = new_stride[i] / new_shape[self.axis];

Check warning

Code scanning / clippy

manual implementation of an assign operation Warning

manual implementation of an assign operation
}
new_shape.remove(self.axis);
new_stride.remove(self.axis);

let mut buffer_count: Vec<usize> = vec![0; self.axis + 1];

for _i in 0..self.indices {
// Calculating offset on each iteration works like a counter, where each digit is an element
// in an array/vector with a base corresponding to the shape at the index of the digit.
// In the 'units' place, the 'base' is the stride at the axis of iteration.
// These 'digits' are maintained in buffer_count

let mut curr_axis = self.axis as isize;
data_vec.push(
*self
.data_buffer
.const_ptr()
.offset(self.iter_count * self.stride[self.axis] as isize + ptr_offset),
);

buffer_count[curr_axis as usize] += 1;
ptr_offset += 1;
while curr_axis >= 0
&& ((curr_axis == self.axis as isize
&& buffer_count[curr_axis as usize] == self.stride[self.axis])
|| (curr_axis != self.axis as isize
&& buffer_count[curr_axis as usize]
== self.shape[curr_axis as usize]))
{
buffer_count[curr_axis as usize] = 0;
curr_axis -= 1;

if curr_axis < 0 {
break;
}
buffer_count[curr_axis as usize] += 1;
ptr_offset = (buffer_count[curr_axis as usize]
* self.stride[curr_axis as usize])
as isize;
}
}

let data_buffer = DataView::from_vec_ref(data_vec.clone(), 0, data_vec.len());

self.iter_count += 1;

Some(TensorView {
data: data_buffer,
shape: new_shape.clone(),
stride: new_stride.clone(),
ndims: new_shape.len(),
})
},
}
}
}
3 changes: 3 additions & 0 deletions src/tensor/iterator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@ pub mod buffer_iterator;
pub(super) mod collapse_contiguous;
pub mod flat_index_generator;
pub mod flat_iterator;
mod iterator_base;
mod tensor_iterator;

pub use flat_iterator::*;
pub use tensor_iterator::*;
26 changes: 26 additions & 0 deletions src/tensor/iterator/tensor_iterator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use crate::data_buffer::{DataBuffer, DataOwned};
use crate::dtype::RawDataType;
use crate::iterator::iterator_base::IteratorBase;
use crate::{Axis, Tensor};

pub trait TensorIterator<T: RawDataType> {
type Buffer: DataBuffer<DType = T>;
fn iter(&self, axis: Axis) -> IteratorBase<T, Self::Buffer>;
}

impl<T: RawDataType> TensorIterator<T> for Tensor<T> {
type Buffer = DataOwned<T>;
fn iter(&self, axis: Axis) -> IteratorBase<T, Self::Buffer> {
assert!(
axis.0 < self.ndims,
"Axis must be smaller than number of dimensions!"
);
IteratorBase::from(
&self.data,
axis.0,
self.shape.clone(),
self.stride.clone(),
self.size() / self.shape[axis.0],
)
}
}
Loading
Loading