그로킹 심층 강화학습 중 7장 내용인 "조금 더 효율적인 방법으로 목표에 도달하기"에 대한 내용입니다.
Note: 실행을 위해 아래의 패키지들을 설치해주기 바랍니다.
#collapse
!pip install tqdm numpy scikit-learn pyglet setuptools && \
!pip install gym asciinema pandas tabulate tornado==5.* PyBullet && \
!pip install git+https://github.com/pybox2d/pybox2d#egg=Box2D && \
!pip install git+https://github.com/mimoralea/gym-bandits#egg=gym-bandits && \
!pip install git+https://github.com/mimoralea/gym-walk#egg=gym-walk && \
!pip install git+https://github.com/mimoralea/gym-aima#egg=gym-aima && \
!pip install gym[atari]
import warnings ; warnings.filterwarnings('ignore')
import itertools
import gym, gym_walk, gym_aima
import numpy as np
from tabulate import tabulate
from pprint import pprint
from tqdm import tqdm_notebook as tqdm
from mpl_toolkits.mplot3d import Axes3D
from itertools import cycle, count
import random
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
SEEDS = (12, 34, 56, 78, 90)
%matplotlib inline
plt.style.use('fivethirtyeight')
params = {
'figure.figsize': (15, 8),
'font.size': 24,
'legend.fontsize': 20,
'axes.titlesize': 28,
'axes.labelsize': 24,
'xtick.labelsize': 20,
'ytick.labelsize': 20
}
pylab.rcParams.update(params)
np.set_printoptions(suppress=True)
def value_iteration(P, gamma=1.0, theta=1e-10):
V = np.zeros(len(P), dtype=np.float64)
while True:
Q = np.zeros((len(P), len(P[0])), dtype=np.float64)
for s in range(len(P)):
for a in range(len(P[s])):
for prob, next_state, reward, done in P[s][a]:
Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
if np.max(np.abs(V - np.max(Q, axis=1))) < theta:
break
V = np.max(Q, axis=1)
pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
return Q, V, pi
def print_policy(pi, P, action_symbols=('<', 'v', '>', '^'), n_cols=4, title='정책:'):
print(title)
arrs = {k:v for k,v in enumerate(action_symbols)}
for s in range(len(P)):
a = pi(s)
print("| ", end="")
if np.all([done for action in P[s].values() for _, _, _, done in action]):
print("".rjust(9), end=" ")
else:
print(str(s).zfill(2), arrs[a].rjust(6), end=" ")
if (s + 1) % n_cols == 0: print("|")
def print_state_value_function(V, P, n_cols=4, prec=3, title='상태-가치 함수:'):
print(title)
for s in range(len(P)):
v = V[s]
print("| ", end="")
if np.all([done for action in P[s].values() for _, _, _, done in action]):
print("".rjust(9), end=" ")
else:
print(str(s).zfill(2), '{}'.format(np.round(v, prec)).rjust(6), end=" ")
if (s + 1) % n_cols == 0: print("|")
def print_action_value_function(Q,
optimal_Q=None,
action_symbols=('<', '>'),
prec=3,
title='행동-가치 함수:'):
vf_types=('',) if optimal_Q is None else ('', '*', 'er')
headers = ['s',] + [' '.join(i) for i in list(itertools.product(vf_types, action_symbols))]
print(title)
states = np.arange(len(Q))[..., np.newaxis]
arr = np.hstack((states, np.round(Q, prec)))
if not (optimal_Q is None):
arr = np.hstack((arr, np.round(optimal_Q, prec), np.round(optimal_Q-Q, prec)))
print(tabulate(arr, headers, tablefmt="fancy_grid"))
def get_policy_metrics(env, gamma, pi, goal_state, optimal_Q,
n_episodes=100, max_steps=200):
random.seed(123); np.random.seed(123) ; env.seed(123)
reached_goal, episode_reward, episode_regret = [], [], []
for _ in range(n_episodes):
state, done, steps = env.reset(), False, 0
episode_reward.append(0.0)
episode_regret.append(0.0)
while not done and steps < max_steps:
action = pi(state)
regret = np.max(optimal_Q[state]) - optimal_Q[state][action]
episode_regret[-1] += regret
state, reward, done, _ = env.step(action)
episode_reward[-1] += (gamma**steps * reward)
steps += 1
reached_goal.append(state == goal_state)
results = np.array((np.sum(reached_goal)/len(reached_goal)*100,
np.mean(episode_reward),
np.mean(episode_regret)))
return results
def get_metrics_from_tracks(env, gamma, goal_state, optimal_Q, pi_track, coverage=0.1):
total_samples = len(pi_track)
n_samples = int(total_samples * coverage)
samples_e = np.linspace(0, total_samples, n_samples, endpoint=True, dtype=np.int)
metrics = []
for e, pi in enumerate(tqdm(pi_track)):
if e in samples_e:
metrics.append(get_policy_metrics(
env,
gamma=gamma,
pi=lambda s: pi[s],
goal_state=goal_state,
optimal_Q=optimal_Q))
else:
metrics.append(metrics[-1])
metrics = np.array(metrics)
success_rate_ma, mean_return_ma, mean_regret_ma = np.apply_along_axis(moving_average, axis=0, arr=metrics).T
return success_rate_ma, mean_return_ma, mean_regret_ma
def rmse(x, y, dp=4):
return np.round(np.sqrt(np.mean((x - y)**2)), dp)
def moving_average(a, n=100) :
ret = np.cumsum(a, dtype=float)
ret[n:] = ret[n:] - ret[:-n]
return ret[n - 1:] / n
def plot_value_function(title, V_track, V_true=None, log=False, limit_value=0.05, limit_items=5):
np.random.seed(123)
per_col = 25
linecycler = cycle(["-","--",":","-."])
legends = []
valid_values = np.argwhere(V_track[-1] > limit_value).squeeze()
items_idxs = np.random.choice(valid_values,
min(len(valid_values), limit_items),
replace=False)
# 첫번째 참값을 뽑아냅니다.
if V_true is not None:
for i, state in enumerate(V_track.T):
if i not in items_idxs:
continue
if state[-1] < limit_value:
continue
label = 'v*({})'.format(i)
plt.axhline(y=V_true[i], color='k', linestyle='-', linewidth=1)
plt.text(int(len(V_track)*1.02), V_true[i]+.01, label)
# 이에 대한 추정치를 계산합니다.
for i, state in enumerate(V_track.T):
if i not in items_idxs:
continue
if state[-1] < limit_value:
continue
line_type = next(linecycler)
label = 'V({})'.format(i)
p, = plt.plot(state, line_type, label=label, linewidth=3)
legends.append(p)
legends.reverse()
ls = []
for loc, idx in enumerate(range(0, len(legends), per_col)):
subset = legends[idx:idx+per_col]
l = plt.legend(subset, [p.get_label() for p in subset],
loc='center right', bbox_to_anchor=(1.25, 0.5))
ls.append(l)
[plt.gca().add_artist(l) for l in ls[:-1]]
if log: plt.xscale('log')
plt.title(title)
plt.ylabel('State-value function')
plt.xlabel('Episodes (log scale)' if log else 'Episodes')
plt.show()
def plot_transition_model(T_track, episode = 0):
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111, projection='3d')
ax.view_init(elev=20, azim=50)
color_left = '#008fd5' # ax._get_lines.get_next_color()
color_right = '#fc4f30' #ax._get_lines.get_next_color()
left_prob = np.divide(T_track[episode][:,0].T,
T_track[episode][:,0].sum(axis=1).T).T
left_prob = np.nan_to_num(left_prob, 0)
right_prob = np.divide(T_track[episode][:,1].T,
T_track[episode][:,1].sum(axis=1).T).T
right_prob = np.nan_to_num(right_prob, 0)
for s in np.arange(9):
ax.bar3d(s+0.1, np.arange(9)+0.1, np.zeros(9),
np.zeros(9)+0.3,
np.zeros(9)+0.3,
left_prob[s],
color=color_left,
alpha=0.75,
shade=True)
ax.bar3d(s+0.1, np.arange(9)+0.1, left_prob[s],
np.zeros(9)+0.3,
np.zeros(9)+0.3,
right_prob[s],
color=color_right,
alpha=0.75,
shade=True)
ax.tick_params(axis='x', which='major', pad=10)
ax.tick_params(axis='y', which='major', pad=10)
ax.tick_params(axis='z', which='major', pad=10)
ax.xaxis.set_rotate_label(False)
ax.yaxis.set_rotate_label(False)
ax.zaxis.set_rotate_label(False)
ax.set_xticks(np.arange(9))
ax.set_yticks(np.arange(9))
plt.title('SWS learned MDP after {} episodes'.format(episode+1))
ax.set_xlabel('Initial\nstate', labelpad=75, rotation=0)
ax.set_ylabel('Landing\nstate', labelpad=75, rotation=0)
ax.set_zlabel('Transition\nprobabilities', labelpad=75, rotation=0)
left_proxy = plt.Rectangle((0, 0), 1, 1, fc=color_left)
right_proxy = plt.Rectangle((0, 0), 1, 1, fc=color_right)
plt.legend((left_proxy, right_proxy),
('Left', 'Right'),
bbox_to_anchor=(0.15, 0.9),
borderaxespad=0.)
ax.dist = 12
#plt.gcf().subplots_adjust(left=0.1, right=0.9)
plt.tight_layout()
plt.show()
def plot_model_state_sampling(planning, algo='Dyna-Q'):
fig = plt.figure(figsize=(20,10))
color_left = '#008fd5' # ax._get_lines.get_next_color()
color_right = '#fc4f30' #ax._get_lines.get_next_color()
for s in np.arange(9):
actions = planning[np.where(planning[:,0]==s)[0], 1]
left = len(actions[actions == 0])
right = len(actions[actions == 1])
plt.bar(s, right, 0.2, color=color_right)
plt.bar(s, left, 0.2, color=color_left, bottom=right)
plt.title('States samples from {}\nlearned model of SWS environment'.format(algo))
plt.xticks(range(9))
plt.xlabel('Initial states sampled', labelpad=20)
plt.ylabel('Count', labelpad=50, rotation=0)
left_proxy = plt.Rectangle((0, 0), 1, 1, fc=color_left)
right_proxy = plt.Rectangle((0, 0), 1, 1, fc=color_right)
plt.legend((left_proxy, right_proxy),
('Left', 'Right'),
bbox_to_anchor=(0.99, 1.1),
borderaxespad=0.)
#plt.gcf().subplots_adjust(left=0.1, right=0.9)
plt.tight_layout()
plt.show()
def plot_model_state_7(planning, algo='Dyna-Q'):
fig = plt.figure(figsize=(20,10))
color_left = '#008fd5' # ax._get_lines.get_next_color()
color_right = '#fc4f30' #ax._get_lines.get_next_color()
state_7 = planning[np.where(planning[:,0]==7)]
for sp in [6, 7, 8]:
actions = state_7[np.where(state_7[:,3]==sp)[0], 1]
left = len(actions[actions == 0])
right = len(actions[actions == 1])
plt.bar(sp, right, 0.2, color=color_right)
plt.bar(sp, left, 0.2, color=color_left, bottom=right)
plt.title('Next states samples by {}\nin SWS environment from state 7'.format(algo))
plt.xticks([6,7,8])
plt.xlabel('Landing states', labelpad=20)
plt.ylabel('Count', labelpad=50, rotation=0)
left_proxy = plt.Rectangle((0, 0), 1, 1, fc=color_left)
right_proxy = plt.Rectangle((0, 0), 1, 1, fc=color_right)
plt.legend((left_proxy, right_proxy),
('Left', 'Right'),
bbox_to_anchor=(0.99, 1.1),
borderaxespad=0.)
#plt.gcf().subplots_adjust(left=0.1, right=0.9)
plt.tight_layout()
plt.show()
def decay_schedule(init_value, min_value, decay_ratio, max_steps, log_start=-2, log_base=10):
decay_steps = int(max_steps * decay_ratio)
rem_steps = max_steps - decay_steps
values = np.logspace(log_start, 0, decay_steps, base=log_base, endpoint=True)[::-1]
values = (values - values.min()) / (values.max() - values.min())
values = (init_value - min_value) * values + min_value
values = np.pad(values, (0, rem_steps), 'edge')
return values
env = gym.make('SlipperyWalkSeven-v0')
init_state = env.reset()
goal_state = 8
gamma = 0.99
n_episodes = 3000
P = env.env.P
n_cols, svf_prec, err_prec, avf_prec=9, 4, 2, 3
action_symbols=('<', '>')
limit_items, limit_value = 5, 0.0
cu_limit_items, cu_limit_value, cu_episodes = 10, 0.0, 100
plt.plot(decay_schedule(0.5, 0.01, 0.5, n_episodes),
'-', linewidth=2,
label='Alpha schedule')
plt.plot(decay_schedule(1.0, 0.1, 0.9, n_episodes),
':', linewidth=2,
label='Epsilon schedule')
plt.legend(loc=1, ncol=1)
plt.title('Alpha and epsilon schedules')
plt.xlabel('Episodes')
plt.ylabel('Hyperparameter values')
plt.xticks(rotation=45)
plt.show()
optimal_Q, optimal_V, optimal_pi = value_iteration(P, gamma=gamma)
print_state_value_function(optimal_V, P, n_cols=n_cols, prec=svf_prec, title='Optimal state-value function:')
print()
print_action_value_function(optimal_Q,
None,
action_symbols=action_symbols,
prec=avf_prec,
title='Optimal action-value function:')
print()
print_policy(optimal_pi, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_op, mean_return_op, mean_regret_op = get_policy_metrics(
env, gamma=gamma, pi=optimal_pi, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_op, mean_return_op, mean_regret_op))
Optimal state-value function: | | 01 0.5637 | 02 0.763 | 03 0.8449 | 04 0.8892 | 05 0.922 | 06 0.9515 | 07 0.9806 | | Optimal action-value function: ╒═════╤═══════╤═══════╕ │ s │ < │ > │ ╞═════╪═══════╪═══════╡ │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┤ │ 1 │ 0.312 │ 0.564 │ ├─────┼───────┼───────┤ │ 2 │ 0.67 │ 0.763 │ ├─────┼───────┼───────┤ │ 3 │ 0.803 │ 0.845 │ ├─────┼───────┼───────┤ │ 4 │ 0.864 │ 0.889 │ ├─────┼───────┼───────┤ │ 5 │ 0.901 │ 0.922 │ ├─────┼───────┼───────┤ │ 6 │ 0.932 │ 0.952 │ ├─────┼───────┼───────┤ │ 7 │ 0.961 │ 0.981 │ ├─────┼───────┼───────┤ │ 8 │ 0 │ 0 │ ╘═════╧═══════╧═══════╛ 정책: | | 01 > | 02 > | 03 > | 04 > | 05 > | 06 > | 07 > | | Reaches goal 96.00%. Obtains an average return of 0.8548. Regret of 0.0000
def sarsa_lambda(env,
gamma=1.0,
init_alpha=0.5,
min_alpha=0.01,
alpha_decay_ratio=0.5,
init_epsilon=1.0,
min_epsilon=0.1,
epsilon_decay_ratio=0.9,
lambda_=0.5,
replacing_traces=True,
n_episodes=3000):
nS, nA = env.observation_space.n, env.action_space.n
pi_track = []
Q = np.zeros((nS, nA), dtype=np.float64)
Q_track = np.zeros((n_episodes, nS, nA),
dtype=np.float64)
E = np.zeros((nS, nA), dtype=np.float64)
select_action = lambda state, Q, epsilon: \
np.argmax(Q[state]) \
if np.random.random() > epsilon \
else np.random.randint(len(Q[state]))
alphas = decay_schedule(
init_alpha, min_alpha,
alpha_decay_ratio, n_episodes)
epsilons = decay_schedule(
init_epsilon, min_epsilon,
epsilon_decay_ratio, n_episodes)
for e in tqdm(range(n_episodes), leave=False):
E.fill(0)
state, done = env.reset(), False
action = select_action(state, Q, epsilons[e])
while not done:
next_state, reward, done, _ = env.step(action)
next_action = select_action(next_state, Q, epsilons[e])
td_target = reward + gamma * Q[next_state][next_action] * (not done)
td_error = td_target - Q[state][action]
if replacing_traces: E[state].fill(0)
E[state][action] = E[state][action] + 1
if replacing_traces: E.clip(0, 1, out=E)
Q = Q + alphas[e] * td_error * E
E = gamma * lambda_ * E
state, action = next_state, next_action
Q_track[e] = Q
pi_track.append(np.argmax(Q, axis=1))
V = np.max(Q, axis=1)
pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
return Q, V, pi, Q_track, pi_track
Q_rsls, V_rsls, Q_track_rsls = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_rsl, V_rsl, pi_rsl, Q_track_rsl, pi_track_rsl = sarsa_lambda(env, gamma=gamma, n_episodes=n_episodes)
Q_rsls.append(Q_rsl) ; V_rsls.append(V_rsl) ; Q_track_rsls.append(Q_track_rsl)
Q_rsl, V_rsl, Q_track_rsl = np.mean(Q_rsls, axis=0), np.mean(V_rsls, axis=0), np.mean(Q_track_rsls, axis=0)
del Q_rsls ; del V_rsls ; del Q_track_rsls
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
print_state_value_function(V_rsl, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Sarsa(λ) replacing:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_rsl - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_rsl, optimal_V)))
print()
print_action_value_function(Q_rsl,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Sarsa(λ) replacing action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_rsl, optimal_Q)))
print()
print_policy(pi_rsl, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_rsl, mean_return_rsl, mean_regret_rsl = get_policy_metrics(
env, gamma=gamma, pi=pi_rsl, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_rsl, mean_return_rsl, mean_regret_rsl))
State-value function found by Sarsa(λ) replacing: | | 01 0.4672 | 02 0.6985 | 03 0.8056 | 04 0.8656 | 05 0.9102 | 06 0.9436 | 07 0.9773 | | Optimal state-value function: | | 01 0.5637 | 02 0.763 | 03 0.8449 | 04 0.8892 | 05 0.922 | 06 0.9515 | 07 0.9806 | | State-value function errors: | | 01 -0.1 | 02 -0.06 | 03 -0.04 | 04 -0.02 | 05 -0.01 | 06 -0.01 | 07 -0.0 | | State-value function RMSE: 0.0419 Sarsa(λ) replacing action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╕ │ s │ < │ > │ * < │ * > │ er < │ er > │ ╞═════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╡ │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 1 │ 0.195 │ 0.467 │ 0.312 │ 0.564 │ 0.117 │ 0.097 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 2 │ 0.512 │ 0.698 │ 0.67 │ 0.763 │ 0.158 │ 0.065 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 3 │ 0.713 │ 0.806 │ 0.803 │ 0.845 │ 0.091 │ 0.039 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 4 │ 0.825 │ 0.866 │ 0.864 │ 0.889 │ 0.038 │ 0.024 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 5 │ 0.878 │ 0.91 │ 0.901 │ 0.922 │ 0.023 │ 0.012 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 6 │ 0.921 │ 0.944 │ 0.932 │ 0.952 │ 0.012 │ 0.008 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 7 │ 0.954 │ 0.977 │ 0.961 │ 0.981 │ 0.008 │ 0.003 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 8 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╛ Action-value function RMSE: 0.06 정책: | | 01 > | 02 > | 03 > | 04 > | 05 > | 06 > | 07 > | | Reaches goal 96.00%. Obtains an average return of 0.8548. Regret of 0.0000
Q_asls, V_asls, Q_track_asls = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_asl, V_asl, pi_asl, Q_track_asl, pi_track_asl = sarsa_lambda(env, gamma=gamma,
replacing_traces=False,
n_episodes=n_episodes)
Q_asls.append(Q_asl) ; V_asls.append(V_asl) ; Q_track_asls.append(Q_track_asl)
Q_asl, V_asl, Q_track_asl = np.mean(Q_asls, axis=0), np.mean(V_asls, axis=0), np.mean(Q_track_asls, axis=0)
del Q_asls ; del V_asls ; del Q_track_asls
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
print_state_value_function(V_asl, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Sarsa(λ) accumulating:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_asl - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_asl, optimal_V)))
print()
print_action_value_function(Q_asl,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Sarsa(λ) accumulating action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_asl, optimal_Q)))
print()
print_policy(pi_asl, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_asl, mean_return_asl, mean_regret_asl = get_policy_metrics(
env, gamma=gamma, pi=pi_asl, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_asl, mean_return_asl, mean_regret_asl))
State-value function found by Sarsa(λ) accumulating: | | 01 0.4814 | 02 0.7085 | 03 0.8168 | 04 0.8683 | 05 0.9082 | 06 0.9443 | 07 0.9783 | | Optimal state-value function: | | 01 0.5637 | 02 0.763 | 03 0.8449 | 04 0.8892 | 05 0.922 | 06 0.9515 | 07 0.9806 | | State-value function errors: | | 01 -0.08 | 02 -0.05 | 03 -0.03 | 04 -0.02 | 05 -0.01 | 06 -0.01 | 07 -0.0 | | State-value function RMSE: 0.0353 Sarsa(λ) accumulating action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╕ │ s │ < │ > │ * < │ * > │ er < │ er > │ ╞═════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╡ │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 1 │ 0.207 │ 0.481 │ 0.312 │ 0.564 │ 0.105 │ 0.082 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 2 │ 0.539 │ 0.709 │ 0.67 │ 0.763 │ 0.131 │ 0.055 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 3 │ 0.727 │ 0.817 │ 0.803 │ 0.845 │ 0.077 │ 0.028 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 4 │ 0.829 │ 0.868 │ 0.864 │ 0.889 │ 0.035 │ 0.021 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 5 │ 0.881 │ 0.908 │ 0.901 │ 0.922 │ 0.02 │ 0.014 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 6 │ 0.918 │ 0.944 │ 0.932 │ 0.952 │ 0.014 │ 0.007 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 7 │ 0.954 │ 0.978 │ 0.961 │ 0.981 │ 0.007 │ 0.002 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 8 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╛ Action-value function RMSE: 0.0511 정책: | | 01 > | 02 > | 03 > | 04 > | 05 > | 06 > | 07 > | | Reaches goal 96.00%. Obtains an average return of 0.8548. Regret of 0.0000
def q_lambda(env,
gamma=1.0,
init_alpha=0.5,
min_alpha=0.01,
alpha_decay_ratio=0.5,
init_epsilon=1.0,
min_epsilon=0.1,
epsilon_decay_ratio=0.9,
lambda_=0.5,
replacing_traces=True,
n_episodes=3000):
nS, nA = env.observation_space.n, env.action_space.n
pi_track = []
Q = np.zeros((nS, nA), dtype=np.float64)
Q_track = np.zeros((n_episodes, nS, nA), dtype=np.float64)
E = np.zeros((nS, nA), dtype=np.float64)
select_action = lambda state, Q, epsilon: \
np.argmax(Q[state]) \
if np.random.random() > epsilon \
else np.random.randint(len(Q[state]))
alphas = decay_schedule(
init_alpha, min_alpha,
alpha_decay_ratio, n_episodes)
epsilons = decay_schedule(
init_epsilon, min_epsilon,
epsilon_decay_ratio, n_episodes)
for e in tqdm(range(n_episodes), leave=False):
E.fill(0)
state, done = env.reset(), False
action = select_action(state, Q, epsilons[e])
while not done:
next_state, reward, done, _ = env.step(action)
next_action = select_action(next_state, Q, epsilons[e])
next_action_is_greedy = Q[next_state][next_action] == Q[next_state].max()
td_target = reward + gamma * Q[next_state].max() * (not done)
td_error = td_target - Q[state][action]
if replacing_traces: E[state].fill(0)
E[state][action] = E[state][action] + 1
if replacing_traces: E.clip(0, 1, out=E)
Q = Q + alphas[e] * td_error * E
if next_action_is_greedy:
E = gamma * lambda_ * E
else:
E.fill(0)
state, action = next_state, next_action
Q_track[e] = Q
pi_track.append(np.argmax(Q, axis=1))
V = np.max(Q, axis=1)
pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
return Q, V, pi, Q_track, pi_track
Q_rqlls, V_rqlls, Q_track_rqlls = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_rqll, V_rqll, pi_rqll, Q_track_rqll, pi_track_rqll = q_lambda(env, gamma=gamma, n_episodes=n_episodes)
Q_rqlls.append(Q_rqll) ; V_rqlls.append(V_rqll) ; Q_track_rqlls.append(Q_track_rqll)
Q_rqll, V_rqll, Q_track_rqll = np.mean(Q_rqlls, axis=0), np.mean(V_rqlls, axis=0), np.mean(Q_track_rqlls, axis=0)
del Q_rqlls ; del V_rqlls ; del Q_track_rqlls
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
print_state_value_function(V_rqll, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Q(λ) replacing:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_rqll - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_rqll, optimal_V)))
print()
print_action_value_function(Q_rqll,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Q(λ) replacing action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_rqll, optimal_Q)))
print()
print_policy(pi_rqll, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_rqll, mean_return_rqll, mean_regret_rqll = get_policy_metrics(
env, gamma=gamma, pi=pi_rqll, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_rqll, mean_return_rqll, mean_regret_rqll))
State-value function found by Q(λ) replacing: | | 01 0.5641 | 02 0.7718 | 03 0.8443 | 04 0.8878 | 05 0.9231 | 06 0.9537 | 07 0.9817 | | Optimal state-value function: | | 01 0.5637 | 02 0.763 | 03 0.8449 | 04 0.8892 | 05 0.922 | 06 0.9515 | 07 0.9806 | | State-value function errors: | | 01 0.0 | 02 0.01 | 03 -0.0 | 04 -0.0 | 05 0.0 | 06 0.0 | 07 0.0 | | State-value function RMSE: 0.0031 Q(λ) replacing action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╕ │ s │ < │ > │ * < │ * > │ er < │ er > │ ╞═════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╡ │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 1 │ 0.286 │ 0.564 │ 0.312 │ 0.564 │ 0.026 │ -0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 2 │ 0.661 │ 0.772 │ 0.67 │ 0.763 │ 0.01 │ -0.009 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 3 │ 0.805 │ 0.844 │ 0.803 │ 0.845 │ -0.001 │ 0.001 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 4 │ 0.865 │ 0.888 │ 0.864 │ 0.889 │ -0.001 │ 0.001 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 5 │ 0.901 │ 0.923 │ 0.901 │ 0.922 │ 0 │ -0.001 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 6 │ 0.933 │ 0.954 │ 0.932 │ 0.952 │ -0.001 │ -0.002 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 7 │ 0.962 │ 0.982 │ 0.961 │ 0.981 │ -0.001 │ -0.001 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 8 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╛ Action-value function RMSE: 0.007 정책: | | 01 > | 02 > | 03 > | 04 > | 05 > | 06 > | 07 > | | Reaches goal 96.00%. Obtains an average return of 0.8548. Regret of 0.0000
Q_aqlls, V_aqlls, Q_track_aqlls = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_aqll, V_aqll, pi_aqll, Q_track_aqll, pi_track_aqll = q_lambda(env, gamma=gamma,
replacing_traces=False,
n_episodes=n_episodes)
Q_aqlls.append(Q_aqll) ; V_aqlls.append(V_aqll) ; Q_track_aqlls.append(Q_track_aqll)
Q_aqll, V_aqll, Q_track_aqll = np.mean(Q_aqlls, axis=0), np.mean(V_aqlls, axis=0), np.mean(Q_track_aqlls, axis=0)
del Q_aqlls ; del V_aqlls ; del Q_track_aqlls
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
print_state_value_function(V_aqll, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Q(λ) accumulating:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_aqll - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_aqll, optimal_V)))
print()
print_action_value_function(Q_aqll,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Q(λ) accumulating action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_aqll, optimal_Q)))
print()
print_policy(pi_aqll, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_aqll, mean_return_aqll, mean_regret_aqll = get_policy_metrics(
env, gamma=gamma, pi=pi_aqll, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_aqll, mean_return_aqll, mean_regret_aqll))
State-value function found by Q(λ) accumulating: | | 01 0.5853 | 02 0.7684 | 03 0.8461 | 04 0.8894 | 05 0.9223 | 06 0.952 | 07 0.9803 | | Optimal state-value function: | | 01 0.5637 | 02 0.763 | 03 0.8449 | 04 0.8892 | 05 0.922 | 06 0.9515 | 07 0.9806 | | State-value function errors: | | 01 0.02 | 02 0.01 | 03 0.0 | 04 0.0 | 05 0.0 | 06 0.0 | 07 -0.0 | | State-value function RMSE: 0.0074 Q(λ) accumulating action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╕ │ s │ < │ > │ * < │ * > │ er < │ er > │ ╞═════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╡ │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 1 │ 0.329 │ 0.585 │ 0.312 │ 0.564 │ -0.017 │ -0.022 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 2 │ 0.674 │ 0.768 │ 0.67 │ 0.763 │ -0.004 │ -0.005 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 3 │ 0.802 │ 0.846 │ 0.803 │ 0.845 │ 0.001 │ -0.001 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 4 │ 0.865 │ 0.889 │ 0.864 │ 0.889 │ -0.001 │ -0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 5 │ 0.902 │ 0.922 │ 0.901 │ 0.922 │ -0 │ -0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 6 │ 0.932 │ 0.952 │ 0.932 │ 0.952 │ 0 │ -0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 7 │ 0.962 │ 0.98 │ 0.961 │ 0.981 │ -0.001 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 8 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╛ Action-value function RMSE: 0.0068 정책: | | 01 > | 02 > | 03 > | 04 > | 05 > | 06 > | 07 > | | Reaches goal 96.00%. Obtains an average return of 0.8548. Regret of 0.0000
def dyna_q(env,
gamma=1.0,
init_alpha=0.5,
min_alpha=0.01,
alpha_decay_ratio=0.5,
init_epsilon=1.0,
min_epsilon=0.1,
epsilon_decay_ratio=0.9,
n_planning=3,
n_episodes=3000):
nS, nA = env.observation_space.n, env.action_space.n
pi_track, T_track, R_track, planning_track = [], [], [], []
Q = np.zeros((nS, nA), dtype=np.float64)
T_count = np.zeros((nS, nA, nS), dtype=np.int)
R_model = np.zeros((nS, nA, nS), dtype=np.float64)
Q_track = np.zeros((n_episodes, nS, nA), dtype=np.float64)
select_action = lambda state, Q, epsilon: \
np.argmax(Q[state]) \
if np.random.random() > epsilon \
else np.random.randint(len(Q[state]))
alphas = decay_schedule(
init_alpha, min_alpha,
alpha_decay_ratio, n_episodes)
epsilons = decay_schedule(
init_epsilon, min_epsilon,
epsilon_decay_ratio, n_episodes)
for e in tqdm(range(n_episodes), leave=False):
state, done = env.reset(), False
while not done:
action = select_action(state, Q, epsilons[e])
next_state, reward, done, _ = env.step(action)
T_count[state][action][next_state] += 1
r_diff = reward - R_model[state][action][next_state]
R_model[state][action][next_state] += (r_diff / T_count[state][action][next_state])
td_target = reward + gamma * Q[next_state].max() * (not done)
td_error = td_target - Q[state][action]
Q[state][action] = Q[state][action] + alphas[e] * td_error
backup_next_state = next_state
for _ in range(n_planning):
if Q.sum() == 0: break
visited_states = np.where(np.sum(T_count, axis=(1, 2)) > 0)[0]
state = np.random.choice(visited_states)
actions_taken = np.where(np.sum(T_count[state], axis=1) > 0)[0]
action = np.random.choice(actions_taken)
probs = T_count[state][action]/T_count[state][action].sum()
next_state = np.random.choice(np.arange(nS), size=1, p=probs)[0]
reward = R_model[state][action][next_state]
planning_track.append((state, action, reward, next_state))
td_target = reward + gamma * Q[next_state].max()
td_error = td_target - Q[state][action]
Q[state][action] = Q[state][action] + alphas[e] * td_error
state = backup_next_state
T_track.append(T_count.copy())
R_track.append(R_model.copy())
Q_track[e] = Q
pi_track.append(np.argmax(Q, axis=1))
V = np.max(Q, axis=1)
pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
return Q, V, pi, Q_track, pi_track, T_track, R_track, np.array(planning_track)
Q_dqs, V_dqs, Q_track_dqs = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_dq, V_dq, pi_dq, Q_track_dq, pi_track_dq, T_track_dq, R_track_dq, planning_dq = dyna_q(
env, gamma=gamma, n_episodes=n_episodes)
Q_dqs.append(Q_dq) ; V_dqs.append(V_dq) ; Q_track_dqs.append(Q_track_dq)
Q_dq, V_dq, Q_track_dq = np.mean(Q_dqs, axis=0), np.mean(V_dqs, axis=0), np.mean(Q_track_dqs, axis=0)
del Q_dqs ; del V_dqs ; del Q_track_dqs
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
print_state_value_function(V_dq, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Dyna-Q:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_dq - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_dq, optimal_V)))
print()
print_action_value_function(Q_dq,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Dyna-Q action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_dq, optimal_Q)))
print()
print_policy(pi_dq, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_dq, mean_return_dq, mean_regret_dq = get_policy_metrics(
env, gamma=gamma, pi=pi_dq, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_dq, mean_return_dq, mean_regret_dq))
State-value function found by Dyna-Q: | | 01 0.5576 | 02 0.7725 | 03 0.8452 | 04 0.8896 | 05 0.9212 | 06 0.9515 | 07 0.9821 | | Optimal state-value function: | | 01 0.5637 | 02 0.763 | 03 0.8449 | 04 0.8892 | 05 0.922 | 06 0.9515 | 07 0.9806 | | State-value function errors: | | 01 -0.01 | 02 0.01 | 03 0.0 | 04 0.0 | 05 -0.0 | 06 -0.0 | 07 0.0 | | State-value function RMSE: 0.0038 Dyna-Q action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╕ │ s │ < │ > │ * < │ * > │ er < │ er > │ ╞═════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╡ │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 1 │ 0.299 │ 0.558 │ 0.312 │ 0.564 │ 0.013 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 2 │ 0.685 │ 0.773 │ 0.67 │ 0.763 │ -0.015 │ -0.01 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 3 │ 0.806 │ 0.845 │ 0.803 │ 0.845 │ -0.003 │ -0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 4 │ 0.862 │ 0.89 │ 0.864 │ 0.889 │ 0.001 │ -0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 5 │ 0.901 │ 0.921 │ 0.901 │ 0.922 │ 0 │ 0.001 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 6 │ 0.933 │ 0.951 │ 0.932 │ 0.952 │ -0.001 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 7 │ 0.962 │ 0.982 │ 0.961 │ 0.981 │ -0.001 │ -0.002 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 8 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╛ Action-value function RMSE: 0.0054 정책: | | 01 > | 02 > | 03 > | 04 > | 05 > | 06 > | 07 > | | Reaches goal 96.00%. Obtains an average return of 0.8548. Regret of 0.0000
plot_transition_model(T_track_dq, episode=0)
plot_transition_model(T_track_dq, episode=9)
plot_transition_model(T_track_dq, episode=99)
plot_transition_model(T_track_dq, episode=len(T_track_dq)-1)
plot_model_state_sampling(planning_dq, algo='Dyna-Q')
plot_model_state_7(planning_dq, algo='Dyna-Q')
def trajectory_sampling(env,
gamma=1.0,
init_alpha=0.5,
min_alpha=0.01,
alpha_decay_ratio=0.5,
init_epsilon=1.0,
min_epsilon=0.1,
epsilon_decay_ratio=0.9,
max_trajectory_depth=100,
planning_freq=5,
greedy_planning=True,
n_episodes=3000):
nS, nA = env.observation_space.n, env.action_space.n
pi_track, T_track, R_track, planning_track = [], [], [], []
Q = np.zeros((nS, nA), dtype=np.float64)
T_count = np.zeros((nS, nA, nS), dtype=np.int)
R_model = np.zeros((nS, nA, nS), dtype=np.float64)
Q_track = np.zeros((n_episodes, nS, nA), dtype=np.float64)
select_action = lambda state, Q, epsilon: \
np.argmax(Q[state]) \
if np.random.random() > epsilon \
else np.random.randint(len(Q[state]))
alphas = decay_schedule(
init_alpha, min_alpha,
alpha_decay_ratio, n_episodes)
epsilons = decay_schedule(
init_epsilon, min_epsilon,
epsilon_decay_ratio, n_episodes)
for e in tqdm(range(n_episodes), leave=False):
state, done = env.reset(), False
while not done:
action = select_action(state, Q, epsilons[e])
next_state, reward, done, _ = env.step(action)
T_count[state][action][next_state] += 1
r_diff = reward - R_model[state][action][next_state]
R_model[state][action][next_state] += (r_diff / T_count[state][action][next_state])
td_target = reward + gamma * Q[next_state].max() * (not done)
td_error = td_target - Q[state][action]
Q[state][action] = Q[state][action] + alphas[e] * td_error
backup_next_state = next_state
if e % planning_freq == 0:
for _ in range(max_trajectory_depth):
if Q.sum() == 0: break
action = Q[state].argmax() if greedy_planning else \
select_action(state, Q, epsilons[e])
if not T_count[state][action].sum(): break
probs = T_count[state][action]/T_count[state][action].sum()
next_state = np.random.choice(np.arange(nS), size=1, p=probs)[0]
reward = R_model[state][action][next_state]
planning_track.append((state, action, reward, next_state))
td_target = reward + gamma * Q[next_state].max()
td_error = td_target - Q[state][action]
Q[state][action] = Q[state][action] + alphas[e] * td_error
state = next_state
state = backup_next_state
T_track.append(T_count.copy())
R_track.append(R_model.copy())
Q_track[e] = Q
pi_track.append(np.argmax(Q, axis=1))
V = np.max(Q, axis=1)
pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
return Q, V, pi, Q_track, pi_track, T_track, R_track, np.array(planning_track)
Q_tss, V_tss, Q_track_tss = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_ts, V_ts, pi_ts, Q_track_ts, pi_track_ts, T_track_ts, R_track_ts, planning_ts = trajectory_sampling(
env, gamma=gamma, n_episodes=n_episodes)
Q_tss.append(Q_ts) ; V_tss.append(V_ts) ; Q_track_tss.append(Q_track_ts)
Q_ts, V_ts, Q_track_ts = np.mean(Q_tss, axis=0), np.mean(V_tss, axis=0), np.mean(Q_track_tss, axis=0)
del Q_tss ; del V_tss ; del Q_track_tss
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
0%| | 0/3000 [00:00<?, ?it/s]
print_state_value_function(V_ts, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Trajectory Sampling:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_ts - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_ts, optimal_V)))
print()
print_action_value_function(Q_ts,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Trajectory Sampling action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_ts, optimal_Q)))
print()
print_policy(pi_ts, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_ts, mean_return_ts, mean_regret_ts = get_policy_metrics(
env, gamma=gamma, pi=pi_ts, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_ts, mean_return_ts, mean_regret_ts))
State-value function found by Trajectory Sampling: | | 01 0.562 | 02 0.7616 | 03 0.8434 | 04 0.8869 | 05 0.9219 | 06 0.9515 | 07 0.981 | | Optimal state-value function: | | 01 0.5637 | 02 0.763 | 03 0.8449 | 04 0.8892 | 05 0.922 | 06 0.9515 | 07 0.9806 | | State-value function errors: | | 01 -0.0 | 02 -0.0 | 03 -0.0 | 04 -0.0 | 05 -0.0 | 06 -0.0 | 07 0.0 | | State-value function RMSE: 0.0012 Trajectory Sampling action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╕ │ s │ < │ > │ * < │ * > │ er < │ er > │ ╞═════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╡ │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 1 │ 0.302 │ 0.562 │ 0.312 │ 0.564 │ 0.01 │ 0.002 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 2 │ 0.665 │ 0.762 │ 0.67 │ 0.763 │ 0.005 │ 0.001 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 3 │ 0.802 │ 0.843 │ 0.803 │ 0.845 │ 0.001 │ 0.002 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 4 │ 0.865 │ 0.887 │ 0.864 │ 0.889 │ -0.002 │ 0.002 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 5 │ 0.903 │ 0.922 │ 0.901 │ 0.922 │ -0.001 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 6 │ 0.934 │ 0.952 │ 0.932 │ 0.952 │ -0.002 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 7 │ 0.962 │ 0.981 │ 0.961 │ 0.981 │ -0.001 │ -0 │ ├─────┼───────┼───────┼───────┼───────┼────────┼────────┤ │ 8 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╛ Action-value function RMSE: 0.0029 정책: | | 01 > | 02 > | 03 > | 04 > | 05 > | 06 > | 07 > | | Reaches goal 96.00%. Obtains an average return of 0.8548. Regret of 0.0000
plot_model_state_sampling(planning_ts, algo='Trajectory Sampling')
plot_model_state_7(planning_ts, algo='Trajectory Sampling')
plot_value_function(
'Sarsa(λ) replacing estimates through time vs. true values',
np.max(Q_track_rsl, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Sarsa(λ) replacing estimates through time vs. true values (log scale)',
np.max(Q_track_rsl, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Sarsa(λ) replacing estimates through time (close up)',
np.max(Q_track_rsl, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Sarsa(λ) accumulating estimates through time vs. true values',
np.max(Q_track_asl, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Sarsa(λ) accumulating estimates through time vs. true values (log scale)',
np.max(Q_track_asl, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Sarsa(λ) accumulating estimates through time (close up)',
np.max(Q_track_asl, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Q(λ) replacing estimates through time vs. true values',
np.max(Q_track_rqll, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Q(λ) replacing estimates through time vs. true values (log scale)',
np.max(Q_track_rqll, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Q(λ) replacing estimates through time (close up)',
np.max(Q_track_rqll, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Q(λ) accumulating estimates through time vs. true values',
np.max(Q_track_aqll, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Q(λ) accumulating estimates through time vs. true values (log scale)',
np.max(Q_track_aqll, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Q(λ) accumulating estimates through time (close up)',
np.max(Q_track_aqll, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Dyna-Q estimates through time vs. true values',
np.max(Q_track_dq, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Dyna-Q estimates through time vs. true values (log scale)',
np.max(Q_track_dq, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Dyna-Q estimates through time (close up)',
np.max(Q_track_dq, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Trajectory Sampling estimates through time vs. true values',
np.max(Q_track_ts, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Trajectory Sampling estimates through time vs. true values (log scale)',
np.max(Q_track_ts, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Trajectory Sampling estimates through time (close up)',
np.max(Q_track_ts, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
rsl_success_rate_ma, rsl_mean_return_ma, rsl_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_rsl)
0%| | 0/3000 [00:00<?, ?it/s]
asl_success_rate_ma, asl_mean_return_ma, asl_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_asl)
0%| | 0/3000 [00:00<?, ?it/s]
rqll_success_rate_ma, rqll_mean_return_ma, rqll_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_rqll)
0%| | 0/3000 [00:00<?, ?it/s]
aqll_success_rate_ma, aqll_mean_return_ma, aqll_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_aqll)
0%| | 0/3000 [00:00<?, ?it/s]
dq_success_rate_ma, dq_mean_return_ma, dq_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_dq)
0%| | 0/3000 [00:00<?, ?it/s]
ts_success_rate_ma, ts_mean_return_ma, ts_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_ts)
0%| | 0/3000 [00:00<?, ?it/s]
plt.axhline(y=success_rate_op, color='k', linestyle='-', linewidth=1)
plt.text(int(len(rsl_success_rate_ma)*1.02), success_rate_op*1.01, 'π*')
plt.plot(rsl_success_rate_ma, '-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(asl_success_rate_ma, '--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(rqll_success_rate_ma, ':', linewidth=2, label='Q(λ) replacing')
plt.plot(aqll_success_rate_ma, '-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(dq_success_rate_ma, '-', linewidth=2, label='Dyna-Q')
plt.plot(ts_success_rate_ma, '--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=4, ncol=1)
plt.title('Policy success rate (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Success rate %')
plt.ylim(-1, 101)
plt.xticks(rotation=45)
plt.show()
plt.axhline(y=mean_return_op, color='k', linestyle='-', linewidth=1)
plt.text(int(len(rsl_mean_return_ma)*1.02), mean_return_op*1.01, 'π*')
plt.plot(rsl_mean_return_ma, '-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(asl_mean_return_ma, '--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(rqll_mean_return_ma, ':', linewidth=2, label='Q(λ) replacing')
plt.plot(aqll_mean_return_ma, '-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(dq_mean_return_ma, '-', linewidth=2, label='Dyna-Q')
plt.plot(ts_mean_return_ma, '--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=4, ncol=1)
plt.title('Policy episode return (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Return (Gt:T)')
plt.xticks(rotation=45)
plt.show()
plt.plot(rsl_mean_regret_ma, '-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(asl_mean_regret_ma, '--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(rqll_mean_regret_ma, ':', linewidth=2, label='Q(λ) replacing')
plt.plot(aqll_mean_regret_ma, '-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(dq_mean_regret_ma, '-', linewidth=2, label='Dyna-Q')
plt.plot(ts_mean_regret_ma, '--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=1, ncol=1)
plt.title('Policy episode regret (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Regret (q* - Q)')
plt.xticks(rotation=45)
plt.show()
plt.axhline(y=optimal_V[init_state], color='k', linestyle='-', linewidth=1)
plt.text(int(len(Q_track_rsl)*1.05), optimal_V[init_state]+.01, 'v*({})'.format(init_state))
plt.plot(moving_average(np.max(Q_track_rsl, axis=2).T[init_state]),
'-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(moving_average(np.max(Q_track_asl, axis=2).T[init_state]),
'--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(moving_average(np.max(Q_track_rqll, axis=2).T[init_state]),
':', linewidth=2, label='Q(λ) replacing')
plt.plot(moving_average(np.max(Q_track_aqll, axis=2).T[init_state]),
'-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(moving_average(np.max(Q_track_dq, axis=2).T[init_state]),
'-', linewidth=2, label='Dyna-Q')
plt.plot(moving_average(np.max(Q_track_ts, axis=2).T[init_state]),
'--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=4, ncol=1)
plt.title('Estimated expected return (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Estimated value of initial state V({})'.format(init_state))
plt.xticks(rotation=45)
plt.show()
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_rsl, axis=2) - optimal_V), axis=1)),
'-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_asl, axis=2) - optimal_V), axis=1)),
'--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_rqll, axis=2) - optimal_V), axis=1)),
':', linewidth=2, label='Q(λ) replacing')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_aqll, axis=2) - optimal_V), axis=1)),
'-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_dq, axis=2) - optimal_V), axis=1)),
'-', linewidth=2, label='Dyna-Q')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_ts, axis=2) - optimal_V), axis=1)),
'--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=1, ncol=1)
plt.title('State-value function estimation error (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Mean Absolute Error MAE(V, v*)')
plt.xticks(rotation=45)
plt.show()
plt.plot(moving_average(np.mean(np.abs(Q_track_rsl - optimal_Q), axis=(1,2))),
'-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(moving_average(np.mean(np.abs(Q_track_asl - optimal_Q), axis=(1,2))),
'--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(moving_average(np.mean(np.abs(Q_track_rqll - optimal_Q), axis=(1,2))),
':', linewidth=2, label='Q(λ) replacing')
plt.plot(moving_average(np.mean(np.abs(Q_track_aqll - optimal_Q), axis=(1,2))),
'-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(moving_average(np.mean(np.abs(Q_track_dq - optimal_Q), axis=(1,2))),
'-', linewidth=2, label='Dyna-Q')
plt.plot(moving_average(np.mean(np.abs(Q_track_ts - optimal_Q), axis=(1,2))),
'--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=1, ncol=1)
plt.title('Action-value function estimation error (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Mean Absolute Error MAE(Q, q*)')
plt.xticks(rotation=45)
plt.show()
env = gym.make('FrozenLake-v0')
init_state = env.reset()
goal_state = 15
gamma = 0.99
n_episodes = 10000
P = env.env.P
n_cols, svf_prec, err_prec, avf_prec=4, 4, 2, 3
action_symbols=('<', 'v', '>', '^')
limit_items, limit_value = 5, 0.0
cu_limit_items, cu_limit_value, cu_episodes = 10, 0.01, 2000
plt.plot(decay_schedule(0.5, 0.01, 0.5, n_episodes),
'-', linewidth=2,
label='Alpha schedule')
plt.plot(decay_schedule(1.0, 0.1, 0.9, n_episodes),
':', linewidth=2,
label='Epsilon schedule')
plt.legend(loc=1, ncol=1)
plt.title('Alpha and epsilon schedules')
plt.xlabel('Episodes')
plt.ylabel('Hyperparameter values')
plt.xticks(rotation=45)
plt.show()
optimal_Q, optimal_V, optimal_pi = value_iteration(P, gamma=gamma)
print_state_value_function(optimal_V, P, n_cols=n_cols, prec=svf_prec, title='Optimal state-value function:')
print()
print_action_value_function(optimal_Q,
None,
action_symbols=action_symbols,
prec=avf_prec,
title='Optimal action-value function:')
print()
print_policy(optimal_pi, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_op, mean_return_op, mean_regret_op = get_policy_metrics(
env, gamma=gamma, pi=optimal_pi, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_op, mean_return_op, mean_regret_op))
Optimal state-value function: | 00 0.542 | 01 0.4988 | 02 0.4707 | 03 0.4569 | | 04 0.5585 | | 06 0.3583 | | | 08 0.5918 | 09 0.6431 | 10 0.6152 | | | | 13 0.7417 | 14 0.8628 | | Optimal action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╕ │ s │ < │ v │ > │ ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╡ │ 0 │ 0.542 │ 0.528 │ 0.528 │ 0.522 │ ├─────┼───────┼───────┼───────┼───────┤ │ 1 │ 0.343 │ 0.334 │ 0.32 │ 0.499 │ ├─────┼───────┼───────┼───────┼───────┤ │ 2 │ 0.438 │ 0.434 │ 0.424 │ 0.471 │ ├─────┼───────┼───────┼───────┼───────┤ │ 3 │ 0.306 │ 0.306 │ 0.302 │ 0.457 │ ├─────┼───────┼───────┼───────┼───────┤ │ 4 │ 0.558 │ 0.38 │ 0.374 │ 0.363 │ ├─────┼───────┼───────┼───────┼───────┤ │ 5 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 6 │ 0.358 │ 0.203 │ 0.358 │ 0.155 │ ├─────┼───────┼───────┼───────┼───────┤ │ 7 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 8 │ 0.38 │ 0.408 │ 0.397 │ 0.592 │ ├─────┼───────┼───────┼───────┼───────┤ │ 9 │ 0.44 │ 0.643 │ 0.448 │ 0.398 │ ├─────┼───────┼───────┼───────┼───────┤ │ 10 │ 0.615 │ 0.497 │ 0.403 │ 0.33 │ ├─────┼───────┼───────┼───────┼───────┤ │ 11 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 12 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 13 │ 0.457 │ 0.53 │ 0.742 │ 0.497 │ ├─────┼───────┼───────┼───────┼───────┤ │ 14 │ 0.733 │ 0.863 │ 0.821 │ 0.781 │ ├─────┼───────┼───────┼───────┼───────┤ │ 15 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╛ 정책: | 00 < | 01 ^ | 02 ^ | 03 ^ | | 04 < | | 06 < | | | 08 ^ | 09 v | 10 < | | | | 13 > | 14 v | | Reaches goal 74.00%. Obtains an average return of 0.5116. Regret of 0.0000
Q_rsls, V_rsls, Q_track_rsls = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_rsl, V_rsl, pi_rsl, Q_track_rsl, pi_track_rsl = sarsa_lambda(env, gamma=gamma, n_episodes=n_episodes)
Q_rsls.append(Q_rsl) ; V_rsls.append(V_rsl) ; Q_track_rsls.append(Q_track_rsl)
Q_rsl, V_rsl, Q_track_rsl = np.mean(Q_rsls, axis=0), np.mean(V_rsls, axis=0), np.mean(Q_track_rsls, axis=0)
del Q_rsls ; del V_rsls ; del Q_track_rsls
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
print_state_value_function(V_rsl, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Sarsa(λ) replacing:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_rsl - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_rsl, optimal_V)))
print()
print_action_value_function(Q_rsl,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Sarsa(λ) replacing action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_rsl, optimal_Q)))
print()
print_policy(pi_rsl, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_rsl, mean_return_rsl, mean_regret_rsl = get_policy_metrics(
env, gamma=gamma, pi=pi_rsl, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_rsl, mean_return_rsl, mean_regret_rsl))
State-value function found by Sarsa(λ) replacing: | 00 0.2941 | 01 0.2414 | 02 0.2168 | 03 0.133 | | 04 0.3138 | | 06 0.2152 | | | 08 0.3585 | 09 0.4465 | 10 0.4496 | | | | 13 0.5839 | 14 0.7726 | | Optimal state-value function: | 00 0.542 | 01 0.4988 | 02 0.4707 | 03 0.4569 | | 04 0.5585 | | 06 0.3583 | | | 08 0.5918 | 09 0.6431 | 10 0.6152 | | | | 13 0.7417 | 14 0.8628 | | State-value function errors: | 00 -0.25 | 01 -0.26 | 02 -0.25 | 03 -0.32 | | 04 -0.24 | | 06 -0.14 | | | 08 -0.23 | 09 -0.2 | 10 -0.17 | | | | 13 -0.16 | 14 -0.09 | | State-value function RMSE: 0.1822 Sarsa(λ) replacing action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╤════════╤════════╕ │ s │ < │ v │ > │ ^ │ * < │ * v │ * > │ * ^ │ er < │ er v │ er > │ er ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╪════════╪════════╡ │ 0 │ 0.294 │ 0.27 │ 0.271 │ 0.265 │ 0.542 │ 0.528 │ 0.528 │ 0.522 │ 0.248 │ 0.258 │ 0.256 │ 0.257 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 1 │ 0.118 │ 0.126 │ 0.105 │ 0.241 │ 0.343 │ 0.334 │ 0.32 │ 0.499 │ 0.225 │ 0.208 │ 0.215 │ 0.257 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 2 │ 0.217 │ 0.136 │ 0.139 │ 0.133 │ 0.438 │ 0.434 │ 0.424 │ 0.471 │ 0.221 │ 0.298 │ 0.285 │ 0.338 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 3 │ 0.045 │ 0.047 │ 0.036 │ 0.133 │ 0.306 │ 0.306 │ 0.302 │ 0.457 │ 0.262 │ 0.259 │ 0.265 │ 0.324 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 4 │ 0.314 │ 0.218 │ 0.209 │ 0.198 │ 0.558 │ 0.38 │ 0.374 │ 0.363 │ 0.245 │ 0.162 │ 0.165 │ 0.165 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 5 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 6 │ 0.162 │ 0.116 │ 0.202 │ 0.049 │ 0.358 │ 0.203 │ 0.358 │ 0.155 │ 0.196 │ 0.087 │ 0.156 │ 0.107 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 7 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 8 │ 0.214 │ 0.255 │ 0.252 │ 0.358 │ 0.38 │ 0.408 │ 0.397 │ 0.592 │ 0.166 │ 0.152 │ 0.144 │ 0.233 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 9 │ 0.291 │ 0.447 │ 0.326 │ 0.259 │ 0.44 │ 0.643 │ 0.448 │ 0.398 │ 0.149 │ 0.197 │ 0.122 │ 0.139 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 10 │ 0.45 │ 0.354 │ 0.289 │ 0.184 │ 0.615 │ 0.497 │ 0.403 │ 0.33 │ 0.166 │ 0.143 │ 0.114 │ 0.147 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 11 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 12 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 13 │ 0.304 │ 0.418 │ 0.584 │ 0.365 │ 0.457 │ 0.53 │ 0.742 │ 0.497 │ 0.153 │ 0.112 │ 0.158 │ 0.131 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 14 │ 0.557 │ 0.773 │ 0.713 │ 0.64 │ 0.733 │ 0.863 │ 0.821 │ 0.781 │ 0.175 │ 0.09 │ 0.108 │ 0.141 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 15 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╧════════╧════════╛ Action-value function RMSE: 0.167 정책: | 00 < | 01 ^ | 02 < | 03 ^ | | 04 < | | 06 > | | | 08 ^ | 09 v | 10 < | | | | 13 > | 14 v | | Reaches goal 70.00%. Obtains an average return of 0.4864. Regret of 0.0156
Q_asls, V_asls, Q_track_asls = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_asl, V_asl, pi_asl, Q_track_asl, pi_track_asl = sarsa_lambda(env, gamma=gamma,
replacing_traces=False,
n_episodes=n_episodes)
Q_asls.append(Q_asl) ; V_asls.append(V_asl) ; Q_track_asls.append(Q_track_asl)
Q_asl, V_asl, Q_track_asl = np.mean(Q_asls, axis=0), np.mean(V_asls, axis=0), np.mean(Q_track_asls, axis=0)
del Q_asls ; del V_asls ; del Q_track_asls
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
print_state_value_function(V_asl, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Sarsa(λ) accumulating:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_asl - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_asl, optimal_V)))
print()
print_action_value_function(Q_asl,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Sarsa(λ) accumulating action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_asl, optimal_Q)))
print()
print_policy(pi_asl, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_asl, mean_return_asl, mean_regret_asl = get_policy_metrics(
env, gamma=gamma, pi=pi_asl, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_asl, mean_return_asl, mean_regret_asl))
State-value function found by Sarsa(λ) accumulating: | 00 0.2872 | 01 0.2453 | 02 0.2138 | 03 0.1526 | | 04 0.3114 | | 06 0.2142 | | | 08 0.3617 | 09 0.4517 | 10 0.4699 | | | | 13 0.5917 | 14 0.7812 | | Optimal state-value function: | 00 0.542 | 01 0.4988 | 02 0.4707 | 03 0.4569 | | 04 0.5585 | | 06 0.3583 | | | 08 0.5918 | 09 0.6431 | 10 0.6152 | | | | 13 0.7417 | 14 0.8628 | | State-value function errors: | 00 -0.25 | 01 -0.25 | 02 -0.26 | 03 -0.3 | | 04 -0.25 | | 06 -0.14 | | | 08 -0.23 | 09 -0.19 | 10 -0.15 | | | | 13 -0.15 | 14 -0.08 | | State-value function RMSE: 0.1784 Sarsa(λ) accumulating action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╤════════╤════════╕ │ s │ < │ v │ > │ ^ │ * < │ * v │ * > │ * ^ │ er < │ er v │ er > │ er ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╪════════╪════════╡ │ 0 │ 0.287 │ 0.274 │ 0.275 │ 0.271 │ 0.542 │ 0.528 │ 0.528 │ 0.522 │ 0.255 │ 0.253 │ 0.253 │ 0.252 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 1 │ 0.125 │ 0.127 │ 0.11 │ 0.245 │ 0.343 │ 0.334 │ 0.32 │ 0.499 │ 0.218 │ 0.207 │ 0.21 │ 0.254 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 2 │ 0.214 │ 0.153 │ 0.142 │ 0.14 │ 0.438 │ 0.434 │ 0.424 │ 0.471 │ 0.224 │ 0.28 │ 0.282 │ 0.33 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 3 │ 0.047 │ 0.048 │ 0.043 │ 0.153 │ 0.306 │ 0.306 │ 0.302 │ 0.457 │ 0.26 │ 0.258 │ 0.259 │ 0.304 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 4 │ 0.311 │ 0.213 │ 0.197 │ 0.199 │ 0.558 │ 0.38 │ 0.374 │ 0.363 │ 0.247 │ 0.166 │ 0.177 │ 0.164 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 5 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 6 │ 0.206 │ 0.125 │ 0.176 │ 0.049 │ 0.358 │ 0.203 │ 0.358 │ 0.155 │ 0.153 │ 0.078 │ 0.182 │ 0.106 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 7 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 8 │ 0.223 │ 0.264 │ 0.241 │ 0.362 │ 0.38 │ 0.408 │ 0.397 │ 0.592 │ 0.156 │ 0.144 │ 0.155 │ 0.23 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 9 │ 0.292 │ 0.452 │ 0.326 │ 0.264 │ 0.44 │ 0.643 │ 0.448 │ 0.398 │ 0.148 │ 0.191 │ 0.122 │ 0.134 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 10 │ 0.47 │ 0.35 │ 0.295 │ 0.183 │ 0.615 │ 0.497 │ 0.403 │ 0.33 │ 0.145 │ 0.147 │ 0.108 │ 0.148 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 11 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 12 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 13 │ 0.314 │ 0.424 │ 0.592 │ 0.395 │ 0.457 │ 0.53 │ 0.742 │ 0.497 │ 0.143 │ 0.106 │ 0.15 │ 0.102 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 14 │ 0.557 │ 0.781 │ 0.705 │ 0.66 │ 0.733 │ 0.863 │ 0.821 │ 0.781 │ 0.175 │ 0.082 │ 0.116 │ 0.121 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 15 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╧════════╧════════╛ Action-value function RMSE: 0.1633 정책: | 00 < | 01 ^ | 02 < | 03 ^ | | 04 < | | 06 > | | | 08 ^ | 09 v | 10 < | | | | 13 > | 14 v | | Reaches goal 70.00%. Obtains an average return of 0.4864. Regret of 0.0156
Q_rqlls, V_rqlls, Q_track_rqlls = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_rqll, V_rqll, pi_rqll, Q_track_rqll, pi_track_rqll = q_lambda(env, gamma=gamma, n_episodes=n_episodes)
Q_rqlls.append(Q_rqll) ; V_rqlls.append(V_rqll) ; Q_track_rqlls.append(Q_track_rqll)
Q_rqll, V_rqll, Q_track_rqll = np.mean(Q_rqlls, axis=0), np.mean(V_rqlls, axis=0), np.mean(Q_track_rqlls, axis=0)
del Q_rqlls ; del V_rqlls ; del Q_track_rqlls
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
print_state_value_function(V_rqll, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Q(λ) replacing:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_rqll - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_rqll, optimal_V)))
print()
print_action_value_function(Q_rqll,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Q(λ) replacing action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_rqll, optimal_Q)))
print()
print_policy(pi_rqll, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_rqll, mean_return_rqll, mean_regret_rqll = get_policy_metrics(
env, gamma=gamma, pi=pi_rqll, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_rqll, mean_return_rqll, mean_regret_rqll))
State-value function found by Q(λ) replacing: | 00 0.5205 | 01 0.4758 | 02 0.4481 | 03 0.4346 | | 04 0.5376 | | 06 0.3382 | | | 08 0.5704 | 09 0.6248 | 10 0.6006 | | | | 13 0.7198 | 14 0.8505 | | Optimal state-value function: | 00 0.542 | 01 0.4988 | 02 0.4707 | 03 0.4569 | | 04 0.5585 | | 06 0.3583 | | | 08 0.5918 | 09 0.6431 | 10 0.6152 | | | | 13 0.7417 | 14 0.8628 | | State-value function errors: | 00 -0.02 | 01 -0.02 | 02 -0.02 | 03 -0.02 | | 04 -0.02 | | 06 -0.02 | | | 08 -0.02 | 09 -0.02 | 10 -0.01 | | | | 13 -0.02 | 14 -0.01 | | State-value function RMSE: 0.0167 Q(λ) replacing action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╤════════╤════════╕ │ s │ < │ v │ > │ ^ │ * < │ * v │ * > │ * ^ │ er < │ er v │ er > │ er ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╪════════╪════════╡ │ 0 │ 0.521 │ 0.511 │ 0.512 │ 0.505 │ 0.542 │ 0.528 │ 0.528 │ 0.522 │ 0.021 │ 0.017 │ 0.016 │ 0.017 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 1 │ 0.344 │ 0.336 │ 0.305 │ 0.476 │ 0.343 │ 0.334 │ 0.32 │ 0.499 │ -0 │ -0.002 │ 0.015 │ 0.023 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 2 │ 0.422 │ 0.426 │ 0.415 │ 0.448 │ 0.438 │ 0.434 │ 0.424 │ 0.471 │ 0.016 │ 0.008 │ 0.01 │ 0.023 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 3 │ 0.295 │ 0.296 │ 0.286 │ 0.435 │ 0.306 │ 0.306 │ 0.302 │ 0.457 │ 0.011 │ 0.01 │ 0.016 │ 0.022 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 4 │ 0.538 │ 0.348 │ 0.357 │ 0.362 │ 0.558 │ 0.38 │ 0.374 │ 0.363 │ 0.021 │ 0.032 │ 0.017 │ 0.002 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 5 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 6 │ 0.326 │ 0.185 │ 0.318 │ 0.147 │ 0.358 │ 0.203 │ 0.358 │ 0.155 │ 0.033 │ 0.018 │ 0.041 │ 0.008 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 7 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 8 │ 0.364 │ 0.402 │ 0.38 │ 0.57 │ 0.38 │ 0.408 │ 0.397 │ 0.592 │ 0.015 │ 0.006 │ 0.016 │ 0.021 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 9 │ 0.434 │ 0.625 │ 0.439 │ 0.402 │ 0.44 │ 0.643 │ 0.448 │ 0.398 │ 0.006 │ 0.018 │ 0.009 │ -0.004 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 10 │ 0.601 │ 0.51 │ 0.413 │ 0.344 │ 0.615 │ 0.497 │ 0.403 │ 0.33 │ 0.015 │ -0.014 │ -0.01 │ -0.013 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 11 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 12 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 13 │ 0.441 │ 0.517 │ 0.72 │ 0.489 │ 0.457 │ 0.53 │ 0.742 │ 0.497 │ 0.016 │ 0.012 │ 0.022 │ 0.008 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 14 │ 0.734 │ 0.851 │ 0.824 │ 0.784 │ 0.733 │ 0.863 │ 0.821 │ 0.781 │ -0.001 │ 0.012 │ -0.003 │ -0.003 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 15 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╧════════╧════════╛ Action-value function RMSE: 0.0137 정책: | 00 < | 01 ^ | 02 ^ | 03 ^ | | 04 < | | 06 < | | | 08 ^ | 09 v | 10 < | | | | 13 > | 14 v | | Reaches goal 74.00%. Obtains an average return of 0.5116. Regret of 0.0000
Q_aqlls, V_aqlls, Q_track_aqlls = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_aqll, V_aqll, pi_aqll, Q_track_aqll, pi_track_aqll = q_lambda(env, gamma=gamma,
replacing_traces=False,
n_episodes=n_episodes)
Q_aqlls.append(Q_aqll) ; V_aqlls.append(V_aqll) ; Q_track_aqlls.append(Q_track_aqll)
Q_aqll, V_aqll, Q_track_aqll = np.mean(Q_aqlls, axis=0), np.mean(V_aqlls, axis=0), np.mean(Q_track_aqlls, axis=0)
del Q_aqlls ; del V_aqlls ; del Q_track_aqlls
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
print_state_value_function(V_aqll, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Q(λ) accumulating:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_aqll - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_aqll, optimal_V)))
print()
print_action_value_function(Q_aqll,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Q(λ) accumulating action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_aqll, optimal_Q)))
print()
print_policy(pi_aqll, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_aqll, mean_return_aqll, mean_regret_aqll = get_policy_metrics(
env, gamma=gamma, pi=pi_aqll, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_aqll, mean_return_aqll, mean_regret_aqll))
State-value function found by Q(λ) accumulating: | 00 0.5208 | 01 0.4669 | 02 0.4335 | 03 0.4211 | | 04 0.5382 | | 06 0.321 | | | 08 0.5731 | 09 0.6253 | 10 0.5811 | | | | 13 0.7373 | 14 0.862 | | Optimal state-value function: | 00 0.542 | 01 0.4988 | 02 0.4707 | 03 0.4569 | | 04 0.5585 | | 06 0.3583 | | | 08 0.5918 | 09 0.6431 | 10 0.6152 | | | | 13 0.7417 | 14 0.8628 | | State-value function errors: | 00 -0.02 | 01 -0.03 | 02 -0.04 | 03 -0.04 | | 04 -0.02 | | 06 -0.04 | | | 08 -0.02 | 09 -0.02 | 10 -0.03 | | | | 13 -0.0 | 14 -0.0 | | State-value function RMSE: 0.0221 Q(λ) accumulating action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╤════════╤════════╕ │ s │ < │ v │ > │ ^ │ * < │ * v │ * > │ * ^ │ er < │ er v │ er > │ er ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╪════════╪════════╡ │ 0 │ 0.521 │ 0.5 │ 0.497 │ 0.495 │ 0.542 │ 0.528 │ 0.528 │ 0.522 │ 0.021 │ 0.028 │ 0.03 │ 0.027 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 1 │ 0.331 │ 0.307 │ 0.302 │ 0.467 │ 0.343 │ 0.334 │ 0.32 │ 0.499 │ 0.013 │ 0.028 │ 0.018 │ 0.032 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 2 │ 0.421 │ 0.416 │ 0.41 │ 0.433 │ 0.438 │ 0.434 │ 0.424 │ 0.471 │ 0.018 │ 0.017 │ 0.014 │ 0.037 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 3 │ 0.287 │ 0.297 │ 0.294 │ 0.421 │ 0.306 │ 0.306 │ 0.302 │ 0.457 │ 0.019 │ 0.009 │ 0.008 │ 0.036 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 4 │ 0.538 │ 0.352 │ 0.348 │ 0.343 │ 0.558 │ 0.38 │ 0.374 │ 0.363 │ 0.02 │ 0.027 │ 0.026 │ 0.02 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 5 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 6 │ 0.307 │ 0.188 │ 0.318 │ 0.13 │ 0.358 │ 0.203 │ 0.358 │ 0.155 │ 0.051 │ 0.015 │ 0.041 │ 0.025 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 7 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 8 │ 0.379 │ 0.393 │ 0.387 │ 0.573 │ 0.38 │ 0.408 │ 0.397 │ 0.592 │ 0.001 │ 0.014 │ 0.01 │ 0.019 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 9 │ 0.424 │ 0.625 │ 0.419 │ 0.388 │ 0.44 │ 0.643 │ 0.448 │ 0.398 │ 0.016 │ 0.018 │ 0.029 │ 0.01 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 10 │ 0.581 │ 0.493 │ 0.396 │ 0.317 │ 0.615 │ 0.497 │ 0.403 │ 0.33 │ 0.034 │ 0.004 │ 0.007 │ 0.014 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 11 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 12 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 13 │ 0.44 │ 0.514 │ 0.737 │ 0.491 │ 0.457 │ 0.53 │ 0.742 │ 0.497 │ 0.017 │ 0.016 │ 0.004 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 14 │ 0.721 │ 0.862 │ 0.805 │ 0.764 │ 0.733 │ 0.863 │ 0.821 │ 0.781 │ 0.012 │ 0.001 │ 0.016 │ 0.017 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 15 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╧════════╧════════╛ Action-value function RMSE: 0.0183 정책: | 00 < | 01 ^ | 02 ^ | 03 ^ | | 04 < | | 06 < | | | 08 ^ | 09 v | 10 < | | | | 13 > | 14 v | | Reaches goal 74.00%. Obtains an average return of 0.5116. Regret of 0.0000
Q_dqs, V_dqs, Q_track_dqs = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_dq, V_dq, pi_dq, Q_track_dq, pi_track_dq, T_track_dq, R_track_dq, planning_dq = dyna_q(
env, gamma=gamma, n_episodes=n_episodes)
Q_dqs.append(Q_dq) ; V_dqs.append(V_dq) ; Q_track_dqs.append(Q_track_dq)
Q_dq, V_dq, Q_track_dq = np.mean(Q_dqs, axis=0), np.mean(V_dqs, axis=0), np.mean(Q_track_dqs, axis=0)
del Q_dqs ; del V_dqs ; del Q_track_dqs
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
print_state_value_function(V_dq, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Dyna-Q:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_dq - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_dq, optimal_V)))
print()
print_action_value_function(Q_dq,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Dyna-Q action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_dq, optimal_Q)))
print()
print_policy(pi_dq, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_dq, mean_return_dq, mean_regret_dq = get_policy_metrics(
env, gamma=gamma, pi=pi_dq, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_dq, mean_return_dq, mean_regret_dq))
State-value function found by Dyna-Q: | 00 0.5278 | 01 0.4863 | 02 0.4592 | 03 0.4437 | | 04 0.5448 | | 06 0.3628 | | | 08 0.579 | 09 0.6339 | 10 0.604 | | | | 13 0.7354 | 14 0.8566 | | Optimal state-value function: | 00 0.542 | 01 0.4988 | 02 0.4707 | 03 0.4569 | | 04 0.5585 | | 06 0.3583 | | | 08 0.5918 | 09 0.6431 | 10 0.6152 | | | | 13 0.7417 | 14 0.8628 | | State-value function errors: | 00 -0.01 | 01 -0.01 | 02 -0.01 | 03 -0.01 | | 04 -0.01 | | 06 0.0 | | | 08 -0.01 | 09 -0.01 | 10 -0.01 | | | | 13 -0.01 | 14 -0.01 | | State-value function RMSE: 0.0091 Dyna-Q action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╤════════╤════════╕ │ s │ < │ v │ > │ ^ │ * < │ * v │ * > │ * ^ │ er < │ er v │ er > │ er ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╪════════╪════════╡ │ 0 │ 0.528 │ 0.514 │ 0.514 │ 0.509 │ 0.542 │ 0.528 │ 0.528 │ 0.522 │ 0.014 │ 0.013 │ 0.014 │ 0.014 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 1 │ 0.322 │ 0.335 │ 0.314 │ 0.486 │ 0.343 │ 0.334 │ 0.32 │ 0.499 │ 0.022 │ -0.001 │ 0.006 │ 0.013 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 2 │ 0.432 │ 0.427 │ 0.416 │ 0.459 │ 0.438 │ 0.434 │ 0.424 │ 0.471 │ 0.006 │ 0.007 │ 0.009 │ 0.012 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 3 │ 0.298 │ 0.303 │ 0.286 │ 0.444 │ 0.306 │ 0.306 │ 0.302 │ 0.457 │ 0.008 │ 0.003 │ 0.015 │ 0.013 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 4 │ 0.545 │ 0.371 │ 0.36 │ 0.361 │ 0.558 │ 0.38 │ 0.374 │ 0.363 │ 0.014 │ 0.009 │ 0.014 │ 0.002 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 5 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 6 │ 0.336 │ 0.238 │ 0.358 │ 0.134 │ 0.358 │ 0.203 │ 0.358 │ 0.155 │ 0.023 │ -0.035 │ 0 │ 0.022 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 7 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 8 │ 0.37 │ 0.419 │ 0.366 │ 0.579 │ 0.38 │ 0.408 │ 0.397 │ 0.592 │ 0.009 │ -0.011 │ 0.03 │ 0.013 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 9 │ 0.446 │ 0.634 │ 0.446 │ 0.401 │ 0.44 │ 0.643 │ 0.448 │ 0.398 │ -0.006 │ 0.009 │ 0.002 │ -0.002 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 10 │ 0.604 │ 0.488 │ 0.393 │ 0.348 │ 0.615 │ 0.497 │ 0.403 │ 0.33 │ 0.011 │ 0.009 │ 0.01 │ -0.017 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 11 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 12 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 13 │ 0.472 │ 0.5 │ 0.735 │ 0.509 │ 0.457 │ 0.53 │ 0.742 │ 0.497 │ -0.015 │ 0.029 │ 0.006 │ -0.012 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 14 │ 0.718 │ 0.857 │ 0.813 │ 0.786 │ 0.733 │ 0.863 │ 0.821 │ 0.781 │ 0.015 │ 0.006 │ 0.008 │ -0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 15 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╧════════╧════════╛ Action-value function RMSE: 0.0115 정책: | 00 < | 01 ^ | 02 ^ | 03 ^ | | 04 < | | 06 > | | | 08 ^ | 09 v | 10 < | | | | 13 > | 14 v | | Reaches goal 72.00%. Obtains an average return of 0.4936. Regret of 0.0000
Q_tss, V_tss, Q_track_tss = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_ts, V_ts, pi_ts, Q_track_ts, pi_track_ts, T_track_ts, R_track_ts, planning_ts = trajectory_sampling(
env, gamma=gamma, n_episodes=n_episodes)
Q_tss.append(Q_ts) ; V_tss.append(V_ts) ; Q_track_tss.append(Q_track_ts)
Q_ts, V_ts, Q_track_ts = np.mean(Q_tss, axis=0), np.mean(V_tss, axis=0), np.mean(Q_track_tss, axis=0)
del Q_tss ; del V_tss ; del Q_track_tss
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
0%| | 0/10000 [00:00<?, ?it/s]
print_state_value_function(V_ts, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Trajectory Sampling:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_ts - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_ts, optimal_V)))
print()
print_action_value_function(Q_ts,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Trajectory Sampling action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_ts, optimal_Q)))
print()
print_policy(pi_ts, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_ts, mean_return_ts, mean_regret_ts = get_policy_metrics(
env, gamma=gamma, pi=pi_ts, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_ts, mean_return_ts, mean_regret_ts))
State-value function found by Trajectory Sampling: | 00 0.5377 | 01 0.4944 | 02 0.4656 | 03 0.4507 | | 04 0.5536 | | 06 0.3605 | | | 08 0.5857 | 09 0.6343 | 10 0.5992 | | | | 13 0.7337 | 14 0.8616 | | Optimal state-value function: | 00 0.542 | 01 0.4988 | 02 0.4707 | 03 0.4569 | | 04 0.5585 | | 06 0.3583 | | | 08 0.5918 | 09 0.6431 | 10 0.6152 | | | | 13 0.7417 | 14 0.8628 | | State-value function errors: | 00 -0.0 | 01 -0.0 | 02 -0.01 | 03 -0.01 | | 04 -0.0 | | 06 0.0 | | | 08 -0.01 | 09 -0.01 | 10 -0.02 | | | | 13 -0.01 | 14 -0.0 | | State-value function RMSE: 0.006 Trajectory Sampling action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╤════════╤════════╕ │ s │ < │ v │ > │ ^ │ * < │ * v │ * > │ * ^ │ er < │ er v │ er > │ er ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╪════════╪════════╡ │ 0 │ 0.538 │ 0.522 │ 0.52 │ 0.515 │ 0.542 │ 0.528 │ 0.528 │ 0.522 │ 0.004 │ 0.006 │ 0.008 │ 0.007 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 1 │ 0.345 │ 0.324 │ 0.313 │ 0.494 │ 0.343 │ 0.334 │ 0.32 │ 0.499 │ -0.001 │ 0.01 │ 0.007 │ 0.004 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 2 │ 0.432 │ 0.426 │ 0.419 │ 0.466 │ 0.438 │ 0.434 │ 0.424 │ 0.471 │ 0.006 │ 0.008 │ 0.005 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 3 │ 0.3 │ 0.29 │ 0.284 │ 0.451 │ 0.306 │ 0.306 │ 0.302 │ 0.457 │ 0.006 │ 0.016 │ 0.017 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 4 │ 0.554 │ 0.374 │ 0.382 │ 0.342 │ 0.558 │ 0.38 │ 0.374 │ 0.363 │ 0.005 │ 0.006 │ -0.008 │ 0.021 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 5 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 6 │ 0.333 │ 0.175 │ 0.335 │ 0.168 │ 0.358 │ 0.203 │ 0.358 │ 0.155 │ 0.025 │ 0.028 │ 0.023 │ -0.012 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 7 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 8 │ 0.365 │ 0.413 │ 0.404 │ 0.586 │ 0.38 │ 0.408 │ 0.397 │ 0.592 │ 0.015 │ -0.006 │ -0.008 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 9 │ 0.426 │ 0.634 │ 0.439 │ 0.373 │ 0.44 │ 0.643 │ 0.448 │ 0.398 │ 0.014 │ 0.009 │ 0.009 │ 0.025 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 10 │ 0.599 │ 0.485 │ 0.405 │ 0.331 │ 0.615 │ 0.497 │ 0.403 │ 0.33 │ 0.016 │ 0.012 │ -0.002 │ -0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 11 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 12 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 13 │ 0.461 │ 0.537 │ 0.734 │ 0.501 │ 0.457 │ 0.53 │ 0.742 │ 0.497 │ -0.004 │ -0.008 │ 0.008 │ -0.004 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 14 │ 0.73 │ 0.862 │ 0.813 │ 0.776 │ 0.733 │ 0.863 │ 0.821 │ 0.781 │ 0.002 │ 0.001 │ 0.008 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 15 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╧════════╧════════╛ Action-value function RMSE: 0.0095 정책: | 00 < | 01 ^ | 02 ^ | 03 ^ | | 04 < | | 06 < | | | 08 ^ | 09 v | 10 < | | | | 13 > | 14 v | | Reaches goal 74.00%. Obtains an average return of 0.5116. Regret of 0.0000
plot_value_function(
'Sarsa(λ) replacing estimates through time vs. true values',
np.max(Q_track_rsl, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Sarsa(λ) replacing estimates through time vs. true values (log scale)',
np.max(Q_track_rsl, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Sarsa(λ) replacing estimates through time (close up)',
np.max(Q_track_rsl, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Sarsa(λ) accumulating estimates through time vs. true values',
np.max(Q_track_asl, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Sarsa(λ) accumulating estimates through time vs. true values (log scale)',
np.max(Q_track_asl, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Sarsa(λ) accumulating estimates through time (close up)',
np.max(Q_track_asl, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Q(λ) replacing estimates through time vs. true values',
np.max(Q_track_rqll, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Q(λ) replacing estimates through time vs. true values (log scale)',
np.max(Q_track_rqll, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Q(λ) replacing estimates through time (close up)',
np.max(Q_track_rqll, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Q(λ) accumulating estimates through time vs. true values',
np.max(Q_track_aqll, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Q(λ) accumulating estimates through time vs. true values (log scale)',
np.max(Q_track_aqll, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Q(λ) accumulating estimates through time (close up)',
np.max(Q_track_aqll, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Dyna-Q estimates through time vs. true values',
np.max(Q_track_dq, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Dyna-Q estimates through time vs. true values (log scale)',
np.max(Q_track_dq, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Dyna-Q estimates through time (close up)',
np.max(Q_track_dq, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Trajectory Sampling estimates through time vs. true values',
np.max(Q_track_ts, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Trajectory Sampling estimates through time vs. true values (log scale)',
np.max(Q_track_ts, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Trajectory Sampling estimates through time (close up)',
np.max(Q_track_ts, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
rsl_success_rate_ma, rsl_mean_return_ma, rsl_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_rsl)
0%| | 0/10000 [00:00<?, ?it/s]
asl_success_rate_ma, asl_mean_return_ma, asl_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_asl)
0%| | 0/10000 [00:00<?, ?it/s]
rqll_success_rate_ma, rqll_mean_return_ma, rqll_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_rqll)
0%| | 0/10000 [00:00<?, ?it/s]
aqll_success_rate_ma, aqll_mean_return_ma, aqll_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_aqll)
0%| | 0/10000 [00:00<?, ?it/s]
dq_success_rate_ma, dq_mean_return_ma, dq_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_dq)
0%| | 0/10000 [00:00<?, ?it/s]
ts_success_rate_ma, ts_mean_return_ma, ts_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_ts)
0%| | 0/10000 [00:00<?, ?it/s]
plt.axhline(y=success_rate_op, color='k', linestyle='-', linewidth=1)
plt.text(int(len(rsl_success_rate_ma)*1.02), success_rate_op*1.01, 'π*')
plt.plot(rsl_success_rate_ma, '-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(asl_success_rate_ma, '--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(rqll_success_rate_ma, ':', linewidth=2, label='Q(λ) replacing')
plt.plot(aqll_success_rate_ma, '-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(dq_success_rate_ma, '-', linewidth=2, label='Dyna-Q')
plt.plot(ts_success_rate_ma, '--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=4, ncol=1)
plt.title('Policy success rate (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Success rate %')
plt.ylim(-1, 101)
plt.xticks(rotation=45)
plt.show()
plt.axhline(y=mean_return_op, color='k', linestyle='-', linewidth=1)
plt.text(int(len(rsl_mean_return_ma)*1.02), mean_return_op*1.01, 'π*')
plt.plot(rsl_mean_return_ma, '-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(asl_mean_return_ma, '--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(rqll_mean_return_ma, ':', linewidth=2, label='Q(λ) replacing')
plt.plot(aqll_mean_return_ma, '-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(dq_mean_return_ma, '-', linewidth=2, label='Dyna-Q')
plt.plot(ts_mean_return_ma, '--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=4, ncol=1)
plt.title('Policy episode return (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Return (Gt:T)')
plt.xticks(rotation=45)
plt.show()
plt.plot(rsl_mean_regret_ma, '-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(asl_mean_regret_ma, '--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(rqll_mean_regret_ma, ':', linewidth=2, label='Q(λ) replacing')
plt.plot(aqll_mean_regret_ma, '-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(dq_mean_regret_ma, '-', linewidth=2, label='Dyna-Q')
plt.plot(ts_mean_regret_ma, '--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=1, ncol=1)
plt.title('Policy episode regret (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Regret (q* - Q)')
plt.xticks(rotation=45)
plt.show()
plt.axhline(y=optimal_V[init_state], color='k', linestyle='-', linewidth=1)
plt.text(int(len(Q_track_rsl)*1.05), optimal_V[init_state]+.01, 'v*({})'.format(init_state))
plt.plot(moving_average(np.max(Q_track_rsl, axis=2).T[init_state]),
'-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(moving_average(np.max(Q_track_asl, axis=2).T[init_state]),
'--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(moving_average(np.max(Q_track_rqll, axis=2).T[init_state]),
':', linewidth=2, label='Q(λ) replacing')
plt.plot(moving_average(np.max(Q_track_aqll, axis=2).T[init_state]),
'-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(moving_average(np.max(Q_track_dq, axis=2).T[init_state]),
'-', linewidth=2, label='Dyna-Q')
plt.plot(moving_average(np.max(Q_track_ts, axis=2).T[init_state]),
'--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=4, ncol=1)
plt.title('Estimated expected return (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Estimated value of initial state V({})'.format(init_state))
plt.xticks(rotation=45)
plt.show()
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_rsl, axis=2) - optimal_V), axis=1)),
'-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_asl, axis=2) - optimal_V), axis=1)),
'--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_rqll, axis=2) - optimal_V), axis=1)),
':', linewidth=2, label='Q(λ) replacing')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_aqll, axis=2) - optimal_V), axis=1)),
'-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_dq, axis=2) - optimal_V), axis=1)),
'-', linewidth=2, label='Dyna-Q')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_ts, axis=2) - optimal_V), axis=1)),
'--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=1, ncol=1)
plt.title('State-value function estimation error (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Mean Absolute Error MAE(V, v*)')
plt.xticks(rotation=45)
plt.show()
plt.plot(moving_average(np.mean(np.abs(Q_track_rsl - optimal_Q), axis=(1,2))),
'-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(moving_average(np.mean(np.abs(Q_track_asl - optimal_Q), axis=(1,2))),
'--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(moving_average(np.mean(np.abs(Q_track_rqll - optimal_Q), axis=(1,2))),
':', linewidth=2, label='Q(λ) replacing')
plt.plot(moving_average(np.mean(np.abs(Q_track_aqll - optimal_Q), axis=(1,2))),
'-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(moving_average(np.mean(np.abs(Q_track_dq - optimal_Q), axis=(1,2))),
'-', linewidth=2, label='Dyna-Q')
plt.plot(moving_average(np.mean(np.abs(Q_track_ts - optimal_Q), axis=(1,2))),
'--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=1, ncol=1)
plt.title('Action-value function estimation error (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Mean Absolute Error MAE(Q, q*)')
plt.xticks(rotation=45)
plt.show()
env = gym.make('FrozenLake8x8-v0')
init_state = env.reset()
goal_state = 63
gamma = 0.99
n_episodes = 30000
P = env.env.P
n_cols, svf_prec, err_prec, avf_prec=8, 4, 2, 3
action_symbols=('<', 'v', '>', '^')
limit_items, limit_value = 5, 0.025
cu_limit_items, cu_limit_value, cu_episodes = 10, 0.0, 5000
plt.plot(decay_schedule(0.5, 0.01, 0.5, n_episodes),
'-', linewidth=2,
label='Alpha schedule')
plt.plot(decay_schedule(1.0, 0.1, 0.9, n_episodes),
':', linewidth=2,
label='Epsilon schedule')
plt.legend(loc=1, ncol=1)
plt.title('Alpha and epsilon schedules')
plt.xlabel('Episodes')
plt.ylabel('Hyperparameter values')
plt.xticks(rotation=45)
plt.show()
optimal_Q, optimal_V, optimal_pi = value_iteration(P, gamma=gamma)
print_state_value_function(optimal_V, P, n_cols=n_cols, prec=svf_prec, title='Optimal state-value function:')
print()
print_action_value_function(optimal_Q,
None,
action_symbols=action_symbols,
prec=avf_prec,
title='Optimal action-value function:')
print()
print_policy(optimal_pi, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_op, mean_return_op, mean_regret_op = get_policy_metrics(
env, gamma=gamma, pi=optimal_pi, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_op, mean_return_op, mean_regret_op))
Optimal state-value function: | 00 0.4146 | 01 0.4272 | 02 0.4461 | 03 0.4683 | 04 0.4924 | 05 0.5166 | 06 0.5353 | 07 0.541 | | 08 0.4117 | 09 0.4212 | 10 0.4375 | 11 0.4584 | 12 0.4832 | 13 0.5135 | 14 0.5458 | 15 0.5574 | | 16 0.3968 | 17 0.3938 | 18 0.3755 | | 20 0.4217 | 21 0.4938 | 22 0.5612 | 23 0.5859 | | 24 0.3693 | 25 0.353 | 26 0.3065 | 27 0.2004 | 28 0.3008 | | 30 0.569 | 31 0.6283 | | 32 0.3327 | 33 0.2914 | 34 0.1973 | | 36 0.2893 | 37 0.362 | 38 0.5348 | 39 0.6897 | | 40 0.3061 | | | 43 0.0863 | 44 0.2139 | 45 0.2727 | | 47 0.772 | | 48 0.2889 | | 50 0.0577 | 51 0.0475 | | 53 0.2505 | | 55 0.8778 | | 56 0.2804 | 57 0.2008 | 58 0.1273 | | 60 0.2396 | 61 0.4864 | 62 0.7371 | | Optimal action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╕ │ s │ < │ v │ > │ ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╡ │ 0 │ 0.41 │ 0.414 │ 0.414 │ 0.415 │ ├─────┼───────┼───────┼───────┼───────┤ │ 1 │ 0.417 │ 0.423 │ 0.427 │ 0.425 │ ├─────┼───────┼───────┼───────┼───────┤ │ 2 │ 0.433 │ 0.44 │ 0.446 │ 0.443 │ ├─────┼───────┼───────┼───────┼───────┤ │ 3 │ 0.453 │ 0.461 │ 0.468 │ 0.464 │ ├─────┼───────┼───────┼───────┼───────┤ │ 4 │ 0.477 │ 0.484 │ 0.492 │ 0.488 │ ├─────┼───────┼───────┼───────┼───────┤ │ 5 │ 0.502 │ 0.509 │ 0.517 │ 0.51 │ ├─────┼───────┼───────┼───────┼───────┤ │ 6 │ 0.527 │ 0.529 │ 0.535 │ 0.526 │ ├─────┼───────┼───────┼───────┼───────┤ │ 7 │ 0.539 │ 0.539 │ 0.541 │ 0.534 │ ├─────┼───────┼───────┼───────┼───────┤ │ 8 │ 0.404 │ 0.406 │ 0.407 │ 0.412 │ ├─────┼───────┼───────┼───────┼───────┤ │ 9 │ 0.407 │ 0.41 │ 0.415 │ 0.421 │ ├─────┼───────┼───────┼───────┼───────┤ │ 10 │ 0.41 │ 0.414 │ 0.422 │ 0.437 │ ├─────┼───────┼───────┼───────┼───────┤ │ 11 │ 0.299 │ 0.304 │ 0.314 │ 0.458 │ ├─────┼───────┼───────┼───────┼───────┤ │ 12 │ 0.453 │ 0.46 │ 0.471 │ 0.483 │ ├─────┼───────┼───────┼───────┼───────┤ │ 13 │ 0.493 │ 0.503 │ 0.514 │ 0.51 │ ├─────┼───────┼───────┼───────┼───────┤ │ 14 │ 0.531 │ 0.539 │ 0.546 │ 0.53 │ ├─────┼───────┼───────┼───────┼───────┤ │ 15 │ 0.552 │ 0.557 │ 0.556 │ 0.543 │ ├─────┼───────┼───────┼───────┼───────┤ │ 16 │ 0.389 │ 0.383 │ 0.388 │ 0.397 │ ├─────┼───────┼───────┼───────┼───────┤ │ 17 │ 0.386 │ 0.371 │ 0.379 │ 0.394 │ ├─────┼───────┼───────┼───────┼───────┤ │ 18 │ 0.375 │ 0.231 │ 0.246 │ 0.274 │ ├─────┼───────┼───────┼───────┼───────┤ │ 19 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 20 │ 0.259 │ 0.262 │ 0.422 │ 0.322 │ ├─────┼───────┼───────┼───────┼───────┤ │ 21 │ 0.309 │ 0.324 │ 0.355 │ 0.494 │ ├─────┼───────┼───────┼───────┼───────┤ │ 22 │ 0.531 │ 0.544 │ 0.561 │ 0.536 │ ├─────┼───────┼───────┼───────┼───────┤ │ 23 │ 0.576 │ 0.586 │ 0.585 │ 0.562 │ ├─────┼───────┼───────┼───────┼───────┤ │ 24 │ 0.363 │ 0.348 │ 0.357 │ 0.369 │ ├─────┼───────┼───────┼───────┼───────┤ │ 25 │ 0.348 │ 0.319 │ 0.327 │ 0.353 │ ├─────┼───────┼───────┼───────┼───────┤ │ 26 │ 0.306 │ 0.248 │ 0.255 │ 0.307 │ ├─────┼───────┼───────┼───────┼───────┤ │ 27 │ 0.101 │ 0.2 │ 0.099 │ 0.2 │ ├─────┼───────┼───────┼───────┼───────┤ │ 28 │ 0.301 │ 0.162 │ 0.235 │ 0.205 │ ├─────┼───────┼───────┼───────┼───────┤ │ 29 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 30 │ 0.362 │ 0.384 │ 0.569 │ 0.393 │ ├─────┼───────┼───────┼───────┼───────┤ │ 31 │ 0.609 │ 0.623 │ 0.628 │ 0.588 │ ├─────┼───────┼───────┼───────┼───────┤ │ 32 │ 0.333 │ 0.307 │ 0.319 │ 0.328 │ ├─────┼───────┼───────┼───────┼───────┤ │ 33 │ 0.226 │ 0.175 │ 0.182 │ 0.291 │ ├─────┼───────┼───────┼───────┼───────┤ │ 34 │ 0.197 │ 0.096 │ 0.101 │ 0.197 │ ├─────┼───────┼───────┼───────┼───────┤ │ 35 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 36 │ 0.17 │ 0.19 │ 0.289 │ 0.219 │ ├─────┼───────┼───────┼───────┼───────┤ │ 37 │ 0.185 │ 0.362 │ 0.266 │ 0.272 │ ├─────┼───────┼───────┼───────┼───────┤ │ 38 │ 0.307 │ 0.347 │ 0.415 │ 0.535 │ ├─────┼───────┼───────┼───────┼───────┤ │ 39 │ 0.639 │ 0.659 │ 0.69 │ 0.611 │ ├─────┼───────┼───────┼───────┼───────┤ │ 40 │ 0.306 │ 0.196 │ 0.205 │ 0.211 │ ├─────┼───────┼───────┼───────┼───────┤ │ 41 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 42 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 43 │ 0.016 │ 0.086 │ 0.086 │ 0.071 │ ├─────┼───────┼───────┼───────┼───────┤ │ 44 │ 0.124 │ 0.118 │ 0.185 │ 0.214 │ ├─────┼───────┼───────┼───────┼───────┤ │ 45 │ 0.273 │ 0.153 │ 0.202 │ 0.19 │ ├─────┼───────┼───────┼───────┼───────┤ │ 46 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 47 │ 0.517 │ 0.544 │ 0.772 │ 0.482 │ ├─────┼───────┼───────┼───────┼───────┤ │ 48 │ 0.289 │ 0.188 │ 0.194 │ 0.196 │ ├─────┼───────┼───────┼───────┼───────┤ │ 49 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 50 │ 0.042 │ 0.058 │ 0.058 │ 0.016 │ ├─────┼───────┼───────┼───────┼───────┤ │ 51 │ 0.048 │ 0.019 │ 0.028 │ 0.048 │ ├─────┼───────┼───────┼───────┼───────┤ │ 52 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 53 │ 0.251 │ 0.161 │ 0.251 │ 0.09 │ ├─────┼───────┼───────┼───────┼───────┤ │ 54 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 55 │ 0.588 │ 0.623 │ 0.878 │ 0.544 │ ├─────┼───────┼───────┼───────┼───────┤ │ 56 │ 0.28 │ 0.251 │ 0.254 │ 0.254 │ ├─────┼───────┼───────┼───────┼───────┤ │ 57 │ 0.159 │ 0.201 │ 0.108 │ 0.135 │ ├─────┼───────┼───────┼───────┼───────┤ │ 58 │ 0.127 │ 0.108 │ 0.061 │ 0.085 │ ├─────┼───────┼───────┼───────┼───────┤ │ 59 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┤ │ 60 │ 0.079 │ 0.24 │ 0.24 │ 0.161 │ ├─────┼───────┼───────┼───────┼───────┤ │ 61 │ 0.322 │ 0.483 │ 0.486 │ 0.405 │ ├─────┼───────┼───────┼───────┼───────┤ │ 62 │ 0.404 │ 0.737 │ 0.577 │ 0.494 │ ├─────┼───────┼───────┼───────┼───────┤ │ 63 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╛ 정책: | 00 ^ | 01 > | 02 > | 03 > | 04 > | 05 > | 06 > | 07 > | | 08 ^ | 09 ^ | 10 ^ | 11 ^ | 12 ^ | 13 > | 14 > | 15 v | | 16 ^ | 17 ^ | 18 < | | 20 > | 21 ^ | 22 > | 23 v | | 24 ^ | 25 ^ | 26 ^ | 27 v | 28 < | | 30 > | 31 > | | 32 < | 33 ^ | 34 < | | 36 > | 37 v | 38 ^ | 39 > | | 40 < | | | 43 v | 44 ^ | 45 < | | 47 > | | 48 < | | 50 v | 51 < | | 53 < | | 55 > | | 56 < | 57 v | 58 < | | 60 v | 61 > | 62 v | | Reaches goal 81.00%. Obtains an average return of 0.3994. Regret of 0.0000
Q_rsls, V_rsls, Q_track_rsls = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_rsl, V_rsl, pi_rsl, Q_track_rsl, pi_track_rsl = sarsa_lambda(env, gamma=gamma, n_episodes=n_episodes)
Q_rsls.append(Q_rsl) ; V_rsls.append(V_rsl) ; Q_track_rsls.append(Q_track_rsl)
Q_rsl, V_rsl, Q_track_rsl = np.mean(Q_rsls, axis=0), np.mean(V_rsls, axis=0), np.mean(Q_track_rsls, axis=0)
del Q_rsls ; del V_rsls ; del Q_track_rsls
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
print_state_value_function(V_rsl, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Sarsa(λ) replacing:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_rsl - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_rsl, optimal_V)))
print()
print_action_value_function(Q_rsl,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Sarsa(λ) replacing action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_rsl, optimal_Q)))
print()
print_policy(pi_rsl, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_rsl, mean_return_rsl, mean_regret_rsl = get_policy_metrics(
env, gamma=gamma, pi=pi_rsl, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_rsl, mean_return_rsl, mean_regret_rsl))
State-value function found by Sarsa(λ) replacing: | 00 0.2416 | 01 0.2499 | 02 0.2651 | 03 0.2854 | 04 0.3088 | 05 0.3341 | 06 0.353 | 07 0.3573 | | 08 0.2375 | 09 0.2447 | 10 0.2578 | 11 0.2771 | 12 0.299 | 13 0.3286 | 14 0.3617 | 15 0.3746 | | 16 0.2115 | 17 0.2091 | 18 0.1956 | | 20 0.2412 | 21 0.3088 | 22 0.3711 | 23 0.4023 | | 24 0.1504 | 25 0.1559 | 26 0.139 | 27 0.0886 | 28 0.1472 | | 30 0.3798 | 31 0.4486 | | 32 0.0533 | 33 0.0487 | 34 0.0364 | | 36 0.1404 | 37 0.2004 | 38 0.3508 | 39 0.5192 | | 40 0.0066 | | | 43 0.0341 | 44 0.099 | 45 0.136 | | 47 0.6163 | | 48 0.0019 | | 50 0.0009 | 51 0.0062 | | 53 0.122 | | 55 0.7921 | | 56 0.0005 | 57 0.0001 | 58 0.0002 | | 60 0.0668 | 61 0.2667 | 62 0.5133 | | Optimal state-value function: | 00 0.4146 | 01 0.4272 | 02 0.4461 | 03 0.4683 | 04 0.4924 | 05 0.5166 | 06 0.5353 | 07 0.541 | | 08 0.4117 | 09 0.4212 | 10 0.4375 | 11 0.4584 | 12 0.4832 | 13 0.5135 | 14 0.5458 | 15 0.5574 | | 16 0.3968 | 17 0.3938 | 18 0.3755 | | 20 0.4217 | 21 0.4938 | 22 0.5612 | 23 0.5859 | | 24 0.3693 | 25 0.353 | 26 0.3065 | 27 0.2004 | 28 0.3008 | | 30 0.569 | 31 0.6283 | | 32 0.3327 | 33 0.2914 | 34 0.1973 | | 36 0.2893 | 37 0.362 | 38 0.5348 | 39 0.6897 | | 40 0.3061 | | | 43 0.0863 | 44 0.2139 | 45 0.2727 | | 47 0.772 | | 48 0.2889 | | 50 0.0577 | 51 0.0475 | | 53 0.2505 | | 55 0.8778 | | 56 0.2804 | 57 0.2008 | 58 0.1273 | | 60 0.2396 | 61 0.4864 | 62 0.7371 | | State-value function errors: | 00 -0.17 | 01 -0.18 | 02 -0.18 | 03 -0.18 | 04 -0.18 | 05 -0.18 | 06 -0.18 | 07 -0.18 | | 08 -0.17 | 09 -0.18 | 10 -0.18 | 11 -0.18 | 12 -0.18 | 13 -0.18 | 14 -0.18 | 15 -0.18 | | 16 -0.19 | 17 -0.18 | 18 -0.18 | | 20 -0.18 | 21 -0.19 | 22 -0.19 | 23 -0.18 | | 24 -0.22 | 25 -0.2 | 26 -0.17 | 27 -0.11 | 28 -0.15 | | 30 -0.19 | 31 -0.18 | | 32 -0.28 | 33 -0.24 | 34 -0.16 | | 36 -0.15 | 37 -0.16 | 38 -0.18 | 39 -0.17 | | 40 -0.3 | | | 43 -0.05 | 44 -0.11 | 45 -0.14 | | 47 -0.16 | | 48 -0.29 | | 50 -0.06 | 51 -0.04 | | 53 -0.13 | | 55 -0.09 | | 56 -0.28 | 57 -0.2 | 58 -0.13 | | 60 -0.17 | 61 -0.22 | 62 -0.22 | | State-value function RMSE: 0.1666 Sarsa(λ) replacing action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╤════════╤════════╕ │ s │ < │ v │ > │ ^ │ * < │ * v │ * > │ * ^ │ er < │ er v │ er > │ er ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╪════════╪════════╡ │ 0 │ 0.235 │ 0.237 │ 0.237 │ 0.242 │ 0.41 │ 0.414 │ 0.414 │ 0.415 │ 0.175 │ 0.176 │ 0.176 │ 0.173 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 1 │ 0.239 │ 0.244 │ 0.25 │ 0.246 │ 0.417 │ 0.423 │ 0.427 │ 0.425 │ 0.178 │ 0.179 │ 0.177 │ 0.179 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 2 │ 0.25 │ 0.257 │ 0.265 │ 0.259 │ 0.433 │ 0.44 │ 0.446 │ 0.443 │ 0.183 │ 0.183 │ 0.181 │ 0.184 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 3 │ 0.266 │ 0.275 │ 0.284 │ 0.281 │ 0.453 │ 0.461 │ 0.468 │ 0.464 │ 0.187 │ 0.186 │ 0.184 │ 0.183 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 4 │ 0.29 │ 0.299 │ 0.309 │ 0.302 │ 0.477 │ 0.484 │ 0.492 │ 0.488 │ 0.186 │ 0.186 │ 0.184 │ 0.185 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 5 │ 0.316 │ 0.323 │ 0.334 │ 0.324 │ 0.502 │ 0.509 │ 0.517 │ 0.51 │ 0.186 │ 0.186 │ 0.182 │ 0.186 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 6 │ 0.345 │ 0.347 │ 0.353 │ 0.344 │ 0.527 │ 0.529 │ 0.535 │ 0.526 │ 0.182 │ 0.182 │ 0.182 │ 0.182 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 7 │ 0.356 │ 0.356 │ 0.355 │ 0.354 │ 0.539 │ 0.539 │ 0.541 │ 0.534 │ 0.183 │ 0.183 │ 0.186 │ 0.179 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 8 │ 0.215 │ 0.215 │ 0.217 │ 0.237 │ 0.404 │ 0.406 │ 0.407 │ 0.412 │ 0.189 │ 0.19 │ 0.19 │ 0.174 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 9 │ 0.223 │ 0.223 │ 0.227 │ 0.245 │ 0.407 │ 0.41 │ 0.415 │ 0.421 │ 0.183 │ 0.188 │ 0.188 │ 0.177 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 10 │ 0.224 │ 0.227 │ 0.233 │ 0.258 │ 0.41 │ 0.414 │ 0.422 │ 0.437 │ 0.186 │ 0.187 │ 0.189 │ 0.18 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 11 │ 0.184 │ 0.177 │ 0.19 │ 0.277 │ 0.299 │ 0.304 │ 0.314 │ 0.458 │ 0.114 │ 0.127 │ 0.124 │ 0.181 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 12 │ 0.262 │ 0.271 │ 0.278 │ 0.299 │ 0.453 │ 0.46 │ 0.471 │ 0.483 │ 0.191 │ 0.189 │ 0.193 │ 0.184 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 13 │ 0.302 │ 0.308 │ 0.327 │ 0.322 │ 0.493 │ 0.503 │ 0.514 │ 0.51 │ 0.191 │ 0.195 │ 0.186 │ 0.188 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 14 │ 0.346 │ 0.351 │ 0.362 │ 0.347 │ 0.531 │ 0.539 │ 0.546 │ 0.53 │ 0.185 │ 0.187 │ 0.184 │ 0.183 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 15 │ 0.366 │ 0.372 │ 0.369 │ 0.361 │ 0.552 │ 0.557 │ 0.556 │ 0.543 │ 0.186 │ 0.186 │ 0.187 │ 0.182 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 16 │ 0.12 │ 0.114 │ 0.118 │ 0.212 │ 0.389 │ 0.383 │ 0.388 │ 0.397 │ 0.269 │ 0.269 │ 0.269 │ 0.185 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 17 │ 0.127 │ 0.117 │ 0.121 │ 0.209 │ 0.386 │ 0.371 │ 0.379 │ 0.394 │ 0.259 │ 0.254 │ 0.258 │ 0.185 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 18 │ 0.196 │ 0.07 │ 0.079 │ 0.092 │ 0.375 │ 0.231 │ 0.246 │ 0.274 │ 0.18 │ 0.161 │ 0.166 │ 0.182 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 19 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 20 │ 0.107 │ 0.119 │ 0.241 │ 0.148 │ 0.259 │ 0.262 │ 0.422 │ 0.322 │ 0.152 │ 0.143 │ 0.181 │ 0.174 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 21 │ 0.156 │ 0.178 │ 0.204 │ 0.309 │ 0.309 │ 0.324 │ 0.355 │ 0.494 │ 0.152 │ 0.146 │ 0.151 │ 0.185 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 22 │ 0.343 │ 0.353 │ 0.371 │ 0.347 │ 0.531 │ 0.544 │ 0.561 │ 0.536 │ 0.188 │ 0.192 │ 0.19 │ 0.189 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 23 │ 0.388 │ 0.394 │ 0.398 │ 0.377 │ 0.576 │ 0.586 │ 0.585 │ 0.562 │ 0.188 │ 0.192 │ 0.187 │ 0.185 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 24 │ 0.047 │ 0.04 │ 0.049 │ 0.15 │ 0.363 │ 0.348 │ 0.357 │ 0.369 │ 0.315 │ 0.308 │ 0.308 │ 0.219 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 25 │ 0.053 │ 0.046 │ 0.051 │ 0.156 │ 0.348 │ 0.319 │ 0.327 │ 0.353 │ 0.295 │ 0.273 │ 0.277 │ 0.197 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 26 │ 0.068 │ 0.046 │ 0.056 │ 0.139 │ 0.306 │ 0.248 │ 0.255 │ 0.307 │ 0.238 │ 0.201 │ 0.199 │ 0.168 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 27 │ 0.018 │ 0.053 │ 0.021 │ 0.079 │ 0.101 │ 0.2 │ 0.099 │ 0.2 │ 0.083 │ 0.148 │ 0.078 │ 0.121 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 28 │ 0.147 │ 0.049 │ 0.081 │ 0.065 │ 0.301 │ 0.162 │ 0.235 │ 0.205 │ 0.154 │ 0.112 │ 0.154 │ 0.14 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 29 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 30 │ 0.241 │ 0.253 │ 0.38 │ 0.271 │ 0.362 │ 0.384 │ 0.569 │ 0.393 │ 0.12 │ 0.131 │ 0.189 │ 0.122 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 31 │ 0.418 │ 0.431 │ 0.449 │ 0.4 │ 0.609 │ 0.623 │ 0.628 │ 0.588 │ 0.191 │ 0.191 │ 0.18 │ 0.188 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 32 │ 0.012 │ 0.009 │ 0.013 │ 0.053 │ 0.333 │ 0.307 │ 0.319 │ 0.328 │ 0.321 │ 0.298 │ 0.306 │ 0.275 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 33 │ 0.014 │ 0.008 │ 0.01 │ 0.049 │ 0.226 │ 0.175 │ 0.182 │ 0.291 │ 0.212 │ 0.167 │ 0.171 │ 0.243 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 34 │ 0.022 │ 0.003 │ 0.006 │ 0.023 │ 0.197 │ 0.096 │ 0.101 │ 0.197 │ 0.175 │ 0.093 │ 0.095 │ 0.174 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 35 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 36 │ 0.05 │ 0.068 │ 0.14 │ 0.083 │ 0.17 │ 0.19 │ 0.289 │ 0.219 │ 0.119 │ 0.122 │ 0.149 │ 0.136 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 37 │ 0.074 │ 0.2 │ 0.145 │ 0.135 │ 0.185 │ 0.362 │ 0.266 │ 0.272 │ 0.112 │ 0.162 │ 0.121 │ 0.137 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 38 │ 0.187 │ 0.216 │ 0.285 │ 0.351 │ 0.307 │ 0.347 │ 0.415 │ 0.535 │ 0.12 │ 0.131 │ 0.131 │ 0.184 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 39 │ 0.454 │ 0.481 │ 0.519 │ 0.428 │ 0.639 │ 0.659 │ 0.69 │ 0.611 │ 0.184 │ 0.178 │ 0.171 │ 0.184 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 40 │ 0.005 │ 0.001 │ 0.002 │ 0.005 │ 0.306 │ 0.196 │ 0.205 │ 0.211 │ 0.301 │ 0.196 │ 0.203 │ 0.206 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 41 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 42 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 43 │ 0 │ 0.013 │ 0.015 │ 0.021 │ 0.016 │ 0.086 │ 0.086 │ 0.071 │ 0.015 │ 0.073 │ 0.072 │ 0.05 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 44 │ 0.033 │ 0.028 │ 0.059 │ 0.091 │ 0.124 │ 0.118 │ 0.185 │ 0.214 │ 0.09 │ 0.09 │ 0.127 │ 0.123 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 45 │ 0.136 │ 0.045 │ 0.073 │ 0.059 │ 0.273 │ 0.153 │ 0.202 │ 0.19 │ 0.137 │ 0.108 │ 0.129 │ 0.131 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 46 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 47 │ 0.413 │ 0.467 │ 0.616 │ 0.366 │ 0.517 │ 0.544 │ 0.772 │ 0.482 │ 0.104 │ 0.077 │ 0.156 │ 0.117 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 48 │ 0.001 │ 0 │ 0.001 │ 0 │ 0.289 │ 0.188 │ 0.194 │ 0.196 │ 0.288 │ 0.188 │ 0.192 │ 0.196 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 49 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 50 │ 0 │ 0 │ 0.001 │ 0 │ 0.042 │ 0.058 │ 0.058 │ 0.016 │ 0.042 │ 0.058 │ 0.057 │ 0.016 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 51 │ 0.003 │ 0 │ 0.001 │ 0.003 │ 0.048 │ 0.019 │ 0.028 │ 0.048 │ 0.044 │ 0.019 │ 0.028 │ 0.044 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 52 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 53 │ 0.067 │ 0.053 │ 0.097 │ 0.015 │ 0.251 │ 0.161 │ 0.251 │ 0.09 │ 0.184 │ 0.108 │ 0.154 │ 0.075 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 54 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 55 │ 0.519 │ 0.573 │ 0.792 │ 0.434 │ 0.588 │ 0.623 │ 0.878 │ 0.544 │ 0.069 │ 0.05 │ 0.086 │ 0.11 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 56 │ 0 │ 0 │ 0 │ 0 │ 0.28 │ 0.251 │ 0.254 │ 0.254 │ 0.28 │ 0.251 │ 0.254 │ 0.254 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 57 │ 0 │ 0 │ 0 │ 0 │ 0.159 │ 0.201 │ 0.108 │ 0.135 │ 0.159 │ 0.201 │ 0.108 │ 0.135 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 58 │ 0 │ 0 │ 0 │ 0 │ 0.127 │ 0.108 │ 0.061 │ 0.085 │ 0.127 │ 0.108 │ 0.061 │ 0.085 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 59 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 60 │ 0.003 │ 0.009 │ 0.055 │ 0.021 │ 0.079 │ 0.24 │ 0.24 │ 0.161 │ 0.076 │ 0.23 │ 0.185 │ 0.14 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 61 │ 0.04 │ 0.13 │ 0.176 │ 0.138 │ 0.322 │ 0.483 │ 0.486 │ 0.405 │ 0.282 │ 0.353 │ 0.31 │ 0.267 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 62 │ 0.043 │ 0.394 │ 0.131 │ 0.228 │ 0.404 │ 0.737 │ 0.577 │ 0.494 │ 0.361 │ 0.343 │ 0.445 │ 0.266 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 63 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╧════════╧════════╛ Action-value function RMSE: 0.17 정책: | 00 ^ | 01 > | 02 > | 03 > | 04 > | 05 > | 06 > | 07 < | | 08 ^ | 09 ^ | 10 ^ | 11 ^ | 12 ^ | 13 ^ | 14 > | 15 > | | 16 ^ | 17 ^ | 18 < | | 20 > | 21 ^ | 22 > | 23 v | | 24 ^ | 25 ^ | 26 ^ | 27 ^ | 28 < | | 30 > | 31 > | | 32 ^ | 33 ^ | 34 ^ | | 36 > | 37 v | 38 ^ | 39 > | | 40 ^ | | | 43 ^ | 44 ^ | 45 < | | 47 > | | 48 > | | 50 > | 51 ^ | | 53 > | | 55 > | | 56 > | 57 v | 58 < | | 60 > | 61 v | 62 v | | Reaches goal 92.00%. Obtains an average return of 0.4264. Regret of 0.0207
Q_asls, V_asls, Q_track_asls = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_asl, V_asl, pi_asl, Q_track_asl, pi_track_asl = sarsa_lambda(env, gamma=gamma,
replacing_traces=False,
n_episodes=n_episodes)
Q_asls.append(Q_asl) ; V_asls.append(V_asl) ; Q_track_asls.append(Q_track_asl)
Q_asl, V_asl, Q_track_asl = np.mean(Q_asls, axis=0), np.mean(V_asls, axis=0), np.mean(Q_track_asls, axis=0)
del Q_asls ; del V_asls ; del Q_track_asls
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
print_state_value_function(V_asl, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Sarsa(λ) accumulating:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_asl - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_asl, optimal_V)))
print()
print_action_value_function(Q_asl,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Sarsa(λ) accumulating action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_asl, optimal_Q)))
print()
print_policy(pi_asl, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_asl, mean_return_asl, mean_regret_asl = get_policy_metrics(
env, gamma=gamma, pi=pi_asl, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_asl, mean_return_asl, mean_regret_asl))
State-value function found by Sarsa(λ) accumulating: | 00 0.2398 | 01 0.248 | 02 0.2628 | 03 0.2829 | 04 0.3103 | 05 0.3389 | 06 0.3578 | 07 0.3628 | | 08 0.2364 | 09 0.2427 | 10 0.2557 | 11 0.2745 | 12 0.2989 | 13 0.3309 | 14 0.365 | 15 0.3766 | | 16 0.2147 | 17 0.2118 | 18 0.1967 | | 20 0.2444 | 21 0.3088 | 22 0.3763 | 23 0.4018 | | 24 0.1644 | 25 0.1636 | 26 0.1449 | 27 0.0939 | 28 0.1546 | | 30 0.3848 | 31 0.4434 | | 32 0.0544 | 33 0.0521 | 34 0.0354 | | 36 0.1426 | 37 0.2024 | 38 0.3576 | 39 0.5073 | | 40 0.0075 | | | 43 0.0276 | 44 0.0905 | 45 0.123 | | 47 0.6022 | | 48 0.0016 | | 50 0.0009 | 51 0.0063 | | 53 0.1092 | | 55 0.7683 | | 56 0.0007 | 57 0.0002 | 58 0.0001 | | 60 0.0747 | 61 0.2227 | 62 0.4008 | | Optimal state-value function: | 00 0.4146 | 01 0.4272 | 02 0.4461 | 03 0.4683 | 04 0.4924 | 05 0.5166 | 06 0.5353 | 07 0.541 | | 08 0.4117 | 09 0.4212 | 10 0.4375 | 11 0.4584 | 12 0.4832 | 13 0.5135 | 14 0.5458 | 15 0.5574 | | 16 0.3968 | 17 0.3938 | 18 0.3755 | | 20 0.4217 | 21 0.4938 | 22 0.5612 | 23 0.5859 | | 24 0.3693 | 25 0.353 | 26 0.3065 | 27 0.2004 | 28 0.3008 | | 30 0.569 | 31 0.6283 | | 32 0.3327 | 33 0.2914 | 34 0.1973 | | 36 0.2893 | 37 0.362 | 38 0.5348 | 39 0.6897 | | 40 0.3061 | | | 43 0.0863 | 44 0.2139 | 45 0.2727 | | 47 0.772 | | 48 0.2889 | | 50 0.0577 | 51 0.0475 | | 53 0.2505 | | 55 0.8778 | | 56 0.2804 | 57 0.2008 | 58 0.1273 | | 60 0.2396 | 61 0.4864 | 62 0.7371 | | State-value function errors: | 00 -0.17 | 01 -0.18 | 02 -0.18 | 03 -0.19 | 04 -0.18 | 05 -0.18 | 06 -0.18 | 07 -0.18 | | 08 -0.18 | 09 -0.18 | 10 -0.18 | 11 -0.18 | 12 -0.18 | 13 -0.18 | 14 -0.18 | 15 -0.18 | | 16 -0.18 | 17 -0.18 | 18 -0.18 | | 20 -0.18 | 21 -0.19 | 22 -0.18 | 23 -0.18 | | 24 -0.2 | 25 -0.19 | 26 -0.16 | 27 -0.11 | 28 -0.15 | | 30 -0.18 | 31 -0.18 | | 32 -0.28 | 33 -0.24 | 34 -0.16 | | 36 -0.15 | 37 -0.16 | 38 -0.18 | 39 -0.18 | | 40 -0.3 | | | 43 -0.06 | 44 -0.12 | 45 -0.15 | | 47 -0.17 | | 48 -0.29 | | 50 -0.06 | 51 -0.04 | | 53 -0.14 | | 55 -0.11 | | 56 -0.28 | 57 -0.2 | 58 -0.13 | | 60 -0.16 | 61 -0.26 | 62 -0.34 | | State-value function RMSE: 0.1702 Sarsa(λ) accumulating action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╤════════╤════════╕ │ s │ < │ v │ > │ ^ │ * < │ * v │ * > │ * ^ │ er < │ er v │ er > │ er ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╪════════╪════════╡ │ 0 │ 0.234 │ 0.236 │ 0.236 │ 0.24 │ 0.41 │ 0.414 │ 0.414 │ 0.415 │ 0.175 │ 0.177 │ 0.177 │ 0.175 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 1 │ 0.238 │ 0.243 │ 0.247 │ 0.245 │ 0.417 │ 0.423 │ 0.427 │ 0.425 │ 0.179 │ 0.18 │ 0.18 │ 0.18 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 2 │ 0.249 │ 0.255 │ 0.263 │ 0.258 │ 0.433 │ 0.44 │ 0.446 │ 0.443 │ 0.183 │ 0.185 │ 0.183 │ 0.185 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 3 │ 0.266 │ 0.274 │ 0.281 │ 0.28 │ 0.453 │ 0.461 │ 0.468 │ 0.464 │ 0.187 │ 0.187 │ 0.188 │ 0.184 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 4 │ 0.288 │ 0.299 │ 0.31 │ 0.301 │ 0.477 │ 0.484 │ 0.492 │ 0.488 │ 0.188 │ 0.186 │ 0.182 │ 0.186 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 5 │ 0.315 │ 0.321 │ 0.339 │ 0.324 │ 0.502 │ 0.509 │ 0.517 │ 0.51 │ 0.187 │ 0.188 │ 0.178 │ 0.185 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 6 │ 0.345 │ 0.347 │ 0.358 │ 0.344 │ 0.527 │ 0.529 │ 0.535 │ 0.526 │ 0.182 │ 0.182 │ 0.177 │ 0.181 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 7 │ 0.359 │ 0.356 │ 0.36 │ 0.353 │ 0.539 │ 0.539 │ 0.541 │ 0.534 │ 0.18 │ 0.183 │ 0.181 │ 0.18 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 8 │ 0.22 │ 0.219 │ 0.221 │ 0.236 │ 0.404 │ 0.406 │ 0.407 │ 0.412 │ 0.184 │ 0.187 │ 0.185 │ 0.175 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 9 │ 0.225 │ 0.226 │ 0.229 │ 0.243 │ 0.407 │ 0.41 │ 0.415 │ 0.421 │ 0.181 │ 0.184 │ 0.186 │ 0.178 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 10 │ 0.225 │ 0.226 │ 0.232 │ 0.256 │ 0.41 │ 0.414 │ 0.422 │ 0.437 │ 0.185 │ 0.188 │ 0.19 │ 0.182 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 11 │ 0.178 │ 0.182 │ 0.188 │ 0.275 │ 0.299 │ 0.304 │ 0.314 │ 0.458 │ 0.121 │ 0.122 │ 0.126 │ 0.184 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 12 │ 0.26 │ 0.269 │ 0.281 │ 0.299 │ 0.453 │ 0.46 │ 0.471 │ 0.483 │ 0.193 │ 0.191 │ 0.19 │ 0.184 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 13 │ 0.301 │ 0.309 │ 0.324 │ 0.326 │ 0.493 │ 0.503 │ 0.514 │ 0.51 │ 0.192 │ 0.194 │ 0.189 │ 0.184 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 14 │ 0.348 │ 0.355 │ 0.365 │ 0.348 │ 0.531 │ 0.539 │ 0.546 │ 0.53 │ 0.183 │ 0.183 │ 0.181 │ 0.182 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 15 │ 0.369 │ 0.375 │ 0.372 │ 0.361 │ 0.552 │ 0.557 │ 0.556 │ 0.543 │ 0.183 │ 0.183 │ 0.184 │ 0.182 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 16 │ 0.128 │ 0.124 │ 0.13 │ 0.215 │ 0.389 │ 0.383 │ 0.388 │ 0.397 │ 0.261 │ 0.259 │ 0.258 │ 0.182 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 17 │ 0.133 │ 0.118 │ 0.127 │ 0.212 │ 0.386 │ 0.371 │ 0.379 │ 0.394 │ 0.253 │ 0.253 │ 0.252 │ 0.182 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 18 │ 0.197 │ 0.069 │ 0.086 │ 0.1 │ 0.375 │ 0.231 │ 0.246 │ 0.274 │ 0.179 │ 0.162 │ 0.159 │ 0.175 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 19 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 20 │ 0.115 │ 0.109 │ 0.244 │ 0.15 │ 0.259 │ 0.262 │ 0.422 │ 0.322 │ 0.144 │ 0.153 │ 0.177 │ 0.172 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 21 │ 0.153 │ 0.168 │ 0.194 │ 0.309 │ 0.309 │ 0.324 │ 0.355 │ 0.494 │ 0.155 │ 0.156 │ 0.16 │ 0.185 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 22 │ 0.345 │ 0.351 │ 0.376 │ 0.351 │ 0.531 │ 0.544 │ 0.561 │ 0.536 │ 0.186 │ 0.193 │ 0.185 │ 0.185 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 23 │ 0.393 │ 0.397 │ 0.4 │ 0.381 │ 0.576 │ 0.586 │ 0.585 │ 0.562 │ 0.184 │ 0.189 │ 0.184 │ 0.182 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 24 │ 0.053 │ 0.044 │ 0.052 │ 0.164 │ 0.363 │ 0.348 │ 0.357 │ 0.369 │ 0.31 │ 0.304 │ 0.305 │ 0.205 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 25 │ 0.059 │ 0.048 │ 0.053 │ 0.164 │ 0.348 │ 0.319 │ 0.327 │ 0.353 │ 0.289 │ 0.272 │ 0.274 │ 0.189 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 26 │ 0.066 │ 0.05 │ 0.053 │ 0.145 │ 0.306 │ 0.248 │ 0.255 │ 0.307 │ 0.239 │ 0.198 │ 0.202 │ 0.162 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 27 │ 0.023 │ 0.085 │ 0.025 │ 0.051 │ 0.101 │ 0.2 │ 0.099 │ 0.2 │ 0.079 │ 0.116 │ 0.074 │ 0.149 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 28 │ 0.155 │ 0.046 │ 0.085 │ 0.073 │ 0.301 │ 0.162 │ 0.235 │ 0.205 │ 0.146 │ 0.115 │ 0.15 │ 0.133 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 29 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 30 │ 0.237 │ 0.26 │ 0.385 │ 0.266 │ 0.362 │ 0.384 │ 0.569 │ 0.393 │ 0.125 │ 0.124 │ 0.184 │ 0.127 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 31 │ 0.419 │ 0.429 │ 0.443 │ 0.402 │ 0.609 │ 0.623 │ 0.628 │ 0.588 │ 0.189 │ 0.194 │ 0.185 │ 0.186 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 32 │ 0.015 │ 0.01 │ 0.015 │ 0.054 │ 0.333 │ 0.307 │ 0.319 │ 0.328 │ 0.318 │ 0.297 │ 0.304 │ 0.273 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 33 │ 0.012 │ 0.008 │ 0.011 │ 0.052 │ 0.226 │ 0.175 │ 0.182 │ 0.291 │ 0.214 │ 0.167 │ 0.17 │ 0.239 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 34 │ 0.031 │ 0.004 │ 0.006 │ 0.015 │ 0.197 │ 0.096 │ 0.101 │ 0.197 │ 0.166 │ 0.092 │ 0.095 │ 0.182 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 35 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 36 │ 0.054 │ 0.067 │ 0.143 │ 0.083 │ 0.17 │ 0.19 │ 0.289 │ 0.219 │ 0.116 │ 0.123 │ 0.147 │ 0.136 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 37 │ 0.071 │ 0.202 │ 0.136 │ 0.14 │ 0.185 │ 0.362 │ 0.266 │ 0.272 │ 0.114 │ 0.16 │ 0.131 │ 0.132 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 38 │ 0.177 │ 0.199 │ 0.287 │ 0.358 │ 0.307 │ 0.347 │ 0.415 │ 0.535 │ 0.13 │ 0.148 │ 0.128 │ 0.177 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 39 │ 0.462 │ 0.471 │ 0.507 │ 0.431 │ 0.639 │ 0.659 │ 0.69 │ 0.611 │ 0.177 │ 0.188 │ 0.182 │ 0.181 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 40 │ 0.005 │ 0.001 │ 0.002 │ 0.005 │ 0.306 │ 0.196 │ 0.205 │ 0.211 │ 0.301 │ 0.195 │ 0.203 │ 0.206 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 41 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 42 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 43 │ 0 │ 0.012 │ 0.022 │ 0.01 │ 0.016 │ 0.086 │ 0.086 │ 0.071 │ 0.015 │ 0.074 │ 0.064 │ 0.061 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 44 │ 0.029 │ 0.029 │ 0.056 │ 0.085 │ 0.124 │ 0.118 │ 0.185 │ 0.214 │ 0.095 │ 0.09 │ 0.13 │ 0.129 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 45 │ 0.119 │ 0.041 │ 0.079 │ 0.065 │ 0.273 │ 0.153 │ 0.202 │ 0.19 │ 0.153 │ 0.113 │ 0.123 │ 0.125 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 46 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 47 │ 0.419 │ 0.451 │ 0.602 │ 0.359 │ 0.517 │ 0.544 │ 0.772 │ 0.482 │ 0.098 │ 0.094 │ 0.17 │ 0.123 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 48 │ 0.001 │ 0 │ 0.001 │ 0.001 │ 0.289 │ 0.188 │ 0.194 │ 0.196 │ 0.288 │ 0.188 │ 0.193 │ 0.196 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 49 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 50 │ 0 │ 0 │ 0 │ 0 │ 0.042 │ 0.058 │ 0.058 │ 0.016 │ 0.042 │ 0.058 │ 0.057 │ 0.015 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 51 │ 0.003 │ 0 │ 0.003 │ 0.003 │ 0.048 │ 0.019 │ 0.028 │ 0.048 │ 0.045 │ 0.019 │ 0.026 │ 0.045 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 52 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 53 │ 0.077 │ 0.04 │ 0.083 │ 0.02 │ 0.251 │ 0.161 │ 0.251 │ 0.09 │ 0.174 │ 0.121 │ 0.168 │ 0.07 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 54 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 55 │ 0.524 │ 0.57 │ 0.768 │ 0.459 │ 0.588 │ 0.623 │ 0.878 │ 0.544 │ 0.064 │ 0.053 │ 0.109 │ 0.085 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 56 │ 0.001 │ 0 │ 0 │ 0 │ 0.28 │ 0.251 │ 0.254 │ 0.254 │ 0.28 │ 0.251 │ 0.254 │ 0.254 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 57 │ 0 │ 0 │ 0 │ 0 │ 0.159 │ 0.201 │ 0.108 │ 0.135 │ 0.159 │ 0.201 │ 0.108 │ 0.135 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 58 │ 0 │ 0 │ 0 │ 0 │ 0.127 │ 0.108 │ 0.061 │ 0.085 │ 0.127 │ 0.108 │ 0.061 │ 0.085 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 59 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 60 │ 0.007 │ 0.05 │ 0.049 │ 0.013 │ 0.079 │ 0.24 │ 0.24 │ 0.161 │ 0.073 │ 0.19 │ 0.191 │ 0.148 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 61 │ 0.046 │ 0.115 │ 0.159 │ 0.108 │ 0.322 │ 0.483 │ 0.486 │ 0.405 │ 0.276 │ 0.368 │ 0.328 │ 0.297 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 62 │ 0.056 │ 0.208 │ 0.258 │ 0.133 │ 0.404 │ 0.737 │ 0.577 │ 0.494 │ 0.348 │ 0.529 │ 0.319 │ 0.361 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 63 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╧════════╧════════╛ Action-value function RMSE: 0.171 정책: | 00 ^ | 01 ^ | 02 > | 03 ^ | 04 > | 05 > | 06 > | 07 < | | 08 ^ | 09 ^ | 10 ^ | 11 ^ | 12 ^ | 13 > | 14 > | 15 v | | 16 ^ | 17 ^ | 18 < | | 20 > | 21 ^ | 22 > | 23 > | | 24 ^ | 25 ^ | 26 ^ | 27 v | 28 < | | 30 > | 31 > | | 32 ^ | 33 ^ | 34 < | | 36 > | 37 v | 38 ^ | 39 > | | 40 < | | | 43 > | 44 ^ | 45 < | | 47 > | | 48 < | | 50 v | 51 > | | 53 < | | 55 > | | 56 > | 57 < | 58 ^ | | 60 v | 61 v | 62 > | | Reaches goal 85.00%. Obtains an average return of 0.3755. Regret of 0.0430
Q_rqlls, V_rqlls, Q_track_rqlls = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_rqll, V_rqll, pi_rqll, Q_track_rqll, pi_track_rqll = q_lambda(env, gamma=gamma, n_episodes=n_episodes)
Q_rqlls.append(Q_rqll) ; V_rqlls.append(V_rqll) ; Q_track_rqlls.append(Q_track_rqll)
Q_rqll, V_rqll, Q_track_rqll = np.mean(Q_rqlls, axis=0), np.mean(V_rqlls, axis=0), np.mean(Q_track_rqlls, axis=0)
del Q_rqlls ; del V_rqlls ; del Q_track_rqlls
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
print_state_value_function(V_rqll, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Q(λ) replacing:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_rqll - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_rqll, optimal_V)))
print()
print_action_value_function(Q_rqll,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Q(λ) replacing action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_rqll, optimal_Q)))
print()
print_policy(pi_rqll, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_rqll, mean_return_rqll, mean_regret_rqll = get_policy_metrics(
env, gamma=gamma, pi=pi_rqll, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_rqll, mean_return_rqll, mean_regret_rqll))
State-value function found by Q(λ) replacing: | 00 0.3966 | 01 0.4085 | 02 0.428 | 03 0.4494 | 04 0.4714 | 05 0.4951 | 06 0.5117 | 07 0.5166 | | 08 0.3931 | 09 0.4027 | 10 0.4195 | 11 0.4402 | 12 0.463 | 13 0.4898 | 14 0.5203 | 15 0.5309 | | 16 0.3733 | 17 0.3721 | 18 0.3533 | | 20 0.3995 | 21 0.4694 | 22 0.534 | 23 0.5577 | | 24 0.3414 | 25 0.3258 | 26 0.2845 | 27 0.1791 | 28 0.2781 | | 30 0.5409 | 31 0.5953 | | 32 0.2965 | 33 0.2623 | 34 0.1737 | | 36 0.2607 | 37 0.3286 | 38 0.5083 | 39 0.6558 | | 40 0.2546 | | | 43 0.0707 | 44 0.1817 | 45 0.2299 | | 47 0.7506 | | 48 0.2161 | | 50 0.0149 | 51 0.0303 | | 53 0.1888 | | 55 0.8732 | | 56 0.194 | 57 0.0796 | 58 0.034 | | 60 0.098 | 61 0.3324 | 62 0.509 | | Optimal state-value function: | 00 0.4146 | 01 0.4272 | 02 0.4461 | 03 0.4683 | 04 0.4924 | 05 0.5166 | 06 0.5353 | 07 0.541 | | 08 0.4117 | 09 0.4212 | 10 0.4375 | 11 0.4584 | 12 0.4832 | 13 0.5135 | 14 0.5458 | 15 0.5574 | | 16 0.3968 | 17 0.3938 | 18 0.3755 | | 20 0.4217 | 21 0.4938 | 22 0.5612 | 23 0.5859 | | 24 0.3693 | 25 0.353 | 26 0.3065 | 27 0.2004 | 28 0.3008 | | 30 0.569 | 31 0.6283 | | 32 0.3327 | 33 0.2914 | 34 0.1973 | | 36 0.2893 | 37 0.362 | 38 0.5348 | 39 0.6897 | | 40 0.3061 | | | 43 0.0863 | 44 0.2139 | 45 0.2727 | | 47 0.772 | | 48 0.2889 | | 50 0.0577 | 51 0.0475 | | 53 0.2505 | | 55 0.8778 | | 56 0.2804 | 57 0.2008 | 58 0.1273 | | 60 0.2396 | 61 0.4864 | 62 0.7371 | | State-value function errors: | 00 -0.02 | 01 -0.02 | 02 -0.02 | 03 -0.02 | 04 -0.02 | 05 -0.02 | 06 -0.02 | 07 -0.02 | | 08 -0.02 | 09 -0.02 | 10 -0.02 | 11 -0.02 | 12 -0.02 | 13 -0.02 | 14 -0.03 | 15 -0.03 | | 16 -0.02 | 17 -0.02 | 18 -0.02 | | 20 -0.02 | 21 -0.02 | 22 -0.03 | 23 -0.03 | | 24 -0.03 | 25 -0.03 | 26 -0.02 | 27 -0.02 | 28 -0.02 | | 30 -0.03 | 31 -0.03 | | 32 -0.04 | 33 -0.03 | 34 -0.02 | | 36 -0.03 | 37 -0.03 | 38 -0.03 | 39 -0.03 | | 40 -0.05 | | | 43 -0.02 | 44 -0.03 | 45 -0.04 | | 47 -0.02 | | 48 -0.07 | | 50 -0.04 | 51 -0.02 | | 53 -0.06 | | 55 -0.0 | | 56 -0.09 | 57 -0.12 | 58 -0.09 | | 60 -0.14 | 61 -0.15 | 62 -0.23 | | State-value function RMSE: 0.051 Q(λ) replacing action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╤════════╤════════╕ │ s │ < │ v │ > │ ^ │ * < │ * v │ * > │ * ^ │ er < │ er v │ er > │ er ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╪════════╪════════╡ │ 0 │ 0.387 │ 0.39 │ 0.393 │ 0.394 │ 0.41 │ 0.414 │ 0.414 │ 0.415 │ 0.023 │ 0.024 │ 0.021 │ 0.021 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 1 │ 0.393 │ 0.399 │ 0.408 │ 0.401 │ 0.417 │ 0.423 │ 0.427 │ 0.425 │ 0.023 │ 0.024 │ 0.019 │ 0.024 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 2 │ 0.409 │ 0.415 │ 0.428 │ 0.418 │ 0.433 │ 0.44 │ 0.446 │ 0.443 │ 0.023 │ 0.025 │ 0.018 │ 0.025 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 3 │ 0.428 │ 0.434 │ 0.449 │ 0.438 │ 0.453 │ 0.461 │ 0.468 │ 0.464 │ 0.025 │ 0.027 │ 0.019 │ 0.027 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 4 │ 0.45 │ 0.457 │ 0.471 │ 0.458 │ 0.477 │ 0.484 │ 0.492 │ 0.488 │ 0.027 │ 0.027 │ 0.021 │ 0.029 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 5 │ 0.475 │ 0.478 │ 0.495 │ 0.48 │ 0.502 │ 0.509 │ 0.517 │ 0.51 │ 0.027 │ 0.03 │ 0.021 │ 0.029 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 6 │ 0.501 │ 0.502 │ 0.512 │ 0.499 │ 0.527 │ 0.529 │ 0.535 │ 0.526 │ 0.026 │ 0.027 │ 0.024 │ 0.026 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 7 │ 0.51 │ 0.511 │ 0.516 │ 0.508 │ 0.539 │ 0.539 │ 0.541 │ 0.534 │ 0.029 │ 0.029 │ 0.025 │ 0.026 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 8 │ 0.376 │ 0.378 │ 0.378 │ 0.393 │ 0.404 │ 0.406 │ 0.407 │ 0.412 │ 0.028 │ 0.028 │ 0.029 │ 0.019 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 9 │ 0.383 │ 0.386 │ 0.39 │ 0.403 │ 0.407 │ 0.41 │ 0.415 │ 0.421 │ 0.024 │ 0.024 │ 0.025 │ 0.019 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 10 │ 0.388 │ 0.39 │ 0.397 │ 0.419 │ 0.41 │ 0.414 │ 0.422 │ 0.437 │ 0.022 │ 0.024 │ 0.025 │ 0.018 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 11 │ 0.277 │ 0.279 │ 0.294 │ 0.44 │ 0.299 │ 0.304 │ 0.314 │ 0.458 │ 0.022 │ 0.024 │ 0.02 │ 0.018 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 12 │ 0.429 │ 0.432 │ 0.441 │ 0.463 │ 0.453 │ 0.46 │ 0.471 │ 0.483 │ 0.024 │ 0.028 │ 0.03 │ 0.02 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 13 │ 0.465 │ 0.472 │ 0.485 │ 0.483 │ 0.493 │ 0.503 │ 0.514 │ 0.51 │ 0.028 │ 0.031 │ 0.029 │ 0.027 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 14 │ 0.507 │ 0.512 │ 0.52 │ 0.506 │ 0.531 │ 0.539 │ 0.546 │ 0.53 │ 0.025 │ 0.027 │ 0.025 │ 0.024 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 15 │ 0.522 │ 0.531 │ 0.525 │ 0.518 │ 0.552 │ 0.557 │ 0.556 │ 0.543 │ 0.029 │ 0.027 │ 0.031 │ 0.024 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 16 │ 0.359 │ 0.356 │ 0.359 │ 0.373 │ 0.389 │ 0.383 │ 0.388 │ 0.397 │ 0.029 │ 0.026 │ 0.029 │ 0.023 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 17 │ 0.357 │ 0.349 │ 0.354 │ 0.372 │ 0.386 │ 0.371 │ 0.379 │ 0.394 │ 0.029 │ 0.022 │ 0.025 │ 0.022 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 18 │ 0.353 │ 0.212 │ 0.236 │ 0.252 │ 0.375 │ 0.231 │ 0.246 │ 0.274 │ 0.022 │ 0.019 │ 0.009 │ 0.023 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 19 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 20 │ 0.25 │ 0.249 │ 0.399 │ 0.304 │ 0.259 │ 0.262 │ 0.422 │ 0.322 │ 0.008 │ 0.013 │ 0.022 │ 0.018 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 21 │ 0.301 │ 0.323 │ 0.33 │ 0.469 │ 0.309 │ 0.324 │ 0.355 │ 0.494 │ 0.007 │ 0.001 │ 0.024 │ 0.024 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 22 │ 0.506 │ 0.521 │ 0.534 │ 0.513 │ 0.531 │ 0.544 │ 0.561 │ 0.536 │ 0.025 │ 0.023 │ 0.027 │ 0.023 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 23 │ 0.552 │ 0.558 │ 0.553 │ 0.541 │ 0.576 │ 0.586 │ 0.585 │ 0.562 │ 0.025 │ 0.028 │ 0.031 │ 0.022 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 24 │ 0.329 │ 0.322 │ 0.327 │ 0.341 │ 0.363 │ 0.348 │ 0.357 │ 0.369 │ 0.034 │ 0.026 │ 0.03 │ 0.029 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 25 │ 0.317 │ 0.296 │ 0.302 │ 0.325 │ 0.348 │ 0.319 │ 0.327 │ 0.353 │ 0.031 │ 0.023 │ 0.025 │ 0.028 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 26 │ 0.277 │ 0.232 │ 0.242 │ 0.283 │ 0.306 │ 0.248 │ 0.255 │ 0.307 │ 0.028 │ 0.016 │ 0.013 │ 0.024 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 27 │ 0.089 │ 0.175 │ 0.084 │ 0.171 │ 0.101 │ 0.2 │ 0.099 │ 0.2 │ 0.012 │ 0.025 │ 0.015 │ 0.029 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 28 │ 0.278 │ 0.151 │ 0.209 │ 0.193 │ 0.301 │ 0.162 │ 0.235 │ 0.205 │ 0.023 │ 0.01 │ 0.025 │ 0.012 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 29 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 30 │ 0.339 │ 0.352 │ 0.541 │ 0.371 │ 0.362 │ 0.384 │ 0.569 │ 0.393 │ 0.023 │ 0.032 │ 0.028 │ 0.022 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 31 │ 0.584 │ 0.589 │ 0.595 │ 0.568 │ 0.609 │ 0.623 │ 0.628 │ 0.588 │ 0.025 │ 0.034 │ 0.033 │ 0.02 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 32 │ 0.294 │ 0.274 │ 0.282 │ 0.292 │ 0.333 │ 0.307 │ 0.319 │ 0.328 │ 0.039 │ 0.033 │ 0.037 │ 0.036 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 33 │ 0.209 │ 0.154 │ 0.175 │ 0.262 │ 0.226 │ 0.175 │ 0.182 │ 0.291 │ 0.018 │ 0.021 │ 0.007 │ 0.029 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 34 │ 0.168 │ 0.088 │ 0.092 │ 0.162 │ 0.197 │ 0.096 │ 0.101 │ 0.197 │ 0.029 │ 0.008 │ 0.009 │ 0.035 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 35 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 36 │ 0.156 │ 0.166 │ 0.261 │ 0.199 │ 0.17 │ 0.19 │ 0.289 │ 0.219 │ 0.014 │ 0.024 │ 0.029 │ 0.02 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 37 │ 0.165 │ 0.329 │ 0.244 │ 0.248 │ 0.185 │ 0.362 │ 0.266 │ 0.272 │ 0.021 │ 0.033 │ 0.023 │ 0.024 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 38 │ 0.291 │ 0.324 │ 0.401 │ 0.508 │ 0.307 │ 0.347 │ 0.415 │ 0.535 │ 0.017 │ 0.023 │ 0.014 │ 0.027 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 39 │ 0.626 │ 0.637 │ 0.656 │ 0.589 │ 0.639 │ 0.659 │ 0.69 │ 0.611 │ 0.012 │ 0.022 │ 0.034 │ 0.023 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 40 │ 0.255 │ 0.131 │ 0.163 │ 0.197 │ 0.306 │ 0.196 │ 0.205 │ 0.211 │ 0.052 │ 0.065 │ 0.042 │ 0.014 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 41 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 42 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 43 │ 0.004 │ 0.065 │ 0.046 │ 0.042 │ 0.016 │ 0.086 │ 0.086 │ 0.071 │ 0.012 │ 0.022 │ 0.04 │ 0.029 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 44 │ 0.111 │ 0.096 │ 0.154 │ 0.182 │ 0.124 │ 0.118 │ 0.185 │ 0.214 │ 0.013 │ 0.023 │ 0.031 │ 0.032 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 45 │ 0.23 │ 0.118 │ 0.16 │ 0.163 │ 0.273 │ 0.153 │ 0.202 │ 0.19 │ 0.043 │ 0.035 │ 0.042 │ 0.027 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 46 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 47 │ 0.492 │ 0.552 │ 0.751 │ 0.476 │ 0.517 │ 0.544 │ 0.772 │ 0.482 │ 0.025 │ -0.008 │ 0.021 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 48 │ 0.216 │ 0.1 │ 0.126 │ 0.133 │ 0.289 │ 0.188 │ 0.194 │ 0.196 │ 0.073 │ 0.088 │ 0.068 │ 0.063 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 49 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 50 │ 0.003 │ 0.009 │ 0.008 │ 0.001 │ 0.042 │ 0.058 │ 0.058 │ 0.016 │ 0.039 │ 0.049 │ 0.05 │ 0.015 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 51 │ 0.026 │ 0.001 │ 0.011 │ 0.008 │ 0.048 │ 0.019 │ 0.028 │ 0.048 │ 0.022 │ 0.018 │ 0.018 │ 0.039 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 52 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 53 │ 0.149 │ 0.068 │ 0.163 │ 0.057 │ 0.251 │ 0.161 │ 0.251 │ 0.09 │ 0.101 │ 0.092 │ 0.087 │ 0.033 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 54 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 55 │ 0.605 │ 0.594 │ 0.873 │ 0.556 │ 0.588 │ 0.623 │ 0.878 │ 0.544 │ -0.017 │ 0.029 │ 0.005 │ -0.011 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 56 │ 0.194 │ 0.109 │ 0.123 │ 0.125 │ 0.28 │ 0.251 │ 0.254 │ 0.254 │ 0.086 │ 0.143 │ 0.131 │ 0.129 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 57 │ 0.054 │ 0.064 │ 0.017 │ 0.035 │ 0.159 │ 0.201 │ 0.108 │ 0.135 │ 0.105 │ 0.137 │ 0.092 │ 0.1 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 58 │ 0.033 │ 0.014 │ 0.004 │ 0.009 │ 0.127 │ 0.108 │ 0.061 │ 0.085 │ 0.095 │ 0.095 │ 0.057 │ 0.077 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 59 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 60 │ 0.009 │ 0.044 │ 0.082 │ 0.022 │ 0.079 │ 0.24 │ 0.24 │ 0.161 │ 0.07 │ 0.195 │ 0.158 │ 0.139 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 61 │ 0.097 │ 0.195 │ 0.286 │ 0.144 │ 0.322 │ 0.483 │ 0.486 │ 0.405 │ 0.225 │ 0.288 │ 0.201 │ 0.261 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 62 │ 0.078 │ 0.28 │ 0.296 │ 0.229 │ 0.404 │ 0.737 │ 0.577 │ 0.494 │ 0.326 │ 0.457 │ 0.281 │ 0.264 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 63 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╧════════╧════════╛ Action-value function RMSE: 0.0656 정책: | 00 ^ | 01 > | 02 > | 03 > | 04 > | 05 > | 06 > | 07 > | | 08 ^ | 09 ^ | 10 ^ | 11 ^ | 12 ^ | 13 > | 14 > | 15 v | | 16 ^ | 17 ^ | 18 < | | 20 > | 21 ^ | 22 > | 23 v | | 24 ^ | 25 ^ | 26 ^ | 27 v | 28 < | | 30 > | 31 > | | 32 ^ | 33 ^ | 34 ^ | | 36 > | 37 v | 38 ^ | 39 > | | 40 < | | | 43 > | 44 ^ | 45 < | | 47 > | | 48 < | | 50 < | 51 < | | 53 > | | 55 > | | 56 < | 57 < | 58 v | | 60 > | 61 v | 62 v | | Reaches goal 82.00%. Obtains an average return of 0.4156. Regret of 0.0032
Q_aqlls, V_aqlls, Q_track_aqlls = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_aqll, V_aqll, pi_aqll, Q_track_aqll, pi_track_aqll = q_lambda(env, gamma=gamma,
replacing_traces=False,
n_episodes=n_episodes)
Q_aqlls.append(Q_aqll) ; V_aqlls.append(V_aqll) ; Q_track_aqlls.append(Q_track_aqll)
Q_aqll, V_aqll, Q_track_aqll = np.mean(Q_aqlls, axis=0), np.mean(V_aqlls, axis=0), np.mean(Q_track_aqlls, axis=0)
del Q_aqlls ; del V_aqlls ; del Q_track_aqlls
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
print_state_value_function(V_aqll, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Q(λ) accumulating:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_aqll - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_aqll, optimal_V)))
print()
print_action_value_function(Q_aqll,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Q(λ) accumulating action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_aqll, optimal_Q)))
print()
print_policy(pi_aqll, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_aqll, mean_return_aqll, mean_regret_aqll = get_policy_metrics(
env, gamma=gamma, pi=pi_aqll, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_aqll, mean_return_aqll, mean_regret_aqll))
State-value function found by Q(λ) accumulating: | 00 0.3924 | 01 0.4057 | 02 0.4252 | 03 0.4471 | 04 0.4706 | 05 0.4929 | 06 0.5099 | 07 0.5143 | | 08 0.3896 | 09 0.3995 | 10 0.4173 | 11 0.4381 | 12 0.4627 | 13 0.4907 | 14 0.5207 | 15 0.5321 | | 16 0.3708 | 17 0.3686 | 18 0.3491 | | 20 0.3985 | 21 0.4725 | 22 0.5346 | 23 0.5601 | | 24 0.3436 | 25 0.3289 | 26 0.2823 | 27 0.1806 | 28 0.2756 | | 30 0.5408 | 31 0.6102 | | 32 0.2998 | 33 0.2637 | 34 0.1776 | | 36 0.2505 | 37 0.313 | 38 0.5038 | 39 0.6802 | | 40 0.2487 | | | 43 0.061 | 44 0.1707 | 45 0.2132 | | 47 0.7706 | | 48 0.2174 | | 50 0.0083 | 51 0.0241 | | 53 0.1686 | | 55 0.8835 | | 56 0.2001 | 57 0.0772 | 58 0.022 | | 60 0.0605 | 61 0.3062 | 62 0.4714 | | Optimal state-value function: | 00 0.4146 | 01 0.4272 | 02 0.4461 | 03 0.4683 | 04 0.4924 | 05 0.5166 | 06 0.5353 | 07 0.541 | | 08 0.4117 | 09 0.4212 | 10 0.4375 | 11 0.4584 | 12 0.4832 | 13 0.5135 | 14 0.5458 | 15 0.5574 | | 16 0.3968 | 17 0.3938 | 18 0.3755 | | 20 0.4217 | 21 0.4938 | 22 0.5612 | 23 0.5859 | | 24 0.3693 | 25 0.353 | 26 0.3065 | 27 0.2004 | 28 0.3008 | | 30 0.569 | 31 0.6283 | | 32 0.3327 | 33 0.2914 | 34 0.1973 | | 36 0.2893 | 37 0.362 | 38 0.5348 | 39 0.6897 | | 40 0.3061 | | | 43 0.0863 | 44 0.2139 | 45 0.2727 | | 47 0.772 | | 48 0.2889 | | 50 0.0577 | 51 0.0475 | | 53 0.2505 | | 55 0.8778 | | 56 0.2804 | 57 0.2008 | 58 0.1273 | | 60 0.2396 | 61 0.4864 | 62 0.7371 | | State-value function errors: | 00 -0.02 | 01 -0.02 | 02 -0.02 | 03 -0.02 | 04 -0.02 | 05 -0.02 | 06 -0.03 | 07 -0.03 | | 08 -0.02 | 09 -0.02 | 10 -0.02 | 11 -0.02 | 12 -0.02 | 13 -0.02 | 14 -0.03 | 15 -0.03 | | 16 -0.03 | 17 -0.03 | 18 -0.03 | | 20 -0.02 | 21 -0.02 | 22 -0.03 | 23 -0.03 | | 24 -0.03 | 25 -0.02 | 26 -0.02 | 27 -0.02 | 28 -0.03 | | 30 -0.03 | 31 -0.02 | | 32 -0.03 | 33 -0.03 | 34 -0.02 | | 36 -0.04 | 37 -0.05 | 38 -0.03 | 39 -0.01 | | 40 -0.06 | | | 43 -0.03 | 44 -0.04 | 45 -0.06 | | 47 -0.0 | | 48 -0.07 | | 50 -0.05 | 51 -0.02 | | 53 -0.08 | | 55 0.01 | | 56 -0.08 | 57 -0.12 | 58 -0.11 | | 60 -0.18 | 61 -0.18 | 62 -0.27 | | State-value function RMSE: 0.0581 Q(λ) accumulating action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╤════════╤════════╕ │ s │ < │ v │ > │ ^ │ * < │ * v │ * > │ * ^ │ er < │ er v │ er > │ er ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╪════════╪════════╡ │ 0 │ 0.383 │ 0.389 │ 0.386 │ 0.39 │ 0.41 │ 0.414 │ 0.414 │ 0.415 │ 0.026 │ 0.024 │ 0.027 │ 0.025 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 1 │ 0.391 │ 0.396 │ 0.406 │ 0.398 │ 0.417 │ 0.423 │ 0.427 │ 0.425 │ 0.026 │ 0.027 │ 0.021 │ 0.027 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 2 │ 0.407 │ 0.412 │ 0.425 │ 0.413 │ 0.433 │ 0.44 │ 0.446 │ 0.443 │ 0.026 │ 0.028 │ 0.021 │ 0.03 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 3 │ 0.427 │ 0.432 │ 0.447 │ 0.433 │ 0.453 │ 0.461 │ 0.468 │ 0.464 │ 0.026 │ 0.029 │ 0.021 │ 0.031 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 4 │ 0.448 │ 0.454 │ 0.471 │ 0.456 │ 0.477 │ 0.484 │ 0.492 │ 0.488 │ 0.028 │ 0.031 │ 0.022 │ 0.032 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 5 │ 0.473 │ 0.475 │ 0.493 │ 0.478 │ 0.502 │ 0.509 │ 0.517 │ 0.51 │ 0.03 │ 0.034 │ 0.024 │ 0.031 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 6 │ 0.499 │ 0.499 │ 0.51 │ 0.496 │ 0.527 │ 0.529 │ 0.535 │ 0.526 │ 0.029 │ 0.03 │ 0.025 │ 0.029 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 7 │ 0.506 │ 0.512 │ 0.503 │ 0.501 │ 0.539 │ 0.539 │ 0.541 │ 0.534 │ 0.033 │ 0.027 │ 0.038 │ 0.032 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 8 │ 0.376 │ 0.377 │ 0.378 │ 0.39 │ 0.404 │ 0.406 │ 0.407 │ 0.412 │ 0.028 │ 0.029 │ 0.029 │ 0.022 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 9 │ 0.382 │ 0.384 │ 0.388 │ 0.399 │ 0.407 │ 0.41 │ 0.415 │ 0.421 │ 0.025 │ 0.026 │ 0.028 │ 0.022 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 10 │ 0.384 │ 0.389 │ 0.393 │ 0.417 │ 0.41 │ 0.414 │ 0.422 │ 0.437 │ 0.026 │ 0.025 │ 0.029 │ 0.02 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 11 │ 0.281 │ 0.29 │ 0.294 │ 0.438 │ 0.299 │ 0.304 │ 0.314 │ 0.458 │ 0.017 │ 0.014 │ 0.02 │ 0.02 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 12 │ 0.425 │ 0.432 │ 0.439 │ 0.463 │ 0.453 │ 0.46 │ 0.471 │ 0.483 │ 0.028 │ 0.027 │ 0.033 │ 0.021 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 13 │ 0.466 │ 0.472 │ 0.491 │ 0.474 │ 0.493 │ 0.503 │ 0.514 │ 0.51 │ 0.027 │ 0.03 │ 0.023 │ 0.036 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 14 │ 0.504 │ 0.511 │ 0.521 │ 0.501 │ 0.531 │ 0.539 │ 0.546 │ 0.53 │ 0.027 │ 0.028 │ 0.025 │ 0.029 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 15 │ 0.52 │ 0.531 │ 0.524 │ 0.513 │ 0.552 │ 0.557 │ 0.556 │ 0.543 │ 0.031 │ 0.026 │ 0.032 │ 0.029 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 16 │ 0.36 │ 0.356 │ 0.359 │ 0.371 │ 0.389 │ 0.383 │ 0.388 │ 0.397 │ 0.029 │ 0.027 │ 0.029 │ 0.026 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 17 │ 0.356 │ 0.347 │ 0.355 │ 0.369 │ 0.386 │ 0.371 │ 0.379 │ 0.394 │ 0.031 │ 0.024 │ 0.024 │ 0.025 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 18 │ 0.349 │ 0.206 │ 0.229 │ 0.266 │ 0.375 │ 0.231 │ 0.246 │ 0.274 │ 0.026 │ 0.025 │ 0.017 │ 0.008 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 19 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 20 │ 0.242 │ 0.247 │ 0.398 │ 0.304 │ 0.259 │ 0.262 │ 0.422 │ 0.322 │ 0.016 │ 0.015 │ 0.023 │ 0.018 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 21 │ 0.287 │ 0.31 │ 0.336 │ 0.473 │ 0.309 │ 0.324 │ 0.355 │ 0.494 │ 0.021 │ 0.014 │ 0.019 │ 0.021 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 22 │ 0.505 │ 0.517 │ 0.535 │ 0.513 │ 0.531 │ 0.544 │ 0.561 │ 0.536 │ 0.025 │ 0.027 │ 0.027 │ 0.024 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 23 │ 0.549 │ 0.56 │ 0.552 │ 0.539 │ 0.576 │ 0.586 │ 0.585 │ 0.562 │ 0.028 │ 0.026 │ 0.032 │ 0.023 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 24 │ 0.329 │ 0.319 │ 0.328 │ 0.344 │ 0.363 │ 0.348 │ 0.357 │ 0.369 │ 0.033 │ 0.029 │ 0.029 │ 0.026 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 25 │ 0.317 │ 0.289 │ 0.3 │ 0.329 │ 0.348 │ 0.319 │ 0.327 │ 0.353 │ 0.031 │ 0.03 │ 0.027 │ 0.024 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 26 │ 0.272 │ 0.23 │ 0.231 │ 0.28 │ 0.306 │ 0.248 │ 0.255 │ 0.307 │ 0.033 │ 0.018 │ 0.024 │ 0.027 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 27 │ 0.091 │ 0.172 │ 0.084 │ 0.17 │ 0.101 │ 0.2 │ 0.099 │ 0.2 │ 0.011 │ 0.029 │ 0.015 │ 0.031 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 28 │ 0.276 │ 0.143 │ 0.22 │ 0.185 │ 0.301 │ 0.162 │ 0.235 │ 0.205 │ 0.025 │ 0.018 │ 0.014 │ 0.02 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 29 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 30 │ 0.349 │ 0.368 │ 0.541 │ 0.373 │ 0.362 │ 0.384 │ 0.569 │ 0.393 │ 0.012 │ 0.015 │ 0.028 │ 0.019 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 31 │ 0.583 │ 0.592 │ 0.609 │ 0.565 │ 0.609 │ 0.623 │ 0.628 │ 0.588 │ 0.025 │ 0.031 │ 0.02 │ 0.023 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 32 │ 0.288 │ 0.267 │ 0.276 │ 0.299 │ 0.333 │ 0.307 │ 0.319 │ 0.328 │ 0.044 │ 0.039 │ 0.043 │ 0.029 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 33 │ 0.203 │ 0.164 │ 0.17 │ 0.264 │ 0.226 │ 0.175 │ 0.182 │ 0.291 │ 0.023 │ 0.011 │ 0.011 │ 0.028 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 34 │ 0.161 │ 0.101 │ 0.098 │ 0.167 │ 0.197 │ 0.096 │ 0.101 │ 0.197 │ 0.036 │ -0.005 │ 0.003 │ 0.03 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 35 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 36 │ 0.154 │ 0.17 │ 0.251 │ 0.193 │ 0.17 │ 0.19 │ 0.289 │ 0.219 │ 0.016 │ 0.02 │ 0.039 │ 0.026 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 37 │ 0.154 │ 0.313 │ 0.231 │ 0.254 │ 0.185 │ 0.362 │ 0.266 │ 0.272 │ 0.031 │ 0.049 │ 0.036 │ 0.018 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 38 │ 0.294 │ 0.323 │ 0.414 │ 0.504 │ 0.307 │ 0.347 │ 0.415 │ 0.535 │ 0.013 │ 0.024 │ 0.001 │ 0.031 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 39 │ 0.62 │ 0.634 │ 0.68 │ 0.59 │ 0.639 │ 0.659 │ 0.69 │ 0.611 │ 0.018 │ 0.024 │ 0.009 │ 0.021 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 40 │ 0.249 │ 0.154 │ 0.168 │ 0.175 │ 0.306 │ 0.196 │ 0.205 │ 0.211 │ 0.057 │ 0.042 │ 0.037 │ 0.036 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 41 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 42 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 43 │ 0.005 │ 0.048 │ 0.056 │ 0.046 │ 0.016 │ 0.086 │ 0.086 │ 0.071 │ 0.011 │ 0.038 │ 0.03 │ 0.024 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 44 │ 0.107 │ 0.089 │ 0.156 │ 0.171 │ 0.124 │ 0.118 │ 0.185 │ 0.214 │ 0.017 │ 0.029 │ 0.03 │ 0.043 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 45 │ 0.212 │ 0.113 │ 0.148 │ 0.162 │ 0.273 │ 0.153 │ 0.202 │ 0.19 │ 0.061 │ 0.04 │ 0.054 │ 0.028 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 46 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 47 │ 0.498 │ 0.534 │ 0.771 │ 0.498 │ 0.517 │ 0.544 │ 0.772 │ 0.482 │ 0.02 │ 0.011 │ 0.001 │ -0.016 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 48 │ 0.217 │ 0.114 │ 0.129 │ 0.147 │ 0.289 │ 0.188 │ 0.194 │ 0.196 │ 0.072 │ 0.074 │ 0.065 │ 0.049 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 49 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 50 │ 0.001 │ 0 │ 0.007 │ 0.001 │ 0.042 │ 0.058 │ 0.058 │ 0.016 │ 0.041 │ 0.057 │ 0.051 │ 0.015 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 51 │ 0.01 │ 0.001 │ 0.016 │ 0.008 │ 0.048 │ 0.019 │ 0.028 │ 0.048 │ 0.038 │ 0.018 │ 0.013 │ 0.04 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 52 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 53 │ 0.158 │ 0.069 │ 0.138 │ 0.066 │ 0.251 │ 0.161 │ 0.251 │ 0.09 │ 0.093 │ 0.091 │ 0.112 │ 0.024 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 54 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 55 │ 0.6 │ 0.616 │ 0.884 │ 0.528 │ 0.588 │ 0.623 │ 0.878 │ 0.544 │ -0.012 │ 0.007 │ -0.006 │ 0.016 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 56 │ 0.2 │ 0.116 │ 0.131 │ 0.131 │ 0.28 │ 0.251 │ 0.254 │ 0.254 │ 0.08 │ 0.135 │ 0.123 │ 0.123 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 57 │ 0.029 │ 0.058 │ 0.008 │ 0.03 │ 0.159 │ 0.201 │ 0.108 │ 0.135 │ 0.129 │ 0.143 │ 0.1 │ 0.105 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 58 │ 0.009 │ 0.014 │ 0.001 │ 0.008 │ 0.127 │ 0.108 │ 0.061 │ 0.085 │ 0.118 │ 0.095 │ 0.06 │ 0.077 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 59 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 60 │ 0.002 │ 0.028 │ 0.028 │ 0.014 │ 0.079 │ 0.24 │ 0.24 │ 0.161 │ 0.077 │ 0.211 │ 0.212 │ 0.146 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 61 │ 0.097 │ 0.146 │ 0.301 │ 0.152 │ 0.322 │ 0.483 │ 0.486 │ 0.405 │ 0.225 │ 0.337 │ 0.186 │ 0.253 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 62 │ 0.071 │ 0.397 │ 0.191 │ 0.164 │ 0.404 │ 0.737 │ 0.577 │ 0.494 │ 0.332 │ 0.34 │ 0.386 │ 0.33 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 63 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╧════════╧════════╛ Action-value function RMSE: 0.0683 정책: | 00 ^ | 01 > | 02 > | 03 > | 04 > | 05 > | 06 > | 07 v | | 08 ^ | 09 ^ | 10 ^ | 11 ^ | 12 ^ | 13 > | 14 > | 15 v | | 16 ^ | 17 ^ | 18 < | | 20 > | 21 ^ | 22 > | 23 v | | 24 ^ | 25 ^ | 26 ^ | 27 v | 28 < | | 30 > | 31 v | | 32 ^ | 33 ^ | 34 < | | 36 > | 37 v | 38 ^ | 39 > | | 40 < | | | 43 ^ | 44 ^ | 45 < | | 47 > | | 48 < | | 50 ^ | 51 ^ | | 53 > | | 55 > | | 56 < | 57 ^ | 58 v | | 60 > | 61 > | 62 v | | Reaches goal 81.00%. Obtains an average return of 0.4145. Regret of 0.0444
Q_dqs, V_dqs, Q_track_dqs = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_dq, V_dq, pi_dq, Q_track_dq, pi_track_dq, T_track_dq, R_track_dq, planning_dq = dyna_q(
env, gamma=gamma, n_episodes=n_episodes)
Q_dqs.append(Q_dq) ; V_dqs.append(V_dq) ; Q_track_dqs.append(Q_track_dq)
Q_dq, V_dq, Q_track_dq = np.mean(Q_dqs, axis=0), np.mean(V_dqs, axis=0), np.mean(Q_track_dqs, axis=0)
del Q_dqs ; del V_dqs ; del Q_track_dqs
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
print_state_value_function(V_dq, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Dyna-Q:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_dq - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_dq, optimal_V)))
print()
print_action_value_function(Q_dq,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Dyna-Q action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_dq, optimal_Q)))
print()
print_policy(pi_dq, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_dq, mean_return_dq, mean_regret_dq = get_policy_metrics(
env, gamma=gamma, pi=pi_dq, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_dq, mean_return_dq, mean_regret_dq))
State-value function found by Dyna-Q: | 00 0.41 | 01 0.4221 | 02 0.4414 | 03 0.462 | 04 0.4856 | 05 0.5085 | 06 0.5245 | 07 0.5297 | | 08 0.4068 | 09 0.4164 | 10 0.4328 | 11 0.4536 | 12 0.4777 | 13 0.5048 | 14 0.5342 | 15 0.5457 | | 16 0.3926 | 17 0.3905 | 18 0.3749 | | 20 0.419 | 21 0.4876 | 22 0.5497 | 23 0.5729 | | 24 0.3644 | 25 0.3516 | 26 0.311 | 27 0.2079 | 28 0.3035 | | 30 0.5569 | 31 0.6136 | | 32 0.3292 | 33 0.2942 | 34 0.2039 | | 36 0.2907 | 37 0.3604 | 38 0.5213 | 39 0.6729 | | 40 0.3028 | | | 43 0.0912 | 44 0.2171 | 45 0.2795 | | 47 0.7556 | | 48 0.2855 | | 50 0.063 | 51 0.0517 | | 53 0.2658 | | 55 0.8627 | | 56 0.2764 | 57 0.1989 | 58 0.1282 | | 60 0.245 | 61 0.5248 | 62 0.7573 | | Optimal state-value function: | 00 0.4146 | 01 0.4272 | 02 0.4461 | 03 0.4683 | 04 0.4924 | 05 0.5166 | 06 0.5353 | 07 0.541 | | 08 0.4117 | 09 0.4212 | 10 0.4375 | 11 0.4584 | 12 0.4832 | 13 0.5135 | 14 0.5458 | 15 0.5574 | | 16 0.3968 | 17 0.3938 | 18 0.3755 | | 20 0.4217 | 21 0.4938 | 22 0.5612 | 23 0.5859 | | 24 0.3693 | 25 0.353 | 26 0.3065 | 27 0.2004 | 28 0.3008 | | 30 0.569 | 31 0.6283 | | 32 0.3327 | 33 0.2914 | 34 0.1973 | | 36 0.2893 | 37 0.362 | 38 0.5348 | 39 0.6897 | | 40 0.3061 | | | 43 0.0863 | 44 0.2139 | 45 0.2727 | | 47 0.772 | | 48 0.2889 | | 50 0.0577 | 51 0.0475 | | 53 0.2505 | | 55 0.8778 | | 56 0.2804 | 57 0.2008 | 58 0.1273 | | 60 0.2396 | 61 0.4864 | 62 0.7371 | | State-value function errors: | 00 -0.0 | 01 -0.01 | 02 -0.0 | 03 -0.01 | 04 -0.01 | 05 -0.01 | 06 -0.01 | 07 -0.01 | | 08 -0.0 | 09 -0.0 | 10 -0.0 | 11 -0.0 | 12 -0.01 | 13 -0.01 | 14 -0.01 | 15 -0.01 | | 16 -0.0 | 17 -0.0 | 18 -0.0 | | 20 -0.0 | 21 -0.01 | 22 -0.01 | 23 -0.01 | | 24 -0.0 | 25 -0.0 | 26 0.0 | 27 0.01 | 28 0.0 | | 30 -0.01 | 31 -0.01 | | 32 -0.0 | 33 0.0 | 34 0.01 | | 36 0.0 | 37 -0.0 | 38 -0.01 | 39 -0.02 | | 40 -0.0 | | | 43 0.0 | 44 0.0 | 45 0.01 | | 47 -0.02 | | 48 -0.0 | | 50 0.01 | 51 0.0 | | 53 0.02 | | 55 -0.02 | | 56 -0.0 | 57 -0.0 | 58 0.0 | | 60 0.01 | 61 0.04 | 62 0.02 | | State-value function RMSE: 0.009 Dyna-Q action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╤════════╤════════╕ │ s │ < │ v │ > │ ^ │ * < │ * v │ * > │ * ^ │ er < │ er v │ er > │ er ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╪════════╪════════╡ │ 0 │ 0.405 │ 0.409 │ 0.409 │ 0.41 │ 0.41 │ 0.414 │ 0.414 │ 0.415 │ 0.005 │ 0.005 │ 0.005 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 1 │ 0.412 │ 0.419 │ 0.422 │ 0.421 │ 0.417 │ 0.423 │ 0.427 │ 0.425 │ 0.005 │ 0.004 │ 0.005 │ 0.004 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 2 │ 0.428 │ 0.434 │ 0.441 │ 0.438 │ 0.433 │ 0.44 │ 0.446 │ 0.443 │ 0.005 │ 0.006 │ 0.005 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 3 │ 0.448 │ 0.455 │ 0.462 │ 0.459 │ 0.453 │ 0.461 │ 0.468 │ 0.464 │ 0.005 │ 0.006 │ 0.007 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 4 │ 0.471 │ 0.478 │ 0.486 │ 0.481 │ 0.477 │ 0.484 │ 0.492 │ 0.488 │ 0.006 │ 0.006 │ 0.007 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 5 │ 0.496 │ 0.502 │ 0.508 │ 0.501 │ 0.502 │ 0.509 │ 0.517 │ 0.51 │ 0.007 │ 0.007 │ 0.008 │ 0.009 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 6 │ 0.518 │ 0.519 │ 0.525 │ 0.517 │ 0.527 │ 0.529 │ 0.535 │ 0.526 │ 0.01 │ 0.01 │ 0.011 │ 0.009 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 7 │ 0.529 │ 0.528 │ 0.529 │ 0.523 │ 0.539 │ 0.539 │ 0.541 │ 0.534 │ 0.01 │ 0.011 │ 0.012 │ 0.01 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 8 │ 0.399 │ 0.4 │ 0.403 │ 0.407 │ 0.404 │ 0.406 │ 0.407 │ 0.412 │ 0.005 │ 0.005 │ 0.004 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 9 │ 0.403 │ 0.406 │ 0.41 │ 0.416 │ 0.407 │ 0.41 │ 0.415 │ 0.421 │ 0.003 │ 0.004 │ 0.005 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 10 │ 0.407 │ 0.411 │ 0.418 │ 0.433 │ 0.41 │ 0.414 │ 0.422 │ 0.437 │ 0.003 │ 0.003 │ 0.004 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 11 │ 0.287 │ 0.302 │ 0.31 │ 0.454 │ 0.299 │ 0.304 │ 0.314 │ 0.458 │ 0.012 │ 0.002 │ 0.004 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 12 │ 0.448 │ 0.458 │ 0.465 │ 0.478 │ 0.453 │ 0.46 │ 0.471 │ 0.483 │ 0.005 │ 0.002 │ 0.006 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 13 │ 0.485 │ 0.495 │ 0.505 │ 0.503 │ 0.493 │ 0.503 │ 0.514 │ 0.51 │ 0.007 │ 0.008 │ 0.009 │ 0.007 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 14 │ 0.522 │ 0.531 │ 0.534 │ 0.521 │ 0.531 │ 0.539 │ 0.546 │ 0.53 │ 0.009 │ 0.008 │ 0.012 │ 0.009 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 15 │ 0.54 │ 0.545 │ 0.545 │ 0.532 │ 0.552 │ 0.557 │ 0.556 │ 0.543 │ 0.012 │ 0.012 │ 0.011 │ 0.01 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 16 │ 0.384 │ 0.378 │ 0.384 │ 0.393 │ 0.389 │ 0.383 │ 0.388 │ 0.397 │ 0.005 │ 0.005 │ 0.004 │ 0.004 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 17 │ 0.382 │ 0.369 │ 0.378 │ 0.39 │ 0.386 │ 0.371 │ 0.379 │ 0.394 │ 0.005 │ 0.002 │ 0.002 │ 0.003 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 18 │ 0.375 │ 0.228 │ 0.248 │ 0.278 │ 0.375 │ 0.231 │ 0.246 │ 0.274 │ 0.001 │ 0.003 │ -0.003 │ -0.003 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 19 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 20 │ 0.268 │ 0.262 │ 0.419 │ 0.312 │ 0.259 │ 0.262 │ 0.422 │ 0.322 │ -0.01 │ 0 │ 0.003 │ 0.01 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 21 │ 0.307 │ 0.315 │ 0.358 │ 0.488 │ 0.309 │ 0.324 │ 0.355 │ 0.494 │ 0.002 │ 0.009 │ -0.003 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 22 │ 0.522 │ 0.534 │ 0.55 │ 0.528 │ 0.531 │ 0.544 │ 0.561 │ 0.536 │ 0.008 │ 0.01 │ 0.012 │ 0.008 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 23 │ 0.566 │ 0.572 │ 0.572 │ 0.552 │ 0.576 │ 0.586 │ 0.585 │ 0.562 │ 0.01 │ 0.014 │ 0.012 │ 0.011 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 24 │ 0.359 │ 0.345 │ 0.356 │ 0.364 │ 0.363 │ 0.348 │ 0.357 │ 0.369 │ 0.004 │ 0.004 │ 0.001 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 25 │ 0.347 │ 0.32 │ 0.328 │ 0.351 │ 0.348 │ 0.319 │ 0.327 │ 0.353 │ 0.001 │ -0 │ -0 │ 0.002 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 26 │ 0.305 │ 0.248 │ 0.26 │ 0.305 │ 0.306 │ 0.248 │ 0.255 │ 0.307 │ 0 │ -0 │ -0.005 │ 0.001 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 27 │ 0.095 │ 0.199 │ 0.099 │ 0.208 │ 0.101 │ 0.2 │ 0.099 │ 0.2 │ 0.007 │ 0.002 │ 0.001 │ -0.007 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 28 │ 0.303 │ 0.16 │ 0.239 │ 0.207 │ 0.301 │ 0.162 │ 0.235 │ 0.205 │ -0.003 │ 0.002 │ -0.004 │ -0.002 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 29 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 30 │ 0.362 │ 0.375 │ 0.557 │ 0.388 │ 0.362 │ 0.384 │ 0.569 │ 0.393 │ 0 │ 0.009 │ 0.012 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 31 │ 0.599 │ 0.612 │ 0.612 │ 0.579 │ 0.609 │ 0.623 │ 0.628 │ 0.588 │ 0.01 │ 0.011 │ 0.016 │ 0.009 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 32 │ 0.328 │ 0.306 │ 0.317 │ 0.326 │ 0.333 │ 0.307 │ 0.319 │ 0.328 │ 0.005 │ 0.001 │ 0.002 │ 0.002 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 33 │ 0.225 │ 0.168 │ 0.18 │ 0.294 │ 0.226 │ 0.175 │ 0.182 │ 0.291 │ 0.002 │ 0.007 │ 0.001 │ -0.003 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 34 │ 0.196 │ 0.092 │ 0.099 │ 0.196 │ 0.197 │ 0.096 │ 0.101 │ 0.197 │ 0.002 │ 0.005 │ 0.002 │ 0.002 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 35 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 36 │ 0.174 │ 0.191 │ 0.291 │ 0.213 │ 0.17 │ 0.19 │ 0.289 │ 0.219 │ -0.004 │ -0.001 │ -0.001 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 37 │ 0.196 │ 0.36 │ 0.256 │ 0.273 │ 0.185 │ 0.362 │ 0.266 │ 0.272 │ -0.01 │ 0.002 │ 0.01 │ -0.001 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 38 │ 0.308 │ 0.325 │ 0.403 │ 0.521 │ 0.307 │ 0.347 │ 0.415 │ 0.535 │ -0.001 │ 0.022 │ 0.012 │ 0.014 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 39 │ 0.624 │ 0.64 │ 0.673 │ 0.604 │ 0.639 │ 0.659 │ 0.69 │ 0.611 │ 0.015 │ 0.019 │ 0.017 │ 0.008 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 40 │ 0.303 │ 0.207 │ 0.202 │ 0.212 │ 0.306 │ 0.196 │ 0.205 │ 0.211 │ 0.003 │ -0.01 │ 0.003 │ -0.002 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 41 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 42 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 43 │ 0.02 │ 0.082 │ 0.09 │ 0.069 │ 0.016 │ 0.086 │ 0.086 │ 0.071 │ -0.004 │ 0.004 │ -0.003 │ 0.001 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 44 │ 0.133 │ 0.13 │ 0.177 │ 0.217 │ 0.124 │ 0.118 │ 0.185 │ 0.214 │ -0.009 │ -0.012 │ 0.009 │ -0.003 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 45 │ 0.28 │ 0.156 │ 0.213 │ 0.204 │ 0.273 │ 0.153 │ 0.202 │ 0.19 │ -0.007 │ -0.003 │ -0.011 │ -0.014 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 46 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 47 │ 0.51 │ 0.531 │ 0.756 │ 0.472 │ 0.517 │ 0.544 │ 0.772 │ 0.482 │ 0.008 │ 0.014 │ 0.016 │ 0.01 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 48 │ 0.285 │ 0.196 │ 0.186 │ 0.197 │ 0.289 │ 0.188 │ 0.194 │ 0.196 │ 0.003 │ -0.008 │ 0.008 │ -0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 49 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 50 │ 0.038 │ 0.051 │ 0.06 │ 0.016 │ 0.042 │ 0.058 │ 0.058 │ 0.016 │ 0.004 │ 0.006 │ -0.002 │ -0.001 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 51 │ 0.049 │ 0.023 │ 0.029 │ 0.05 │ 0.048 │ 0.019 │ 0.028 │ 0.048 │ -0.001 │ -0.004 │ -0.001 │ -0.003 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 52 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 53 │ 0.26 │ 0.213 │ 0.258 │ 0.078 │ 0.251 │ 0.161 │ 0.251 │ 0.09 │ -0.009 │ -0.052 │ -0.007 │ 0.012 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 54 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 55 │ 0.588 │ 0.626 │ 0.863 │ 0.547 │ 0.588 │ 0.623 │ 0.878 │ 0.544 │ -0 │ -0.003 │ 0.015 │ -0.002 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 56 │ 0.276 │ 0.246 │ 0.255 │ 0.251 │ 0.28 │ 0.251 │ 0.254 │ 0.254 │ 0.004 │ 0.005 │ -0 │ 0.003 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 57 │ 0.149 │ 0.199 │ 0.106 │ 0.127 │ 0.159 │ 0.201 │ 0.108 │ 0.135 │ 0.01 │ 0.002 │ 0.002 │ 0.007 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 58 │ 0.124 │ 0.103 │ 0.056 │ 0.09 │ 0.127 │ 0.108 │ 0.061 │ 0.085 │ 0.004 │ 0.005 │ 0.005 │ -0.004 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 59 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 60 │ 0.068 │ 0.216 │ 0.213 │ 0.186 │ 0.079 │ 0.24 │ 0.24 │ 0.161 │ 0.011 │ 0.023 │ 0.027 │ -0.025 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 61 │ 0.34 │ 0.503 │ 0.515 │ 0.422 │ 0.322 │ 0.483 │ 0.486 │ 0.405 │ -0.018 │ -0.02 │ -0.028 │ -0.017 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 62 │ 0.401 │ 0.757 │ 0.497 │ 0.49 │ 0.404 │ 0.737 │ 0.577 │ 0.494 │ 0.003 │ -0.02 │ 0.08 │ 0.004 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 63 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╧════════╧════════╛ Action-value function RMSE: 0.0096 정책: | 00 ^ | 01 ^ | 02 > | 03 > | 04 > | 05 > | 06 > | 07 < | | 08 ^ | 09 ^ | 10 ^ | 11 ^ | 12 ^ | 13 > | 14 > | 15 v | | 16 ^ | 17 ^ | 18 < | | 20 > | 21 ^ | 22 > | 23 v | | 24 ^ | 25 ^ | 26 ^ | 27 ^ | 28 < | | 30 > | 31 v | | 32 < | 33 ^ | 34 < | | 36 > | 37 v | 38 ^ | 39 > | | 40 < | | | 43 > | 44 ^ | 45 < | | 47 > | | 48 < | | 50 v | 51 < | | 53 < | | 55 > | | 56 < | 57 v | 58 v | | 60 v | 61 v | 62 v | | Reaches goal 77.00%. Obtains an average return of 0.3837. Regret of 0.0511
Q_tss, V_tss, Q_track_tss = [], [], []
for seed in tqdm(SEEDS, desc='All seeds', leave=True):
random.seed(seed); np.random.seed(seed) ; env.seed(seed)
Q_ts, V_ts, pi_ts, Q_track_ts, pi_track_ts, T_track_ts, R_track_ts, planning_ts = trajectory_sampling(
env, gamma=gamma, n_episodes=n_episodes)
Q_tss.append(Q_ts) ; V_tss.append(V_ts) ; Q_track_tss.append(Q_track_ts)
Q_ts, V_ts, Q_track_ts = np.mean(Q_tss, axis=0), np.mean(V_tss, axis=0), np.mean(Q_track_tss, axis=0)
del Q_tss ; del V_tss ; del Q_track_tss
All seeds: 0%| | 0/5 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
0%| | 0/30000 [00:00<?, ?it/s]
print_state_value_function(V_ts, P, n_cols=n_cols,
prec=svf_prec, title='State-value function found by Trajectory Sampling:')
print_state_value_function(optimal_V, P, n_cols=n_cols,
prec=svf_prec, title='Optimal state-value function:')
print_state_value_function(V_ts - optimal_V, P, n_cols=n_cols,
prec=err_prec, title='State-value function errors:')
print('State-value function RMSE: {}'.format(rmse(V_ts, optimal_V)))
print()
print_action_value_function(Q_ts,
optimal_Q,
action_symbols=action_symbols,
prec=avf_prec,
title='Trajectory Sampling action-value function:')
print('Action-value function RMSE: {}'.format(rmse(Q_ts, optimal_Q)))
print()
print_policy(pi_ts, P, action_symbols=action_symbols, n_cols=n_cols)
success_rate_ts, mean_return_ts, mean_regret_ts = get_policy_metrics(
env, gamma=gamma, pi=pi_ts, goal_state=goal_state, optimal_Q=optimal_Q)
print('Reaches goal {:.2f}%. Obtains an average return of {:.4f}. Regret of {:.4f}'.format(
success_rate_ts, mean_return_ts, mean_regret_ts))
State-value function found by Trajectory Sampling: | 00 0.4086 | 01 0.4223 | 02 0.4415 | 03 0.4633 | 04 0.4881 | 05 0.5137 | 06 0.5315 | 07 0.5364 | | 08 0.406 | 09 0.4152 | 10 0.432 | 11 0.4539 | 12 0.4792 | 13 0.5106 | 14 0.542 | 15 0.5522 | | 16 0.3938 | 17 0.3907 | 18 0.3746 | | 20 0.4216 | 21 0.4925 | 22 0.5553 | 23 0.5811 | | 24 0.3668 | 25 0.3534 | 26 0.3077 | 27 0.1988 | 28 0.2935 | | 30 0.564 | 31 0.6245 | | 32 0.3268 | 33 0.2918 | 34 0.1974 | | 36 0.2699 | 37 0.3413 | 38 0.5286 | 39 0.6887 | | 40 0.3005 | | | 43 0.0695 | 44 0.192 | 45 0.2442 | | 47 0.775 | | 48 0.283 | | 50 0.0351 | 51 0.0344 | | 53 0.1925 | | 55 0.8773 | | 56 0.2748 | 57 0.181 | 58 0.0862 | | 60 0.1042 | 61 0.3408 | 62 0.5623 | | Optimal state-value function: | 00 0.4146 | 01 0.4272 | 02 0.4461 | 03 0.4683 | 04 0.4924 | 05 0.5166 | 06 0.5353 | 07 0.541 | | 08 0.4117 | 09 0.4212 | 10 0.4375 | 11 0.4584 | 12 0.4832 | 13 0.5135 | 14 0.5458 | 15 0.5574 | | 16 0.3968 | 17 0.3938 | 18 0.3755 | | 20 0.4217 | 21 0.4938 | 22 0.5612 | 23 0.5859 | | 24 0.3693 | 25 0.353 | 26 0.3065 | 27 0.2004 | 28 0.3008 | | 30 0.569 | 31 0.6283 | | 32 0.3327 | 33 0.2914 | 34 0.1973 | | 36 0.2893 | 37 0.362 | 38 0.5348 | 39 0.6897 | | 40 0.3061 | | | 43 0.0863 | 44 0.2139 | 45 0.2727 | | 47 0.772 | | 48 0.2889 | | 50 0.0577 | 51 0.0475 | | 53 0.2505 | | 55 0.8778 | | 56 0.2804 | 57 0.2008 | 58 0.1273 | | 60 0.2396 | 61 0.4864 | 62 0.7371 | | State-value function errors: | 00 -0.01 | 01 -0.0 | 02 -0.0 | 03 -0.0 | 04 -0.0 | 05 -0.0 | 06 -0.0 | 07 -0.0 | | 08 -0.01 | 09 -0.01 | 10 -0.01 | 11 -0.0 | 12 -0.0 | 13 -0.0 | 14 -0.0 | 15 -0.01 | | 16 -0.0 | 17 -0.0 | 18 -0.0 | | 20 -0.0 | 21 -0.0 | 22 -0.01 | 23 -0.0 | | 24 -0.0 | 25 0.0 | 26 0.0 | 27 -0.0 | 28 -0.01 | | 30 -0.0 | 31 -0.0 | | 32 -0.01 | 33 0.0 | 34 0.0 | | 36 -0.02 | 37 -0.02 | 38 -0.01 | 39 -0.0 | | 40 -0.01 | | | 43 -0.02 | 44 -0.02 | 45 -0.03 | | 47 0.0 | | 48 -0.01 | | 50 -0.02 | 51 -0.01 | | 53 -0.06 | | 55 -0.0 | | 56 -0.01 | 57 -0.02 | 58 -0.04 | | 60 -0.14 | 61 -0.15 | 62 -0.17 | | State-value function RMSE: 0.0352 Trajectory Sampling action-value function: ╒═════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤═══════╤════════╤════════╤════════╤════════╕ │ s │ < │ v │ > │ ^ │ * < │ * v │ * > │ * ^ │ er < │ er v │ er > │ er ^ │ ╞═════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪═══════╪════════╪════════╪════════╪════════╡ │ 0 │ 0.405 │ 0.407 │ 0.407 │ 0.408 │ 0.41 │ 0.414 │ 0.414 │ 0.415 │ 0.004 │ 0.007 │ 0.006 │ 0.007 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 1 │ 0.413 │ 0.417 │ 0.422 │ 0.418 │ 0.417 │ 0.423 │ 0.427 │ 0.425 │ 0.004 │ 0.006 │ 0.005 │ 0.007 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 2 │ 0.429 │ 0.434 │ 0.441 │ 0.436 │ 0.433 │ 0.44 │ 0.446 │ 0.443 │ 0.003 │ 0.006 │ 0.005 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 3 │ 0.449 │ 0.456 │ 0.463 │ 0.457 │ 0.453 │ 0.461 │ 0.468 │ 0.464 │ 0.004 │ 0.005 │ 0.005 │ 0.007 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 4 │ 0.473 │ 0.48 │ 0.488 │ 0.48 │ 0.477 │ 0.484 │ 0.492 │ 0.488 │ 0.004 │ 0.005 │ 0.004 │ 0.008 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 5 │ 0.498 │ 0.502 │ 0.514 │ 0.502 │ 0.502 │ 0.509 │ 0.517 │ 0.51 │ 0.004 │ 0.007 │ 0.003 │ 0.008 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 6 │ 0.519 │ 0.521 │ 0.532 │ 0.518 │ 0.527 │ 0.529 │ 0.535 │ 0.526 │ 0.008 │ 0.009 │ 0.004 │ 0.008 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 7 │ 0.526 │ 0.527 │ 0.536 │ 0.525 │ 0.539 │ 0.539 │ 0.541 │ 0.534 │ 0.013 │ 0.012 │ 0.005 │ 0.008 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 8 │ 0.4 │ 0.402 │ 0.402 │ 0.406 │ 0.404 │ 0.406 │ 0.407 │ 0.412 │ 0.004 │ 0.004 │ 0.004 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 9 │ 0.404 │ 0.407 │ 0.41 │ 0.415 │ 0.407 │ 0.41 │ 0.415 │ 0.421 │ 0.003 │ 0.003 │ 0.005 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 10 │ 0.408 │ 0.412 │ 0.421 │ 0.432 │ 0.41 │ 0.414 │ 0.422 │ 0.437 │ 0.003 │ 0.002 │ 0.001 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 11 │ 0.294 │ 0.295 │ 0.317 │ 0.454 │ 0.299 │ 0.304 │ 0.314 │ 0.458 │ 0.005 │ 0.009 │ -0.003 │ 0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 12 │ 0.45 │ 0.455 │ 0.465 │ 0.479 │ 0.453 │ 0.46 │ 0.471 │ 0.483 │ 0.003 │ 0.004 │ 0.006 │ 0.004 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 13 │ 0.489 │ 0.497 │ 0.511 │ 0.499 │ 0.493 │ 0.503 │ 0.514 │ 0.51 │ 0.004 │ 0.006 │ 0.003 │ 0.011 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 14 │ 0.526 │ 0.529 │ 0.542 │ 0.524 │ 0.531 │ 0.539 │ 0.546 │ 0.53 │ 0.005 │ 0.009 │ 0.004 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 15 │ 0.541 │ 0.552 │ 0.542 │ 0.537 │ 0.552 │ 0.557 │ 0.556 │ 0.543 │ 0.011 │ 0.005 │ 0.014 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 16 │ 0.385 │ 0.38 │ 0.384 │ 0.394 │ 0.389 │ 0.383 │ 0.388 │ 0.397 │ 0.003 │ 0.003 │ 0.004 │ 0.003 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 17 │ 0.382 │ 0.37 │ 0.378 │ 0.391 │ 0.386 │ 0.371 │ 0.379 │ 0.394 │ 0.004 │ 0.002 │ 0.002 │ 0.003 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 18 │ 0.375 │ 0.228 │ 0.241 │ 0.287 │ 0.375 │ 0.231 │ 0.246 │ 0.274 │ 0.001 │ 0.003 │ 0.004 │ -0.012 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 19 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 20 │ 0.249 │ 0.255 │ 0.422 │ 0.337 │ 0.259 │ 0.262 │ 0.422 │ 0.322 │ 0.009 │ 0.007 │ 0 │ -0.014 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 21 │ 0.301 │ 0.315 │ 0.356 │ 0.492 │ 0.309 │ 0.324 │ 0.355 │ 0.494 │ 0.007 │ 0.009 │ -0.001 │ 0.001 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 22 │ 0.527 │ 0.538 │ 0.555 │ 0.53 │ 0.531 │ 0.544 │ 0.561 │ 0.536 │ 0.004 │ 0.006 │ 0.006 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 23 │ 0.566 │ 0.579 │ 0.569 │ 0.557 │ 0.576 │ 0.586 │ 0.585 │ 0.562 │ 0.01 │ 0.007 │ 0.016 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 24 │ 0.354 │ 0.343 │ 0.351 │ 0.367 │ 0.363 │ 0.348 │ 0.357 │ 0.369 │ 0.009 │ 0.005 │ 0.006 │ 0.002 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 25 │ 0.339 │ 0.314 │ 0.323 │ 0.353 │ 0.348 │ 0.319 │ 0.327 │ 0.353 │ 0.009 │ 0.005 │ 0.005 │ -0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 26 │ 0.294 │ 0.246 │ 0.252 │ 0.302 │ 0.306 │ 0.248 │ 0.255 │ 0.307 │ 0.011 │ 0.001 │ 0.004 │ 0.004 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 27 │ 0.096 │ 0.173 │ 0.105 │ 0.191 │ 0.101 │ 0.2 │ 0.099 │ 0.2 │ 0.005 │ 0.027 │ -0.006 │ 0.01 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 28 │ 0.293 │ 0.149 │ 0.224 │ 0.198 │ 0.301 │ 0.162 │ 0.235 │ 0.205 │ 0.007 │ 0.012 │ 0.011 │ 0.007 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 29 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 30 │ 0.356 │ 0.391 │ 0.564 │ 0.382 │ 0.362 │ 0.384 │ 0.569 │ 0.393 │ 0.006 │ -0.007 │ 0.005 │ 0.011 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 31 │ 0.597 │ 0.604 │ 0.624 │ 0.582 │ 0.609 │ 0.623 │ 0.628 │ 0.588 │ 0.012 │ 0.019 │ 0.005 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 32 │ 0.319 │ 0.298 │ 0.308 │ 0.32 │ 0.333 │ 0.307 │ 0.319 │ 0.328 │ 0.014 │ 0.009 │ 0.011 │ 0.007 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 33 │ 0.21 │ 0.167 │ 0.159 │ 0.292 │ 0.226 │ 0.175 │ 0.182 │ 0.291 │ 0.016 │ 0.008 │ 0.023 │ -0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 34 │ 0.173 │ 0.112 │ 0.102 │ 0.184 │ 0.197 │ 0.096 │ 0.101 │ 0.197 │ 0.025 │ -0.016 │ -0.001 │ 0.013 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 35 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 36 │ 0.165 │ 0.174 │ 0.27 │ 0.208 │ 0.17 │ 0.19 │ 0.289 │ 0.219 │ 0.005 │ 0.016 │ 0.019 │ 0.011 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 37 │ 0.169 │ 0.341 │ 0.257 │ 0.269 │ 0.185 │ 0.362 │ 0.266 │ 0.272 │ 0.016 │ 0.021 │ 0.01 │ 0.003 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 38 │ 0.31 │ 0.346 │ 0.419 │ 0.529 │ 0.307 │ 0.347 │ 0.415 │ 0.535 │ -0.002 │ 0.001 │ -0.004 │ 0.006 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 39 │ 0.636 │ 0.654 │ 0.689 │ 0.608 │ 0.639 │ 0.659 │ 0.69 │ 0.611 │ 0.003 │ 0.004 │ 0.001 │ 0.003 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 40 │ 0.3 │ 0.186 │ 0.199 │ 0.203 │ 0.306 │ 0.196 │ 0.205 │ 0.211 │ 0.006 │ 0.011 │ 0.006 │ 0.007 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 41 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 42 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 43 │ 0.01 │ 0.062 │ 0.06 │ 0.047 │ 0.016 │ 0.086 │ 0.086 │ 0.071 │ 0.006 │ 0.024 │ 0.026 │ 0.024 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 44 │ 0.113 │ 0.103 │ 0.158 │ 0.192 │ 0.124 │ 0.118 │ 0.185 │ 0.214 │ 0.011 │ 0.015 │ 0.027 │ 0.022 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 45 │ 0.244 │ 0.123 │ 0.175 │ 0.176 │ 0.273 │ 0.153 │ 0.202 │ 0.19 │ 0.028 │ 0.03 │ 0.027 │ 0.014 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 46 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 47 │ 0.499 │ 0.556 │ 0.775 │ 0.475 │ 0.517 │ 0.544 │ 0.772 │ 0.482 │ 0.018 │ -0.012 │ -0.003 │ 0.007 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 48 │ 0.283 │ 0.167 │ 0.188 │ 0.201 │ 0.289 │ 0.188 │ 0.194 │ 0.196 │ 0.006 │ 0.021 │ 0.005 │ -0.005 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 49 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 50 │ 0.018 │ 0.01 │ 0.027 │ 0.003 │ 0.042 │ 0.058 │ 0.058 │ 0.016 │ 0.024 │ 0.048 │ 0.031 │ 0.013 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 51 │ 0.021 │ 0.007 │ 0.016 │ 0.029 │ 0.048 │ 0.019 │ 0.028 │ 0.048 │ 0.026 │ 0.012 │ 0.012 │ 0.019 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 52 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 53 │ 0.173 │ 0.093 │ 0.165 │ 0.075 │ 0.251 │ 0.161 │ 0.251 │ 0.09 │ 0.077 │ 0.068 │ 0.086 │ 0.015 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 54 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 55 │ 0.587 │ 0.658 │ 0.877 │ 0.543 │ 0.588 │ 0.623 │ 0.878 │ 0.544 │ 0.001 │ -0.035 │ 0 │ 0.001 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 56 │ 0.275 │ 0.231 │ 0.23 │ 0.237 │ 0.28 │ 0.251 │ 0.254 │ 0.254 │ 0.006 │ 0.02 │ 0.024 │ 0.017 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 57 │ 0.125 │ 0.163 │ 0.076 │ 0.074 │ 0.159 │ 0.201 │ 0.108 │ 0.135 │ 0.034 │ 0.037 │ 0.032 │ 0.06 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 58 │ 0.072 │ 0.05 │ 0.02 │ 0.032 │ 0.127 │ 0.108 │ 0.061 │ 0.085 │ 0.056 │ 0.058 │ 0.041 │ 0.054 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 59 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 60 │ 0.003 │ 0.065 │ 0.056 │ 0.024 │ 0.079 │ 0.24 │ 0.24 │ 0.161 │ 0.076 │ 0.175 │ 0.184 │ 0.136 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 61 │ 0.164 │ 0.198 │ 0.318 │ 0.216 │ 0.322 │ 0.483 │ 0.486 │ 0.405 │ 0.158 │ 0.285 │ 0.168 │ 0.189 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 62 │ 0.109 │ 0.499 │ 0.19 │ 0.23 │ 0.404 │ 0.737 │ 0.577 │ 0.494 │ 0.295 │ 0.238 │ 0.386 │ 0.264 │ ├─────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┤ │ 63 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ 0 │ ╘═════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧═══════╧════════╧════════╧════════╧════════╛ Action-value function RMSE: 0.0517 정책: | 00 v | 01 > | 02 > | 03 > | 04 > | 05 > | 06 > | 07 > | | 08 ^ | 09 ^ | 10 ^ | 11 ^ | 12 ^ | 13 > | 14 > | 15 v | | 16 ^ | 17 ^ | 18 < | | 20 > | 21 ^ | 22 > | 23 v | | 24 ^ | 25 ^ | 26 ^ | 27 ^ | 28 < | | 30 > | 31 v | | 32 ^ | 33 ^ | 34 < | | 36 > | 37 v | 38 ^ | 39 > | | 40 < | | | 43 > | 44 ^ | 45 < | | 47 > | | 48 < | | 50 > | 51 ^ | | 53 < | | 55 > | | 56 < | 57 < | 58 v | | 60 v | 61 < | 62 v | | Reaches goal 79.00%. Obtains an average return of 0.4173. Regret of 0.0514
plot_value_function(
'Sarsa(λ) replacing estimates through time vs. true values',
np.max(Q_track_rsl, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Sarsa(λ) replacing estimates through time vs. true values (log scale)',
np.max(Q_track_rsl, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Sarsa(λ) replacing estimates through time (close up)',
np.max(Q_track_rsl, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Sarsa(λ) accumulating estimates through time vs. true values',
np.max(Q_track_asl, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Sarsa(λ) accumulating estimates through time vs. true values (log scale)',
np.max(Q_track_asl, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Sarsa(λ) accumulating estimates through time (close up)',
np.max(Q_track_asl, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Q(λ) replacing estimates through time vs. true values',
np.max(Q_track_rqll, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Q(λ) replacing estimates through time vs. true values (log scale)',
np.max(Q_track_rqll, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Q(λ) replacing estimates through time (close up)',
np.max(Q_track_rqll, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Q(λ) accumulating estimates through time vs. true values',
np.max(Q_track_aqll, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Q(λ) accumulating estimates through time vs. true values (log scale)',
np.max(Q_track_aqll, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Q(λ) accumulating estimates through time (close up)',
np.max(Q_track_aqll, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Dyna-Q estimates through time vs. true values',
np.max(Q_track_dq, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Dyna-Q estimates through time vs. true values (log scale)',
np.max(Q_track_dq, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Dyna-Q estimates through time (close up)',
np.max(Q_track_dq, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
plot_value_function(
'Trajectory Sampling estimates through time vs. true values',
np.max(Q_track_ts, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=False)
plot_value_function(
'Trajectory Sampling estimates through time vs. true values (log scale)',
np.max(Q_track_ts, axis=2),
optimal_V,
limit_items=limit_items,
limit_value=limit_value,
log=True)
plot_value_function(
'Trajectory Sampling estimates through time (close up)',
np.max(Q_track_ts, axis=2)[:cu_episodes],
None,
limit_items=cu_limit_items,
limit_value=cu_limit_value,
log=False)
rsl_success_rate_ma, rsl_mean_return_ma, rsl_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_rsl, coverage=0.05)
0%| | 0/30000 [00:00<?, ?it/s]
asl_success_rate_ma, asl_mean_return_ma, asl_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_asl, coverage=0.05)
0%| | 0/30000 [00:00<?, ?it/s]
rqll_success_rate_ma, rqll_mean_return_ma, rqll_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_rqll, coverage=0.05)
0%| | 0/30000 [00:00<?, ?it/s]
aqll_success_rate_ma, aqll_mean_return_ma, aqll_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_aqll, coverage=0.05)
0%| | 0/30000 [00:00<?, ?it/s]
dq_success_rate_ma, dq_mean_return_ma, dq_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_dq, coverage=0.05)
0%| | 0/30000 [00:00<?, ?it/s]
ts_success_rate_ma, ts_mean_return_ma, ts_mean_regret_ma = get_metrics_from_tracks(
env, gamma, goal_state, optimal_Q, pi_track_ts, coverage=0.05)
0%| | 0/30000 [00:00<?, ?it/s]
plt.axhline(y=success_rate_op, color='k', linestyle='-', linewidth=1)
plt.text(int(len(rsl_success_rate_ma)*1.02), success_rate_op*1.01, 'π*')
plt.plot(rsl_success_rate_ma, '-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(asl_success_rate_ma, '--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(rqll_success_rate_ma, ':', linewidth=2, label='Q(λ) replacing')
plt.legend(loc=4, ncol=1)
plt.title('Policy success rate (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Success rate %')
plt.ylim(-1, 101)
plt.xticks(rotation=45)
plt.show()
plt.axhline(y=success_rate_op, color='k', linestyle='-', linewidth=1)
plt.text(int(len(rsl_success_rate_ma)*1.02), success_rate_op*1.01, 'π*')
plt.plot(aqll_success_rate_ma, '-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(dq_success_rate_ma, '-', linewidth=2, label='Dyna-Q')
plt.plot(ts_success_rate_ma, '--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=4, ncol=1)
plt.title('Policy success rate (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Success rate %')
plt.ylim(-1, 101)
plt.xticks(rotation=45)
plt.show()
plt.axhline(y=mean_return_op, color='k', linestyle='-', linewidth=1)
plt.text(int(len(rsl_mean_return_ma)*1.02), mean_return_op*1.01, 'π*')
plt.plot(rsl_mean_return_ma, '-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(asl_mean_return_ma, '--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(rqll_mean_return_ma, ':', linewidth=2, label='Q(λ) replacing')
plt.legend(loc=4, ncol=1)
plt.title('Policy episode return (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Return (Gt:T)')
plt.xticks(rotation=45)
plt.show()
plt.axhline(y=mean_return_op, color='k', linestyle='-', linewidth=1)
plt.text(int(len(rsl_mean_return_ma)*1.02), mean_return_op*1.01, 'π*')
plt.plot(aqll_mean_return_ma, '-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(dq_mean_return_ma, '-', linewidth=2, label='Dyna-Q')
plt.plot(ts_mean_return_ma, '--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=4, ncol=1)
plt.title('Policy episode return (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Return (Gt:T)')
plt.xticks(rotation=45)
plt.show()
plt.plot(rsl_mean_regret_ma, '-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(asl_mean_regret_ma, '--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(rqll_mean_regret_ma, ':', linewidth=2, label='Q(λ) replacing')
plt.legend(loc=1, ncol=1)
plt.title('Policy episode regret (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Regret (q* - Q)')
plt.xticks(rotation=45)
plt.show()
plt.plot(aqll_mean_regret_ma, '-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(dq_mean_regret_ma, '-', linewidth=2, label='Dyna-Q')
plt.plot(ts_mean_regret_ma, '--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=1, ncol=1)
plt.title('Policy episode regret (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Regret (q* - Q)')
plt.xticks(rotation=45)
plt.show()
plt.axhline(y=optimal_V[init_state], color='k', linestyle='-', linewidth=1)
plt.text(int(len(Q_track_rsl)*1.05), optimal_V[init_state]+.01, 'v*({})'.format(init_state))
plt.plot(moving_average(np.max(Q_track_rsl, axis=2).T[init_state]),
'-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(moving_average(np.max(Q_track_asl, axis=2).T[init_state]),
'--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(moving_average(np.max(Q_track_rqll, axis=2).T[init_state]),
':', linewidth=2, label='Q(λ) replacing')
plt.plot(moving_average(np.max(Q_track_aqll, axis=2).T[init_state]),
'-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(moving_average(np.max(Q_track_dq, axis=2).T[init_state]),
'-', linewidth=2, label='Dyna-Q')
plt.plot(moving_average(np.max(Q_track_ts, axis=2).T[init_state]),
'--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=4, ncol=1)
plt.title('Estimated expected return (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Estimated value of initial state V({})'.format(init_state))
plt.xticks(rotation=45)
plt.show()
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_rsl, axis=2) - optimal_V), axis=1)),
'-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_asl, axis=2) - optimal_V), axis=1)),
'--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_rqll, axis=2) - optimal_V), axis=1)),
':', linewidth=2, label='Q(λ) replacing')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_aqll, axis=2) - optimal_V), axis=1)),
'-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_dq, axis=2) - optimal_V), axis=1)),
'-', linewidth=2, label='Dyna-Q')
plt.plot(moving_average(np.mean(np.abs(np.max(Q_track_ts, axis=2) - optimal_V), axis=1)),
'--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=1, ncol=1)
plt.title('State-value function estimation error (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Mean Absolute Error MAE(V, v*)')
plt.xticks(rotation=45)
plt.show()
plt.plot(moving_average(np.mean(np.abs(Q_track_rsl - optimal_Q), axis=(1,2))),
'-', linewidth=2, label='Sarsa(λ) replacing')
plt.plot(moving_average(np.mean(np.abs(Q_track_asl - optimal_Q), axis=(1,2))),
'--', linewidth=2, label='Sarsa(λ) accumulating')
plt.plot(moving_average(np.mean(np.abs(Q_track_rqll - optimal_Q), axis=(1,2))),
':', linewidth=2, label='Q(λ) replacing')
plt.plot(moving_average(np.mean(np.abs(Q_track_aqll - optimal_Q), axis=(1,2))),
'-.', linewidth=2, label='Q(λ) accumulating')
plt.plot(moving_average(np.mean(np.abs(Q_track_dq - optimal_Q), axis=(1,2))),
'-', linewidth=2, label='Dyna-Q')
plt.plot(moving_average(np.mean(np.abs(Q_track_ts - optimal_Q), axis=(1,2))),
'--', linewidth=2, label='Trajectory Sampling')
plt.legend(loc=1, ncol=1)
plt.title('Action-value function estimation error (ma 100)')
plt.xlabel('Episodes')
plt.ylabel('Mean Absolute Error MAE(Q, q*)')
plt.xticks(rotation=45)
plt.show()