Skip to content

Commit

Permalink
Integrate scheduler run and show results on xquant (#1131)
Browse files Browse the repository at this point in the history
Integrate scheduler run and show results on xquant.
Add a flag in DebugConfig to run the scheduler for searching minimal max-cut and save it in the model's metadata. This information can be displayed then in xquant graph.

---------

Co-authored-by: reuvenp <[email protected]>
  • Loading branch information
reuvenperetz and reuvenp authored Aug 7, 2024
1 parent dbe59db commit c70c464
Show file tree
Hide file tree
Showing 33 changed files with 824 additions and 116 deletions.
15 changes: 14 additions & 1 deletion model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,17 @@
MP_DEFAULT_NUM_SAMPLES = 32

# Pruning constants
PRUNING_NUM_SCORE_APPROXIMATIONS = 32
PRUNING_NUM_SCORE_APPROXIMATIONS = 32

# Scheduling information fields
OPERATORS_SCHEDULING = 'operators_scheduling'
MAX_CUT = 'max_cut'
CUTS = 'cuts'
FUSED_NODES_MAPPING = 'fused_nodes_mapping'
OP_ORDER = 'op_order'
OP_RECORD = 'op_record'
MEM_ELEMENTS = 'mem_elements'
SHAPE = 'shape'
NODE_NAME = 'node_name'
TOTAL_SIZE = 'total_size'
NODE_OUTPUT_INDEX = 'node_output_index'
135 changes: 135 additions & 0 deletions model_compression_toolkit/core/common/fusion/graph_fuser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Dict, List

from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.base_graph import OutTensor


class FusedLayerType:
"""
Used to represent the type of fused layers, since __name__
is accessed when the graph is displayed.
"""
def __init__(self):
self.__name__ = 'FusedLayer'
class GraphFuser:

def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
"""
GraphFuser is responsible for fusing nodes in a networkx graph.
The fusion process involves:
1. Creating new fused nodes to represent these groups.
2. Updating the graph structure to replace the original nodes with fused nodes.
3. Maintaining mapping mapping of original node names to their fused node names.
Args:
graph: Graph to sue its nodes.
Returns:
Mapping of original node names to their fused node names
"""
fused_nodes_mapping = {}
# Iterate through each group of nodes to be fused
for fused_nodes_list in graph.fused_nodes:
new_fused_node = self._create_fused_node(fused_nodes_list)
self._replace_nodes_with_fused_node(graph, fused_nodes_list, new_fused_node)
# Update the mapping to keep track of which original nodes are now part of which fused nodes
for node in fused_nodes_list:
fused_nodes_mapping[node.name] = new_fused_node.name
return fused_nodes_mapping

def _create_fused_node(self, nodes: List[BaseNode]) -> BaseNode:
"""
Create a new node that represents the fusion of the given nodes.
Args:
nodes: Nodes to create the fuse node that contain them.
Returns:
Node that represents the nodes to be fused.
"""
# Create a new node with a name that reflects its components
# Use the input shape of the first node and output shape of the last node
fused_node = BaseNode(name='FusedNode_' + '_'.join([node.name for node in nodes]),
framework_attr={},
input_shape=nodes[0].input_shape,
output_shape=nodes[-1].output_shape,
weights={},
layer_class=FusedLayerType)

# Preserve the final activation quantization configuration
# This is important for maintaining the correct behavior of the fused node
fused_node.final_activation_quantization_cfg = nodes[-1].final_activation_quantization_cfg

return fused_node

def _replace_nodes_with_fused_node(self,
graph: Graph,
nodes_to_fuse: List[BaseNode],
fused_node: BaseNode):
"""
Replace the specified nodes in the graph with a new fused node.
Args:
graph: Graph to replace the nodes_to_fuse with fused_node
nodes_to_fuse: List of nodes to replace with a new fused node.
fused_node: Node to add instead of nodes in fused_node.
"""
if not nodes_to_fuse:
return

first_node = nodes_to_fuse[0]
last_node = nodes_to_fuse[-1]

# Update incoming edges: Connect predecessors of the first node to the fused node
for predecessor in graph.get_prev_nodes(first_node):
e_attr = graph.get_edge_data(predecessor, first_node)
graph.add_edge(predecessor, fused_node, **(e_attr[0]))
graph.remove_edge(predecessor, first_node)

# Update outgoing edges: Connect the fused node to successors of the last node
for successor in graph.get_next_nodes(last_node):
e_attr = graph.get_edge_data(last_node, successor)
graph.add_edge(fused_node, successor, **(e_attr[0]))
graph.remove_edge(last_node, successor)

# Remove internal edges between fused nodes
# This step is necessary to maintain graph consistency
for current_node in nodes_to_fuse[:-1]:
subsequent_nodes = graph.get_next_nodes(current_node)
for next_node in subsequent_nodes:
assert next_node in nodes_to_fuse # Ensure we're not removing edges outside the fusion
graph.remove_edge(current_node, next_node)

# Handle the case where fused nodes are part of the graph's outputs
graph_output_tensors = graph.get_outputs()
graph_output_nodes = [ot.node for ot in graph_output_tensors]
for node in nodes_to_fuse:
if node in graph_output_nodes:
# If a fused node was an output, update the graph's outputs to use the new fused node
node_to_remove_index = graph_output_nodes.index(node)
graph_output_tensors[node_to_remove_index] = OutTensor(node=fused_node,
node_out_index=graph_output_tensors[
node_to_remove_index].node_out_index)
graph.remove_node(node, new_graph_outputs=graph_output_tensors)
else:
# Remove the original node from the graph
graph.remove_node(node)

# Finally, add the new fused node to the graph
graph.add_node(fused_node)
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from collections import namedtuple

from typing import Tuple, List

from model_compression_toolkit.constants import OPERATORS_SCHEDULING, MAX_CUT, CUTS, FUSED_NODES_MAPPING
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut
from model_compression_toolkit.core.common.graph.memory_graph.max_cut_astar import MaxCutAstar
from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph

SchedulerInfo = namedtuple('SchedulerInfo', [OPERATORS_SCHEDULING, MAX_CUT, CUTS, FUSED_NODES_MAPPING])

def compute_graph_max_cut(memory_graph: MemoryGraph,
n_iter: int = 50,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ class DebugConfig:
"""
def __init__(self,
analyze_similarity: bool = False,
network_editor: List[EditRule] = []):
network_editor: List[EditRule] = [],
simulate_scheduler: bool = False):
"""
Args:
analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is
enabled) or not. Can be used to pinpoint problematic layers in the quantization process.
network_editor (List[EditRule]): A list of rules and actions to edit the network for quantization.
simulate_scheduler (bool): Simulate scheduler behaviour to compute operators order and cuts.
"""
self.analyze_similarity = analyze_similarity
self.network_editor = network_editor
self.simulate_scheduler = simulate_scheduler
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.node_def_pb2 import NodeDef
from tensorboard.compat.proto.step_stats_pb2 import StepStats, NodeExecStats, DeviceStepStats, AllocatorMemoryUsed
from tensorboard.compat.proto.summary_pb2 import HistogramProto
from tensorboard.compat.proto.summary_pb2 import HistogramProto, SummaryMetadata
from tensorboard.compat.proto.summary_pb2 import Summary
from tensorboard.compat.proto.tensor_pb2 import TensorProto
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
from tensorboard.summary.writer.event_file_writer import EventFileWriter
from typing import List, Any, Dict
from networkx import topological_sort
Expand Down Expand Up @@ -497,6 +499,32 @@ def add_figure(self,
er.add_event(event)
er.flush()

def add_text(self,
text: str,
main_tag_name: str):
"""
Add a text summary to the TensorBoard log.
Args:
text: The text content to be added to the summary.
main_tag_name: The name of the tag under which the text will be grouped in TensorBoard.
"""
plugin_data = SummaryMetadata.PluginData(
plugin_name="text", content=TextPluginData(version=0).SerializeToString()
)
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_STRING",
string_val=[text.encode(encoding="utf_8")],
tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
)
event = Event(summary=Summary(value=[Summary.Value(tag=main_tag_name, metadata=smd, tensor=tensor)]))

# Get the event writer for this tag name
er = self.__get_event_writer_by_tag_name(main_tag_name)
er.add_event(event)
er.flush()

def init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
"""
Expand Down
22 changes: 21 additions & 1 deletion model_compression_toolkit/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from collections import namedtuple

import copy

from typing import Callable, Tuple, Any, List, Dict

import numpy as np

from model_compression_toolkit.core.common import FrameworkInfo
from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser

from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import compute_graph_max_cut, \
SchedulerInfo
from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import \
requires_mixed_precision
Expand Down Expand Up @@ -174,7 +181,20 @@ def core_runner(in_model: Any,
if tb_w is not None:
finalize_bitwidth_in_tb(tb_w, weights_conf_nodes_bitwidth, activation_conf_nodes_bitwidth)

return tg, bit_widths_config, hessian_info_service
scheduler_info = None
if core_config.debug_config.simulate_scheduler:
graph_to_fuse = copy.deepcopy(tg)
fused_nodes_mapping = GraphFuser().create_fused_graph(graph_to_fuse)
memory_graph = MemoryGraph(graph_to_fuse)
schedule, max_cut, cuts = compute_graph_max_cut(memory_graph)
scheduler_info = SchedulerInfo(
operators_scheduling=schedule,
max_cut=float(max_cut),
cuts=cuts,
fused_nodes_mapping=fused_nodes_mapping
)

return tg, bit_widths_config, hessian_info_service, scheduler_info


def _set_final_resource_utilization(graph: Graph,
Expand Down
24 changes: 13 additions & 11 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from model_compression_toolkit.gptq.runner import gptq_runner
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.metadata import get_versions_dict
from model_compression_toolkit.metadata import get_versions_dict, create_model_metadata

LR_DEFAULT = 0.15
LR_REST_DEFAULT = 1e-4
Expand Down Expand Up @@ -208,15 +208,15 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da

fw_impl = GPTQKerasImplemantation()

tg, bit_widths_config, hessian_info_service = core_runner(in_model=in_model,
representative_data_gen=representative_data_gen,
core_config=core_config,
fw_info=DEFAULT_KERAS_INFO,
fw_impl=fw_impl,
tpc=target_platform_capabilities,
target_resource_utilization=target_resource_utilization,
tb_w=tb_w,
running_gptq=True)
tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
representative_data_gen=representative_data_gen,
core_config=core_config,
fw_info=DEFAULT_KERAS_INFO,
fw_impl=fw_impl,
tpc=target_platform_capabilities,
target_resource_utilization=target_resource_utilization,
tb_w=tb_w,
running_gptq=True)

float_graph = copy.deepcopy(tg)

Expand All @@ -242,7 +242,9 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da

exportable_model, user_info = get_exportable_keras_model(tg_gptq)
if target_platform_capabilities.tp_model.add_metadata:
exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
exportable_model = add_metadata(exportable_model,
create_model_metadata(tpc=target_platform_capabilities,
scheduling_info=scheduling_info))
return exportable_model, user_info

else:
Expand Down
24 changes: 13 additions & 11 deletions model_compression_toolkit/gptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfig
from model_compression_toolkit.metadata import get_versions_dict
from model_compression_toolkit.metadata import get_versions_dict, create_model_metadata

LR_DEFAULT = 1e-4
LR_REST_DEFAULT = 1e-4
Expand Down Expand Up @@ -177,15 +177,15 @@ def pytorch_gradient_post_training_quantization(model: Module,
# ---------------------- #
# Core Runner
# ---------------------- #
graph, bit_widths_config, hessian_info_service = core_runner(in_model=model,
representative_data_gen=representative_data_gen,
core_config=core_config,
fw_info=DEFAULT_PYTORCH_INFO,
fw_impl=fw_impl,
tpc=target_platform_capabilities,
target_resource_utilization=target_resource_utilization,
tb_w=tb_w,
running_gptq=True)
graph, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=model,
representative_data_gen=representative_data_gen,
core_config=core_config,
fw_info=DEFAULT_PYTORCH_INFO,
fw_impl=fw_impl,
tpc=target_platform_capabilities,
target_resource_utilization=target_resource_utilization,
tb_w=tb_w,
running_gptq=True)

float_graph = copy.deepcopy(graph)

Expand All @@ -212,7 +212,9 @@ def pytorch_gradient_post_training_quantization(model: Module,

exportable_model, user_info = get_exportable_pytorch_model(graph_gptq)
if target_platform_capabilities.tp_model.add_metadata:
exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
exportable_model = add_metadata(exportable_model,
create_model_metadata(tpc=target_platform_capabilities,
scheduling_info=scheduling_info))
return exportable_model, user_info


Expand Down
Loading

0 comments on commit c70c464

Please sign in to comment.