Skip to content

Commit

Permalink
zeros, ones, full constructors now accept arrays as shape
Browse files Browse the repository at this point in the history
  • Loading branch information
BhavyeMathur committed Dec 30, 2024
1 parent 50a1e34 commit bd52256
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 90 deletions.
9 changes: 5 additions & 4 deletions src/tensor/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::tensor::{Tensor, TensorBase, TensorView};
use crate::traits::flatten::Flatten;
use crate::traits::nested::Nested;
use crate::traits::shape::Shape;

use crate::traits::to_vec::ToVec;

// calculates the stride from the tensor's shape
// shape [5, 3, 2, 1] -> stride [10, 2, 1, 1]
Expand Down Expand Up @@ -39,7 +39,8 @@ impl<T: RawDataType> Tensor<T> {
}
}

pub fn full(n: T, shape: Vec<usize>) -> Self {
pub fn full(n: T, shape: impl ToVec<usize>) -> Self {
let shape = shape.to_vec();
assert!(!shape.is_empty(), "Cannot create a zero-dimension tensor!");

let vector_ns = vec![n; shape.iter().product()];
Expand All @@ -54,14 +55,14 @@ impl<T: RawDataType> Tensor<T> {
}
}

pub fn zeros(shape: Vec<usize>) -> Self
pub fn zeros(shape: impl ToVec<usize>) -> Self
where
T: RawDataType + From<bool>,
{
Self::full(false.into(), shape)
}

pub fn ones(shape: Vec<usize>) -> Self
pub fn ones(shape: impl ToVec<usize>) -> Self
where
T: RawDataType + From<bool>,
{
Expand Down
1 change: 1 addition & 0 deletions src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ pub(crate) mod homogenous;
pub(crate) mod nested;
pub(crate) mod shape;
pub(crate) mod haslength;
pub(crate) mod to_vec;
2 changes: 1 addition & 1 deletion src/traits/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::fmt::Debug;

use std::ptr::copy_nonoverlapping;

pub trait Flatten<A: RawDataType> {
pub(crate) trait Flatten<A: RawDataType> {
fn flatten(self) -> Vec<A>;
}

Expand Down
2 changes: 1 addition & 1 deletion src/traits/haslength.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub trait HasLength {
pub(crate) trait HasLength {
fn len(&self) -> usize;
}

Expand Down
2 changes: 1 addition & 1 deletion src/traits/homogenous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use crate::recursive_trait_base_cases;
use crate::traits::shape::Shape;

pub trait Homogenous {
pub(crate) trait Homogenous {
fn check_homogenous(&self) -> bool;
}

Expand Down
2 changes: 1 addition & 1 deletion src/traits/nested.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::tensor::dtype::RawDataType;

pub trait Nested<const D: usize> {}
pub(crate) trait Nested<const D: usize> {}

impl<T> Nested<1> for Vec<T> where T: RawDataType {}
impl<T> Nested<2> for Vec<T> where T: Nested<1> {}
Expand Down
2 changes: 1 addition & 1 deletion src/traits/shape.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::recursive_trait_base_cases;
use crate::traits::homogenous::Homogenous;

pub trait Shape: Homogenous {
pub(crate) trait Shape: Homogenous {
fn shape(&self) -> Vec<usize>;
}

Expand Down
15 changes: 15 additions & 0 deletions src/traits/to_vec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pub(crate) trait ToVec<T> {
fn to_vec(self) -> Vec<T>;
}

impl<T> ToVec<T> for Vec<T> {
fn to_vec(self) -> Vec<T> {
self
}
}

impl<T, const N: usize> ToVec<T> for [T; N] {
fn to_vec(self) -> Vec<T> {
Vec::from(self)
}
}
114 changes: 33 additions & 81 deletions tests/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,150 +434,102 @@ fn unsqueeze_random_dimension_last_axis() {
}

#[test]
fn full_i32(){
let a = Tensor::full(3, vec![2, 3]);
fn full_i32() {
let a = Tensor::full(3, [2, 3]);
assert_eq!(a.shape(), &[2, 3]);
assert_eq!(a.stride(), &[3, 1]);
let b = a.flatten();
let b_len = b.len().clone();
for i in 0..b_len {
assert_eq!(b[i], 3i32);
}
assert!(a.flat_iter().all(|x| x == 3));
}

#[test]
fn full_f64(){
let a = Tensor::full(3.2,vec![4, 6, 2]);
fn full_f64() {
let a = Tensor::full(3.2, [4, 6, 2]);
assert_eq!(a.shape(), &[4, 6, 2]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], 3.2f64);
}
assert!(a.flat_iter().all(|x| x == 3.2));
}

#[test]
fn full_bool(){
let a: Tensor<bool> = Tensor::full(true,vec![3, 5, 3]);
fn full_bool() {
let a: Tensor<bool> = Tensor::full(true, vec![3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], true);
}
assert!(a.flat_iter().all(|x| x == true));
}

#[test]
fn ones_u8(){
let a: Tensor<u8> = Tensor::ones(vec![3, 5, 3]);
fn ones_u8() {
let a: Tensor<u8> = Tensor::ones([3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], 1u8);
}
assert!(a.flat_iter().all(|x| x == 1));
}

#[test]
fn ones_i32(){
fn ones_i32() {
let a: Tensor<i32> = Tensor::ones(vec![3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], 1i32);
}
assert!(a.flat_iter().all(|x| x == 1));
}

#[test]
fn ones_1d(){
let a: Tensor<u8> = Tensor::ones(vec![4]);
fn ones_1d() {
let a: Tensor<u8> = Tensor::ones([4]);
assert_eq!(a.shape(), &[4]);
let a_len = *a.len();
for i in 0..a_len {
assert_eq!(a[i], 1u8);
}
assert!(a.flat_iter().all(|x| x == 1));
}

#[test]
fn ones_f64(){
fn ones_f64() {
let a: Tensor<f64> = Tensor::ones(vec![4]);
assert_eq!(a.shape(), &[4]);
let a_len = *a.len();
for i in 0..a_len {
assert_eq!(a[i], 1f64);
}
assert!(a.flat_iter().all(|x| x == 1.0));
}

#[test]
fn ones_bool(){
fn ones_bool() {
let a: Tensor<bool> = Tensor::ones(vec![3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], true);
}
assert!(a.flat_iter().all(|x| x == true));
}

#[test]
fn zeroes_u8(){
let a: Tensor<u8> = Tensor::zeros(vec![3, 5, 3]);
fn zeroes_u8() {
let a: Tensor<u8> = Tensor::zeros([3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], 0u8);
}
assert!(a.flat_iter().all(|x| x == 0));
}

#[test]
fn zeroes_i32(){
fn zeroes_i32() {
let a: Tensor<i32> = Tensor::zeros(vec![3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], 0i32);
}
assert!(a.flat_iter().all(|x| x == 0));
}

#[test]
fn zeroes_1d(){
let a: Tensor<u8> = Tensor::zeros(vec![4]);
fn zeroes_1d() {
let a: Tensor<u8> = Tensor::zeros([4]);
assert_eq!(a.shape(), &[4]);
let a_len = *a.len();
for i in 0..a_len {
assert_eq!(a[i], 0u8);
}
assert!(a.flat_iter().all(|x| x == 0));
}

#[test]
fn zeroes_f64(){
fn zeroes_f64() {
let a: Tensor<f64> = Tensor::zeros(vec![4]);
assert_eq!(a.shape(), &[4]);
let a_len = *a.len();
for i in 0..a_len {
assert_eq!(a[i], 0f64);
}
assert!(a.flat_iter().all(|x| x == 0.0));
}

#[test]
fn zeroes_bool(){
fn zeroes_bool() {
let a: Tensor<bool> = Tensor::zeros(vec![3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], false);
}
assert!(a.flat_iter().all(|x| x == false));
}
#[test]
fn iterate() {
Expand Down

0 comments on commit bd52256

Please sign in to comment.