Here is a better rendering of this notebook on nbviewer.
The class of actor-critic algorithms is known for its practical success. Pearl includes several popular algorithms in this class, such as PPO, DDPG, SAC, and TD3.
In this tutorial, we will demonstrate how to use the TD3 algorithm with Pearl. Implementing other algorithms is just as simple; you only need to modify the Policy Learner to the appropriate algorithm.
Pearl also supports safe training, which allows an algorithm designer to optimize a reward function while adhering to additional constraints. We will show you how to use safe training by instantiating the Safety Module with the Reward Constraint Safety Module.
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
If you haven't installed Pearl, please make sure you install Pearl with the following cell. Otherwise, you can skip the cell below.
%pip uninstall Pearl -y
%rm -rf Pearl
!git clone https://github.com/facebookresearch/Pearl.git
%cd Pearl
%pip install .
%cd ..
WARNING: Skipping Pearl as it is not installed. Cloning into 'Pearl'... remote: Enumerating objects: 5265, done. remote: Counting objects: 100% (1477/1477), done. remote: Compressing objects: 100% (401/401), done. remote: Total 5265 (delta 1166), reused 1272 (delta 1065), pack-reused 3788 Receiving objects: 100% (5265/5265), 53.40 MiB | 18.32 MiB/s, done. Resolving deltas: 100% (3496/3496), done. /content/Pearl Processing /content/Pearl Installing build dependencies ... done Getting requirements to build wheel ... done Installing backend dependencies ... done Preparing metadata (pyproject.toml) ... done Requirement already satisfied: gym in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (0.25.2) Collecting gymnasium[accept-rom-license,atari,mujoco] (from Pearl==0.1.0) Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 953.9/953.9 kB 11.8 MB/s eta 0:00:00 Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (1.25.2) Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (3.7.1) Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (1.5.3) Collecting parameterized (from Pearl==0.1.0) Downloading parameterized-0.9.0-py2.py3-none-any.whl (20 kB) Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (2.31.0) Collecting mujoco (from Pearl==0.1.0) Downloading mujoco-3.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.3/5.3 MB 37.4 MB/s eta 0:00:00 Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (2.1.0+cu121) Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (0.16.0+cu121) Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (2.1.0+cu121) Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gym->Pearl==0.1.0) (2.2.1) Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym->Pearl==0.1.0) (0.0.8) Requirement already satisfied: typing-extensions>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) (4.10.0) Collecting farama-notifications>=0.0.1 (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB) Collecting autorom[accept-rom-license]~=0.4.2 (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) Downloading AutoROM-0.4.2-py3-none-any.whl (16 kB) Requirement already satisfied: imageio>=2.14.1 in /usr/local/lib/python3.10/dist-packages (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) (2.31.6) Collecting shimmy[atari]<1.0,>=0.1.0 (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) Downloading Shimmy-0.2.1-py3-none-any.whl (25 kB) Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from mujoco->Pearl==0.1.0) (1.4.0) Requirement already satisfied: etils[epath] in /usr/local/lib/python3.10/dist-packages (from mujoco->Pearl==0.1.0) (1.7.0) Collecting glfw (from mujoco->Pearl==0.1.0) Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.8/211.8 kB 27.7 MB/s eta 0:00:00 Requirement already satisfied: pyopengl in /usr/local/lib/python3.10/dist-packages (from mujoco->Pearl==0.1.0) (3.1.7) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (1.2.0) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (4.49.0) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (1.4.5) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (23.2) Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (9.4.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (3.1.1) Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->Pearl==0.1.0) (2023.4) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->Pearl==0.1.0) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->Pearl==0.1.0) (3.6) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->Pearl==0.1.0) (2.0.7) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->Pearl==0.1.0) (2024.2.2) Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (3.13.1) Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (1.12) Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (3.2.1) Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (3.1.3) Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (2023.6.0) Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (2.1.0) Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from autorom[accept-rom-license]~=0.4.2->gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) (8.1.7) Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from autorom[accept-rom-license]~=0.4.2->gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) (4.66.2) Collecting AutoROM.accept-rom-license (from autorom[accept-rom-license]~=0.4.2->gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) Downloading AutoROM.accept-rom-license-0.6.1.tar.gz (434 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 434.7/434.7 kB 32.1 MB/s eta 0:00:00 Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->Pearl==0.1.0) (1.16.0) Collecting ale-py~=0.8.1 (from shimmy[atari]<1.0,>=0.1.0->gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) Downloading ale_py-0.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 37.7 MB/s eta 0:00:00 Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils[epath]->mujoco->Pearl==0.1.0) (6.1.2) Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils[epath]->mujoco->Pearl==0.1.0) (3.17.0) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->Pearl==0.1.0) (2.1.5) Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->Pearl==0.1.0) (1.3.0) Building wheels for collected packages: Pearl, AutoROM.accept-rom-license Building wheel for Pearl (pyproject.toml) ... done Created wheel for Pearl: filename=Pearl-0.1.0-py3-none-any.whl size=207231 sha256=da9df6e8130797beb4ca25f951d71fd64b01816fa7ccee0c04b06dc6a718d81a Stored in directory: /tmp/pip-ephem-wheel-cache-vpk5nyu2/wheels/83/80/1d/d9211ba70ee392341daf21a07252739e0cb2af9f95439a28cd Building wheel for AutoROM.accept-rom-license (pyproject.toml) ... done Created wheel for AutoROM.accept-rom-license: filename=AutoROM.accept_rom_license-0.6.1-py3-none-any.whl size=446660 sha256=a386ef486a27f4c012ca3ccecb7469389408470e7dcb029356e7b2fd0d35153c Stored in directory: /root/.cache/pip/wheels/6b/1b/ef/a43ff1a2f1736d5711faa1ba4c1f61be1131b8899e6a057811 Successfully built Pearl AutoROM.accept-rom-license Installing collected packages: glfw, farama-notifications, parameterized, gymnasium, ale-py, shimmy, AutoROM.accept-rom-license, autorom, mujoco, Pearl Successfully installed AutoROM.accept-rom-license-0.6.1 Pearl-0.1.0 ale-py-0.8.1 autorom-0.4.2 farama-notifications-0.0.4 glfw-2.7.0 gymnasium-0.29.1 mujoco-3.1.2 parameterized-0.9.0 shimmy-0.2.1 /content
from pearl.utils.functional_utils.experimentation.set_seed import set_seed
from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import FIFOOffPolicyReplayBuffer
from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning
from pearl.pearl_agent import PearlAgent
from pearl.user_envs.wrappers.gym_avg_torque_cost import GymAvgTorqueWrapper
from pearl.utils.instantiations.environments.gym_environment import GymEnvironment
import gymnasium as gym
from pearl.policy_learners.sequential_decision_making.td3 import TD3
from pearl.neural_networks.sequential_decision_making.actor_networks import VanillaContinuousActorNetwork
from pearl.neural_networks.sequential_decision_making.q_value_networks import VanillaQValueNetwork
from pearl.policy_learners.exploration_modules.common.normal_distribution_exploration import (
NormalDistributionExploration,
)
from pearl.safety_modules.reward_constrained_safety_module import (
RCSafetyModuleCostCriticContinuousAction,
)
from matplotlib import pyplot as plt
import torch
import numpy as np
set_seed(0)
Let's dive into the code. First, we will create a MuJoCo environment with an additional torque cost function. This environment is designed as a wrapper on top of Gym, so at each step, the environment returns the usual reward and an additional cost function. The cost function we use is $c(s,a)= ||a ||^2$, which represents the squared norm of the action. This cost function represents the average power of the action, and often, algorithm designers wish to restrict this quantity. In later sections, we will use the additional cost function to place additional constraints on the agent.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# create a HalfCheetah with an additional torque cost by using the GymAvgTorqueWrapper wrapper
env = GymEnvironment(GymAvgTorqueWrapper(gym.make("HalfCheetah-v4")))
/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above. and should_run_async(code)
With the environment ready, we now test the performance of the TD3 algorithm. We set the Policy Learner to be TD3 and configure the replay buffer, but do not add any special safety module at this point.
The following code demonstrates the performance of TD3 in the environment with the additional torque cost. As the learning progresses, we track and print both the cumulative reward and the cumulative cost.
# setup TD3 algorithm
td3_agent = PearlAgent(
policy_learner=TD3(
state_dim=env.observation_space.shape[0],
action_space=env.action_space,
actor_hidden_dims= [256, 256],
critic_hidden_dims= [256, 256],
training_rounds= 1,
batch_size= 256,
actor_network_type= VanillaContinuousActorNetwork,
critic_network_type= VanillaQValueNetwork,
actor_soft_update_tau= 0.005,
critic_soft_update_tau= 0.005,
actor_learning_rate= 1e-3,
critic_learning_rate= 3e-4,
discount_factor= 0.99,
actor_update_freq= 2,
actor_update_noise= 0.2,
actor_update_noise_clip= 0.5,
exploration_module=NormalDistributionExploration(
mean=0.0,
std_dev=0.1,
),
),
replay_buffer=FIFOOffPolicyReplayBuffer(
capacity=100000,
has_cost_available=True
),
safety_module=None,
)
# Run TD3 on the environment
number_of_steps = 200000
print_every_x_steps = 20000
record_period = 1000
td3_info = online_learning(
td3_agent,
env,
number_of_steps=number_of_steps,
print_every_x_steps=print_every_x_steps,
record_period=record_period,
)
episode 20, step 20000, agent=PearlAgent with TD3, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 49.970514269312844 return_cost: 806.2538956999779 episode 40, step 40000, agent=PearlAgent with TD3, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: -14.736769429975539 return_cost: 787.3026733063161 episode 60, step 60000, agent=PearlAgent with TD3, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 1430.6730238944292 return_cost: 813.093893378973 episode 80, step 80000, agent=PearlAgent with TD3, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 2032.4104736470617 return_cost: 794.4528419971466 episode 100, step 100000, agent=PearlAgent with TD3, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 2657.43493815884 return_cost: 793.5689518451691 episode 120, step 120000, agent=PearlAgent with TD3, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 2887.5689703784883 return_cost: 798.2803127765656 episode 140, step 140000, agent=PearlAgent with TD3, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 3716.0947114322335 return_cost: 803.4706705212593 episode 160, step 160000, agent=PearlAgent with TD3, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 3800.822272643447 return_cost: 800.8424595594406 episode 180, step 180000, agent=PearlAgent with TD3, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 3853.1925881505013 return_cost: 805.224086523056 episode 200, step 200000, agent=PearlAgent with TD3, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 4108.776982480951 return_cost: 812.8720943331718
The modular design of Pearl enables us to easily combine the TD3 algorithm with a safety module. One such generic safety module we developed is based on the Reward Constrained (RC) Policy Optimization (PO) framework [1].
This approach is versatile and can be applied to any actor-critic algorithm. To try its Pearl implementation yourself, simply instantiate the Pearl agent safety module with the RC safety module.
To integrate TD3 with the RC safety module, simply instantiate PearlAgent
with TD3 as before, and set the safety_module
parameter to RCSafetyModuleCostCriticContinuousAction
. The following code demonstrates this.
[1] Reward Constrained Policy Optimization, C. Tessler, D. Mankowitz, S. Mannor, 2019, https://arxiv.org/abs/1805.11074.
# setup RCTD3 algorithm, TD3 with reward constraint safety module
rctd3_agent = PearlAgent(
policy_learner=TD3(
state_dim=env.observation_space.shape[0],
action_space=env.action_space,
actor_hidden_dims= [256, 256],
critic_hidden_dims= [256, 256],
training_rounds= 1,
batch_size= 256,
actor_network_type= VanillaContinuousActorNetwork,
critic_network_type= VanillaQValueNetwork,
actor_soft_update_tau= 0.005,
critic_soft_update_tau= 0.005,
actor_learning_rate= 1e-3,
critic_learning_rate= 3e-4,
discount_factor= 0.99,
actor_update_freq= 2,
actor_update_noise= 0.2,
actor_update_noise_clip= 0.5,
exploration_module=NormalDistributionExploration(
mean=0.0,
std_dev=0.1,
),
),
replay_buffer=FIFOOffPolicyReplayBuffer(
capacity=100000,
has_cost_available=True
),
safety_module=RCSafetyModuleCostCriticContinuousAction(
state_dim=env.observation_space.shape[0],
action_space=env.action_space,
critic_hidden_dims= [256, 256],
constraint_value=0.4,
lambda_constraint_ub_value=200.0,
lr_lambda=1e-3
),
)
# Run RCTD3 on the environment
number_of_steps = 200000
print_every_x_steps = 20000
record_period = 1000
rctd3_info = online_learning(
rctd3_agent,
env,
number_of_steps=number_of_steps,
print_every_x_steps=print_every_x_steps,
record_period=record_period,
)
episode 20, step 20000, agent=PearlAgent with TD3, RCSafetyModuleCostCriticContinuousAction, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 946.3039617876057 return_cost: 612.9130087792873 episode 40, step 40000, agent=PearlAgent with TD3, RCSafetyModuleCostCriticContinuousAction, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 1735.8694469803013 return_cost: 557.5508826673031 episode 60, step 60000, agent=PearlAgent with TD3, RCSafetyModuleCostCriticContinuousAction, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 2104.1888732789084 return_cost: 462.54283828660846 episode 80, step 80000, agent=PearlAgent with TD3, RCSafetyModuleCostCriticContinuousAction, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 2863.3212767671794 return_cost: 478.60841024667025 episode 100, step 100000, agent=PearlAgent with TD3, RCSafetyModuleCostCriticContinuousAction, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 3986.5186753571033 return_cost: 469.34288113564253 episode 120, step 120000, agent=PearlAgent with TD3, RCSafetyModuleCostCriticContinuousAction, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 4531.529906626791 return_cost: 458.11196183785796 episode 140, step 140000, agent=PearlAgent with TD3, RCSafetyModuleCostCriticContinuousAction, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 4801.1537653543055 return_cost: 442.866158567369 episode 160, step 160000, agent=PearlAgent with TD3, RCSafetyModuleCostCriticContinuousAction, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 5312.129231594212 return_cost: 431.6710966601968 episode 180, step 180000, agent=PearlAgent with TD3, RCSafetyModuleCostCriticContinuousAction, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 5500.471403837204 return_cost: 425.44032253697515 episode 200, step 200000, agent=PearlAgent with TD3, RCSafetyModuleCostCriticContinuousAction, FIFOOffPolicyReplayBuffer, env=HalfCheetah-v4 return: 5173.863486928865 return_cost: 419.25983215868473
Next we elaborate on the RC-PO framework and its Pearl implementation. The RC-PO framework is based on the Lagrangian formulation of Constraint MDPs (CMDP). Effectively, it translates the solution of a CMDP to a min-max optimization problem. The min-player solves an MDP with a modifed reward of $$ r_\lambda(s,a) = r(s,a) + \lambda (c(s,a)-\alpha), $$ where $r(s,a)$ and $c(s,a)$ are the immediate reward and cost, $\alpha$ is the constraint, and $\lambda$ is the Lagrange multiplier. On the other hand, the max-player maximizes $\lambda$, and increases it as long the agent does not satisfy the constraint.
In Pearl, when the safety module is instantiated with RC safety, the baseline actor optimizes the effective reward $r_\lambda(s,a)$ while updating the Lagrange multiplier.
Let us now compare the results. As we can see, the cumulative cost of the RCTD3 agent is much better controlled than that of the TD3 agent. Additionally, the cumulative reward of the RCTD3 agent is higher in this particular problem, which may be due to the extra regularization provided by the RC module. This is not always the case, but it is possible in certain situations.
#Plot the cummulative return and cummulative cost
steps = np.arange(200) * record_period
plt.plot(steps, td3_info["return"], label='TD3')
plt.plot(steps, rctd3_info["return"], label='RCTD3')
plt.xlabel('Time step')
plt.ylabel('Cummulative return')
plt.title("Cummulative return")
plt.legend()
plt.show()
# Create the second plot
plt.plot(steps, td3_info["return_cost"], label='TD3')
plt.plot(steps, rctd3_info["return_cost"], label='RCTD3')
plt.xlabel('Time step')
plt.ylabel('Cummulative cost')
plt.title("Cummulative cost")
plt.legend()
plt.show()
In this tutorial, we showed how to use the TD3 algorithm and its safe variant, we refer as RCTD3, in Pearl. Further, we elaborated on the RC framework, and empirically tested the performance of TD3 and RCTD3.