Skip to content

Commit

Permalink
check device before sync (ModelCloud#796)
Browse files Browse the repository at this point in the history
* check device before sync

* cleanup
  • Loading branch information
LRL-ModelCloud authored Dec 6, 2024
1 parent 26961ce commit cd44c6e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from transformers.modeling_utils import no_init_weights
from transformers.utils.generic import ContextManagers

from ._const import DEVICE, SUPPORTED_MODELS, get_best_device, is_torch_support_xpu
from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear
from ..nn_modules.qlinear.ipex import IPEXQuantLinear, ipex_dtype
from ..quantization import QuantizeConfig
Expand All @@ -25,6 +24,7 @@
from ..utils.model import (auto_dtype_from_config, convert_gptq_v1_to_v2_format, find_layers,
get_checkpoints, get_moe_layer_modules, gptqmodel_post_init, make_quant,
simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes)
from ._const import DEVICE, SUPPORTED_MODELS, get_best_device, is_torch_support_xpu

logger = setup_logger()

Expand Down
5 changes: 3 additions & 2 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,11 @@ def fasterquant(
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
logger.debug(torch.sum(Losses))

if torch.cuda.is_available():
if self.dev.type == "cuda":
torch.cuda.synchronize()
if hasattr(torch, "xpu") and torch.xpu.is_available():
elif self.dev.type == "xpu":
torch.xpu.synchronize()

duration = time.time() - tick
avg_loss = torch.sum(Losses).item() / self.nsamples

Expand Down
8 changes: 4 additions & 4 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@
from transformers import AutoConfig, PretrainedConfig
from transformers.utils.hub import cached_file

from .backend import BACKEND
from .importer import select_quant_linear
from .logger import setup_logger
from .progress import ProgressBar
from ..models._const import CPU, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS
from ..nn_modules.qlinear import BaseQuantLinear
from ..nn_modules.qlinear.exllama import ExllamaQuantLinear
from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear
from ..nn_modules.qlinear.ipex import IPEXQuantLinear
from ..nn_modules.qlinear.torch import TorchQuantLinear
from ..quantization import FORMAT, QuantizeConfig
from .backend import BACKEND
from .importer import select_quant_linear
from .logger import setup_logger
from .progress import ProgressBar

logger = setup_logger()

Expand Down

0 comments on commit cd44c6e

Please sign in to comment.