@@ -125,15 +125,14 @@ def compile(
125125 else dspy .settings .max_errors
126126 )
127127
128- # Update max demos if specified
129- initial_max_bootstrapped_demos = self .max_bootstrapped_demos
130- if max_bootstrapped_demos is not None :
131- self .max_bootstrapped_demos = max_bootstrapped_demos
132- initial_max_labeled_demos = self .max_labeled_demos
133- if max_labeled_demos is not None :
134- self .max_labeled_demos = max_labeled_demos
128+ effective_max_bootstrapped_demos = (
129+ max_bootstrapped_demos if max_bootstrapped_demos is not None else self .max_bootstrapped_demos
130+ )
131+ effective_max_labeled_demos = (
132+ max_labeled_demos if max_labeled_demos is not None else self .max_labeled_demos
133+ )
135134
136- zeroshot_opt = (self . max_bootstrapped_demos == 0 ) and (self . max_labeled_demos == 0 )
135+ zeroshot_opt = (effective_max_bootstrapped_demos == 0 ) and (effective_max_labeled_demos == 0 )
137136
138137 # If auto is None, and num_trials is not provided (but num_candidates is), raise an error that suggests a good num_trials value
139138 if self .auto is None and (self .num_candidates is not None and num_trials is None ):
@@ -159,13 +158,42 @@ def compile(
159158 # Set training & validation sets
160159 trainset , valset = self ._set_and_validate_datasets (trainset , valset )
161160
161+ num_instruct_candidates = (
162+ self .num_instruct_candidates
163+ if self .num_instruct_candidates is not None
164+ else self .num_candidates
165+ )
166+ num_fewshot_candidates = (
167+ self .num_fewshot_candidates
168+ if self .num_fewshot_candidates is not None
169+ else self .num_candidates
170+ )
171+
162172 # Set hyperparameters based on run mode (if set)
163- num_trials , valset , minibatch = self ._set_hyperparams_from_run_mode (
164- student , num_trials , minibatch , zeroshot_opt , valset
173+ (
174+ num_trials ,
175+ valset ,
176+ minibatch ,
177+ num_instruct_candidates ,
178+ num_fewshot_candidates ,
179+ ) = self ._set_hyperparams_from_run_mode (
180+ student ,
181+ num_trials ,
182+ minibatch ,
183+ zeroshot_opt ,
184+ valset ,
185+ num_instruct_candidates ,
186+ num_fewshot_candidates ,
165187 )
166188
167189 if self .auto :
168- self ._print_auto_run_settings (num_trials , minibatch , valset )
190+ self ._print_auto_run_settings (
191+ num_trials ,
192+ minibatch ,
193+ valset ,
194+ num_fewshot_candidates ,
195+ num_instruct_candidates ,
196+ )
169197
170198 if minibatch and minibatch_size > len (valset ):
171199 raise ValueError (f"Minibatch size cannot exceed the size of the valset. Valset size: { len (valset )} ." )
@@ -183,7 +211,17 @@ def compile(
183211 )
184212
185213 # Step 1: Bootstrap few-shot examples
186- demo_candidates = self ._bootstrap_fewshot_examples (program , trainset , seed , teacher )
214+ demo_candidates = self ._bootstrap_fewshot_examples (
215+ program ,
216+ trainset ,
217+ seed ,
218+ teacher ,
219+ num_fewshot_candidates = num_fewshot_candidates ,
220+ max_bootstrapped_demos = effective_max_bootstrapped_demos ,
221+ max_labeled_demos = effective_max_labeled_demos ,
222+ max_errors = effective_max_errors ,
223+ metric_threshold = self .metric_threshold ,
224+ )
187225
188226 # Step 2: Propose instruction candidates
189227 instruction_candidates = self ._propose_instructions (
@@ -195,6 +233,7 @@ def compile(
195233 data_aware_proposer ,
196234 tip_aware_proposer ,
197235 fewshot_aware_proposer ,
236+ num_instruct_candidates = num_instruct_candidates ,
198237 )
199238
200239 # If zero-shot, discard demos
@@ -215,10 +254,6 @@ def compile(
215254 seed ,
216255 )
217256
218- # Reset max demos
219- self .max_bootstrapped_demos = initial_max_bootstrapped_demos
220- self .max_labeled_demos = initial_max_labeled_demos
221-
222257 return best_program
223258
224259 def _set_random_seeds (self , seed ):
@@ -237,13 +272,17 @@ def _set_num_trials_from_num_candidates(self, program, zeroshot_opt, num_candida
237272 def _set_hyperparams_from_run_mode (
238273 self ,
239274 program : Any ,
240- num_trials : int ,
275+ num_trials : int | None ,
241276 minibatch : bool ,
242277 zeroshot_opt : bool ,
243278 valset : list ,
244- ) -> tuple [int , list , bool ]:
279+ num_instruct_candidates : int | None ,
280+ num_fewshot_candidates : int | None ,
281+ ) -> tuple [int , list , bool , int , int ]:
245282 if self .auto is None :
246- return num_trials , valset , minibatch
283+ if num_instruct_candidates is None or num_fewshot_candidates is None :
284+ raise ValueError ("num_candidates must be provided when auto is None." )
285+ return num_trials , valset , minibatch , num_instruct_candidates , num_fewshot_candidates
247286
248287 auto_settings = AUTO_RUN_SETTINGS [self .auto ]
249288
@@ -253,12 +292,12 @@ def _set_hyperparams_from_run_mode(
253292 # Set num instruct candidates to 1/2 of N if optimizing with few-shot examples, otherwise set to N
254293 # This is because we've found that it's generally better to spend optimization budget on few-shot examples
255294 # When they are allowed.
256- self . num_instruct_candidates = auto_settings ["n" ] if zeroshot_opt else int (auto_settings ["n" ] * 0.5 )
257- self . num_fewshot_candidates = auto_settings ["n" ]
295+ num_instruct_candidates = auto_settings ["n" ] if zeroshot_opt else int (auto_settings ["n" ] * 0.5 )
296+ num_fewshot_candidates = auto_settings ["n" ]
258297
259298 num_trials = self ._set_num_trials_from_num_candidates (program , zeroshot_opt , auto_settings ["n" ])
260299
261- return num_trials , valset , minibatch
300+ return num_trials , valset , minibatch , num_instruct_candidates , num_fewshot_candidates
262301
263302 def _set_and_validate_datasets (self , trainset : list , valset : list | None ):
264303 if not trainset :
@@ -277,13 +316,20 @@ def _set_and_validate_datasets(self, trainset: list, valset: list | None):
277316
278317 return trainset , valset
279318
280- def _print_auto_run_settings (self , num_trials : int , minibatch : bool , valset : list ):
319+ def _print_auto_run_settings (
320+ self ,
321+ num_trials : int ,
322+ minibatch : bool ,
323+ valset : list ,
324+ num_fewshot_candidates : int ,
325+ num_instruct_candidates : int ,
326+ ):
281327 logger .info (
282328 f"\n RUNNING WITH THE FOLLOWING { self .auto .upper ()} AUTO RUN SETTINGS:"
283329 f"\n num_trials: { num_trials } "
284330 f"\n minibatch: { minibatch } "
285- f"\n num_fewshot_candidates: { self . num_fewshot_candidates } "
286- f"\n num_instruct_candidates: { self . num_instruct_candidates } "
331+ f"\n num_fewshot_candidates: { num_fewshot_candidates } "
332+ f"\n num_instruct_candidates: { num_instruct_candidates } "
287333 f"\n valset size: { len (valset )} \n "
288334 )
289335
@@ -296,18 +342,19 @@ def _estimate_lm_calls(
296342 minibatch_full_eval_steps : int ,
297343 valset : list ,
298344 program_aware_proposer : bool ,
345+ num_instruct_candidates : int ,
299346 ) -> tuple [str , str ]:
300347 num_predictors = len (program .predictors ())
301348
302349 # Estimate prompt model calls
303350 estimated_prompt_model_calls = (
304351 10 # Data summarizer calls
305- + self . num_instruct_candidates * num_predictors # Candidate generation
352+ + num_instruct_candidates * num_predictors # Candidate generation
306353 + (num_predictors + 1 if program_aware_proposer else 0 ) # Program-aware proposer
307354 )
308355 prompt_model_line = (
309356 f"{ YELLOW } - Prompt Generation: { BLUE } { BOLD } 10{ ENDC } { YELLOW } data summarizer calls + "
310- f"{ BLUE } { BOLD } { self . num_instruct_candidates } { ENDC } { YELLOW } * "
357+ f"{ BLUE } { BOLD } { num_instruct_candidates } { ENDC } { YELLOW } * "
311358 f"{ BLUE } { BOLD } { num_predictors } { ENDC } { YELLOW } lm calls in program "
312359 f"+ ({ BLUE } { BOLD } { num_predictors + 1 } { ENDC } { YELLOW } ) lm calls in program-aware proposer "
313360 f"= { BLUE } { BOLD } { estimated_prompt_model_calls } { ENDC } { YELLOW } prompt model calls{ ENDC } "
@@ -334,38 +381,48 @@ def _estimate_lm_calls(
334381
335382 return prompt_model_line , task_model_line
336383
337- def _bootstrap_fewshot_examples (self , program : Any , trainset : list , seed : int , teacher : Any ) -> list | None :
384+ def _bootstrap_fewshot_examples (
385+ self ,
386+ program : Any ,
387+ trainset : list ,
388+ seed : int ,
389+ teacher : Any ,
390+ * ,
391+ num_fewshot_candidates : int ,
392+ max_bootstrapped_demos : int ,
393+ max_labeled_demos : int ,
394+ max_errors : int | None ,
395+ metric_threshold : float | None ,
396+ ) -> list | None :
338397 logger .info ("\n ==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==" )
339- if self . max_bootstrapped_demos > 0 :
398+ if max_bootstrapped_demos > 0 :
340399 logger .info (
341400 "These will be used as few-shot example candidates for our program and for creating instructions.\n "
342401 )
343402 else :
344403 logger .info ("These will be used for informing instruction proposal.\n " )
345404
346- logger .info (f"Bootstrapping N={ self . num_fewshot_candidates } sets of demonstrations..." )
405+ logger .info (f"Bootstrapping N={ num_fewshot_candidates } sets of demonstrations..." )
347406
348- zeroshot = self . max_bootstrapped_demos == 0 and self . max_labeled_demos == 0
407+ zeroshot = max_bootstrapped_demos == 0 and max_labeled_demos == 0
349408
350- # try:
351- effective_max_errors = (
352- self .max_errors if self .max_errors is not None else dspy .settings .max_errors
353- )
409+ if max_errors is None :
410+ max_errors = dspy .settings .max_errors
354411
355412 demo_candidates = create_n_fewshot_demo_sets (
356413 student = program ,
357- num_candidate_sets = self . num_fewshot_candidates ,
414+ num_candidate_sets = num_fewshot_candidates ,
358415 trainset = trainset ,
359- max_labeled_demos = (LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self . max_labeled_demos ),
416+ max_labeled_demos = (LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_labeled_demos ),
360417 max_bootstrapped_demos = (
361- BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self . max_bootstrapped_demos
418+ BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_bootstrapped_demos
362419 ),
363420 metric = self .metric ,
364- max_errors = effective_max_errors ,
421+ max_errors = max_errors ,
365422 teacher = teacher ,
366423 teacher_settings = self .teacher_settings ,
367424 seed = seed ,
368- metric_threshold = self . metric_threshold ,
425+ metric_threshold = metric_threshold ,
369426 rng = self .rng ,
370427 )
371428 # NOTE: Bootstrapping is essential to MIPRO!
@@ -387,6 +444,7 @@ def _propose_instructions(
387444 data_aware_proposer : bool ,
388445 tip_aware_proposer : bool ,
389446 fewshot_aware_proposer : bool ,
447+ num_instruct_candidates : int ,
390448 ) -> dict [int , list [str ]]:
391449 logger .info ("\n ==> STEP 2: PROPOSE INSTRUCTION CANDIDATES <==" )
392450 logger .info (
@@ -411,12 +469,12 @@ def _propose_instructions(
411469 init_temperature = self .init_temperature ,
412470 )
413471
414- logger .info (f"\n Proposing N={ self . num_instruct_candidates } instructions...\n " )
472+ logger .info (f"\n Proposing N={ num_instruct_candidates } instructions...\n " )
415473 instruction_candidates = proposer .propose_instructions_for_program (
416474 trainset = trainset ,
417475 program = program ,
418476 demo_candidates = demo_candidates ,
419- N = self . num_instruct_candidates ,
477+ N = num_instruct_candidates ,
420478 trial_logs = {},
421479 )
422480
0 commit comments