-
Notifications
You must be signed in to change notification settings - Fork 0
/
issuse595.patch
52 lines (51 loc) · 2.67 KB
/
issuse595.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc
index 94561720..5f290afe 100644
--- a/mace/ops/reshape.cc
+++ b/mace/ops/reshape.cc
@@ -160,13 +160,10 @@ void RegisterReshape(OpRegistryBase *op_registry) {
if (op->output_shape_size() != op->output_size()) {
return {DeviceType::CPU, DeviceType::GPU};
}
-
- auto tensor_shape_info = context->tensor_shape_info();
- const std::string &input_0 = op->input(0);
- const auto out_dims_size =
- op->output_shape(0).dims_size();
- if (4 == tensor_shape_info->at(input_0).size()
- && (out_dims_size == 4 || out_dims_size == 2)) {
+ int has_data_format =
+ ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
+ *op, "has_data_format", 0);
+ if (has_data_format) {
return {DeviceType::CPU, DeviceType::GPU};
}
return {DeviceType::CPU};
diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py
index 2c8901a9..fc76bb36 100644
--- a/tools/python/transform/transformer.py
+++ b/tools/python/transform/transformer.py
@@ -1402,12 +1402,21 @@ class Transformer(base_converter.ConverterInterface):
def is_transposable_data_format_ops(self, op):
if op.type == MaceOp.Reshape:
input_op = self._producer[op.input[0]]
- out_dims_len = len(op.output_shape[0].dims)
+ input_dims = input_op.output_shape[0].dims
+ output_dims = op.output_shape[0].dims
+ tranposable = True
if len(input_op.output_shape) != 1 or \
- len(input_op.output_shape[0].dims) != 4 \
- or (out_dims_len != 4 and out_dims_len != 2):
+ len(input_dims) != 4 or len(output_dims) != 4:
+ tranposable = False
+ else:
+ in_b, in_h, in_w, in_c = self.sort_feature_map_shape(
+ input_dims, ConverterUtil.data_format(input_op))
+ ou_b, ou_h, ou_w, ou_c = self.sort_feature_map_shape(
+ output_dims, ConverterUtil.data_format(op))
+ tranposable = (in_b == ou_b and in_c == ou_c)
+ if not tranposable:
print("In this model, reshape is not transposable op.")
- return False
+ return tranposable
return op.type in MaceTransposableDataFormatOps
def update_data_format(self):