Skip to content

Commit

Permalink
Merge branch 'main' into weekly-bump
Browse files Browse the repository at this point in the history
  • Loading branch information
ko3n1g authored Jan 21, 2025
2 parents 1e4b7d9 + 499161e commit c906a50
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 27 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/mcore-tag-bump-bot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ jobs:
source-ref: main
build-arg: MCORE_TAG
dockerfile: Dockerfile.ci
base-branch: main
base-branch: weekly-bump
cicd-label: Run CICD
pr-reviewers: 'pablo-garay'
secrets:
PAT: ${{ secrets.PAT }}
PAT: ${{ secrets.PAT }}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
trust_remote_code=False,
default_dtype=torch.bfloat16,
load_in_4bit=False,
attn_implementation="sdpa",
):
super().__init__()
self.save_hyperparameters()
Expand All @@ -58,6 +59,7 @@ def __init__(
self.trust_remote_code = trust_remote_code
self.default_dtype = default_dtype
self.load_in_4bit = load_in_4bit
self.attn_implementation = attn_implementation

@property
def tokenizer(self):
Expand All @@ -82,14 +84,18 @@ def configure_model(self):
torch_dtype='auto',
trust_remote_code=self.trust_remote_code,
load_in_4bit=self.load_in_4bit,
attn_implementation=self.attn_implementation,
)
else:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code)
dtype = getattr(config, 'torch_dtype', self.default_dtype)
self.model = AutoModelForCausalLM.from_config(
config, torch_dtype=dtype, trust_remote_code=self.trust_remote_code
config,
torch_dtype=dtype,
trust_remote_code=self.trust_remote_code,
attn_implementation=self.attn_implementation,
)

# Apply FSDP2 and TP to the model
Expand Down
22 changes: 15 additions & 7 deletions nemo/collections/llm/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class LinearAdapter(nn.Linear):
orig_linear (nn.Module): the linear module to augment.
dim (int): lora's dim in_features -> dim -> out_features.
alpha (int): lora's scaling alpha.
dropout (float): dropout prob (default: 0.1).
dropout (float): dropout prob (default: 0.0).
dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post)
lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform'])
lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they
Expand All @@ -64,7 +64,7 @@ def __init__(
orig_linear,
dim=8,
alpha=32,
dropout=0.1,
dropout=0.0,
dropout_position='post',
lora_A_init_method='xavier',
lora_dtype=None,
Expand All @@ -82,14 +82,22 @@ def __init__(
if orig_linear.bias is not None:
self.bias.data.copy_(orig_linear.bias.data)
# initialize the adapte
LinearAdapter._init_adapter(self)
LinearAdapter._init_adapter(
self,
dim=dim,
alpha=alpha,
dropout=dropout,
dropout_position=dropout_position,
lora_A_init_method=lora_A_init_method,
lora_dtype=lora_dtype,
)

@staticmethod
def _init_adapter(
obj,
dim=8,
alpha=32,
dropout=0.1,
dropout=0.0,
dropout_position='post',
lora_A_init_method='xavier',
lora_dtype=None,
Expand All @@ -101,7 +109,7 @@ def _init_adapter(
obj (LinearAdapter | nn.Module): input module to adapt.
dim (int): lora's dim in_features -> dim -> out_features.
alpha (int): lora's scaling alpha.
dropout (float): dropout prob (default: 0.1).
dropout (float): dropout prob (default: 0.0).
dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post)
lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform'])
lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they
Expand Down Expand Up @@ -155,7 +163,7 @@ def patch_linear_module(
orig_linear,
dim=8,
alpha=32,
dropout=0.1,
dropout=0.0,
dropout_position='post',
lora_A_init_method='xavier',
lora_dtype=None,
Expand All @@ -175,7 +183,7 @@ def patch_linear_module(
orig_linear (nn.Linear): the module we add adapter to.
dim (int, optional): Lora dim. Defaults to 8.
alpha (int, optional): Lora alpha scale. Defaults to 32.
dropout (float, optional): dropout prob. Defaults to 0.1.
dropout (float, optional): dropout prob. Defaults to 0.0.
dropout_position (str, optional): location to apply dropout wrt lora.
Defaults to 'post' (choices: 'pre', 'post').
lora_A_init_method (str, optional): lora_a init method. Defaults to 'xavier'.
Expand Down
33 changes: 16 additions & 17 deletions nemo/lightning/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,7 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
if isinstance(trainer, fl.Fabric):
raise NotImplementedError("Fabric is not supported yet.")

trainer_ckpt_path = self.get_trainer_ckpt_path(model)
if trainer_ckpt_path:
trainer.ckpt_path = trainer_ckpt_path
trainer.checkpoint_callback.last_model_path = trainer_ckpt_path
# Load artifacts
if getattr(self.restore_config, 'load_artifacts', False):
if isinstance(trainer_ckpt_path, AdapterPath):
# load tokenizer from the base model during peft resume, in case the first peft checkpoint
# is deleted before the current peft checkpoint is saved
context_path = trainer_ckpt_path.base_model_path / "context"
if not context_path.exists():
context_path = trainer_ckpt_path.base_model_path
else:
context_path = self.get_context_path(model)
model = _try_restore_tokenizer(model, context_path)

elif self.restore_config:
if self.restore_config:
new_path = self._extract_path(
model=model,
path=self.restore_config.path,
Expand All @@ -139,6 +123,21 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):

_try_restore_tokenizer(model, context_path)

elif (trainer_ckpt_path := self.get_trainer_ckpt_path(model)) is not None:
trainer.ckpt_path = trainer_ckpt_path
trainer.checkpoint_callback.last_model_path = trainer_ckpt_path
# Load artifacts
if getattr(self.restore_config, 'load_artifacts', False):
if isinstance(trainer_ckpt_path, AdapterPath):
# load tokenizer from the base model during peft resume, in case the first peft checkpoint
# is deleted before the current peft checkpoint is saved
context_path = trainer_ckpt_path.base_model_path / "context"
if not context_path.exists():
context_path = trainer_ckpt_path.base_model_path
else:
context_path = self.get_context_path(model)
model = _try_restore_tokenizer(model, context_path)

def _extract_path(
self, model: Optional[io.ConnectorMixin], path: str, adapter_path: Optional[str] = None
) -> BasePath:
Expand Down

0 comments on commit c906a50

Please sign in to comment.