Skip to content

Commit

Permalink
tag manually checked hot functions with @[direct_array_access]
Browse files Browse the repository at this point in the history
  • Loading branch information
spytheman committed Dec 17, 2024
1 parent b758d5f commit b977bab
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 1 deletion.
4 changes: 4 additions & 0 deletions datasets/mnist_test.v
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
import vtl.datasets

fn test_mnist() {
unbuffer_stdout()
println('start')
mnist := datasets.load_mnist()!
println('mnist dataset loaded')

assert mnist.train_features.shape == [60000, 28, 28]
assert mnist.test_features.shape == [10000, 28, 28]
assert mnist.train_labels.shape == [60000]
assert mnist.test_labels.shape == [10000]
println('done')
}
1 change: 1 addition & 0 deletions src/creation.v
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ pub fn from_1d[T](arr []T, params TensorData) !&Tensor[T] {

// from_2d takes a two dimensional array of floating point values
// and returns a two-dimensional Tensor if possible
@[direct_array_access]
pub fn from_2d[T](a [][]T, params TensorData) !&Tensor[T] {
mut arr := []T{cap: a.len * a[0].len}
for i in 0 .. a.len {
Expand Down
3 changes: 3 additions & 0 deletions src/fun.v
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ pub fn (t &Tensor[T]) as_strided[T](shape []int, strides []int) !&Tensor[T] {

// transpose permutes the axes of an tensor in a specified
// order and returns a view of the data
@[direct_array_access]
pub fn (t &Tensor[T]) transpose[T](order []int) !&Tensor[T] {
mut ret := t.view()
n := order.len
Expand Down Expand Up @@ -187,6 +188,7 @@ fn fabs(x f64) f64 {
}

// slice returns a tensor from a variadic list of indexing operations
@[direct_array_access]
pub fn (t &Tensor[T]) slice[T](idx ...[]int) !&Tensor[T] {
mut newshape := t.shape.clone()
mut newstrides := t.strides.clone()
Expand Down Expand Up @@ -266,6 +268,7 @@ pub fn (t &Tensor[T]) slice[T](idx ...[]int) !&Tensor[T] {

// slice_hilo returns a view of an array from a list of starting
// indices and a list of closing indices.
@[direct_array_access]
pub fn (t &Tensor[T]) slice_hilo[T](idx1 []int, idx2 []int) !&Tensor[T] {
mut newshape := t.shape.clone()
mut newstrides := t.strides.clone()
Expand Down
2 changes: 1 addition & 1 deletion src/fun_logical.v
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pub fn (t &Tensor[T]) array_equiv[T](other &Tensor[T]) bool {
return true
}

@[inline]
@[direct_array_access; inline]
fn handle_equal[T](vals []T, _ []int) bool {
mut equal := true
for v in vals {
Expand Down
1 change: 1 addition & 0 deletions src/iter.v
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ fn handle_flatten_iteration[T](mut s TensorIterator[T]) T {
return val
}

@[direct_array_access]
fn tensor_backstrides[T](t &Tensor[T]) []int {
rank := t.rank()
shape := t.shape
Expand Down
5 changes: 5 additions & 0 deletions src/split.v
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module vtl
// integer that does not equally divide the axis. For an array of length
// l that should be split into n sections, it returns l % n sub-arrays of
// size l//n + 1 and the rest of size l//n.
@[direct_array_access]
pub fn (t &Tensor[T]) array_split[T](ind int, axis int) ![]&Tensor[T] {
ntotal := t.shape[axis]
neach := ntotal / ind
Expand Down Expand Up @@ -125,7 +126,11 @@ pub fn (t &Tensor[T]) dsplit_expl[T](ind []int) ![]&Tensor[T] {

// splitter implements a generic splitting function that contains the underlying functionality
// for all split operations
@[direct_array_access]
fn (t &Tensor[T]) splitter[T](axis int, n int, div_points []int) ![]&Tensor[T] {
if n > 0 && div_points.len <= n {
return error('splitter error, div_points.len <= n')
}
mut subary := []&Tensor[T]{}
sary := t.swapaxes(axis, 0)!
for i in 0 .. n {
Expand Down
5 changes: 5 additions & 0 deletions src/util.v
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pub fn (mut t Tensor[T]) ensure_memory[T]() {

// assert_shape_off_axis ensures that the shapes of Tensors match
// for concatenation, except along the axis being joined
@[direct_array_access]
fn assert_shape_off_axis[T](ts []&Tensor[T], axis int, shape []int) ![]int {
mut retshape := shape.clone()
for t in ts {
Expand Down Expand Up @@ -91,6 +92,7 @@ fn assert_shape[T](shape []int, ts []&Tensor[T]) ! {

// is_col_major_contiguous checks if an array is contiguous with a col-major
// memory layout
@[direct_array_access]
fn is_col_major_contiguous(shape []int, strides []int, ndims int) bool {
if ndims == 0 {
return true
Expand All @@ -114,6 +116,7 @@ fn is_col_major_contiguous(shape []int, strides []int, ndims int) bool {

// is_row_major_contiguous checks if an array is contiguous with a row-major
// memory layout
@[direct_array_access]
fn is_row_major_contiguous(shape []int, strides []int, ndims int) bool {
if ndims == 0 {
return true
Expand Down Expand Up @@ -150,6 +153,7 @@ fn clip_axis(axis int, size int) !int {
}

// strides_from_shape returns the strides from a shape and memory format
@[direct_array_access]
fn strides_from_shape(shape []int, memory MemoryFormat) []int {
mut accum := 1
mut result := []int{len: shape.len}
Expand Down Expand Up @@ -207,6 +211,7 @@ fn shape_with_autosize(shape []int, size int) !([]int, int) {

// filter_shape_not_strides removes 0 size dimensions from the shape
// and strides of an array
@[direct_array_access]
fn filter_shape_not_strides(shape []int, strides []int) !([]int, []int) {
mut newshape := []int{}
mut newstrides := []int{}
Expand Down

0 comments on commit b977bab

Please sign in to comment.