From 41fdc86b4a80e520964733d11732fae30182f6f2 Mon Sep 17 00:00:00 2001 From: Bhavye Mathur Date: Mon, 16 Dec 2024 15:17:14 -0500 Subject: [PATCH] Implemented equals operator between tensors --- README.md | 4 ++ src/tensor.rs | 3 ++ src/tensor/clone.rs | 4 +- src/tensor/dtype.rs | 2 +- src/tensor/equals.rs | 21 ++++++++ src/tensor/iterator/buffer_iterator.rs | 44 ++++++++++++++++ ...ex_iterator.rs => flat_index_generator.rs} | 8 +-- src/tensor/iterator/flat_iterator.rs | 51 ++++++------------- src/tensor/iterator/mod.rs | 22 ++------ tests/tensor.rs | 12 ++--- 10 files changed, 105 insertions(+), 66 deletions(-) create mode 100644 src/tensor/equals.rs create mode 100644 src/tensor/iterator/buffer_iterator.rs rename src/tensor/iterator/{flat_index_iterator.rs => flat_index_generator.rs} (87%) diff --git a/README.md b/README.md index 697f153..a47be95 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,10 @@ A Rust library for linear algebra and machine learning! ## Changelog +### Dec 16, 2024 + +- Implemented `PartialEq` for `TensorBase` + ### Dec 14, 2024 - Added a flat Iterator for `Tensor` and `TensorView` diff --git a/src/tensor.rs b/src/tensor.rs index 3fbdac7..a6dc351 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -9,6 +9,9 @@ pub mod iterator; pub mod flatten; pub mod clone; pub mod squeeze; +pub mod equals; + +pub use iterator::*; use crate::tensor::data_buffer::{DataBuffer, DataOwned, DataView}; diff --git a/src/tensor/clone.rs b/src/tensor/clone.rs index 3cb2936..fafe0f9 100644 --- a/src/tensor/clone.rs +++ b/src/tensor/clone.rs @@ -1,7 +1,7 @@ use crate::data_buffer::{DataBuffer, DataOwned}; use crate::dtype::RawDataType; use crate::iterator::collapse_contiguous::collapse_contiguous; -use crate::iterator::flat_index_iterator::FlatIndexIterator; +use crate::iterator::flat_index_generator::FlatIndexGenerator; use crate::{Tensor, TensorBase, TensorView}; use std::ptr::copy_nonoverlapping; @@ -41,7 +41,7 @@ impl TensorClone for TensorView { let src = self.data.const_ptr(); let mut dst = data.as_mut_ptr(); - for i in FlatIndexIterator::from(&shape, &stride) { + for i in FlatIndexGenerator::from(&shape, &stride) { unsafe { copy_nonoverlapping(src.offset(i), dst, contiguous_stride); dst = dst.add(contiguous_stride); diff --git a/src/tensor/dtype.rs b/src/tensor/dtype.rs index 4455d42..c8a4041 100644 --- a/src/tensor/dtype.rs +++ b/src/tensor/dtype.rs @@ -1,4 +1,4 @@ -pub trait RawDataType: Clone + Copy {} +pub trait RawDataType: Clone + Copy + PartialEq {} impl RawDataType for u8 {} impl RawDataType for u16 {} diff --git a/src/tensor/equals.rs b/src/tensor/equals.rs new file mode 100644 index 0000000..8367ab3 --- /dev/null +++ b/src/tensor/equals.rs @@ -0,0 +1,21 @@ +use crate::data_buffer::DataBuffer; +use crate::dtype::RawDataType; +use crate::iterator::flat_iterator::FlatIterator; +use crate::TensorBase; + +impl PartialEq> for TensorBase +where + TensorBase: FlatIterator, + TensorBase: FlatIterator, + B1: DataBuffer, + B2: DataBuffer, + T1: RawDataType, + T2: RawDataType + From, +{ + fn eq(&self, other: &TensorBase) -> bool { + if self.shape != other.shape { + false; + } + self.flat_iter().zip(other.flat_iter()).all(|(a, b)| a == b.into()) + } +} diff --git a/src/tensor/iterator/buffer_iterator.rs b/src/tensor/iterator/buffer_iterator.rs new file mode 100644 index 0000000..2d1f2fc --- /dev/null +++ b/src/tensor/iterator/buffer_iterator.rs @@ -0,0 +1,44 @@ +use crate::data_buffer::DataBuffer; +use crate::dtype::RawDataType; +use crate::TensorBase; + +#[non_exhaustive] +pub struct BufferIterator +where + T: RawDataType, + I: Iterator, +{ + ptr: *const T, + indices: I, +} + +impl BufferIterator +where + T: RawDataType, + I: Iterator, +{ + pub(super) fn from(tensor: &TensorBase, indices: I) -> Self + where + B: DataBuffer, + { + Self { + ptr: tensor.data.const_ptr(), + indices, + } + } +} + +impl Iterator for BufferIterator +where + T: RawDataType, + I: Iterator, +{ + type Item = T; + + fn next(&mut self) -> Option { + match self.indices.next() { + None => None, + Some(i) => Some(unsafe { *self.ptr.offset(i) }) + } + } +} diff --git a/src/tensor/iterator/flat_index_iterator.rs b/src/tensor/iterator/flat_index_generator.rs similarity index 87% rename from src/tensor/iterator/flat_index_iterator.rs rename to src/tensor/iterator/flat_index_generator.rs index 3f2c05f..1a2b4fa 100644 --- a/src/tensor/iterator/flat_index_iterator.rs +++ b/src/tensor/iterator/flat_index_generator.rs @@ -1,7 +1,7 @@ use crate::iterator::collapse_contiguous::collapse_contiguous; #[non_exhaustive] -pub struct FlatIndexIterator +pub struct FlatIndexGenerator { shape: Vec, stride: Vec, @@ -14,8 +14,8 @@ pub struct FlatIndexIterator flat_index: usize, } -impl FlatIndexIterator { - pub(in crate::tensor) fn from(shape: &Vec, stride: &Vec) -> Self { +impl FlatIndexGenerator { + pub(in crate::tensor) fn from(shape: &[usize], stride: &[usize]) -> Self { let (shape, stride) = collapse_contiguous(shape, stride); let ndims = shape.len(); let size = shape.iter().product(); @@ -32,7 +32,7 @@ impl FlatIndexIterator { } } -impl Iterator for FlatIndexIterator { +impl Iterator for FlatIndexGenerator { type Item = isize; fn next(&mut self) -> Option { diff --git a/src/tensor/iterator/flat_iterator.rs b/src/tensor/iterator/flat_iterator.rs index 2bcde2d..dfe7b23 100644 --- a/src/tensor/iterator/flat_iterator.rs +++ b/src/tensor/iterator/flat_iterator.rs @@ -1,44 +1,25 @@ -use crate::data_buffer::DataBuffer; use crate::dtype::RawDataType; -use crate::TensorBase; +use crate::iterator::buffer_iterator::BufferIterator; +use crate::iterator::flat_index_generator::FlatIndexGenerator; +use crate::{Tensor, TensorView}; +use std::ops::Range; -#[non_exhaustive] -pub struct FlatIterator -where - T: RawDataType, - I: Iterator, -{ - ptr: *const T, - indices: I, +pub trait FlatIterator { + type Indices: Iterator; + fn flat_iter(&self) -> BufferIterator; } -impl FlatIterator -where - T: RawDataType, - I: Iterator, -{ - pub(super) fn from(tensor: &TensorBase, indices: I) -> Self - where - B: DataBuffer, - { - Self { - ptr: tensor.data.const_ptr(), - indices, - } +impl FlatIterator for Tensor { + type Indices = Range; + fn flat_iter(&self) -> BufferIterator { + BufferIterator::from(self, 0..self.size() as isize) } } -impl Iterator for FlatIterator -where - T: RawDataType, - I: Iterator, -{ - type Item = T; - - fn next(&mut self) -> Option { - match self.indices.next() { - None => None, - Some(i) => Some(unsafe { *self.ptr.offset(i) }) - } +impl FlatIterator for TensorView { + type Indices = FlatIndexGenerator; + fn flat_iter(&self) -> BufferIterator { + let indices = FlatIndexGenerator::from(&self.shape, &self.stride); + BufferIterator::from(self, indices) } } diff --git a/src/tensor/iterator/mod.rs b/src/tensor/iterator/mod.rs index 7d44098..29d8ee3 100644 --- a/src/tensor/iterator/mod.rs +++ b/src/tensor/iterator/mod.rs @@ -1,22 +1,8 @@ use crate::dtype::RawDataType; -use crate::iterator::flat_index_iterator::FlatIndexIterator; -use crate::iterator::flat_iterator::FlatIterator; -use crate::{Tensor, TensorView}; -use std::ops::Range; -pub mod flat_iterator; +pub mod buffer_iterator; pub(super) mod collapse_contiguous; -pub mod flat_index_iterator; - -impl Tensor { - pub fn flat_iter(&self) -> FlatIterator> { - FlatIterator::from(self, 0..self.size() as isize) - } -} +pub mod flat_index_generator; +pub mod flat_iterator; -impl TensorView { - pub fn flat_iter(&self) -> FlatIterator { - let indices = FlatIndexIterator::from(&self.shape, &self.stride); - FlatIterator::from(self, indices) - } -} +pub use flat_iterator::*; diff --git a/tests/tensor.rs b/tests/tensor.rs index 2b46a70..004d2ec 100644 --- a/tests/tensor.rs +++ b/tests/tensor.rs @@ -342,9 +342,9 @@ fn flatten() { assert_eq!(b.len(), &18); assert_eq!(b.ndims(), 1); - assert_eq!(b[0], 10); - assert_eq!(b[5], 15); - assert_eq!(b[17], 27); + let correct = Tensor::from([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27]); + assert_eq!(b, correct); let b = a.slice(s![.., 0]).flatten(); assert_eq!(b.shape(), &[9]); @@ -362,8 +362,8 @@ fn flatten() { assert_eq!(b.len(), &4); assert_eq!(b.ndims(), 1); - assert_eq!(b[0], 14); - assert_eq!(b[3], 21); + let correct = Tensor::from([14, 15, 20, 21]); + assert_eq!(b, correct); } #[test] @@ -432,4 +432,4 @@ fn unsqueeze_random_dimension_last_axis() { let b = a.unsqueeze(Axis(2)); assert_eq!(b.shape(), &[2, 3, 1]); assert_eq!(b.stride(), &[3, 1, 1]); -} \ No newline at end of file +}