-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b02b601
commit cab5345
Showing
32 changed files
with
3,534 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# TensorTrait::tile | ||
|
||
```rust | ||
fn tile(self: @Tensor<T>, repeats: Span<usize>) -> Tensor<T>; | ||
``` | ||
|
||
Constructs a tensor by tiling a given tensor. This is the same as function tile in Numpy, but no broadcast. | ||
|
||
For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]] | ||
|
||
## Args | ||
|
||
* `tensor`(`@Tensor<T>`) - Input tensor of any shape. | ||
* `repeats`(Span<usize>) - 1D usize array of the same length as input's dimension number, includes numbers of repeated copies along input's dimensions. | ||
|
||
## Returns | ||
|
||
* Output tensor of the same dimensions and type as tensor input. output_dim[i] = input_dim[i] * repeats[i]. | ||
|
||
## Examples | ||
|
||
```rust | ||
use orion::operators::tensor::{I32Tensor, I32TensorAdd}; | ||
use core::array::{ArrayTrait, SpanTrait}; | ||
use orion::operators::tensor::{TensorTrait, Tensor}; | ||
use orion::utils::{assert_eq, assert_seq_eq}; | ||
use orion::operators::tensor::I32TensorPartialEq; | ||
|
||
|
||
fn example() -> Tensor<i32> { | ||
let mut shape = ArrayTrait::<usize>::new(); | ||
shape.append(1); | ||
shape.append(2); | ||
|
||
let mut data = ArrayTrait::new(); | ||
data.append(2); | ||
data.append(1); | ||
let input_0 = TensorTrait::new(shape.span(), data.span()); | ||
|
||
return input_0.tile(array![1, 4].span()); | ||
} | ||
>>> [[2, 1, 2, 1, 2, 1, 2, 1]] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import numpy as np | ||
from nodegen.node import RunAll | ||
from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait, get_data_statement | ||
|
||
|
||
class Tile(RunAll): | ||
|
||
@staticmethod | ||
# We test here with fp8x23 implementation. | ||
def fp8x23(): | ||
x = np.random.randint(-3, 3, (2, 2, 4, 5)).astype(np.float64) | ||
k = np.random.randint(0, 5, (4)) | ||
y = np.tile(x, k) | ||
|
||
x = Tensor(Dtype.FP8x23, x.shape, to_fp( | ||
x.flatten(), FixedImpl.FP8x23)) | ||
y = Tensor(Dtype.FP8x23, y.shape, to_fp( | ||
y.flatten(), FixedImpl.FP8x23)) | ||
|
||
name = "tile_fp8x23" | ||
make_test([x], y, f"input_0.tile(array!{k.tolist()}.span())", name) | ||
|
||
@staticmethod | ||
# We test here with fp16x16 implementation. | ||
def fp16x16(): | ||
x = np.random.randint(-3, 3, (4, 7, 9)).astype(np.float64) | ||
k = np.random.randint(0, 5, (3)) | ||
y = np.tile(x, k) | ||
|
||
x = Tensor(Dtype.FP16x16, x.shape, to_fp( | ||
x.flatten(), FixedImpl.FP16x16)) | ||
y = Tensor(Dtype.FP16x16, y.shape, to_fp( | ||
y.flatten(), FixedImpl.FP16x16)) | ||
|
||
name = "tile_fp16x16" | ||
make_test([x], y, f"input_0.tile(array!{k.tolist()}.span())", name) | ||
|
||
@staticmethod | ||
# We test here with i8 implementation. | ||
def i8(): | ||
x = np.random.randint(0, 6, (5)).astype(np.int8) | ||
k = np.random.randint(0, 5, (1)) | ||
y = np.tile(x, k) | ||
|
||
x = Tensor(Dtype.I8, x.shape, x.flatten()) | ||
y = Tensor(Dtype.I8, y.shape, y.flatten()) | ||
|
||
name = "tile_i8" | ||
make_test([x], y, f"input_0.tile(array!{k.tolist()}.span())", name) | ||
|
||
@staticmethod | ||
# We test here with i32 implementation. | ||
def i32(): | ||
x = np.random.randint(0, 6, (5, 8)).astype(np.int32) | ||
k = np.random.randint(0, 5, (2)) | ||
y = np.tile(x, k) | ||
|
||
x = Tensor(Dtype.I32, x.shape, x.flatten()) | ||
y = Tensor(Dtype.I32, y.shape, y.flatten()) | ||
|
||
name = "tile_i32" | ||
make_test([x], y, f"input_0.tile(array!{k.tolist()}.span())", name) | ||
|
||
@staticmethod | ||
# We test here with u32 implementation. | ||
def u32(): | ||
x = np.random.randint(0, 6, (1, 2)).astype(np.uint32) | ||
k = np.random.randint(0, 5, (2)) | ||
y = np.tile(x, k) | ||
|
||
x = Tensor(Dtype.U32, x.shape, x.flatten()) | ||
y = Tensor(Dtype.U32, y.shape, y.flatten()) | ||
|
||
name = "tile_u32" | ||
make_test([x], y, f"input_0.tile(array!{k.tolist()}.span())", name) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,3 +68,4 @@ mod hann_window; | |
mod hamming_window; | ||
mod blackman_window; | ||
mod scatter_nd; | ||
mod tile; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
use orion::numbers::fixed_point::core::FixedTrait; | ||
use orion::numbers::NumberTrait; | ||
use orion::operators::tensor::core::{Tensor, TensorTrait}; | ||
|
||
fn tile< | ||
T, | ||
MAG, | ||
impl TTensor: TensorTrait<T>, | ||
impl TNumber: NumberTrait<T, MAG>, | ||
impl TAdd: Add<T>, | ||
impl TSub: Sub<T>, | ||
impl TMul: Mul<T>, | ||
impl TDiv: Div<T>, | ||
impl TTensorAdd: Add<Tensor<T>>, | ||
impl TPartialOrd: PartialOrd<T>, | ||
impl TAddEq: AddEq<T>, | ||
impl TCopy: Copy<T>, | ||
impl TDrop: Drop<T>, | ||
>(self: Tensor<T>, repeats: Span<usize>) -> Tensor<T> { | ||
let mut tensor = self; | ||
let len = (tensor.shape).len(); | ||
let mut i: usize = 0; | ||
while i != len { | ||
let mut k = len - i - 1; | ||
let mut arr: Array<Tensor<T>> = array![]; | ||
let mut j: usize = 0; | ||
if (*repeats.at(k) == 0) { | ||
tensor = TensorTrait::<T>::new(array![0].span(), array![].span()); | ||
i = len; | ||
} else { | ||
while j != *repeats.at(k) { | ||
arr.append(tensor); | ||
j += 1; | ||
}; | ||
if (arr.len() > 1) { | ||
tensor = TensorTrait::concat(arr.span(), k); | ||
} | ||
i += 1; | ||
} | ||
}; | ||
tensor | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
mod input_0; | ||
mod output_0; | ||
|
||
|
||
use orion::operators::tensor::FP16x16TensorPartialEq; | ||
use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; | ||
use orion::operators::tensor::{TensorTrait, Tensor}; | ||
use core::array::{ArrayTrait, SpanTrait}; | ||
use orion::utils::{assert_eq, assert_seq_eq}; | ||
|
||
#[test] | ||
#[available_gas(2000000000)] | ||
fn test_tile_fp16x16() { | ||
let input_0 = input_0::input_0(); | ||
let z_0 = output_0::output_0(); | ||
|
||
let y_0 = input_0.tile(array![4, 1, 2].span()); | ||
|
||
assert_eq(y_0, z_0); | ||
} |
Oops, something went wrong.