Skip to content

Commit

Permalink
Only compute fusion reg loss if fusion layer is trained (#505)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt authored Mar 2, 2023
1 parent 1c48e10 commit 954d782
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 19 deletions.
10 changes: 7 additions & 3 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,15 +858,19 @@ def forward_context(self, context: ForwardContext, *args, **kwargs):
context.adapter_fusion_attentions = defaultdict(dict)

def get_fusion_regularization_loss(self):
reg_loss = 0.0
reg_loss = None

target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device)
for i, layer in self.iter_layers():
for module in layer.modules():
if isinstance(module, AdapterLayer):
for _, layer_fusion in module.adapter_fusion_layer.items():
if hasattr(layer_fusion, "value"):
reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum()
if hasattr(layer_fusion, "value") and layer_fusion.value.weight.requires_grad:
layer_reg_loss = 0.01 * (target - layer_fusion.value.weight).pow(2).sum()
if reg_loss is None:
reg_loss = layer_reg_loss
else:
reg_loss += layer_reg_loss

return reg_loss

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/adapters/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
model = kwargs.pop("model")
if self.trainer.train_adapter_fusion:
fusion_reg_loss = model.base_model.get_fusion_regularization_loss()
fusion_reg_loss.backward()
if fusion_reg_loss is not None:
fusion_reg_loss.backward()


class Seq2SeqAdapterTrainer(AdapterTrainer, Seq2SeqTrainer):
Expand Down
92 changes: 77 additions & 15 deletions tests_adapters/test_adapter_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@


class TestAdapterTrainer(unittest.TestCase):
def get_model_config(self):
return BertConfig(
hidden_size=32,
num_hidden_layers=4,
num_attention_heads=4,
intermediate_size=37,
)

def test_resume_training(self):

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
Expand All @@ -29,7 +37,7 @@ def test_resume_training(self):
)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_config(self.get_model_config())
model.add_adapter("adapter")
model.add_adapter("additional_adapter")
model.set_active_adapters("adapter")
Expand All @@ -52,7 +60,7 @@ def test_resume_training(self):

trainer.train()
# create second model that should resume the training of the first
model_resume = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
model_resume = AutoModelForSequenceClassification.from_config(self.get_model_config())
model_resume.add_adapter("adapter")
model_resume.add_adapter("additional_adapter")
model_resume.set_active_adapters("adapter")
Expand All @@ -78,7 +86,7 @@ def test_resume_training_with_fusion(self):
)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_config(self.get_model_config())
model.add_adapter("adapter")
model.add_adapter("additional_adapter")
model.add_adapter_fusion(Fuse("adapter", "additional_adapter"))
Expand All @@ -101,7 +109,7 @@ def test_resume_training_with_fusion(self):
)

trainer.train()
model_resume = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
model_resume = AutoModelForSequenceClassification.from_config(self.get_model_config())
model_resume.add_adapter("adapter")
model_resume.add_adapter("additional_adapter")
model_resume.add_adapter_fusion(Fuse("adapter", "additional_adapter"))
Expand Down Expand Up @@ -155,7 +163,7 @@ def test_training_load_best_model_at_end_full_model(self):
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_config(self.get_model_config())
model.add_adapter("adapter")
model.train_adapter("adapter")

Expand Down Expand Up @@ -189,7 +197,7 @@ def test_training_load_best_model_at_end_adapter(self):
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_config(self.get_model_config())
model.add_adapter("adapter")
model.train_adapter("adapter")

Expand Down Expand Up @@ -221,7 +229,7 @@ def test_training_load_best_model_at_end_fusion(self):
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_config(self.get_model_config())
model.add_adapter("fuse_adapter_1")
model.add_adapter("fuse_adapter_2")
model.add_adapter_fusion(Fuse("fuse_adapter_1", "fuse_adapter_2"))
Expand Down Expand Up @@ -254,7 +262,7 @@ def test_reloading_prediction_head(self):
)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")

model = AutoAdapterModel.from_pretrained("bert-base-uncased")
model = AutoAdapterModel.from_config(self.get_model_config())

model.add_classification_head("adapter", num_labels=3)
model.add_classification_head("dummy", num_labels=2)
Expand Down Expand Up @@ -288,7 +296,7 @@ def test_reloading_prediction_head(self):

trainer.train()
# create second model that should resume the training of the first
model_resume = AutoAdapterModel.from_pretrained("bert-base-uncased")
model_resume = AutoAdapterModel.from_config(self.get_model_config())

model_resume.add_classification_head("adapter", num_labels=3)
model_resume.add_classification_head("dummy", num_labels=2)
Expand Down Expand Up @@ -323,7 +331,7 @@ def test_general(self):
)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")

model = AutoAdapterModel.from_pretrained("bert-base-uncased")
model = AutoAdapterModel.from_config(self.get_model_config())

model.add_classification_head("task", num_labels=3)

Expand Down Expand Up @@ -364,6 +372,61 @@ def test_general(self):
self.assertEqual("task", model.active_head)
self.assertEqual(Stack("task"), model.active_adapters)

def test_train_with_frozen_adapter_fusion(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")

model = AutoAdapterModel.from_config(self.get_model_config())

model.add_adapter("a")
model.add_adapter("b")

adapter_setup = Fuse("a", "b")

model.add_adapter_fusion(adapter_setup, set_active=True)

model.add_adapter("c")
model.add_classification_head("c")

model.train_adapter("c")

model.active_adapters = Stack(Fuse("a", "b"), "c")

# Since our config has a value matrix, make sure it is regularized.
# We do this by patching the fusion regularization function.
regularization_called = False
orig_fusion_regularization_loss = model.base_model.get_fusion_regularization_loss

def patched_fusion_reg_loss():
nonlocal regularization_called
regularization_called = True
return orig_fusion_regularization_loss()

model.base_model.get_fusion_regularization_loss = patched_fusion_reg_loss

with TemporaryDirectory() as tempdir:
training_args = TrainingArguments(
output_dir=tempdir,
do_train=True,
learning_rate=0.1,
logging_steps=1,
max_steps=1,
save_steps=1,
remove_unused_columns=False,
)
trainer = AdapterTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)

trainer.train()

self.assertTrue(regularization_called)

@require_ray
def test_hyperparameter_search_works_with_AdapterTrainer(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
Expand All @@ -375,12 +438,13 @@ def test_hyperparameter_search_works_with_AdapterTrainer(self):

def hp_space(params):
from ray import tune

return {
"learning_rate": tune.choice([0.1, 0.2]),
}

def model_init(trail=None):
model = AutoAdapterModel.from_pretrained("bert-base-uncased")
model = AutoAdapterModel.from_config(self.get_model_config())

model.add_classification_head("task", num_labels=3)

Expand All @@ -406,12 +470,10 @@ def model_init(trail=None):
model_init=model_init,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
eval_dataset=eval_dataset,
)

trainer.hyperparameter_search(
direction="minimize", hp_space=hp_space, backend="ray", n_trials=2
)
trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, backend="ray", n_trials=2)


if __name__ == "__main__":
Expand Down

0 comments on commit 954d782

Please sign in to comment.