From 888db8dd64cf0bce18f54f846625ec5cbb96f220 Mon Sep 17 00:00:00 2001 From: GuoYu Yang <49673553+pilgrimygy@users.noreply.github.com> Date: Mon, 26 Jul 2021 23:59:41 +0800 Subject: [PATCH 1/2] update --- .../Policy Gradient/JuliaRL_SAC_Pendulum.jl | 2 +- .../neural_network_approximator.jl | 11 + .../src/ReinforcementLearningZoo.jl | 2 +- .../src/algorithms/offline_rl/CRR.jl | 241 ++++++++++++++++++ .../src/algorithms/offline_rl/common.jl | 93 +++++++ .../src/algorithms/offline_rl/offline_rl.jl | 2 + .../src/algorithms/policy_gradient/sac.jl | 30 +-- 7 files changed, 364 insertions(+), 17 deletions(-) create mode 100644 src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl create mode 100644 src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl index 7df613dc1..c15e3c07d 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl @@ -71,7 +71,7 @@ function RL.Experiment( start_steps = 1000, start_policy = RandomPolicy(Space([-1.0..1.0 for _ in 1:na]); rng = rng), update_after = 1000, - update_every = 1, + update_freq = 1, automatic_entropy_tuning = true, lr_alpha = 0.003f0, action_dims = action_dims, diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl index 43d70a735..388f4a0b2 100644 --- a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl +++ b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl @@ -111,6 +111,17 @@ function (model::GaussianNetwork)(state; is_sampling::Bool=false, is_return_log_ model(Random.GLOBAL_RNG, state; is_sampling=is_sampling, is_return_log_prob=is_return_log_prob) end +""" +This function is used to infer the probability of getting action `a` given state `s`. +""" +function (model::GaussianNetwork)(state, action) + x = model.pre(state) + μ, logσ = model.μ(x), model.logσ(x) + π_dist = Normal.(μ, exp.(logσ)) + logp_π = sum(logpdf.(π_dist, action), dims = 1) + logp_π -= sum((2.0f0 .* (log(2.0f0) .- action - softplus.(-2.0f0 * action))), dims = 1) +end + ##### # DuelingNetwork ##### diff --git a/src/ReinforcementLearningZoo/src/ReinforcementLearningZoo.jl b/src/ReinforcementLearningZoo/src/ReinforcementLearningZoo.jl index 3e64ebdde..47992ae63 100644 --- a/src/ReinforcementLearningZoo/src/ReinforcementLearningZoo.jl +++ b/src/ReinforcementLearningZoo/src/ReinforcementLearningZoo.jl @@ -3,7 +3,7 @@ module ReinforcementLearningZoo const RLZoo = ReinforcementLearningZoo export RLZoo -export GaussianNetwork +export GaussianNetwork, DuelingNetwork using CircularArrayBuffers using ReinforcementLearningBase diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl new file mode 100644 index 000000000..67d1ce0e5 --- /dev/null +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl @@ -0,0 +1,241 @@ +export CRRLearner + +""" + CRRLearner(;kwargs) + +See paper: [Critic Regularized Regression](https://arxiv.org/abs/2006.15134). + +# Keyword arguments + +- `approximator`::[`ActorCritic`](@ref): used to get Q-values (Critic) and logits (Actor) of a state. +- `target_approximator`::[`ActorCritic`](@ref): similar to `approximator`, but used to estimate the target. +- `γ::Float32`, reward discount rate. +- `batch_size::Int=32` +- `policy_improvement_mode::Symbol=:exp`, type of the weight function f. Possible values: :binary/:exp. +- `ratio_upper_bound::Float32`, when `policy_improvement_mode` is ":exp", the value of the exp function is upper-bounded by this parameter. +- `beta::Float32`, when `policy_improvement_mode` is ":exp", this is the denominator of the exp function. +- `advantage_estimator::Symbol=:mean`, type of the advantage estimate \\hat{A}. Possible values: :mean/:max. +- `update_freq::Int`: the frequency of updating the `approximator`. +- `update_step::Int=0` +- `target_update_freq::Int`: the frequency of syncing `target_approximator`. +- `continuous::Bool`: type of action space. +- `m::Int`: if `continuous=true`, sample `m` actions to calculate advantage estimate. +- `rng = Random.GLOBAL_RNG` +""" +mutable struct CRRLearner{ + Aq<:ActorCritic, + At<:ActorCritic, + R<:AbstractRNG, +} <: AbstractLearner + approximator::Aq + target_approximator::At + γ::Float32 + batch_size::Int + policy_improvement_mode::Symbol + ratio_upper_bound::Float32 + beta::Float32 + advantage_estimator::Symbol + update_freq::Int + update_step::Int + target_update_freq::Int + continuous::Bool + m::Int + rng::R + # for logging + actor_loss::Float32 + critic_loss::Float32 +end + +function CRRLearner(; + approximator::Aq, + target_approximator::At, + γ::Float32 = 0.99f0, + batch_size::Int = 32, + policy_improvement_mode::Symbol = :binary, + ratio_upper_bound::Float32 = 20.0f0, + beta::Float32 = 1.0f0, + advantage_estimator::Symbol = :max, + update_freq::Int = 10, + update_step::Int = 0, + target_update_freq::Int = 100, + continuous::Bool, + m::Int = 4, + rng = Random.GLOBAL_RNG, +) where {Aq<:ActorCritic, At<:ActorCritic} + copyto!(approximator, target_approximator) + CRRLearner( + approximator, + target_approximator, + γ, + batch_size, + policy_improvement_mode, + ratio_upper_bound, + beta, + advantage_estimator, + update_freq, + update_step, + target_update_freq, + continuous, + m, + rng, + 0.0f0, + 0.0f0, + ) +end + +Flux.functor(x::CRRLearner) = (Q = x.approximator, Qₜ = x.target_approximator), +y -> begin + x = @set x.approximator = y.Q + x = @set x.target_approximator = y.Qₜ + x +end + +function (learner::CRRLearner)(env) + s = state(env) + s = Flux.unsqueeze(s, ndims(s) + 1) + s = send_to_device(device(learner), s) + if learner.continuous + learner.approximator.actor(s; is_sampling=true) |> vec |> send_to_host + else + learner.approximator.actor(s) |> vec |> send_to_host + end +end + +function RLBase.update!(learner::CRRLearner, batch::NamedTuple) + if learner.continuous + continuous_update!(learner, batch) + else + discrete_update!(learner, batch) + end +end + +function continuous_update!(learner::CRRLearner, batch::NamedTuple) + AC = learner.approximator + target_AC = learner.target_approximator + γ = learner.γ + beta = learner.beta + batch_size = learner.batch_size + policy_improvement_mode = learner.policy_improvement_mode + ratio_upper_bound = learner.ratio_upper_bound + advantage_estimator = learner.advantage_estimator + D = device(AC) + + s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS) + + a = reshape(a, :, batch_size) + + target_a_t = target_AC.actor(s′; is_sampling=true) + target_q_input = vcat(s′, target_a_t) + target_q_t = target_AC.critic(target_q_input) + + target = r .+ γ .* (1 .- t) .* target_q_t + + q_t = Array{Float32}(undef, learner.m, batch_size) + for i in 1:learner.m + a_sample = AC.actor(learner.rng, s; is_sampling=true) + q_t[i, :] = AC.critic(vcat(s, a_sample)) + end + println(size(maximum(q_t, dims=1))) + + ps = Flux.params(AC) + gs = gradient(ps) do + # Critic loss + q_input = vcat(s, a) + qa_t = AC.critic(q_input) + println(size(qa_t)) + @assert 1 == 2 + critic_loss = Flux.Losses.logitcrossentropy(qa_t, target) + + a = atanh.(a) + log_π = AC.actor(s, a) + + + # Actor loss + if advantage_estimator == :max + advantage = qa_t .- maximum(q_t, dims=1) + elseif advantage_estimator == :mean + advantage = qa_t .- mean(q_t, dims=1) + else + error("Wrong parameter.") + end + println(size(advantage)) + + if policy_improvement_mode == :binary + actor_loss_coef = Float32.(advantage .> 0.0f0) + elseif policy_improvement_mode == :exp + actor_loss_coef = clamp.(exp.(advantage ./ beta), 0, ratio_upper_bound) + else + error("Wrong parameter.") + end + + actor_loss = mean(-log_π) + + ignore() do + learner.actor_loss = actor_loss + learner.critic_loss = critic_loss + end + + actor_loss + critic_loss + end + + update!(AC, gs) +end + +function discrete_update!(learner::CRRLearner, batch::NamedTuple) + AC = learner.approximator + target_AC = learner.target_approximator + γ = learner.γ + beta = learner.beta + batch_size = learner.batch_size + policy_improvement_mode = learner.policy_improvement_mode + ratio_upper_bound = learner.ratio_upper_bound + advantage_estimator = learner.advantage_estimator + D = device(AC) + + s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS) + a = CartesianIndex.(a, 1:batch_size) + + target_a_t = softmax(target_AC.actor(s′)) + target_q_t = target_AC.critic(s′) + expected_target_q = sum(target_a_t .* target_q_t, dims=1) + + target = r .+ γ .* (1 .- t) .* expected_target_q + + ps = Flux.params(AC) + gs = gradient(ps) do + # Critic loss + q_t = AC.critic(s) + qa_t = q_t[a] + critic_loss = Flux.Losses.mse(qa_t, target) + + # Actor loss + a_t = softmax(AC.actor(s)) + + if advantage_estimator == :max + advantage = qa_t .- maximum(q_t, dims=1) + elseif advantage_estimator == :mean + advantage = qa_t .- mean(q_t, dims=1) + else + error("Wrong parameter.") + end + + if policy_improvement_mode == :binary + actor_loss_coef = Float32.(advantage .> 0.0f0) + elseif policy_improvement_mode == :exp + actor_loss_coef = clamp.(exp.(advantage ./ beta), 0, ratio_upper_bound) + else + error("Wrong parameter.") + end + + actor_loss = mean(-log.(a_t[a]) .* actor_loss_coef) + + ignore() do + learner.actor_loss = actor_loss + learner.critic_loss = critic_loss + end + + actor_loss + critic_loss + end + + update!(AC, gs) +end \ No newline at end of file diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl new file mode 100644 index 000000000..49961bd5a --- /dev/null +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl @@ -0,0 +1,93 @@ +export OfflinePolicy, RLTransition + +struct RLTransition + state + action + reward + terminal + next_state +end + +Base.@kwdef struct OfflinePolicy{L,T} <: AbstractPolicy + learner::L + dataset::Vector{T} + continuous::Bool + batch_size::Int +end + +(π::OfflinePolicy)(env) = π(env, ActionStyle(env), action_space(env)) + +function (π::OfflinePolicy)(env, ::MinimalActionSet, ::Base.OneTo) + if π.continuous + π.learner(env) + else + findmax(π.learner(env))[2] + end +end +(π::OfflinePolicy)(env, ::FullActionSet, ::Base.OneTo) = findmax(π.learner(env), legal_action_space_mask(env))[2] + +function (π::OfflinePolicy)(env, ::MinimalActionSet, A) + if π.continuous + π.learner(env) + else + A[findmax(π.learner(env))[2]] + end +end +(π::OfflinePolicy)(env, ::FullActionSet, A) = A[findmax(π.learner(env), legal_action_space_mask(env))[2]] + +function RLBase.update!( + p::OfflinePolicy, + traj::AbstractTrajectory, + ::AbstractEnv, + ::PreActStage, +) + l = p.learner + l.update_step += 1 + + if in(:target_update_freq, fieldnames(typeof(l))) && l.update_step % l.target_update_freq == 0 + copyto!(l.target_approximator, l.approximator) + end + + l.update_step % l.update_freq == 0 || return + + inds, batch = sample(l.rng, p.dataset, p.batch_size) + + update!(l, batch) +end + +function StatsBase.sample(rng::AbstractRNG, dataset::Vector{T}, batch_size::Int) where {T} + valid_range = 1:length(dataset) + inds = rand(rng, valid_range, batch_size) + batch_data = dataset[inds] + s_length = size(batch_data[1].state)[1] + + s = Array{Float32}(undef, s_length, batch_size) + s′ = Array{Float32}(undef, s_length, batch_size) + a = [] + r = [] + t = [] + for (i, data) in enumerate(batch_data) + s[:, i] = data.state + push!(a, data.action) + s′[:, i] = data.next_state + push!(r, data.reward) + push!(t, data.terminal) + end + #a = reshape(a, :, batch_size) + batch = NamedTuple{SARTS}((s, a, r, t, s′)) + inds, batch +end + +""" + calculate_CQL_loss(q_value, action; method) + +See paper: [Conservative Q-Learning for Offline Reinforcement Learning](https://arxiv.org/abs/2006.04779) +""" +function calculate_CQL_loss(q_value::Matrix{T}, action::Vector{R}; method = "CQL(H)") where {T, R} + if method == "CQL(H)" + cql_loss = mean(log.(sum(exp.(q_value), dims=1)) .- q_value[action]) + else + @error Wrong method parameter + end + return cql_loss +end \ No newline at end of file diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/offline_rl.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/offline_rl.jl index 6c518749b..9847b6493 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/offline_rl.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/offline_rl.jl @@ -1 +1,3 @@ include("behavior_cloning.jl") +include("CRR.jl") +include("common.jl") \ No newline at end of file diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl index f567fac76..0c74c60fb 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl @@ -19,11 +19,11 @@ mutable struct SACPolicy{ start_steps::Int start_policy::P update_after::Int - update_every::Int + update_freq::Int automatic_entropy_tuning::Bool lr_alpha::Float32 target_entropy::Float32 - step::Int + update_step::Int rng::R # Logging reward_term::Float32 @@ -47,11 +47,11 @@ end - `batch_size = 32`, - `start_steps = 10000`, - `update_after = 1000`, -- `update_every = 50`, +- `update_freq = 50`, - `automatic_entropy_tuning::Bool = false`, whether to automatically tune the entropy. - `lr_alpha::Float32 = 0.003f0`, learning rate of tuning entropy. - `action_dims = 0`, the dimension of the action. if `automatic_entropy_tuning = true`, must enter this parameter. -- `step = 0`, +- `update_step = 0`, - `rng = Random.GLOBAL_RNG`, `policy` is expected to output a tuple `(μ, logσ)` of mean and @@ -73,11 +73,11 @@ function SACPolicy(; batch_size = 32, start_steps = 10000, update_after = 1000, - update_every = 50, + update_freq = 50, automatic_entropy_tuning = true, lr_alpha = 0.003f0, action_dims = 0, - step = 0, + update_step = 0, rng = Random.GLOBAL_RNG, ) copyto!(qnetwork1, target_qnetwork1) # force sync @@ -98,11 +98,11 @@ function SACPolicy(; start_steps, start_policy, update_after, - update_every, + update_freq, automatic_entropy_tuning, lr_alpha, Float32(-action_dims), - step, + update_step, rng, 0f0, 0f0, @@ -111,17 +111,17 @@ end # TODO: handle Training/Testing mode function (p::SACPolicy)(env) - p.step += 1 + p.update_step += 1 - if p.step <= p.start_steps + if p.update_step <= p.start_steps p.start_policy(env) else D = device(p.policy) s = state(env) s = Flux.unsqueeze(s, ndims(s) + 1) # trainmode: - action = dropdims(p.policy.model(s; is_sampling=true, is_return_log_prob=true)[1], dims=2) # Single action vec, drop second dim - + action = dropdims(p.policy(s; is_sampling=true), dims=2) # Single action vec, drop second dim + # testmode: # if testing dont sample an action, but act deterministically by # taking the "mean" action @@ -136,7 +136,7 @@ function RLBase.update!( ::PreActStage, ) length(traj) > p.update_after || return - p.step % p.update_every == 0 || return + p.update_step % p.update_freq == 0 || return inds, batch = sample(p.rng, traj, BatchSampler{SARTS}(p.batch_size)) update!(p, batch) end @@ -146,7 +146,7 @@ function RLBase.update!(p::SACPolicy, batch::NamedTuple{SARTS}) γ, τ, α = p.γ, p.τ, p.α - a′, log_π = p.policy.model(s′; is_sampling=true, is_return_log_prob=true) + a′, log_π = p.policy(s′; is_sampling=true, is_return_log_prob=true) q′_input = vcat(s′, a′) q′ = min.(p.target_qnetwork1(q′_input), p.target_qnetwork2(q′_input)) @@ -168,7 +168,7 @@ function RLBase.update!(p::SACPolicy, batch::NamedTuple{SARTS}) # Train Policy p_grad = gradient(Flux.params(p.policy)) do - a, log_π = p.policy.model(s; is_sampling=true, is_return_log_prob=true) + a, log_π = p.policy(s; is_sampling=true, is_return_log_prob=true) q_input = vcat(s, a) q = min.(p.qnetwork1(q_input), p.qnetwork2(q_input)) reward = mean(q) From 0ca0fe17b6bfe9c797617be8c47adfc99b482b19 Mon Sep 17 00:00:00 2001 From: pilgrim Date: Tue, 27 Jul 2021 10:52:38 +0800 Subject: [PATCH 2/2] Update --- .../src/algorithms/offline_rl/CRR.jl | 17 ++++++----------- .../src/algorithms/offline_rl/common.jl | 1 - 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl index 67d1ce0e5..2dbde0c34 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl @@ -135,21 +135,17 @@ function continuous_update!(learner::CRRLearner, batch::NamedTuple) a_sample = AC.actor(learner.rng, s; is_sampling=true) q_t[i, :] = AC.critic(vcat(s, a_sample)) end - println(size(maximum(q_t, dims=1))) ps = Flux.params(AC) gs = gradient(ps) do # Critic loss q_input = vcat(s, a) qa_t = AC.critic(q_input) - println(size(qa_t)) - @assert 1 == 2 + critic_loss = Flux.Losses.logitcrossentropy(qa_t, target) - a = atanh.(a) log_π = AC.actor(s, a) - # Actor loss if advantage_estimator == :max advantage = qa_t .- maximum(q_t, dims=1) @@ -158,17 +154,16 @@ function continuous_update!(learner::CRRLearner, batch::NamedTuple) else error("Wrong parameter.") end - println(size(advantage)) if policy_improvement_mode == :binary actor_loss_coef = Float32.(advantage .> 0.0f0) elseif policy_improvement_mode == :exp - actor_loss_coef = clamp.(exp.(advantage ./ beta), 0, ratio_upper_bound) + actor_loss_coef = clamp.(exp.(advantage ./ beta), 0.0f0, ratio_upper_bound) else error("Wrong parameter.") end - actor_loss = mean(-log_π) + actor_loss = mean(-log_π .* Zygote.dropgrad(actor_loss_coef)) ignore() do learner.actor_loss = actor_loss @@ -206,7 +201,7 @@ function discrete_update!(learner::CRRLearner, batch::NamedTuple) # Critic loss q_t = AC.critic(s) qa_t = q_t[a] - critic_loss = Flux.Losses.mse(qa_t, target) + critic_loss = Flux.Losses.logitcrossentropy(qa_t, target) # Actor loss a_t = softmax(AC.actor(s)) @@ -222,12 +217,12 @@ function discrete_update!(learner::CRRLearner, batch::NamedTuple) if policy_improvement_mode == :binary actor_loss_coef = Float32.(advantage .> 0.0f0) elseif policy_improvement_mode == :exp - actor_loss_coef = clamp.(exp.(advantage ./ beta), 0, ratio_upper_bound) + actor_loss_coef = clamp.(exp.(advantage ./ beta), 0.0f0, ratio_upper_bound) else error("Wrong parameter.") end - actor_loss = mean(-log.(a_t[a]) .* actor_loss_coef) + actor_loss = mean(-log.(a_t[a]) .* Zygote.dropgrad(actor_loss_coef)) ignore() do learner.actor_loss = actor_loss diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl index 49961bd5a..6f83021aa 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl @@ -73,7 +73,6 @@ function StatsBase.sample(rng::AbstractRNG, dataset::Vector{T}, batch_size::Int) push!(r, data.reward) push!(t, data.terminal) end - #a = reshape(a, :, batch_size) batch = NamedTuple{SARTS}((s, a, r, t, s′)) inds, batch end