Skip to content

Commit

Permalink
add issue595.patch
Browse files Browse the repository at this point in the history
N/A

Signed-off-by: Luxuhui <[email protected]>
  • Loading branch information
lu229 committed Mar 24, 2020
1 parent 7e673ce commit 2d42a53
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions issuse595.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc
index 94561720..5f290afe 100644
--- a/mace/ops/reshape.cc
+++ b/mace/ops/reshape.cc
@@ -160,13 +160,10 @@ void RegisterReshape(OpRegistryBase *op_registry) {
if (op->output_shape_size() != op->output_size()) {
return {DeviceType::CPU, DeviceType::GPU};
}
-
- auto tensor_shape_info = context->tensor_shape_info();
- const std::string &input_0 = op->input(0);
- const auto out_dims_size =
- op->output_shape(0).dims_size();
- if (4 == tensor_shape_info->at(input_0).size()
- && (out_dims_size == 4 || out_dims_size == 2)) {
+ int has_data_format =
+ ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
+ *op, "has_data_format", 0);
+ if (has_data_format) {
return {DeviceType::CPU, DeviceType::GPU};
}
return {DeviceType::CPU};
diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py
index 2c8901a9..fc76bb36 100644
--- a/tools/python/transform/transformer.py
+++ b/tools/python/transform/transformer.py
@@ -1402,12 +1402,21 @@ class Transformer(base_converter.ConverterInterface):
def is_transposable_data_format_ops(self, op):
if op.type == MaceOp.Reshape:
input_op = self._producer[op.input[0]]
- out_dims_len = len(op.output_shape[0].dims)
+ input_dims = input_op.output_shape[0].dims
+ output_dims = op.output_shape[0].dims
+ tranposable = True
if len(input_op.output_shape) != 1 or \
- len(input_op.output_shape[0].dims) != 4 \
- or (out_dims_len != 4 and out_dims_len != 2):
+ len(input_dims) != 4 or len(output_dims) != 4:
+ tranposable = False
+ else:
+ in_b, in_h, in_w, in_c = self.sort_feature_map_shape(
+ input_dims, ConverterUtil.data_format(input_op))
+ ou_b, ou_h, ou_w, ou_c = self.sort_feature_map_shape(
+ output_dims, ConverterUtil.data_format(op))
+ tranposable = (in_b == ou_b and in_c == ou_c)
+ if not tranposable:
print("In this model, reshape is not transposable op.")
- return False
+ return tranposable
return op.type in MaceTransposableDataFormatOps

def update_data_format(self):

0 comments on commit 2d42a53

Please sign in to comment.