importall POMDPs using POMDPToolbox using Distributions using Parameters using Plots using StaticArrays @with_kw struct SimpleLightDark <: POMDPs.POMDP{Int,Int,Float64} discount::Float64 = 0.95 correct_r::Float64 = 100.0 incorrect_r::Float64 = -100.0 light_loc::Int = 10 radius::Int = 60 end discount(p::SimpleLightDark) = p.discount isterminal(p::SimpleLightDark, s::Number) = !(s in -p.radius:p.radius) const ACTIONS = [-10, -1, 0, 1, 10] actions(p::SimpleLightDark) = ACTIONS n_actions(p::SimpleLightDark) = length(actions(p)) const ACTION_INDS = Dict(a=>i for (i,a) in enumerate(actions(SimpleLightDark()))) action_index(p::SimpleLightDark, a::Int) = ACTION_INDS[a] states(p::SimpleLightDark) = -p.radius:p.radius + 1 n_states(p::SimpleLightDark) = length(states(p)) state_index(p::SimpleLightDark, s::Int) = s+p.radius+1 function transition(p::SimpleLightDark, s::Int, a::Int) if a == 0 return SparseCat(SVector(p.radius+1), SVector(1.0)) else return SparseCat(SVector(clamp(s+a, -p.radius, p.radius)), SVector(1.0)) end end observation(p::SimpleLightDark, sp) = Normal(sp, abs(sp - p.light_loc) + 0.0001) function reward(p::SimpleLightDark, s, a) if a == 0 return s == 0 ? p.correct_r : p.incorrect_r else return -1.0 end end function initial_state_distribution(p::SimpleLightDark) ps = ones(2*div(p.radius,2)+1) ps /= length(ps) return SparseCat(div(-p.radius,2):div(p.radius,2), ps) end; using Plots function plothist(pomdp, hist, heading="LightDark") tmax = 80 smin = -10 smax = 20 vsh = collect(filter(s->!isterminal(pomdp,s), state_hist(hist)[1:end-1])) bh = belief_hist(hist) pts = Int[] pss = Int[] pws = Float64[] for t in 0:length(bh)-1 b = bh[t+1] for s in smin:smax w = 10.0*sqrt(pdf(b, s)) if 0.0abs(s-pomdp.light_loc), bg_inside=:black, fill=true, xlim=(0, tmax), ylim=(smin, smax), color=inv_grays, xlabel="Time", ylabel="State", cbar=false, legend=:topright, title=@sprintf("%s (Reward: %8.2f)", heading, discounted_reward(hist)) ) plot!(p, [0, tmax], [0, 0], linewidth=1, color="green", label="Goal", line=:dash) scatter!(p, pts, pss, color="lightblue", label="Belief Particles", markersize=pws, marker=stroke(0.1, 0.3)) plot!(p, 0:length(vsh)-1, vsh, linewidth=3, color="orangered", label="Trajectory") return p end; using ParticleFilters rng = MersenneTwister(7) p = SimpleLightDark() pf = SIRParticleFilter(p, 10000, rng=rng) h = sim(p, updater=pf, initial_state=1, initial_obs=initial_state_distribution(p), max_steps=80, rng=rng) do b return rand(rng, [-1,1]) end plothist(p, h) using QMDP solver = QMDPSolver() policy = solve(solver, p, verbose=true); srand(rng, 14) hr = HistoryRecorder(max_steps=80, initial_state=1) h = simulate(hr, p, policy, pf) plothist(p, h, "QMDP") using BasicPOMCP using DiscreteValueIteration srand(rng, 7) max_depth = 20 max_time = 0.01 ro = ValueIterationSolver() sol = POMCPSolver(max_depth=max_depth, max_time=max_time, c=100.0, tree_queries=typemax(Int), # estimate_value=FOValue(ro), estimate_value=FORollout(RandomSolver()), rng=rng, tree_in_info=true ) planner = solve(sol, p) h = simulate(hr, p, planner, pf) plothist(p, h) using D3Trees b = initial_state_distribution(p) a, i = action_info(planner, b) D3Tree(i[:tree], init_expand=2) using POMCPOW srand(rng, 7) max_time=0.01 sol = POMCPOWSolver(tree_queries=10_000_000, criterion=MaxUCB(90.0), max_depth=max_depth, max_time=max_time, enable_action_pw=false, k_observation=5.0, alpha_observation=1/15.0, # estimate_value=FOValue(ro), estimate_value=FORollout(RandomSolver()), check_repeat_obs=false, tree_in_info=true, rng=rng ) planner = solve(sol, p) h = simulate(hr, p, planner, pf) plothist(p, h) using D3Trees b = initial_state_distribution(p) a, i = action_info(planner, b) D3Tree(i[:tree], init_expand=2)