diff --git a/NEWS.md b/NEWS.md index c59b2822e..6276520b1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,10 +1,13 @@ # ReinforcementLearning.jl Release Notes ## ReinforcementLearning.jl@v0.10.0 + ### ReinforcementLearningCore.jl + #### v0.8.2 - Add GaussianNetwork and DuelingNetwork into ReinforcementLearningCore.jl as general components. [#370](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/370) +- Export `WeightedSoftmaxExplorer`. [#382](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/382) ### ReinforcementLearningZoo.jl @@ -12,6 +15,8 @@ - Update the complete SAC implementation and modify some details based on the original paper. [#365](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/365) +- Add some extra keyword parameters for `BehaviorCloningPolicy` to use it + online. [#390](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/390) ## ReinforcementLearning.jl@v0.9.0 diff --git a/src/ReinforcementLearningCore/src/policies/agents/trajectories/reservoir_trajectory.jl b/src/ReinforcementLearningCore/src/policies/agents/trajectories/reservoir_trajectory.jl index b8f3deb09..c0a154679 100644 --- a/src/ReinforcementLearningCore/src/policies/agents/trajectories/reservoir_trajectory.jl +++ b/src/ReinforcementLearningCore/src/policies/agents/trajectories/reservoir_trajectory.jl @@ -32,3 +32,21 @@ function Base.push!(b::ReservoirTrajectory; kw...) end end end + +function RLBase.update!( + trajectory::ReservoirTrajectory, + policy::AbstractPolicy, + env::AbstractEnv, + ::PreActStage, + action, +) + s = policy isa NamedPolicy ? state(env, nameof(policy)) : state(env) + if haskey(trajectory.buffer, :legal_actions_mask) + lasm = + policy isa NamedPolicy ? legal_action_space_mask(env, nameof(policy)) : + legal_action_space_mask(env) + push!(trajectory; :state => s, :action => action, :legal_actions_mask => lasm) + else + push!(trajectory; :state => s, :action => action) + end +end diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policies/explorers/weighted_softmax_explorer.jl b/src/ReinforcementLearningCore/src/policies/q_based_policies/explorers/weighted_softmax_explorer.jl index 2a0eff6b5..82a046269 100644 --- a/src/ReinforcementLearningCore/src/policies/q_based_policies/explorers/weighted_softmax_explorer.jl +++ b/src/ReinforcementLearningCore/src/policies/q_based_policies/explorers/weighted_softmax_explorer.jl @@ -25,4 +25,9 @@ function (s::WeightedSoftmaxExplorer)(values::AbstractVector{T}, mask) where {T} s(values) end -RLBase.prob(s::WeightedSoftmaxExplorer, values) = softmax(values) \ No newline at end of file +RLBase.prob(s::WeightedSoftmaxExplorer, values) = softmax(values) + +function RLBase.prob(s::WeightedSoftmaxExplorer, values::AbstractVector{T}, mask) where {T} + values[.!mask] .= typemin(T) + prob(s, values) +end diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl index 564f6c84e..20aa98d79 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl @@ -1,5 +1,12 @@ export BehaviorCloningPolicy +mutable struct BehaviorCloningPolicy{A} <: AbstractPolicy + approximator::A + explorer::Any + sampler::BatchSampler{(:state, :action)} + min_reservoir_history::Int +end + """ BehaviorCloningPolicy(;kw...) @@ -7,18 +14,32 @@ export BehaviorCloningPolicy - `approximator`: calculate the logits of possible actions directly - `explorer=GreedyExplorer()` - +- `batch_size::Int = 32` +- `min_reservoir_history::Int = 100`, number of transitions that should be experienced before updating the `approximator`. +- `rng = Random.GLOBAL_RNG` """ -Base.@kwdef struct BehaviorCloningPolicy{A} <: AbstractPolicy - approximator::A - explorer::Any = GreedyExplorer() +function BehaviorCloningPolicy(; + approximator::A, + explorer::Any = GreedyExplorer(), + batch_size::Int = 32, + min_reservoir_history::Int = 100, + rng = Random.GLOBAL_RNG +) where {A} + sampler = BatchSampler{(:state, :action)}(batch_size; rng = rng) + BehaviorCloningPolicy( + approximator, + explorer, + sampler, + min_reservoir_history, + ) end function (p::BehaviorCloningPolicy)(env::AbstractEnv) s = state(env) s_batch = Flux.unsqueeze(s, ndims(s) + 1) - logits = p.approximator(s_batch) |> vec # drop dimension - p.explorer(logits) + s_batch = send_to_device(device(p.approximator), s_batch) + logits = p.approximator(s_batch) |> vec |> send_to_host # drop dimension + typeof(ActionStyle(env)) == MinimalActionSet ? p.explorer(logits) : p.explorer(logits, legal_action_space_mask(env)) end function RLBase.update!(p::BehaviorCloningPolicy, batch::NamedTuple{(:state, :action)}) @@ -31,3 +52,33 @@ function RLBase.update!(p::BehaviorCloningPolicy, batch::NamedTuple{(:state, :ac end update!(m, gs) end + +function RLBase.update!(p::BehaviorCloningPolicy, t::AbstractTrajectory) + (length(t) <= p.min_reservoir_history || length(t) <= p.sampler.batch_size) && return + + _, batch = p.sampler(t) + RLBase.update!(p, send_to_device(device(p.approximator), batch)) +end + +function RLBase.prob(p::BehaviorCloningPolicy, env::AbstractEnv) + s = state(env) + s_batch = Flux.unsqueeze(s, ndims(s) + 1) + values = p.approximator(s_batch) |> vec |> send_to_host + typeof(ActionStyle(env)) == MinimalActionSet ? prob(p.explorer, values) : prob(p.explorer, values, legal_action_space_mask(env)) +end + +function RLBase.prob(p::BehaviorCloningPolicy, env::AbstractEnv, action) + A = action_space(env) + P = prob(p, env) + @assert length(A) == length(P) + if A isa Base.OneTo + P[action] + else + for (a, p) in zip(A, P) + if a == action + return p + end + end + @error "action[$action] is not found in action space[$(action_space(env))]" + end +end