Skip to content

Commit

Permalink
Add old dropout trans back, update tests and add subcmd
Browse files Browse the repository at this point in the history
  • Loading branch information
dboyliao committed Jun 12, 2019
1 parent e90771a commit 2671be1
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ test_%:
fi;

clean:
rm -f tests_log.txt
rm -f tests_log.txt *.pdf
44 changes: 37 additions & 7 deletions tests/test_transformer/test_dropout/test_dropout_transormer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import tensorflow as tf
from utensor_cgen.frontend.tensorflow import GraphDefParser
from utensor_cgen.transformer.ns_transformer import DropoutTransformer
from utensor_cgen.transformer.ns_transformer import (DropoutTransformer,
DropoutTransformerV2)


def test_dropout_trans_1(droput_graph_tuple):
def test_dropout_trans_1_1(droput_graph_tuple):
(graph_def,
(keep_prob_name, dropout_output_name),
output_nodes) = droput_graph_tuple
Expand Down Expand Up @@ -35,19 +36,48 @@ def test_dropout_trans_1(droput_graph_tuple):
# expecting the same outputs with keep_prob == 1.0
assert (output_1 == output_2).all()

def test_dropout_trans_1_2(droput_graph_tuple):
(graph_def,
(keep_prob_name, dropout_output_name),
output_nodes) = droput_graph_tuple
ugraph = GraphDefParser.parse(graph_def, output_nodes=output_nodes)
transformer = DropoutTransformerV2()
new_ugraph = transformer.transform(ugraph)
for op in new_ugraph.ops_info.values():
assert op.ugraph
out_op = new_ugraph.ops_info[output_nodes[0]]
assert set([str(op.name) for op in out_op.input_nodes]) == set(['x', 'bias'])
# all dropout nodes should be gone
graph_1 = tf.Graph()
graph_2 = tf.Graph()
with graph_1.as_default():
tf.import_graph_def(ugraph.graph_def, name='')
with graph_2.as_default():
tf.import_graph_def(new_ugraph.graph_def, name='')
with tf.Session(graph=graph_1):
keep_prob = graph_1.get_tensor_by_name(keep_prob_name)
dropout_output = graph_1.get_tensor_by_name(dropout_output_name)
output = graph_1.get_tensor_by_name(output_nodes[0]+":0")
# test the dropout ops are gone
assert keep_prob.op.name not in new_ugraph.ops_info
assert dropout_output.op.name not in new_ugraph.ops_info
output_1 = output.eval({keep_prob:1.0})
with tf.Session(graph=graph_2):
output = graph_2.get_tensor_by_name(output_nodes[0]+":0")
output_2 = output.eval()
# expecting the same outputs with keep_prob == 1.0
assert (output_1 == output_2).all()

def test_dropout_trans_2(dropout_graph_tuple2):
graph_def, output_nodes = dropout_graph_tuple2
ugraph = GraphDefParser.parse(graph_def, output_nodes=output_nodes)
trans = DropoutTransformer()
trans = DropoutTransformerV2()
new_ugraph = trans.transform(ugraph)
assert len(new_ugraph.ops_info) == 1
assert 'x' in new_ugraph.ops_info


def test_dropout_vgg(vgg_ugraph):
trans = DropoutTransformer()
from utensor_cgen.ir.misc.graph_viz import viz_graph
viz_graph(trans.pattern_ugraph, 'dropout_pattern')
trans = DropoutTransformerV2()
new_ugraph = trans.transform(vgg_ugraph)
for op_name in new_ugraph.ops_info:
assert not op_name.startswith('dropout')
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ def test_pipeline_1(methods):
pipeline = TransformerPipeline(methods)
assert len(pipeline.pipeline) == len(methods)
for transformer, (method_name, _) in zip(pipeline.pipeline, methods):
assert isinstance(transformer, pipeline._TRANSFORMER_MAP[method_name])
assert isinstance(transformer, pipeline.TRANSFORMER_MAP[method_name])

def test_pipeline_2(methods):
pipeline = TransformerPipeline(methods)
assert len(pipeline.pipeline) == len(methods)
for transformer, (method_name, _) in zip(pipeline.pipeline, methods):
assert isinstance(transformer, pipeline._TRANSFORMER_MAP[method_name])
assert isinstance(transformer, pipeline.TRANSFORMER_MAP[method_name])

def test_pipeline_3(methods):
pipeline = TransformerPipeline(methods)
assert len(pipeline.pipeline) == len(methods)
for transformer, (method_name, _) in zip(pipeline.pipeline, methods):
assert isinstance(transformer, pipeline._TRANSFORMER_MAP[method_name])
assert isinstance(transformer, pipeline.TRANSFORMER_MAP[method_name])
17 changes: 16 additions & 1 deletion utensor_cgen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def _get_pb_model_name(path):
def cli():
pass


@cli.command(name='convert', help='convert graph to cpp/hpp files')
@click.help_option('-h', '--help')
@click.argument('pb_file', required=True, metavar='MODEL.pb')
Expand Down Expand Up @@ -84,6 +83,22 @@ def convert_graph(pb_file, output, data_dir, embed_data_dir, save_graph,
save_graph, debug_comment)
generator.generate(model_path)

@cli.command(name='list-trans-methods', help='list all available graph transformation')
@click.help_option('-h', '--help')
@click.option('--verbose', is_flag=True)
def list_trans_methods(verbose):
from utensor_cgen.transformer import TransformerPipeline

if verbose:
for name, trans_cls in TransformerPipeline.TRANSFORMER_MAP.items():
click.secho(name, fg='white', bold=True)
click.secho(trans_cls.__doc__, fg='yellow', bold=True)
else:
click.secho(
str(TransformerPipeline.all_transform_methods()),
fg='white', bold=True
)
return 0

@cli.command(name='show', help='show node names in the pb file')
@click.help_option('-h', '--help')
Expand Down
111 changes: 109 additions & 2 deletions utensor_cgen/transformer/ns_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
Transformers that get rid of namescope/nodes which are not needed
for inference
"""
import re
from collections import defaultdict
from copy import deepcopy

import numpy as np

import tensorflow as tf
from utensor_cgen.frontend.tensorflow import GraphDefParser
from utensor_cgen.ir import OperationInfo, uTensorGraph
from utensor_cgen.matcher import uTensorGraphMatcher
from utensor_cgen.utils import prune_graph, topologic_order_graph
from utensor_cgen.utils import (parse_tensor_name, prune_graph,
topologic_order_graph)

from .base import Transformer

Expand Down Expand Up @@ -47,11 +51,114 @@ def transform(self, ugraph):

return ugraph


class DropoutTransformer(Transformer):
"""Remove Dropout Op
"""Dropout removal transformer
Pros
====
- Insensitive to the dropout layer pattern so it works across different
versions of tensorflow
Cons
====
- naming constrains on the dropout layers, layer name must starts with 'dropout'
and the keep_prob op must be with name starts with 'keep_prop'
"""
METHOD_NAME = 'dropout'
KWARGS_NAMESCOPE = '_utensor_dropout'
TARGET_NODENAME_PATTERN = re.compile(r'(dropout[_\w\d]*)/.*')

def transform(self, ugraph):
new_graph = uTensorGraph(output_nodes=ugraph.output_nodes)
dropout_input_map = self._find_input(ugraph)
new_ops_info = {}
for node_name in ugraph.ops_info:
match = self.TARGET_NODENAME_PATTERN.match(node_name)
if match:
# ignore all dropout nodes
continue
# replace inputs with dropout inputs
op_info = ugraph.ops_info[node_name]
in_t_infos = [deepcopy(t_info, {'ugraph': new_graph})
for t_info in op_info.input_tensors]
out_t_infos = [deepcopy(t_info, {'ugraph': new_graph})
for t_info in op_info.output_tensors]
op_attr = deepcopy(op_info.op_attr)
for i, t_info in enumerate(in_t_infos):
op_name = parse_tensor_name(t_info.name)[0]
match = self.TARGET_NODENAME_PATTERN.match(op_name)
if match:
name_scope = match.group(1)
# assume there should be only on input except keep_prob
dropout_in_tensor = dropout_input_map[name_scope]
in_t_infos.pop(i)
in_t_infos.insert(i, dropout_in_tensor)
new_op_info = OperationInfo(name=op_info.name,
input_tensors=in_t_infos,
n_inputs=len(in_t_infos),
output_tensors=out_t_infos,
n_outputs=len(out_t_infos),
op_type=op_info.op_type,
backend=op_info.backend,
op_attr=op_attr,
ugraph=new_graph)
new_ops_info[node_name] = new_op_info
new_graph.ops_info = new_ops_info
new_graph._backend = ugraph._backend
return new_graph

def _find_dropout_clusters(self, ugraph):
clusters = defaultdict(lambda: [])
for node_name in ugraph.topo_order:
match = self.TARGET_NODENAME_PATTERN.match(node_name)
if match:
name_scope = match.group(1)
clusters[name_scope].append(node_name)
return dict(clusters)

def _find_input(self, ugraph):
"""dropout_name --> input_tensor_info
input_tensor_info := the tensor info of a tensor which is not generated
in the dropout namescope but is consumed by ops in
dropout namescope with name not starts with 'keep_prob'
"""
clusters = self._find_dropout_clusters(ugraph)
input_map = {}
for node_name in ugraph.topo_order:
match = self.TARGET_NODENAME_PATTERN.match(node_name)
if match:
name_scope = match.group(1)
cluster = clusters[name_scope]
op_info = ugraph.ops_info[node_name]
for in_tensor_info in op_info.input_tensors:
in_op_name = parse_tensor_name(in_tensor_info.name)[0]
if in_op_name not in cluster and not in_op_name.startswith('keep_prob'):
input_map[name_scope] = in_tensor_info
# assuming there is only one input for dropout
break
return input_map


class DropoutTransformerV2(Transformer):
"""Dropout removal transformer version 2
Implemented with subgraph matcher
Pros
====
- no naming requirements on the dropout layer and keep prob op
Cons
====
- sensitive to the dropout layer pattern. The pattern of dropout
layer may differ across different version of tensorflow so this
transformer may fail to match the dropout layer if the given graph
is not using the same version
"""
METHOD_NAME = 'dropout_v2'
KWARGS_NAMESCOPE = '_utensor_dropout_v2'

@property
def pattern_ugraph(self):
Expand Down
15 changes: 8 additions & 7 deletions utensor_cgen/transformer/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
from .linear_reoder import (Linear_Reorder_Transformer,
LinearReorderTransformerV2)
from .ns_transformer import (BatchNormTransformer, BiasAddTransformer,
DropoutTransformer, FakeGatherV2Transformer,
InlineTransformer)
DropoutTransformer, DropoutTransformerV2,
FakeGatherV2Transformer, InlineTransformer)
from .optimizer import IdOpRemoveOptimizer, RefCntOptimizer
from .quantize import QuantizeTransformer


class TransformerPipeline(object):

_TRANSFORMER_MAP = {
TRANSFORMER_MAP = {
RefCntOptimizer.METHOD_NAME: RefCntOptimizer,
DropoutTransformer.METHOD_NAME: DropoutTransformer,
DropoutTransformerV2.METHOD_NAME: DropoutTransformerV2,
BatchNormTransformer.METHOD_NAME: BatchNormTransformer,
QuantizeTransformer.METHOD_NAME: QuantizeTransformer,
InlineTransformer.METHOD_NAME: InlineTransformer,
Expand All @@ -37,7 +38,7 @@ def __init__(self, methods):
"""
self._pipeline = []
for method, kwargs in methods:
trans_cls = self._TRANSFORMER_MAP.get(method, None)
trans_cls = self.TRANSFORMER_MAP.get(method, None)
if trans_cls is None:
raise ValueError("Unknown transformation method: {}".format(method))
transformer = trans_cls(**kwargs)
Expand All @@ -54,12 +55,12 @@ def pipeline(self):

@classmethod
def all_transform_methods(cls):
return list(cls._TRANSFORMER_MAP.keys())
return list(cls.TRANSFORMER_MAP.keys())

@classmethod
def register_transformer(cls, trans_cls, overwrite=False):
assert issubclass(trans_cls, Transformer), \
"expecting Transformer type, get %s" % trans_cls
assert trans_cls.METHOD_NAME not in cls._TRANSFORMER_MAP or overwrite, \
assert trans_cls.METHOD_NAME not in cls.TRANSFORMER_MAP or overwrite, \
"Registering existing transformer without overwriting"
cls._TRANSFORMER_MAP[trans_cls.METHOD_NAME] = trans_cls
cls.TRANSFORMER_MAP[trans_cls.METHOD_NAME] = trans_cls

0 comments on commit 2671be1

Please sign in to comment.