diff --git a/src/operators/tensor/math/gather_elements.cairo b/src/operators/tensor/math/gather_elements.cairo index cc8b9ae20..e4b624e42 100644 --- a/src/operators/tensor/math/gather_elements.cairo +++ b/src/operators/tensor/math/gather_elements.cairo @@ -1,7 +1,9 @@ +use core::option::OptionTrait; +use core::traits::TryInto; use alexandria_data_structures::array_ext::SpanTraitExt; use orion::numbers::NumberTrait; -use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{{TensorTrait, Tensor}, core::{unravel_index, stride}}; /// Cf: TensorTrait::gather_elements docstring fn gather_elements, impl TCopy: Copy, impl TDrop: Drop,>( @@ -19,71 +21,45 @@ fn gather_elements, impl TCopy: Copy, im }; assert(axis < (*self.shape).len(), 'axis out of dimensions'); - let axis_shape = *(*self.shape).at(axis); - - // Adjust indices that are negative - let mut adjusted_indices = array![]; - let mut indices_data = indices.data.clone(); - loop { - match indices_data.pop_front() { - Option::Some(index) => { - let adjusted_index: usize = if *index < 0 { - let val: u32 = (axis_shape.try_into().unwrap() + *index).try_into().unwrap(); - val - } else { - let val: u32 = (*index).try_into().unwrap(); - val - }; - assert(adjusted_index >= 0 && adjusted_index < axis_shape, 'Index out of bounds'); - adjusted_indices.append(adjusted_index); - }, - Option::None => { break; } - }; - }; + let data_strides = stride(*self.shape); let mut output_data = array![]; - let mut data_shape_clone = (*self.shape).clone(); - let mut multiplier = 1; - let mut looper = 1; - let mut ind = 0; - loop { - match data_shape_clone.pop_front() { - Option::Some(val) => { - if ind >= axis { - multiplier *= *val; - } - if ind > axis { - looper *= *val; - } - ind += 1; - }, - Option::None => { break; } - }; - }; + let mut i: usize = 0; + while i < indices + .data + .len() { + let indice = *indices.data.at(i); + let adjusted_indice: u32 = if indice < 0 { + ((*(*self.shape).at(axis)).try_into().unwrap() + indice).try_into().unwrap() + } else { + indice.try_into().unwrap() + }; - let inner_loop = multiplier / axis_shape; - let mut adjusted_indices_iter = adjusted_indices.clone(); + assert(adjusted_indice < (*(*self.shape).at(axis)), 'Index out of bounds'); - let mut i: usize = 0; - loop { - match adjusted_indices_iter.pop_front() { - Option::Some(indice) => { - let value = if axis == 0 { - indice * inner_loop + (i % inner_loop) - } else if axis == (*self.shape).len() - 1 { - indice + axis_shape * (i / axis_shape) - } else { - indice * looper - + (i % looper) - + (multiplier / axis_shape) * (i / (multiplier / axis_shape)) + let multidim_index = unravel_index(i, indices.shape); + let mut flat_index_for_data = 0; + + let mut j: usize = 0; + while j < multidim_index + .len() { + let dim_index = *multidim_index.at(j); + if j == axis { + flat_index_for_data += adjusted_indice * (*data_strides.at(j)); + } else { + flat_index_for_data += (dim_index * *data_strides.at(j)) + } + j += 1; }; - output_data.append(*self.data[value]); - i += 1; - }, - Option::None => { break; } + assert( + flat_index_for_data < (*self.data).len().try_into().unwrap(), + 'Flat index out of bounds' + ); + + output_data.append(*(*self.data).at(flat_index_for_data)); + i += 1; }; - }; TensorTrait::::new(indices.shape, output_data.span()) }