Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closure-based versions of Optimizers don't handle Parameters with requires_grad=False well #55

Open
ORippler opened this issue Jan 14, 2025 · 0 comments · May be fixed by #56
Open

Closure-based versions of Optimizers don't handle Parameters with requires_grad=False well #55

ORippler opened this issue Jan 14, 2025 · 0 comments · May be fixed by #56

Comments

@ORippler
Copy link

ORippler commented Jan 14, 2025

Description

As outlined in the title, closure-based versions of schedule-free optimizers don't handle Parameters constructed with requires_grad=False well.
The reason for this is that, for closure-based variants, the history-state z is initialized upon the first call to optimizer.step for all parameters, regardless of whether they received gradients (or even require gradients). As a consequence, we will always lerp back to that very first history-state no matter what is done with the parameters in the mean time.

A typical use-case impossible with closure-based, schedule-free optimizers would be to register parameters to the module with requires_grad=False, "training" the parameters by estimating/following some kind of distribution (EMA for self-supervised training would come to mind). Later on, these parameters might be switched to being optimized with gradient descent.

Will be closed by #56

MWE

A dummy module is constructed, where one parameter has requires_grad=True, and one does not. The one that does not require gradients is "updated" manually based on some dummy heursitic. Asserting equal state of model parameters after small number of parameter updates will throw an error surprisingly

from collections import defaultdict

import torch

from schedulefree import (
    AdamWScheduleFree,
    AdamWScheduleFreeClosure,
    AdamWScheduleFreePaper,
    AdamWScheduleFreeReference,
    RAdamScheduleFree,
    RAdamScheduleFreeClosure,
    ScheduleFreeWrapper,
    ScheduleFreeWrapperReference,
    SGDScheduleFree,
    SGDScheduleFreeClosure,
    SGDScheduleFreeReference,
)


class SimpleModule(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.param_no_grad = torch.nn.Parameter(torch.ones(1), requires_grad=False)
        self.param_grad = torch.nn.Parameter(torch.ones(1), requires_grad=True)

    def forward(self, x: torch.tensor):
        return x * self.param_grad + self.param_no_grad

for (optim_closure_cls, optim_cls) in ((AdamWScheduleFreeClosure, AdamWScheduleFree), (RAdamScheduleFreeClosure, RAdamScheduleFree), (SGDScheduleFreeClosure, SGDScheduleFree)):
    print(f"Processing {optim_closure_cls.__name__, optim_cls.__name__}")
    ### Closure based Optimizer
    module = SimpleModule()
    optim_closure = optim_closure_cls(module.parameters())
    closure_weights = defaultdict(dict)
    def closure():
        optim_closure.zero_grad()
        loss = module(torch.ones(1))
        loss.backward()
        return loss


    # parameter fixed, "trained" e.g. by means of EMA or something else
    for i in range(3):
        optim_closure.step(closure)
        module.param_no_grad.copy_(module.param_no_grad * 0.5)
        for (name, param) in module.named_parameters():
            closure_weights[i][name] = torch.clone(param)

    # unfix parameter
    module.param_no_grad.requires_grad = True
    for i in range(3, 6):
        optim_closure.step(closure)
        for (name, param) in module.named_parameters():
            closure_weights[i][name] = torch.clone(param)
    print(f"grad: {module.param_grad.item():4f}, no_grad: {module.param_no_grad.item():4f}")


    ### Non-Closure based Optimizer
    module = SimpleModule()
    optim = optim_cls(module.parameters())
    non_closure_weights = defaultdict(dict)
    def closure():
        optim.zero_grad()
        loss = module(torch.ones(1))
        loss.backward()
        return loss


    # parameter fixed, "trained" e.g. by means of EMA or something else
    for i in range(3):
        optim.train()
        optim.step(closure)
        module.param_no_grad.copy_(module.param_no_grad * 0.5)
        optim.eval()
        for (name, param) in module.named_parameters():
            non_closure_weights[i][name] = torch.clone(param)

    # parameter unfixed
    module.param_no_grad.requires_grad = True
    for i in range(3, 6):
        optim.train()
        optim.step(closure)
        optim.eval()
        for (name, param) in module.named_parameters():
            non_closure_weights[i][name] = torch.clone(param)
    print(f"grad: {module.param_grad.item():4f}, no_grad: {module.param_no_grad.item():4f}")

    for k, v in closure_weights.items():
        for (param_closure, param_non_closure) in zip(v.values(), non_closure_weights[k].values()):
            torch.testing.assert_close(param_closure, param_non_closure)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant