From 041b5fc27b5f33935100ff654017a30ffcb0f12f Mon Sep 17 00:00:00 2001 From: John Muchovej <5000729+jmuchovej@users.noreply.github.com> Date: Mon, 5 Jun 2023 22:09:45 -0400 Subject: [PATCH 1/9] Migrating to EzXML; adding pretty-printing Migrating from LightXML to EzXML. Enabling states/actions/observations to mapped onto readable names rather than numeric-oriented values. --- Project.toml | 6 +- REQUIRE | 2 - src/POMDPXFiles.jl | 12 +- src/read.jl | 119 ---------- src/reader.jl | 41 ++++ src/writer.jl | 549 +++++++++++++++---------------------------- test/mypolicy.policy | 20 +- 7 files changed, 255 insertions(+), 494 deletions(-) delete mode 100644 REQUIRE delete mode 100644 src/read.jl create mode 100644 src/reader.jl diff --git a/Project.toml b/Project.toml index f9bf2a6..2f9bd81 100644 --- a/Project.toml +++ b/Project.toml @@ -4,13 +4,13 @@ repo = "https://github.com/JuliaPOMDP/POMDPXFiles.jl" version = "0.2.4" [deps] -LightXML = "9c8b4983-aa76-5018-a973-4c85ecc9e179" +EzXML = "8f5d6c58-4d21-5cfd-889c-e3ad7ee6a615" POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7" POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" +Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" [compat] -LightXML = "0.9" POMDPTools = "0.1" POMDPs = "0.7.3, 0.8, 0.9" julia = "1" @@ -20,4 +20,4 @@ POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["POMDPModels", "Test"] \ No newline at end of file +test = ["POMDPModels", "Test"] diff --git a/REQUIRE b/REQUIRE deleted file mode 100644 index c6df563..0000000 --- a/REQUIRE +++ /dev/null @@ -1,2 +0,0 @@ -julia 0.6 -LightXML diff --git a/src/POMDPXFiles.jl b/src/POMDPXFiles.jl index 042709e..8581c40 100644 --- a/src/POMDPXFiles.jl +++ b/src/POMDPXFiles.jl @@ -3,26 +3,20 @@ module POMDPXFiles using POMDPs using POMDPTools using ProgressMeter -import POMDPs: action, value - -# import o avoid naming conflict in POMDPs.jl (value is overloaded in LightXML) -import LightXML: parse_file, root, get_elements_by_tagname, attribute, content +import EzXML: Node, XMLDocument, ElementNode, addelement!, link!, write, readxml, root, findfirst, findall, nodecontent export AbstractPOMDPXFile, POMDPXFile, - MOMDPXFile, Alphas, POMDPAlphas, - read_pomdp, - action, - value + read_pomdp include("writer.jl") include("policy.jl") -include("read.jl") +include("reader.jl") end # module diff --git a/src/read.jl b/src/read.jl deleted file mode 100644 index da5606e..0000000 --- a/src/read.jl +++ /dev/null @@ -1,119 +0,0 @@ -# Parses policy xml file and returns alpha vectors and alpha actions -# Should handle any policy in the .policy format -# Should handle policies written by POMDPs.jl or APPL -# -# alpha_vectors is a matrix containing the alpha vectors -# Each row corresponds to a different alpha vector -# -# alpha_actions is a vector containing the list of action indices -# Each index corresponds to action associated with the alpha vector of this row -# These are 0-indexed... 0 means it is the first action -# -# TODO: Check that the input file exists and handle the case that it doesn't -# TODO: Handle sparse vectors -# - -function read_momdp(filename::String) - - # Parse the xml file - # TODO: Check that the file exists and handle the case that it doesn't - xdoc = parse_file(filename) - - # Get the root of the document (the Policy tag in this case) - policy_tag = root(xdoc) - #println(name(policy_tag)) # print the name of this tag - - # Determine expected number of vectors and their length - alphavector_tag = get_elements_by_tagname(policy_tag, "AlphaVector")[1] #length of 1 anyway - num_vectors = int(attribute(alphavector_tag, "numVectors")) - vector_length = int(attribute(alphavector_tag, "vectorLength")) - num_full_obs_states = int(attribute(alphavector_tag, "numObsValue")) - - # For debugging purposes... - #println("AlphaVector tag: # vectors, vector length: $(num_vectors), $(vector_length)") - - # Arrays with vector and sparse vector tags - vector_tags = get_elements_by_tagname(alphavector_tag, "Vector") - sparsevector_tags = get_elements_by_tagname(alphavector_tag, "SparseVector") - - num_vectors_check = length(vector_tags) + length(sparsevector_tags) # should be same as num_vectors - - # Initialize the gamma matrix. This is basically a matrix with the alpha - # vectors as rows. - #alpha_vectors = Array(Float64, num_vectors, vector_length) - alpha_vectors = Array{Float64}(vector_length, num_vectors) - alpha_actions = Array{String}(num_vectors) - observable_states = Array{String}(num_vectors) - gammarow = 1 - - # Fill in gamma - for vector in vector_tags - alpha = float(split(content(vector))) - #alpha_vectors[gammarow, :] = alpha - alpha_vectors[:,gammarow] = alpha - alpha_actions[gammarow] = attribute(vector, "action") - observable_states[gammarow] = attribute(vector, "obsValue") - gammarow += 1 - end - - # TODO: Handle sparse vectors - for vector in sparsevector_tags - # Turn these into vectors as well - end - - # Return alpha vectors and indices of actions - return alpha_vectors, int(alpha_actions), int(observable_states) -end - - -function read_pomdp(filename::String) - - # Parse the xml file - # TODO: Check that the file exists and handle the case that it doesn't - xdoc = parse_file(filename) - - # Get the root of the document (the Policy tag in this case) - policy_tag = root(xdoc) - #println(name(policy_tag)) # print the name of this tag - - # Determine expected number of vectors and their length - alphavector_tag = get_elements_by_tagname(policy_tag, "AlphaVector")[1] #length of 1 anyway - num_vectors = parse(Int64, attribute(alphavector_tag, "numVectors")) - vector_length = parse(Int64, attribute(alphavector_tag, "vectorLength")) - num_full_obs_states = parse(Int64, attribute(alphavector_tag, "numObsValue")) - - # For debugging purposes... - #println("AlphaVector tag: # vectors, vector length: $(num_vectors), $(vector_length)") - - # Arrays with vector and sparse vector tags - vector_tags = get_elements_by_tagname(alphavector_tag, "Vector") - sparsevector_tags = get_elements_by_tagname(alphavector_tag, "SparseVector") - - num_vectors_check = length(vector_tags) + length(sparsevector_tags) # should be same as num_vectors - - # Initialize the gamma matrix. This is basically a matrix with the alpha - # vectors as columns. - #alpha_vectors = Array(Float64, num_vectors, vector_length) - alpha_vectors = Array{Float64}(undef, vector_length, num_vectors) - alpha_actions = Array{String}(undef, num_vectors) - observable_states = Array{String}(undef, num_vectors) - gammarow = 1 - - # Fill in gamma - for vector in vector_tags - alpha = parse.(Float64, split(content(vector))) - #alpha_vectors[gammarow, :] = alpha - alpha_vectors[:,gammarow] = alpha - alpha_actions[gammarow] = attribute(vector, "action") - observable_states[gammarow] = attribute(vector, "obsValue") - gammarow += 1 - end - - # TODO: Handle sparse vectors - for vector in sparsevector_tags - # Turn these into vectors as well - end - - # Return alpha vectors and indices of actions - return alpha_vectors, [parse(Int64,s) for s in alpha_actions] -end diff --git a/src/reader.jl b/src/reader.jl new file mode 100644 index 0000000..2a503ba --- /dev/null +++ b/src/reader.jl @@ -0,0 +1,41 @@ +"""Parses a Policy XML file and returns `alphavectors` and `alphaactions` +This should be able to handle any policy in the `.policy` format, which includes + policies written by `POMDPs.jl` or `APPL`. + +`alphavectors` is a matrix containing the alpha vectors where each row corresponds to a + different alpha vector. + +`alphaactions` is a vector containing the list of action indices, where ach index + corresponds to action associated with the alpha vector row. +These are `0-indexed`, so `0` maps to the first action. +""" + +function read_pomdp(filename::String) + xml = readxml(open(filename, "r")) + policy = root(xml) # The node + + av_node = findfirst("AlphaVector", policy) + + alphavectors = Array{Real}[] + alphaactions = [] + + # Create alpha vector and alpha action lists + vectors = findall("Vector", av_node) + for vector in vectors + push!(alphavectors, parse.(Float64, split(nodecontent(vector)))) + push!(alphaactions, vector["action"]) + end + alphavectors = mapreduce(permutedims, vcat, alphavectors) + + n_vectors = parse(Int, av_node["numVectors"]) + vector_len = parse(Int, av_node["vectorLength"]) + @assert size(alphavectors) == (n_vectors, vector_len) + + # TODO handle sparsevectors + sparsevectors = findall("SparseVector", av_node) + for vector in sparsevectors + # Turn these into vectors + end + + return (alphavectors=alphavectors, alphaactions=parse.(Int64, alphaactions)) +end \ No newline at end of file diff --git a/src/writer.jl b/src/writer.jl index 7bbffa0..3ecf185 100644 --- a/src/writer.jl +++ b/src/writer.jl @@ -1,398 +1,241 @@ -################################################################# -# This file implements a .pomdpx file generator using the -# POMDPs.jl interface. -################################################################# +#!/usr/bin/env julia +using POMDPs +using POMDPTools +using ProgressMeter +import POMDPs: action, value +using POMDPModels +using EzXML +import EzXML: Node, XMLDocument, ElementNode, addelement!, link!, write +using Parameters +using ProgressMeter abstract type AbstractPOMDPXFile end -mutable struct POMDPXFile <: AbstractPOMDPXFile - file_name::AbstractString - description::AbstractString +@with_kw struct POMDPXFile <: AbstractPOMDPXFile + filename::String + description::String = "This is a POMDPX file for a POMDP" - state_name::AbstractString - action_name::AbstractString - reward_name::AbstractString - obs_name::AbstractString + state_name::String = "state" + action_name::String = "action" + obs_name::String = "observation" + reward_name::String = "reward" - initial_belief::Vector{Float64} - - #initial_belief::Vector{Float64} # belief over partially observed vars + pretty::Bool = false +end - function POMDPXFile(file_name::AbstractString; description::AbstractString="", - initial_belief::Vector{Float64}=Float64[]) +a_name( px::POMDPXFile) = px.action_name +s_name( px::POMDPXFile) = px.state_name +sp_name(px::POMDPXFile) = px.state_name * "p" +o_name( px::POMDPXFile) = px.obs_name +r_name( px::POMDPXFile) = px.reward_name + +function build_xml(p::POMDP, px::POMDPXFile) + n_states = length(states(p)) + n_actions = length(actions(p)) + n_obs = length(observations(p)) + n_nodes = 14 + n_states + n_states * n_actions * n_obs + n_states * n_actions * 2 + pbar = Progress(n_nodes; dt=0.01) + + doc = XMLDocument() + root = ElementNode("pomdpx") + root["version"] = 0.1 + root["id"] = replace(px.filename, ".pomdpx" => "") + root["xmlns:xsi"] = "http://www.w3.org/2001/XMLSchema-instance" + root["xsi:noNamespaceSchemaLocation"] = "https://raw.githubusercontent.com/JuliaPOMDP/sarsop/master/doc/POMDPX/pomdpx.xsd" + setroot!(doc, root) + + addelement!(root, "Description", px.description) + addelement!(root, "Discount", "$(discount(pomdp))") + next!(pbar) + + build_variables!(root, p, px, pbar) + build_initial_beliefs(root, p, px, pbar) + build_transitions!(root, p, px, pbar) + build_observations!(root, p, px, pbar) + build_rewards!(root, p, px, pbar) + + return doc +end - if isempty(description) - description = "This is a pomdpx file for a partially observable MDP" - end +function build_variables!(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) + variables = ElementNode("Variable") + link!(root, variables) + + states_node = ElementNode("StateVar") + states_node["vnamePrev"] = s_name(px) + states_node["vnameCurr"] = sp_name(px) + states_node["fullyObs" ] = "false" + link!(variables, states_node) + + actions_node = ElementNode("ActionVar") + actions_node["vname"] = a_name(px) + link!(variables, actions_node) + + obs_node = ElementNode("ObsVar") + obs_node["vname"] = o_name(px) + link!(variables, obs_node) + + reward_node = ElementNode("RewardVar") + reward_node["vname"] = r_name(px) + link!(variables, reward_node) + next!(pbar) + + if !px.pretty + addelement!(states_node, "NumValues", "$(length(states(px.pomdp)))") + addelement!(actions_node, "NumValues", "$(length(actions(px.pomdp)))") + addelement!(obs_node, "NumValues", "$(length(observations(px.pomdp)))") + return + end - self = new() - self.file_name = file_name - self.description = description + all_states = join(["s_$(s)" for s=ordered_states(p)], " ") + addelement!(states_node, "ValueEnum", all_states) - self.state_name = "state" - self.action_name = "action" - self.reward_name = "reward" - self.obs_name = "observation" + all_actions = join(["a_$(a)" for a=ordered_actions(p)], " ") + addelement!(actions_node, "ValueEnum", all_actions) - self.initial_belief = initial_belief + all_observations = join(["o_$(o)" for o=ordered_observations(p)], " ") + addelement!(obs_node, "ValueEnum", all_observations) +end - return self +function param(label::String; prob::Real = -1, value::Real = -Inf) + entry = ElementNode("Entry") + addelement!(entry, "Instance", label) + if prob != -1 + addelement!(entry, "ProbTable", "$(prob)") end - + if !isinf(value) + addelement!(entry, "ValueTable", "$(value)") + end + return entry end +function build_initial_beliefs(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) + ibstate = ElementNode("InitialStateBelief") + link!(root, ibstate) -function Base.write(pomdp::POMDP, pomdpx::AbstractPOMDPXFile) - file_name = pomdpx.file_name - description = pomdpx.description - discount_factor = discount(pomdp) - - # Open file to write to - out_file = open("$file_name", "w") - - pomdp_states = ordered_states(pomdp) - pomdp_pstates = ordered_states(pomdp) - acts = ordered_actions(pomdp) - obs = ordered_observations(pomdp) - # x = Number of next statement to track Progress - # Added approximately after every four lines written to file, 14 next statements outside loops - x = 14 + length(pomdp_states) + length(pomdp_states)*length(acts)*length(obs) + length(pomdp_states)*length(acts) + length(acts)*length(pomdp_pstates) - p1 = Progress(x, dt=0.01) - - # Header stuff for xml - write(out_file, "\n\n\n") - write(out_file, "\n\n\n") - - sleep(0.01) - next!(p1) - ############################################################################ - # DESCRIPTION - ############################################################################ - write(out_file, "\t $(description)\n\n\n") - - ############################################################################ - # DISCOUNT - ############################################################################ - write(out_file, "\t$(discount_factor)\n\n\n") - - ############################################################################ - # VARIABLES - ############################################################################ - write(out_file, "\t\n") - next!(p1) - # State Variables - str = state_xml(pomdp, pomdpx) - write(out_file, str) - next!(p1) - # Action Variables - str = action_xml(pomdp, pomdpx) - write(out_file, str) - next!(p1) - # Observation Variables - str = obs_var_xml(pomdp, pomdpx) - write(out_file, str) - next!(p1) - # Reward Variable - str = reward_var_xml(pomdp, pomdpx) - write(out_file, str) - write(out_file, "\t\n\n\n") - - next!(p1) - ############################################################################ - # INITIAL STATE BELIEF - ############################################################################ - belief_xml(pomdp, pomdpx, out_file, p1) - - - ############################################################################ - # STATE TRANSITION FUNCTION - ############################################################################ - trans_xml(pomdp, pomdpx, out_file, p1) - - - ############################################################################ - # OBS FUNCTION - ############################################################################ - obs_xml(pomdp, pomdpx, out_file, p1) - - - ############################################################################ - # REWARD FUNCTION - ############################################################################ - reward_xml(pomdp, pomdpx, out_file, p1) - - - # CLOSE POMDPX TAG AND FILE - write(out_file, "") - close(out_file) -end + condprob = ElementNode("CondProb") + link!(ibstate, condprob) + addelement!(condprob, "Var", s_name(px)) + addelement!(condprob, "Parent", "null") -############################################################################ -# function: state_xml -# input: pomdp model, pomdpx type -# output: string in xml format the defines the state varaibles -############################################################################ -function state_xml(pomdp::POMDP, pomdpx::POMDPXFile) - # defines state vars for a POMDP - n_s = length(states(pomdp)) - sname = pomdpx.state_name - str = "\t\t\n" - str = "$(str)\t\t\t$(n_s)\n" - str = "$(str)\t\t\n\n" - return str -end -############################################################################ - - - -############################################################################ -# function: obs_var_xml -# input: pomdp model, pomdpx type -# output: string in xml format the defines the observation varaibles -############################################################################ -function obs_var_xml(pomdp::POMDP, pomdpx::AbstractPOMDPXFile) - # defines observation vars for POMDP and MOMDP - n_o = length(observations(pomdp)) - oname = pomdpx.obs_name - str = "\t\t\n" - str = "$(str)\t\t\t$(n_o)\n" - str = "$(str)\t\t\n\n" - return str -end -############################################################################ - - - -############################################################################ -# function: action_xml -# input: pomdp model, pomdpx type -# output: string in xml format the defines the action varaibles -############################################################################ -function action_xml(pomdp::POMDP, pomdpx::AbstractPOMDPXFile) - # defines action vars for MDP, POMDP and MOMDP - n_a = length(actions(pomdp)) - aname = pomdpx.action_name - str = "\t\t\n" - str = "$(str)\t\t\t$(n_a)\n" - str = "$(str)\t\t\n\n" - return str + parameter = ElementNode("Parameter") + parameter["type"] = "TBL" + link!(condprob, parameter) + next!(pbar) + + for (prob, s) in initialstate(p) + sidx = stateindex(p, s) + label = px.pretty ? "s_$(s)" : "s$(sidx)" + link!(parameter, param(label; prob=prob)) + next!(pbar) + end end -############################################################################ +function build_transitions!(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) + statetrans = ElementNode("StateTransitionFunction") + link!(root, statetrans) + condprob = ElementNode("CondProb") + link!(statetrans, condprob) -############################################################################ -# function: reward_var_xml -# input: pomdp model, pomdpx type -# output: string in xml format the defines the reward varaible -############################################################################ -function reward_var_xml(pomdp::POMDP, pomdpx::AbstractPOMDPXFile) - # defines reward var for MDP, POMDP and MOMDP - rname = pomdpx.reward_name - str = "\t\t\n\n" - return str -end -############################################################################ - - - -############################################################################ -# function: belief_xml -# input: pomdp model, pomdpx type, output file -# output: None, writes the initial belief to the output file -############################################################################ -function belief_xml(pomdp::POMDP, pomdpx::POMDPXFile, out_file::IOStream, p1) - belief = pomdpx.initial_belief - var = pomdpx.state_name - write(out_file, "\t\n") - str = "\t\t\n" - str = "$(str)\t\t\t$(var)0\n" - str = "$(str)\t\t\tnull\n" - str = "$(str)\t\t\t\n" - next!(p1) - - d = initialstate(pomdp) - for (i, s) in enumerate(ordered_states(pomdp)) - p = pdf(d, s) - str = "$(str)\t\t\t\t\n" - str = "$(str)\t\t\t\t\ts$(i-1)\n" - str = "$(str)\t\t\t\t\t$(p)\n" - str = "$(str)\t\t\t\t\n" - next!(p1) - end - str = "$(str)\t\t\t\n" - str = "$(str)\t\t\n" - write(out_file, str) - write(out_file, "\t\n\n\n") - next!(p1) -end -############################################################################ - - - -############################################################################ -# function: trans_xml -# input: pomdp model, pomdpx type, output file -# output: None, writes the transition probability table to the output file -############################################################################ -function trans_xml(pomdp::POMDP, pomdpx::POMDPXFile, out_file::IOStream, p1) - pomdp_states = ordered_states(pomdp) - pomdp_pstates = ordered_states(pomdp) - acts = ordered_actions(pomdp) - - aname = pomdpx.action_name - var = pomdpx.state_name - - write(out_file, "\t\n") - str = "\t\t\n" - str = "$(str)\t\t\t$(var)1\n" - str = "$(str)\t\t\t$(aname) $(var)0\n" - str = "$(str)\t\t\t\n" - write(out_file, str) - next!(p1) - for (i, s) in enumerate(pomdp_states) - if isterminal(pomdp, s) # if terminal, just remain in the same state - str = "\t\t\t\t\n" - str = "$(str)\t\t\t\t\t* s$(i-1) s$(i-1)\n" - str = "$(str)\t\t\t\t\t1.0\n" - str = "$(str)\t\t\t\t\n" - write(out_file, str) - for i = 1:length(acts)*length(pomdp_pstates) - next!(p1) + addelement!(condprob, "Var", sp_name(px)) + addelement!(condprob, "Parent", "$(a_name(px)) $(s_name(px))") + + parameter = ElementNode("Parameter") + link!(condprob, parameter) + next!(pbar) + + for s=states(p) + sidx = stateindex(p, s) + if isterminal(p, s) + label = px.pretty ? "* s_$(s) s_$(s)" : "* s$(sidx) s$(sidx)" + link!(parameter, param(label; prob=1.0)) + for _=1:(length(actions(p)) * length(states(p))) + next!(pbar) end - else - for (ai, a) in enumerate(acts) - d = transition(pomdp, s, a) - for (j, sp) in enumerate(pomdp_pstates) - p = pdf(d, sp) - if p > 0.0 - str = "\t\t\t\t\n" - str = "$(str)\t\t\t\t\ta$(ai-1) s$(i-1) s$(j-1)\n" - str = "$(str)\t\t\t\t\t$(p)\n" - str = "$(str)\t\t\t\t\n" - write(out_file, str) - end - next!(p1) - end + continue + end + + for a=actions(p), sp=states(p) + T = transition(p, s, a) + (aidx, spidx) = actionindex(p, a), stateindex(p, sp) + label = px.pretty ? "a_$(a) s_$(s) s_$(sp)" : "a$(aidx) s$(sidx) s$(spidx)" + if pdf(T, sp) > 0.0 + link!(parameter, param(label; prob=pdf(T, sp))) end + next!(pbar) end end - str = "\t\t\t\n" - str = "$(str)\t\t\n" - write(out_file, str) - write(out_file, "\t\n\n\n") - next!(p1) - return nothing end -############################################################################ +function build_observations!(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) + obsfun = ElementNode("ObsFunction") + link!(root, obsfun) + condprob = ElementNode("CondProb") + link!(obsfun, condprob) -############################################################################ -# function: obs_xml -# input: pomdp model, pomdpx type, output file -# output: None, writes the observation probability table to the output file -############################################################################ -function obs_xml(pomdp::POMDP, pomdpx::POMDPXFile, out_file::IOStream, p1) - pomdp_states = ordered_states(pomdp) - acts = ordered_actions(pomdp) - obs = ordered_observations(pomdp) + addelement!(condprob, "Var", o_name(px)) + addelement!(condprob, "Parent", "$(a_name(px)) $(sp_name(px))") - aname = pomdpx.action_name - oname = pomdpx.obs_name - var = pomdpx.state_name + parameter = ElementNode("Paramtieer") + link!(condprob, parameter) + next!(pbar) - write(out_file, "\t\n") - str = "\t\t\n" - str = "$(str)\t\t\t$(oname)\n" - str = "$(str)\t\t\t$(aname) $(var)1\n" - str = "$(str)\t\t\t\n" - write(out_file, str) - next!(p1) - - try observation(pomdp, first(acts), first(pomdp_states)) + try observation(p, first(actions(p)), first(states(p))) catch ex if ex isa MethodError - @warn("""POMDPXFiles only supports observation distributions conditioned on a and sp. + @warn("""POMDPXFiles only supports observation distributions conditioned on `a` and `sp`. + + Check that there is an `observation(::P, ::A, ::S)` method available (or an (::A, ::S) method of the observation function for a QuickPOMDP). - Check that there is an `observation(::M, ::A, ::S)` method available (or an (::A, ::S) method of the observation function for a QuickPOMDP). - This warning is designed to give a helpful hint to fix errors, but may not always be relevant. - """, M=typeof(pomdp), S=typeof(first(pomdp_states)), A=typeof(first(acts))) + """, P=typeof(p), S=eltype(states(p)), A=eltype(actions(p))) end rethrow(ex) end - for (i, s) in enumerate(pomdp_states) - for (ai, a) in enumerate(acts) - d = observation(pomdp, a, s) - for (oi, o) in enumerate(obs) - p = pdf(d, o) - if p > 0.0 - str = "\t\t\t\t\n" - str = "$(str)\t\t\t\t\ta$(ai-1) s$(i-1) o$(oi-1)\n" - str = "$(str)\t\t\t\t\t$(p)\n" - str = "$(str)\t\t\t\t\n" - write(out_file, str) - end - next!(p1) - end + for a=actions(p), sp=states(p), o=observations(p) + O = observation(p, a, sp) + (aidx, spidx, oidx) = (actionindex(p, a), stateindex(p, sp), obsindex(p, o)) + label = px.pretty ? "a_$(a) s_$(sp) o_$(o)" : "a$(aidx) s$(spidx) o$(oidx)" + + if pdf(O, o) > 0. + link!(parameter, param(label; prob=pdf(O, o))) end + next!(pbar) end - write(out_file, "\t\t\t\n") - write(out_file, "\t\t\n") - write(out_file, "\t\n") - next!(p1) end -############################################################################ - - - -############################################################################ -# function: reward_xml -# input: pomdp model, pomdpx type, output file -# output: None, writes the reward function to the output file -############################################################################ -function reward_xml(pomdp::POMDP, pomdpx::POMDPXFile, out_file::IOStream, p1) - pomdp_states = ordered_states(pomdp) - acts = ordered_actions(pomdp) - rew = StateActionReward(pomdp) - - aname = pomdpx.action_name - var = pomdpx.state_name - rname = pomdpx.reward_name - - write(out_file, "\t\n") - str = "\t\t\n" - str = "$(str)\t\t\t$(rname)\n" - str = "$(str)\t\t\t$(aname) $(var)0\n" - str = "$(str)\t\t\t\n" - write(out_file, str) - next!(p1) - - for (i, s) in enumerate(pomdp_states) - if !isterminal(pomdp, s) - for (ai, a) in enumerate(acts) - r = rew(s, a) - str = "\t\t\t\t\n" - str = "$(str)\t\t\t\t\ta$(ai-1) s$(i-1)\n" - str = "$(str)\t\t\t\t\t$(r)\n" - str = "$(str)\t\t\t\t\n" - write(out_file, str) - next!(p1) - end - else - for i = 1:length(acts) - next!(p1) - end + +function build_rewards!(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) + rewardfunc = ElementNode("RewardFunction") + link!(root, rewardfunc) + + func = ElementNode("Func") + link!(rewardfunc, func) + + addelement!(func, "Var", r_name(px)) + addelement!(func, "Parent", "$(a_name(px)) $(s_name(px))") + + parameter = ElementNode("Parameter") + link!(func, parameter) + next!(pbar) + + for a=actions(p), s=states(p) + (aidx, sidx) = (actionindex(p, a), stateindex(p, s)) + label= px.pretty ? "a_$(a) s_$(s)" : "a$(aidx) s$(sidx)" + + if !isterminal(p, s) + link!(parameter, param(label; value=reward(p, s, a))) end + next!(pbar) end - - write(out_file, "\t\t\t\n\t\t\n") - write(out_file, "\t\n\n") - next!(p1) end -############################################################################ + +function Base.write(p::POMDP, px::POMDPXFile) + file = open(px.filename, "w") + doc = build_xml(p, px) + prettyprint(file, doc) + close(file) +end \ No newline at end of file diff --git a/test/mypolicy.policy b/test/mypolicy.policy index 422a3fa..132889f 100644 --- a/test/mypolicy.policy +++ b/test/mypolicy.policy @@ -1,9 +1,13 @@ - - --81.5975 28.4025 -3.01448 24.6954 -24.6954 3.01452 -28.4025 -81.5975 -19.3711 19.3711 - + + + -81.5975 28.4025 + 3.01448 24.6954 + 24.6954 3.01452 + 28.4025 -81.5975 + 19.3711 19.3711 + + \ No newline at end of file From 2138d8f90fa5a41ff7ef73af221b965b3dbd8953 Mon Sep 17 00:00:00 2001 From: John Muchovej <5000729+jmuchovej@users.noreply.github.com> Date: Mon, 5 Jun 2023 23:20:04 -0400 Subject: [PATCH 2/9] Cleaning up imports; checking tests Cleaning up imports from the stand-alone scripts I used to prototype. Also checking that tests pass. --- src/POMDPXFiles.jl | 4 ++-- src/reader.jl | 4 ++-- src/writer.jl | 37 ++++++++++++++----------------------- test/runtests.jl | 10 +++++----- 4 files changed, 23 insertions(+), 32 deletions(-) diff --git a/src/POMDPXFiles.jl b/src/POMDPXFiles.jl index 8581c40..a410bb1 100644 --- a/src/POMDPXFiles.jl +++ b/src/POMDPXFiles.jl @@ -3,8 +3,9 @@ module POMDPXFiles using POMDPs using POMDPTools using ProgressMeter +using Parameters -import EzXML: Node, XMLDocument, ElementNode, addelement!, link!, write, readxml, root, findfirst, findall, nodecontent +import EzXML: Node, XMLDocument, ElementNode, addelement!, link!, write, readxml, root, findfirst, findall, nodecontent, setroot!, prettyprint export AbstractPOMDPXFile, @@ -14,7 +15,6 @@ export read_pomdp - include("writer.jl") include("policy.jl") include("reader.jl") diff --git a/src/reader.jl b/src/reader.jl index 2a503ba..137142f 100644 --- a/src/reader.jl +++ b/src/reader.jl @@ -25,11 +25,11 @@ function read_pomdp(filename::String) push!(alphavectors, parse.(Float64, split(nodecontent(vector)))) push!(alphaactions, vector["action"]) end - alphavectors = mapreduce(permutedims, vcat, alphavectors) + alphavectors = transpose(mapreduce(permutedims, vcat, alphavectors)) n_vectors = parse(Int, av_node["numVectors"]) vector_len = parse(Int, av_node["vectorLength"]) - @assert size(alphavectors) == (n_vectors, vector_len) + @assert size(alphavectors) == (vector_len, n_vectors) # TODO handle sparsevectors sparsevectors = findall("SparseVector", av_node) diff --git a/src/writer.jl b/src/writer.jl index 3ecf185..8a7f59f 100644 --- a/src/writer.jl +++ b/src/writer.jl @@ -1,14 +1,3 @@ -#!/usr/bin/env julia -using POMDPs -using POMDPTools -using ProgressMeter -import POMDPs: action, value -using POMDPModels -using EzXML -import EzXML: Node, XMLDocument, ElementNode, addelement!, link!, write -using Parameters -using ProgressMeter - abstract type AbstractPOMDPXFile end @with_kw struct POMDPXFile <: AbstractPOMDPXFile @@ -45,7 +34,7 @@ function build_xml(p::POMDP, px::POMDPXFile) setroot!(doc, root) addelement!(root, "Description", px.description) - addelement!(root, "Discount", "$(discount(pomdp))") + addelement!(root, "Discount", "$(discount(p))") next!(pbar) build_variables!(root, p, px, pbar) @@ -57,7 +46,7 @@ function build_xml(p::POMDP, px::POMDPXFile) return doc end -function build_variables!(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) +function build_variables!(root::Node, p::POMDP, px::POMDPXFile, pbar) variables = ElementNode("Variable") link!(root, variables) @@ -81,9 +70,9 @@ function build_variables!(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) next!(pbar) if !px.pretty - addelement!(states_node, "NumValues", "$(length(states(px.pomdp)))") - addelement!(actions_node, "NumValues", "$(length(actions(px.pomdp)))") - addelement!(obs_node, "NumValues", "$(length(observations(px.pomdp)))") + addelement!(states_node, "NumValues", "$(length(states(p)))") + addelement!(actions_node, "NumValues", "$(length(actions(p)))") + addelement!(obs_node, "NumValues", "$(length(observations(p)))") return end @@ -109,7 +98,7 @@ function param(label::String; prob::Real = -1, value::Real = -Inf) return entry end -function build_initial_beliefs(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) +function build_initial_beliefs(root::Node, p::POMDP, px::POMDPXFile, pbar) ibstate = ElementNode("InitialStateBelief") link!(root, ibstate) @@ -124,15 +113,16 @@ function build_initial_beliefs(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) link!(condprob, parameter) next!(pbar) - for (prob, s) in initialstate(p) + init_states = initialstate(p) + for s in states(p) sidx = stateindex(p, s) label = px.pretty ? "s_$(s)" : "s$(sidx)" - link!(parameter, param(label; prob=prob)) + link!(parameter, param(label; prob=pdf(init_states, s))) next!(pbar) end end -function build_transitions!(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) +function build_transitions!(root::Node, p::POMDP, px::POMDPXFile, pbar) statetrans = ElementNode("StateTransitionFunction") link!(root, statetrans) @@ -169,7 +159,7 @@ function build_transitions!(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) end end -function build_observations!(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) +function build_observations!(root::Node, p::POMDP, px::POMDPXFile, pbar) obsfun = ElementNode("ObsFunction") link!(root, obsfun) @@ -208,7 +198,7 @@ function build_observations!(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) end end -function build_rewards!(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) +function build_rewards!(root::Node, p::POMDP, px::POMDPXFile, pbar) rewardfunc = ElementNode("RewardFunction") link!(root, rewardfunc) @@ -222,12 +212,13 @@ function build_rewards!(root::EzXML.Node, p::POMDP, px::POMDPXFile, pbar) link!(func, parameter) next!(pbar) + reward_fn = StateActionReward(p) for a=actions(p), s=states(p) (aidx, sidx) = (actionindex(p, a), stateindex(p, s)) label= px.pretty ? "a_$(a) s_$(s)" : "a$(aidx) s$(sidx)" if !isterminal(p, s) - link!(parameter, param(label; value=reward(p, s, a))) + link!(parameter, param(label; value=reward_fn(s, a))) end next!(pbar) end diff --git a/test/runtests.jl b/test/runtests.jl index 35ea892..280199e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,9 +5,9 @@ using POMDPModels using Test @testset "basic" begin - file_name = "tiger_test.pomdpx" + filename = "tiger_test.pomdpx" pomdp = TigerPOMDP() - pomdpx = POMDPXFile(file_name) + pomdpx = POMDPXFile(; filename=filename) write(pomdp, pomdpx) av, aa = read_pomdp("mypolicy.policy") @@ -15,7 +15,7 @@ using Test @test aa == [1,0,0,2,0] end -@testset "a, sp observation warning" begin +@testset "a, sp observation warning" begin struct BadObsPOMDP <: POMDP{Int,Int,Int} end POMDPs.states(m::BadObsPOMDP) = 1:2 POMDPs.actions(m::BadObsPOMDP) = 1:2 @@ -30,11 +30,11 @@ end POMDPs.obsindex(m::BadObsPOMDP, s) = s @test_throws MethodError cd(mktempdir()) do - write(BadObsPOMDP(), POMDPXFile("bad_obs_test.pomdpx")) + write(BadObsPOMDP(), POMDPXFile(; filename="bad_obs_test.pomdpx")) end POMDPs.observation(m::BadObsPOMDP, a, sp) = Deterministic(1) cd(mktempdir()) do - write(BadObsPOMDP(), POMDPXFile("bad_obs_test.pomdpx")) + write(BadObsPOMDP(), POMDPXFile(; filename="bad_obs_test.pomdpx")) end end From 68dce8f5837a9258b635b98fb9bfc6f1fdb65379 Mon Sep 17 00:00:00 2001 From: John Muchovej <5000729+jmuchovej@users.noreply.github.com> Date: Tue, 6 Jun 2023 12:30:12 -0400 Subject: [PATCH 3/9] Reorganize pretty printing; force use Base.string --- src/writer.jl | 69 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 24 deletions(-) diff --git a/src/writer.jl b/src/writer.jl index 8a7f59f..318402e 100644 --- a/src/writer.jl +++ b/src/writer.jl @@ -12,9 +12,12 @@ abstract type AbstractPOMDPXFile end pretty::Bool = false end +POMDPXFile(filename::String, pretty::Bool) = POMDPXFile(; filename=filename, pretty=pretty) +POMDPXFile(filename::String) = POMDPXFile(filename, false) + a_name( px::POMDPXFile) = px.action_name -s_name( px::POMDPXFile) = px.state_name -sp_name(px::POMDPXFile) = px.state_name * "p" +s_name( px::POMDPXFile) = "$(px.state_name)0" +sp_name(px::POMDPXFile) = "$(px.state_name)1" o_name( px::POMDPXFile) = px.obs_name r_name( px::POMDPXFile) = px.reward_name @@ -69,21 +72,20 @@ function build_variables!(root::Node, p::POMDP, px::POMDPXFile, pbar) link!(variables, reward_node) next!(pbar) - if !px.pretty + if px.pretty + all_states = join(["s_$(string(s))" for s=ordered_states(p)], " ") + addelement!(states_node, "ValueEnum", all_states) + + all_actions = join(["a_$(string(a))" for a=ordered_actions(p)], " ") + addelement!(actions_node, "ValueEnum", all_actions) + + all_observations = join(["o_$(string(o))" for o=ordered_observations(p)], " ") + addelement!(obs_node, "ValueEnum", all_observations) + else addelement!(states_node, "NumValues", "$(length(states(p)))") addelement!(actions_node, "NumValues", "$(length(actions(p)))") addelement!(obs_node, "NumValues", "$(length(observations(p)))") - return end - - all_states = join(["s_$(s)" for s=ordered_states(p)], " ") - addelement!(states_node, "ValueEnum", all_states) - - all_actions = join(["a_$(a)" for a=ordered_actions(p)], " ") - addelement!(actions_node, "ValueEnum", all_actions) - - all_observations = join(["o_$(o)" for o=ordered_observations(p)], " ") - addelement!(obs_node, "ValueEnum", all_observations) end function param(label::String; prob::Real = -1, value::Real = -Inf) @@ -115,8 +117,12 @@ function build_initial_beliefs(root::Node, p::POMDP, px::POMDPXFile, pbar) init_states = initialstate(p) for s in states(p) - sidx = stateindex(p, s) - label = px.pretty ? "s_$(s)" : "s$(sidx)" + if px.pretty + label = "s_$(string(s))" + else + label = "s$(stateindex(p, s))" + end + link!(parameter, param(label; prob=pdf(init_states, s))) next!(pbar) end @@ -137,20 +143,29 @@ function build_transitions!(root::Node, p::POMDP, px::POMDPXFile, pbar) next!(pbar) for s=states(p) - sidx = stateindex(p, s) if isterminal(p, s) - label = px.pretty ? "* s_$(s) s_$(s)" : "* s$(sidx) s$(sidx)" + if px.pretty + label = "* s_$(string(s)) s_$(string(s))" + else + label = "* s$(stateindex(p, s)) s$(stateindex(p, s))" + end + link!(parameter, param(label; prob=1.0)) for _=1:(length(actions(p)) * length(states(p))) next!(pbar) end + continue end for a=actions(p), sp=states(p) + if px.pretty + label = "a_$(string(a)) s_$(string(s)) s_$(string(sp))" + else + label = "a$(actionindex(p, a)) s$(stateindex(p, s)) s$(stateindex(p, sp))" + end + T = transition(p, s, a) - (aidx, spidx) = actionindex(p, a), stateindex(p, sp) - label = px.pretty ? "a_$(a) s_$(s) s_$(sp)" : "a$(aidx) s$(sidx) s$(spidx)" if pdf(T, sp) > 0.0 link!(parameter, param(label; prob=pdf(T, sp))) end @@ -187,10 +202,13 @@ function build_observations!(root::Node, p::POMDP, px::POMDPXFile, pbar) end for a=actions(p), sp=states(p), o=observations(p) - O = observation(p, a, sp) - (aidx, spidx, oidx) = (actionindex(p, a), stateindex(p, sp), obsindex(p, o)) - label = px.pretty ? "a_$(a) s_$(sp) o_$(o)" : "a$(aidx) s$(spidx) o$(oidx)" + if px.pretty + label = "a_$(string(a)) s_$(string(sp)) o_$(string(o))" + else + label = "a$(actionindex(p, a)) s$(stateindex(p, sp)) o$(obsindex(p, o))" + end + O = observation(p, a, sp) if pdf(O, o) > 0. link!(parameter, param(label; prob=pdf(O, o))) end @@ -214,8 +232,11 @@ function build_rewards!(root::Node, p::POMDP, px::POMDPXFile, pbar) reward_fn = StateActionReward(p) for a=actions(p), s=states(p) - (aidx, sidx) = (actionindex(p, a), stateindex(p, s)) - label= px.pretty ? "a_$(a) s_$(s)" : "a$(aidx) s$(sidx)" + if px.pretty + label = "a_$(string(a)) s_$(string(s))" + else + label = "a$(actionindex(p, a)) s$(stateindex(p, s))" + end if !isterminal(p, s) link!(parameter, param(label; value=reward_fn(s, a))) From d6891fa5f1b50f45e23bd1e6e4aaa6e184fb20f4 Mon Sep 17 00:00:00 2001 From: John Muchovej <5000729+jmuchovej@users.noreply.github.com> Date: Fri, 9 Jun 2023 12:12:29 -0400 Subject: [PATCH 4/9] Bringing impl closer to that of POMDPFiles --- src/writer.jl | 89 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 53 insertions(+), 36 deletions(-) diff --git a/src/writer.jl b/src/writer.jl index 318402e..d71922c 100644 --- a/src/writer.jl +++ b/src/writer.jl @@ -4,10 +4,14 @@ abstract type AbstractPOMDPXFile end filename::String description::String = "This is a POMDPX file for a POMDP" - state_name::String = "state" - action_name::String = "action" - obs_name::String = "observation" - reward_name::String = "reward" + s_var_name::String = "state" + a_var_name::String = "action" + o_var_name::String = "observation" + r_var_name::String = "reward" + + s_name::Function = normalize + a_name::Function = normalize + o_name::Function = normalize pretty::Bool = false end @@ -15,11 +19,11 @@ end POMDPXFile(filename::String, pretty::Bool) = POMDPXFile(; filename=filename, pretty=pretty) POMDPXFile(filename::String) = POMDPXFile(filename, false) -a_name( px::POMDPXFile) = px.action_name -s_name( px::POMDPXFile) = "$(px.state_name)0" -sp_name(px::POMDPXFile) = "$(px.state_name)1" -o_name( px::POMDPXFile) = px.obs_name -r_name( px::POMDPXFile) = px.reward_name +a_var_name( px::POMDPXFile) = px.action_var_name +s_var_name( px::POMDPXFile) = "$(px.state_var_name)0" +sp_var_name(px::POMDPXFile) = "$(px.state_var_name)1" +o_var_name( px::POMDPXFile) = px.obs_var_name +r_var_name( px::POMDPXFile) = px.reward_var_name function build_xml(p::POMDP, px::POMDPXFile) n_states = length(states(p)) @@ -41,10 +45,10 @@ function build_xml(p::POMDP, px::POMDPXFile) next!(pbar) build_variables!(root, p, px, pbar) - build_initial_beliefs(root, p, px, pbar) - build_transitions!(root, p, px, pbar) - build_observations!(root, p, px, pbar) - build_rewards!(root, p, px, pbar) + build_initial_beliefs(root, p, px, sname, pbar) + build_transitions!(root, p, px, sname, aname, pbar) + build_observations!(root, p, px, sname, aname, oname, pbar) + build_rewards!(root, p, px, sname, aname, pbar) return doc end @@ -54,32 +58,32 @@ function build_variables!(root::Node, p::POMDP, px::POMDPXFile, pbar) link!(root, variables) states_node = ElementNode("StateVar") - states_node["vnamePrev"] = s_name(px) - states_node["vnameCurr"] = sp_name(px) + states_node["vnamePrev"] = s_var_name(px) + states_node["vnameCurr"] = sp_var_name(px) states_node["fullyObs" ] = "false" link!(variables, states_node) actions_node = ElementNode("ActionVar") - actions_node["vname"] = a_name(px) + actions_node["vname"] = a_var_name(px) link!(variables, actions_node) obs_node = ElementNode("ObsVar") - obs_node["vname"] = o_name(px) + obs_node["vname"] = o_var_name(px) link!(variables, obs_node) reward_node = ElementNode("RewardVar") - reward_node["vname"] = r_name(px) + reward_node["vname"] = r_var_name(px) link!(variables, reward_node) next!(pbar) if px.pretty - all_states = join(["s_$(string(s))" for s=ordered_states(p)], " ") + all_states = join(["s_$(px.s_name(s))" for s=ordered_states(p)], " ") addelement!(states_node, "ValueEnum", all_states) - all_actions = join(["a_$(string(a))" for a=ordered_actions(p)], " ") + all_actions = join(["a_$(px.a_name(a))" for a=ordered_actions(p)], " ") addelement!(actions_node, "ValueEnum", all_actions) - all_observations = join(["o_$(string(o))" for o=ordered_observations(p)], " ") + all_observations = join(["o_$(px.o_name(o))" for o=ordered_observations(p)], " ") addelement!(obs_node, "ValueEnum", all_observations) else addelement!(states_node, "NumValues", "$(length(states(p)))") @@ -107,7 +111,7 @@ function build_initial_beliefs(root::Node, p::POMDP, px::POMDPXFile, pbar) condprob = ElementNode("CondProb") link!(ibstate, condprob) - addelement!(condprob, "Var", s_name(px)) + addelement!(condprob, "Var", s_var_name(px)) addelement!(condprob, "Parent", "null") parameter = ElementNode("Parameter") @@ -118,7 +122,7 @@ function build_initial_beliefs(root::Node, p::POMDP, px::POMDPXFile, pbar) init_states = initialstate(p) for s in states(p) if px.pretty - label = "s_$(string(s))" + label = "s_$(px.s_name(s))" else label = "s$(stateindex(p, s))" end @@ -135,8 +139,8 @@ function build_transitions!(root::Node, p::POMDP, px::POMDPXFile, pbar) condprob = ElementNode("CondProb") link!(statetrans, condprob) - addelement!(condprob, "Var", sp_name(px)) - addelement!(condprob, "Parent", "$(a_name(px)) $(s_name(px))") + addelement!(condprob, "Var", sp_var_name(px)) + addelement!(condprob, "Parent", "$(a_var_name(px)) $(s_var_name(px))") parameter = ElementNode("Parameter") link!(condprob, parameter) @@ -145,7 +149,7 @@ function build_transitions!(root::Node, p::POMDP, px::POMDPXFile, pbar) for s=states(p) if isterminal(p, s) if px.pretty - label = "* s_$(string(s)) s_$(string(s))" + label = "* s_$(px.s_name(s)) s_$(px.s_name(s))" else label = "* s$(stateindex(p, s)) s$(stateindex(p, s))" end @@ -160,7 +164,7 @@ function build_transitions!(root::Node, p::POMDP, px::POMDPXFile, pbar) for a=actions(p), sp=states(p) if px.pretty - label = "a_$(string(a)) s_$(string(s)) s_$(string(sp))" + label = "a_$(px.a_name(a)) s_$(px.s_name(s)) s_$(px.s_name(sp))" else label = "a$(actionindex(p, a)) s$(stateindex(p, s)) s$(stateindex(p, sp))" end @@ -181,10 +185,10 @@ function build_observations!(root::Node, p::POMDP, px::POMDPXFile, pbar) condprob = ElementNode("CondProb") link!(obsfun, condprob) - addelement!(condprob, "Var", o_name(px)) - addelement!(condprob, "Parent", "$(a_name(px)) $(sp_name(px))") + addelement!(condprob, "Var", o_var_name(px)) + addelement!(condprob, "Parent", "$(a_var_name(px)) $(sp_var_name(px))") - parameter = ElementNode("Paramtieer") + parameter = ElementNode("Parameter") link!(condprob, parameter) next!(pbar) @@ -203,7 +207,7 @@ function build_observations!(root::Node, p::POMDP, px::POMDPXFile, pbar) for a=actions(p), sp=states(p), o=observations(p) if px.pretty - label = "a_$(string(a)) s_$(string(sp)) o_$(string(o))" + label = "a_$(px.a_name(a)) s_$(px.s_name(sp)) o_$(px.o_name(o))" else label = "a$(actionindex(p, a)) s$(stateindex(p, sp)) o$(obsindex(p, o))" end @@ -223,8 +227,8 @@ function build_rewards!(root::Node, p::POMDP, px::POMDPXFile, pbar) func = ElementNode("Func") link!(rewardfunc, func) - addelement!(func, "Var", r_name(px)) - addelement!(func, "Parent", "$(a_name(px)) $(s_name(px))") + addelement!(func, "Var", r_var_name(px)) + addelement!(func, "Parent", "$(a_var_name(px)) $(s_var_name(px))") parameter = ElementNode("Parameter") link!(func, parameter) @@ -233,7 +237,7 @@ function build_rewards!(root::Node, p::POMDP, px::POMDPXFile, pbar) reward_fn = StateActionReward(p) for a=actions(p), s=states(p) if px.pretty - label = "a_$(string(a)) s_$(string(s))" + label = "a_$(px.a_name(a)) s_$(px.s_name(s))" else label = "a$(actionindex(p, a)) s$(stateindex(p, s))" end @@ -248,6 +252,19 @@ end function Base.write(p::POMDP, px::POMDPXFile) file = open(px.filename, "w") doc = build_xml(p, px) - prettyprint(file, doc) + EzXML.prettyprint(file, doc) + close(file) +end + +function normalize(s) + s = string(s) + clean = replace(s, r"[^a-zA-Z0-9]" => "_") + return replace(clean, r"_+" => "_") +end + +function prettyprint(p::POMDP, px::POMDPXFile) + file = open(px.filename, "w") + doc = build_xml(p, px) + EzXML.prettyprint(file, doc) close(file) -end \ No newline at end of file +end From f6a4dfb3fea09de4f3dd6d0ccfd3ab84a0073eeb Mon Sep 17 00:00:00 2001 From: John Muchovej <5000729+jmuchovej@users.noreply.github.com> Date: Fri, 9 Jun 2023 12:16:33 -0400 Subject: [PATCH 5/9] Typos in `var_name` access. --- src/writer.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/writer.jl b/src/writer.jl index d71922c..1d538e1 100644 --- a/src/writer.jl +++ b/src/writer.jl @@ -19,11 +19,11 @@ end POMDPXFile(filename::String, pretty::Bool) = POMDPXFile(; filename=filename, pretty=pretty) POMDPXFile(filename::String) = POMDPXFile(filename, false) -a_var_name( px::POMDPXFile) = px.action_var_name -s_var_name( px::POMDPXFile) = "$(px.state_var_name)0" -sp_var_name(px::POMDPXFile) = "$(px.state_var_name)1" -o_var_name( px::POMDPXFile) = px.obs_var_name -r_var_name( px::POMDPXFile) = px.reward_var_name +a_var_name( px::POMDPXFile) = px.a_var_name +s_var_name( px::POMDPXFile) = "$(px.s_var_name)0" +sp_var_name(px::POMDPXFile) = "$(px.s_var_name)1" +o_var_name( px::POMDPXFile) = px.o_var_name +r_var_name( px::POMDPXFile) = px.r_var_name function build_xml(p::POMDP, px::POMDPXFile) n_states = length(states(p)) From 77b9dd5905c452e23f8a7661c6944df768a9d3a2 Mon Sep 17 00:00:00 2001 From: John Muchovej <5000729+jmuchovej@users.noreply.github.com> Date: Fri, 9 Jun 2023 12:19:12 -0400 Subject: [PATCH 6/9] Forgot to remove `[s|a|o]name` params in `build_!` funcs --- src/writer.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/writer.jl b/src/writer.jl index 1d538e1..c9012a7 100644 --- a/src/writer.jl +++ b/src/writer.jl @@ -45,10 +45,10 @@ function build_xml(p::POMDP, px::POMDPXFile) next!(pbar) build_variables!(root, p, px, pbar) - build_initial_beliefs(root, p, px, sname, pbar) - build_transitions!(root, p, px, sname, aname, pbar) - build_observations!(root, p, px, sname, aname, oname, pbar) - build_rewards!(root, p, px, sname, aname, pbar) + build_initial_beliefs(root, p, px, pbar) + build_transitions!(root, p, px, pbar) + build_observations!(root, p, px, pbar) + build_rewards!(root, p, px, pbar) return doc end From 2075bcbcde1e246d86763d08bac31d1eb6c0d4b4 Mon Sep 17 00:00:00 2001 From: John Muchovej <5000729+jmuchovej@users.noreply.github.com> Date: Fri, 9 Jun 2023 12:21:21 -0400 Subject: [PATCH 7/9] EzXML wasn't imported correctly? --- src/POMDPXFiles.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/POMDPXFiles.jl b/src/POMDPXFiles.jl index a410bb1..6268562 100644 --- a/src/POMDPXFiles.jl +++ b/src/POMDPXFiles.jl @@ -5,6 +5,7 @@ using POMDPTools using ProgressMeter using Parameters +import EzXML import EzXML: Node, XMLDocument, ElementNode, addelement!, link!, write, readxml, root, findfirst, findall, nodecontent, setroot!, prettyprint export From 63d063d0c7a6ea003ccd0e8566a2c4726549fc4c Mon Sep 17 00:00:00 2001 From: John Muchovej <5000729+jmuchovej@users.noreply.github.com> Date: Fri, 9 Jun 2023 12:28:41 -0400 Subject: [PATCH 8/9] Forgot to `-1` on indices b/c C++ starts at 0. --- src/writer.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/writer.jl b/src/writer.jl index c9012a7..9574c6c 100644 --- a/src/writer.jl +++ b/src/writer.jl @@ -124,7 +124,8 @@ function build_initial_beliefs(root::Node, p::POMDP, px::POMDPXFile, pbar) if px.pretty label = "s_$(px.s_name(s))" else - label = "s$(stateindex(p, s))" + sidx = stateindex(p, s) + label = "s$(sidx - 1)" end link!(parameter, param(label; prob=pdf(init_states, s))) @@ -151,7 +152,8 @@ function build_transitions!(root::Node, p::POMDP, px::POMDPXFile, pbar) if px.pretty label = "* s_$(px.s_name(s)) s_$(px.s_name(s))" else - label = "* s$(stateindex(p, s)) s$(stateindex(p, s))" + sidx = stateindex(p, s) + label = "* s$(sidx - 1) s$(sidx - 1)" end link!(parameter, param(label; prob=1.0)) @@ -166,7 +168,8 @@ function build_transitions!(root::Node, p::POMDP, px::POMDPXFile, pbar) if px.pretty label = "a_$(px.a_name(a)) s_$(px.s_name(s)) s_$(px.s_name(sp))" else - label = "a$(actionindex(p, a)) s$(stateindex(p, s)) s$(stateindex(p, sp))" + (aidx, sidx, spidx) = (actionindex(p, a), stateindex(p, sp), stateindex(p, sp)) + label = "a$(aidx - 1) s$(sidx - 1) s$(spidx - 1)" end T = transition(p, s, a) @@ -209,7 +212,8 @@ function build_observations!(root::Node, p::POMDP, px::POMDPXFile, pbar) if px.pretty label = "a_$(px.a_name(a)) s_$(px.s_name(sp)) o_$(px.o_name(o))" else - label = "a$(actionindex(p, a)) s$(stateindex(p, sp)) o$(obsindex(p, o))" + (aidx, spidx, oidx) = (actionindex(p, a), stateindex(p, sp), obsindex(p, o)) + label = "a$(aidx - 1) s$(spidx - 1) o$(oidx - 1)" end O = observation(p, a, sp) @@ -239,7 +243,8 @@ function build_rewards!(root::Node, p::POMDP, px::POMDPXFile, pbar) if px.pretty label = "a_$(px.a_name(a)) s_$(px.s_name(s))" else - label = "a$(actionindex(p, a)) s$(stateindex(p, s))" + (aidx, sidx) = (actionindex(p, a), stateindex(p, s)) + label = "a$(aidx - 1) s$(sidx - 1)" end if !isterminal(p, s) From 22e8b526377e6b6795fc7281cae855424f050ff4 Mon Sep 17 00:00:00 2001 From: John Muchovej <5000729+jmuchovej@users.noreply.github.com> Date: Fri, 9 Jun 2023 12:37:54 -0400 Subject: [PATCH 9/9] Add fallback values for unspecified T/O/R --- src/writer.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/writer.jl b/src/writer.jl index 9574c6c..8af1451 100644 --- a/src/writer.jl +++ b/src/writer.jl @@ -147,6 +147,8 @@ function build_transitions!(root::Node, p::POMDP, px::POMDPXFile, pbar) link!(condprob, parameter) next!(pbar) + link!(parameter, param("* * *"; prob=0.0)) + for s=states(p) if isterminal(p, s) if px.pretty @@ -173,9 +175,7 @@ function build_transitions!(root::Node, p::POMDP, px::POMDPXFile, pbar) end T = transition(p, s, a) - if pdf(T, sp) > 0.0 - link!(parameter, param(label; prob=pdf(T, sp))) - end + link!(parameter, param(label; prob=pdf(T, sp))) next!(pbar) end end @@ -195,6 +195,8 @@ function build_observations!(root::Node, p::POMDP, px::POMDPXFile, pbar) link!(condprob, parameter) next!(pbar) + link!(parameter, param("* * *"; prob=0.0)) + try observation(p, first(actions(p)), first(states(p))) catch ex if ex isa MethodError @@ -217,9 +219,7 @@ function build_observations!(root::Node, p::POMDP, px::POMDPXFile, pbar) end O = observation(p, a, sp) - if pdf(O, o) > 0. - link!(parameter, param(label; prob=pdf(O, o))) - end + link!(parameter, param(label; prob=pdf(O, o))) next!(pbar) end end @@ -238,6 +238,8 @@ function build_rewards!(root::Node, p::POMDP, px::POMDPXFile, pbar) link!(func, parameter) next!(pbar) + link!(parameter, param("* *"; value=0.0)) + reward_fn = StateActionReward(p) for a=actions(p), s=states(p) if px.pretty