diff --git a/src/accelerate.rs b/src/accelerate.rs new file mode 100644 index 0000000..54c4687 --- /dev/null +++ b/src/accelerate.rs @@ -0,0 +1 @@ +pub(super) mod cblas; diff --git a/src/accelerate/cblas.rs b/src/accelerate/cblas.rs new file mode 100644 index 0000000..adb7a0f --- /dev/null +++ b/src/accelerate/cblas.rs @@ -0,0 +1,9 @@ +use std::ffi::{c_double, c_float, c_int}; + +#[cfg(target_vendor = "apple")] +#[link(name = "cblas")] +extern { + pub(crate) fn catlas_sset(N: c_int, alpha: c_float, X: *const c_float, incX: c_int); + + pub(crate) fn catlas_dset(N: c_int, alpha: c_double, X: *const c_double, incX: c_int); +} diff --git a/src/lib.rs b/src/lib.rs index f5a9573..ed9502b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ pub mod axis; pub mod tensor; mod traits; +mod accelerate; pub use axis::*; pub use tensor::*; diff --git a/src/tensor.rs b/src/tensor.rs index a6dc351..c926b2b 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -6,6 +6,7 @@ pub mod index_impl; pub mod shape; pub mod slice; pub mod iterator; +pub mod fill; pub mod flatten; pub mod clone; pub mod squeeze; diff --git a/src/tensor/data_buffer/fill.rs b/src/tensor/data_buffer/fill.rs new file mode 100644 index 0000000..b2ab832 --- /dev/null +++ b/src/tensor/data_buffer/fill.rs @@ -0,0 +1,25 @@ +use crate::accelerate::cblas::{catlas_dset, catlas_sset}; +use crate::data_buffer::{DataBuffer, DataOwned}; +use crate::dtype::RawDataType; +use std::ffi::c_int; + +pub(in crate::tensor) trait Fill +where + T: RawDataType, +{ + fn fill(&self, value: T); +} + +impl Fill for DataOwned { + #[cfg(target_vendor = "apple")] + fn fill(&self, value: f32) { + unsafe { catlas_sset(self.len as c_int, value, self.const_ptr(), 1) } + } +} + +impl Fill for DataOwned { + #[cfg(target_vendor = "apple")] + fn fill(&self, value: f64) { + unsafe { catlas_dset(self.len as c_int, value, self.const_ptr(), 1) } + } +} diff --git a/src/tensor/data_buffer/mod.rs b/src/tensor/data_buffer/mod.rs index 2fae4e0..ad8e23c 100644 --- a/src/tensor/data_buffer/mod.rs +++ b/src/tensor/data_buffer/mod.rs @@ -2,6 +2,7 @@ pub(super) mod clone; pub(super) mod data_owned; pub(super) mod data_view; pub(super) mod buffer; +pub(super) mod fill; pub(super) use crate::data_buffer::buffer::DataBuffer; pub(super) use crate::data_buffer::data_owned::DataOwned; diff --git a/src/tensor/fill.rs b/src/tensor/fill.rs new file mode 100644 index 0000000..6e5f7c5 --- /dev/null +++ b/src/tensor/fill.rs @@ -0,0 +1,37 @@ +use crate::data_buffer::fill::Fill; +use crate::data_buffer::DataBuffer; +use crate::dtype::RawDataType; +use crate::TensorBase; + +impl TensorBase +where + B: DataBuffer + Fill, + T: RawDataType, +{ + pub fn fill(&self, value: T) { + self.data.fill(value) + } +} + +#[cfg(test)] +mod tests { + use crate::{FlatIterator, Tensor}; + + #[test] + fn test_fill_f32() { + let a: Tensor = Tensor::zeros([3, 5, 3]); + + assert!(a.flat_iter().all(|x| x == 0.0)); + a.fill(25.0); + assert!(a.flat_iter().all(|x| x == 25.0)); + } + + #[test] + fn test_fill_f64() { + let a: Tensor = Tensor::zeros([15]); + + assert!(a.flat_iter().all(|x| x == 0.0)); + a.fill(20.0); + assert!(a.flat_iter().all(|x| x == 20.0)); + } +}