Skip to content
This repository was archived by the owner on May 21, 2022. It is now read-only.
Open
184 changes: 90 additions & 94 deletions src/OpenAIGym.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,88 +11,64 @@ import Reinforce:
KeyboardAction, KeyboardActionSet

export
gym,
pygym,
GymEnv,
test_env
test_env,
PyAny

const _py_envs = Dict{String,Any}()

# --------------------------------------------------------------

abstract AbstractGymEnv <: AbstractEnvironment
abstract type AbstractGymEnv <: AbstractEnvironment end

"A simple wrapper around the OpenAI gym environments to add to the Reinforce framework"
type GymEnv <: AbstractGymEnv
mutable struct GymEnv{T} <: AbstractGymEnv
name::String
pyenv # the python "env" object
state
pyenv::PyObject # the python "env" object
pystep::PyObject # the python env.step function
pyreset::PyObject # the python env.reset function
pystate::PyObject # the state array object referenced by the PyArray state.o
pystepres::PyObject # used to make stepping the env slightly more efficient
pytplres::PyObject # used to make stepping the env slightly more efficient
info::PyObject # store it as a PyObject for speed, since often unused
state::T
reward::Float64
total_reward::Float64
actions::AbstractSet
done::Bool
info::Dict
GymEnv(name,pyenv) = new(name,pyenv)
end
GymEnv(name) = gym(name)

function Reinforce.reset!(env::GymEnv)
env.state = env.pyenv[:reset]()
env.reward = 0.0
env.actions = actions(env, nothing)
env.done = false
end

"A simple wrapper around the OpenAI gym environments to add to the Reinforce framework"
type UniverseEnv <: AbstractGymEnv
name::String
pyenv # the python "env" object
state
reward
actions::AbstractSet
done
info::Dict
UniverseEnv(name,pyenv) = new(name,pyenv)
end
UniverseEnv(name) = gym(name)

function Reinforce.reset!(env::UniverseEnv)
env.state = env.pyenv[:reset]()
env.reward = [0.0]
env.actions = actions(env, nothing)
env.done = [false]
function GymEnv{T}(name, pyenv, pystate, state::T) where T
env = new{T}(name, pyenv, pyenv["step"], pyenv["reset"],
pystate, PyNULL(), PyNULL(), PyNULL(), state)
reset!(env)
env
end
end

function gym(name::AbstractString)
function GymEnv(name; stateT=PyArray)
env = if name in ("Soccer-v0", "SoccerEmptyGoal-v0")
@pyimport gym_soccer
get!(_py_envs, name) do
GymEnv(name, pygym[:make](name))
end
elseif split(name, ".")[1] in ("flashgames", "wob")
@pyimport universe
@pyimport universe.wrappers as wrappers
if !isdefined(OpenAIGym, :vnc_event)
global const vnc_event = PyCall.pywrap(PyCall.pyimport("universe.spaces.vnc_event"))
end
Base.copy!(gym_soccer, pyimport("gym_soccer"))
get!(_py_envs, name) do
pyenv = wrappers.SafeActionSpace(pygym[:make](name))
pyenv[:configure](remotes=1) # automatically creates a local docker container
# pyenv[:configure](remotes="vnc://localhost:5900+15900")
o = UniverseEnv(name, pyenv)
# finalizer(o, o.pyenv[:close]())
sleep(2)
o
GymEnv(name, pygym[:make](name), stateT)
end
else
GymEnv(name, pygym[:make](name))
GymEnv(name, pygym[:make](name), stateT)
end
reset!(env)
env
end

function GymEnv(name, pyenv, stateT)
pystate = pycall(pyenv["reset"], PyObject)
state = convert(stateT, pystate)
T = typeof(state)
GymEnv{T}(name, pyenv, pystate, state)
end


# --------------------------------------------------------------

render(env::AbstractGymEnv, args...) = env.pyenv[:render]()
render(env::AbstractGymEnv, args...; kwargs...) =
pycall(env.pyenv[:render], PyAny; kwargs...)

# --------------------------------------------------------------

Expand All @@ -116,14 +92,6 @@ function actionset(A::PyObject)
# # error("Unsupported shape for IntervalSet: $(A[:shape])")
# [IntervalSet{Float64}(lo[i], hi[i]) for i=1:length(lo)]
# end
elseif haskey(A, :buttonmasks)
# assumed VNC actions... keys to press, buttons to mask, and screen position
# keyboard = DiscreteSet(A[:keys])
keyboard = KeyboardActionSet(A[:keys])
buttons = DiscreteSet(Int[bm for bm in A[:buttonmasks]])
width,height = A[:screen_shape]
mouse = MouseActionSet(width, height, buttons)
TupleSet(keyboard, mouse)
elseif haskey(A, :actions)
# Hardcoded
TupleSet(DiscreteSet(A[:actions]))
Expand All @@ -134,55 +102,83 @@ function actionset(A::PyObject)
end
end


function Reinforce.actions(env::AbstractGymEnv, s′)
actionset(env.pyenv[:action_space])
end

pyaction(a::Vector) = Any[pyaction(ai) for ai=a]
pyaction(a::KeyboardAction) = Any[a.key]
pyaction(a::MouseAction) = Any[vnc_event.PointerEvent(a.x, a.y, a.button)]
pyaction(a) = a

function Reinforce.step!(env::GymEnv, s, a)
# info("Going to take action: $a")
pyact = pyaction(a)
s′, r, env.done, env.info = env.pyenv[:step](pyact)
env.reward, env.state = r, s′
"""
`reset!` for PyArray state types
"""
function Reinforce.reset!(env::GymEnv{T}) where T <: PyArray
setdata!(env.state, pycall!(env.pystate, env.pyreset, PyObject))
return gymreset!(env)
end

function Reinforce.step!(env::UniverseEnv, s, a)
info("Going to take action: $a")
pyact = Any[pyaction(a)]
s′, r, env.done, env.info = env.pyenv[:step](pyact)
env.reward, env.state = r, s′
"""
`reset!` for non PyArray state types
"""
function Reinforce.reset!(env::GymEnv{T}) where T
pycall!(env.pystate, env.pyreset, PyObject)
env.state = convert(T, env.pystate)
return gymreset!(env)
end

Reinforce.finished(env::GymEnv, s′) = env.done
Reinforce.finished(env::UniverseEnv, s′) = all(env.done)
function gymreset!(env::GymEnv{T}) where T
env.reward = 0.0
env.total_reward = 0.0
env.actions = actions(env, nothing)
env.done = false
return env.state
end

"""
`step!` for PyArray state
"""
function Reinforce.step!(env::GymEnv{T}, a) where T <: PyArray
pyact = pyaction(a)
pycall!(env.pystepres, env.pystep, PyObject, pyact)

# --------------------------------------------------------------
unsafe_gettpl!(env.pystate, env.pystepres, PyObject, 0)
setdata!(env.state, env.pystate)

return gymstep!(env)
end

function test_env(name::String = "CartPole-v0")
env = gym(name)
for sars′ in Episode(env, RandomPolicy())
render(env)
end
"""
step! for non-PyArray state
"""
function Reinforce.step!(env::GymEnv{T}, a) where T
pyact = pyaction(a)
pycall!(env.pystepres, env.pystep, PyObject, pyact)

unsafe_gettpl!(env.pystate, env.pystepres, PyObject, 0)
env.state = convert(T, env.pystate)

return gymstep!(env)
end

@inline function gymstep!(env)
r = unsafe_gettpl!(env.pytplres, env.pystepres, Float64, 1)
env.done = unsafe_gettpl!(env.pytplres, env.pystepres, Bool, 2)
unsafe_gettpl!(env.info, env.pystepres, PyObject, 3)
env.total_reward += r
return (r, env.state)
end

Reinforce.finished(env::GymEnv) = env.done
Reinforce.finished(env::GymEnv, s′) = env.done

# --------------------------------------------------------------

function __init__()
@static if is_linux()
# due to a ssl library bug, I have to first load the ssl lib here
condadir = Pkg.dir("Conda","deps","usr","lib")
Libdl.dlopen(joinpath(condadir, "libssl.so"))
Libdl.dlopen(joinpath(condadir, "python2.7", "lib-dynload", "_ssl.so"))
end
global const pygym = PyNULL()
global const pysoccer = PyNULL()

global const pygym = pyimport("gym")
function __init__()
# the copy! puts the gym module into `pygym`, handling python ref-counting
Base.copy!(pygym, pyimport("gym"))
end

end # module
141 changes: 132 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,136 @@
using OpenAIGym
using Base.Test
using OpenAIGym, Compat.Test
using PyCall

# write your own tests here
@test 1 == 1
"""
`function time_steps(env::GymEnv{T}, num_eps::Int) where T`

if isinteractive()
env = GymEnv("CartPole-v0")
for i=1:5
R = run_episode(()->nothing, env, RandomPolicy())
info("Episode $i finished. Total reward: $R")
run through num_eps eps, recording the time taken for each step and
how many steps were made. Doesn't time the `reset!` or the first step of each
episode (since higher chance that it's slower/faster than the rest, and we want
to compare the average time taken for each step as fairly as possible)
"""
function time_steps(env::GymEnv, num_eps::Int)
t = 0.0
steps = 0
for i in 1:num_eps
reset!(env)
# step!(env, rand(env.actions)) # ignore the first step - it might be slow?
t += (@elapsed steps += epstep(env))
end
steps, t
end

"""
Steps through an episode until it's `done`
assumes env has been `reset!`
"""
function epstep(env::GymEnv)
steps = 0
while !env.done
steps += 1
r, s = step!(env, rand(env.actions))
end
steps
end

@testset "Gym Basics" begin

pong = GymEnv("Pong-v4")
pongnf = GymEnv("PongNoFrameskip-v4")
pacman = GymEnv("MsPacman-v4")
pacmannf = GymEnv("MsPacmanNoFrameskip-v4")
cartpole = GymEnv("CartPole-v0")
bj = GymEnv("Blackjack-v0", stateT=PyAny)

allenvs = [pong, pongnf, pacman, pacmannf, cartpole, bj]
eps2trial = Dict(pong=>2, pongnf=>1, pacman=>2, pacmannf=>1, cartpole=>400, bj=>30000)
atarienvs = [pong, pongnf, pacman, pacmannf]
envs = allenvs

@testset "envs load" begin
# check they all work - no errors == no worries
println("------------------------------ Check envs load ------------------------------")
for (i, env) in enumerate(envs)
a = rand(env.actions) |> OpenAIGym.pyaction
action_type = a |> PyObject |> pytypeof
println("env.pyenv: $(env.pyenv) action_type: $action_type (e.g. $a)")
time_steps(env, 1)
@test !ispynull(env.pyenv)
println("------------------------------")
end
end

@testset "julia speed test" begin
println("------------------------------ Begin Julia Speed Check ------------------------------")
for env in envs
num_eps = eps2trial[env]
steps, t = time_steps(env, num_eps)
println("env.pyenv: $(env.pyenv) num_eps: $num_eps t: $t steps: $steps")
println("microsecs/step (lower is better): ", t*1e6/steps)
println("------------------------------")
end
println("------------------------------ End Julia Speed Check ------------------------------\n")
end

@testset "python speed test" begin
println("------------------------------ Begin Python Speed Check ------------------------------")
py"""
import gym
import numpy as np

pong = gym.make("Pong-v4")
pongnf = gym.make("PongNoFrameskip-v4")
pacman = gym.make("MsPacman-v4");
pacmannf = gym.make("MsPacmanNoFrameskip-v4");
cartpole = gym.make("CartPole-v0")
bj = gym.make("Blackjack-v0")

allenvs = [pong, pongnf, pacman, pacmannf, cartpole, bj]
eps2trial = {pong: 2, pongnf: 1, pacman: 2, pacmannf: 1, cartpole: 400, bj: 30000}
atarienvs = [pong, pongnf, pacman, pacmannf];

envs = allenvs

import time
class Timer(object):
elapsed = 0.0
def __init__(self, name=None):
self.name = name

def __enter__(self):
self.tstart = time.time()

def __exit__(self, type, value, traceback):
Timer.elapsed = time.time() - self.tstart

def time_steps(env, num_eps):
t = 0.0
steps = 0
for i in range(num_eps):
env.reset()
with Timer():
steps += epstep(env)
t += Timer.elapsed
return steps, t

def epstep(env):
steps = 0
while True:
steps += 1
action = env.action_space.sample()
state, reward, done, info = env.step(action)
if done == True:
break
return steps

for env in envs:
num_eps = eps2trial[env]
with Timer():
steps, s = time_steps(env, num_eps)
t = Timer.elapsed
print(f"{env} num_eps: {num_eps} t: {t} steps: {steps} \n microsecs/step (lower is better): {t*1e6/steps}")
print("------------------------------")
"""
println("------------------------------ End Python Speed Check ------------------------------")
end
end