Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions botorch/acquisition/logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def __init__(
tau_relu: float = TAU_RELU,
marginalize_dim: int | None = None,
incremental: bool = True,
infeasible_obj: Tensor | float | None = None,
Copy link

@renzph renzph Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be added to qNoisyExpectedImprovement?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - if the changes look good to whoever reviews, then I will also apply them to qNEI, as well as investigating q(Log)NEHVI too. Just wanted to get confirmation that the changes were good before making them elsewhere :)

) -> None:
r"""q-Noisy Expected Improvement.

Expand Down Expand Up @@ -324,6 +325,9 @@ def __init__(
incremental: Whether to compute incremental EI over the pending points
or compute EI of the joint batch improvement (including pending
points).
infeasible_obj: A Tensor to be used calculating the best objective when
no feasible points exist. If None, automatically calculate lower
bound on objective values from the GP posterior.

TODO: similar to qNEHVI, when we are using sequential greedy candidate
selection, we could incorporate pending points X_baseline and compute
Expand All @@ -333,6 +337,7 @@ def __init__(
# TODO: separate out baseline variables initialization and other functions
# in qNEI to avoid duplication of both code and work at runtime.
self.incremental = incremental
self.infeasible_obj = infeasible_obj

super().__init__(
model=model,
Expand Down Expand Up @@ -570,6 +575,7 @@ def _compute_best_feasible_objective(self, samples: Tensor, obj: Tensor) -> Tens
objective=self.objective,
posterior_transform=self.posterior_transform,
X_baseline=self.X_baseline,
infeasible_obj=self.infeasible_obj,
)


Expand Down
62 changes: 47 additions & 15 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
from botorch.sampling.base import MCSampler
from botorch.sampling.get_sampler import get_sampler
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
from botorch.utils.objective import compute_feasibility_indicator
from botorch.utils.objective import (
compute_feasibility_indicator,
compute_smoothed_feasibility_indicator,
)
from botorch.utils.sampling import optimize_posterior_samples
from botorch.utils.transforms import is_ensemble, normalize_indices
from gpytorch.models import GP
Expand Down Expand Up @@ -171,8 +174,9 @@ def _estimate_objective_lower_bound(
posterior_transform: PosteriorTransform | None,
X: Tensor,
) -> Tensor:
"""Estimates a lower bound on the objective values by evaluating the model at convex
combinations of `X`, returning the 6-sigma lower bound of the computed statistics.
"""Estimates a lower bound on the objective values by evaluating the at uniformly
random points in the bounding box of `X`, returning the 6-sigma lower bound of the
computed statistics.

Args:
model: A fitted model.
Expand All @@ -183,19 +187,19 @@ def _estimate_objective_lower_bound(
Returns:
A `m`-dimensional Tensor of lower bounds of the objectives.
"""
convex_weights = torch.rand(
32,
X.shape[-2],
dtype=X.dtype,
device=X.device,
)
weights_sum = convex_weights.sum(dim=0, keepdim=True)
convex_weights = convex_weights / weights_sum
# we do not have access to `bounds` here, so we infer the bounding box
# from data, expanding by 10% in each direction
X_lb = X.min(dim=-2)
X_ub = X.max(dim=-2)
X_range = X_ub - X_lb
X_padding = 0.1 * X_range
uniform_samples = torch.rand(32, X.shape[-1], dtype=X.dtype, device=X.device)
X_samples = X_lb - X_padding + uniform_samples * (X_range + 2 * X_padding)
# infeasible cost M is such that -M < min_x f(x), thus
# 0 < min_x f(x) - (-M), so we should take -M as a lower
# bound on the best feasible objective
return -get_infeasible_cost(
X=convex_weights @ X,
X=X_samples,
model=model,
objective=objective,
posterior_transform=posterior_transform,
Expand Down Expand Up @@ -235,8 +239,19 @@ def objective(Y: Tensor, X: Tensor | None = None):
return Y.squeeze(-1)

posterior = model.posterior(X, posterior_transform=posterior_transform)
lb = objective(posterior.mean - 6 * posterior.variance.clamp_min(0).sqrt(), X=X)
if lb.ndim < posterior.mean.ndim:
# We check both the upper and lower bound of the posterior, since the objective
# may be increasing or decreasing. For objectives that are neither (eg. absolute
# distance from a target), this should still provide a good bound.
six_stdv = 6 * posterior.variance.clamp_min(0).sqrt()
lb = torch.stack(
[
objective(posterior.mean - six_stdv, X=X),
objective(posterior.mean + six_stdv, X=X),
],
dim=0,
)

if lb.ndim - 1 < posterior.mean.ndim:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We add an extra dimension in the line above, by stacking the mean - 6*std and mean + 6*std. This then changes the check below. Maybe it would be better for me to be explicit above, taking the minimum of mean +- 6*std instead of stacking, and then leaving this line unchanged?

lb = lb.unsqueeze(-1)
# Take outcome-wise min. Looping in to handle batched models.
while lb.dim() > 1:
Expand Down Expand Up @@ -374,7 +389,24 @@ def prune_inferior_points(
sampler=sampler,
marginalize_dim=marginalize_dim,
)
if infeas.any():
if infeas.all():
# if no points are feasible, keep the point closest to being feasible
with torch.no_grad():
posterior = model.posterior(X=X, posterior_transform=posterior_transform)
if sampler is None:
sampler = get_sampler(
posterior=posterior, sample_shape=torch.Size([num_samples])
)
samples = sampler(posterior)
# use the probability of feasibility as the objective for computing best points
obj_vals = compute_smoothed_feasibility_indicator(
constraints=constraints,
samples=samples,
eta=1e-3,
log=True,
)

elif infeas.any():
# set infeasible points to worse than worst objective across all samples
# Use clone() here to avoid deprecated `index_put_` on an expanded tensor
obj_vals = obj_vals.clone()
Expand Down