diff --git a/src/tensor/data_buffer/buffer.rs b/src/tensor/data_buffer/buffer.rs new file mode 100644 index 0000000..9d856ac --- /dev/null +++ b/src/tensor/data_buffer/buffer.rs @@ -0,0 +1,63 @@ +use crate::data_buffer::{DataOwned, DataView}; +use crate::dtype::RawDataType; +use std::ops::Index; +use std::ptr::NonNull; + +pub trait DataBuffer: Index { + type DType: RawDataType; + + fn len(&self) -> usize; + + fn ptr(&self) -> NonNull; + + fn const_ptr(&self) -> *const Self::DType; + + fn to_view(&self) -> DataView; + + // fn clone(&self) -> DataOwned; +} + +// Two kinds of data buffers +// DataOwned: owns its data & responsible for cleaning it up +// DataView: reference to data owned by another buffer + +impl DataBuffer for DataOwned { + type DType = T; + + fn len(&self) -> usize { + self.len + } + + fn ptr(&self) -> NonNull { + self.ptr + } + + fn const_ptr(&self) -> *const T { + self.ptr.as_ptr() + } + + fn to_view(&self) -> DataView { + let ptr = self.ptr; + let len = self.len; + DataView { ptr, len } + } +} +impl DataBuffer for DataView { + type DType = T; + + fn len(&self) -> usize { + self.len + } + + fn ptr(&self) -> NonNull { + self.ptr + } + + fn const_ptr(&self) -> *const T { + self.ptr.as_ptr() + } + + fn to_view(&self) -> DataView { + (*self).clone() + } +} diff --git a/src/tensor/data_buffer/mod.rs b/src/tensor/data_buffer/mod.rs index ec4dce1..35fbe80 100644 --- a/src/tensor/data_buffer/mod.rs +++ b/src/tensor/data_buffer/mod.rs @@ -1,70 +1,56 @@ pub(super) mod clone; pub(super) mod data_owned; pub(super) mod data_view; +pub(super) mod buffer; +pub(super) use crate::data_buffer::buffer::DataBuffer; pub(super) use crate::data_buffer::data_owned::DataOwned; pub(super) use crate::data_buffer::data_view::DataView; use crate::tensor::dtype::RawDataType; use std::ops::Index; -use std::ptr::NonNull; -pub trait DataBuffer: Index { - type DType: RawDataType; +#[cfg(test)] +mod tests { + use crate::data_buffer::DataOwned; - fn len(&self) -> usize; + #[test] + fn from_vector() { + let arr = DataOwned::from(vec![0, 50, 100]); + assert_eq!(arr.len(), &3); - fn ptr(&self) -> NonNull; + let arr = DataOwned::from(vec![vec![50], vec![50], vec![50]]); + assert_eq!(arr.len(), &3); - fn const_ptr(&self) -> *const Self::DType; + let arr = DataOwned::from(vec![vec![vec![50]], vec![vec![50]]]); + assert_eq!(arr.len(), &2); - fn to_view(&self) -> DataView; - - // fn clone(&self) -> DataOwned; -} - -// Two kinds of data buffers -// DataOwned: owns its data & responsible for cleaning it up -// DataView: reference to data owned by another buffer - -impl DataBuffer for DataOwned { - type DType = T; - - fn len(&self) -> usize { - self.len - } - - fn ptr(&self) -> NonNull { - self.ptr + let arr = DataOwned::from(vec![vec![vec![50, 50, 50]], vec![vec![50, 50, 50]]]); + assert_eq!(arr.len(), &6); } - fn const_ptr(&self) -> *const T { - self.ptr.as_ptr() - } + #[test] + fn from_array() { + let arr = DataOwned::from([500, 50, 100]); + assert_eq!(arr.len(), &3); - fn to_view(&self) -> DataView { - let ptr = self.ptr; - let len = self.len; - DataView { ptr, len } - } -} -impl DataBuffer for DataView { - type DType = T; + let arr = DataOwned::from([[500], [50], [100]]); + assert_eq!(arr.len(), &3); - fn len(&self) -> usize { - self.len - } + let arr = DataOwned::from([[[500], [50], [30]], [[50], [0], [0]]]); + assert_eq!(arr.len(), &6); - fn ptr(&self) -> NonNull { - self.ptr + let arr = DataOwned::from([[[50, 50]], [[50, 50]]]); + assert_eq!(arr.len(), &4); } - fn const_ptr(&self) -> *const T { - self.ptr.as_ptr() - } + #[test] + fn from_inhomogeneous_vector() { + let arr = DataOwned::from(vec![vec![50, 50], vec![50]]); + assert_eq!(arr.len(), &3); - fn to_view(&self) -> DataView { - (*self).clone() + let arr = DataOwned::from(vec![vec![vec![50, 50]], vec![vec![50]], vec![vec![50]]]); + assert_eq!(arr.len(), &4); } } diff --git a/tests/data_owned.rs b/tests/data_owned.rs deleted file mode 100644 index 84764ed..0000000 --- a/tests/data_owned.rs +++ /dev/null @@ -1,40 +0,0 @@ -use chela::tensor::data_owned::*; - -#[test] -fn from_vector() { - let arr = DataOwned::from(vec![0, 50, 100]); - assert_eq!(arr.len(), &3); - - let arr = DataOwned::from(vec![vec![50], vec![50], vec![50]]); - assert_eq!(arr.len(), &3); - - let arr = DataOwned::from(vec![vec![vec![50]], vec![vec![50]]]); - assert_eq!(arr.len(), &2); - - let arr = DataOwned::from(vec![vec![vec![50, 50, 50]], vec![vec![50, 50, 50]]]); - assert_eq!(arr.len(), &6); -} - -#[test] -fn from_array() { - let arr = DataOwned::from([500, 50, 100]); - assert_eq!(arr.len(), &3); - - let arr = DataOwned::from([[500], [50], [100]]); - assert_eq!(arr.len(), &3); - - let arr = DataOwned::from([[[500], [50], [30]], [[50], [0], [0]]]); - assert_eq!(arr.len(), &6); - - let arr = DataOwned::from([[[50, 50]], [[50, 50]]]); - assert_eq!(arr.len(), &4); -} - -#[test] -fn from_inhomogeneous_vector() { - let arr = DataOwned::from(vec![vec![50, 50], vec![50]]); - assert_eq!(arr.len(), &3); - - let arr = DataOwned::from(vec![vec![vec![50, 50]], vec![vec![50]], vec![vec![50]]]); - assert_eq!(arr.len(), &4); -}