Skip to content

Commit

Permalink
allow for teacher to be passed in as instantiated model (#2170) (#2172)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins authored Mar 11, 2024
1 parent 69a99e1 commit 9ae35bd
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
6 changes: 6 additions & 0 deletions src/sparseml/transformers/finetune/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ class ModelArguments:
)
},
)
distill_teacher: Optional[str] = field(
default=None,
metadata={
"help": "Teacher model (a trained text generation model)",
},
)
config_name: Optional[str] = field(
default=None,
metadata={
Expand Down
10 changes: 5 additions & 5 deletions src/sparseml/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ def intialize_model_from_path(
)
teacher_config = (
AutoConfig.from_pretrained(
training_args.distill_teacher,
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
)
if training_args.distill_teacher
if model_args.distill_teacher
else None
)

Expand Down Expand Up @@ -208,11 +208,11 @@ def intialize_model_from_path(

teacher = (
SparseAutoModel.text_generation_from_pretrained(
model_name_or_path=training_args.distill_teacher,
model_name_or_path=model_args.distill_teacher,
sequence_length=None, # use model default
**teacher_kwargs,
)
if training_args.distill_teacher is not None
if model_args.distill_teacher is not None
else None
)

Expand Down Expand Up @@ -289,7 +289,7 @@ def main(

# Detecting last checkpoint.
last_checkpoint = None
teacher = None
teacher = model_args.distill_teacher
model_path = None
model = model_args.model
# Load tokenizer
Expand Down
6 changes: 0 additions & 6 deletions src/sparseml/transformers/finetune/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ class TrainingArguments(HFTrainingArgs):
arguments
"""

distill_teacher: Optional[str] = field(
default=None,
metadata={
"help": "Teacher model (a trained text generation model)",
},
)
best_model_after_epoch: int = field(
default=None,
metadata={"help": "Epoch after which best model will be saved."},
Expand Down

0 comments on commit 9ae35bd

Please sign in to comment.