From a3acad4b331aa194dcc2e8ccdee3b730a457aca0 Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Fri, 15 Nov 2024 15:52:08 +0000 Subject: [PATCH 1/7] Add missing properties to `PosteriorMeanModel` --- botorch/models/deterministic.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/botorch/models/deterministic.py b/botorch/models/deterministic.py index 31de8d29df..7ec568caed 100644 --- a/botorch/models/deterministic.py +++ b/botorch/models/deterministic.py @@ -162,6 +162,16 @@ def __init__(self, model: Model) -> None: def forward(self, X: Tensor) -> Tensor: return self.model.posterior(X).mean + @property + def num_outputs(self) -> int: + r"""The number of outputs of the model.""" + return self.model.num_outputs + + @property + def batch_shape(self) -> torch.Size: + r"""The batch shape of the model.""" + return self.model.batch_shape + class FixedSingleSampleModel(DeterministicModel): r""" From 06f76bda446c9bc7db00296982201ac91fc56070 Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Wed, 20 Nov 2024 11:19:14 +0000 Subject: [PATCH 2/7] Add qPM and qPSTD acquisition functions, update docstrings --- botorch/acquisition/monte_carlo.py | 142 +++++++++++++++++++++++++++-- 1 file changed, 136 insertions(+), 6 deletions(-) diff --git a/botorch/acquisition/monte_carlo.py b/botorch/acquisition/monte_carlo.py index 17cf53fd14..c59087a9e0 100644 --- a/botorch/acquisition/monte_carlo.py +++ b/botorch/acquisition/monte_carlo.py @@ -747,7 +747,7 @@ class qSimpleRegret(SampleReducingMCAcquisitionFunction): non-negative. `qSimpleRegret` acquisition values can be negative, so we instead use a `ConstrainedMCObjective` which applies constraints to the objectives (e.g. before computing the acquisition function) and shifts negative objective values using - by an infeasible cost to ensure non-negativity (before applying constraints and + an infeasible cost to ensure non-negativity (before applying constraints and shifting them back). Example: @@ -813,11 +813,11 @@ class qUpperConfidenceBound(SampleReducingMCAcquisitionFunction): `SampleReducingMCAcquisitionFunction` computes the acquisition values on the sample level and then weights the sample-level acquisition values by a soft feasibility indicator. Hence, it expects non-log acquisition function values to be - non-negative. `qSimpleRegret` acquisition values can be negative, so we instead use - a `ConstrainedMCObjective` which applies constraints to the objectives (e.g. before - computing the acquisition function) and shifts negative objective values using - by an infeasible cost to ensure non-negativity (before applying constraints and - shifting them back). + non-negative. `qUpperConfidenceBound` acquisition values can be negative, so we + instead use a `ConstrainedMCObjective` which applies constraints to the objectives + (e.g. before computing the acquisition function) and shifts negative objective + values using an infeasible cost to ensure non-negativity (before applying + constraints and shifting them back). Example: >>> model = SingleTaskGP(train_X, train_Y) @@ -887,3 +887,133 @@ class qLowerConfidenceBound(qUpperConfidenceBound): def _get_beta_prime(self, beta: float) -> float: """Multiply beta prime by -1 to get the lower confidence bound.""" return -super()._get_beta_prime(beta=beta) + + +class qPosteriorMean(SampleReducingMCAcquisitionFunction): + r"""MC-based batch Posterior Mean. + + Constraints should be provided as a `ConstrainedMCObjective`. + Passing `constraints` as an argument is not supported. This is because + `SampleReducingMCAcquisitionFunction` computes the acquisition values on the sample + level and then weights the sample-level acquisition values by a soft feasibility + indicator. Hence, it expects non-log acquisition function values to be + non-negative. `qPosteriorMean` acquisition values can be negative, so we instead use + a `ConstrainedMCObjective` which applies constraints to the objectives (e.g. before + computing the acquisition function) and shifts negative objective values using + an infeasible cost to ensure non-negativity (before applying constraints and + shifting them back). + + Example: + >>> model = SingleTaskGP(train_X, train_Y) + >>> sampler = SobolQMCNormalSampler(1024) + >>> qPM = qPosteriorMean(model, sampler) + >>> qpm = qPM(test_X) + """ + + def __init__( + self, + model: Model, + sampler: MCSampler | None = None, + objective: MCAcquisitionObjective | None = None, + posterior_transform: PosteriorTransform | None = None, + X_pending: Tensor | None = None, + ) -> None: + r"""q-Posterior Mean. + + Args: + model: A fitted model. + sampler: The sampler used to draw base samples. See `MCAcquisitionFunction` + more details. + objective: The MCAcquisitionObjective under which the samples are + evaluated. Defaults to `IdentityMCObjective()`. + posterior_transform: A PosteriorTransform (optional). + X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points that have + points that have been submitted for function evaluation but have not yet + been evaluated. Concatenated into X upon forward call. Copied and set to + have no gradient. + """ + super().__init__( + model=model, + sampler=sampler, + objective=objective, + posterior_transform=posterior_transform, + X_pending=X_pending, + ) + + def _sample_forward(self, obj: Tensor) -> Tensor: + r"""Evaluate qPosteriorMean per sample on the candidate set `X`. + + Args: + obj: A `sample_shape x batch_shape x q`-dim Tensor of MC objective values. + + Returns: + A `sample_shape x batch_shape x q`-dim Tensor of acquisition values. + """ + mean = obj.mean(dim=0, keepdim=True).broadcast_to(obj.shape) + return mean + + +class qPosteriorStandardDeviation(SampleReducingMCAcquisitionFunction): + r"""MC-based batch Posterior Standard Deviation. + + An acquisition function for pure exploration. + + Example: + >>> model = SingleTaskGP(train_X, train_Y) + >>> sampler = SobolQMCNormalSampler(1024) + >>> qPSTD = qPosteriorStandardDeviation(model, sampler) + >>> std = qPSTD(test_X) + """ + + def __init__( + self, + model: Model, + sampler: MCSampler | None = None, + objective: MCAcquisitionObjective | None = None, + posterior_transform: PosteriorTransform | None = None, + X_pending: Tensor | None = None, + constraints: list[Callable[[Tensor], Tensor]] | None = None, + eta: Tensor | float = 1e-3, + ) -> None: + r"""q-Posterior Mean. + + Args: + model: A fitted model. + sampler: The sampler used to draw base samples. See `MCAcquisitionFunction` + more details. + objective: The MCAcquisitionObjective under which the samples are + evaluated. Defaults to `IdentityMCObjective()`. + posterior_transform: A PosteriorTransform (optional). + X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points that have + points that have been submitted for function evaluation but have not yet + been evaluated. Concatenated into X upon forward call. Copied and set to + have no gradient. + constraints: A list of constraint callables which map a Tensor of posterior + samples of dimension `sample_shape x batch-shape x q x m`-dim to a + `sample_shape x batch-shape x q`-dim Tensor. The associated constraints + are considered satisfied if the output is less than zero. + eta: Temperature parameter(s) governing the smoothness of the sigmoid + approximation to the constraint indicators. For more details, on this + parameter, see the docs of `compute_smoothed_feasibility_indicator`. + """ + super().__init__( + model=model, + sampler=sampler, + objective=objective, + posterior_transform=posterior_transform, + X_pending=X_pending, + constraints=constraints, + eta=eta, + ) + + def _sample_forward(self, obj: Tensor) -> Tensor: + r"""Evaluate qPosteriorStandardDeviation per sample on the candidate set `X`. + + Args: + obj: A `sample_shape x batch_shape x q`-dim Tensor of MC objective values. + + Returns: + A `sample_shape x batch_shape x q`-dim Tensor of acquisition values. + """ + mean = obj.mean(dim=0) + return (obj - mean).abs() From fc8e40f040f7db9a6a8d3a5f81c0495e124185cc Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Wed, 20 Nov 2024 16:53:05 +0000 Subject: [PATCH 3/7] Remove qPosteriorMean --- botorch/acquisition/monte_carlo.py | 66 +----------------------------- 1 file changed, 1 insertion(+), 65 deletions(-) diff --git a/botorch/acquisition/monte_carlo.py b/botorch/acquisition/monte_carlo.py index c59087a9e0..f9a946f7c9 100644 --- a/botorch/acquisition/monte_carlo.py +++ b/botorch/acquisition/monte_carlo.py @@ -889,70 +889,6 @@ def _get_beta_prime(self, beta: float) -> float: return -super()._get_beta_prime(beta=beta) -class qPosteriorMean(SampleReducingMCAcquisitionFunction): - r"""MC-based batch Posterior Mean. - - Constraints should be provided as a `ConstrainedMCObjective`. - Passing `constraints` as an argument is not supported. This is because - `SampleReducingMCAcquisitionFunction` computes the acquisition values on the sample - level and then weights the sample-level acquisition values by a soft feasibility - indicator. Hence, it expects non-log acquisition function values to be - non-negative. `qPosteriorMean` acquisition values can be negative, so we instead use - a `ConstrainedMCObjective` which applies constraints to the objectives (e.g. before - computing the acquisition function) and shifts negative objective values using - an infeasible cost to ensure non-negativity (before applying constraints and - shifting them back). - - Example: - >>> model = SingleTaskGP(train_X, train_Y) - >>> sampler = SobolQMCNormalSampler(1024) - >>> qPM = qPosteriorMean(model, sampler) - >>> qpm = qPM(test_X) - """ - - def __init__( - self, - model: Model, - sampler: MCSampler | None = None, - objective: MCAcquisitionObjective | None = None, - posterior_transform: PosteriorTransform | None = None, - X_pending: Tensor | None = None, - ) -> None: - r"""q-Posterior Mean. - - Args: - model: A fitted model. - sampler: The sampler used to draw base samples. See `MCAcquisitionFunction` - more details. - objective: The MCAcquisitionObjective under which the samples are - evaluated. Defaults to `IdentityMCObjective()`. - posterior_transform: A PosteriorTransform (optional). - X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points that have - points that have been submitted for function evaluation but have not yet - been evaluated. Concatenated into X upon forward call. Copied and set to - have no gradient. - """ - super().__init__( - model=model, - sampler=sampler, - objective=objective, - posterior_transform=posterior_transform, - X_pending=X_pending, - ) - - def _sample_forward(self, obj: Tensor) -> Tensor: - r"""Evaluate qPosteriorMean per sample on the candidate set `X`. - - Args: - obj: A `sample_shape x batch_shape x q`-dim Tensor of MC objective values. - - Returns: - A `sample_shape x batch_shape x q`-dim Tensor of acquisition values. - """ - mean = obj.mean(dim=0, keepdim=True).broadcast_to(obj.shape) - return mean - - class qPosteriorStandardDeviation(SampleReducingMCAcquisitionFunction): r"""MC-based batch Posterior Standard Deviation. @@ -975,7 +911,7 @@ def __init__( constraints: list[Callable[[Tensor], Tensor]] | None = None, eta: Tensor | float = 1e-3, ) -> None: - r"""q-Posterior Mean. + r"""q-Posterior Standard Deviation. Args: model: A fitted model. From 135eba22d396889af31b91464cd0c5396d0af94e Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Fri, 29 Nov 2024 17:09:50 +0000 Subject: [PATCH 4/7] Update acquisition __init__ --- botorch/acquisition/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/botorch/acquisition/__init__.py b/botorch/acquisition/__init__.py index 87c6030183..862897fe11 100644 --- a/botorch/acquisition/__init__.py +++ b/botorch/acquisition/__init__.py @@ -55,7 +55,9 @@ from botorch.acquisition.monte_carlo import ( MCAcquisitionFunction, qExpectedImprovement, + qLowerConfidenceBound, qNoisyExpectedImprovement, + qPosteriorStandardDeviation, qProbabilityOfImprovement, qSimpleRegret, qUpperConfidenceBound, @@ -120,6 +122,8 @@ "qNegIntegratedPosteriorVariance", "qProbabilityOfImprovement", "qSimpleRegret", + "qPosteriorStandardDeviation", + "qLowerConfidenceBound", "qUpperConfidenceBound", "ConstrainedMCObjective", "GenericMCObjective", From 4b526d7b7c20f8f930602421452c75f487a3ed2f Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Fri, 29 Nov 2024 17:52:18 +0000 Subject: [PATCH 5/7] Correct scaling of qPSTD --- botorch/acquisition/monte_carlo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/botorch/acquisition/monte_carlo.py b/botorch/acquisition/monte_carlo.py index f9a946f7c9..9e39fc6db4 100644 --- a/botorch/acquisition/monte_carlo.py +++ b/botorch/acquisition/monte_carlo.py @@ -941,6 +941,7 @@ def __init__( constraints=constraints, eta=eta, ) + self._scale = math.sqrt(math.pi / 2) def _sample_forward(self, obj: Tensor) -> Tensor: r"""Evaluate qPosteriorStandardDeviation per sample on the candidate set `X`. @@ -952,4 +953,4 @@ def _sample_forward(self, obj: Tensor) -> Tensor: A `sample_shape x batch_shape x q`-dim Tensor of acquisition values. """ mean = obj.mean(dim=0) - return (obj - mean).abs() + return (obj - mean).abs() * self._scale From 1caf62ca4d967cfb51f0e05759d3985f814ff90d Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Tue, 3 Dec 2024 17:09:33 +0000 Subject: [PATCH 6/7] Added unit test for qPSTD --- test/acquisition/test_monte_carlo.py | 101 +++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/test/acquisition/test_monte_carlo.py b/test/acquisition/test_monte_carlo.py index c3c9810fc6..0262fd04be 100644 --- a/test/acquisition/test_monte_carlo.py +++ b/test/acquisition/test_monte_carlo.py @@ -20,6 +20,7 @@ qExpectedImprovement, qLowerConfidenceBound, qNoisyExpectedImprovement, + qPosteriorStandardDeviation, qProbabilityOfImprovement, qSimpleRegret, qUpperConfidenceBound, @@ -1009,6 +1010,105 @@ def test_beta_prime(self): super().test_beta_prime(negate=True) +class TestQPosteriorStandardDeviation(BotorchTestCase): + def test_q_pstd(self): + for dtype in (torch.float, torch.double): + # the event shape is `b x q x t` = 1 x 1 x 1 + samples = torch.zeros(1, 1, 1, device=self.device, dtype=dtype) + mm = MockModel(MockPosterior(samples=samples)) + # X is `q x d` = 1 x 1. X is a dummy and unused b/c of mocking + X = torch.zeros(1, 1, device=self.device, dtype=dtype) + + # basic test + sampler = IIDNormalSampler(sample_shape=torch.Size([2])) + acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler) + res = acqf(X) + self.assertEqual(res.item(), 0.0) + + # basic test + sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345) + acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler) + res = acqf(X) + self.assertEqual(res.item(), 0.0) + self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 1])) + bs = acqf.sampler.base_samples.clone() + res = acqf(X) + self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) + + # basic test, qmc + sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2])) + acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler) + res = acqf(X) + self.assertEqual(res.item(), 0.0) + self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 1])) + bs = acqf.sampler.base_samples.clone() + acqf(X) + self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) + + # basic test for X_pending and warning + acqf.set_X_pending() + self.assertIsNone(acqf.X_pending) + acqf.set_X_pending(None) + self.assertIsNone(acqf.X_pending) + acqf.set_X_pending(X) + self.assertEqual(acqf.X_pending, X) + mm._posterior._samples = mm._posterior._samples.expand(1, 2, 1) + res = acqf(X) + X2 = torch.zeros( + 1, 1, 1, device=self.device, dtype=dtype, requires_grad=True + ) + with warnings.catch_warnings(record=True) as ws: + acqf.set_X_pending(X2) + self.assertEqual(acqf.X_pending, X2) + self.assertEqual(sum(issubclass(w.category, BotorchWarning) for w in ws), 1) + + def test_q_pstd_batch(self): + # the event shape is `b x q x t` = 2 x 2 x 1 + for dtype in (torch.float, torch.double): + samples = torch.zeros(2, 2, 1, device=self.device, dtype=dtype) + samples[0, 0, 0] = 1.0 + mm = MockModel(MockPosterior(samples=samples)) + # X is a dummy and unused b/c of mocking + X = torch.zeros(2, 2, 1, device=self.device, dtype=dtype) + + # test batch mode + sampler = IIDNormalSampler(sample_shape=torch.Size([8])) + acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler) + res = acqf(X) + self.assertEqual(res[0].item(), 0.0) + self.assertEqual(res[1].item(), 0.0) + + # test batch mode + sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345) + acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler) + res = acqf(X) # 1-dim batch + self.assertEqual(res[0].item(), 0.0) + self.assertEqual(res[1].item(), 0.0) + self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1])) + bs = acqf.sampler.base_samples.clone() + acqf(X) + self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) + res = acqf(X.expand(2, -1, 1)) # 2-dim batch + self.assertEqual(res[0].item(), 0.0) + self.assertEqual(res[1].item(), 0.0) + # the base samples should have the batch dim collapsed + self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1])) + bs = acqf.sampler.base_samples.clone() + acqf(X.expand(2, -1, 1)) + self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) + + # test batch mode, qmc + sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2])) + acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler) + res = acqf(X) + self.assertEqual(res[0].item(), 0.0) + self.assertEqual(res[1].item(), 0.0) + self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1])) + bs = acqf.sampler.base_samples.clone() + acqf(X) + self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) + + class TestMCAcquisitionFunctionWithConstraints(BotorchTestCase): def test_mc_acquisition_function_with_constraints(self): for dtype in (torch.float, torch.double): @@ -1033,6 +1133,7 @@ def _test_mc_acquisition_function_with_constraints(self, dtype: torch.dtype): # cache_root=True not supported by MockModel, see test_cache_root partial(qNoisyExpectedImprovement, cache_root=False, **nei_args), partial(qNoisyExpectedImprovement, cache_root=True, **nei_args), + partial(qPosteriorStandardDeviation, model=mm), ]: acqf = acqf_constructor() mm._posterior._samples = ( From 87e8fe218967927b3340fb9a287ffb7ca2c7ad69 Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Tue, 7 Jan 2025 11:59:50 +0000 Subject: [PATCH 7/7] Update qPSTD tests, allow MockPosterior to return non-expanded raw samples --- botorch/utils/testing.py | 6 +++++- test/acquisition/test_monte_carlo.py | 32 ++++++++++++++++++---------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/botorch/utils/testing.py b/botorch/utils/testing.py index effc04276b..ba83761c32 100644 --- a/botorch/utils/testing.py +++ b/botorch/utils/testing.py @@ -347,7 +347,11 @@ def rsample( do a shape check but return the same mock samples.""" if sample_shape is None: sample_shape = torch.Size() - return self._samples.expand(sample_shape + self._samples.shape) + extended_shape = self._extended_shape(sample_shape) + if self._samples.shape == extended_shape: + return self._samples + else: + return self._samples.expand(extended_shape) def rsample_from_base_samples( self, diff --git a/test/acquisition/test_monte_carlo.py b/test/acquisition/test_monte_carlo.py index 0262fd04be..296b9ec912 100644 --- a/test/acquisition/test_monte_carlo.py +++ b/test/acquisition/test_monte_carlo.py @@ -1012,35 +1012,44 @@ def test_beta_prime(self): class TestQPosteriorStandardDeviation(BotorchTestCase): def test_q_pstd(self): + n_samples = 128 for dtype in (torch.float, torch.double): # the event shape is `b x q x t` = 1 x 1 x 1 - samples = torch.zeros(1, 1, 1, device=self.device, dtype=dtype) - mm = MockModel(MockPosterior(samples=samples)) + torch.manual_seed(0) + samples = torch.randn(n_samples, 1, 1, 1, device=self.device, dtype=dtype) + std = samples.std(dim=0, correction=0).item() + mm = MockModel( + MockPosterior(samples=samples, base_shape=torch.Size([1, 1, 1])) + ) # X is `q x d` = 1 x 1. X is a dummy and unused b/c of mocking X = torch.zeros(1, 1, device=self.device, dtype=dtype) # basic test - sampler = IIDNormalSampler(sample_shape=torch.Size([2])) + sampler = IIDNormalSampler(sample_shape=torch.Size([n_samples])) acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler) res = acqf(X) - self.assertEqual(res.item(), 0.0) + self.assertAllClose(res.item(), std, rtol=0.02, atol=0) # basic test - sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345) + sampler = IIDNormalSampler(sample_shape=torch.Size([n_samples]), seed=12345) acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler) res = acqf(X) - self.assertEqual(res.item(), 0.0) - self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 1])) + self.assertAllClose(res.item(), std, rtol=0.02, atol=0) + self.assertEqual( + acqf.sampler.base_samples.shape, torch.Size([n_samples, 1, 1, 1]) + ) bs = acqf.sampler.base_samples.clone() res = acqf(X) self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) # basic test, qmc - sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2])) + sampler = SobolQMCNormalSampler(sample_shape=torch.Size([n_samples])) acqf = qPosteriorStandardDeviation(model=mm, sampler=sampler) res = acqf(X) - self.assertEqual(res.item(), 0.0) - self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 1])) + self.assertAllClose(res.item(), std, rtol=0.02, atol=0) + self.assertEqual( + acqf.sampler.base_samples.shape, torch.Size([n_samples, 1, 1, 1]) + ) bs = acqf.sampler.base_samples.clone() acqf(X) self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) @@ -1052,7 +1061,8 @@ def test_q_pstd(self): self.assertIsNone(acqf.X_pending) acqf.set_X_pending(X) self.assertEqual(acqf.X_pending, X) - mm._posterior._samples = mm._posterior._samples.expand(1, 2, 1) + mm._posterior._base_shape = torch.Size([1, 2, 1]) + mm._posterior._samples = mm._posterior._samples.expand(n_samples, 1, 2, 1) res = acqf(X) X2 = torch.zeros( 1, 1, 1, device=self.device, dtype=dtype, requires_grad=True