Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 35 additions & 53 deletions src/MutationFunctions.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module MutationFunctionsModule

using DispatchDoctor: @unstable
using Random: default_rng, AbstractRNG
using DynamicExpressions:
AbstractExpressionNode,
Expand Down Expand Up @@ -40,6 +41,18 @@ function with_contents_for_mutation(ex::AbstractExpression, new_contents, ::Noth
return with_contents(ex, new_contents)
end

function apply_tree_mutation(
ex::AbstractExpression,
rng::AbstractRNG,
mutation_func::F,
args::Vararg{Any,N};
kwargs...,
) where {F<:Function,N}
tree, context = get_contents_for_mutation(ex, rng)
mutated_tree = mutation_func(tree, args..., rng; kwargs...)
return with_contents_for_mutation(ex, mutated_tree, context)
end

"""
random_node(tree::AbstractNode; filter::F=Returns(true))

Expand All @@ -58,9 +71,7 @@ end

"""Swap operands in binary operator for ops like pow and divide"""
function swap_operands(ex::AbstractExpression, rng::AbstractRNG=default_rng())
tree, context = get_contents_for_mutation(ex, rng)
ex = with_contents_for_mutation(ex, swap_operands(tree, rng), context)
return ex
return apply_tree_mutation(ex, rng, swap_operands)
end
function swap_operands(tree::AbstractNode, rng::AbstractRNG=default_rng())
if !any(node -> node.degree == 2, tree)
Expand All @@ -73,17 +84,13 @@ end

"""Randomly convert an operator into another one (binary->binary; unary->unary)"""
function mutate_operator(
ex::AbstractExpression{T}, options::AbstractOptions, rng::AbstractRNG=default_rng()
) where {T<:DATA_TYPE}
tree, context = get_contents_for_mutation(ex, rng)
ex = with_contents_for_mutation(ex, mutate_operator(tree, options, rng), context)
return ex
ex::AbstractExpression, options::AbstractOptions, rng::AbstractRNG=default_rng()
)
return apply_tree_mutation(ex, rng, mutate_operator, options)
end
function mutate_operator(
tree::AbstractExpressionNode{T},
options::AbstractOptions,
rng::AbstractRNG=default_rng(),
) where {T}
tree::AbstractExpressionNode, options::AbstractOptions, rng::AbstractRNG=default_rng()
)
if !(has_operators(tree))
return tree
end
Expand All @@ -98,16 +105,12 @@ end

"""Randomly perturb a constant"""
function mutate_constant(
ex::AbstractExpression{T},
ex::AbstractExpression,
temperature,
options::AbstractOptions,
rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
tree, context = get_contents_for_mutation(ex, rng)
ex = with_contents_for_mutation(
ex, mutate_constant(tree, temperature, options, rng), context
)
return ex
)
return apply_tree_mutation(ex, rng, mutate_constant, temperature, options)
end
function mutate_constant(
tree::AbstractExpressionNode{T},
Expand Down Expand Up @@ -145,17 +148,15 @@ end

"""Add a random unary/binary operation to the end of a tree"""
function append_random_op(
ex::AbstractExpression{T},
ex::AbstractExpression,
options::AbstractOptions,
nfeatures::Int,
rng::AbstractRNG=default_rng();
make_new_bin_op::Union{Bool,Nothing}=nothing,
) where {T<:DATA_TYPE}
tree, context = get_contents_for_mutation(ex, rng)
ex = with_contents_for_mutation(
ex, append_random_op(tree, options, nfeatures, rng; make_new_bin_op), context
)
return apply_tree_mutation(
ex, rng, append_random_op, options, nfeatures; make_new_bin_op=make_new_bin_op
)
return ex
end
function append_random_op(
tree::AbstractExpressionNode{T},
Expand Down Expand Up @@ -195,11 +196,7 @@ function insert_random_op(
nfeatures::Int,
rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
tree, context = get_contents_for_mutation(ex, rng)
ex = with_contents_for_mutation(
ex, insert_random_op(tree, options, nfeatures, rng), context
)
return ex
return apply_tree_mutation(ex, rng, insert_random_op, options, nfeatures)
end
function insert_random_op(
tree::AbstractExpressionNode{T},
Expand Down Expand Up @@ -231,11 +228,7 @@ function prepend_random_op(
nfeatures::Int,
rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
tree, context = get_contents_for_mutation(ex, rng)
ex = with_contents_for_mutation(
ex, prepend_random_op(tree, options, nfeatures, rng), context
)
return ex
return apply_tree_mutation(ex, rng, prepend_random_op, options, nfeatures)
end
function prepend_random_op(
tree::AbstractExpressionNode{T},
Expand Down Expand Up @@ -294,11 +287,7 @@ function delete_random_op!(
nfeatures::Int,
rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
tree, context = get_contents_for_mutation(ex, rng)
ex = with_contents_for_mutation(
ex, delete_random_op!(tree, options, nfeatures, rng), context
)
return ex
return apply_tree_mutation(ex, rng, delete_random_op!, options, nfeatures)
end
function delete_random_op!(
tree::AbstractExpressionNode{T},
Expand Down Expand Up @@ -345,18 +334,15 @@ function delete_random_op!(
return tree
end

function randomize_tree(
# TODO: For some reason this is unstable on Julia 1.10.
@unstable function randomize_tree(
ex::AbstractExpression,
curmaxsize::Int,
options::AbstractOptions,
nfeatures::Int,
rng::AbstractRNG=default_rng(),
)
tree, context = get_contents_for_mutation(ex, rng)
ex = with_contents_for_mutation(
ex, randomize_tree(tree, curmaxsize, options, nfeatures, rng), context
)
return ex
return apply_tree_mutation(ex, rng, randomize_tree, curmaxsize, options, nfeatures)
end
function randomize_tree(
::AbstractExpressionNode{T},
Expand Down Expand Up @@ -472,8 +458,7 @@ function get_two_nodes_without_loop(tree::AbstractNode, rng::AbstractRNG; max_at
end

function form_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng())
tree, context = get_contents_for_mutation(ex, rng)
return with_contents_for_mutation(ex, form_random_connection!(tree, rng), context)
return apply_tree_mutation(ex, rng, form_random_connection!)
end
function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng())
if length(tree) < 5
Expand All @@ -496,8 +481,7 @@ function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rn
end

function break_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng())
tree, context = get_contents_for_mutation(ex, rng)
return with_contents_for_mutation(ex, break_random_connection!(tree, rng), context)
return apply_tree_mutation(ex, rng, break_random_connection!)
end
function break_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng())
tree.degree == 0 && return tree
Expand All @@ -515,9 +499,7 @@ function is_valid_rotation_node(node::AbstractNode)
end

function randomly_rotate_tree!(ex::AbstractExpression, rng::AbstractRNG=default_rng())
tree, context = get_contents_for_mutation(ex, rng)
rotated_tree = randomly_rotate_tree!(tree, rng)
return with_contents_for_mutation(ex, rotated_tree, context)
return apply_tree_mutation(ex, rng, randomly_rotate_tree!)
end
function randomly_rotate_tree!(tree::AbstractNode, rng::AbstractRNG=default_rng())
num_rotation_nodes = count(is_valid_rotation_node, tree)
Expand Down
18 changes: 14 additions & 4 deletions src/ProgressBars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@ using ..UtilsModule: AnnotatedString
mutable struct WrappedProgressBar
bar::Progress
postfix::Vector{Tuple{AnnotatedString,AnnotatedString}}
clean_postfix::Vector{Tuple{AnnotatedString,AnnotatedString}}
last_update::Float64

function WrappedProgressBar(n::Integer, niterations::Integer; kwargs...)
init_vector = Tuple{AnnotatedString,AnnotatedString}[]
kwargs = (; kwargs..., desc="Evolving for $niterations iterations...")
last_update = time()
if get(ENV, "SYMBOLIC_REGRESSION_TEST", "false") == "true"
# For testing, create a progress bar that writes to devnull
output = devnull
return new(Progress(n; output, kwargs...), init_vector)
return new(
Progress(n; output, kwargs...), init_vector, copy(init_vector), last_update
)
end
return new(Progress(n; kwargs...), init_vector)
return new(Progress(n; kwargs...), init_vector, copy(init_vector), last_update)
end
end

Expand All @@ -34,8 +39,13 @@ end
"""Iterate a progress bar."""
function manually_iterate!(pbar::WrappedProgressBar)
width = barlen(pbar)
postfix = map(Fix{2}(format_for_meter, width), pbar.postfix)
next!(pbar.bar; showvalues=postfix, valuecolor=:none)
last_update = pbar.last_update
update_interval = 0.005
if time() - last_update > update_interval
pbar.clean_postfix = map(Fix{2}(format_for_meter, width), pbar.postfix)
pbar.last_update = time()
end
next!(pbar.bar; showvalues=pbar.clean_postfix, valuecolor=:none)
return nothing
end

Expand Down
Loading
Loading