Skip to content

Commit

Permalink
Merge pull request #2 from BhavyeMathur/dev
Browse files Browse the repository at this point in the history
Zeros, ones, and full constructors
  • Loading branch information
BhavyeMathur authored Dec 16, 2024
2 parents 8be2e16 + aa5e061 commit 69cc82b
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 7 deletions.
53 changes: 49 additions & 4 deletions src/tensor/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,54 @@ impl<T: RawDataType> Tensor<T> {
ndims: D,
}
}

pub fn full(n: T, shape: Vec<usize>) -> Self {
assert!(!shape.is_empty(), "Cannot create a zero-dimension tensor!");

let vector_ns = vec![n; shape.iter().product()];

let mut stride = vec![0; shape.len()];

let ndims = shape.len();

let mut p = 1;
for i in (0..ndims).rev() {
stride[i] = p;
p *= shape[i];
}

Self {
data: DataOwned::new(vector_ns),
shape,
stride,
ndims,
}
}

pub fn zeros(shape: Vec<usize>) -> Self
where
T: RawDataType + From<bool>,
{
Self::full(false.into(), shape)
}

pub fn ones(shape: Vec<usize>) -> Self
where
T: RawDataType + From<bool>,
{
Self::full(true.into(), shape)
}
}

impl<T: RawDataType> TensorView<T> {
pub(super) fn from<B>(tensor: &TensorBase<B>, offset: usize, shape: Vec<usize>, stride: Vec<usize>) -> Self
pub(super) fn from<B>(
tensor: &TensorBase<B>,
offset: usize,
shape: Vec<usize>,
stride: Vec<usize>,
) -> Self
where
B: DataBuffer<DType=T>,
B: DataBuffer<DType = T>,
{
let ndims = shape.len();

Expand All @@ -46,9 +88,12 @@ impl<T: RawDataType> TensorView<T> {
// }
//
// the following code is equivalent to the above loop
let len = shape.iter().zip(stride.iter())
let len = shape
.iter()
.zip(stride.iter())
.map(|(&axis_length, &axis_stride)| axis_stride * (axis_length - 1))
.sum::<usize>() + 1;
.sum::<usize>()
+ 1;

let data = DataView::from_buffer(&tensor.data, offset, len);

Expand Down
2 changes: 1 addition & 1 deletion src/tensor/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ impl RawDataType for i128 {}
impl RawDataType for f32 {}
impl RawDataType for f64 {}

impl RawDataType for bool {}
impl RawDataType for bool {}
151 changes: 149 additions & 2 deletions tests/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,16 +420,163 @@ fn unsqueeze_random_dimension_first_axis() {

#[test]
fn unsqueeze_random_dimension_axis_1() {
let a: Tensor<i32> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
let a = Tensor::from([[1, 2, 3], [4, 5, 6]]);
let b = a.unsqueeze(Axis(1));
assert_eq!(b.shape(), &[2, 1, 3]);
assert_eq!(b.stride(), &[3, 3, 1]);
}

#[test]
fn unsqueeze_random_dimension_last_axis() {
let a: Tensor<i32> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
let a = Tensor::from([[1, 2, 3], [4, 5, 6]]);
let b = a.unsqueeze(Axis(2));
assert_eq!(b.shape(), &[2, 3, 1]);
assert_eq!(b.stride(), &[3, 1, 1]);
}

#[test]
fn full_i32(){
let a = Tensor::full(3, vec![2, 3]);
assert_eq!(a.shape(), &[2, 3]);
assert_eq!(a.stride(), &[3, 1]);
let b = a.flatten();
let b_len = b.len().clone();
for i in 0..b_len {
assert_eq!(b[i], 3i32);
}
}

#[test]
fn full_f64(){
let a = Tensor::full(3.2,vec![4, 6, 2]);
assert_eq!(a.shape(), &[4, 6, 2]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], 3.2f64);
}
}

#[test]
fn full_bool(){
let a: Tensor<bool> = Tensor::full(true,vec![3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], true);
}
}

#[test]
fn ones_u8(){
let a: Tensor<u8> = Tensor::ones(vec![3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], 1u8);
}
}

#[test]
fn ones_i32(){
let a: Tensor<i32> = Tensor::ones(vec![3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], 1i32);
}
}

#[test]
fn ones_1d(){
let a: Tensor<u8> = Tensor::ones(vec![4]);
assert_eq!(a.shape(), &[4]);
let a_len = *a.len();
for i in 0..a_len {
assert_eq!(a[i], 1u8);
}
}

#[test]
fn ones_f64(){
let a: Tensor<f64> = Tensor::ones(vec![4]);
assert_eq!(a.shape(), &[4]);
let a_len = *a.len();
for i in 0..a_len {
assert_eq!(a[i], 1f64);
}
}

#[test]
fn ones_bool(){
let a: Tensor<bool> = Tensor::ones(vec![3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], true);
}
}

#[test]
fn zeroes_u8(){
let a: Tensor<u8> = Tensor::zeros(vec![3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], 0u8);
}
}

#[test]
fn zeroes_i32(){
let a: Tensor<i32> = Tensor::zeros(vec![3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], 0i32);
}
}

#[test]
fn zeroes_1d(){
let a: Tensor<u8> = Tensor::zeros(vec![4]);
assert_eq!(a.shape(), &[4]);
let a_len = *a.len();
for i in 0..a_len {
assert_eq!(a[i], 0u8);
}
}

#[test]
fn zeroes_f64(){
let a: Tensor<f64> = Tensor::zeros(vec![4]);
assert_eq!(a.shape(), &[4]);
let a_len = *a.len();
for i in 0..a_len {
assert_eq!(a[i], 0f64);
}
}

#[test]
fn zeroes_bool(){
let a: Tensor<bool> = Tensor::zeros(vec![3, 5, 3]);
assert_eq!(a.shape(), &[3, 5, 3]);
assert_eq!(a.stride(), &[15, 3, 1]);
let b = a.flatten();
let b_len = *b.len();
for i in 0..b_len {
assert_eq!(b[i], false);
}
}

0 comments on commit 69cc82b

Please sign in to comment.