diff --git a/botorch/generation/gen.py b/botorch/generation/gen.py index 7611691d27..d8eca5d646 100644 --- a/botorch/generation/gen.py +++ b/botorch/generation/gen.py @@ -532,7 +532,9 @@ def gen_candidates_torch( optimizer (Optimizer): The pytorch optimizer to use to perform candidate search. options: Options used to control the optimization. Includes - maxiter: Maximum number of iterations + optimizer_options: Dict of additional options to pass to the optimizer + (e.g. lr, weight_decay) + stopping_criterion_options: Dict of options for the stopping criterion. callback: A callback function accepting the current iteration, loss, and gradients as arguments. This function is executed after computing the loss and gradients, but before calling the optimizer. @@ -580,11 +582,17 @@ def gen_candidates_torch( [i for i in range(clamped_candidates.shape[-1]) if i not in fixed_features], ] clamped_candidates = clamped_candidates.requires_grad_(True) - _optimizer = optimizer(params=[clamped_candidates], lr=options.get("lr", 0.025)) + + # Extract optimizer-specific options from the options dict + optimizer_options = options.pop("optimizer_options", {}) + stopping_criterion_options = options.pop("stopping_criterion_options", {}) + + optimizer_options["lr"] = optimizer_options.get("lr", 0.025) + _optimizer = optimizer(params=[clamped_candidates], **optimizer_options) i = 0 stop = False - stopping_criterion = ExpMAStoppingCriterion(**options) + stopping_criterion = ExpMAStoppingCriterion(**stopping_criterion_options) while not stop: i += 1 with torch.no_grad(): diff --git a/test/generation/test_gen.py b/test/generation/test_gen.py index dc5961038c..bc34b59f82 100644 --- a/test/generation/test_gen.py +++ b/test/generation/test_gen.py @@ -324,6 +324,37 @@ def test_gen_candidates_torch_timeout_behavior(self): self.assertFalse(any(issubclass(w.category, OptimizationWarning) for w in ws)) self.assertTrue("Optimization timed out" in logs.output[-1]) + def test_gen_candidates_torch_optimizer_with_optimizer_args(self): + """Test that Adam optimizer is created with the correct learning rate.""" + self._setUp(double=False) + qEI = qExpectedImprovement(self.model, best_f=self.f_best) + + # Create a mock optimizer class + mock_optimizer_class = mock.MagicMock() + mock_optimizer_instance = mock.MagicMock() + mock_optimizer_class.return_value = mock_optimizer_instance + + gen_candidates_torch( + initial_conditions=self.initial_conditions, + acquisition_function=qEI, + lower_bounds=0, + upper_bounds=1, + optimizer=mock_optimizer_class, # Pass the mock optimizer directly + options={ + "optimizer_options": {"lr": 0.02, "weight_decay": 1e-5}, + "stopping_criterion_options": {"maxiter": 1}, + }, + ) + + # Verify that the optimizer was called with the correct arguments + mock_optimizer_class.assert_called_once() + call_args = mock_optimizer_class.call_args + # Check that params argument is present + self.assertIn("params", call_args.kwargs) + # Check optimizer options + self.assertEqual(call_args.kwargs["lr"], 0.02) + self.assertEqual(call_args.kwargs["weight_decay"], 1e-5) + def test_gen_candidates_scipy_warns_opt_no_res(self): ckwargs = {"dtype": torch.float, "device": self.device}