%load_ext autoreload %autoreload 2 %pip uninstall Pearl -y %rm -rf Pearl !git clone https://github.com/facebookresearch/Pearl.git %cd Pearl %pip install . %cd .. from pearl.utils.functional_utils.experimentation.set_seed import set_seed from pearl.policy_learners.sequential_decision_making.deep_q_learning import DeepQLearning from pearl.replay_buffers import BasicReplayBuffer from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning from pearl.pearl_agent import PearlAgent from pearl.utils.instantiations.environments.gym_environment import GymEnvironment from pearl.utils.instantiations.environments.environments import ( OneHotObservationsFromDiscrete, ) from pearl.utils.instantiations.spaces.discrete import DiscreteSpace import torch import matplotlib.pyplot as plt import numpy as np from pearl.action_representation_modules.one_hot_action_representation_module import ( OneHotActionTensorRepresentationModule, ) set_seed(0) number_of_steps = 20000 record_period = 400 """ This test is checking if DQN will eventually solve FrozenLake-v1 whose observations need to be wrapped in a one-hot representation. """ env = OneHotObservationsFromDiscrete( GymEnvironment( "FrozenLake-v1", is_slippery=False, map_name="4x4", ) ) action_representation_module = OneHotActionTensorRepresentationModule( max_number_actions= env.action_space.n, ) assert isinstance(env.action_space, DiscreteSpace) state_dim = env.observation_space.n agent = PearlAgent( policy_learner=DeepQLearning( state_dim=state_dim, action_space=env.action_space, hidden_dims=[64, 64], training_rounds=1, action_representation_module=action_representation_module ), replay_buffer=BasicReplayBuffer(1000), ) info = online_learning( agent=agent, env=env, number_of_steps=number_of_steps, print_every_x_steps=100, record_period=record_period, learn_after_episode=False, ) torch.save(info["return"], "DQN-return.pt") plt.plot(record_period * np.arange(len(info["return"])), info["return"], label="DQN") plt.xlabel("steps") plt.ylabel("return") plt.legend() plt.show()