Skip to content

Commit

Permalink
Update expected scheme to account for new order (name first)
Browse files Browse the repository at this point in the history
  • Loading branch information
anmarques committed Oct 13, 2023
1 parent ee9de3c commit 0949cd0
Showing 1 changed file with 8 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sparseml.pytorch.sparsification.quantization.quantize import (
is_qat_helper_module,
is_quantizable_module,
_match_submodule_name_or_type,
)
from tests.sparseml.pytorch.helpers import (
ConvNet,
Expand Down Expand Up @@ -66,23 +67,16 @@ def _assert_observers_eq(observer_1, observer_2):
_assert_observers_eq(qconfig_1.weight, qconfig_2.weight)


def _test_quantized_module(base_model, modifier, module, name):
def _test_quantized_module(base_model, modifier, module, name, override_key):
# check quant scheme and configs are set
quantization_scheme = getattr(module, "quantization_scheme", None)
qconfig = getattr(module, "qconfig", None)
assert quantization_scheme is not None
assert qconfig is not None

# if module type is overwritten in by scheme_overrides, check scheme set correctly
# name takes precedence over class
module_type_name = module.__class__.__name__
print(name, module_type_name)
print(modifier.scheme_overrides.keys())
if name in modifier.scheme_overrides:
expected_scheme = modifier.scheme_overrides[name]
assert quantization_scheme == expected_scheme
elif module_type_name in modifier.scheme_overrides:
expected_scheme = modifier.scheme_overrides[module_type_name]
if override_key is not None:
expected_scheme = modifier.scheme_overrides[override_key]
assert quantization_scheme == expected_scheme

is_quant_wrapper = isinstance(module, torch_quantization.QuantWrapper)
Expand Down Expand Up @@ -154,7 +148,10 @@ def _test_qat_applied(modifier, model):
_test_qat_wrapped_module(model, name)
elif is_quantizable:
# check each target module is quantized
_test_quantized_module(model, modifier, module, name)
override_key = _match_submodule_name_or_type(
module, name, list(modifier.scheme_overrides.keys()),
)
_test_quantized_module(model, modifier, module, name, override_key)
else:
# check all non-target modules are not quantized
assert not hasattr(module, "quantization_scheme")
Expand Down

0 comments on commit 0949cd0

Please sign in to comment.