diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 00fbe35291..ed8143d851 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -11,6 +11,7 @@ from __future__ import annotations import math +import warnings from collections.abc import Callable import torch @@ -24,12 +25,16 @@ DeprecationError, UnsupportedError, ) +from botorch.exceptions.warnings import BotorchWarning from botorch.models.fully_bayesian import MCMC_DIM from botorch.models.model import Model from botorch.sampling.base import MCSampler from botorch.sampling.get_sampler import get_sampler from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model -from botorch.utils.objective import compute_feasibility_indicator +from botorch.utils.objective import ( + compute_feasibility_indicator, + compute_smoothed_feasibility_indicator, +) from botorch.utils.sampling import optimize_posterior_samples from botorch.utils.transforms import is_ensemble, normalize_indices from gpytorch.models import GP @@ -150,6 +155,13 @@ def compute_best_feasible_objective( raise ValueError( "Must specify `X_baseline` when no feasible observation exists." ) + warnings.warn( + "When all training points are infeasible, it is better to use " + "q(Log)ProbabilityOfFeasibility.", + BotorchWarning, + stacklevel=2, + ) + infeasible_value = _estimate_objective_lower_bound( model=model, objective=objective, @@ -171,8 +183,9 @@ def _estimate_objective_lower_bound( posterior_transform: PosteriorTransform | None, X: Tensor, ) -> Tensor: - """Estimates a lower bound on the objective values by evaluating the model at convex - combinations of `X`, returning the 6-sigma lower bound of the computed statistics. + """Estimates a lower bound on the objective values by evaluating the at uniformly + random points in the bounding box of `X`, returning the 6-sigma lower bound of the + computed statistics. Args: model: A fitted model. @@ -183,19 +196,19 @@ def _estimate_objective_lower_bound( Returns: A `m`-dimensional Tensor of lower bounds of the objectives. """ - convex_weights = torch.rand( - 32, - X.shape[-2], - dtype=X.dtype, - device=X.device, - ) - weights_sum = convex_weights.sum(dim=0, keepdim=True) - convex_weights = convex_weights / weights_sum + # we do not have access to `bounds` here, so we infer the bounding box + # from data, expanding by 10% in each direction + X_lb = X.min(dim=-2).values + X_ub = X.max(dim=-2).values + X_range = X_ub - X_lb + X_padding = 0.1 * X_range + uniform_samples = torch.rand(32, X.shape[-1], dtype=X.dtype, device=X.device) + X_samples = X_lb - X_padding + uniform_samples * (X_range + 2 * X_padding) # infeasible cost M is such that -M < min_x f(x), thus # 0 < min_x f(x) - (-M), so we should take -M as a lower # bound on the best feasible objective return -get_infeasible_cost( - X=convex_weights @ X, + X=X_samples, model=model, objective=objective, posterior_transform=posterior_transform, @@ -235,7 +248,19 @@ def objective(Y: Tensor, X: Tensor | None = None): return Y.squeeze(-1) posterior = model.posterior(X, posterior_transform=posterior_transform) - lb = objective(posterior.mean - 6 * posterior.variance.clamp_min(0).sqrt(), X=X) + # We check both the upper and lower bound of the posterior, since the objective + # may be increasing or decreasing. For objectives that are neither (eg. absolute + # distance from a target), this should still provide a good bound. + six_stdv = 6 * posterior.variance.clamp_min(0).sqrt() + lb = torch.stack( + [ + objective(posterior.mean - six_stdv, X=X), + objective(posterior.mean + six_stdv, X=X), + ], + dim=0, + ) + lb = lb.min(dim=0).values + if lb.ndim < posterior.mean.ndim: lb = lb.unsqueeze(-1) # Take outcome-wise min. Looping in to handle batched models. @@ -311,6 +336,7 @@ def _prune_inferior_shared_processing( samples=samples, marginalize_dim=marginalize_dim, ) + return max_points, obj_vals, infeas @@ -374,7 +400,24 @@ def prune_inferior_points( sampler=sampler, marginalize_dim=marginalize_dim, ) - if infeas.any(): + if infeas.all(): + # if no points are feasible, keep the point closest to being feasible + with torch.no_grad(): + posterior = model.posterior(X=X, posterior_transform=posterior_transform) + if sampler is None: + sampler = get_sampler( + posterior=posterior, sample_shape=torch.Size([num_samples]) + ) + samples = sampler(posterior) + # use the probability of feasibility as the objective for computing best points + obj_vals = compute_smoothed_feasibility_indicator( + constraints=constraints, + samples=samples, + eta=1e-3, + log=True, + ) + + elif infeas.any(): # set infeasible points to worse than worst objective across all samples # Use clone() here to avoid deprecated `index_put_` on an expanded tensor obj_vals = obj_vals.clone() diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index b8115ba0af..f071545de4 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -32,6 +32,7 @@ DeprecationError, UnsupportedError, ) +from botorch.exceptions.warnings import BotorchWarning from botorch.models import SingleTaskGP from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior from gpytorch.distributions import MultivariateNormal @@ -154,14 +155,17 @@ def test_compute_best_feasible_objective(self): def objective(Y, X): return Y.squeeze(-1) - 5.0 - best_f = compute_best_feasible_objective( - samples=samples, - obj=obj, - constraints=[lambda X: torch.ones_like(X[..., 0])], - model=mm, - X_baseline=X, - objective=objective, - ) + with self.assertWarnsRegex( + BotorchWarning, "ProbabilityOfFeasibility" + ): + best_f = compute_best_feasible_objective( + samples=samples, + obj=obj, + constraints=[lambda X: torch.ones_like(X[..., 0])], + model=mm, + X_baseline=X, + objective=objective, + ) expected_best_f = torch.full( sample_shape + batch_shape, -get_infeasible_cost(X=X, model=mm, objective=objective).item(),