-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpytorch_helper.py
85 lines (70 loc) · 3.28 KB
/
pytorch_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import io
import torch.onnx
import onnx
from caffe2.python.onnx.backend import Caffe2Backend
from caffe2.python.core import BlobReference, Net
_next_idx = 0
# Clone net takes a dict instead of a lambda
# It should probably take a lambda, it is more flexible
# We fake dict here
class _FakeDict(object):
def __init__(self, fn):
self.fn = fn
def get(self, name, _):
return self.fn(name)
def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=None):
"""
Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built.
Arguments:
helper (caffe2.python.core.ModelHelder): the model helper where
this imported network should be inserted
model (torch.nn.Module): the model to be exported
sample_arguments (tuple of arguments): the inputs to
the model, e.g., such that ``model(*args)`` is a valid
invocation of the model. Any non-Variable arguments will
be hard-coded into the exported model; any Variable arguments
will become inputs of the exported model, in the order they
occur in args. If args is a Variable, this is equivalent
to having called it with a 1-ary tuple of that Variable.
(Note: passing keyword arguments to the model is not currently
supported. Give us a shout if you need it.)
caffe2_inputs (list of str or caffe2.python.core.BlobReference): the
caffe2 Blobs that should be inputs to this network. Must be
the same length as sample_arguments
prefix_name: prefix name to add to each member of the blob, if None then
a fresh prefix pytorch_input_N/ is used
Returns:
A tuple of caffe2.python.core.BlobReference objects referring to the
models outputs, or a single BlobReference when the model returns a single
value.
"""
if prefix_name is None:
global _next_idx
prefix_name = 'pytorch_import_' + str(_next_idx) + '/'
_next_idx += 1
# TODO: handle the case where model cannot be exported
# and embed as a Python op in Caffe2
f = io.BytesIO()
torch.onnx.export(
model, sample_arguments, f, export_params=True)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(
onnx_model)
initialized = set([x.name for x in onnx_model.graph.initializer])
uninitialized_inputs = {x.name: i for i, x in enumerate(
onnx_model.graph.input) if x.name not in initialized}
if(len(uninitialized_inputs) != len(caffe2_inputs)):
raise ValueError('Expected {} inputs but found {}'.format(
len(uninitialized_inputs), len(caffe2_inputs)))
def remap_blob_name(name):
if name in uninitialized_inputs:
idx = uninitialized_inputs[name]
return str(caffe2_inputs[idx])
return prefix_name + name
predict_net = Net(predict_net).Clone('anon', _FakeDict(remap_blob_name))
helper.net.AppendNet(predict_net)
init_net = Net(init_net).Clone('anon', _FakeDict(remap_blob_name))
helper.param_init_net.AppendNet(init_net)
results = tuple([BlobReference(remap_blob_name(x.name), helper.net)
for x in onnx_model.graph.output])
return results