importall POMDPs
using POMDPToolbox
using Distributions
using Parameters
using Plots
using StaticArrays
@with_kw struct SimpleLightDark <: POMDPs.POMDP{Int,Int,Float64}
discount::Float64 = 0.95
correct_r::Float64 = 100.0
incorrect_r::Float64 = -100.0
light_loc::Int = 10
radius::Int = 60
end
discount(p::SimpleLightDark) = p.discount
isterminal(p::SimpleLightDark, s::Number) = !(s in -p.radius:p.radius)
const ACTIONS = [-10, -1, 0, 1, 10]
actions(p::SimpleLightDark) = ACTIONS
n_actions(p::SimpleLightDark) = length(actions(p))
const ACTION_INDS = Dict(a=>i for (i,a) in enumerate(actions(SimpleLightDark())))
action_index(p::SimpleLightDark, a::Int) = ACTION_INDS[a]
states(p::SimpleLightDark) = -p.radius:p.radius + 1
n_states(p::SimpleLightDark) = length(states(p))
state_index(p::SimpleLightDark, s::Int) = s+p.radius+1
function transition(p::SimpleLightDark, s::Int, a::Int)
if a == 0
return SparseCat(SVector(p.radius+1), SVector(1.0))
else
return SparseCat(SVector(clamp(s+a, -p.radius, p.radius)), SVector(1.0))
end
end
observation(p::SimpleLightDark, sp) = Normal(sp, abs(sp - p.light_loc) + 0.0001)
function reward(p::SimpleLightDark, s, a)
if a == 0
return s == 0 ? p.correct_r : p.incorrect_r
else
return -1.0
end
end
function initial_state_distribution(p::SimpleLightDark)
ps = ones(2*div(p.radius,2)+1)
ps /= length(ps)
return SparseCat(div(-p.radius,2):div(p.radius,2), ps)
end;
using Plots
function plothist(pomdp, hist, heading="LightDark")
tmax = 80
smin = -10
smax = 20
vsh = collect(filter(s->!isterminal(pomdp,s), state_hist(hist)[1:end-1]))
bh = belief_hist(hist)
pts = Int[]
pss = Int[]
pws = Float64[]
for t in 0:length(bh)-1
b = bh[t+1]
for s in smin:smax
w = 10.0*sqrt(pdf(b, s))
if 0.0<w<1.0
w = 1.0
end
push!(pts, t)
push!(pss, s)
push!(pws, w)
end
end
T = linspace(0.0, tmax)
S = linspace(-1.0, 21.0)
inv_grays = cgrad([RGB(1.0, 1.0, 1.0),RGB(0.0,0.0,0.0)])
p = contour(T, S, (t,s)->abs(s-pomdp.light_loc),
bg_inside=:black,
fill=true,
xlim=(0, tmax),
ylim=(smin, smax),
color=inv_grays,
xlabel="Time",
ylabel="State",
cbar=false,
legend=:topright,
title=@sprintf("%s (Reward: %8.2f)", heading, discounted_reward(hist))
)
plot!(p, [0, tmax], [0, 0], linewidth=1, color="green", label="Goal", line=:dash)
scatter!(p, pts, pss, color="lightblue", label="Belief Particles", markersize=pws, marker=stroke(0.1, 0.3))
plot!(p, 0:length(vsh)-1, vsh, linewidth=3, color="orangered", label="Trajectory")
return p
end;
using ParticleFilters
rng = MersenneTwister(7)
p = SimpleLightDark()
pf = SIRParticleFilter(p, 10000, rng=rng)
h = sim(p, updater=pf, initial_state=1, initial_obs=initial_state_distribution(p), max_steps=80, rng=rng) do b
return rand(rng, [-1,1])
end
plothist(p, h)
using QMDP
solver = QMDPSolver()
policy = solve(solver, p, verbose=true);
[Iteration 1 ] residual: 100 | iteration runtime: 0.084 ms, ( 8.41E-05 s total) [Iteration 2 ] residual: 95 | iteration runtime: 0.040 ms, ( 0.000124 s total) [Iteration 3 ] residual: 90.3 | iteration runtime: 0.029 ms, ( 0.000153 s total) [Iteration 4 ] residual: 85.7 | iteration runtime: 0.026 ms, ( 0.000179 s total) [Iteration 5 ] residual: 81.5 | iteration runtime: 0.026 ms, ( 0.000205 s total) [Iteration 6 ] residual: 77.4 | iteration runtime: 0.026 ms, ( 0.000231 s total) [Iteration 7 ] residual: 73.5 | iteration runtime: 0.026 ms, ( 0.000256 s total) [Iteration 8 ] residual: 8.17 | iteration runtime: 0.026 ms, ( 0.000282 s total) [Iteration 9 ] residual: 3.98 | iteration runtime: 0.026 ms, ( 0.000308 s total) [Iteration 10 ] residual: 0 | iteration runtime: 0.027 ms, ( 0.000335 s total)
srand(rng, 14)
hr = HistoryRecorder(max_steps=80, initial_state=1)
h = simulate(hr, p, policy, pf)
plothist(p, h, "QMDP")
WARNING: The initial_state argument for HistoryRecorder is deprecated. The initial state should be specified as the last argument to simulate(...).
using BasicPOMCP
using DiscreteValueIteration
srand(rng, 7)
max_depth = 20
max_time = 0.01
ro = ValueIterationSolver()
sol = POMCPSolver(max_depth=max_depth,
max_time=max_time,
c=100.0,
tree_queries=typemax(Int),
# estimate_value=FOValue(ro),
estimate_value=FORollout(RandomSolver()),
rng=rng,
tree_in_info=true
)
planner = solve(sol, p)
h = simulate(hr, p, planner, pf)
plothist(p, h)
using D3Trees
b = initial_state_distribution(p)
a, i = action_info(planner, b)
D3Tree(i[:tree], init_expand=2)
Attempting to display the tree. If the tree is large, this may take some time.
Note: D3Trees.jl requires an internet connection. If no tree appears, please check your connection. To help fix this, please see this issue. You may also diagnose errors with the javascript console (Ctrl-Shift-J in chrome).
using POMCPOW
srand(rng, 7)
max_time=0.01
sol = POMCPOWSolver(tree_queries=10_000_000,
criterion=MaxUCB(90.0),
max_depth=max_depth,
max_time=max_time,
enable_action_pw=false,
k_observation=5.0,
alpha_observation=1/15.0,
# estimate_value=FOValue(ro),
estimate_value=FORollout(RandomSolver()),
check_repeat_obs=false,
tree_in_info=true,
rng=rng
)
planner = solve(sol, p)
h = simulate(hr, p, planner, pf)
plothist(p, h)
using D3Trees
b = initial_state_distribution(p)
a, i = action_info(planner, b)
D3Tree(i[:tree], init_expand=2)
Attempting to display the tree. If the tree is large, this may take some time.
Note: D3Trees.jl requires an internet connection. If no tree appears, please check your connection. To help fix this, please see this issue. You may also diagnose errors with the javascript console (Ctrl-Shift-J in chrome).