Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ONNX Rewriter and IR to simplify the mnb_to_qdq pass #1482

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
364 changes: 111 additions & 253 deletions olive/passes/onnx/mnb_to_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict

import ml_dtypes

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'ml_dtypes' is not used.

Check warning

Code scanning / lintrunner

PYLINT/W0611 Warning

Unused import ml_dtypes (unused-import)
See unused-import.

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

ml\_dtypes imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
import numpy as np
import onnx
from onnxscript import ir
from onnxscript.rewriter import pattern as orp

from olive.hardware.accelerator import AcceleratorSpec
from olive.model import ONNXModelHandler
Expand Down Expand Up @@ -62,257 +64,124 @@
) -> ONNXModelHandler:
output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name)

# create a dag from the model
dag = OnnxDAG.from_model_path(model.model_path)
# remove unnecessary identity nodes
dag.remove_identity_nodes()

# if matmulnbits zero point is the following, then the zero point is not needed in the DQ node
default_mnb_zp = 8 if config["use_int4"] else 0
int_np_dtype = np.int8 if config["use_int4"] else np.uint8
int_elem_type = onnx.TensorProto.INT4 if config["use_int4"] else onnx.TensorProto.UINT4

num_modified = 0
for node_name in dag.get_node_names():
op_type = dag.get_node_op_type(node_name)
if op_type != "MatMulNBits":
continue

node_inputs = dag.get_node_inputs(node_name)
# only deal with the constant matmul case for now
if not all(dag.is_initializer(i_name) and not dag.is_input(i_name) for i_name in node_inputs[1:]):
continue

graph_idx = dag.get_graph_idx(node_name)
# 2 Step
# 1. pattern replacement
# 2. Repacking
ir_model = ir.serde.deserialize_model(model.load_model())

def mat_mul_n_bits_pattern(op, input_A, q_weight, q_scales, q_zeros, g_idx, bias):
# bias is an optional input
return op.MatMulNBits(
input_A,
q_weight,
q_scales,
q_zeros,
g_idx,
bias,
_outputs=["mat_mul_n_bits_out"], # Bind the output to the name "mat_mul_n_bits_out"
)

# original output proto
node_output = dag.get_node_outputs(node_name)[0]
is_model_output = dag.is_output(node_output)
node_output_proto = None
if dag.get_io(node_output).proto:
node_output_proto = dag.get_io(node_output).proto[-1]
node_attributes = dag.get_node_attributes(node_name)
K = node_attributes["K"] # noqa: N806
N = node_attributes["N"] # noqa: N806
block_size = node_attributes["block_size"]
num_k_blocks = math.ceil(K / block_size)
def _is_initializer(context, value: ir.Value) -> bool:
graph: ir.Graph = context.graph
return value in graph.initializers.values()

def mat_mul_n_bits_pattern_check(context, *, q_weight, g_idx, mat_mul_n_bits_out: ir.Value, **_) -> bool:
Copy link
Contributor

@jambayk jambayk Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does q_weight here match for the input right before g_idx or it is whatever it is in the mat_mul_n_bits_pattern signature? The input before g_idx is qzero and can be optional. we want to check the second input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inputs of the pattern-function (mat_mul_n_bits_pattern) are bound to values in the graph, and these values are passed in as keyword-arguments to the rewrite function here. So, the order here doesn't really matter, though I usually just copy-paste and use the same argument list for both.

if not _is_initializer(context, q_weight):
return False
node: ir.Node = mat_mul_n_bits_out.producer()
block_size = node.attributes["block_size"].as_int()
k = node.attributes["K"].as_int()
if not _is_initializer(g_idx, q_weight):
return False
g_idx = g_idx.constant_value.numpy()
trivial_g_idx = np.arange(k, dtype=np.int32) // block_size
if not np.array_equal(g_idx, trivial_g_idx):
Fixed Show fixed Hide fixed

Check warning

Code scanning / lintrunner

RUFF/SIM103 Warning

Return the negated condition directly.
See https://docs.astral.sh/ruff/rules/needless-bool
# TODO: We can log why the pattern is not matched here

Check warning

Code scanning / lintrunner

RUFF/TD002 Warning

Missing author in TODO; try: # TODO(<author_name>): ... or # TODO @<author_name>: ....
See https://docs.astral.sh/ruff/rules/missing-todo-author
return False
return True

def mat_mul_n_bits_replacement(
op,
*,
input_A: ir.Value,
q_weight: ir.Value,
q_scales: ir.Value,
q_zeros: ir.Value,
bias: ir.Value,
mat_mul_n_bits_out: ir.Value,
**_,
):
node: ir.Node = mat_mul_n_bits_out.producer()
# TODO(justinchuby): Keep the old name of the node
k: int = node.attributes["K"].as_int()
block_size: int = node.attributes["block_size"].as_int()
num_k_blocks = math.ceil(k / block_size)
# will make this a per-axis DQ if num_k_blocks == 1
# - originally per-axis K == block_size
# - originally blockwise but K <= block_size
is_per_axis = num_k_blocks == 1

# only deal with 4 bits (int4) for now
if node_attributes["bits"] != 4:
logger.debug("%s uses %d bits, only 4 bits is supported", node_name, node_attributes["bits"])
continue

# we can only deal with trivial g_idx, dequantize linear does not support g_idx
if len(node_inputs) >= 5 and node_inputs[4]:
g_idx = dag.get_initializer_np_array(node_inputs[4])
trivial_g_idx = np.arange(K, dtype=np.int32) // block_size
if not np.array_equal(g_idx, trivial_g_idx):
continue

# name for the DQ node
dq_name = self._get_new_node_name(dag, node_name, "DequantizeLinear")
# weight, scales, zeros
# (name, new_name, unpacked column size)
quant_inputs = [
(node_inputs[1], f"{dq_name}.qweight", K),
(node_inputs[2], f"{dq_name}.scales", num_k_blocks),
]
if len(node_inputs) >= 4 and node_inputs[3]:
quant_inputs.append((node_inputs[3], f"{dq_name}.qzeros", num_k_blocks))
dq_inputs = []

for qi_name, new_qi_name, unpacked_col_size in quant_inputs:
# get the np array
# weight: uint8, scales: float32, zeros: uint8
qi = dag.get_initializer_np_array(qi_name)
# reshape to 2D
qi = qi.reshape(N, -1)

# there are cases where unpack and repack is not needed: no transpose + no padding
# but will still do it for simplicity
if qi.dtype == np.uint8:
qi = self._unpack_on_row(qi)
# remove padding if any
qi = qi[:, :unpacked_col_size]

# Make 1-D scale or qzero if per-axis
if new_qi_name.endswith((".scales", ".qzeros")) and is_per_axis:
qi = qi.flatten()

# skip if is a no-op zero point
if not config["add_zero_point"] and new_qi_name.endswith(".qzeros") and np.all(qi == default_mnb_zp):
continue

if not config["use_transpose_op"]:
# becomes K X N
qi = qi.T

if qi.dtype == np.uint8:
if config["use_int4"]:
# no worries about making signed since the values only use 4 bits
qi = qi.astype(np.int8)
# subtract 8 to make it signed
# no worries here again since the values are in the range 0-15 and numpy uses 2's complement
qi -= 8

# pack in the format expected by onnx and create the tensor
tensor = onnx.helper.make_tensor(
new_qi_name,
int_elem_type,
qi.shape,
self._pack_on_flat(qi).tobytes(),
raw=True,
)
else:
tensor = onnx.numpy_helper.from_array(qi, name=new_qi_name)

# add the initializer
dag.add_initializer(tensor, graph_idx)
# add the input name
dq_inputs.append(new_qi_name)
# DQ default zp is 0 but MatMulNBits is 8, so we need to add a zero tensor with all 8s
# no need to add for int4 if add_zero_point is False
if len(dq_inputs) == 2 and (config["add_zero_point"] or not config["use_int4"]):
zp_name = f"{dq_name}.qzeros"
zp_shape = (
[N] if is_per_axis else ([N, num_k_blocks] if config["use_transpose_op"] else [num_k_blocks, N])
)
zp_tensor = onnx.helper.make_tensor(
zp_name,
int_elem_type,
zp_shape,
# no zp in matmulnbits is equivalent to 8 uint4 and 0 int4 in DQ
self._pack_on_flat(np.zeros(N * num_k_blocks, dtype=int_np_dtype) + 8 - default_mnb_zp).tobytes(),
raw=True,
)
dag.add_initializer(zp_tensor, graph_idx)
dq_inputs.append(zp_name)

# onnx dtype for the float tensors (scale, dequantized weight, matmul inputs+outputs)
float_elem_type = onnx.helper.np_dtype_to_tensor_dtype(dag.get_initializer_np_array(node_inputs[2]).dtype)

# new nodes and value infos to add to the graph
# ensure that the node names and output names are unique
# will add the new nodes, make consumers use the new output and remove the node
# if output is a model output, rename it back to the original name
new_nodes = []
new_value_infos = []

# DequantizeLinear
dq_name = self._get_new_node_name(dag, node_name, "DequantizeLinear")
dq_output = f"{dq_name}/output_0"
new_nodes.append(
onnx.helper.make_node(
"DequantizeLinear",
dq_inputs,
[dq_output],
name=dq_name,
block_size=None if is_per_axis else block_size,
# for some reason block_wise and per-axis appear to use swapped axis
# flip the axis if it is per-axis
axis=(1 if config["use_transpose_op"] else 0) ^ (1 if is_per_axis else 0),
)
)
new_value_infos.append(
onnx.helper.make_tensor_value_info(
dq_output, float_elem_type, shape=[N, K] if config["use_transpose_op"] else [K, N]
)
# DequantizeLinear -> Transpose -> MatMul -> Add (optional)
dq = op.DequantizeLinear(
q_weight,
q_scales,
q_zeros,
block_size=None if is_per_axis else block_size,
# for some reason block_wise and per-axis appear to use swapped axis
# flip the axis if it is per-axis
axis=config["use_transpose_op"] or is_per_axis,
jambayk marked this conversation as resolved.
Show resolved Hide resolved
)
# TODO(justinchuby): Improve the way we mark something that needs repacking
dq.producer().meta["needs_repacking"] = True
dq.producer().meta["K"] = k
dq.producer().meta["N"] = node.attributes["N"].as_int()
if config["use_transpose_node"]:
dq = op.Transpose(dq, perm=[1, 0])
matmul = op.MatMul(input_A, dq)
if bias is not None:
matmul = op.Add(matmul, bias)
return matmul

replace_mat_mul_n_bits = orp.RewriteRule(

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable replace_mat_mul_n_bits is not used.

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'replace_mat_mul_n_bits' (unused-variable)
See unused-variable.

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable replace\_mat\_mul\_n\_bits is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
mat_mul_n_bits_pattern,
mat_mul_n_bits_pattern_check,
mat_mul_n_bits_replacement,
)
# TODO(justinchuby): Call the rewriter with replace_mat_mul_n_bits

# 2. Repack the quantized weights
for node in ir_model.graph:
if "needs_repacking" not in node.meta:
continue

if config["use_transpose_op"]:
# Transpose
transpose_name = self._get_new_node_name(dag, node_name, "Transpose")
transpose_output = f"{transpose_name}/output_0"
new_nodes.append(
onnx.helper.make_node(
"Transpose", [dq_output], [transpose_output], name=transpose_name, perm=[1, 0]
)
)
new_value_infos.append(
onnx.helper.make_tensor_value_info(transpose_output, float_elem_type, shape=[K, N])
)
matmul_input = transpose_output
else:
matmul_input = dq_output
# Add Logic handling input 3

# MatMul
matmul_name = self._get_new_node_name(dag, node_name, "MatMul")
matmul_output = f"{matmul_name}/output_0"
new_nodes.append(
onnx.helper.make_node("MatMul", [node_inputs[0], matmul_input], [matmul_output], name=matmul_name)
unpacked_weight_arrays = _unpack_weights(

Check failure

Code scanning / lintrunner

PYLINT/E0602 Error

Undefined variable '_unpack_weights' (undefined-variable)
See undefined-variable.

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name \_unpack\_weights.
See https://docs.astral.sh/ruff/rules/undefined-name
node.meta["K"],
node.meta["N"],
node.inputs[1].const_value.numpy(),
node.inputs[2].const_value.numpy(),
node.inputs[3].const_value.numpy(),
)
if node_output_proto:
# the output shape is the same as the original MatMulNBits node
matmul_output_proto = onnx.ValueInfoProto()
matmul_output_proto.CopyFrom(node_output_proto)
matmul_output_proto.name = matmul_output
new_value_infos.append(matmul_output_proto)
final_name = matmul_name
final_output = matmul_output

if len(node_inputs) >= 5 and node_inputs[4]:
# Bias Add
# it has bias
bias_i_name = node_inputs[4]
new_bias_i_name = bias_i_name.replace("MatMulNBits", "MatMul")
bias_initiaizer = onnx.numpy_helper.from_array(
dag.get_initializer_np_array(bias_i_name), name=new_bias_i_name
)
dag.add_initializer(bias_initiaizer, graph_idx)

bias_name = self._get_new_node_name(dag, node_name, "Add")
bias_output = f"{bias_name}/output_0"
new_nodes.append(
onnx.helper.make_node("Add", [matmul_output, new_bias_i_name], [bias_output], name=bias_name)
)
if node_output_proto:
# the output shape is the same as the original MatMulNBits node
bias_output_proto = onnx.ValueInfoProto()
bias_output_proto.CopyFrom(node_output_proto)
bias_output_proto.name = bias_output
new_value_infos.append(bias_output_proto)
final_name = bias_name
final_output = bias_output

for node in new_nodes:
dag.add_node(node, graph_idx)

# change the input of the consumers
for consumer in dag.get_consumers(node_name):
dag.replace_node_input(consumer, node_output, final_output)

# add the new value infos
for vi in new_value_infos:
dag.add_value_info(vi, graph_idx)

# remove the node
if is_model_output:
dag.remove_output(node_output)
dag.remove_node(node_name)

# rename to original name if it is a model output
if is_model_output:
dag.rename_node_output(final_name, final_output, node_output)
dag.make_output(node_output)

num_modified += 1

if num_modified == 0:
logger.info("No MatMulNBits nodes found. Returning the original model.")
return model

dag.update()
logger.debug("Modified %d MatMulNBits nodes", num_modified)
# this might not work for all models but will just update the opset version to 21
# if there is an issue, try the logic in OnnxOpVersionConversion
dag.model.opset_import[0].version = max(21, dag.model.opset_import[0].version)

# save the model to the output path and return the model
return model_proto_to_olive_model(dag.model, output_model_path, config)
node.inputs[1].const_value = ir.Tensor(unpacked_weight_arrays[0])
node.inputs[2].const_value = ir.Tensor(unpacked_weight_arrays[1])
if len(unpacked_weight_arrays) == 3:
# TODO(justinchuby): Specify a name to input_3
input_3 = ir.Value(None)
input_3.const_value = ir.Tensor(unpacked_weight_arrays[2])
# TODO(justinchuby): Ensure the node has three inputs
node.replace_input_with(3, input_3)
ir_model.graph.register_initializer(input_3)

# Clear the meta data
del node.meta["needs_repacking"]
del node.meta["K"]
del node.meta["N"]

# TODO(justinchuby): Register and remove initializers
ir_model.opset_imports[""] = max(21, ir_model.opset_imports[""])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Use a more robust version conversion process


return model_proto_to_olive_model(ir.serde.serialize_model(ir_model), output_model_path, config)

@staticmethod
def _get_new_node_name(dag: OnnxDAG, old_name: str, op_type: str):
Expand Down Expand Up @@ -346,14 +215,3 @@
# mask out the first 4 bits
tensor &= 0xF
return tensor.reshape(tensor.shape[0], -1)

@staticmethod
def _pack_on_flat(tensor: "NDArray") -> "NDArray":
"""Pack two uint4 into one uint8 on a flattened tensor."""
tensor = tensor.flatten()

if len(tensor) % 2:
tensor = np.pad(tensor, (0, 1), mode="constant")

# right 4 bits are the first uint4
return (tensor[0::2] & 0xF) | ((tensor[1::2] & 0xF) << 4)
Loading