diff --git a/utensor_cgen/backend/utensor/code_generator/rearch/_operators/_impls.py b/utensor_cgen/backend/utensor/code_generator/rearch/_operators/_impls.py index c6dbed26..d0488a5a 100644 --- a/utensor_cgen/backend/utensor/code_generator/rearch/_operators/_impls.py +++ b/utensor_cgen/backend/utensor/code_generator/rearch/_operators/_impls.py @@ -544,6 +544,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map): tensor_var_map=tensor_var_map, ) +@OperatorFactory.register +class _ConvOperator(_CommonParams): + op_type = "ConvOperator" + + @classmethod + @must_return_type(Hashable) + def get_constructor_parameters(cls, op_info): + + strides = [ + 1, + op_info.op_attr['StrideW'], + op_info.op_attr['StrideH'], + 1, + ] + padding = cls._PADDING_MAP[op_info.op_attr['Padding']] + strides_str = ','.join(map(str, strides)) + return ("{{ {} }}".format(strides_str), padding) + + def get_declare_snippet(self, op_var_name, tensor_var_map): + return DeclareOpSnippet( + op=self, + templ_dtypes=[self.out_dtypes[0]], + op_var_name=op_var_name, + ) + + def get_eval_snippet(self, op_var_name, op_info, tensor_var_map): + return ConvOpEvalSnippet( + op_info=op_info, + templ_dtypes=[self.out_dtypes[0]], + op_name=op_var_name, + tensor_var_map=tensor_var_map, + ) + @OperatorFactory.register class _QuantizedFullyConnectedOperator(_CommonParams): @@ -842,3 +875,142 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map): op_name=op_var_name, tensor_var_map=tensor_var_map, ) + +@OperatorFactory.register +class _BatchNormOperator(_CommonParams): + op_type = "BatchNormOperator" + + @classmethod + @must_return_type(Hashable) + def get_constructor_parameters(cls, op_info): + strides = [ + 1, + op_info.op_attr['StrideW'], + op_info.op_attr['StrideH'], + 1, + ] + padding = cls._PADDING_MAP[op_info.op_attr['Padding']] + strides_str = ','.join(map(str, strides)) + return ("{{ {} }}".format(strides_str), padding) + + def get_declare_snippet(self, op_var_name, tensor_var_map): + return DeclareOpSnippet( + op=self, + templ_dtypes=[self.out_dtypes[0]], + op_var_name=op_var_name, + ) + + def get_eval_snippet(self, op_var_name, op_info, tensor_var_map): + return BatchNormSnippet( + op_info=op_info, + templ_dtypes=[self.out_dtypes[0]], + op_name=op_var_name, + tensor_var_map=tensor_var_map, + ) + +@OperatorFactory.register +class _MeanOperator(_CommonParams): + op_type = "MeanOperator" + + @classmethod + @must_return_type(Hashable) + def get_constructor_parameters(cls, op_info): + keep_dims = str(op_info.op_attr["keep_dims"]) + return (" {} ".format(keep_dims), ) + + def get_declare_snippet(self, op_var_name, tensor_var_map): + return DeclareOpSnippet( + op=self, + templ_dtypes=[self.out_dtypes[0]], + op_var_name=op_var_name, + ) + + def get_eval_snippet(self, op_var_name, op_info, tensor_var_map): + return BatchNormSnippet( + op_info=op_info, + templ_dtypes=[self.out_dtypes[0]], + op_name=op_var_name, + tensor_var_map=tensor_var_map, + ) + +@OperatorFactory.register +class _SoftmaxOperator(_CommonParams): + op_type = "SoftmaxOperator" + + @classmethod + @must_return_type(Hashable) + def get_constructor_parameters(cls, op_info): + Beta = op_info.op_attr["Beta"] + return (" %f " % Beta,) + + def get_declare_snippet(self, op_var_name, tensor_var_map): + return DeclareOpSnippet( + op=self, + templ_dtypes=[self.out_dtypes[0]], + op_var_name=op_var_name, + ) + + def get_eval_snippet(self, op_var_name, op_info, tensor_var_map): + return BatchNormSnippet( + op_info=op_info, + templ_dtypes=[self.out_dtypes[0]], + op_name=op_var_name, + tensor_var_map=tensor_var_map, + ) + +@OperatorFactory.register +class _MulOperator(_Operator): + op_type = 'MulOperator' + + def get_declare_snippet(self, op_var_name, tensor_var_map): + return DeclareOpSnippet( + op=self, + templ_dtypes=[self.in_dtypes[0]], + op_var_name=op_var_name, + ) + + def get_eval_snippet(self, op_var_name, op_info, tensor_var_map): + return MulOpEvalSnippet( + op_info=op_info, + templ_dtypes=[self.in_dtypes[0]], + op_name=op_var_name, + tensor_var_map=tensor_var_map, + ) + +@OperatorFactory.register +class _SubOperator(_Operator): + op_type = 'SubOperator' + + def get_declare_snippet(self, op_var_name, tensor_var_map): + return DeclareOpSnippet( + op=self, + templ_dtypes=[self.in_dtypes[0]], + op_var_name=op_var_name, + ) + + def get_eval_snippet(self, op_var_name, op_info, tensor_var_map): + return SubOpEvalSnippet( + op_info=op_info, + templ_dtypes=[self.in_dtypes[0]], + op_name=op_var_name, + tensor_var_map=tensor_var_map, + ) + +@OperatorFactory.register +class _SigmoidOperator(_Operator): + op_type = 'SigmoidOperator' + + def get_declare_snippet(self, op_var_name, tensor_var_map): + return DeclareOpSnippet( + op=self, + templ_dtypes=[self.in_dtypes[0]], + op_var_name=op_var_name, + ) + + def get_eval_snippet(self, op_var_name, op_info, tensor_var_map): + return SigmoidOpEvalSnippet( + op_info=op_info, + templ_dtypes=[self.in_dtypes[0]], + op_name=op_var_name, + tensor_var_map=tensor_var_map, + ) diff --git a/utensor_cgen/backend/utensor/snippets/rearch/_snippets.py b/utensor_cgen/backend/utensor/snippets/rearch/_snippets.py index c40b4613..0010416b 100644 --- a/utensor_cgen/backend/utensor/snippets/rearch/_snippets.py +++ b/utensor_cgen/backend/utensor/snippets/rearch/_snippets.py @@ -29,12 +29,14 @@ "MaxPoolEvalSnippet", "QuantizedFullyConnectedSnippet", "MissingOpEvalSnippet", + "BatchNormSnippet", "TimeSlotContainer", "MulOpEvalSnippet", "SubOpEvalSnippet", "ConvOpEvalSnippet", "MeanOpEvalSnippet", "SoftmaxOpEvalSnippet", + "SigmoidOpEvalSnippet", "SimpleContainer", ] @@ -256,6 +258,7 @@ class SoftmaxOpEvalSnippet(OpEvalSnippet): __inputs__ = ['input'] __outputs__ = ['output'] +<<<<<<< HEAD class MissingOpEvalSnippet(OpEvalSnippet): __template_name__ = "snippets/rearch/op_missing.cpp" @@ -277,6 +280,11 @@ def __init__(self, op_info, tensor_var_map): ] self.template_vars['output_tensors'] = op_info.output_tensors[:] self.template_vars['quant_params_map'] = quant_params_map +======= +class SigmoidOpEvalSnippet(OpEvalSnippet): + __inputs__ = ['in'] + __outputs__ = ['out'] +>>>>>>> 05e8882f5f9fd828586bbb708782c9e173677041 class TimeSlotContainer(SnippetBase): diff --git a/utensor_cgen/legalizer/tflite.py b/utensor_cgen/legalizer/tflite.py index 5e0ec445..8eec014a 100644 --- a/utensor_cgen/legalizer/tflite.py +++ b/utensor_cgen/legalizer/tflite.py @@ -38,6 +38,7 @@ class _OpTypeRename(object): "Mean": "MeanOperator", "Softmax": "SoftmaxOperator", "Sigmoid": "SigmoidOperator", + "Logistic": "SigmoidOperator", } @classmethod