Skip to content

Commit

Permalink
fix lora finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Nov 3, 2024
1 parent 864c6fb commit 1fed859
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def setup(self, cfg: DictConfig) -> None:
if self._resume_from_checkpoint
else None
),
quantizer_cfg=cfg.get("quantizer", None),
mixed_precision_cfg=cfg.get("mixed_precision", None),
)

self._tokenizer = config.instantiate(cfg.tokenizer)
Expand Down Expand Up @@ -411,7 +411,7 @@ def _setup_model(
compile_model: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
quantizer_cfg: Optional[DictConfig] = None,
mixed_precision_cfg: Optional[DictConfig] = None,
) -> nn.Module:
with training.set_default_dtype(self._dtype), self._device:
model = config.instantiate(cfg_model)
Expand All @@ -433,9 +433,13 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

if quantizer_cfg is not None:
log.info(f"Preparing model with {quantizer_cfg._component_}")
quantizer = config.instantiate(quantizer_cfg)
if mixed_precision_cfg is not None and mixed_precision_cfg.get(
"enabled", False
):
log.info(f"Preparing model with {mixed_precision_cfg._component_}")
cfg = mixed_precision_cfg.copy()
cfg.pop("enabled", None)
quantizer = config.instantiate(cfg)
model = quantizer.prepare(model)

base_missing, base_unexpected = model.load_state_dict(
Expand Down

0 comments on commit 1fed859

Please sign in to comment.