Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dose it support diffusion model now? #449

Open
DENGBOYU-REX opened this issue Dec 24, 2024 · 0 comments
Open

Dose it support diffusion model now? #449

DENGBOYU-REX opened this issue Dec 24, 2024 · 0 comments

Comments

@DENGBOYU-REX
Copy link

DENGBOYU-REX commented Dec 24, 2024

Thanks for your great job!
I want to apply it in the VAE model but it got some problems.
Here is my code:
import torch
import torch_pruning as tp
from torch import nn
from PIL import Image
from diffusers import AutoencoderKL
from tqdm.auto import tqdm
import numpy as np

device = "cuda:6" if torch.cuda.is_available() else "cpu"

vae_model_id = "stabilityai/stable-diffusion-2-inpainting/vae"
vae = AutoencoderKL.from_pretrained(vae_model_id).to(torch.float16).to(device)
model = vae
example_inputs = torch.randn(1, 3, 224, 224, dtype=torch.float16).to(device)
imp = tp.importance.MagnitudeImportance()
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!

def forward_fn(model, inputs):
return model(inputs)[0]

iterative_steps = 3 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
model.encoder,
example_inputs,
importance=imp,
iterative_steps=iterative_steps,
ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
ignored_layers=ignored_layers,
forward_fn=forward_fn,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")`

`

AssertionError Traceback (most recent call last)
Cell In[5], line 1
----> 1 macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
2 print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")

File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)

File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch_pruning/utils/op_counter.py:35, in count_ops_and_params(model, example_inputs, layer_wise)
33 _ = flops_model(**example_inputs)
34 else:
---> 35 _ = flops_model(example_inputs)
36 flops_count, params_count, _layer_flops, _layer_params = flops_model.compute_average_flops_cost()
37 layer_flops = {}

File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1844, in Module._call_impl(self, *args, **kwargs)
1841 return inner()
1843 try:
-> 1844 return inner()
1845 except Exception:
1846 # run always called hooks if they have not already been run
1847 # For now only forward hooks have the always_call option but perhaps
1848 # this functionality should be added to full backward hooks as well.
1849 for hook_id, hook in _global_forward_hooks.items():

File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1790, in Module._call_impl..inner()
1787 bw_hook = BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1788 args = bw_hook.setup_input_hook(args)
-> 1790 result = forward_call(*args, **kwargs)
1791 if _global_forward_hooks or self._forward_hooks:
1792 for hook_id, hook in (
1793 *_global_forward_hooks.items(),
1794 *self._forward_hooks.items(),
1795 ):
1796 # mark that always called hook is run
...
--> 136 assert hidden_states.shape[1] == self.channels
138 if self.norm is not None:
139 hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

AssertionError:
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...`

It seems like the error was raised due to a change in the channel shape.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant