Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QuanitztionModifier] initialize per channel scales and zps to correct shape #1738

Merged
merged 12 commits into from
Oct 13, 2023

Conversation

bfineran
Copy link
Contributor

fixes bug found by @anmarques
pytorch does not initialize channel wise qat scales and zero points to their final shape, just to (1,). this adds a patch to expand them to their expected shape

test_plan:
@anmarques to fully reproduce and verify

existing unit tests pass

manual inspection of initialized scales and zero points:

from sparseml.pytorch.sparsification import QuantizationModifier
from sparseml.pytorch.models import resnet50

r = resnet50()
QuantizationModifier(scheme=dict(weights=dict(strategy="channel"))).apply(r)
weight_quants = [(n, mod) for n, mod in r.named_modules() if "weight_fake_quant" in n]

@bfineran bfineran self-assigned this Sep 26, 2023
Satrat
Satrat previously approved these changes Sep 28, 2023
anmarques
anmarques previously approved these changes Oct 5, 2023
Copy link
Member

@anmarques anmarques left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested on Llama2

@anmarques anmarques merged commit 1a2bddf into main Oct 13, 2023
11 checks passed
@anmarques anmarques deleted the channel-wise-scale-init branch October 13, 2023 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants