importall POMDPs
using POMDPToolbox
using Distributions
using Parameters
using Plots
using StaticArrays
@with_kw struct SimpleLightDark <: POMDPs.POMDP{Int,Int,Float64}
discount::Float64 = 0.95
correct_r::Float64 = 100.0
incorrect_r::Float64 = -100.0
light_loc::Int = 10
radius::Int = 60
end
discount(p::SimpleLightDark) = p.discount
isterminal(p::SimpleLightDark, s::Number) = !(s in -p.radius:p.radius)
const ACTIONS = [-10, -1, 0, 1, 10]
actions(p::SimpleLightDark) = ACTIONS
n_actions(p::SimpleLightDark) = length(actions(p))
const ACTION_INDS = Dict(a=>i for (i,a) in enumerate(actions(SimpleLightDark())))
action_index(p::SimpleLightDark, a::Int) = ACTION_INDS[a]
states(p::SimpleLightDark) = -p.radius:p.radius + 1
n_states(p::SimpleLightDark) = length(states(p))
state_index(p::SimpleLightDark, s::Int) = s+p.radius+1
function transition(p::SimpleLightDark, s::Int, a::Int)
if a == 0
return SparseCat(SVector(p.radius+1), SVector(1.0))
else
return SparseCat(SVector(clamp(s+a, -p.radius, p.radius)), SVector(1.0))
end
end
observation(p::SimpleLightDark, sp) = Normal(sp, abs(sp - p.light_loc) + 0.0001)
function reward(p::SimpleLightDark, s, a)
if a == 0
return s == 0 ? p.correct_r : p.incorrect_r
else
return -1.0
end
end
function initial_state_distribution(p::SimpleLightDark)
ps = ones(2*div(p.radius,2)+1)
ps /= length(ps)
return SparseCat(div(-p.radius,2):div(p.radius,2), ps)
end;
using Plots
function plothist(pomdp, hist, heading="LightDark")
tmax = 80
smin = -10
smax = 20
vsh = collect(filter(s->!isterminal(pomdp,s), state_hist(hist)[1:end-1]))
bh = belief_hist(hist)
pts = Int[]
pss = Int[]
pws = Float64[]
for t in 0:length(bh)-1
b = bh[t+1]
for s in smin:smax
w = 10.0*sqrt(pdf(b, s))
if 0.0<w<1.0
w = 1.0
end
push!(pts, t)
push!(pss, s)
push!(pws, w)
end
end
T = linspace(0.0, tmax)
S = linspace(-1.0, 21.0)
inv_grays = cgrad([RGB(1.0, 1.0, 1.0),RGB(0.0,0.0,0.0)])
p = contour(T, S, (t,s)->abs(s-pomdp.light_loc),
bg_inside=:black,
fill=true,
xlim=(0, tmax),
ylim=(smin, smax),
color=inv_grays,
xlabel="Time",
ylabel="State",
cbar=false,
legend=:topright,
title=@sprintf("%s (Reward: %8.2f)", heading, discounted_reward(hist))
)
plot!(p, [0, tmax], [0, 0], linewidth=1, color="green", label="Goal", line=:dash)
scatter!(p, pts, pss, color="lightblue", label="Belief Particles", markersize=pws, marker=stroke(0.1, 0.3))
plot!(p, 0:length(vsh)-1, vsh, linewidth=3, color="orangered", label="Trajectory")
return p
end;
using ParticleFilters
rng = MersenneTwister(7)
p = SimpleLightDark()
pf = SIRParticleFilter(p, 10000, rng=rng)
h = sim(p, updater=pf, initial_state=1, initial_obs=initial_state_distribution(p), max_steps=80, rng=rng) do b
return rand(rng, [-1,1])
end
plothist(p, h)