diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 63f0c9a15..647f02694 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -41,10 +41,10 @@ def _clone_layer(self): else: clone = self.layer.weight.data.clone() - if isinstance(clone, nn.Conv2d): + if isinstance(self.layer, nn.Conv2d): clone = clone.flatten(1) - if isinstance(clone, transformers.pytorch_utils.Conv1D): + if isinstance(self.layer, transformers.pytorch_utils.Conv1D): clone = clone.t() return clone.to(device=self.device, dtype=torch.float)