forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_onnxscript_runtime.py
130 lines (103 loc) · 4.41 KB
/
test_onnxscript_runtime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Owner(s): ["module: onnx"]
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
from typing import List
import onnx_test_common
import onnxscript
import torch
from onnxscript.onnx_types import FLOAT
from torch.onnx._internal import jit_utils
from torch.testing._internal import common_utils
class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
# opset version is
# 1. local function is supported after opset 15
# 2. onnx-script requires users to determine opset in local function
opset_version = 15
def test_selu_from_onnxscript_example(self):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU()
from onnxscript.onnx_opset import opset15 as op
# TODO(titaiwang): make an official domain for onnxscript usage
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
@onnxscript.script(custom_opset)
def Selu(X):
# TODO: onnx/ort doesn't support default values for now
# move this when they do
alpha = 1.67326 # auto wrapped as Constants
gamma = 1.0507
alphaX = op.CastLike(alpha, X)
gammaX = op.CastLike(gamma, X)
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(X <= zero, neg, pos)
def custom_selu(g: jit_utils.GraphContext, X):
return g.onnxscript_op(Selu, X).setType(X.type())
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::selu",
symbolic_fn=custom_selu,
opset_version=self.opset_version,
)
self.run_test(model, x)
def test_layer_norm(self):
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.randn(2, 3)
class N(torch.nn.Module):
def __init__(self, prob):
super().__init__()
self.dropout = torch.nn.Dropout(prob)
def forward(self, x):
return self.dropout(x)
class M(torch.nn.Module):
def __init__(self, num_layers):
super().__init__()
self.num_layers = num_layers
self.lns = torch.nn.ModuleList(
[torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)]
)
self.celu1 = torch.nn.CELU(1.0)
self.celu2 = torch.nn.CELU(2.0)
self.dropout = N(0.5)
def forward(self, x, y, z):
res1 = self.celu1(x)
res2 = self.celu2(y)
for ln in self.lns:
z = ln(z)
return res1 + res2, self.dropout(z)
model = M(3)
from onnxscript.onnx_opset import opset15 as op
custom_opset = onnxscript.values.Opset(domain="onnxscript", version=1)
@onnxscript.script(custom_opset)
def layer_norm(
X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
):
mean = op.ReduceMean(X, axes=axes)
D = X - mean # op.Sub(X, mean)
DD = D * D # op.Mul(D, D)
var = op.ReduceMean(DD, axes=axes)
vareps = var + eps # op.Add(var, eps)
stddev = op.Sqrt(vareps)
invstddev = op.Reciprocal(stddev)
normalized = D * invstddev # op.Mul(D, invstddev)
normalizedw = op.CastLike(
normalized, weight
) # Type issue if missing this Op
normalizedscaled = normalizedw * weight # op.Mul(normalized, weight)
return normalizedscaled + bias
@torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
def custom_layer_norm(
g, input, normalized_shape, weight, bias, eps, cudnn_enable
):
# TODO: move the comprehension into local function once it's supported by onnxscript
axes = [-i for i in range(len(normalized_shape), 0, -1)]
return g.onnxscript_op(
layer_norm, input, weight, bias, axes_i=axes, eps_f=eps
).setType(input.type())
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::layer_norm",
symbolic_fn=custom_layer_norm,
opset_version=self.opset_version,
)
self.run_test(model, (x, y, z))
if __name__ == "__main__":
common_utils.run_tests()