You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As outlined in the title, closure-based versions of schedule-free optimizers don't handle Parameters constructed with requires_grad=False well.
The reason for this is that, for closure-based variants, the history-state z is initialized upon the first call to optimizer.step for all parameters, regardless of whether they received gradients (or even require gradients). As a consequence, we will always lerp back to that very first history-state no matter what is done with the parameters in the mean time.
A typical use-case impossible with closure-based, schedule-free optimizers would be to register parameters to the module with requires_grad=False, "training" the parameters by estimating/following some kind of distribution (EMA for self-supervised training would come to mind). Later on, these parameters might be switched to being optimized with gradient descent.
A dummy module is constructed, where one parameter has requires_grad=True, and one does not. The one that does not require gradients is "updated" manually based on some dummy heursitic. Asserting equal state of model parameters after small number of parameter updates will throw an error surprisingly
fromcollectionsimportdefaultdictimporttorchfromschedulefreeimport (
AdamWScheduleFree,
AdamWScheduleFreeClosure,
AdamWScheduleFreePaper,
AdamWScheduleFreeReference,
RAdamScheduleFree,
RAdamScheduleFreeClosure,
ScheduleFreeWrapper,
ScheduleFreeWrapperReference,
SGDScheduleFree,
SGDScheduleFreeClosure,
SGDScheduleFreeReference,
)
classSimpleModule(torch.nn.Module):
def__init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.param_no_grad=torch.nn.Parameter(torch.ones(1), requires_grad=False)
self.param_grad=torch.nn.Parameter(torch.ones(1), requires_grad=True)
defforward(self, x: torch.tensor):
returnx*self.param_grad+self.param_no_gradfor (optim_closure_cls, optim_cls) in ((AdamWScheduleFreeClosure, AdamWScheduleFree), (RAdamScheduleFreeClosure, RAdamScheduleFree), (SGDScheduleFreeClosure, SGDScheduleFree)):
print(f"Processing {optim_closure_cls.__name__, optim_cls.__name__}")
### Closure based Optimizermodule=SimpleModule()
optim_closure=optim_closure_cls(module.parameters())
closure_weights=defaultdict(dict)
defclosure():
optim_closure.zero_grad()
loss=module(torch.ones(1))
loss.backward()
returnloss# parameter fixed, "trained" e.g. by means of EMA or something elseforiinrange(3):
optim_closure.step(closure)
module.param_no_grad.copy_(module.param_no_grad*0.5)
for (name, param) inmodule.named_parameters():
closure_weights[i][name] =torch.clone(param)
# unfix parametermodule.param_no_grad.requires_grad=Trueforiinrange(3, 6):
optim_closure.step(closure)
for (name, param) inmodule.named_parameters():
closure_weights[i][name] =torch.clone(param)
print(f"grad: {module.param_grad.item():4f}, no_grad: {module.param_no_grad.item():4f}")
### Non-Closure based Optimizermodule=SimpleModule()
optim=optim_cls(module.parameters())
non_closure_weights=defaultdict(dict)
defclosure():
optim.zero_grad()
loss=module(torch.ones(1))
loss.backward()
returnloss# parameter fixed, "trained" e.g. by means of EMA or something elseforiinrange(3):
optim.train()
optim.step(closure)
module.param_no_grad.copy_(module.param_no_grad*0.5)
optim.eval()
for (name, param) inmodule.named_parameters():
non_closure_weights[i][name] =torch.clone(param)
# parameter unfixedmodule.param_no_grad.requires_grad=Trueforiinrange(3, 6):
optim.train()
optim.step(closure)
optim.eval()
for (name, param) inmodule.named_parameters():
non_closure_weights[i][name] =torch.clone(param)
print(f"grad: {module.param_grad.item():4f}, no_grad: {module.param_no_grad.item():4f}")
fork, vinclosure_weights.items():
for (param_closure, param_non_closure) inzip(v.values(), non_closure_weights[k].values()):
torch.testing.assert_close(param_closure, param_non_closure)
The text was updated successfully, but these errors were encountered:
Description
As outlined in the title, closure-based versions of schedule-free optimizers don't handle
Parameter
s constructed withrequires_grad=False
well.The reason for this is that, for closure-based variants, the history-state z is initialized upon the first call to optimizer.step for all parameters, regardless of whether they received gradients (or even require gradients). As a consequence, we will always lerp back to that very first history-state no matter what is done with the parameters in the mean time.
A typical use-case impossible with closure-based, schedule-free optimizers would be to register parameters to the module with
requires_grad=False
, "training" the parameters by estimating/following some kind of distribution (EMA for self-supervised training would come to mind). Later on, these parameters might be switched to being optimized with gradient descent.Will be closed by #56
MWE
A dummy module is constructed, where one parameter has requires_grad=True, and one does not. The one that does not require gradients is "updated" manually based on some dummy heursitic. Asserting equal state of model parameters after small number of parameter updates will throw an error surprisingly
The text was updated successfully, but these errors were encountered: