From b8e8fec28805d7358bafe0eb3977cccf9ddbb2f8 Mon Sep 17 00:00:00 2001 From: Daniel Jiang Date: Fri, 27 May 2022 04:02:24 -0700 Subject: [PATCH] X_pending for FixedFeatureAcquisition Summary: Handle X_pending properly in FixedFeatureAcquisition Reviewed By: Balandat Differential Revision: D36466941 fbshipit-source-id: 5d92df9769af14357aa6528c376dac32d301e052 --- botorch/acquisition/fixed_feature.py | 21 ++++++++++++++++- test/acquisition/test_fixed_feature.py | 32 +++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/botorch/acquisition/fixed_feature.py b/botorch/acquisition/fixed_feature.py index dd95710fc2..13e32987e7 100644 --- a/botorch/acquisition/fixed_feature.py +++ b/botorch/acquisition/fixed_feature.py @@ -12,7 +12,7 @@ from __future__ import annotations from numbers import Number -from typing import List, Sequence, Union +from typing import List, Optional, Sequence, Union import torch from botorch.acquisition.acquisition import AcquisitionFunction @@ -124,6 +124,25 @@ def forward(self, X: Tensor): X_full = self._construct_X_full(X) return self.acq_func(X_full) + @property + def X_pending(self): + r"""Return the `X_pending` of the base acquisition function.""" + try: + return self.acq_func.X_pending + except (ValueError, AttributeError): + raise ValueError( + f"Base acquisition function {type(self.acq_func).__name__} " + "does not have an `X_pending` attribute." + ) + + @X_pending.setter + def X_pending(self, X_pending: Optional[Tensor]): + r"""Sets the `X_pending` of the base acquisition function.""" + if X_pending is not None: + self.acq_func.X_pending = self._construct_X_full(X_pending) + else: + self.acq_func.X_pending = X_pending + def _construct_X_full(self, X: Tensor) -> Tensor: r"""Constructs the full input for the base acquisition function. diff --git a/test/acquisition/test_fixed_feature.py b/test/acquisition/test_fixed_feature.py index e0d231f074..398f86370b 100644 --- a/test/acquisition/test_fixed_feature.py +++ b/test/acquisition/test_fixed_feature.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch +from botorch.acquisition.analytic import ExpectedImprovement from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction from botorch.acquisition.monte_carlo import qExpectedImprovement from botorch.models import SingleTaskGP @@ -16,8 +17,9 @@ def test_fixed_features(self): train_X = torch.rand(5, 3, device=self.device) train_Y = train_X.norm(dim=-1, keepdim=True) model = SingleTaskGP(train_X, train_Y).to(device=self.device).eval() - qEI = qExpectedImprovement(model, best_f=0.0) for q in [1, 2]: + qEI = qExpectedImprovement(model, best_f=0.0) + # test single point test_X = torch.rand(q, 3, device=self.device) qEI_ff = FixedFeatureAcquisitionFunction( @@ -63,6 +65,25 @@ def test_fixed_features(self): qei_ff = qEI_ff(test_X[..., [1]]) self.assertTrue(torch.allclose(qei, qei_ff)) + # test X_pending + X_pending = torch.rand(2, 3, device=self.device) + qEI.set_X_pending(X_pending) + qEI_ff = FixedFeatureAcquisitionFunction( + qEI, d=3, columns=[2], values=test_X[..., -1:] + ) + self.assertTrue(torch.allclose(qEI.X_pending, qEI_ff.X_pending)) + + # test setting X_pending from qEI_ff + # (set target value to be last dim of X_pending and check if the + # constructed X_pending on qEI is the full X_pending) + X_pending = torch.rand(2, 3, device=self.device) + qEI.X_pending = None + qEI_ff = FixedFeatureAcquisitionFunction( + qEI, d=3, columns=[2], values=X_pending[..., -1:] + ) + qEI_ff.set_X_pending(X_pending[..., :-1]) + self.assertTrue(torch.allclose(qEI.X_pending, X_pending)) + # test gradient test_X = torch.rand(1, 3, device=self.device, requires_grad=True) test_X_ff = test_X[..., :-1].detach().clone().requires_grad_(True) @@ -92,3 +113,12 @@ def test_fixed_features(self): # test error b/c of incompatible input shapes with self.assertRaises(ValueError): qEI_ff(test_X) + + # test error when there is no X_pending (analytic EI) + test_X = torch.rand(q, 3, device=self.device) + analytic_EI = ExpectedImprovement(model, best_f=0.0) + EI_ff = FixedFeatureAcquisitionFunction( + analytic_EI, d=3, columns=[2], values=test_X[..., -1:] + ) + with self.assertRaises(ValueError): + EI_ff.X_pending