From 7114524b66d649b548ca47e0c71e22849038e93c Mon Sep 17 00:00:00 2001 From: Aman-Amith-Shastry Date: Sun, 15 Dec 2024 21:25:08 -0500 Subject: [PATCH 1/3] Additional constructors --- src/tensor/constructors.rs | 29 +++++++++++++++++++++++++++++ tests/tensor.rs | 10 ++++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/tensor/constructors.rs b/src/tensor/constructors.rs index 8f33aca..37bd69c 100644 --- a/src/tensor/constructors.rs +++ b/src/tensor/constructors.rs @@ -31,6 +31,35 @@ impl Tensor { ndims: D, } } + + pub fn full(n: T, shape: Vec) -> Self { + let vector_ns = vec![n; shape.iter().product()]; + + let mut stride = vec![0; shape.len()]; + + let ndims = shape.len(); + + let mut p = 1; + for i in (0..ndims).rev() { + stride[i] = p; + p *= shape[i]; + } + + Self{ + data: DataOwned::new(vector_ns), + shape, + stride, + ndims + } + } + + pub fn zeros(shape: Vec) -> Self { + Self::full(0, shape) + } + + pub fn ones(shape: Vec) -> Self { + Self::full(1, shape) + } } impl TensorView { diff --git a/tests/tensor.rs b/tests/tensor.rs index 2b46a70..3618cab 100644 --- a/tests/tensor.rs +++ b/tests/tensor.rs @@ -420,7 +420,7 @@ fn unsqueeze_random_dimension_first_axis() { #[test] fn unsqueeze_random_dimension_axis_1() { - let a: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); + let a = Tensor::from([[1, 2, 3], [4, 5, 6]]); let b = a.unsqueeze(Axis(1)); assert_eq!(b.shape(), &[2, 1, 3]); assert_eq!(b.stride(), &[3, 3, 1]); @@ -428,8 +428,14 @@ fn unsqueeze_random_dimension_axis_1() { #[test] fn unsqueeze_random_dimension_last_axis() { - let a: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); + let a = Tensor::from([[1, 2, 3], [4, 5, 6]]); let b = a.unsqueeze(Axis(2)); assert_eq!(b.shape(), &[2, 3, 1]); assert_eq!(b.stride(), &[3, 1, 1]); +} + +#[test] +fn full_n(){ + let a = Tensor::full(3, vec![2, 3]); + assert_eq!(a.shape(), &[2, 3]); } \ No newline at end of file From b70f0f355b06787dbfa0820a5c0c598502b9c3e1 Mon Sep 17 00:00:00 2001 From: Aman-Amith-Shastry Date: Mon, 16 Dec 2024 16:44:49 -0500 Subject: [PATCH 2/3] Fixed boolean implementation for zeroes and ones --- src/tensor/constructors.rs | 8 ++++---- tests/tensor.rs | 29 +++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/tensor/constructors.rs b/src/tensor/constructors.rs index eb950c2..0a18934 100644 --- a/src/tensor/constructors.rs +++ b/src/tensor/constructors.rs @@ -57,16 +57,16 @@ impl Tensor { pub fn zeros(shape: Vec) -> Self where - T: RawDataType + From, + T: RawDataType + From, { - Self::full(0.into(), shape) + Self::full(false.into(), shape) } pub fn ones(shape: Vec) -> Self where - T: RawDataType + From, + T: RawDataType + From, { - Self::full(1.into(), shape) + Self::full(true.into(), shape) } } diff --git a/tests/tensor.rs b/tests/tensor.rs index d5e60e0..121701e 100644 --- a/tests/tensor.rs +++ b/tests/tensor.rs @@ -442,7 +442,7 @@ fn full_i32(){ let b = a.flatten(); let b_len = b.len().clone(); for i in 0..b_len { - assert_eq!(b[i], 3); + assert_eq!(b[i], 3i32); } } @@ -453,7 +453,7 @@ fn full_f64(){ let b = a.flatten(); let b_len = *b.len(); for i in 0..b_len { - assert_eq!(b[i], 3.2); + assert_eq!(b[i], 3.2f64); } } @@ -468,6 +468,7 @@ fn full_bool(){ assert_eq!(b[i], true); } } + #[test] fn ones_u8(){ let a: Tensor = Tensor::ones(vec![3, 5, 3]); @@ -512,6 +513,18 @@ fn ones_f64(){ } } +#[test] +fn ones_bool(){ + let a: Tensor = Tensor::ones(vec![3, 5, 3]); + assert_eq!(a.shape(), &[3, 5, 3]); + assert_eq!(a.stride(), &[15, 3, 1]); + let b = a.flatten(); + let b_len = *b.len(); + for i in 0..b_len { + assert_eq!(b[i], true); + } +} + #[test] fn zeroes_u8(){ let a: Tensor = Tensor::zeros(vec![3, 5, 3]); @@ -555,3 +568,15 @@ fn zeroes_f64(){ assert_eq!(a[i], 0f64); } } + +#[test] +fn zeroes_bool(){ + let a: Tensor = Tensor::zeros(vec![3, 5, 3]); + assert_eq!(a.shape(), &[3, 5, 3]); + assert_eq!(a.stride(), &[15, 3, 1]); + let b = a.flatten(); + let b_len = *b.len(); + for i in 0..b_len { + assert_eq!(b[i], false); + } +} \ No newline at end of file From aa5e061c0fd5ee52e13b141f487d970089e7f8c7 Mon Sep 17 00:00:00 2001 From: Aman-Amith-Shastry Date: Mon, 16 Dec 2024 17:48:02 -0500 Subject: [PATCH 3/3] Addressed warning of comparing length of array to 0 --- src/tensor/constructors.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensor/constructors.rs b/src/tensor/constructors.rs index 0a18934..c2a5a3a 100644 --- a/src/tensor/constructors.rs +++ b/src/tensor/constructors.rs @@ -33,7 +33,7 @@ impl Tensor { } pub fn full(n: T, shape: Vec) -> Self { - assert!(shape.len() > 0, "Cannot create a zero-dimension tensor!"); + assert!(!shape.is_empty(), "Cannot create a zero-dimension tensor!"); let vector_ns = vec![n; shape.iter().product()];