You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for your great job!
I want to apply it in the VAE model but it got some problems.
Here is my code: import torch
import torch_pruning as tp
from torch import nn
from PIL import Image
from diffusers import AutoencoderKL
from tqdm.auto import tqdm
import numpy as np
device = "cuda:6" if torch.cuda.is_available() else "cpu"
vae_model_id = "stabilityai/stable-diffusion-2-inpainting/vae"
vae = AutoencoderKL.from_pretrained(vae_model_id).to(torch.float16).to(device)
model = vae
example_inputs = torch.randn(1, 3, 224, 224, dtype=torch.float16).to(device)
imp = tp.importance.MagnitudeImportance()
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!
File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1844, in Module._call_impl(self, *args, **kwargs) 1841 return inner() 1843 try:
-> 1844 return inner() 1845 except Exception: 1846 # run always called hooks if they have not already been run 1847 # For now only forward hooks have the always_call option but perhaps 1848 # this functionality should be added to full backward hooks as well. 1849 for hook_id, hook in _global_forward_hooks.items():
File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1790, in Module._call_impl..inner() 1787 bw_hook = BackwardHook(self, full_backward_hooks, backward_pre_hooks) 1788 args = bw_hook.setup_input_hook(args)
-> 1790 result = forward_call(*args, **kwargs) 1791 if _global_forward_hooks or self._forward_hooks: 1792 for hook_id, hook in ( 1793 *_global_forward_hooks.items(), 1794 *self._forward_hooks.items(), 1795 ): 1796 # mark that always called hook is run
...
--> 136 assert hidden_states.shape[1] == self.channels 138 if self.norm is not None: 139 hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
AssertionError:
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...`
It seems like the error was raised due to a change in the channel shape.
The text was updated successfully, but these errors were encountered:
Thanks for your great job!
I want to apply it in the VAE model but it got some problems.
Here is my code:
import
torchimport torch_pruning as tp
from torch import nn
from PIL import Image
from diffusers import AutoencoderKL
from tqdm.auto import tqdm
import numpy as np
device = "cuda:6" if torch.cuda.is_available() else "cpu"
vae_model_id = "stabilityai/stable-diffusion-2-inpainting/vae"
vae = AutoencoderKL.from_pretrained(vae_model_id).to(torch.float16).to(device)
model = vae
example_inputs = torch.randn(1, 3, 224, 224, dtype=torch.float16).to(device)
imp = tp.importance.MagnitudeImportance()
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!
def forward_fn(model, inputs):
return model(inputs)[0]
iterative_steps = 3 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
model.encoder,
example_inputs,
importance=imp,
iterative_steps=iterative_steps,
ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
ignored_layers=ignored_layers,
forward_fn=forward_fn,
)
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")`
`
AssertionError Traceback (most recent call last)
Cell In[5], line 1
----> 1 macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
2 print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")
File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch_pruning/utils/op_counter.py:35, in count_ops_and_params(model, example_inputs, layer_wise)
33 _ = flops_model(**example_inputs)
34 else:
---> 35 _ = flops_model(example_inputs)
36 flops_count, params_count, _layer_flops, _layer_params = flops_model.compute_average_flops_cost()
37 layer_flops = {}
File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1844, in Module._call_impl(self, *args, **kwargs)
1841 return inner()
1843 try:
-> 1844 return inner()
1845 except Exception:
1846 # run always called hooks if they have not already been run
1847 # For now only forward hooks have the always_call option but perhaps
1848 # this functionality should be added to full backward hooks as well.
1849 for hook_id, hook in _global_forward_hooks.items():
File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1790, in Module._call_impl..inner()
1787 bw_hook = BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1788 args = bw_hook.setup_input_hook(args)
-> 1790 result = forward_call(*args, **kwargs)
1791 if _global_forward_hooks or self._forward_hooks:
1792 for hook_id, hook in (
1793 *_global_forward_hooks.items(),
1794 *self._forward_hooks.items(),
1795 ):
1796 # mark that always called hook is run
...
--> 136 assert hidden_states.shape[1] == self.channels
138 if self.norm is not None:
139 hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
AssertionError:
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...`
It seems like the error was raised due to a change in the channel shape.
The text was updated successfully, but these errors were encountered: