Skip to content

Commit

Permalink
fix gather_elements
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Mar 27, 2024
1 parent 9c62277 commit 57871a4
Showing 1 changed file with 35 additions and 59 deletions.
94 changes: 35 additions & 59 deletions src/operators/tensor/math/gather_elements.cairo
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use core::option::OptionTrait;
use core::traits::TryInto;
use alexandria_data_structures::array_ext::SpanTraitExt;

use orion::numbers::NumberTrait;
use orion::operators::tensor::{TensorTrait, Tensor};
use orion::operators::tensor::{{TensorTrait, Tensor}, core::{unravel_index, stride}};

/// Cf: TensorTrait::gather_elements docstring
fn gather_elements<T, impl TTensorTrait: TensorTrait<T>, impl TCopy: Copy<T>, impl TDrop: Drop<T>,>(
Expand All @@ -19,71 +21,45 @@ fn gather_elements<T, impl TTensorTrait: TensorTrait<T>, impl TCopy: Copy<T>, im
};
assert(axis < (*self.shape).len(), 'axis out of dimensions');

let axis_shape = *(*self.shape).at(axis);

// Adjust indices that are negative
let mut adjusted_indices = array![];
let mut indices_data = indices.data.clone();
loop {
match indices_data.pop_front() {
Option::Some(index) => {
let adjusted_index: usize = if *index < 0 {
let val: u32 = (axis_shape.try_into().unwrap() + *index).try_into().unwrap();
val
} else {
let val: u32 = (*index).try_into().unwrap();
val
};
assert(adjusted_index >= 0 && adjusted_index < axis_shape, 'Index out of bounds');
adjusted_indices.append(adjusted_index);
},
Option::None => { break; }
};
};
let data_strides = stride(*self.shape);

let mut output_data = array![];
let mut data_shape_clone = (*self.shape).clone();
let mut multiplier = 1;
let mut looper = 1;
let mut ind = 0;
loop {
match data_shape_clone.pop_front() {
Option::Some(val) => {
if ind >= axis {
multiplier *= *val;
}
if ind > axis {
looper *= *val;
}
ind += 1;
},
Option::None => { break; }
};
};
let mut i: usize = 0;
while i < indices
.data
.len() {
let indice = *indices.data.at(i);
let adjusted_indice: u32 = if indice < 0 {
((*(*self.shape).at(axis)).try_into().unwrap() + indice).try_into().unwrap()
} else {
indice.try_into().unwrap()
};

let inner_loop = multiplier / axis_shape;
let mut adjusted_indices_iter = adjusted_indices.clone();
assert(adjusted_indice < (*(*self.shape).at(axis)), 'Index out of bounds');

let mut i: usize = 0;
loop {
match adjusted_indices_iter.pop_front() {
Option::Some(indice) => {
let value = if axis == 0 {
indice * inner_loop + (i % inner_loop)
} else if axis == (*self.shape).len() - 1 {
indice + axis_shape * (i / axis_shape)
} else {
indice * looper
+ (i % looper)
+ (multiplier / axis_shape) * (i / (multiplier / axis_shape))
let multidim_index = unravel_index(i, indices.shape);
let mut flat_index_for_data = 0;

let mut j: usize = 0;
while j < multidim_index
.len() {
let dim_index = *multidim_index.at(j);
if j == axis {
flat_index_for_data += adjusted_indice * (*data_strides.at(j));
} else {
flat_index_for_data += (dim_index * *data_strides.at(j))
}
j += 1;
};

output_data.append(*self.data[value]);
i += 1;
},
Option::None => { break; }
assert(
flat_index_for_data < (*self.data).len().try_into().unwrap(),
'Flat index out of bounds'
);

output_data.append(*(*self.data).at(flat_index_for_data));
i += 1;
};
};

TensorTrait::<T>::new(indices.shape, output_data.span())
}

0 comments on commit 57871a4

Please sign in to comment.