Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

input_transform Normalize does not seem to work properly with condition_on_observations #1435

Open
matthewcarbone opened this issue Oct 1, 2022 · 15 comments
Assignees
Labels
bug Something isn't working

Comments

@matthewcarbone
Copy link

matthewcarbone commented Oct 1, 2022

🐛 Bug

After running model.condition_on_observations(new_x, new_y), where the original model was instantiated with Normalize(d), that model fails during retraining. I believe this is a bug but I'm honestly not sure.

To reproduce

Step 1: initialize dummy data

import botorch
import numpy as np
import torch

np.random.seed(123)
torch.manual_seed(123)

# use regular spaced points on the interval [0, 1]
train_x = torch.linspace(0, 1, 15)

# training data needs to be explicitly multi-dimensional
train_x = train_x.unsqueeze(1)

# sample observed values and add some synthetic noise
train_y = torch.sin(train_x * (2 * np.pi)) + 0.15 * torch.randn_like(train_x)

Step 2: initialization/training, works just fine

model = botorch.models.SingleTaskGP(
    train_X=train_x, train_Y=train_y, input_transform=Normalize(1, transform_on_eval=True)
)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
botorch.fit.fit_gpytorch_mll(mll)

Step 3: condition

new_x = torch.FloatTensor(np.array([1.25, 1.5]).reshape(-1, 1))
new_y = torch.FloatTensor(np.array([-1.0, -2.0]).reshape(-1, 1))
model = model.condition_on_observations(new_x, new_y)

Step 4: attempt retraining to further tune hyper parameters/length scales and whatnot

mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
botorch.fit.fit_gpytorch_mll(mll)  # fails

Stack trace/error message

MDNotImplementedError                     Traceback (most recent call last)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/utils/dispatcher.py:88, in Dispatcher.__call__(self, *args, **kwargs)
     87 try:
---> 88     return func(*args, **kwargs)
     89 except MDNotImplementedError:
     90     # Traverses registered methods in order, yields whenever a match is found

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/fit.py:320, in _fit_multioutput_independent(mll, _, __, sequential, **kwargs)
    315 if (  # incompatible models
    316     not sequential
    317     or mll.model.num_outputs == 1
    318     or mll.likelihood is not getattr(mll.model, "likelihood", None)
    319 ):
--> 320     raise MDNotImplementedError  # defer to generic
    322 # TODO: Unpacking of OutcomeTransforms not yet supported. Targets are often
    323 # pre-transformed in __init__, so try fitting with outcome_transform hidden

MDNotImplementedError: 

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Input In [20], in <cell line: 2>()
      1 mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
----> 2 botorch.fit.fit_gpytorch_mll(mll)

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/fit.py:114, in fit_gpytorch_mll(mll, optimizer, optimizer_kwargs, **kwargs)
    111 if optimizer is not None:  # defer to per-method defaults
    112     kwargs["optimizer"] = optimizer
--> 114 return dispatcher(
    115     mll,
    116     type(mll.likelihood),
    117     type(mll.model),
    118     optimizer_kwargs=optimizer_kwargs,
    119     **kwargs,
    120 )

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/utils/dispatcher.py:95, in Dispatcher.__call__(self, *args, **kwargs)
     93 for func in funcs:
     94     try:
---> 95         return func(*args, **kwargs)
     96     except MDNotImplementedError:
     97         pass

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/fit.py:240, in _fit_fallback(mll, _, __, optimizer, optimizer_kwargs, max_attempts, warning_filter, caught_exception_types, **ignore)
    238 with catch_warnings(record=True) as warning_list, debug(True):
    239     simplefilter("always", category=OptimizationWarning)
--> 240     mll, _ = optimizer(mll, **optimizer_kwargs)
    242 # Resolve warning messages and determine whether or not to retry
    243 done = True

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/fit.py:142, in fit_gpytorch_scipy(mll, bounds, method, options, track_iterations, approx_mll, scipy_objective, module_to_array_func, module_from_array_func)
    140 cb = store_iteration if track_iterations else None
    141 with gpt_settings.fast_computations(log_prob=approx_mll):
--> 142     res = minimize(
    143         scipy_objective,
    144         x0,
    145         args=(mll, property_dict),
    146         bounds=bounds,
    147         method=method,
    148         jac=True,
    149         options=options,
    150         callback=cb,
    151     )
    152     iterations = []
    153     if track_iterations:

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_minimize.py:692, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    689     res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
    690                              **options)
    691 elif meth == 'l-bfgs-b':
--> 692     res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
    693                            callback=callback, **options)
    694 elif meth == 'tnc':
    695     res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
    696                         **options)

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_lbfgsb_py.py:308, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
    305     else:
    306         iprint = disp
--> 308 sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps,
    309                               bounds=new_bounds,
    310                               finite_diff_rel_step=finite_diff_rel_step)
    312 func_and_grad = sf.fun_and_grad
    314 fortran_int = _lbfgsb.types.intvar.dtype

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_optimize.py:263, in _prepare_scalar_function(fun, x0, jac, args, bounds, epsilon, finite_diff_rel_step, hess)
    259     bounds = (-np.inf, np.inf)
    261 # ScalarFunction caches. Reuse of fun(x) during grad
    262 # calculation reduces overall function evaluations.
--> 263 sf = ScalarFunction(fun, x0, args, grad, hess,
    264                     finite_diff_rel_step, bounds, epsilon=epsilon)
    266 return sf

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:158, in ScalarFunction.__init__(self, fun, x0, args, grad, hess, finite_diff_rel_step, finite_diff_bounds, epsilon)
    155     self.f = fun_wrapped(self.x)
    157 self._update_fun_impl = update_fun
--> 158 self._update_fun()
    160 # Gradient evaluation
    161 if callable(grad):

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:251, in ScalarFunction._update_fun(self)
    249 def _update_fun(self):
    250     if not self.f_updated:
--> 251         self._update_fun_impl()
    252         self.f_updated = True

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:155, in ScalarFunction.__init__.<locals>.update_fun()
    154 def update_fun():
--> 155     self.f = fun_wrapped(self.x)

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:137, in ScalarFunction.__init__.<locals>.fun_wrapped(x)
    133 self.nfev += 1
    134 # Send a copy because the user may overwrite it.
    135 # Overwriting results in undefined behaviour because
    136 # fun(self.x) will change self.x, with the two no longer linked.
--> 137 fx = fun(np.copy(x), *args)
    138 # Make sure the function returns a true scalar
    139 if not np.isscalar(fx):

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_optimize.py:76, in MemoizeJac.__call__(self, x, *args)
     74 def __call__(self, x, *args):
     75     """ returns the the function value """
---> 76     self._compute_if_needed(x, *args)
     77     return self._value

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_optimize.py:70, in MemoizeJac._compute_if_needed(self, x, *args)
     68 if not np.all(x == self.x) or self._value is None or self.jac is None:
     69     self.x = np.asarray(x).copy()
---> 70     fg = self.fun(x, *args)
     71     self.jac = fg[1]
     72     self._value = fg[0]

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/utils.py:227, in _scipy_objective_and_grad(x, mll, property_dict)
    225     loss = -mll(*args).sum()
    226 except RuntimeError as e:
--> 227     return _handle_numerical_errors(error=e, x=x)
    228 loss.backward()
    230 i = 0

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/utils.py:256, in _handle_numerical_errors(error, x)
    250 if (
    251     isinstance(error, NanError)
    252     or "singular" in error_message  # old pytorch message
    253     or "input is not positive-definite" in error_message  # since pytorch #63864
    254 ):
    255     return float("nan"), np.full_like(x, "nan")
--> 256 raise error

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/utils.py:225, in _scipy_objective_and_grad(x, mll, property_dict)
    223     output = mll.model(*train_inputs)
    224     args = [output, train_targets] + _get_extra_mll_args(mll)
--> 225     loss = -mll(*args).sum()
    226 except RuntimeError as e:
    227     return _handle_numerical_errors(error=e, x=x)

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/gpytorch/module.py:30, in Module.__call__(self, *inputs, **kwargs)
     29 def __call__(self, *inputs, **kwargs):
---> 30     outputs = self.forward(*inputs, **kwargs)
     31     if isinstance(outputs, list):
     32         return [_validate_module_outputs(output) for output in outputs]

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:64, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params)
     62 # Get the log prob of the marginal distribution
     63 output = self.likelihood(function_dist, *params)
---> 64 res = output.log_prob(target)
     65 res = self._add_other_terms(res, params)
     67 # Scale by the amount of data we have

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/gpytorch/distributions/multivariate_normal.py:147, in MultivariateNormal.log_prob(self, value)
    145 def log_prob(self, value):
    146     if settings.fast_computations.log_prob.off():
--> 147         return super().log_prob(value)
    149     if self._validate_args:
    150         self._validate_sample(value)

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/torch/distributions/multivariate_normal.py:211, in MultivariateNormal.log_prob(self, value)
    209 if self._validate_args:
    210     self._validate_sample(value)
--> 211 diff = value - self.loc
    212 M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
    213 half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)

RuntimeError: The size of tensor a (17) must match the size of tensor b (15) at non-singleton dimension 0

Expected Behavior

I think the second training procedure is supposed to work, right? It would seem sensible that Normalize would be updated with the new training information as passed during conditioning.

System information

BoTorch version: 0.7.2
GPyTorch version: 1.9.0
Torch version: 1.12.0
Computer OS: Mac M1 Max OS version 12.5.1

@matthewcarbone matthewcarbone added the bug Something isn't working label Oct 1, 2022
@Balandat
Copy link
Contributor

Balandat commented Oct 1, 2022

Hmm interesting it might be that some of the input transformations may not play well with the condition_on_observations call? @saitcakmak you've probably got the best understanding of the input transforms, could you please take a look?

@Balandat
Copy link
Contributor

Balandat commented Oct 1, 2022

I assume this works fine if you don't use an input transform?

@matthewcarbone
Copy link
Author

@Balandat that is correct. In fact, it actually works with the Standardize output transform.

@saitcakmak
Copy link
Contributor

The issue seems to be that the model._original_train_inputs is a 15 x 1-dim tensor (even after conditioning), so when we call mll.train, the model.trian_inputs get reverted back to model._original_train_inputs, losing the X's we just conditioned on.

This will be fixed in #1372, as it gets rid of the _original_train_inputs and all the related hacks. In the mean time, you can achieve the warm-starting behavior by creating a new model and loading the hyper-parameters using new_model.load_state_dict(old_model.state_dict()).

@matthewcarbone
Copy link
Author

@saitcakmak thank you, I'll give your hack a try for now!

@matthewcarbone
Copy link
Author

matthewcarbone commented Oct 5, 2022

@saitcakmak so just to be clear as to what you mean here. The idea would be to create an entirely new model with all of the same required objects (e.g. the transforms) but with the new training data (in my case a 17 x 1 tensor), and then to load state from the old model, setting the hyper parameters properly. Am I interpreting all of this correctly?

Thanks!

Edit: never mind, what I just asked about does appear to work regardless!

@saitcakmak
Copy link
Contributor

Yep, just repeat model = botorch.models.SingleTaskGP( train_X=train_x, train_Y=train_y, input_transform=Normalize(1, transform_on_eval=True) ) with updated training data and use the bit from above comment to transfer over the hyper-parameters.

@matthewcarbone
Copy link
Author

@saitcakmak the problem though is that when initializing using what you show above, the correct hyper parameters are set for the transforms, but not the length scales. When you do load_state_dict it actually sets the old parameters for the input and output transforms (i.e., it sets the right length scales but the wrong transform parameters). This also seems to mess something up with the posterior. Even manually setting model.input_transform._buffer to the old ordered dict from the previous model doesn't work properly. The only way I can seem to fix this is by retraining.

Not entirely sure how to fix this on my end, seems pretty complicated since the posterior method somehow also contains the transforms. I'm not following it so well.

TL;DR, can I set the model length scales without overwriting the transform parameters, without recalling train?

Also, an aside, does calling train_gpytorch_mll on an already-trained model use the initial length scales and whatnot as initial guesses? Thus perhaps speeding up the training?

@saitcakmak
Copy link
Contributor

If you do

new_model = botorch.models.SingleTaskGP(
    train_X=train_x, train_Y=train_y, input_transform=Normalize(1, transform_on_eval=True)
)
new_model.load_state_dict(old_model.state_dict())

It will update all buffers & parameters of the GP and its submodules with the corresponding values from the old_model. If you then train this model with fit_gpytorch_mll, it should use these as the starting values in the training. Since the model will be called in train mode during training, the input transform buffers (e.g. for Normalize) should also get updated. As you noticed, this doesn't work with the outcome transforms, since they're only called in train mode while initializing the model. One way to get around this is to exclude the outcome transforms from the state dict loading.

new_model.load_state_dict(
    {k:v for k, v in old_model.state_dict().items() if "outcome_transform" not in k},
    strict=False,  # needed since the state dict is now missing certain keys.
)

can I set the model length scales without overwriting the transform parameters, without recalling train?

The above bit should work for this. Though, if you don't want to retrain / call train on the model, the original bug should not be an issue, so you can also just use condition_on_observations. That bug will only happen if you call model.train() after condition_on_observations.

Also, an aside, does calling train_gpytorch_mll on an already-trained model use the initial length scales and whatnot as initial guesses? Thus perhaps speeding up the training?

Yes, for the first model training attempt, it should use the model parameters as the initial values. If the first attempt fails and it has to retry, then it will randomly sample from the priors. Intuitively, this should speed things up.

@matthewcarbone
Copy link
Author

matthewcarbone commented Oct 11, 2022

@saitcakmak Your suggestion does work, but I'd like to bring another set of weird behaviors to your attention:

print(train_x)
"""
tensor([[0.0000],
        [0.7143],
        [1.4286],
        [2.1429],
        [2.8571]])
"""

print(train_y)
"""
tensor([-1.6720, 45.1938, 72.6386, 93.8865, 79.5389])
"""

model = botorch.models.SingleTaskGP(
    train_X=train_x,
    train_Y=train_y,
    input_transform=Normalize(1),
    outcome_transform=Standardize(1)
)

model.train_inputs[0]
"""
tensor([[0.0000],
        [0.7143],
        [1.4286],
        [2.1429],
        [2.8571]])
"""

model.train_targets
"""
tensor([-1.5798, -0.3373,  0.3903,  0.9536,  0.5732])
"""

For some reason, the inputs are not being scaled immediately like the targets. This is making retrieving the current model's training data from the model itself really challenging, especially after reconditioning. Any thoughts?

EDIT: I think I've nailed down the source of this. Looks like the training data inputs behavior changes with whether or not model is in train() or eval() mode, whereas the targets are independent! Is this intended?

@saitcakmak
Copy link
Contributor

I think I've nailed down the source of this. Looks like the training data inputs behavior changes with whether or not model is in train() or eval() mode, whereas the targets are independent! Is this intended?

Looks like you figured out. We transform the outcomes before passing them to GPyTorch, so train_targets will always show the transformed outcomes. For train_inputs, the story is a bit more complicated. Currently, it will be untransformed if the model is in train mode, and transformed if the model is in eval mode (with the originals available at _original_train_inputs). This was a hack to make sure we can apply some of the input transforms a bit more selectively to train / test inputs. I'll clean all of that up in #1372 (train_inputs will always be untransformed, and input transforms will be handled in GPyTorch), though I have some other things I need to work on before I get to that.

@matthewcarbone
Copy link
Author

Ahh so that's what _original_train_inputs is for. I never noticed that it doesn't change during transforms. Ok all of this sounds good. Looking forward to the botorch+gpytorch update!

@matthewcarbone
Copy link
Author

@saitcakmak funny that I come back to this exactly a year later. Have there been any updates on merging #1372? 😁

Looks like the GPyTorch side of the changes is ready to merge but just never was. Thanks!

@esantorella
Copy link
Member

Hi @matthewcarbone , refactoring input transforms is still in progress. #1372 can't be merged in as-is since there are a good number of other changes that need to be made to ensure that such a large change works smoothly.

@matthewcarbone
Copy link
Author

@esantorella ok no problem. I'm happy to help contribute somehow if there's any opportunity. Thanks! 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants