dask
Module¶dask
library to parallelize code¶Note: This Jupyter notebook uses parallelization and is not meant to be executed within a Google Colab environment.
Note: This Jupyter notebook requires the PyRosetta distributed layer which is obtained by building PyRosetta with the --serialization
flag or installing PyRosetta from the RosettaCommons conda channel
Please see Chapter 16.00 for setup instructions
import dask
import dask.array as da
import graphviz
import logging
logging.basicConfig(level=logging.INFO)
import numpy as np
import os
import pyrosetta
import pyrosetta.distributed
import pyrosetta.distributed.dask
import pyrosetta.distributed.io as io
import random
import sys
from dask.distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from IPython.display import Image
if 'google.colab' in sys.modules:
print("This Jupyter notebook uses parallelization and is therefore not set up for the Google Colab environment.")
sys.exit(0)
Initialize PyRosetta within this Jupyter notebook using custom command line PyRosetta flags:
flags = """-out:level 100
-ignore_unrecognized_res 1
-ignore_waters 0
-detect_disulf 0 # Do not automatically detect disulfides
""" # These can be unformatted for user convenience, but no spaces in file paths!
pyrosetta.distributed.init(flags)
If you are running this example on a high-performance computing (HPC) cluster with SLURM scheduling, use the SLURMCluster
class described below. For more information, visit https://jobqueue.dask.org/en/latest/generated/dask_jobqueue.SLURMCluster.html. Note: If you are running this example on a HPC cluster with a job scheduler other than SLURM, dask_jobqueue
also works with other job schedulers: http://jobqueue.dask.org/en/latest/api.html
The SLURMCluster
class in the dask_jobqueue
module is very useful! In this case, we are requesting four workers using cluster.scale(4)
, and specifying each worker to have:
cores=1
processes=1
job_cpu=1
memory="4GB"
queue="short"
walltime="03:00:00"
local_directory
job_extra
optionextra=pyrosetta.distributed.dask.worker_extra(init_flags=flags)
optionif not os.getenv("DEBUG"):
scratch_dir = os.path.join("/net/scratch", os.environ["USER"])
cluster = SLURMCluster(
cores=1,
processes=1,
job_cpu=1,
memory="4GB",
queue="short",
walltime="02:59:00",
local_directory=scratch_dir,
job_extra=["-o {}".format(os.path.join(scratch_dir, "slurm-%j.out"))],
extra=pyrosetta.distributed.dask.worker_extra(init_flags=flags)
)
cluster.scale(4)
client = Client(cluster)
else:
cluster = None
client = None
Note: The actual sbatch script submitted to the Slurm scheduler under the hood was:
if not os.getenv("DEBUG"):
print(cluster.job_script())
Otherwise, if you are running this example locally on your laptop, you can still spawn workers and take advantage of the dask
module:
# cluster = LocalCluster(n_workers=1, threads_per_worker=1)
# client = Client(cluster)
Open the dask
dashboard, which shows diagnostic information about the current state of your cluster and helps track progress, identify performance issues, and debug failures:
client
def inc(x):
return x + 1
def double(x):
return x + 2
def add(x, y):
return x + y
output = []
for x in range(10):
a = inc(x)
b = double(x)
c = add(a, b)
output.append(c)
total = sum(output)
print(total)
With a slight modification, we can parallelize it on the HPC cluster using the dask
module
output = []
for x in range(10):
a = dask.delayed(inc)(x)
b = dask.delayed(double)(x)
c = dask.delayed(add)(a, b)
output.append(c)
delayed = dask.delayed(sum)(output)
print(delayed)
We used the dask.delayed
function to wrap the function calls that we want to turn into tasks. None of the inc
, double
, add
, or sum
calls have happened yet. Instead, the object total is a Delayed
object that contains a task graph of the entire computation to be executed.
Let's visualize the task graph to see clear opportunities for parallel execution.
if not os.getenv("DEBUG"):
delayed.visualize()
We can now compute this lazy result to execute the graph in parallel:
if not os.getenv("DEBUG"):
total = delayed.compute()
print(total)
We can also use dask.delayed
as a python function decorator for identical performance
@dask.delayed
def inc(x):
return x + 1
@dask.delayed
def double(x):
return x + 2
@dask.delayed
def add(x, y):
return x + y
output = []
for x in range(10):
a = inc(x)
b = double(x)
c = add(a, b)
output.append(c)
total = dask.delayed(sum)(output).compute()
print(total)
We can also use the dask.array
library, which implements a subset of the NumPy ndarray interface using blocked algorithms, cutting up the large array into many parallelizable small arrays.
See dask.array
documentation: http://docs.dask.org/en/latest/array.html, along with that of dask.bag
, dask.dataframe
, dask.delayed
, Futures
, etc.
if not os.getenv("DEBUG"):
x = da.random.random((10000, 10000, 10), chunks=(1000, 1000, 5))
y = da.random.random((10000, 10000, 10), chunks=(1000, 1000, 5))
z = (da.arcsin(x) + da.arccos(y)).sum(axis=(1, 2))
z.compute()
The dask dashboard allows visualizing parallel computation, including progress bars for tasks. Here is a snapshot of the dask dashboard while executing the previous cell:
Image(filename="inputs/dask_dashboard_example.png")
For more info on interpreting the dask dashboard, see: https://distributed.dask.org/en/latest/web.html
dask.delayed
with PyRosetta¶Let's look at a simple example of sending PyRosetta jobs to the dask-worker
, and the dask-worker
sending the results back to this Jupyter Notebook.
We will use the crystal structure of the de novo mini protein gEHEE_06 from PDB ID 5JG9
@dask.delayed
def mutate(ppose, target, new_res):
import pyrosetta
pose = io.to_pose(ppose)
mutate = pyrosetta.rosetta.protocols.simple_moves.MutateResidue(target=target, new_res=new_res)
mutate.apply(pose)
return io.to_packed(pose)
@dask.delayed
def refine(ppose):
import pyrosetta
pose = io.to_pose(ppose)
scorefxn = pyrosetta.create_score_function("ref2015_cart")
mm = pyrosetta.rosetta.core.kinematics.MoveMap()
mm.set_bb(True)
mm.set_chi(True)
min_mover = pyrosetta.rosetta.protocols.minimization_packing.MinMover()
min_mover.set_movemap(mm)
min_mover.score_function(scorefxn)
min_mover.min_type("lbfgs_armijo_nonmonotone")
min_mover.cartesian(True)
min_mover.tolerance(0.01)
min_mover.max_iter(200)
min_mover.apply(pose)
return io.to_packed(pose)
@dask.delayed
def score(ppose):
import pyrosetta
pose = io.to_pose(ppose)
scorefxn = pyrosetta.create_score_function("ref2015")
total_score = scorefxn(pose)
return pose, total_score
if not os.getenv("DEBUG"):
pose = pyrosetta.io.pose_from_file("inputs/5JG9.clean.pdb")
keep_chA = pyrosetta.rosetta.protocols.grafting.simple_movers.KeepRegionMover(
res_start=str(pose.chain_begin(1)), res_end=str(pose.chain_end(1))
)
keep_chA.apply(pose)
#kwargs = {"extra_options": pyrosetta.distributed._normflags(flags)}
output = []
for target in random.sample(range(1, pose.size() + 1), 10):
if pose.sequence()[target - 1] != "C":
for new_res in ["ALA", "TRP"]:
a = mutate(io.to_packed(pose), target, new_res)
b = refine(a)
c = score(b)
output.append((target, new_res, c[0], c[1]))
delayed_obj = dask.delayed(np.argmin)([x[-1] for x in output])
delayed_obj.visualize()
print(output)
if not os.getenv("DEBUG"):
delayed_result = delayed_obj.persist()
progress(delayed_result)
The dask progress bar allows visualizing parallelization directly within the Jupyter notebook. Here is a snapshot of the dask progress bar while executing the previous cell:
Image(filename="inputs/dask_progress_bar_example.png")
if not os.getenv("DEBUG"):
result = delayed_result.compute()
print("The mutation with the lowest energy is residue {0} at position {1}".format(output[result][1], output[result][0]))
Note: For best practices while using dask.delayed
, see: http://docs.dask.org/en/latest/delayed-best-practices.html