diff --git a/src/MutationFunctions.jl b/src/MutationFunctions.jl index 4f9111b5f..b58501894 100644 --- a/src/MutationFunctions.jl +++ b/src/MutationFunctions.jl @@ -1,5 +1,6 @@ module MutationFunctionsModule +using DispatchDoctor: @unstable using Random: default_rng, AbstractRNG using DynamicExpressions: AbstractExpressionNode, @@ -40,6 +41,18 @@ function with_contents_for_mutation(ex::AbstractExpression, new_contents, ::Noth return with_contents(ex, new_contents) end +function apply_tree_mutation( + ex::AbstractExpression, + rng::AbstractRNG, + mutation_func::F, + args::Vararg{Any,N}; + kwargs..., +) where {F<:Function,N} + tree, context = get_contents_for_mutation(ex, rng) + mutated_tree = mutation_func(tree, args..., rng; kwargs...) + return with_contents_for_mutation(ex, mutated_tree, context) +end + """ random_node(tree::AbstractNode; filter::F=Returns(true)) @@ -58,9 +71,7 @@ end """Swap operands in binary operator for ops like pow and divide""" function swap_operands(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree, context = get_contents_for_mutation(ex, rng) - ex = with_contents_for_mutation(ex, swap_operands(tree, rng), context) - return ex + return apply_tree_mutation(ex, rng, swap_operands) end function swap_operands(tree::AbstractNode, rng::AbstractRNG=default_rng()) if !any(node -> node.degree == 2, tree) @@ -73,17 +84,13 @@ end """Randomly convert an operator into another one (binary->binary; unary->unary)""" function mutate_operator( - ex::AbstractExpression{T}, options::AbstractOptions, rng::AbstractRNG=default_rng() -) where {T<:DATA_TYPE} - tree, context = get_contents_for_mutation(ex, rng) - ex = with_contents_for_mutation(ex, mutate_operator(tree, options, rng), context) - return ex + ex::AbstractExpression, options::AbstractOptions, rng::AbstractRNG=default_rng() +) + return apply_tree_mutation(ex, rng, mutate_operator, options) end function mutate_operator( - tree::AbstractExpressionNode{T}, - options::AbstractOptions, - rng::AbstractRNG=default_rng(), -) where {T} + tree::AbstractExpressionNode, options::AbstractOptions, rng::AbstractRNG=default_rng() +) if !(has_operators(tree)) return tree end @@ -98,16 +105,12 @@ end """Randomly perturb a constant""" function mutate_constant( - ex::AbstractExpression{T}, + ex::AbstractExpression, temperature, options::AbstractOptions, rng::AbstractRNG=default_rng(), -) where {T<:DATA_TYPE} - tree, context = get_contents_for_mutation(ex, rng) - ex = with_contents_for_mutation( - ex, mutate_constant(tree, temperature, options, rng), context - ) - return ex +) + return apply_tree_mutation(ex, rng, mutate_constant, temperature, options) end function mutate_constant( tree::AbstractExpressionNode{T}, @@ -145,17 +148,15 @@ end """Add a random unary/binary operation to the end of a tree""" function append_random_op( - ex::AbstractExpression{T}, + ex::AbstractExpression, options::AbstractOptions, nfeatures::Int, rng::AbstractRNG=default_rng(); make_new_bin_op::Union{Bool,Nothing}=nothing, -) where {T<:DATA_TYPE} - tree, context = get_contents_for_mutation(ex, rng) - ex = with_contents_for_mutation( - ex, append_random_op(tree, options, nfeatures, rng; make_new_bin_op), context +) + return apply_tree_mutation( + ex, rng, append_random_op, options, nfeatures; make_new_bin_op=make_new_bin_op ) - return ex end function append_random_op( tree::AbstractExpressionNode{T}, @@ -195,11 +196,7 @@ function insert_random_op( nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree, context = get_contents_for_mutation(ex, rng) - ex = with_contents_for_mutation( - ex, insert_random_op(tree, options, nfeatures, rng), context - ) - return ex + return apply_tree_mutation(ex, rng, insert_random_op, options, nfeatures) end function insert_random_op( tree::AbstractExpressionNode{T}, @@ -231,11 +228,7 @@ function prepend_random_op( nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree, context = get_contents_for_mutation(ex, rng) - ex = with_contents_for_mutation( - ex, prepend_random_op(tree, options, nfeatures, rng), context - ) - return ex + return apply_tree_mutation(ex, rng, prepend_random_op, options, nfeatures) end function prepend_random_op( tree::AbstractExpressionNode{T}, @@ -294,11 +287,7 @@ function delete_random_op!( nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree, context = get_contents_for_mutation(ex, rng) - ex = with_contents_for_mutation( - ex, delete_random_op!(tree, options, nfeatures, rng), context - ) - return ex + return apply_tree_mutation(ex, rng, delete_random_op!, options, nfeatures) end function delete_random_op!( tree::AbstractExpressionNode{T}, @@ -345,18 +334,15 @@ function delete_random_op!( return tree end -function randomize_tree( +# TODO: For some reason this is unstable on Julia 1.10. +@unstable function randomize_tree( ex::AbstractExpression, curmaxsize::Int, options::AbstractOptions, nfeatures::Int, rng::AbstractRNG=default_rng(), ) - tree, context = get_contents_for_mutation(ex, rng) - ex = with_contents_for_mutation( - ex, randomize_tree(tree, curmaxsize, options, nfeatures, rng), context - ) - return ex + return apply_tree_mutation(ex, rng, randomize_tree, curmaxsize, options, nfeatures) end function randomize_tree( ::AbstractExpressionNode{T}, @@ -472,8 +458,7 @@ function get_two_nodes_without_loop(tree::AbstractNode, rng::AbstractRNG; max_at end function form_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree, context = get_contents_for_mutation(ex, rng) - return with_contents_for_mutation(ex, form_random_connection!(tree, rng), context) + return apply_tree_mutation(ex, rng, form_random_connection!) end function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng()) if length(tree) < 5 @@ -496,8 +481,7 @@ function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rn end function break_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree, context = get_contents_for_mutation(ex, rng) - return with_contents_for_mutation(ex, break_random_connection!(tree, rng), context) + return apply_tree_mutation(ex, rng, break_random_connection!) end function break_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng()) tree.degree == 0 && return tree @@ -515,9 +499,7 @@ function is_valid_rotation_node(node::AbstractNode) end function randomly_rotate_tree!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree, context = get_contents_for_mutation(ex, rng) - rotated_tree = randomly_rotate_tree!(tree, rng) - return with_contents_for_mutation(ex, rotated_tree, context) + return apply_tree_mutation(ex, rng, randomly_rotate_tree!) end function randomly_rotate_tree!(tree::AbstractNode, rng::AbstractRNG=default_rng()) num_rotation_nodes = count(is_valid_rotation_node, tree) diff --git a/src/ProgressBars.jl b/src/ProgressBars.jl index c32b0c82f..63c39d52a 100644 --- a/src/ProgressBars.jl +++ b/src/ProgressBars.jl @@ -9,16 +9,21 @@ using ..UtilsModule: AnnotatedString mutable struct WrappedProgressBar bar::Progress postfix::Vector{Tuple{AnnotatedString,AnnotatedString}} + clean_postfix::Vector{Tuple{AnnotatedString,AnnotatedString}} + last_update::Float64 function WrappedProgressBar(n::Integer, niterations::Integer; kwargs...) init_vector = Tuple{AnnotatedString,AnnotatedString}[] kwargs = (; kwargs..., desc="Evolving for $niterations iterations...") + last_update = time() if get(ENV, "SYMBOLIC_REGRESSION_TEST", "false") == "true" # For testing, create a progress bar that writes to devnull output = devnull - return new(Progress(n; output, kwargs...), init_vector) + return new( + Progress(n; output, kwargs...), init_vector, copy(init_vector), last_update + ) end - return new(Progress(n; kwargs...), init_vector) + return new(Progress(n; kwargs...), init_vector, copy(init_vector), last_update) end end @@ -34,8 +39,13 @@ end """Iterate a progress bar.""" function manually_iterate!(pbar::WrappedProgressBar) width = barlen(pbar) - postfix = map(Fix{2}(format_for_meter, width), pbar.postfix) - next!(pbar.bar; showvalues=postfix, valuecolor=:none) + last_update = pbar.last_update + update_interval = 0.005 + if time() - last_update > update_interval + pbar.clean_postfix = map(Fix{2}(format_for_meter, width), pbar.postfix) + pbar.last_update = time() + end + next!(pbar.bar; showvalues=pbar.clean_postfix, valuecolor=:none) return nothing end diff --git a/src/RegularizedEvolution.jl b/src/RegularizedEvolution.jl index ad1223d5c..8c2625842 100644 --- a/src/RegularizedEvolution.jl +++ b/src/RegularizedEvolution.jl @@ -8,8 +8,162 @@ using ..MutateModule: next_generation, crossover_generation using ..RecorderModule: @recorder using ..UtilsModule: argmin_fast -# Pass through the population several times, replacing the oldest -# with the fittest of a small subsample +function setup_member_recording!(record::RecordType, members, options::AbstractOptions) + if !haskey(record, "mutations") + record["mutations"] = RecordType() + end + + for member in members + if !haskey(record["mutations"], "$(member.ref)") + record["mutations"]["$(member.ref)"] = RecordType( + "events" => Vector{RecordType}(), + "tree" => string_tree(member.tree, options), + "cost" => member.cost, + "loss" => member.loss, + "parent" => member.parent, + ) + end + end +end + +""" + handle_mutation!(pop, dataset, running_search_statistics, options, record, temperature, curmaxsize) + +Perform mutation on a selected member and replace the oldest population member with the result. +Returns the number of evaluations performed. +""" +function handle_mutation!( + pop::P, + dataset::Dataset{T,L}, + running_search_statistics::RunningSearchStatistics, + options::AbstractOptions, + record::RecordType, + temperature, + curmaxsize::Int, +) where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:Population{T,L}} + # Select best member from tournament + allstar = best_of_sample(pop, running_search_statistics, options) + + # Perform mutation + mutation_recorder = RecordType() + baby, mutation_accepted, num_evals = next_generation( + dataset, + allstar, + temperature, + curmaxsize, + running_search_statistics, + options; + tmp_recorder=mutation_recorder, + ) + + # Skip if mutation failed and we're configured to skip failures + if !mutation_accepted && options.skip_mutation_failures + return num_evals + end + + # Find oldest member to replace + oldest = argmin_fast([pop.members[member].birth for member in 1:(pop.n)]) + + # Record mutation events + @recorder begin + members_to_record = [allstar, baby, pop.members[oldest]] + setup_member_recording!(record, members_to_record, options) + + mutate_event = RecordType( + "type" => "mutate", + "time" => time(), + "child" => baby.ref, + "mutation" => mutation_recorder, + ) + death_event = RecordType("type" => "death", "time" => time()) + + push!(record["mutations"]["$(allstar.ref)"]["events"], mutate_event) + push!(record["mutations"]["$(pop.members[oldest].ref)"]["events"], death_event) + end + + # Replace the oldest member with the new baby + pop.members[oldest] = baby + + return num_evals +end + +""" + handle_crossover!(pop, dataset, running_search_statistics, options, record, curmaxsize) + +Perform crossover between two selected members and replace the two oldest population members with the results. +Returns the number of evaluations performed. +""" +function handle_crossover!( + pop::P, + dataset::Dataset{T,L}, + running_search_statistics::RunningSearchStatistics, + options::AbstractOptions, + record::RecordType, + curmaxsize::Int, +) where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:Population{T,L}} + # Select the two parents + allstar1 = best_of_sample(pop, running_search_statistics, options) + allstar2 = best_of_sample(pop, running_search_statistics, options) + + # Perform crossover + crossover_recorder = RecordType() + baby1, baby2, crossover_accepted, num_evals = crossover_generation( + allstar1, allstar2, dataset, curmaxsize, options; recorder=crossover_recorder + ) + + # Skip if crossover failed and we're configured to skip failures + if !crossover_accepted && options.skip_mutation_failures + return num_evals + end + + # Find the two oldest members to replace + oldest1 = argmin_fast([pop.members[member].birth for member in 1:(pop.n)]) + BT = typeof(first(pop.members).birth) + oldest2 = argmin_fast([ + i == oldest1 ? typemax(BT) : pop.members[i].birth for i in 1:(pop.n) + ]) + + # Record crossover events + @recorder begin + members_to_record = [ + allstar1, allstar2, baby1, baby2, pop.members[oldest1], pop.members[oldest2] + ] + setup_member_recording!(record, members_to_record, options) + + crossover_event = RecordType( + "type" => "crossover", + "time" => time(), + "parent1" => allstar1.ref, + "parent2" => allstar2.ref, + "child1" => baby1.ref, + "child2" => baby2.ref, + "details" => crossover_recorder, + ) + death_event1 = RecordType("type" => "death", "time" => time()) + death_event2 = RecordType("type" => "death", "time" => time()) + + push!(record["mutations"]["$(allstar1.ref)"]["events"], crossover_event) + push!(record["mutations"]["$(allstar2.ref)"]["events"], crossover_event) + push!(record["mutations"]["$(pop.members[oldest1].ref)"]["events"], death_event1) + push!(record["mutations"]["$(pop.members[oldest2].ref)"]["events"], death_event2) + end + + # Replace old members with new ones + pop.members[oldest1] = baby1 + pop.members[oldest2] = baby2 + + return num_evals +end + +""" + reg_evol_cycle(dataset, pop, temperature, curmaxsize, running_search_statistics, options, record) + +Pass through the population several times, replacing the oldest with the fittest +members from a small subsample based on tournament selection. + +This implements the regularized evolution algorithm, alternating between mutation and +crossover operations based on the crossover probability. +""" function reg_evol_cycle( dataset::Dataset{T,L}, pop::P, @@ -20,137 +174,28 @@ function reg_evol_cycle( record::RecordType, )::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:Population{T,L}} num_evals = 0.0 + + # Calculate number of evolution cycles based on population size and tournament size n_evol_cycles = ceil(Int, pop.n / options.tournament_selection_n) - for i in 1:n_evol_cycles + # Perform multiple cycles of selection and replacement + for _ in 1:n_evol_cycles if rand() > options.crossover_probability - allstar = best_of_sample(pop, running_search_statistics, options) - mutation_recorder = RecordType() - baby, mutation_accepted, tmp_num_evals = next_generation( + # Mutation case + num_evals += handle_mutation!( + pop, dataset, - allstar, + running_search_statistics, + options, + record, temperature, curmaxsize, - running_search_statistics, - options; - tmp_recorder=mutation_recorder, ) - num_evals += tmp_num_evals - - if !mutation_accepted && options.skip_mutation_failures - # Skip this mutation rather than replacing oldest member with unchanged member - continue - end - - oldest = argmin_fast([pop.members[member].birth for member in 1:(pop.n)]) - - @recorder begin - if !haskey(record, "mutations") - record["mutations"] = RecordType() - end - for member in [allstar, baby, pop.members[oldest]] - if !haskey(record["mutations"], "$(member.ref)") - record["mutations"]["$(member.ref)"] = RecordType( - "events" => Vector{RecordType}(), - "tree" => string_tree(member.tree, options), - "cost" => member.cost, - "loss" => member.loss, - "parent" => member.parent, - ) - end - end - mutate_event = RecordType( - "type" => "mutate", - "time" => time(), - "child" => baby.ref, - "mutation" => mutation_recorder, - ) - death_event = RecordType("type" => "death", "time" => time()) - - # Put in random key rather than vector; otherwise there are collisions! - push!(record["mutations"]["$(allstar.ref)"]["events"], mutate_event) - push!( - record["mutations"]["$(pop.members[oldest].ref)"]["events"], death_event - ) - end - - pop.members[oldest] = baby - - else # Crossover - allstar1 = best_of_sample(pop, running_search_statistics, options) - allstar2 = best_of_sample(pop, running_search_statistics, options) - - crossover_recorder = RecordType() - baby1, baby2, crossover_accepted, tmp_num_evals = crossover_generation( - allstar1, - allstar2, - dataset, - curmaxsize, - options; - recorder=crossover_recorder, + else + # Crossover case + num_evals += handle_crossover!( + pop, dataset, running_search_statistics, options, record, curmaxsize ) - num_evals += tmp_num_evals - - if !crossover_accepted && options.skip_mutation_failures - continue - end - - # Find the oldest members to replace: - oldest1 = argmin_fast([pop.members[member].birth for member in 1:(pop.n)]) - BT = typeof(first(pop.members).birth) - oldest2 = argmin_fast([ - i == oldest1 ? typemax(BT) : pop.members[i].birth for i in 1:(pop.n) - ]) - - @recorder begin - if !haskey(record, "mutations") - record["mutations"] = RecordType() - end - for member in [ - allstar1, - allstar2, - baby1, - baby2, - pop.members[oldest1], - pop.members[oldest2], - ] - if !haskey(record["mutations"], "$(member.ref)") - record["mutations"]["$(member.ref)"] = RecordType( - "events" => Vector{RecordType}(), - "tree" => string_tree(member.tree, options), - "cost" => member.cost, - "loss" => member.loss, - "parent" => member.parent, - ) - end - end - crossover_event = RecordType( - "type" => "crossover", - "time" => time(), - "parent1" => allstar1.ref, - "parent2" => allstar2.ref, - "child1" => baby1.ref, - "child2" => baby2.ref, - "details" => crossover_recorder, - ) - death_event1 = RecordType("type" => "death", "time" => time()) - death_event2 = RecordType("type" => "death", "time" => time()) - - push!(record["mutations"]["$(allstar1.ref)"]["events"], crossover_event) - push!(record["mutations"]["$(allstar2.ref)"]["events"], crossover_event) - push!( - record["mutations"]["$(pop.members[oldest1].ref)"]["events"], - death_event1, - ) - push!( - record["mutations"]["$(pop.members[oldest2].ref)"]["events"], - death_event2, - ) - end - - # Replace old members with new ones: - pop.members[oldest1] = baby1 - pop.members[oldest2] = baby2 end end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 5fcf57f63..2d7505660 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -810,6 +810,319 @@ function _warmup_search!( end return nothing end +""" + _initialize_parallel_workers!(state, nout, options) + +Initialize parallel workers for populations across output dimensions. +""" +function _initialize_parallel_workers!( + state::AbstractSearchState{T,L,N}, + nout::Int, + options::AbstractOptions, + ropt::AbstractRuntimeOptions, +) where {T,L,N} + if ropt.parallelism in (:multiprocessing, :multithreading) + for j in 1:nout, i in 1:(options.populations) + # Start listening for each population to finish: + t = Base.errormonitor( + @async put!(state.channels[j][i], fetch(state.worker_output[j][i])) + ) + push!(state.tasks[j], t) + end + end +end + +""" + _process_population_results!(state, (j, i), datasets, options, resource_monitor, equation_speed, ropt, progress_bar) + +Process results from a completed population iteration and launch the next. +""" +function _process_population_results!( + state::AbstractSearchState{T,L,N}, + (j, i), + datasets::Vector, + options::AbstractOptions, + resource_monitor::ResourceMonitor, + equation_speed::Vector{Float32}, + ropt::AbstractRuntimeOptions, + progress_bar, +) where {T,L,N} + # Take the fetch operation from the channel since its ready + (cur_pop, best_seen, cur_record, cur_num_evals) = if ropt.parallelism in + ( + :multiprocessing, :multithreading + ) + take!( + state.channels[j][i] + ) + else + state.worker_output[j][i] + end::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N}} + + # Update state with current population results + state.last_pops[j][i] = copy(cur_pop) + state.best_sub_pops[j][i] = best_sub_pop(cur_pop; topn=options.topn) + @recorder state.record[] = recursive_merge(state.record[], cur_record) + state.num_evals[j][i] += cur_num_evals + + # Get current dataset and maxsize + dataset = datasets[j] + cur_maxsize = state.cur_maxsizes[j] + + # Update frequency statistics for adaptive parsimony + for member in cur_pop.members + size = compute_complexity(member, options) + update_frequencies!(state.all_running_search_statistics[j]; size) + end + + # Update hall of fame with new expressions + update_hall_of_fame!(state.halls_of_fame[j], cur_pop.members, options) + update_hall_of_fame!( + state.halls_of_fame[j], best_seen.members[best_seen.exists], options + ) + + # Calculate pareto frontier for this output dimension + dominating = calculate_pareto_frontier(state.halls_of_fame[j]) + + # Save results to file if requested + if options.save_to_file + save_to_file(dominating, length(datasets), j, dataset, options, ropt) + end + + # Handle migrations between populations + _perform_migration!(state, j, cur_pop, dominating, options) + + # Decrement remaining cycles and check if this output is done + state.cycles_remaining[j] -= 1 + if state.cycles_remaining[j] == 0 + return false + end + + # Start the next iteration for this population + _launch_next_iteration!(state, j, i, cur_pop, dataset, cur_maxsize, options, ropt) + + # Update maxsize and parsimony window + total_cycles = ropt.niterations * options.populations + state.cur_maxsizes[j] = get_cur_maxsize(; + options, total_cycles, cycles_remaining=state.cycles_remaining[j] + ) + move_window!(state.all_running_search_statistics[j]) + + # Update UI and logs + if !isnothing(progress_bar) + head_node_occupation = estimate_work_fraction(resource_monitor) + update_progress_bar!( + progress_bar, + only(state.halls_of_fame), + only(datasets), + options, + equation_speed, + head_node_occupation, + ropt.parallelism, + ) + end + if ropt.logger !== nothing + logging_callback!(ropt.logger; state, datasets, ropt, options) + end + + return true +end + +""" + _perform_migration!(state, j, cur_pop, dominating, options) + +Handle population migration between different populations and from hall of fame. +""" +function _perform_migration!( + state::AbstractSearchState, j::Int, cur_pop, dominating, options::AbstractOptions +) + # Population migration (from best members of all populations) + if options.migration + best_of_each = Population([ + member for pop in state.best_sub_pops[j] for member in pop.members + ]) + migrate!(best_of_each.members => cur_pop, options; frac=options.fraction_replaced) + end + + # Hall of fame migration (from best overall expressions) + if options.hof_migration && length(dominating) > 0 + migrate!(dominating => cur_pop, options; frac=options.fraction_replaced_hof) + end +end + +""" + _launch_next_iteration!(state, j, i, cur_pop, dataset, cur_maxsize, options, ropt) + +Set up and launch the next iteration for a specific population. +""" +function _launch_next_iteration!( + state::AbstractSearchState{T,L,N}, + j::Int, + i::Int, + cur_pop::Population{T,L,N}, + dataset, + cur_maxsize::Int, + options::AbstractOptions, + ropt::AbstractRuntimeOptions, +) where {T,L,N} + # Get worker assignment and iteration number + worker_idx = assign_next_worker!( + state.worker_assignment; out=j, pop=i, parallelism=ropt.parallelism, state.procs + ) + + iteration = if options.use_recorder + key = "out$(j)_pop$(i)" + find_iteration_from_record(key, state.record[]) + 1 + else + 0 + end + + # Clone statistics and population for next iteration + c_rss = deepcopy(state.all_running_search_statistics[j]) + in_pop = copy(cur_pop) + + # Launch the next search cycle + state.worker_output[j][i] = @sr_spawner( + begin + _dispatch_s_r_cycle( + in_pop, + dataset, + options; + pop=i, + out=j, + iteration, + ropt.verbosity, + cur_maxsize, + running_search_statistics=c_rss, + ) + end, + parallelism = ropt.parallelism, + worker_idx = worker_idx + ) + + # For parallel modes, set up the task to collect results + if ropt.parallelism in (:multiprocessing, :multithreading) + state.tasks[j][i] = Base.errormonitor( + @async put!(state.channels[j][i], fetch(state.worker_output[j][i])) + ) + end +end + +""" + _update_search_statistics!(equation_speed, state, last_speed_recording_time, num_evals_last) + +Update search statistics including evaluation speed. +""" +function _update_search_statistics!( + equation_speed::Vector{Float32}, + state::AbstractSearchState, + last_speed_recording_time::Float64, + num_evals_last::Float64, +) + elapsed_since_speed_recording = time() - last_speed_recording_time + if elapsed_since_speed_recording > 1.0 + current_eval_sum = sum(sum, state.num_evals) + num_evals_since_last = current_eval_sum - num_evals_last + current_speed = num_evals_since_last / elapsed_since_speed_recording + push!(equation_speed, current_speed) + + # Keep a running average of the last 20 seconds + average_over_m_measurements = 20 + if length(equation_speed) > average_over_m_measurements + deleteat!(equation_speed, 1) + end + + return time(), current_eval_sum + else + return last_speed_recording_time, num_evals_last + end +end + +""" + _print_status_update(ropt, options, state, datasets, equation_speed, last_print_time, resource_monitor) + +Print status updates to the console at regular intervals. +""" +function _print_status_update( + ropt::AbstractRuntimeOptions, + options::AbstractOptions, + state::AbstractSearchState, + datasets, + equation_speed::Vector{Float32}, + last_print_time::Float64, + resource_monitor::ResourceMonitor, +) + print_every_n_seconds = 5 + elapsed = time() - last_print_time + + if elapsed > print_every_n_seconds + if ropt.verbosity > 0 && !ropt.progress && length(equation_speed) > 0 + head_node_occupation = estimate_work_fraction(resource_monitor) + total_cycles = ropt.niterations * options.populations + print_search_state( + state.halls_of_fame, + datasets; + options, + equation_speed, + total_cycles, + state.cycles_remaining, + head_node_occupation, + parallelism=ropt.parallelism, + width=options.terminal_width, + ) + end + return time() + else + return last_print_time + end +end + +""" + _should_continue_search(state, start_time, options) + +Check if the search should continue by evaluating all stopping conditions. +Returns true if search should continue, false if it should stop. +""" +function _should_continue_search( + state::AbstractSearchState, start_time::Float64, options::AbstractOptions +) + has_cycles_remaining = sum(state.cycles_remaining) > 0 + early_stop = any(( + check_for_loss_threshold(state.halls_of_fame, options), + check_for_user_quit(state.stdin_reader), + check_for_timeout(start_time, options), + check_max_evals(state.num_evals, options), + )) + return has_cycles_remaining && !early_stop +end + +""" + _check_population_ready(state, j, i, ropt) + +Check if a specific population is ready for processing. +""" +function _check_population_ready( + state::AbstractSearchState, (j, i), ropt::AbstractRuntimeOptions +) + # Check if error on population + if ropt.parallelism in (:multiprocessing, :multithreading) + if istaskfailed(state.tasks[j][i]) + fetch(state.tasks[j][i]) + error("Task failed for population") + end + end + + # Non-blocking check if a population is ready + population_ready = if ropt.parallelism in (:multiprocessing, :multithreading) + isready(state.channels[j][i]) + else + true + end + + # Don't start more if this output has finished its cycles + return population_ready && (state.cycles_remaining[j] > 0) +end + function _main_search_loop!( state::AbstractSearchState{T,L,N}, datasets, @@ -819,8 +1132,9 @@ function _main_search_loop!( ropt.verbosity > 0 && @info "Started!" nout = length(datasets) start_time = time() + + # Setup progress bar if requested progress_bar = if ropt.progress - #TODO: need to iterate this on the max cycles remaining! sum_cycle_remaining = sum(state.cycles_remaining) WrappedProgressBar( sum_cycle_remaining, ropt.niterations; barlen=options.terminal_width @@ -829,227 +1143,74 @@ function _main_search_loop!( nothing end + # Initialize timers and statistics last_print_time = time() last_speed_recording_time = time() num_evals_last = sum(sum, state.num_evals) - num_evals_since_last = sum(sum, state.num_evals) - num_evals_last # i.e., start at 0 - print_every_n_seconds = 5 equation_speed = Float32[] - if ropt.parallelism in (:multiprocessing, :multithreading) - for j in 1:nout, i in 1:(options.populations) - # Start listening for each population to finish: - t = Base.errormonitor( - @async put!(state.channels[j][i], fetch(state.worker_output[j][i])) - ) - push!(state.tasks[j], t) - end - end + # Initialize parallel workers + _initialize_parallel_workers!(state, nout, options, ropt) + + # Setup resource monitoring kappa = 0 resource_monitor = ResourceMonitor(; - # Storing n times as many monitoring intervals as populations seems like it will - # help get accurate resource estimates: max_recordings=options.populations * 100 * nout, start_reporting_at=options.populations * 3 * nout, window_size=options.populations * 2 * nout, ) - while sum(state.cycles_remaining) > 0 + + # Main search loop + while _should_continue_search(state, start_time, options) + # Select next population to process kappa += 1 if kappa > options.populations * nout kappa = 1 end - # nout, populations: - j, i = state.task_order[kappa] - - # Check if error on population: - if ropt.parallelism in (:multiprocessing, :multithreading) - if istaskfailed(state.tasks[j][i]) - fetch(state.tasks[j][i]) - error("Task failed for population") - end - end - # Non-blocking check if a population is ready: - population_ready = if ropt.parallelism in (:multiprocessing, :multithreading) - # TODO: Implement type assertions based on parallelism. - isready(state.channels[j][i]) - else - true - end + (j, i) = state.task_order[kappa] + + # Check if population is ready and record state + population_ready = _check_population_ready(state, (j, i), ropt) record_channel_state!(resource_monitor, population_ready) - # Don't start more if this output has finished its cycles: - # TODO - this might skip extra cycles? - population_ready &= (state.cycles_remaining[j] > 0) + # Process population results if ready if population_ready - # Take the fetch operation from the channel since its ready - (cur_pop, best_seen, cur_record, cur_num_evals) = if ropt.parallelism in - ( - :multiprocessing, :multithreading + continue_search = _process_population_results!( + state, + (j, i), + datasets, + options, + resource_monitor, + equation_speed, + ropt, + progress_bar, ) - take!( - state.channels[j][i] - ) - else - state.worker_output[j][i] - end::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N}} - state.last_pops[j][i] = copy(cur_pop) - state.best_sub_pops[j][i] = best_sub_pop(cur_pop; topn=options.topn) - @recorder state.record[] = recursive_merge(state.record[], cur_record) - state.num_evals[j][i] += cur_num_evals - dataset = datasets[j] - cur_maxsize = state.cur_maxsizes[j] - - for member in cur_pop.members - size = compute_complexity(member, options) - update_frequencies!(state.all_running_search_statistics[j]; size) - end - #! format: off - update_hall_of_fame!(state.halls_of_fame[j], cur_pop.members, options) - update_hall_of_fame!(state.halls_of_fame[j], best_seen.members[best_seen.exists], options) - #! format: on - - # Dominating pareto curve - must be better than all simpler equations - dominating = calculate_pareto_frontier(state.halls_of_fame[j]) - - if options.save_to_file - save_to_file(dominating, nout, j, dataset, options, ropt) - end - ################################################################### - # Migration ####################################################### - if options.migration - best_of_each = Population([ - member for pop in state.best_sub_pops[j] for member in pop.members - ]) - migrate!( - best_of_each.members => cur_pop, options; frac=options.fraction_replaced - ) - end - if options.hof_migration && length(dominating) > 0 - migrate!(dominating => cur_pop, options; frac=options.fraction_replaced_hof) - end - ################################################################### - - state.cycles_remaining[j] -= 1 - if state.cycles_remaining[j] == 0 + if !continue_search break end - worker_idx = assign_next_worker!( - state.worker_assignment; - out=j, - pop=i, - parallelism=ropt.parallelism, - state.procs, - ) - iteration = if options.use_recorder - key = "out$(j)_pop$(i)" - find_iteration_from_record(key, state.record[]) + 1 - else - 0 - end - - c_rss = deepcopy(state.all_running_search_statistics[j]) - in_pop = copy(cur_pop::Population{T,L,N}) - state.worker_output[j][i] = @sr_spawner( - begin - _dispatch_s_r_cycle( - in_pop, - dataset, - options; - pop=i, - out=j, - iteration, - ropt.verbosity, - cur_maxsize, - running_search_statistics=c_rss, - ) - end, - parallelism = ropt.parallelism, - worker_idx = worker_idx - ) - if ropt.parallelism in (:multiprocessing, :multithreading) - state.tasks[j][i] = Base.errormonitor( - @async put!(state.channels[j][i], fetch(state.worker_output[j][i])) - ) - end - - total_cycles = ropt.niterations * options.populations - state.cur_maxsizes[j] = get_cur_maxsize(; - options, total_cycles, cycles_remaining=state.cycles_remaining[j] - ) - move_window!(state.all_running_search_statistics[j]) - if !isnothing(progress_bar) - head_node_occupation = estimate_work_fraction(resource_monitor) - update_progress_bar!( - progress_bar, - only(state.halls_of_fame), - only(datasets), - options, - equation_speed, - head_node_occupation, - ropt.parallelism, - ) - end - if ropt.logger !== nothing - logging_callback!(ropt.logger; state, datasets, ropt, options) - end end + + # Allow other tasks to run yield() - ################################################################ - ## Search statistics - elapsed_since_speed_recording = time() - last_speed_recording_time - if elapsed_since_speed_recording > 1.0 - num_evals_since_last, num_evals_last = let s = sum(sum, state.num_evals) - s - num_evals_last, s - end - current_speed = num_evals_since_last / elapsed_since_speed_recording - push!(equation_speed, current_speed) - average_over_m_measurements = 20 # 20 second running average - if length(equation_speed) > average_over_m_measurements - deleteat!(equation_speed, 1) - end - last_speed_recording_time = time() - end - ################################################################ - - ################################################################ - ## Printing code - elapsed = time() - last_print_time - # Update if time has passed - if elapsed > print_every_n_seconds - if ropt.verbosity > 0 && !ropt.progress && length(equation_speed) > 0 - - # Dominating pareto curve - must be better than all simpler equations - head_node_occupation = estimate_work_fraction(resource_monitor) - total_cycles = ropt.niterations * options.populations - print_search_state( - state.halls_of_fame, - datasets; - options, - equation_speed, - total_cycles, - state.cycles_remaining, - head_node_occupation, - parallelism=ropt.parallelism, - width=options.terminal_width, - ) - end - last_print_time = time() - end - ################################################################ - - ################################################################ - ## Early stopping code - if any(( - check_for_loss_threshold(state.halls_of_fame, options), - check_for_user_quit(state.stdin_reader), - check_for_timeout(start_time, options), - check_max_evals(state.num_evals, options), - )) - break - end - ################################################################ + # Update statistics periodically + last_speed_recording_time, num_evals_last = _update_search_statistics!( + equation_speed, state, last_speed_recording_time, num_evals_last + ) + + # Print status updates periodically + last_print_time = _print_status_update( + ropt, + options, + state, + datasets, + equation_speed, + last_print_time, + resource_monitor, + ) end + + # Cleanup if !isnothing(progress_bar) finish!(progress_bar) end