diff --git a/examples/CartPole/DQN.jl b/examples/CartPole/DQN.jl index a3f1203..6139567 100644 --- a/examples/CartPole/DQN.jl +++ b/examples/CartPole/DQN.jl @@ -1,11 +1,10 @@ using Flux, Gym using Flux.Optimise: Optimiser -using Flux.Tracker: data using Statistics: mean using DataStructures: CircularBuffer using Distributions: sample using Printf -#using CuArrays +using StatsBase # Load game environment env = make("CartPole-v0", :human_pane) @@ -48,7 +47,7 @@ remember(state, action, reward, next_state, done) = push!(memory, (data(state), action, reward, data(next_state), done)) function action(state, train=true) - train && rand() <= get_ϵ(e) && (return Gym.sample(env.action_space)) + train && rand() <= get_ϵ(e) && (return Gym.sample(env._env.action_space)) act_values = model(state |> gpu) return Flux.onecold(act_values) end @@ -56,7 +55,7 @@ end function replay() global ϵ batch_size = min(BATCH_SIZE, length(memory)) - minibatch = sample(memory, batch_size, replace = false) + minibatch = StatsBase.sample(memory, batch_size, replace = false) x = [] y = []