Skip to content

Commit

Permalink
Accelerate BLAS binding for Tensor::fill float types
Browse files Browse the repository at this point in the history
  • Loading branch information
BhavyeMathur committed Dec 30, 2024
1 parent bd52256 commit ab5c267
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/accelerate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(super) mod cblas;
9 changes: 9 additions & 0 deletions src/accelerate/cblas.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use std::ffi::{c_double, c_float, c_int};

#[cfg(target_vendor = "apple")]
#[link(name = "cblas")]
extern {
pub(crate) fn catlas_sset(N: c_int, alpha: c_float, X: *const c_float, incX: c_int);

pub(crate) fn catlas_dset(N: c_int, alpha: c_double, X: *const c_double, incX: c_int);
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod axis;
pub mod tensor;
mod traits;
mod accelerate;

pub use axis::*;
pub use tensor::*;
1 change: 1 addition & 0 deletions src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod index_impl;
pub mod shape;
pub mod slice;
pub mod iterator;
pub mod fill;
pub mod flatten;
pub mod clone;
pub mod squeeze;
Expand Down
25 changes: 25 additions & 0 deletions src/tensor/data_buffer/fill.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use crate::accelerate::cblas::{catlas_dset, catlas_sset};
use crate::data_buffer::{DataBuffer, DataOwned};
use crate::dtype::RawDataType;
use std::ffi::c_int;

pub(in crate::tensor) trait Fill<T>
where
T: RawDataType,
{
fn fill(&self, value: T);
}

impl Fill<f32> for DataOwned<f32> {
#[cfg(target_vendor = "apple")]
fn fill(&self, value: f32) {
unsafe { catlas_sset(self.len as c_int, value, self.const_ptr(), 1) }
}
}

impl Fill<f64> for DataOwned<f64> {
#[cfg(target_vendor = "apple")]
fn fill(&self, value: f64) {
unsafe { catlas_dset(self.len as c_int, value, self.const_ptr(), 1) }
}
}
1 change: 1 addition & 0 deletions src/tensor/data_buffer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub(super) mod clone;
pub(super) mod data_owned;
pub(super) mod data_view;
pub(super) mod buffer;
pub(super) mod fill;

pub(super) use crate::data_buffer::buffer::DataBuffer;
pub(super) use crate::data_buffer::data_owned::DataOwned;
Expand Down
37 changes: 37 additions & 0 deletions src/tensor/fill.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use crate::data_buffer::fill::Fill;
use crate::data_buffer::DataBuffer;
use crate::dtype::RawDataType;
use crate::TensorBase;

impl<B, T> TensorBase<B>
where
B: DataBuffer<DType=T> + Fill<T>,
T: RawDataType,
{
pub fn fill(&self, value: T) {
self.data.fill(value)
}
}

#[cfg(test)]
mod tests {
use crate::{FlatIterator, Tensor};

#[test]
fn test_fill_f32() {
let a: Tensor<f32> = Tensor::zeros([3, 5, 3]);

assert!(a.flat_iter().all(|x| x == 0.0));
a.fill(25.0);
assert!(a.flat_iter().all(|x| x == 25.0));
}

#[test]
fn test_fill_f64() {
let a: Tensor<f64> = Tensor::zeros([15]);

assert!(a.flat_iter().all(|x| x == 0.0));
a.fill(20.0);
assert!(a.flat_iter().all(|x| x == 20.0));
}
}

0 comments on commit ab5c267

Please sign in to comment.