diff --git a/issuse595.patch b/issuse595.patch new file mode 100644 index 0000000..de62e49 --- /dev/null +++ b/issuse595.patch @@ -0,0 +1,52 @@ +diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc +index 94561720..5f290afe 100644 +--- a/mace/ops/reshape.cc ++++ b/mace/ops/reshape.cc +@@ -160,13 +160,10 @@ void RegisterReshape(OpRegistryBase *op_registry) { + if (op->output_shape_size() != op->output_size()) { + return {DeviceType::CPU, DeviceType::GPU}; + } +- +- auto tensor_shape_info = context->tensor_shape_info(); +- const std::string &input_0 = op->input(0); +- const auto out_dims_size = +- op->output_shape(0).dims_size(); +- if (4 == tensor_shape_info->at(input_0).size() +- && (out_dims_size == 4 || out_dims_size == 2)) { ++ int has_data_format = ++ ProtoArgHelper::GetOptionalArg( ++ *op, "has_data_format", 0); ++ if (has_data_format) { + return {DeviceType::CPU, DeviceType::GPU}; + } + return {DeviceType::CPU}; +diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py +index 2c8901a9..fc76bb36 100644 +--- a/tools/python/transform/transformer.py ++++ b/tools/python/transform/transformer.py +@@ -1402,12 +1402,21 @@ class Transformer(base_converter.ConverterInterface): + def is_transposable_data_format_ops(self, op): + if op.type == MaceOp.Reshape: + input_op = self._producer[op.input[0]] +- out_dims_len = len(op.output_shape[0].dims) ++ input_dims = input_op.output_shape[0].dims ++ output_dims = op.output_shape[0].dims ++ tranposable = True + if len(input_op.output_shape) != 1 or \ +- len(input_op.output_shape[0].dims) != 4 \ +- or (out_dims_len != 4 and out_dims_len != 2): ++ len(input_dims) != 4 or len(output_dims) != 4: ++ tranposable = False ++ else: ++ in_b, in_h, in_w, in_c = self.sort_feature_map_shape( ++ input_dims, ConverterUtil.data_format(input_op)) ++ ou_b, ou_h, ou_w, ou_c = self.sort_feature_map_shape( ++ output_dims, ConverterUtil.data_format(op)) ++ tranposable = (in_b == ou_b and in_c == ou_c) ++ if not tranposable: + print("In this model, reshape is not transposable op.") +- return False ++ return tranposable + return op.type in MaceTransposableDataFormatOps + + def update_data_format(self):