Skip to content

Commit

Permalink
Merge pull request #620 from gizatechxyz/less-equal
Browse files Browse the repository at this point in the history
Refactor LessEqual
  • Loading branch information
raphaelDkhn authored Mar 25, 2024
2 parents dc86183 + 0db954f commit d816a66
Show file tree
Hide file tree
Showing 56 changed files with 1,210 additions and 1,189 deletions.
8 changes: 4 additions & 4 deletions docs/framework/operators/tensor/tensor.less_equal.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#tensor.less_equal

```rust
fn less_equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<usize>;
fn less_equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<i32>;
```

Check if each element of the first tensor is less than or equal to the corresponding element of the second tensor.
Expand All @@ -20,7 +20,7 @@ The input tensors must have either:

## Returns

A new `Tensor<usize>` of booleans (0 or 1) with the same shape as the broadcasted inputs.
A new `Tensor<i32>` of booleans (0 or 1) with the same shape as the broadcasted inputs.

## Examples

Expand All @@ -31,7 +31,7 @@ use core::array::{ArrayTrait, SpanTrait};

use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor};

fn less_equal_example() -> Tensor<usize> {
fn less_equal_example() -> Tensor<i32> {
let tensor_1 = TensorTrait::<u32>::new(
shape: array![3, 3, 3].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7, 8].span(),
);
Expand All @@ -53,7 +53,7 @@ use core::array::{ArrayTrait, SpanTrait};

use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor};

fn less_equal_example() -> Tensor<usize> {
fn less_equal_example() -> Tensor<i32> {
let tensor_1 = TensorTrait::<u32>::new(
shape: array![3, 3, 3].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7, 8].span(),
);
Expand Down
20 changes: 10 additions & 10 deletions nodegen/node/less_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def default():

x = Tensor(Dtype.U32, x.shape, x.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.I32, z.shape, z.flatten())

name = "less_equal_u32"
make_test([x, y], z, "input_0.less_equal(@input_1)", name)
Expand All @@ -25,7 +25,7 @@ def broadcast():

x = Tensor(Dtype.U32, x.shape, x.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.I32, z.shape, z.flatten())

name = "less_equal_u32_broadcast"
make_test([x, y], z, "input_0.less_equal(@input_1)", name)
Expand All @@ -42,7 +42,7 @@ def default():

x = Tensor(Dtype.I32, x.shape, x.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.I32, z.shape, z.flatten())

name = "less_equal_i32"
make_test([x, y], z, "input_0.less_equal(@input_1)", name)
Expand All @@ -54,7 +54,7 @@ def broadcast():

x = Tensor(Dtype.I32, x.shape, x.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.I32, z.shape, z.flatten())

name = "less_equal_i32_broadcast"
make_test([x, y], z, "input_0.less_equal(@input_1)", name)
Expand All @@ -71,7 +71,7 @@ def default():

x = Tensor(Dtype.I8, x.shape, x.flatten())
y = Tensor(Dtype.I8, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.I32, z.shape, z.flatten())

name = "less_equal_i8"
make_test([x, y], z, "input_0.less_equal(@input_1)", name)
Expand All @@ -83,7 +83,7 @@ def broadcast():

x = Tensor(Dtype.I8, x.shape, x.flatten())
y = Tensor(Dtype.I8, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.I32, z.shape, z.flatten())

name = "less_equal_i8_broadcast"
make_test([x, y], z, "input_0.less_equal(@input_1)", name)
Expand All @@ -102,7 +102,7 @@ def default():
x.flatten(), FixedImpl.FP8x23))
y = Tensor(Dtype.FP8x23, y.shape, to_fp(
y.flatten(), FixedImpl.FP8x23))
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.I32, z.shape, z.flatten())

name = "less_equal_fp8x23"
make_test([x, y], z, "input_0.less_equal(@input_1)", name)
Expand All @@ -116,7 +116,7 @@ def broadcast():
x.flatten(), FixedImpl.FP8x23))
y = Tensor(Dtype.FP8x23, y.shape, to_fp(
y.flatten(), FixedImpl.FP8x23))
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.I32, z.shape, z.flatten())

name = "less_equal_fp8x23_broadcast"
make_test([x, y], z, "input_0.less_equal(@input_1)", name)
Expand All @@ -135,7 +135,7 @@ def default():
x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.I32, z.shape, z.flatten())

name = "less_equal_fp16x16"
make_test([x, y], z, "input_0.less_equal(@input_1)", name)
Expand All @@ -149,7 +149,7 @@ def broadcast():
x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.I32, z.shape, z.flatten())

name = "less_equal_fp16x16_broadcast"
make_test([x, y], z, "input_0.less_equal(@input_1)", name)
Expand Down
10 changes: 5 additions & 5 deletions src/operators/tensor/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,7 @@ trait TensorTrait<T> {
/// #tensor.less_equal
///
/// ```rust
/// fn less_equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<usize>;
/// fn less_equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<i32>;
/// ```
///
/// Check if each element of the first tensor is less than or equal to the corresponding element of the second tensor.
Expand All @@ -1340,7 +1340,7 @@ trait TensorTrait<T> {
///
/// ## Returns
///
/// A new `Tensor<usize>` of booleans (0 or 1) with the same shape as the broadcasted inputs.
/// A new `Tensor<i32>` of booleans (0 or 1) with the same shape as the broadcasted inputs.
///
/// ## Examples
///
Expand All @@ -1351,7 +1351,7 @@ trait TensorTrait<T> {
///
/// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor};
///
/// fn less_equal_example() -> Tensor<usize> {
/// fn less_equal_example() -> Tensor<i32> {
/// let tensor_1 = TensorTrait::<u32>::new(
/// shape: array![3, 3, 3].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7, 8].span(),
/// );
Expand All @@ -1373,7 +1373,7 @@ trait TensorTrait<T> {
///
/// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor};
///
/// fn less_equal_example() -> Tensor<usize> {
/// fn less_equal_example() -> Tensor<i32> {
/// let tensor_1 = TensorTrait::<u32>::new(
/// shape: array![3, 3, 3].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7, 8].span(),
/// );
Expand All @@ -1386,7 +1386,7 @@ trait TensorTrait<T> {
/// >>> [1,1,1,0,0,0,1,1,1]
/// ```
///
fn less_equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<usize>;
fn less_equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<i32>;
/// #tensor.abs
///
/// ```rust
Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_bool.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl BoolTensor of TensorTrait<bool> {
panic(array!['not supported!'])
}

fn less_equal(self: @Tensor<bool>, other: @Tensor<bool>) -> Tensor<usize> {
fn less_equal(self: @Tensor<bool>, other: @Tensor<bool>) -> Tensor<i32> {
panic(array!['not supported!'])
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl Complex64Tensor of TensorTrait<complex64> {
panic(array!['not supported!'])
}

fn less_equal(self: @Tensor<complex64>, other: @Tensor<complex64>) -> Tensor<usize> {
fn less_equal(self: @Tensor<complex64>, other: @Tensor<complex64>) -> Tensor<i32> {
panic(array!['not supported!'])
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl FP16x16Tensor of TensorTrait<FP16x16> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<FP16x16>, other: @Tensor<FP16x16>) -> Tensor<usize> {
fn less_equal(self: @Tensor<FP16x16>, other: @Tensor<FP16x16>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl FP16x16WTensor of TensorTrait<FP16x16W> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<FP16x16W>, other: @Tensor<FP16x16W>) -> Tensor<usize> {
fn less_equal(self: @Tensor<FP16x16W>, other: @Tensor<FP16x16W>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp32x32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl FP32x32Tensor of TensorTrait<FP32x32> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<FP32x32>, other: @Tensor<FP32x32>) -> Tensor<usize> {
fn less_equal(self: @Tensor<FP32x32>, other: @Tensor<FP32x32>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp64x64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl FP64x64Tensor of TensorTrait<FP64x64> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<FP64x64>, other: @Tensor<FP64x64>) -> Tensor<usize> {
fn less_equal(self: @Tensor<FP64x64>, other: @Tensor<FP64x64>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp8x23.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl FP8x23Tensor of TensorTrait<FP8x23> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<FP8x23>, other: @Tensor<FP8x23>) -> Tensor<usize> {
fn less_equal(self: @Tensor<FP8x23>, other: @Tensor<FP8x23>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl FP8x23WTensor of TensorTrait<FP8x23W> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<FP8x23W>, other: @Tensor<FP8x23W>) -> Tensor<usize> {
fn less_equal(self: @Tensor<FP8x23W>, other: @Tensor<FP8x23W>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_i32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl I32Tensor of TensorTrait<i32> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<i32>, other: @Tensor<i32>) -> Tensor<usize> {
fn less_equal(self: @Tensor<i32>, other: @Tensor<i32>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_i8.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl I8Tensor of TensorTrait<i8> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<i8>, other: @Tensor<i8>) -> Tensor<usize> {
fn less_equal(self: @Tensor<i8>, other: @Tensor<i8>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_u32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl U32Tensor of TensorTrait<u32> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<u32>, other: @Tensor<u32>) -> Tensor<usize> {
fn less_equal(self: @Tensor<u32>, other: @Tensor<u32>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
7 changes: 3 additions & 4 deletions src/operators/tensor/math/less_equal.cairo
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
use orion::operators::tensor::core::{Tensor, TensorTrait, unravel_index};
use orion::operators::tensor::{core::{Tensor, TensorTrait, unravel_index}, I32Tensor};
use orion::operators::tensor::helpers::{
broadcast_shape, broadcast_index_mapping, len_from_shape, check_compatibility
};

/// Cf: TensorTrait::less_equal docstring
fn less_equal<
T,
impl UsizeFTensor: TensorTrait<usize>,
impl TPartialOrd: PartialOrd<T>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>
>(
y: @Tensor<T>, z: @Tensor<T>
) -> Tensor<usize> {
) -> Tensor<i32> {
let broadcasted_shape = broadcast_shape(*y.shape, *z.shape);
let mut result: Array<usize> = array![];
let mut result: Array<i32> = array![];

let num_elements = len_from_shape(broadcasted_shape);

Expand Down
Loading

0 comments on commit d816a66

Please sign in to comment.