#!/usr/bin/env python
# coding: utf-8
# Before you turn this problem in, make sure everything runs as expected. First, **restart the kernel** (in the menubar, select Kernel$\rightarrow$Restart) and then **run all cells** (in the menubar, select Cell$\rightarrow$Run All).
#
# Make sure you fill in any place that says `YOUR CODE HERE` or "YOUR ANSWER HERE", as well as your name and collaborators below:
# In[ ]:
NAME = ""
COLLABORATORS = ""
# ---
#
# *This notebook contains material from [PyRosetta](https://RosettaCommons.github.io/PyRosetta.notebooks);
# content is available [on Github](https://github.com/RosettaCommons/PyRosetta.notebooks.git).*
#
# < [Examples Using the `dask` Module](http://nbviewer.jupyter.org/github/RosettaCommons/PyRosetta.notebooks/blob/master/notebooks/16.04-dask.delayed-Via-Slurm.ipynb) | [Contents](toc.ipynb) | [Index](index.ipynb) | [Appendix A: Command Reference](http://nbviewer.jupyter.org/github/RosettaCommons/PyRosetta.notebooks/blob/master/notebooks/A.00-Appendix-A.ipynb) >
# # Part I: Parallelized Global Ligand Docking with `pyrosetta.distributed`
# *Warning*: This notebook uses `pyrosetta.distributed.viewer` code, which runs in `jupyter notebook` and might not run if you're using `jupyterlab`.
# *Note:* This Jupyter notebook uses parallelization and is **not** meant to be executed within a Google Colab environment.
# *Note:* This Jupyter notebook requires the PyRosetta distributed layer which is obtained by building PyRosetta with the `--serialization` flag or installing PyRosetta from the RosettaCommons conda channel
#
# **Please see the setup instructions in Chapter 16.00**
# In[ ]:
import logging
logging.basicConfig(level=logging.INFO)
import json
import matplotlib
import matplotlib.pyplot as plt
get_ipython().run_line_magic('matplotlib', 'inline')
import numpy as np
import os
import pandas as pd
import pyrosetta
import pyrosetta.distributed.dask
import pyrosetta.distributed.io as io
import pyrosetta.distributed.packed_pose as packed_pose
import pyrosetta.distributed.tasks.rosetta_scripts as rosetta_scripts
import pyrosetta.distributed.tasks.score as score
import pyrosetta.distributed.viewer as viewer
import seaborn
seaborn.set()
import sys
from dask_jobqueue import SLURMCluster
from dask.distributed import Client, progress, as_completed
from IPython import display
if 'google.colab' in sys.modules:
print("This Jupyter notebook uses parallelization and is therefore not set up for the Google Colab environment.")
sys.exit(0)
# Setup PyRosetta command line flags:
# In[ ]:
ligand_params = "inputs/TPA.am1-bcc.fa.params"
flags = f"""
-extra_res_fa {ligand_params}
-ignore_unrecognized_res 1
-out:level 200
"""
pyrosetta.distributed.init(flags)
# Setup `dask` workers to run ligand docking simulations:
# In[ ]:
if not os.getenv("DEBUG"):
scratch_dir = os.path.join("/net/scratch", os.environ["USER"]) # Change to your scratch directory
cluster = SLURMCluster(cores=1,
processes=1,
job_cpu=1,
memory="3GB",
queue="short",
walltime="02:59:00",
local_directory=scratch_dir,
job_extra=["-o {}".format(os.path.join(scratch_dir, "slurm-%j.out"))],
extra=pyrosetta.distributed.dask.worker_extra(init_flags=flags))
n_workers = 20
cluster.scale(n_workers)
client = Client(cluster)
else:
cluster, client = None, None
# In[ ]:
client
# Setup global ligand docking RosettaScripts protocol within `pyrosetta.distributed`:
# In[ ]:
xml = """
"""
xml_obj = rosetta_scripts.SingleoutputRosettaScriptsTask(xml)
xml_obj.setup()
# Setup input pose as `PackedPose` object:
# In[ ]:
pose_obj = io.pose_from_file(filename="inputs/test_lig.pdb")
# Submit 100 global ligand docking trajectories, very similar to using command line `-nstruct` flag:
# In[ ]:
if not os.getenv("DEBUG"):
futures = [client.submit(xml_obj, pose_obj) for i in range(100)]
results = [future.result() for future in futures]
# As results accumulate, you may wish to keep an eye on the progress bar in the `dask` dashboard.
#
# The called `future.result()` transfers the `PackedPose` objects back to this Jupyter session, so we can inspect the scores in memory!
# In[ ]:
if not os.getenv("DEBUG"):
df = pd.DataFrame.from_records(packed_pose.to_dict(results))
else:
df = pd.DataFrame()
df.head(10)
# Now plot the ligand binding energy landscape:
# In[ ]:
if not os.getenv("DEBUG"):
matplotlib.rcParams['figure.figsize'] = [12.0, 8.0]
seaborn.scatterplot(x="rmsd_chX", y="interfE", data=df)
# Let's look at the lowest energy model according to `interfE`!
# In[ ]:
if not os.getenv("DEBUG"):
lowest_energy_df = df["interfE"].sort_values()
lowest_energy_index = lowest_energy_df.index[-1]
lowest_energy_pose = results[lowest_energy_index]
view = viewer.init(lowest_energy_pose)
view.add(viewer.setStyle())
view.add(viewer.setStyle(command=({"hetflag": True}, {"stick": {"colorscheme": "brownCarbon", "radius": 0.2}})))
view.add(viewer.setHydrogenBonds())
view.add(viewer.setZoomTo(residue_selector=pyrosetta.rosetta.core.select.residue_selector.ChainSelector("X")))
view()
# View the five lowest energy poses according to `interfE`:
# In[ ]:
if not os.getenv("DEBUG"):
lowest_energy_poses = list(packed_pose.dict_to_pose(df.sort_values(by="interfE").head(5).to_dict()).values())
view = viewer.init(lowest_energy_poses)
view.add(viewer.setStyle())
view.add(viewer.setStyle(command=({"hetflag": True}, {"stick": {"colorscheme": "brownCarbon", "radius": 0.2}})))
view.add(viewer.setHydrogenBonds())
view.add(viewer.setZoomTo(residue_selector=pyrosetta.rosetta.core.select.residue_selector.ChainSelector("X")))
view()
# If you wish to save any `PackedPose` objects as `.pdb` files:
# In[ ]:
# for i, p in enumerate(results):
# with open("outputs/RESULT_%i.pdb" % i, "w") as f:
# f.write(io.to_pdbstring(p))
# If you wish to save a scorefile:
# In[ ]:
# with open(os.path.join("outputs", "ligand_docking_scores.fasc"), "w") as f:
# for result in results:
# json.dump(result.scores, f)
# # Part II: Parallelized Global Ligand Docking with `dask.distributed.as_completed` and `pyrosetta.distributed`
# Example using `dask.distributed.as_completed()` function:
#
# "Give me at least 5 global ligand docks where the ligand RMSD is at least 0.4 Angstroms from the input ligand coordinates.":
# In[ ]:
from IPython import display
import matplotlib.pyplot as plt
if not os.getenv("DEBUG"):
with seaborn.color_palette("Blues_d", n_colors=1):
nstruct = n_workers
futures = [client.submit(xml_obj, pose_obj) for j in range(nstruct)]
seq = as_completed(futures, with_results=True)
results = []
for i, (future, result) in enumerate(seq, start=1):
# Update dataset
results.append(result)
df = pd.DataFrame.from_records(packed_pose.to_dict(results))
lowest_rmsd_chX = df["rmsd_chX"].sort_values().values[0]
# Update display
display.clear_output(wait=True)
print(f"After {i} dock(s), the lowest rmsd_chX is {lowest_rmsd_chX}")
seaborn.scatterplot(x="rmsd_chX", y="interfE", data=df)
display.display(plt.gcf())
# Submit more futures if condition is not met
if (i >= nstruct) and (not lowest_rmsd_chX <= 0.4):
nstruct += n_workers
for j in range(n_workers):
seq.add(client.submit(xml_obj, pose_obj))
else:
df = pd.DataFrame()
# View resulting scores in the order they completed:
# In[ ]:
df
#
# < [Examples Using the `dask` Module](http://nbviewer.jupyter.org/github/RosettaCommons/PyRosetta.notebooks/blob/master/notebooks/16.04-dask.delayed-Via-Slurm.ipynb) | [Contents](toc.ipynb) | [Index](index.ipynb) | [Appendix A: Command Reference](http://nbviewer.jupyter.org/github/RosettaCommons/PyRosetta.notebooks/blob/master/notebooks/A.00-Appendix-A.ipynb) >