Skip to content

Commit

Permalink
Implemented equals operator between tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
BhavyeMathur committed Dec 16, 2024
1 parent 093faeb commit 41fdc86
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 66 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ A Rust library for linear algebra and machine learning!

## Changelog

### Dec 16, 2024

- Implemented `PartialEq` for `TensorBase`

### Dec 14, 2024

- Added a flat Iterator for `Tensor` and `TensorView`
Expand Down
3 changes: 3 additions & 0 deletions src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pub mod iterator;
pub mod flatten;
pub mod clone;
pub mod squeeze;
pub mod equals;

pub use iterator::*;

use crate::tensor::data_buffer::{DataBuffer, DataOwned, DataView};

Expand Down
4 changes: 2 additions & 2 deletions src/tensor/clone.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::data_buffer::{DataBuffer, DataOwned};
use crate::dtype::RawDataType;
use crate::iterator::collapse_contiguous::collapse_contiguous;
use crate::iterator::flat_index_iterator::FlatIndexIterator;
use crate::iterator::flat_index_generator::FlatIndexGenerator;
use crate::{Tensor, TensorBase, TensorView};
use std::ptr::copy_nonoverlapping;

Expand Down Expand Up @@ -41,7 +41,7 @@ impl<T: RawDataType> TensorClone<T> for TensorView<T> {
let src = self.data.const_ptr();
let mut dst = data.as_mut_ptr();

for i in FlatIndexIterator::from(&shape, &stride) {
for i in FlatIndexGenerator::from(&shape, &stride) {
unsafe {
copy_nonoverlapping(src.offset(i), dst, contiguous_stride);
dst = dst.add(contiguous_stride);
Expand Down
2 changes: 1 addition & 1 deletion src/tensor/dtype.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub trait RawDataType: Clone + Copy {}
pub trait RawDataType: Clone + Copy + PartialEq {}

impl RawDataType for u8 {}
impl RawDataType for u16 {}
Expand Down
21 changes: 21 additions & 0 deletions src/tensor/equals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use crate::data_buffer::DataBuffer;
use crate::dtype::RawDataType;
use crate::iterator::flat_iterator::FlatIterator;
use crate::TensorBase;

impl<B1, T1, B2, T2> PartialEq<TensorBase<B1>> for TensorBase<B2>
where
TensorBase<B1>: FlatIterator<T1>,
TensorBase<B2>: FlatIterator<T2>,
B1: DataBuffer<DType=T1>,
B2: DataBuffer<DType=T2>,
T1: RawDataType,
T2: RawDataType + From<T1>,
{
fn eq(&self, other: &TensorBase<B1>) -> bool {
if self.shape != other.shape {
false;
}
self.flat_iter().zip(other.flat_iter()).all(|(a, b)| a == b.into())
}
}
44 changes: 44 additions & 0 deletions src/tensor/iterator/buffer_iterator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use crate::data_buffer::DataBuffer;
use crate::dtype::RawDataType;
use crate::TensorBase;

#[non_exhaustive]
pub struct BufferIterator<T, I>
where
T: RawDataType,
I: Iterator<Item=isize>,
{
ptr: *const T,
indices: I,
}

impl<T, I> BufferIterator<T, I>
where
T: RawDataType,
I: Iterator<Item=isize>,
{
pub(super) fn from<B>(tensor: &TensorBase<B>, indices: I) -> Self
where
B: DataBuffer<DType=T>,
{
Self {
ptr: tensor.data.const_ptr(),
indices,
}
}
}

impl<T, I> Iterator for BufferIterator<T, I>
where
T: RawDataType,
I: Iterator<Item=isize>,
{
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
match self.indices.next() {
None => None,
Some(i) => Some(unsafe { *self.ptr.offset(i) })
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::iterator::collapse_contiguous::collapse_contiguous;

#[non_exhaustive]
pub struct FlatIndexIterator
pub struct FlatIndexGenerator
{
shape: Vec<usize>,
stride: Vec<usize>,
Expand All @@ -14,8 +14,8 @@ pub struct FlatIndexIterator
flat_index: usize,
}

impl FlatIndexIterator {
pub(in crate::tensor) fn from(shape: &Vec<usize>, stride: &Vec<usize>) -> Self {
impl FlatIndexGenerator {
pub(in crate::tensor) fn from(shape: &[usize], stride: &[usize]) -> Self {
let (shape, stride) = collapse_contiguous(shape, stride);
let ndims = shape.len();
let size = shape.iter().product();
Expand All @@ -32,7 +32,7 @@ impl FlatIndexIterator {
}
}

impl Iterator for FlatIndexIterator {
impl Iterator for FlatIndexGenerator {
type Item = isize;

fn next(&mut self) -> Option<Self::Item> {
Expand Down
51 changes: 16 additions & 35 deletions src/tensor/iterator/flat_iterator.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,25 @@
use crate::data_buffer::DataBuffer;
use crate::dtype::RawDataType;
use crate::TensorBase;
use crate::iterator::buffer_iterator::BufferIterator;
use crate::iterator::flat_index_generator::FlatIndexGenerator;
use crate::{Tensor, TensorView};
use std::ops::Range;

#[non_exhaustive]
pub struct FlatIterator<T, I>
where
T: RawDataType,
I: Iterator<Item=isize>,
{
ptr: *const T,
indices: I,
pub trait FlatIterator<T: RawDataType> {
type Indices: Iterator<Item=isize>;
fn flat_iter(&self) -> BufferIterator<T, Self::Indices>;
}

impl<T, I> FlatIterator<T, I>
where
T: RawDataType,
I: Iterator<Item=isize>,
{
pub(super) fn from<B>(tensor: &TensorBase<B>, indices: I) -> Self
where
B: DataBuffer<DType=T>,
{
Self {
ptr: tensor.data.const_ptr(),
indices,
}
impl<T: RawDataType> FlatIterator<T> for Tensor<T> {
type Indices = Range<isize>;
fn flat_iter(&self) -> BufferIterator<T, Self::Indices> {
BufferIterator::from(self, 0..self.size() as isize)
}
}

impl<T, I> Iterator for FlatIterator<T, I>
where
T: RawDataType,
I: Iterator<Item=isize>,
{
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
match self.indices.next() {
None => None,
Some(i) => Some(unsafe { *self.ptr.offset(i) })
}
impl<T: RawDataType> FlatIterator<T> for TensorView<T> {
type Indices = FlatIndexGenerator;
fn flat_iter(&self) -> BufferIterator<T, Self::Indices> {
let indices = FlatIndexGenerator::from(&self.shape, &self.stride);
BufferIterator::from(self, indices)
}
}
22 changes: 4 additions & 18 deletions src/tensor/iterator/mod.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,8 @@
use crate::dtype::RawDataType;
use crate::iterator::flat_index_iterator::FlatIndexIterator;
use crate::iterator::flat_iterator::FlatIterator;
use crate::{Tensor, TensorView};
use std::ops::Range;

pub mod flat_iterator;
pub mod buffer_iterator;
pub(super) mod collapse_contiguous;
pub mod flat_index_iterator;

impl<T: RawDataType> Tensor<T> {
pub fn flat_iter(&self) -> FlatIterator<T, Range<isize>> {
FlatIterator::from(self, 0..self.size() as isize)
}
}
pub mod flat_index_generator;
pub mod flat_iterator;

impl<T: RawDataType> TensorView<T> {
pub fn flat_iter(&self) -> FlatIterator<T, FlatIndexIterator> {
let indices = FlatIndexIterator::from(&self.shape, &self.stride);
FlatIterator::from(self, indices)
}
}
pub use flat_iterator::*;
12 changes: 6 additions & 6 deletions tests/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,9 @@ fn flatten() {
assert_eq!(b.len(), &18);
assert_eq!(b.ndims(), 1);

assert_eq!(b[0], 10);
assert_eq!(b[5], 15);
assert_eq!(b[17], 27);
let correct = Tensor::from([10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27]);
assert_eq!(b, correct);

let b = a.slice(s![.., 0]).flatten();
assert_eq!(b.shape(), &[9]);
Expand All @@ -362,8 +362,8 @@ fn flatten() {
assert_eq!(b.len(), &4);
assert_eq!(b.ndims(), 1);

assert_eq!(b[0], 14);
assert_eq!(b[3], 21);
let correct = Tensor::from([14, 15, 20, 21]);
assert_eq!(b, correct);
}

#[test]
Expand Down Expand Up @@ -432,4 +432,4 @@ fn unsqueeze_random_dimension_last_axis() {
let b = a.unsqueeze(Axis(2));
assert_eq!(b.shape(), &[2, 3, 1]);
assert_eq!(b.stride(), &[3, 1, 1]);
}
}

0 comments on commit 41fdc86

Please sign in to comment.