diff --git a/src/tensor/constructors.rs b/src/tensor/constructors.rs index da95f34..c2a5a3a 100644 --- a/src/tensor/constructors.rs +++ b/src/tensor/constructors.rs @@ -31,12 +31,54 @@ impl Tensor { ndims: D, } } + + pub fn full(n: T, shape: Vec) -> Self { + assert!(!shape.is_empty(), "Cannot create a zero-dimension tensor!"); + + let vector_ns = vec![n; shape.iter().product()]; + + let mut stride = vec![0; shape.len()]; + + let ndims = shape.len(); + + let mut p = 1; + for i in (0..ndims).rev() { + stride[i] = p; + p *= shape[i]; + } + + Self { + data: DataOwned::new(vector_ns), + shape, + stride, + ndims, + } + } + + pub fn zeros(shape: Vec) -> Self + where + T: RawDataType + From, + { + Self::full(false.into(), shape) + } + + pub fn ones(shape: Vec) -> Self + where + T: RawDataType + From, + { + Self::full(true.into(), shape) + } } impl TensorView { - pub(super) fn from(tensor: &TensorBase, offset: usize, shape: Vec, stride: Vec) -> Self + pub(super) fn from( + tensor: &TensorBase, + offset: usize, + shape: Vec, + stride: Vec, + ) -> Self where - B: DataBuffer, + B: DataBuffer, { let ndims = shape.len(); @@ -46,9 +88,12 @@ impl TensorView { // } // // the following code is equivalent to the above loop - let len = shape.iter().zip(stride.iter()) + let len = shape + .iter() + .zip(stride.iter()) .map(|(&axis_length, &axis_stride)| axis_stride * (axis_length - 1)) - .sum::() + 1; + .sum::() + + 1; let data = DataView::from_buffer(&tensor.data, offset, len); diff --git a/src/tensor/dtype.rs b/src/tensor/dtype.rs index c8a4041..4b3c89a 100644 --- a/src/tensor/dtype.rs +++ b/src/tensor/dtype.rs @@ -15,4 +15,4 @@ impl RawDataType for i128 {} impl RawDataType for f32 {} impl RawDataType for f64 {} -impl RawDataType for bool {} +impl RawDataType for bool {} \ No newline at end of file diff --git a/tests/tensor.rs b/tests/tensor.rs index 004d2ec..c2df1c9 100644 --- a/tests/tensor.rs +++ b/tests/tensor.rs @@ -420,7 +420,7 @@ fn unsqueeze_random_dimension_first_axis() { #[test] fn unsqueeze_random_dimension_axis_1() { - let a: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); + let a = Tensor::from([[1, 2, 3], [4, 5, 6]]); let b = a.unsqueeze(Axis(1)); assert_eq!(b.shape(), &[2, 1, 3]); assert_eq!(b.stride(), &[3, 3, 1]); @@ -428,8 +428,155 @@ fn unsqueeze_random_dimension_axis_1() { #[test] fn unsqueeze_random_dimension_last_axis() { - let a: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); + let a = Tensor::from([[1, 2, 3], [4, 5, 6]]); let b = a.unsqueeze(Axis(2)); assert_eq!(b.shape(), &[2, 3, 1]); assert_eq!(b.stride(), &[3, 1, 1]); } + +#[test] +fn full_i32(){ + let a = Tensor::full(3, vec![2, 3]); + assert_eq!(a.shape(), &[2, 3]); + assert_eq!(a.stride(), &[3, 1]); + let b = a.flatten(); + let b_len = b.len().clone(); + for i in 0..b_len { + assert_eq!(b[i], 3i32); + } +} + +#[test] +fn full_f64(){ + let a = Tensor::full(3.2,vec![4, 6, 2]); + assert_eq!(a.shape(), &[4, 6, 2]); + let b = a.flatten(); + let b_len = *b.len(); + for i in 0..b_len { + assert_eq!(b[i], 3.2f64); + } +} + +#[test] +fn full_bool(){ + let a: Tensor = Tensor::full(true,vec![3, 5, 3]); + assert_eq!(a.shape(), &[3, 5, 3]); + assert_eq!(a.stride(), &[15, 3, 1]); + let b = a.flatten(); + let b_len = *b.len(); + for i in 0..b_len { + assert_eq!(b[i], true); + } +} + +#[test] +fn ones_u8(){ + let a: Tensor = Tensor::ones(vec![3, 5, 3]); + assert_eq!(a.shape(), &[3, 5, 3]); + assert_eq!(a.stride(), &[15, 3, 1]); + let b = a.flatten(); + let b_len = *b.len(); + for i in 0..b_len { + assert_eq!(b[i], 1u8); + } +} + +#[test] +fn ones_i32(){ + let a: Tensor = Tensor::ones(vec![3, 5, 3]); + assert_eq!(a.shape(), &[3, 5, 3]); + assert_eq!(a.stride(), &[15, 3, 1]); + let b = a.flatten(); + let b_len = *b.len(); + for i in 0..b_len { + assert_eq!(b[i], 1i32); + } +} + +#[test] +fn ones_1d(){ + let a: Tensor = Tensor::ones(vec![4]); + assert_eq!(a.shape(), &[4]); + let a_len = *a.len(); + for i in 0..a_len { + assert_eq!(a[i], 1u8); + } +} + +#[test] +fn ones_f64(){ + let a: Tensor = Tensor::ones(vec![4]); + assert_eq!(a.shape(), &[4]); + let a_len = *a.len(); + for i in 0..a_len { + assert_eq!(a[i], 1f64); + } +} + +#[test] +fn ones_bool(){ + let a: Tensor = Tensor::ones(vec![3, 5, 3]); + assert_eq!(a.shape(), &[3, 5, 3]); + assert_eq!(a.stride(), &[15, 3, 1]); + let b = a.flatten(); + let b_len = *b.len(); + for i in 0..b_len { + assert_eq!(b[i], true); + } +} + +#[test] +fn zeroes_u8(){ + let a: Tensor = Tensor::zeros(vec![3, 5, 3]); + assert_eq!(a.shape(), &[3, 5, 3]); + assert_eq!(a.stride(), &[15, 3, 1]); + let b = a.flatten(); + let b_len = *b.len(); + for i in 0..b_len { + assert_eq!(b[i], 0u8); + } +} + +#[test] +fn zeroes_i32(){ + let a: Tensor = Tensor::zeros(vec![3, 5, 3]); + assert_eq!(a.shape(), &[3, 5, 3]); + assert_eq!(a.stride(), &[15, 3, 1]); + let b = a.flatten(); + let b_len = *b.len(); + for i in 0..b_len { + assert_eq!(b[i], 0i32); + } +} + +#[test] +fn zeroes_1d(){ + let a: Tensor = Tensor::zeros(vec![4]); + assert_eq!(a.shape(), &[4]); + let a_len = *a.len(); + for i in 0..a_len { + assert_eq!(a[i], 0u8); + } +} + +#[test] +fn zeroes_f64(){ + let a: Tensor = Tensor::zeros(vec![4]); + assert_eq!(a.shape(), &[4]); + let a_len = *a.len(); + for i in 0..a_len { + assert_eq!(a[i], 0f64); + } +} + +#[test] +fn zeroes_bool(){ + let a: Tensor = Tensor::zeros(vec![3, 5, 3]); + assert_eq!(a.shape(), &[3, 5, 3]); + assert_eq!(a.stride(), &[15, 3, 1]); + let b = a.flatten(); + let b_len = *b.len(); + for i in 0..b_len { + assert_eq!(b[i], false); + } +} \ No newline at end of file