Skip to content

Commit

Permalink
Merge pull request #6 from BhavyeMathur/tensor-lifetimes
Browse files Browse the repository at this point in the history
Tensor lifetimes
  • Loading branch information
BhavyeMathur authored Jan 5, 2025
2 parents fddbf3d + c5fbaec commit 792f71d
Show file tree
Hide file tree
Showing 23 changed files with 431 additions and 148 deletions.
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,18 @@ edition = "2021"
name = "fill_f32"
path = "benches/fill_f32.rs"

[[bin]]
name = "fill_f32_slice"
path = "benches/fill_f32_slice.rs"

[dependencies]
cpu-time = "1.0.0"
bitflags = "2.6.0"

[lints.rust]
private_bounds = "allow"
dead_code = "allow"

[profile.release]
debug = true
strip = false
33 changes: 33 additions & 0 deletions benches/fill_f32_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
import torch

from perfprofiler import *


class TensorFill(TimingSuite):
def __init__(self, n):
self.n = n

self.ndarray: np.ndarray = np.zeros((n, 2), dtype="float32")
self.ndarray_slice = self.ndarray[:, 0]

self.tensor_cpu = torch.zeros((n, 2), dtype=torch.float32)
self.tensor_cpu_slice = self.tensor_cpu[:, 0]

@measure_performance("NumPy")
def run(self):
self.ndarray_slice.fill(5.0)

@measure_performance("PyTorch CPU")
def run(self):
self.tensor_cpu_slice.fill_(5.0)

@measure_rust_performance("Chela CPU", target="fill_f32_slice")
def run(self, executable):
return self.run_rust(executable, self.n)


if __name__ == "__main__":
sizes = [2 ** n for n in range(9, 25)]
results = TensorFill.profile_each(sizes, n=10)
plot_results(sizes, results, "tensor.fill() CPU time vs length")
21 changes: 21 additions & 0 deletions benches/fill_f32_slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use chela::*;
use std::env;

use cpu_time::ProcessTime;


fn profile(size: usize) -> u128 {
let tensor = Tensor::zeros([size, 2]);
let mut tensor_slice = tensor.slice(s![.., 0]);

let start = ProcessTime::now();
tensor_slice.fill(5_f32);
start.elapsed().as_nanos()
}

fn main() {
let args: Vec<String> = env::args().collect();
let size = if args.len() < 2 { 65536 } else { args[1].parse().unwrap() };

println!("{}", profile(size));
}
22 changes: 13 additions & 9 deletions src/axis/index.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::axis::indexer::Indexer;
use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
use crate::axis::indexer_impl::IndexerImpl;
use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};

#[derive(Clone)]
pub enum Index {
Expand All @@ -25,18 +25,22 @@ impl Indexer for Index {
Index::RangeToInclusive(index) => index.index_of_first_element(),
}
}

fn collapse_dimension(&self) -> bool {
matches!(self, Index::Usize(_))
}
}

impl IndexerImpl for Index {
fn len(&self, axis: usize, shape: &[usize]) -> usize {
fn indexed_length(&self, len: usize) -> usize {
match self {
Index::Usize(index) => IndexerImpl::len(index, axis, shape),
Index::Range(index) => IndexerImpl::len(index, axis, shape),
Index::RangeFrom(index) => IndexerImpl::len(index, axis, shape),
Index::RangeFull(index) => IndexerImpl::len(index, axis, shape),
Index::RangeInclusive(index) => IndexerImpl::len(index, axis, shape),
Index::RangeTo(index) => IndexerImpl::len(index, axis, shape),
Index::RangeToInclusive(index) => IndexerImpl::len(index, axis, shape),
Index::Usize(index) => IndexerImpl::indexed_length(index, len),
Index::Range(index) => IndexerImpl::indexed_length(index, len),
Index::RangeFrom(index) => IndexerImpl::indexed_length(index, len),
Index::RangeFull(index) => IndexerImpl::indexed_length(index, len),
Index::RangeInclusive(index) => IndexerImpl::indexed_length(index, len),
Index::RangeTo(index) => IndexerImpl::indexed_length(index, len),
Index::RangeToInclusive(index) => IndexerImpl::indexed_length(index, len),
}
}
}
27 changes: 8 additions & 19 deletions src/axis/indexer.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,22 @@
use crate::Axis;

use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
use crate::axis::indexer_impl::IndexerImpl;
use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};

pub(crate) trait Indexer: IndexerImpl + Clone {
fn indexed_shape_and_stride(&self, axis: &Axis, shape: &[usize], stride: &[usize]) -> (Vec<usize>, Vec<usize>) {
let mut shape = shape.to_vec();
let mut stride = stride.to_vec();

let axis = axis.0;
let len = self.len(axis, &shape);

if len == 0 {
shape.remove(axis);
stride.remove(axis);
} else {
shape[axis] = len;
}
fn index_of_first_element(&self) -> usize;

(shape, stride)
fn collapse_dimension(&self) -> bool {
false
}

fn index_of_first_element(&self) -> usize;
}

impl Indexer for usize {
fn index_of_first_element(&self) -> usize {
*self
}

fn collapse_dimension(&self) -> bool {
true
}
}
impl Indexer for Range<usize> {
fn index_of_first_element(&self) -> usize {
Expand Down
24 changes: 12 additions & 12 deletions src/axis/indexer_impl.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,46 @@
use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};

pub(super) trait IndexerImpl {
fn len(&self, axis: usize, shape: &[usize]) -> usize;
pub(crate) trait IndexerImpl {
fn indexed_length(&self, axis_length: usize) -> usize;
}

impl IndexerImpl for usize {
fn len(&self, _axis: usize, _shape: &[usize]) -> usize {
0
fn indexed_length(&self, _axis_length: usize) -> usize {
1
}
}
impl IndexerImpl for Range<usize> {
fn len(&self, _axis: usize, _shape: &[usize]) -> usize {
fn indexed_length(&self, _axis_length: usize) -> usize {
self.end - self.start
}
}

impl IndexerImpl for RangeFull {
fn len(&self, axis: usize, shape: &[usize]) -> usize {
shape[axis]
fn indexed_length(&self, axis_length: usize) -> usize {
axis_length
}
}

impl IndexerImpl for RangeFrom<usize> {
fn len(&self, axis: usize, shape: &[usize]) -> usize {
shape[axis] - self.start
fn indexed_length(&self, axis_length: usize) -> usize {
axis_length - self.start
}
}

impl IndexerImpl for RangeTo<usize> {
fn len(&self, _axis: usize, _shape: &[usize]) -> usize {
fn indexed_length(&self, _axis_length: usize) -> usize {
self.end
}
}

impl IndexerImpl for RangeInclusive<usize> {
fn len(&self, _axis: usize, _shape: &[usize]) -> usize {
fn indexed_length(&self, _axis_length: usize) -> usize {
self.end() - self.start() + 1
}
}

impl IndexerImpl for RangeToInclusive<usize> {
fn len(&self, _axis: usize, _shape: &[usize]) -> usize {
fn indexed_length(&self, _axis_length: usize) -> usize {
self.end + 1
}
}
11 changes: 7 additions & 4 deletions src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::marker::PhantomData;
use std::ptr::NonNull;

pub mod constructors;
pub mod dtype;

Expand All @@ -12,19 +15,19 @@ pub mod equals;
mod flags;

use crate::dtype::RawDataType;
use crate::tensor::flags::TensorFlags;

pub use iterator::*;

use crate::tensor::flags::TensorFlags;
use std::ptr::NonNull;

#[derive(Debug)]
pub struct Tensor<T: RawDataType> {
pub struct Tensor<'a, T: RawDataType> {
ptr: NonNull<T>,
len: usize,
capacity: usize,

shape: Vec<usize>,
stride: Vec<usize>,
flags: TensorFlags,

_marker: PhantomData<&'a T>,
}
13 changes: 5 additions & 8 deletions src/tensor/clone.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
use crate::dtype::RawDataType;
use crate::iterator::collapse_contiguous::collapse_contiguous;
use crate::iterator::collapse_contiguous::{collapse_to_uniform_stride};
use crate::iterator::flat_index_generator::FlatIndexGenerator;
use crate::Tensor;
use std::ptr::copy_nonoverlapping;

impl<T: RawDataType> Clone for Tensor<T> {
fn clone(&self) -> Self {
impl<'a, T: RawDataType> Tensor<'a, T> {
pub fn clone<'b>(&'a self) -> Tensor<'b, T> {
unsafe { Tensor::from_contiguous_owned_buffer(self.shape.clone(), self.clone_data()) }
}
}


impl<T: RawDataType> Tensor<T> {
pub(super) fn clone_data(&self) -> Vec<T> {
if self.is_contiguous() {
return unsafe { self.clone_data_contiguous() };
Expand All @@ -33,7 +30,7 @@ impl<T: RawDataType> Tensor<T> {
let size = self.size();
let mut data = Vec::with_capacity(size);

let (mut shape, mut stride) = collapse_contiguous(&self.shape, &self.stride);
let (mut shape, mut stride) = collapse_to_uniform_stride(&self.shape, &self.stride);

// safe to unwrap because if stride has no elements, this would be a scalar tensor
// however, scalar tensors are contiguously stored so this method wouldn't be called
Expand All @@ -57,7 +54,7 @@ impl<T: RawDataType> Tensor<T> {
let mut dst = data.as_mut_ptr();

for i in FlatIndexGenerator::from(&shape, &stride) {
copy_nonoverlapping(src.offset(i), dst, contiguous_stride);
copy_nonoverlapping(src.add(i), dst, contiguous_stride);
dst = dst.add(contiguous_stride);
}

Expand Down
6 changes: 4 additions & 2 deletions src/tensor/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fn stride_from_shape(shape: &[usize]) -> Vec<usize> {
stride
}

impl<T: RawDataType> Tensor<T> {
impl<T: RawDataType> Tensor<'_, T> {
/// Safety: ensure data is non-empty and shape & stride matches data buffer
pub(super) unsafe fn from_owned_buffer(shape: Vec<usize>, stride: Vec<usize>, data: Vec<T>) -> Self {
// take control of the data so that Rust doesn't drop it once the vector goes out of scope
Expand All @@ -37,6 +37,8 @@ impl<T: RawDataType> Tensor<T> {
shape,
stride,
flags: TensorFlags::Owned | TensorFlags::Contiguous,

_marker: Default::default(),
}
}

Expand Down Expand Up @@ -91,7 +93,7 @@ impl<T: RawDataType> Tensor<T> {
}
}

impl<T: RawDataType> Drop for Tensor<T> {
impl<T: RawDataType> Drop for Tensor<'_, T> {
fn drop(&mut self) {
if self.flags.contains(TensorFlags::Owned) {
// drops the data
Expand Down
2 changes: 1 addition & 1 deletion src/tensor/equals.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::dtype::RawDataType;
use crate::Tensor;

impl<T1, T2> PartialEq<Tensor<T1>> for Tensor<T2>
impl<T1, T2> PartialEq<Tensor<'_, T1>> for Tensor<'_, T2>
where
T1: RawDataType,
T2: RawDataType + From<T1>,
Expand Down
32 changes: 19 additions & 13 deletions src/tensor/fill.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
use crate::dtype::RawDataType;
use crate::iterator::collapse_contiguous::collapse_to_uniform_stride;
use crate::Tensor;

impl<T: RawDataType> Tensor<T> {
/// Safety: expects tensor buffer is contiguously stored
unsafe fn fill_contiguous(&mut self, value: T) {
let mut ptr = self.ptr.as_ptr();
let end_ptr = ptr.add(self.len);
unsafe fn fill_strided<T: Copy>(mut start: *mut T, value: T, stride: usize, n: usize) {
for _ in 0..n {
std::ptr::write(start, value);
start = start.add(stride);
}
}

while ptr != end_ptr {
std::ptr::write(ptr, value);
ptr = ptr.add(1);
}
unsafe fn fill_shape_and_stride<T: Copy>(mut start: *mut T, value: T, shape: &[usize], stride: &[usize]) {
if shape.len() == 1 {
return fill_strided(start, value, stride[0], shape[0]);
}

fn fill_non_contiguous(&mut self, value: T) {
todo!()
for _ in 0..shape[0] {
fill_shape_and_stride(start, value, &shape[1..], &stride[1..]);
start = start.add(stride[0]);
}
}

impl<T: RawDataType> Tensor<'_, T> {
pub fn fill(&mut self, value: T) {
if self.is_contiguous() {
return unsafe { self.fill_contiguous(value) };
return unsafe { fill_strided(self.ptr.as_ptr(), value, 1, self.len); };
}
self.fill_non_contiguous(value)

let (shape, stride) = collapse_to_uniform_stride(&self.shape, &self.stride);
unsafe { fill_shape_and_stride(self.ptr.as_ptr(), value, &shape, &stride); }
}
}
4 changes: 2 additions & 2 deletions src/tensor/index_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::dtype::RawDataType;
use crate::Tensor;
use std::ops::Index;

impl<T: RawDataType, const D: usize> Index<[usize; D]> for Tensor<T> {
impl<T: RawDataType, const D: usize> Index<[usize; D]> for Tensor<'_, T> {
type Output = T;

fn index(&self, index: [usize; D]) -> &Self::Output {
Expand All @@ -17,7 +17,7 @@ impl<T: RawDataType, const D: usize> Index<[usize; D]> for Tensor<T> {
}
}

impl<T: RawDataType> Index<usize> for Tensor<T> {
impl<T: RawDataType> Index<usize> for Tensor<'_, T> {
type Output = T;

fn index(&self, index: usize) -> &Self::Output {
Expand Down
Loading

0 comments on commit 792f71d

Please sign in to comment.