Skip to content

Commit

Permalink
Refactor stack functions to use a mutable list for improved performance
Browse files Browse the repository at this point in the history
  • Loading branch information
ulises-jeremias committed Mar 24, 2024
1 parent e39672d commit cab1954
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/stack.v
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ pub fn dstack[T](ts []&Tensor[T]) !&Tensor[T] {
return error('dstack was given arrays with more than two dimensions')
}
if first_tensor.rank() == 1 {
next_ts := ts.map(it.reshape[T]([1, it.size, 1])!)
mut next_ts := []&Tensor[T]{cap: ts.len}
for t in ts {
next_ts << t.reshape[T]([1, t.size, 1])!
}
return concatenate[T](next_ts, axis: 2)
} else {
mut next_ts := []&Tensor[T]{cap: ts.len}
Expand All @@ -49,7 +52,10 @@ pub fn column_stack[T](ts []&Tensor[T]) !&Tensor[T] {
}

if first_tensor.rank() == 1 {
next_ts := ts.map(it.reshape[T]([it.size, 1])!)
mut next_ts := []&Tensor[T]{cap: ts.len}
for t in ts {
next_ts << t.reshape[T]([t.size, 1])!
}
return concatenate[T](next_ts, axis: 1)
}

Expand All @@ -59,7 +65,10 @@ pub fn column_stack[T](ts []&Tensor[T]) !&Tensor[T] {
// stack join a sequence of arrays along a new axis.
pub fn stack[T](ts []&Tensor[T], data AxisData) !&Tensor[T] {
assert_shape[T](ts[0].shape, ts)!
expanded := ts.map(it.expand_dims[T](data)!)
mut expanded := []&Tensor[T]{cap: ts.len}
for t in ts {
expanded << t.expand_dims[T](data)!
}
return concatenate[T](expanded, data)
}

Expand Down

0 comments on commit cab1954

Please sign in to comment.