Skip to content

Commit

Permalink
1
Browse files Browse the repository at this point in the history
  • Loading branch information
Yizhen committed Oct 5, 2024
1 parent 94021d3 commit db3127b
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion src/lmflow/pipeline/finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
send_example_telemetry,
)
import numpy as np
import torch

import lmflow.optim.optimizers as optim
from lmflow.args import OptimizerNames
Expand All @@ -40,7 +41,8 @@


logger = logging.getLogger(__name__)

torch.manual_seed(42)
np.random.seed(42)

class Finetuner(BaseTuner):
"""
Expand Down Expand Up @@ -544,6 +546,12 @@ def on_step_begin(self, args, state, control, **kwargs):
# Check if it's time to switch active layers, including at step 0
if state.global_step % self.interval_steps == 0:
self.switch_active_layers()

layers = eval('self.' + self.layers_attribute) # Re-fetch layer references
self.previous_params = {
name: param.clone().detach()
for name, param in layers[self.active_layers_indices[0]].named_parameters()
}

def switch_active_layers(self):
# First, disable gradients for all layers
Expand All @@ -558,6 +566,15 @@ def switch_active_layers(self):
for idx in self.active_layers_indices:
for param in layers[idx].parameters():
param.requires_grad = True

def on_step_end(self, args, state, control, **kwargs):
layers = eval('self.' + self.layers_attribute) # Re-fetch layer references
for name, param in layers[self.active_layers_indices[0]].named_parameters():
if torch.equal(param, self.previous_params[name]):
print(f"No change in parameter: {name}")
else:
print(f"Parameter updated: {name}")


# Instantiate the callback
dynamic_layer_activation_callback = DynamicLayerActivationCallback(
Expand Down

0 comments on commit db3127b

Please sign in to comment.