Skip to content

Commit

Permalink
Functional iterator for single axis for Tensor (not TensorView)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aman-Amith-Shastry committed Dec 20, 2024
1 parent aa5e061 commit 1c6330c
Show file tree
Hide file tree
Showing 6 changed files with 363 additions and 30 deletions.
18 changes: 18 additions & 0 deletions src/tensor/data_buffer/data_owned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ impl<T: RawDataType> DataOwned<T> {
}
}

// pub fn new_from_ref(data: Vec<&T>) -> Self {
// if data.is_empty() {
// panic!("Tensor::from() failed, cannot create data buffer from empty data");
// }
//
// // take control of the data so that Rust doesn't drop it once the vector goes out of scope
// let mut data = ManuallyDrop::new(data);
//
// // safe to unwrap because we've checked length above
// let ptr = data.as_mut_ptr();
//
// Self {
// len: data.len(),
// capacity: data.capacity(),
// ptr: NonNull::new(ptr).unwrap(),
// }
// }

pub fn from(data: impl Flatten<T> + Homogenous) -> Self {
Self::new(data.flatten())
}
Expand Down
12 changes: 11 additions & 1 deletion src/tensor/data_buffer/data_view.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::mem::ManuallyDrop;
use crate::tensor::data_buffer::{DataBuffer, DataOwned};
use crate::tensor::dtype::RawDataType;
use std::ops::Index;
Expand All @@ -12,14 +13,23 @@ pub struct DataView<T: RawDataType> {
impl<T: RawDataType> DataView<T> {
pub(in crate::tensor) fn from_buffer<B>(value: &B, offset: usize, len: usize) -> Self
where
B: DataBuffer<DType=T>,
B: DataBuffer<DType = T>,
{
assert!(offset + len <= value.len());
Self {
ptr: unsafe { value.ptr().offset(offset as isize) },
len,
}
}

pub(in crate::tensor) fn from_vec_ref(vec: Vec<T>, offset: usize, len: usize) -> Self {
assert!(offset + len <= vec.len());
let mut data = ManuallyDrop::new(vec);
Self {
ptr: NonNull::new(data.as_mut_ptr()).unwrap(),
len,
}
}
}

impl<T: RawDataType> From<&DataOwned<T>> for DataView<T> {
Expand Down
116 changes: 116 additions & 0 deletions src/tensor/iterator/iterator_base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use crate::data_buffer::{DataBuffer, DataOwned, DataView};

Check warning

Code scanning / clippy

unused import: DataOwned Warning

unused import: DataOwned
use crate::dtype::RawDataType;
use crate::{tensor, TensorView};

Check warning

Code scanning / clippy

unused import: tensor Warning

unused import: tensor

#[non_exhaustive]
pub struct IteratorBase<'a, T, B>
where
T: RawDataType,
B: DataBuffer<DType = T>,
{
data_buffer: &'a B,
axis: usize,
shape: Vec<usize>,
stride: Vec<usize>,
indices: usize,
iter_count: isize,
}

impl<'a, T, B> IteratorBase<'a, T, B>
where
T: RawDataType,
B: DataBuffer<DType = T>,
{
pub(super) fn from(
data_buffer: &'a B,
axis: usize,
shape: Vec<usize>,
stride: Vec<usize>,
indices: usize,
) -> Self {
Self {
data_buffer,
axis,
shape,
stride,
indices,
iter_count: 0,
}
}
}

impl<'a, T, B> Iterator for IteratorBase<'a, T, B>

Check warning

Code scanning / clippy

the following explicit lifetimes could be elided: 'a Warning

the following explicit lifetimes could be elided: 'a

Check warning

Code scanning / clippy

the following explicit lifetimes could be elided: 'a Warning

the following explicit lifetimes could be elided: 'a
where
T: RawDataType,
B: DataBuffer<DType = T>,
{
type Item = TensorView<T>;

fn next(&mut self) -> Option<Self::Item> {
match self.iter_count < self.shape[self.axis] as isize {
false => None,
true => unsafe {
let mut ptr_offset = 0isize;
let mut data_vec: Vec<T> = Vec::new();

let mut new_shape = self.shape.clone();
let mut new_stride = self.stride.clone();

for i in 0..self.axis {

Check warning

Code scanning / clippy

the loop variable i is only used to index new_stride Warning

the loop variable i is only used to index new\_stride
new_stride[i] = new_stride[i] / new_shape[self.axis];

Check warning

Code scanning / clippy

manual implementation of an assign operation Warning

manual implementation of an assign operation
}
new_shape.remove(self.axis);
new_stride.remove(self.axis);

let mut buffer_count: Vec<usize> = vec![0; self.axis + 1];

for _i in 0..self.indices {
// Calculating offset on each iteration works like a counter, where each digit is an element
// in an array/vector with a base corresponding to the shape at the index of the digit.
// In the 'units' place, the 'base' is the stride at the axis of iteration.
// These 'digits' are maintained in buffer_count

let mut curr_axis = self.axis as isize;
data_vec.push(
*self
.data_buffer
.const_ptr()
.offset(self.iter_count * self.stride[self.axis] as isize + ptr_offset),
);

buffer_count[curr_axis as usize] += 1;
ptr_offset += 1;
while curr_axis >= 0
&& ((curr_axis == self.axis as isize
&& buffer_count[curr_axis as usize] == self.stride[self.axis])
|| (curr_axis != self.axis as isize
&& buffer_count[curr_axis as usize]
== self.shape[curr_axis as usize]))
{
buffer_count[curr_axis as usize] = 0;
curr_axis -= 1;

if curr_axis < 0 {
break;
}
buffer_count[curr_axis as usize] += 1;
ptr_offset = (buffer_count[curr_axis as usize]
* self.stride[curr_axis as usize])
as isize;
}
}

let data_buffer = DataView::from_vec_ref(data_vec.clone(), 0, data_vec.len());

self.iter_count += 1;

Some(TensorView {
data: data_buffer,
shape: new_shape.clone(),
stride: new_stride.clone(),
ndims: new_shape.len(),
})
},
}
}
}
3 changes: 3 additions & 0 deletions src/tensor/iterator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@ pub mod buffer_iterator;
pub(super) mod collapse_contiguous;
pub mod flat_index_generator;
pub mod flat_iterator;
mod iterator_base;
mod tensor_iterator;

pub use flat_iterator::*;
pub use tensor_iterator::*;
26 changes: 26 additions & 0 deletions src/tensor/iterator/tensor_iterator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use crate::data_buffer::{DataBuffer, DataOwned};
use crate::dtype::RawDataType;
use crate::iterator::iterator_base::IteratorBase;
use crate::{Axis, Tensor};

pub trait TensorIterator<T: RawDataType> {
type Buffer: DataBuffer<DType = T>;
fn iter(&self, axis: Axis) -> IteratorBase<T, Self::Buffer>;
}

impl<T: RawDataType> TensorIterator<T> for Tensor<T> {
type Buffer = DataOwned<T>;
fn iter(&self, axis: Axis) -> IteratorBase<T, Self::Buffer> {
assert!(
axis.0 < self.ndims,
"Axis must be smaller than number of dimensions!"
);
IteratorBase::from(
&self.data,
axis.0,
self.shape.clone(),
self.stride.clone(),
self.size() / self.shape[axis.0],
)
}
}
Loading

0 comments on commit 1c6330c

Please sign in to comment.