diff --git a/Project.toml b/Project.toml index f9bf2a6..2f9bd81 100644 --- a/Project.toml +++ b/Project.toml @@ -4,13 +4,13 @@ repo = "https://github.com/JuliaPOMDP/POMDPXFiles.jl" version = "0.2.4" [deps] -LightXML = "9c8b4983-aa76-5018-a973-4c85ecc9e179" +EzXML = "8f5d6c58-4d21-5cfd-889c-e3ad7ee6a615" POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7" POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" +Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" [compat] -LightXML = "0.9" POMDPTools = "0.1" POMDPs = "0.7.3, 0.8, 0.9" julia = "1" @@ -20,4 +20,4 @@ POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["POMDPModels", "Test"] \ No newline at end of file +test = ["POMDPModels", "Test"] diff --git a/REQUIRE b/REQUIRE deleted file mode 100644 index c6df563..0000000 --- a/REQUIRE +++ /dev/null @@ -1,2 +0,0 @@ -julia 0.6 -LightXML diff --git a/src/POMDPXFiles.jl b/src/POMDPXFiles.jl index 042709e..6268562 100644 --- a/src/POMDPXFiles.jl +++ b/src/POMDPXFiles.jl @@ -3,26 +3,21 @@ module POMDPXFiles using POMDPs using POMDPTools using ProgressMeter -import POMDPs: action, value - -# import o avoid naming conflict in POMDPs.jl (value is overloaded in LightXML) -import LightXML: parse_file, root, get_elements_by_tagname, attribute, content +using Parameters +import EzXML +import EzXML: Node, XMLDocument, ElementNode, addelement!, link!, write, readxml, root, findfirst, findall, nodecontent, setroot!, prettyprint export AbstractPOMDPXFile, POMDPXFile, - MOMDPXFile, Alphas, POMDPAlphas, - read_pomdp, - action, - value - + read_pomdp include("writer.jl") include("policy.jl") -include("read.jl") +include("reader.jl") end # module diff --git a/src/read.jl b/src/read.jl deleted file mode 100644 index da5606e..0000000 --- a/src/read.jl +++ /dev/null @@ -1,119 +0,0 @@ -# Parses policy xml file and returns alpha vectors and alpha actions -# Should handle any policy in the .policy format -# Should handle policies written by POMDPs.jl or APPL -# -# alpha_vectors is a matrix containing the alpha vectors -# Each row corresponds to a different alpha vector -# -# alpha_actions is a vector containing the list of action indices -# Each index corresponds to action associated with the alpha vector of this row -# These are 0-indexed... 0 means it is the first action -# -# TODO: Check that the input file exists and handle the case that it doesn't -# TODO: Handle sparse vectors -# - -function read_momdp(filename::String) - - # Parse the xml file - # TODO: Check that the file exists and handle the case that it doesn't - xdoc = parse_file(filename) - - # Get the root of the document (the Policy tag in this case) - policy_tag = root(xdoc) - #println(name(policy_tag)) # print the name of this tag - - # Determine expected number of vectors and their length - alphavector_tag = get_elements_by_tagname(policy_tag, "AlphaVector")[1] #length of 1 anyway - num_vectors = int(attribute(alphavector_tag, "numVectors")) - vector_length = int(attribute(alphavector_tag, "vectorLength")) - num_full_obs_states = int(attribute(alphavector_tag, "numObsValue")) - - # For debugging purposes... - #println("AlphaVector tag: # vectors, vector length: $(num_vectors), $(vector_length)") - - # Arrays with vector and sparse vector tags - vector_tags = get_elements_by_tagname(alphavector_tag, "Vector") - sparsevector_tags = get_elements_by_tagname(alphavector_tag, "SparseVector") - - num_vectors_check = length(vector_tags) + length(sparsevector_tags) # should be same as num_vectors - - # Initialize the gamma matrix. This is basically a matrix with the alpha - # vectors as rows. - #alpha_vectors = Array(Float64, num_vectors, vector_length) - alpha_vectors = Array{Float64}(vector_length, num_vectors) - alpha_actions = Array{String}(num_vectors) - observable_states = Array{String}(num_vectors) - gammarow = 1 - - # Fill in gamma - for vector in vector_tags - alpha = float(split(content(vector))) - #alpha_vectors[gammarow, :] = alpha - alpha_vectors[:,gammarow] = alpha - alpha_actions[gammarow] = attribute(vector, "action") - observable_states[gammarow] = attribute(vector, "obsValue") - gammarow += 1 - end - - # TODO: Handle sparse vectors - for vector in sparsevector_tags - # Turn these into vectors as well - end - - # Return alpha vectors and indices of actions - return alpha_vectors, int(alpha_actions), int(observable_states) -end - - -function read_pomdp(filename::String) - - # Parse the xml file - # TODO: Check that the file exists and handle the case that it doesn't - xdoc = parse_file(filename) - - # Get the root of the document (the Policy tag in this case) - policy_tag = root(xdoc) - #println(name(policy_tag)) # print the name of this tag - - # Determine expected number of vectors and their length - alphavector_tag = get_elements_by_tagname(policy_tag, "AlphaVector")[1] #length of 1 anyway - num_vectors = parse(Int64, attribute(alphavector_tag, "numVectors")) - vector_length = parse(Int64, attribute(alphavector_tag, "vectorLength")) - num_full_obs_states = parse(Int64, attribute(alphavector_tag, "numObsValue")) - - # For debugging purposes... - #println("AlphaVector tag: # vectors, vector length: $(num_vectors), $(vector_length)") - - # Arrays with vector and sparse vector tags - vector_tags = get_elements_by_tagname(alphavector_tag, "Vector") - sparsevector_tags = get_elements_by_tagname(alphavector_tag, "SparseVector") - - num_vectors_check = length(vector_tags) + length(sparsevector_tags) # should be same as num_vectors - - # Initialize the gamma matrix. This is basically a matrix with the alpha - # vectors as columns. - #alpha_vectors = Array(Float64, num_vectors, vector_length) - alpha_vectors = Array{Float64}(undef, vector_length, num_vectors) - alpha_actions = Array{String}(undef, num_vectors) - observable_states = Array{String}(undef, num_vectors) - gammarow = 1 - - # Fill in gamma - for vector in vector_tags - alpha = parse.(Float64, split(content(vector))) - #alpha_vectors[gammarow, :] = alpha - alpha_vectors[:,gammarow] = alpha - alpha_actions[gammarow] = attribute(vector, "action") - observable_states[gammarow] = attribute(vector, "obsValue") - gammarow += 1 - end - - # TODO: Handle sparse vectors - for vector in sparsevector_tags - # Turn these into vectors as well - end - - # Return alpha vectors and indices of actions - return alpha_vectors, [parse(Int64,s) for s in alpha_actions] -end diff --git a/src/reader.jl b/src/reader.jl new file mode 100644 index 0000000..137142f --- /dev/null +++ b/src/reader.jl @@ -0,0 +1,41 @@ +"""Parses a Policy XML file and returns `alphavectors` and `alphaactions` +This should be able to handle any policy in the `.policy` format, which includes + policies written by `POMDPs.jl` or `APPL`. + +`alphavectors` is a matrix containing the alpha vectors where each row corresponds to a + different alpha vector. + +`alphaactions` is a vector containing the list of action indices, where ach index + corresponds to action associated with the alpha vector row. +These are `0-indexed`, so `0` maps to the first action. +""" + +function read_pomdp(filename::String) + xml = readxml(open(filename, "r")) + policy = root(xml) # The node + + av_node = findfirst("AlphaVector", policy) + + alphavectors = Array{Real}[] + alphaactions = [] + + # Create alpha vector and alpha action lists + vectors = findall("Vector", av_node) + for vector in vectors + push!(alphavectors, parse.(Float64, split(nodecontent(vector)))) + push!(alphaactions, vector["action"]) + end + alphavectors = transpose(mapreduce(permutedims, vcat, alphavectors)) + + n_vectors = parse(Int, av_node["numVectors"]) + vector_len = parse(Int, av_node["vectorLength"]) + @assert size(alphavectors) == (vector_len, n_vectors) + + # TODO handle sparsevectors + sparsevectors = findall("SparseVector", av_node) + for vector in sparsevectors + # Turn these into vectors + end + + return (alphavectors=alphavectors, alphaactions=parse.(Int64, alphaactions)) +end \ No newline at end of file diff --git a/src/writer.jl b/src/writer.jl index 7bbffa0..8af1451 100644 --- a/src/writer.jl +++ b/src/writer.jl @@ -1,398 +1,277 @@ -################################################################# -# This file implements a .pomdpx file generator using the -# POMDPs.jl interface. -################################################################# - abstract type AbstractPOMDPXFile end -mutable struct POMDPXFile <: AbstractPOMDPXFile - file_name::AbstractString - description::AbstractString +@with_kw struct POMDPXFile <: AbstractPOMDPXFile + filename::String + description::String = "This is a POMDPX file for a POMDP" - state_name::AbstractString - action_name::AbstractString - reward_name::AbstractString - obs_name::AbstractString + s_var_name::String = "state" + a_var_name::String = "action" + o_var_name::String = "observation" + r_var_name::String = "reward" - initial_belief::Vector{Float64} + s_name::Function = normalize + a_name::Function = normalize + o_name::Function = normalize - #initial_belief::Vector{Float64} # belief over partially observed vars + pretty::Bool = false +end - function POMDPXFile(file_name::AbstractString; description::AbstractString="", - initial_belief::Vector{Float64}=Float64[]) +POMDPXFile(filename::String, pretty::Bool) = POMDPXFile(; filename=filename, pretty=pretty) +POMDPXFile(filename::String) = POMDPXFile(filename, false) + +a_var_name( px::POMDPXFile) = px.a_var_name +s_var_name( px::POMDPXFile) = "$(px.s_var_name)0" +sp_var_name(px::POMDPXFile) = "$(px.s_var_name)1" +o_var_name( px::POMDPXFile) = px.o_var_name +r_var_name( px::POMDPXFile) = px.r_var_name + +function build_xml(p::POMDP, px::POMDPXFile) + n_states = length(states(p)) + n_actions = length(actions(p)) + n_obs = length(observations(p)) + n_nodes = 14 + n_states + n_states * n_actions * n_obs + n_states * n_actions * 2 + pbar = Progress(n_nodes; dt=0.01) + + doc = XMLDocument() + root = ElementNode("pomdpx") + root["version"] = 0.1 + root["id"] = replace(px.filename, ".pomdpx" => "") + root["xmlns:xsi"] = "http://www.w3.org/2001/XMLSchema-instance" + root["xsi:noNamespaceSchemaLocation"] = "https://raw.githubusercontent.com/JuliaPOMDP/sarsop/master/doc/POMDPX/pomdpx.xsd" + setroot!(doc, root) + + addelement!(root, "Description", px.description) + addelement!(root, "Discount", "$(discount(p))") + next!(pbar) + + build_variables!(root, p, px, pbar) + build_initial_beliefs(root, p, px, pbar) + build_transitions!(root, p, px, pbar) + build_observations!(root, p, px, pbar) + build_rewards!(root, p, px, pbar) + + return doc +end - if isempty(description) - description = "This is a pomdpx file for a partially observable MDP" - end +function build_variables!(root::Node, p::POMDP, px::POMDPXFile, pbar) + variables = ElementNode("Variable") + link!(root, variables) + + states_node = ElementNode("StateVar") + states_node["vnamePrev"] = s_var_name(px) + states_node["vnameCurr"] = sp_var_name(px) + states_node["fullyObs" ] = "false" + link!(variables, states_node) + + actions_node = ElementNode("ActionVar") + actions_node["vname"] = a_var_name(px) + link!(variables, actions_node) + + obs_node = ElementNode("ObsVar") + obs_node["vname"] = o_var_name(px) + link!(variables, obs_node) + + reward_node = ElementNode("RewardVar") + reward_node["vname"] = r_var_name(px) + link!(variables, reward_node) + next!(pbar) + + if px.pretty + all_states = join(["s_$(px.s_name(s))" for s=ordered_states(p)], " ") + addelement!(states_node, "ValueEnum", all_states) + + all_actions = join(["a_$(px.a_name(a))" for a=ordered_actions(p)], " ") + addelement!(actions_node, "ValueEnum", all_actions) + + all_observations = join(["o_$(px.o_name(o))" for o=ordered_observations(p)], " ") + addelement!(obs_node, "ValueEnum", all_observations) + else + addelement!(states_node, "NumValues", "$(length(states(p)))") + addelement!(actions_node, "NumValues", "$(length(actions(p)))") + addelement!(obs_node, "NumValues", "$(length(observations(p)))") + end +end - self = new() - self.file_name = file_name - self.description = description +function param(label::String; prob::Real = -1, value::Real = -Inf) + entry = ElementNode("Entry") + addelement!(entry, "Instance", label) + if prob != -1 + addelement!(entry, "ProbTable", "$(prob)") + end + if !isinf(value) + addelement!(entry, "ValueTable", "$(value)") + end + return entry +end - self.state_name = "state" - self.action_name = "action" - self.reward_name = "reward" - self.obs_name = "observation" +function build_initial_beliefs(root::Node, p::POMDP, px::POMDPXFile, pbar) + ibstate = ElementNode("InitialStateBelief") + link!(root, ibstate) - self.initial_belief = initial_belief + condprob = ElementNode("CondProb") + link!(ibstate, condprob) - return self - end + addelement!(condprob, "Var", s_var_name(px)) + addelement!(condprob, "Parent", "null") -end + parameter = ElementNode("Parameter") + parameter["type"] = "TBL" + link!(condprob, parameter) + next!(pbar) + init_states = initialstate(p) + for s in states(p) + if px.pretty + label = "s_$(px.s_name(s))" + else + sidx = stateindex(p, s) + label = "s$(sidx - 1)" + end -function Base.write(pomdp::POMDP, pomdpx::AbstractPOMDPXFile) - file_name = pomdpx.file_name - description = pomdpx.description - discount_factor = discount(pomdp) - - # Open file to write to - out_file = open("$file_name", "w") - - pomdp_states = ordered_states(pomdp) - pomdp_pstates = ordered_states(pomdp) - acts = ordered_actions(pomdp) - obs = ordered_observations(pomdp) - # x = Number of next statement to track Progress - # Added approximately after every four lines written to file, 14 next statements outside loops - x = 14 + length(pomdp_states) + length(pomdp_states)*length(acts)*length(obs) + length(pomdp_states)*length(acts) + length(acts)*length(pomdp_pstates) - p1 = Progress(x, dt=0.01) - - # Header stuff for xml - write(out_file, "\n\n\n") - write(out_file, "\n\n\n") - - sleep(0.01) - next!(p1) - ############################################################################ - # DESCRIPTION - ############################################################################ - write(out_file, "\t $(description)\n\n\n") - - ############################################################################ - # DISCOUNT - ############################################################################ - write(out_file, "\t$(discount_factor)\n\n\n") - - ############################################################################ - # VARIABLES - ############################################################################ - write(out_file, "\t\n") - next!(p1) - # State Variables - str = state_xml(pomdp, pomdpx) - write(out_file, str) - next!(p1) - # Action Variables - str = action_xml(pomdp, pomdpx) - write(out_file, str) - next!(p1) - # Observation Variables - str = obs_var_xml(pomdp, pomdpx) - write(out_file, str) - next!(p1) - # Reward Variable - str = reward_var_xml(pomdp, pomdpx) - write(out_file, str) - write(out_file, "\t\n\n\n") - - next!(p1) - ############################################################################ - # INITIAL STATE BELIEF - ############################################################################ - belief_xml(pomdp, pomdpx, out_file, p1) - - - ############################################################################ - # STATE TRANSITION FUNCTION - ############################################################################ - trans_xml(pomdp, pomdpx, out_file, p1) - - - ############################################################################ - # OBS FUNCTION - ############################################################################ - obs_xml(pomdp, pomdpx, out_file, p1) - - - ############################################################################ - # REWARD FUNCTION - ############################################################################ - reward_xml(pomdp, pomdpx, out_file, p1) - - - # CLOSE POMDPX TAG AND FILE - write(out_file, "") - close(out_file) + link!(parameter, param(label; prob=pdf(init_states, s))) + next!(pbar) + end end +function build_transitions!(root::Node, p::POMDP, px::POMDPXFile, pbar) + statetrans = ElementNode("StateTransitionFunction") + link!(root, statetrans) -############################################################################ -# function: state_xml -# input: pomdp model, pomdpx type -# output: string in xml format the defines the state varaibles -############################################################################ -function state_xml(pomdp::POMDP, pomdpx::POMDPXFile) - # defines state vars for a POMDP - n_s = length(states(pomdp)) - sname = pomdpx.state_name - str = "\t\t\n" - str = "$(str)\t\t\t$(n_s)\n" - str = "$(str)\t\t\n\n" - return str -end -############################################################################ - - - -############################################################################ -# function: obs_var_xml -# input: pomdp model, pomdpx type -# output: string in xml format the defines the observation varaibles -############################################################################ -function obs_var_xml(pomdp::POMDP, pomdpx::AbstractPOMDPXFile) - # defines observation vars for POMDP and MOMDP - n_o = length(observations(pomdp)) - oname = pomdpx.obs_name - str = "\t\t\n" - str = "$(str)\t\t\t$(n_o)\n" - str = "$(str)\t\t\n\n" - return str -end -############################################################################ - - - -############################################################################ -# function: action_xml -# input: pomdp model, pomdpx type -# output: string in xml format the defines the action varaibles -############################################################################ -function action_xml(pomdp::POMDP, pomdpx::AbstractPOMDPXFile) - # defines action vars for MDP, POMDP and MOMDP - n_a = length(actions(pomdp)) - aname = pomdpx.action_name - str = "\t\t\n" - str = "$(str)\t\t\t$(n_a)\n" - str = "$(str)\t\t\n\n" - return str -end -############################################################################ + condprob = ElementNode("CondProb") + link!(statetrans, condprob) + addelement!(condprob, "Var", sp_var_name(px)) + addelement!(condprob, "Parent", "$(a_var_name(px)) $(s_var_name(px))") + parameter = ElementNode("Parameter") + link!(condprob, parameter) + next!(pbar) -############################################################################ -# function: reward_var_xml -# input: pomdp model, pomdpx type -# output: string in xml format the defines the reward varaible -############################################################################ -function reward_var_xml(pomdp::POMDP, pomdpx::AbstractPOMDPXFile) - # defines reward var for MDP, POMDP and MOMDP - rname = pomdpx.reward_name - str = "\t\t\n\n" - return str -end -############################################################################ - - - -############################################################################ -# function: belief_xml -# input: pomdp model, pomdpx type, output file -# output: None, writes the initial belief to the output file -############################################################################ -function belief_xml(pomdp::POMDP, pomdpx::POMDPXFile, out_file::IOStream, p1) - belief = pomdpx.initial_belief - var = pomdpx.state_name - write(out_file, "\t\n") - str = "\t\t\n" - str = "$(str)\t\t\t$(var)0\n" - str = "$(str)\t\t\tnull\n" - str = "$(str)\t\t\t\n" - next!(p1) - - d = initialstate(pomdp) - for (i, s) in enumerate(ordered_states(pomdp)) - p = pdf(d, s) - str = "$(str)\t\t\t\t\n" - str = "$(str)\t\t\t\t\ts$(i-1)\n" - str = "$(str)\t\t\t\t\t$(p)\n" - str = "$(str)\t\t\t\t\n" - next!(p1) - end - str = "$(str)\t\t\t\n" - str = "$(str)\t\t\n" - write(out_file, str) - write(out_file, "\t\n\n\n") - next!(p1) -end -############################################################################ - - - -############################################################################ -# function: trans_xml -# input: pomdp model, pomdpx type, output file -# output: None, writes the transition probability table to the output file -############################################################################ -function trans_xml(pomdp::POMDP, pomdpx::POMDPXFile, out_file::IOStream, p1) - pomdp_states = ordered_states(pomdp) - pomdp_pstates = ordered_states(pomdp) - acts = ordered_actions(pomdp) - - aname = pomdpx.action_name - var = pomdpx.state_name - - write(out_file, "\t\n") - str = "\t\t\n" - str = "$(str)\t\t\t$(var)1\n" - str = "$(str)\t\t\t$(aname) $(var)0\n" - str = "$(str)\t\t\t\n" - write(out_file, str) - next!(p1) - for (i, s) in enumerate(pomdp_states) - if isterminal(pomdp, s) # if terminal, just remain in the same state - str = "\t\t\t\t\n" - str = "$(str)\t\t\t\t\t* s$(i-1) s$(i-1)\n" - str = "$(str)\t\t\t\t\t1.0\n" - str = "$(str)\t\t\t\t\n" - write(out_file, str) - for i = 1:length(acts)*length(pomdp_pstates) - next!(p1) + link!(parameter, param("* * *"; prob=0.0)) + + for s=states(p) + if isterminal(p, s) + if px.pretty + label = "* s_$(px.s_name(s)) s_$(px.s_name(s))" + else + sidx = stateindex(p, s) + label = "* s$(sidx - 1) s$(sidx - 1)" end - else - for (ai, a) in enumerate(acts) - d = transition(pomdp, s, a) - for (j, sp) in enumerate(pomdp_pstates) - p = pdf(d, sp) - if p > 0.0 - str = "\t\t\t\t\n" - str = "$(str)\t\t\t\t\ta$(ai-1) s$(i-1) s$(j-1)\n" - str = "$(str)\t\t\t\t\t$(p)\n" - str = "$(str)\t\t\t\t\n" - write(out_file, str) - end - next!(p1) - end + + link!(parameter, param(label; prob=1.0)) + for _=1:(length(actions(p)) * length(states(p))) + next!(pbar) end + + continue + end + + for a=actions(p), sp=states(p) + if px.pretty + label = "a_$(px.a_name(a)) s_$(px.s_name(s)) s_$(px.s_name(sp))" + else + (aidx, sidx, spidx) = (actionindex(p, a), stateindex(p, sp), stateindex(p, sp)) + label = "a$(aidx - 1) s$(sidx - 1) s$(spidx - 1)" + end + + T = transition(p, s, a) + link!(parameter, param(label; prob=pdf(T, sp))) + next!(pbar) end end - str = "\t\t\t\n" - str = "$(str)\t\t\n" - write(out_file, str) - write(out_file, "\t\n\n\n") - next!(p1) - return nothing end -############################################################################ +function build_observations!(root::Node, p::POMDP, px::POMDPXFile, pbar) + obsfun = ElementNode("ObsFunction") + link!(root, obsfun) + condprob = ElementNode("CondProb") + link!(obsfun, condprob) -############################################################################ -# function: obs_xml -# input: pomdp model, pomdpx type, output file -# output: None, writes the observation probability table to the output file -############################################################################ -function obs_xml(pomdp::POMDP, pomdpx::POMDPXFile, out_file::IOStream, p1) - pomdp_states = ordered_states(pomdp) - acts = ordered_actions(pomdp) - obs = ordered_observations(pomdp) + addelement!(condprob, "Var", o_var_name(px)) + addelement!(condprob, "Parent", "$(a_var_name(px)) $(sp_var_name(px))") - aname = pomdpx.action_name - oname = pomdpx.obs_name - var = pomdpx.state_name + parameter = ElementNode("Parameter") + link!(condprob, parameter) + next!(pbar) - write(out_file, "\t\n") - str = "\t\t\n" - str = "$(str)\t\t\t$(oname)\n" - str = "$(str)\t\t\t$(aname) $(var)1\n" - str = "$(str)\t\t\t\n" - write(out_file, str) - next!(p1) + link!(parameter, param("* * *"; prob=0.0)) - try observation(pomdp, first(acts), first(pomdp_states)) + try observation(p, first(actions(p)), first(states(p))) catch ex if ex isa MethodError - @warn("""POMDPXFiles only supports observation distributions conditioned on a and sp. + @warn("""POMDPXFiles only supports observation distributions conditioned on `a` and `sp`. + + Check that there is an `observation(::P, ::A, ::S)` method available (or an (::A, ::S) method of the observation function for a QuickPOMDP). - Check that there is an `observation(::M, ::A, ::S)` method available (or an (::A, ::S) method of the observation function for a QuickPOMDP). - This warning is designed to give a helpful hint to fix errors, but may not always be relevant. - """, M=typeof(pomdp), S=typeof(first(pomdp_states)), A=typeof(first(acts))) + """, P=typeof(p), S=eltype(states(p)), A=eltype(actions(p))) end rethrow(ex) end - for (i, s) in enumerate(pomdp_states) - for (ai, a) in enumerate(acts) - d = observation(pomdp, a, s) - for (oi, o) in enumerate(obs) - p = pdf(d, o) - if p > 0.0 - str = "\t\t\t\t\n" - str = "$(str)\t\t\t\t\ta$(ai-1) s$(i-1) o$(oi-1)\n" - str = "$(str)\t\t\t\t\t$(p)\n" - str = "$(str)\t\t\t\t\n" - write(out_file, str) - end - next!(p1) - end + for a=actions(p), sp=states(p), o=observations(p) + if px.pretty + label = "a_$(px.a_name(a)) s_$(px.s_name(sp)) o_$(px.o_name(o))" + else + (aidx, spidx, oidx) = (actionindex(p, a), stateindex(p, sp), obsindex(p, o)) + label = "a$(aidx - 1) s$(spidx - 1) o$(oidx - 1)" end + + O = observation(p, a, sp) + link!(parameter, param(label; prob=pdf(O, o))) + next!(pbar) end - write(out_file, "\t\t\t\n") - write(out_file, "\t\t\n") - write(out_file, "\t\n") - next!(p1) end -############################################################################ - - - -############################################################################ -# function: reward_xml -# input: pomdp model, pomdpx type, output file -# output: None, writes the reward function to the output file -############################################################################ -function reward_xml(pomdp::POMDP, pomdpx::POMDPXFile, out_file::IOStream, p1) - pomdp_states = ordered_states(pomdp) - acts = ordered_actions(pomdp) - rew = StateActionReward(pomdp) - - aname = pomdpx.action_name - var = pomdpx.state_name - rname = pomdpx.reward_name - - write(out_file, "\t\n") - str = "\t\t\n" - str = "$(str)\t\t\t$(rname)\n" - str = "$(str)\t\t\t$(aname) $(var)0\n" - str = "$(str)\t\t\t\n" - write(out_file, str) - next!(p1) - - for (i, s) in enumerate(pomdp_states) - if !isterminal(pomdp, s) - for (ai, a) in enumerate(acts) - r = rew(s, a) - str = "\t\t\t\t\n" - str = "$(str)\t\t\t\t\ta$(ai-1) s$(i-1)\n" - str = "$(str)\t\t\t\t\t$(r)\n" - str = "$(str)\t\t\t\t\n" - write(out_file, str) - next!(p1) - end + +function build_rewards!(root::Node, p::POMDP, px::POMDPXFile, pbar) + rewardfunc = ElementNode("RewardFunction") + link!(root, rewardfunc) + + func = ElementNode("Func") + link!(rewardfunc, func) + + addelement!(func, "Var", r_var_name(px)) + addelement!(func, "Parent", "$(a_var_name(px)) $(s_var_name(px))") + + parameter = ElementNode("Parameter") + link!(func, parameter) + next!(pbar) + + link!(parameter, param("* *"; value=0.0)) + + reward_fn = StateActionReward(p) + for a=actions(p), s=states(p) + if px.pretty + label = "a_$(px.a_name(a)) s_$(px.s_name(s))" else - for i = 1:length(acts) - next!(p1) - end + (aidx, sidx) = (actionindex(p, a), stateindex(p, s)) + label = "a$(aidx - 1) s$(sidx - 1)" + end + + if !isterminal(p, s) + link!(parameter, param(label; value=reward_fn(s, a))) end + next!(pbar) end +end + +function Base.write(p::POMDP, px::POMDPXFile) + file = open(px.filename, "w") + doc = build_xml(p, px) + EzXML.prettyprint(file, doc) + close(file) +end + +function normalize(s) + s = string(s) + clean = replace(s, r"[^a-zA-Z0-9]" => "_") + return replace(clean, r"_+" => "_") +end - write(out_file, "\t\t\t\n\t\t\n") - write(out_file, "\t\n\n") - next!(p1) +function prettyprint(p::POMDP, px::POMDPXFile) + file = open(px.filename, "w") + doc = build_xml(p, px) + EzXML.prettyprint(file, doc) + close(file) end -############################################################################ diff --git a/test/mypolicy.policy b/test/mypolicy.policy index 422a3fa..132889f 100644 --- a/test/mypolicy.policy +++ b/test/mypolicy.policy @@ -1,9 +1,13 @@ - - --81.5975 28.4025 -3.01448 24.6954 -24.6954 3.01452 -28.4025 -81.5975 -19.3711 19.3711 - + + + -81.5975 28.4025 + 3.01448 24.6954 + 24.6954 3.01452 + 28.4025 -81.5975 + 19.3711 19.3711 + + \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 35ea892..280199e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,9 +5,9 @@ using POMDPModels using Test @testset "basic" begin - file_name = "tiger_test.pomdpx" + filename = "tiger_test.pomdpx" pomdp = TigerPOMDP() - pomdpx = POMDPXFile(file_name) + pomdpx = POMDPXFile(; filename=filename) write(pomdp, pomdpx) av, aa = read_pomdp("mypolicy.policy") @@ -15,7 +15,7 @@ using Test @test aa == [1,0,0,2,0] end -@testset "a, sp observation warning" begin +@testset "a, sp observation warning" begin struct BadObsPOMDP <: POMDP{Int,Int,Int} end POMDPs.states(m::BadObsPOMDP) = 1:2 POMDPs.actions(m::BadObsPOMDP) = 1:2 @@ -30,11 +30,11 @@ end POMDPs.obsindex(m::BadObsPOMDP, s) = s @test_throws MethodError cd(mktempdir()) do - write(BadObsPOMDP(), POMDPXFile("bad_obs_test.pomdpx")) + write(BadObsPOMDP(), POMDPXFile(; filename="bad_obs_test.pomdpx")) end POMDPs.observation(m::BadObsPOMDP, a, sp) = Deterministic(1) cd(mktempdir()) do - write(BadObsPOMDP(), POMDPXFile("bad_obs_test.pomdpx")) + write(BadObsPOMDP(), POMDPXFile(; filename="bad_obs_test.pomdpx")) end end