Skip to content

Commit

Permalink
Merge branch 're-arch-support-extra-ops' of github.com:uTensor/utenso…
Browse files Browse the repository at this point in the history
…r_cgen into re-arch-support-extra-ops
  • Loading branch information
mbartling committed Jun 12, 2020
2 parents 3464cfd + 05e8882 commit e79e647
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
tensor_var_map=tensor_var_map,
)

@OperatorFactory.register
class _ConvOperator(_CommonParams):
op_type = "ConvOperator"

@classmethod
@must_return_type(Hashable)
def get_constructor_parameters(cls, op_info):

strides = [
1,
op_info.op_attr['StrideW'],
op_info.op_attr['StrideH'],
1,
]
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
strides_str = ','.join(map(str, strides))
return ("{{ {} }}".format(strides_str), padding)

def get_declare_snippet(self, op_var_name, tensor_var_map):
return DeclareOpSnippet(
op=self,
templ_dtypes=[self.out_dtypes[0]],
op_var_name=op_var_name,
)

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return ConvOpEvalSnippet(
op_info=op_info,
templ_dtypes=[self.out_dtypes[0]],
op_name=op_var_name,
tensor_var_map=tensor_var_map,
)


@OperatorFactory.register
class _QuantizedFullyConnectedOperator(_CommonParams):
Expand Down Expand Up @@ -842,3 +875,142 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
op_name=op_var_name,
tensor_var_map=tensor_var_map,
)

@OperatorFactory.register
class _BatchNormOperator(_CommonParams):
op_type = "BatchNormOperator"

@classmethod
@must_return_type(Hashable)
def get_constructor_parameters(cls, op_info):
strides = [
1,
op_info.op_attr['StrideW'],
op_info.op_attr['StrideH'],
1,
]
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
strides_str = ','.join(map(str, strides))
return ("{{ {} }}".format(strides_str), padding)

def get_declare_snippet(self, op_var_name, tensor_var_map):
return DeclareOpSnippet(
op=self,
templ_dtypes=[self.out_dtypes[0]],
op_var_name=op_var_name,
)

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return BatchNormSnippet(
op_info=op_info,
templ_dtypes=[self.out_dtypes[0]],
op_name=op_var_name,
tensor_var_map=tensor_var_map,
)

@OperatorFactory.register
class _MeanOperator(_CommonParams):
op_type = "MeanOperator"

@classmethod
@must_return_type(Hashable)
def get_constructor_parameters(cls, op_info):
keep_dims = str(op_info.op_attr["keep_dims"])
return (" {} ".format(keep_dims), )

def get_declare_snippet(self, op_var_name, tensor_var_map):
return DeclareOpSnippet(
op=self,
templ_dtypes=[self.out_dtypes[0]],
op_var_name=op_var_name,
)

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return BatchNormSnippet(
op_info=op_info,
templ_dtypes=[self.out_dtypes[0]],
op_name=op_var_name,
tensor_var_map=tensor_var_map,
)

@OperatorFactory.register
class _SoftmaxOperator(_CommonParams):
op_type = "SoftmaxOperator"

@classmethod
@must_return_type(Hashable)
def get_constructor_parameters(cls, op_info):
Beta = op_info.op_attr["Beta"]
return (" %f " % Beta,)

def get_declare_snippet(self, op_var_name, tensor_var_map):
return DeclareOpSnippet(
op=self,
templ_dtypes=[self.out_dtypes[0]],
op_var_name=op_var_name,
)

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return BatchNormSnippet(
op_info=op_info,
templ_dtypes=[self.out_dtypes[0]],
op_name=op_var_name,
tensor_var_map=tensor_var_map,
)

@OperatorFactory.register
class _MulOperator(_Operator):
op_type = 'MulOperator'

def get_declare_snippet(self, op_var_name, tensor_var_map):
return DeclareOpSnippet(
op=self,
templ_dtypes=[self.in_dtypes[0]],
op_var_name=op_var_name,
)

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return MulOpEvalSnippet(
op_info=op_info,
templ_dtypes=[self.in_dtypes[0]],
op_name=op_var_name,
tensor_var_map=tensor_var_map,
)

@OperatorFactory.register
class _SubOperator(_Operator):
op_type = 'SubOperator'

def get_declare_snippet(self, op_var_name, tensor_var_map):
return DeclareOpSnippet(
op=self,
templ_dtypes=[self.in_dtypes[0]],
op_var_name=op_var_name,
)

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return SubOpEvalSnippet(
op_info=op_info,
templ_dtypes=[self.in_dtypes[0]],
op_name=op_var_name,
tensor_var_map=tensor_var_map,
)

@OperatorFactory.register
class _SigmoidOperator(_Operator):
op_type = 'SigmoidOperator'

def get_declare_snippet(self, op_var_name, tensor_var_map):
return DeclareOpSnippet(
op=self,
templ_dtypes=[self.in_dtypes[0]],
op_var_name=op_var_name,
)

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return SigmoidOpEvalSnippet(
op_info=op_info,
templ_dtypes=[self.in_dtypes[0]],
op_name=op_var_name,
tensor_var_map=tensor_var_map,
)
8 changes: 8 additions & 0 deletions utensor_cgen/backend/utensor/snippets/rearch/_snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
"MaxPoolEvalSnippet",
"QuantizedFullyConnectedSnippet",
"MissingOpEvalSnippet",
"BatchNormSnippet",
"TimeSlotContainer",
"MulOpEvalSnippet",
"SubOpEvalSnippet",
"ConvOpEvalSnippet",
"MeanOpEvalSnippet",
"SoftmaxOpEvalSnippet",
"SigmoidOpEvalSnippet",
"SimpleContainer",
]

Expand Down Expand Up @@ -256,6 +258,7 @@ class SoftmaxOpEvalSnippet(OpEvalSnippet):
__inputs__ = ['input']
__outputs__ = ['output']

<<<<<<< HEAD

class MissingOpEvalSnippet(OpEvalSnippet):
__template_name__ = "snippets/rearch/op_missing.cpp"
Expand All @@ -277,6 +280,11 @@ def __init__(self, op_info, tensor_var_map):
]
self.template_vars['output_tensors'] = op_info.output_tensors[:]
self.template_vars['quant_params_map'] = quant_params_map
=======
class SigmoidOpEvalSnippet(OpEvalSnippet):
__inputs__ = ['in']
__outputs__ = ['out']
>>>>>>> 05e8882f5f9fd828586bbb708782c9e173677041


class TimeSlotContainer(SnippetBase):
Expand Down
1 change: 1 addition & 0 deletions utensor_cgen/legalizer/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class _OpTypeRename(object):
"Mean": "MeanOperator",
"Softmax": "SoftmaxOperator",
"Sigmoid": "SigmoidOperator",
"Logistic": "SigmoidOperator",
}

@classmethod
Expand Down

0 comments on commit e79e647

Please sign in to comment.