diff --git a/src/tensor/data_buffer/data_owned.rs b/src/tensor/data_buffer/data_owned.rs index 5cbf156..5edc257 100644 --- a/src/tensor/data_buffer/data_owned.rs +++ b/src/tensor/data_buffer/data_owned.rs @@ -43,6 +43,24 @@ impl DataOwned { } } + // pub fn new_from_ref(data: Vec<&T>) -> Self { + // if data.is_empty() { + // panic!("Tensor::from() failed, cannot create data buffer from empty data"); + // } + // + // // take control of the data so that Rust doesn't drop it once the vector goes out of scope + // let mut data = ManuallyDrop::new(data); + // + // // safe to unwrap because we've checked length above + // let ptr = data.as_mut_ptr(); + // + // Self { + // len: data.len(), + // capacity: data.capacity(), + // ptr: NonNull::new(ptr).unwrap(), + // } + // } + pub fn from(data: impl Flatten + Homogenous) -> Self { Self::new(data.flatten()) } diff --git a/src/tensor/data_buffer/data_view.rs b/src/tensor/data_buffer/data_view.rs index 00ae571..7a1df97 100644 --- a/src/tensor/data_buffer/data_view.rs +++ b/src/tensor/data_buffer/data_view.rs @@ -1,3 +1,4 @@ +use std::mem::ManuallyDrop; use crate::tensor::data_buffer::{DataBuffer, DataOwned}; use crate::tensor::dtype::RawDataType; use std::ops::Index; @@ -12,7 +13,7 @@ pub struct DataView { impl DataView { pub(in crate::tensor) fn from_buffer(value: &B, offset: usize, len: usize) -> Self where - B: DataBuffer, + B: DataBuffer, { assert!(offset + len <= value.len()); Self { @@ -20,6 +21,15 @@ impl DataView { len, } } + + pub(in crate::tensor) fn from_vec_ref(vec: Vec, offset: usize, len: usize) -> Self { + assert!(offset + len <= vec.len()); + let mut data = ManuallyDrop::new(vec); + Self { + ptr: NonNull::new(data.as_mut_ptr()).unwrap(), + len, + } + } } impl From<&DataOwned> for DataView { diff --git a/src/tensor/iterator/iterator_base.rs b/src/tensor/iterator/iterator_base.rs new file mode 100644 index 0000000..a724e72 --- /dev/null +++ b/src/tensor/iterator/iterator_base.rs @@ -0,0 +1,116 @@ +use crate::data_buffer::{DataBuffer, DataOwned, DataView}; +use crate::dtype::RawDataType; +use crate::{tensor, TensorView}; + +#[non_exhaustive] +pub struct IteratorBase<'a, T, B> +where + T: RawDataType, + B: DataBuffer, +{ + data_buffer: &'a B, + axis: usize, + shape: Vec, + stride: Vec, + indices: usize, + iter_count: isize, +} + +impl<'a, T, B> IteratorBase<'a, T, B> +where + T: RawDataType, + B: DataBuffer, +{ + pub(super) fn from( + data_buffer: &'a B, + axis: usize, + shape: Vec, + stride: Vec, + indices: usize, + ) -> Self { + Self { + data_buffer, + axis, + shape, + stride, + indices, + iter_count: 0, + } + } +} + +impl<'a, T, B> Iterator for IteratorBase<'a, T, B> +where + T: RawDataType, + B: DataBuffer, +{ + type Item = TensorView; + + fn next(&mut self) -> Option { + match self.iter_count < self.shape[self.axis] as isize { + false => None, + true => unsafe { + let mut ptr_offset = 0isize; + let mut data_vec: Vec = Vec::new(); + + let mut new_shape = self.shape.clone(); + let mut new_stride = self.stride.clone(); + + for i in 0..self.axis { + new_stride[i] = new_stride[i] / new_shape[self.axis]; + } + new_shape.remove(self.axis); + new_stride.remove(self.axis); + + let mut buffer_count: Vec = vec![0; self.axis + 1]; + + for _i in 0..self.indices { + // Calculating offset on each iteration works like a counter, where each digit is an element + // in an array/vector with a base corresponding to the shape at the index of the digit. + // In the 'units' place, the 'base' is the stride at the axis of iteration. + // These 'digits' are maintained in buffer_count + + let mut curr_axis = self.axis as isize; + data_vec.push( + *self + .data_buffer + .const_ptr() + .offset(self.iter_count * self.stride[self.axis] as isize + ptr_offset), + ); + + buffer_count[curr_axis as usize] += 1; + ptr_offset += 1; + while curr_axis >= 0 + && ((curr_axis == self.axis as isize + && buffer_count[curr_axis as usize] == self.stride[self.axis]) + || (curr_axis != self.axis as isize + && buffer_count[curr_axis as usize] + == self.shape[curr_axis as usize])) + { + buffer_count[curr_axis as usize] = 0; + curr_axis -= 1; + + if curr_axis < 0 { + break; + } + buffer_count[curr_axis as usize] += 1; + ptr_offset = (buffer_count[curr_axis as usize] + * self.stride[curr_axis as usize]) + as isize; + } + } + + let data_buffer = DataView::from_vec_ref(data_vec.clone(), 0, data_vec.len()); + + self.iter_count += 1; + + Some(TensorView { + data: data_buffer, + shape: new_shape.clone(), + stride: new_stride.clone(), + ndims: new_shape.len(), + }) + }, + } + } +} diff --git a/src/tensor/iterator/mod.rs b/src/tensor/iterator/mod.rs index 29d8ee3..8a39489 100644 --- a/src/tensor/iterator/mod.rs +++ b/src/tensor/iterator/mod.rs @@ -4,5 +4,8 @@ pub mod buffer_iterator; pub(super) mod collapse_contiguous; pub mod flat_index_generator; pub mod flat_iterator; +mod iterator_base; +mod tensor_iterator; pub use flat_iterator::*; +pub use tensor_iterator::*; \ No newline at end of file diff --git a/src/tensor/iterator/tensor_iterator.rs b/src/tensor/iterator/tensor_iterator.rs new file mode 100644 index 0000000..efcc1d4 --- /dev/null +++ b/src/tensor/iterator/tensor_iterator.rs @@ -0,0 +1,26 @@ +use crate::data_buffer::{DataBuffer, DataOwned}; +use crate::dtype::RawDataType; +use crate::iterator::iterator_base::IteratorBase; +use crate::{Axis, Tensor}; + +pub trait TensorIterator { + type Buffer: DataBuffer; + fn iter(&self, axis: Axis) -> IteratorBase; +} + +impl TensorIterator for Tensor { + type Buffer = DataOwned; + fn iter(&self, axis: Axis) -> IteratorBase { + assert!( + axis.0 < self.ndims, + "Axis must be smaller than number of dimensions!" + ); + IteratorBase::from( + &self.data, + axis.0, + self.shape.clone(), + self.stride.clone(), + self.size() / self.shape[axis.0], + ) + } +} diff --git a/tests/tensor.rs b/tests/tensor.rs index c2df1c9..51b7d90 100644 --- a/tests/tensor.rs +++ b/tests/tensor.rs @@ -200,10 +200,7 @@ fn slice_along_nd() { #[test] fn slice_homogenous() { - let a = Tensor::from([ - [[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]], - ]); + let a = Tensor::from([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]); let slice = a.slice([1, 1]); @@ -312,7 +309,10 @@ fn flat_iter() { ]); let slice: Vec<_> = a.flat_iter().collect(); - assert_eq!(slice, [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]); + assert_eq!( + slice, + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27] + ); let b = a.slice(s![.., 0]); let slice: Vec<_> = b.flat_iter().collect(); @@ -342,8 +342,9 @@ fn flatten() { assert_eq!(b.len(), &18); assert_eq!(b.ndims(), 1); - let correct = Tensor::from([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, - 20, 21, 22, 23, 24, 25, 26, 27]); + let correct = Tensor::from([ + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, + ]); assert_eq!(b, correct); let b = a.slice(s![.., 0]).flatten(); @@ -368,9 +369,7 @@ fn flatten() { #[test] fn squeeze_first_dimension() { - let a = Tensor::from([ - [[[1, 2, 3], [4, 5, 6]]], - ]); + let a = Tensor::from([[[[1, 2, 3], [4, 5, 6]]]]); let b = a.squeeze(); assert_eq!(b.shape(), &[2, 3]); assert_eq!(b.stride(), &[3, 1]); @@ -378,9 +377,7 @@ fn squeeze_first_dimension() { #[test] fn squeeze_multiple_dimensions() { - let a = Tensor::from([ - [[[[1, 2, 3]], [[4, 5, 6]]]], - ]); + let a = Tensor::from([[[[[1, 2, 3]], [[4, 5, 6]]]]]); let b = a.squeeze(); assert_eq!(b.shape(), &[2, 3]); assert_eq!(b.stride(), &[3, 1]); @@ -435,7 +432,7 @@ fn unsqueeze_random_dimension_last_axis() { } #[test] -fn full_i32(){ +fn full_i32() { let a = Tensor::full(3, vec![2, 3]); assert_eq!(a.shape(), &[2, 3]); assert_eq!(a.stride(), &[3, 1]); @@ -447,8 +444,8 @@ fn full_i32(){ } #[test] -fn full_f64(){ - let a = Tensor::full(3.2,vec![4, 6, 2]); +fn full_f64() { + let a = Tensor::full(3.2, vec![4, 6, 2]); assert_eq!(a.shape(), &[4, 6, 2]); let b = a.flatten(); let b_len = *b.len(); @@ -458,8 +455,8 @@ fn full_f64(){ } #[test] -fn full_bool(){ - let a: Tensor = Tensor::full(true,vec![3, 5, 3]); +fn full_bool() { + let a: Tensor = Tensor::full(true, vec![3, 5, 3]); assert_eq!(a.shape(), &[3, 5, 3]); assert_eq!(a.stride(), &[15, 3, 1]); let b = a.flatten(); @@ -470,7 +467,7 @@ fn full_bool(){ } #[test] -fn ones_u8(){ +fn ones_u8() { let a: Tensor = Tensor::ones(vec![3, 5, 3]); assert_eq!(a.shape(), &[3, 5, 3]); assert_eq!(a.stride(), &[15, 3, 1]); @@ -482,7 +479,7 @@ fn ones_u8(){ } #[test] -fn ones_i32(){ +fn ones_i32() { let a: Tensor = Tensor::ones(vec![3, 5, 3]); assert_eq!(a.shape(), &[3, 5, 3]); assert_eq!(a.stride(), &[15, 3, 1]); @@ -494,7 +491,7 @@ fn ones_i32(){ } #[test] -fn ones_1d(){ +fn ones_1d() { let a: Tensor = Tensor::ones(vec![4]); assert_eq!(a.shape(), &[4]); let a_len = *a.len(); @@ -504,7 +501,7 @@ fn ones_1d(){ } #[test] -fn ones_f64(){ +fn ones_f64() { let a: Tensor = Tensor::ones(vec![4]); assert_eq!(a.shape(), &[4]); let a_len = *a.len(); @@ -514,7 +511,7 @@ fn ones_f64(){ } #[test] -fn ones_bool(){ +fn ones_bool() { let a: Tensor = Tensor::ones(vec![3, 5, 3]); assert_eq!(a.shape(), &[3, 5, 3]); assert_eq!(a.stride(), &[15, 3, 1]); @@ -526,7 +523,7 @@ fn ones_bool(){ } #[test] -fn zeroes_u8(){ +fn zeroes_u8() { let a: Tensor = Tensor::zeros(vec![3, 5, 3]); assert_eq!(a.shape(), &[3, 5, 3]); assert_eq!(a.stride(), &[15, 3, 1]); @@ -538,7 +535,7 @@ fn zeroes_u8(){ } #[test] -fn zeroes_i32(){ +fn zeroes_i32() { let a: Tensor = Tensor::zeros(vec![3, 5, 3]); assert_eq!(a.shape(), &[3, 5, 3]); assert_eq!(a.stride(), &[15, 3, 1]); @@ -550,7 +547,7 @@ fn zeroes_i32(){ } #[test] -fn zeroes_1d(){ +fn zeroes_1d() { let a: Tensor = Tensor::zeros(vec![4]); assert_eq!(a.shape(), &[4]); let a_len = *a.len(); @@ -560,7 +557,7 @@ fn zeroes_1d(){ } #[test] -fn zeroes_f64(){ +fn zeroes_f64() { let a: Tensor = Tensor::zeros(vec![4]); assert_eq!(a.shape(), &[4]); let a_len = *a.len(); @@ -570,7 +567,7 @@ fn zeroes_f64(){ } #[test] -fn zeroes_bool(){ +fn zeroes_bool() { let a: Tensor = Tensor::zeros(vec![3, 5, 3]); assert_eq!(a.shape(), &[3, 5, 3]); assert_eq!(a.stride(), &[15, 3, 1]); @@ -579,4 +576,167 @@ fn zeroes_bool(){ for i in 0..b_len { assert_eq!(b[i], false); } -} \ No newline at end of file +} + +#[test] +fn basic_iteration() { + let a = Tensor::from([[1, 2, 3], [4, 5, 6]]); + // println!("{:?}", a.shape()); + for row in a.iter(Axis(0)) { + print!("{:?} ", row[0]); + print!("{:?} ", row[1]); + println!("{:?}", row[2]); + } +} + +#[test] +fn single_dimension_iteration() { + let a = Tensor::from([1, 2, 3, 4, 5, 6]); + // println!("{:?}", a.shape()); + let v_iter: Vec<_> = a.iter(Axis(0)).collect(); + println!("{:?}", v_iter); +} + +#[test] +fn three_dimension_iteration() { + let a = Tensor::from([ + [[10, 11, 12], [13, 14, 15]], + [[16, 17, 18], [19, 20, 21]], + [[22, 23, 24], [25, 26, 27]], + ]); + + // a.shape() = [3, 2, 3] + // a.stride() = [6, 3, 1] + + // for t in a.iter(Axis(0)) { + // print!("{:?} ", t[[0, 0]]); + // print!("{:?} ", t[[0, 1]]); + // println!("{:?} ", t[[0, 2]]); + // + // print!("{:?} ", t[[1, 0]]); + // print!("{:?} ", t[[1, 1]]); + // println!("{:?} ", t[[1, 2]]); + // + // println!(); + // } + + // for t in a.iter(Axis(1)){ + // print!("{:?} ", t[[0, 0]]); + // print!("{:?} ", t[[0, 1]]); + // println!("{:?} ", t[[0, 2]]); + // + // print!("{:?} ", t[[1, 0]]); + // print!("{:?} ", t[[1, 1]]); + // println!("{:?} ", t[[1, 2]]); + // + // print!("{:?} ", t[[2, 0]]); + // print!("{:?} ", t[[2, 1]]); + // println!("{:?} ", t[[2, 2]]); + // println!(); + // } + + for t in a.iter(Axis(2)){ + println!("{:?} ", t.shape()); + + print!("{:?} ", t[[0, 0]]); + println!("{:?} ", t[[0, 1]]); + + print!("{:?} ", t[[1, 0]]); + println!("{:?} ", t[[1, 1]]); + + print!("{:?} ", t[[2, 0]]); + println!("{:?} ", t[[2, 1]]); + println!(); + } +} + +#[test] +fn four_dimension_iteration() { + let a = [ + [ + [[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]], + [[13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24]], + ], + [ + [[25, 26, 27], [28, 29, 30]], + [[31, 32, 33], [34, 35, 36]], + [[37, 38, 39], [40, 41, 42]], + [[43, 44, 45], [46, 47, 48]], + ], + [ + [[49, 50, 51], [52, 53, 54]], + [[55, 56, 57], [58, 59, 60]], + [[61, 62, 63], [64, 65, 66]], + [[67, 68, 69], [70, 71, 72]], + ], + ]; + + let tensor = Tensor::from(a); + + println!("{:?}", tensor.shape()); + println!("{:?}", tensor.stride()); + + for t in tensor.iter(Axis(0)){ + print!("{:?} ", t[[0, 0, 0]]); + print!("{:?} ", t[[0, 0, 1]]); + print!("{:?} ", t[[0, 0, 2]]); + + print!("{:?} ", t[[0, 1, 0]]); + print!("{:?} ", t[[0, 1, 1]]); + print!("{:?} ", t[[0, 1, 2]]); + + print!("{:?} ", t[[1, 0, 0]]); + print!("{:?} ", t[[1, 0, 1]]); + print!("{:?} ", t[[1, 0, 2]]); + + print!("{:?} ", t[[1, 1, 0]]); + print!("{:?} ", t[[1, 1, 1]]); + print!("{:?} ", t[[1, 1, 2]]); + + print!("{:?} ", t[[2, 0, 0]]); + print!("{:?} ", t[[2, 0, 1]]); + print!("{:?} ", t[[2, 0, 2]]); + + print!("{:?} ", t[[2, 1, 0]]); + print!("{:?} ", t[[2, 1, 1]]); + print!("{:?} ", t[[2, 1, 2]]); + + print!("{:?} ", t[[3, 0, 0]]); + print!("{:?} ", t[[3, 0, 1]]); + print!("{:?} ", t[[3, 0, 2]]); + + print!("{:?} ", t[[3, 1, 0]]); + print!("{:?} ", t[[3, 1, 1]]); + print!("{:?} ", t[[3, 1, 2]]); + + println!(); + + } + + // for t in tensor.iter(Axis(1)) { + // print!("{:?} ", t[[0, 0, 0]]); + // print!("{:?} ", t[[0, 0, 1]]); + // print!("{:?} ", t[[0, 0, 2]]); + // print!("{:?} ", t[[0, 1, 0]]); + // print!("{:?} ", t[[0, 1, 1]]); + // print!("{:?} ", t[[0, 1, 2]]); + // + // print!("{:?} ", t[[1, 0, 0]]); + // print!("{:?} ", t[[1, 0, 1]]); + // print!("{:?} ", t[[1, 0, 2]]); + // print!("{:?} ", t[[1, 1, 0]]); + // print!("{:?} ", t[[1, 1, 1]]); + // print!("{:?} ", t[[1, 1, 2]]); + // + // print!("{:?} ", t[[2, 0, 0]]); + // print!("{:?} ", t[[2, 0, 1]]); + // print!("{:?} ", t[[2, 0, 2]]); + // print!("{:?} ", t[[2, 1, 0]]); + // print!("{:?} ", t[[2, 1, 1]]); + // print!("{:?} ", t[[2, 1, 2]]); + // + // println!(); + // } +}