Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: faster and less memory-intensive model [re]quantization #290

Closed

Conversation

latentCall145
Copy link
Contributor

What does this PR do?

Currently optimum.quanto's quantize/requantize functions run slowly for large models as quantized modules (e.g. QLinear) are initialized with random weights which immediately get replaced with pretrained weights. This also causes these methods to use more CPU RAM than necessary (which is especially visible with models whose weights are lazily loaded). This PR essentially makes quantize/requantize run instantly while using less RAM for lazily-loaded models.

Repro

from optimum.quanto import quantize, requantize, freeze, quantization_map, qint8
from safetensors.torch import save_model, load_file
from diffusers import FluxPipeline
from copy import deepcopy
import torch.nn as nn
import psutil
import torch
import json
import time

torch.random.manual_seed(0)
def free_mem(): return psutil.virtual_memory().available # get current amount of free RAM

# using Flux as its weights are lazily loaded
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe_copy = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
# text_encoder_2 = T5 encoder
model = pipe.text_encoder_2
model_copy = pipe_copy.text_encoder_2

# test quantization speed and memory usage
start_mem = free_mem()
tic = time.perf_counter()
quantize(model, weights=qint8)
print(f'quantize duration: {time.perf_counter() - tic:.3f} s, extra ram usage: {(start_mem - free_mem())/1e6:.3f} MB')

# test model freezing speed and memory usage
start_mem = free_mem()
tic = time.perf_counter()
freeze(model)
print(f'freeze duration: {time.perf_counter() - tic:.3f} s, extra ram usage: {(start_mem - free_mem())/1e6:.3f} MB')

# save model + qmap
save_model(model, '/tmp/model.sft')
with open('/tmp/qmap.json', 'w') as f:
    json.dump(quantization_map(model), f)

# load weight files
state_dict = load_file('/tmp/model.sft')
with open('/tmp/qmap.json', 'r') as f:
    qmap = json.load(f)

# test requantization speed and memory usage
start_mem = free_mem()
tic = time.perf_counter()
requantize(model_copy, state_dict, qmap)
print(f'requantize duration: {time.perf_counter() - tic:.3f} s, extra ram usage: {(start_mem - free_mem())/1e6:.3f} MB')

# make sure that quantized and requantized models have same parameters
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()):
    assert (p1-p2).sum() == 0

Results for above code (before fix)

Op Duration (s) CPU RAM Usage (MB)
quantize 41.66 9211.67
requantize 44.948 1999.847

Results for above code (after fix)

Op Duration (s) CPU RAM Usage (MB)
quantize 0.024 3.867
requantize 0.130 265.83

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you run all tests locally and make sure they pass.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@latentCall145 latentCall145 requested a review from dacorvo as a code owner August 22, 2024 04:20
if qmodule is None:
return None
with torch.no_grad():
qmodule.weight.copy_(module.weight)
qmodule.weight = module.weight
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely sure if the copy is necessary. The copy_ call uses extra RAM if module.weight is lazily loaded (i.e. module.weight hasn't been loaded yet but will load because of the copy_ call) which has caused my computer to run out of memory in the past (i.e. loading 12B model with 32 GB RAM)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copy_ is not strictly necessary, since so far the models are quantized in place. I had planned to change this, but it has proven very convenient and memory efficient.
The only issue I could see is when using tied weights should we want to modify them during calibration.
Anyway, when freezing, new quantized weights are created and even if the weights were still tied they would be untied.

@@ -200,13 +201,15 @@ def from_module(
activations: Optional[qtype] = None,
optimizer: Optional[Optimizer] = None,
):
qmodule = cls.qcreate(module, weights, activations, optimizer)
with init_empty_weights():
Copy link
Contributor Author

@latentCall145 latentCall145 Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_empty_weights() is called since it prevents random weight initialization from happening, which was the main cause of the slow performance of quantize/requantize

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your solution is correct, but accelerate is only an optional dependency: see an alternate solution in #291.

@@ -133,12 +133,11 @@ def move_tensor(t, device):
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
for name, param in m.named_buffers(recurse=False):
setattr(m, name, move_tensor(param, "cpu"))
# Freeze model and move to target device
freeze(model)
Copy link
Contributor Author

@latentCall145 latentCall145 Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why freeze(model) was called here. In quantize(), it's called so all of the weights get set to their quantized versions, but we're already setting quantized weights in requantize() via model.load_state_dict(), so I don't think the freeze(model) call does anything here.

Copy link
Collaborator

@dacorvo dacorvo Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're correct, but only because when loading the state dict we force an assign if the module is unfrozen (see line 186 of qmodule.py)

Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for tracking down the memory issues and useless operations: this is a very valuable contribution. Since I would like to avoid a direct dependency to accelerate, could you rebase on the branch I referenced in the comments ?

@@ -200,13 +201,15 @@ def from_module(
activations: Optional[qtype] = None,
optimizer: Optional[Optimizer] = None,
):
qmodule = cls.qcreate(module, weights, activations, optimizer)
with init_empty_weights():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your solution is correct, but accelerate is only an optional dependency: see an alternate solution in #291.

if qmodule is None:
return None
with torch.no_grad():
qmodule.weight.copy_(module.weight)
qmodule.weight = module.weight
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copy_ is not strictly necessary, since so far the models are quantized in place. I had planned to change this, but it has proven very convenient and memory efficient.
The only issue I could see is when using tied weights should we want to modify them during calibration.
Anyway, when freezing, new quantized weights are created and even if the weights were still tied they would be untied.

@@ -133,12 +133,11 @@ def move_tensor(t, device):
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
for name, param in m.named_buffers(recurse=False):
setattr(m, name, move_tensor(param, "cpu"))
# Freeze model and move to target device
freeze(model)
Copy link
Collaborator

@dacorvo dacorvo Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're correct, but only because when loading the state dict we force an assign if the module is unfrozen (see line 186 of qmodule.py)

@dacorvo
Copy link
Collaborator

dacorvo commented Aug 28, 2024

Rebased and merged as #297

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants