dask
Module¶dask
library to parallelize code¶Note: This Jupyter notebook uses parallelization and is not meant to be executed within a Google Colab environment.
Note: This Jupyter notebook requires the PyRosetta distributed layer which is obtained by building PyRosetta with the --serialization
flag or installing PyRosetta from the RosettaCommons conda channel
Please see Chapter 15.00 for setup instructions
import dask
import dask.array as da
import graphviz
import logging
logging.basicConfig(level=logging.INFO)
import numpy as np
import os
import pyrosetta
import pyrosetta.distributed
import pyrosetta.distributed.dask
import pyrosetta.distributed.io as io
import random
import sys
from dask.distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from IPython.display import Image
if 'google.colab' in sys.modules:
print("This Jupyter notebook uses parallelization and is therefore not set up for the Google Colab environment.")
sys.exit(0)
Initialize PyRosetta within this Jupyter notebook using custom command line PyRosetta flags:
flags = """-out:level 100
-ignore_unrecognized_res 1
-ignore_waters 0
-detect_disulf 0 # Do not automatically detect disulfides
""" # These can be unformatted for user convenience, but no spaces in file paths!
pyrosetta.distributed.init(flags)
INFO:pyrosetta.distributed:maybe_init performing pyrosetta initialization: {'extra_options': '-out:level 100 -ignore_unrecognized_res 1 -ignore_waters 0 -detect_disulf 0', 'silent': True} INFO:pyrosetta.rosetta:Found rosetta database at: /Users/jklima/opt/miniconda3/envs/PyRosetta.notebooks/lib/python3.7/site-packages/pyrosetta/database; using it.... INFO:pyrosetta.rosetta:PyRosetta-4 2020 [Rosetta PyRosetta4.conda.mac.python37.Release 2020.02+release.22ef835b4a2647af94fcd6421a85720f07eddf12 2020-01-05T17:31:56] retrieved from: http://www.pyrosetta.org (C) Copyright Rosetta Commons Member Institutions. Created in JHU by Sergey Lyskov and PyRosetta Team.
If you are running this example on a high-performance computing (HPC) cluster with SLURM scheduling, use the SLURMCluster
class described below. For more information, visit https://jobqueue.dask.org/en/latest/generated/dask_jobqueue.SLURMCluster.html. Note: If you are running this example on a HPC cluster with a job scheduler other than SLURM, dask_jobqueue
also works with other job schedulers: http://jobqueue.dask.org/en/latest/api.html
The SLURMCluster
class in the dask_jobqueue
module is very useful! In this case, we are requesting four workers using cluster.scale(4)
, and specifying each worker to have:
cores=1
processes=1
job_cpu=1
memory="4GB"
queue="short"
walltime="03:00:00"
local_directory
job_extra
optionextra=pyrosetta.distributed.dask.worker_extra(init_flags=flags)
optionif not os.getenv("DEBUG"):
scratch_dir = os.path.join("/net/scratch", os.environ["USER"])
cluster = SLURMCluster(
cores=1,
processes=1,
job_cpu=1,
memory="4GB",
queue="short",
walltime="02:59:00",
local_directory=scratch_dir,
job_extra=["-o {}".format(os.path.join(scratch_dir, "slurm-%j.out"))],
extra=pyrosetta.distributed.dask.worker_extra(init_flags=flags)
)
cluster.scale(4)
client = Client(cluster)
else:
cluster = None
client = None
Note: The actual sbatch script submitted to the Slurm scheduler under the hood was:
print(cluster.job_script())
#!/bin/bash #!/usr/bin/env bash #SBATCH -J dask-worker #SBATCH -p short #SBATCH -n 1 #SBATCH --cpus-per-task=1 #SBATCH --mem=4G #SBATCH -t 02:59:00 #SBATCH -o /net/scratch/klimaj/slurm-%j.out JOB_ID=${SLURM_JOB_ID%;*} /home/klimaj/anaconda3/envs/pyrosetta-code-school/bin/python -m distributed.cli.dask_worker tcp://172.16.131.31:27327 --nthreads 1 --memory-limit 4.00GB --name dask-worker--${JOB_ID}-- --death-timeout 60 --local-directory /net/scratch/klimaj --preload pyrosetta.distributed.dask.worker ' -out:level 100 -ignore_unrecognized_res 1 -ignore_waters 0 -detect_disulf 0'
Otherwise, if you are running this example locally on your laptop, you can still spawn workers and take advantage of the dask
module:
# cluster = LocalCluster(n_workers=2, threads_per_worker=2)
# client = Client(cluster)
Open the dask
dashboard, which shows diagnostic information about the current state of your cluster and helps track progress, identify performance issues, and debug failures:
client
Client
|
Cluster
|
def inc(x):
return x + 1
def double(x):
return x + 2
def add(x, y):
return x + y
output = []
for x in range(10):
a = inc(x)
b = double(x)
c = add(a, b)
output.append(c)
total = sum(output)
print(total)
120
With a slight modification, we can parallelize it on the HPC cluster using the dask
module
output = []
for x in range(10):
a = dask.delayed(inc)(x)
b = dask.delayed(double)(x)
c = dask.delayed(add)(a, b)
output.append(c)
delayed = dask.delayed(sum)(output)
print(delayed)
Delayed('sum-94562113-2e9a-483f-b671-a288c49a6fd7')
We used the dask.delayed
function to wrap the function calls that we want to turn into tasks. None of the inc
, double
, add
, or sum
calls have happened yet. Instead, the object total is a Delayed
object that contains a task graph of the entire computation to be executed.
Let's visualize the task graph to see clear opportunities for parallel execution.
delayed.visualize()
We can now compute this lazy result to execute the graph in parallel:
if not os.getenv("DEBUG"):
total = delayed.compute()
print(total)
120
We can also use dask.delayed
as a python function decorator for identical performance
@dask.delayed
def inc(x):
return x + 1
@dask.delayed
def double(x):
return x + 2
@dask.delayed
def add(x, y):
return x + y
output = []
for x in range(10):
a = inc(x)
b = double(x)
c = add(a, b)
output.append(c)
total = dask.delayed(sum)(output).compute()
print(total)
120
We can also use the dask.array
library, which implements a subset of the NumPy ndarray interface using blocked algorithms, cutting up the large array into many parallelizable small arrays.
See dask.array
documentation: http://docs.dask.org/en/latest/array.html, along with that of dask.bag
, dask.dataframe
, dask.delayed
, Futures
, etc.
if not os.getenv("DEBUG"):
x = da.random.random((10000, 10000, 10), chunks=(1000, 1000, 5))
y = da.random.random((10000, 10000, 10), chunks=(1000, 1000, 5))
z = (da.arcsin(x) + da.arccos(y)).sum(axis=(1, 2))
z.compute()
The dask dashboard allows visualizing parallel computation, including progress bars for tasks. Here is a snapshot of the dask dashboard while executing the previous cell:
Image(filename="inputs/dask_dashboard_example.png")
For more info on interpreting the dask dashboard, see: https://distributed.dask.org/en/latest/web.html
dask.delayed
with PyRosetta¶Let's look at a simple example of sending PyRosetta jobs to the dask-worker
, and the dask-worker
sending the results back to this Jupyter Notebook.
We will use the crystal structure of the de novo mini protein gEHEE_06 from PDB ID 5JG9
@dask.delayed
def mutate(ppose, target, new_res):
import pyrosetta
pose = io.to_pose(ppose)
mutate = pyrosetta.rosetta.protocols.simple_moves.MutateResidue(target=target, new_res=new_res)
mutate.apply(pose)
return io.to_packed(pose)
@dask.delayed
def refine(ppose):
import pyrosetta
pose = io.to_pose(ppose)
scorefxn = pyrosetta.create_score_function("ref2015_cart")
mm = pyrosetta.rosetta.core.kinematics.MoveMap()
mm.set_bb(True)
mm.set_chi(True)
min_mover = pyrosetta.rosetta.protocols.minimization_packing.MinMover()
min_mover.set_movemap(mm)
min_mover.score_function(scorefxn)
min_mover.min_type("lbfgs_armijo_nonmonotone")
min_mover.cartesian(True)
min_mover.tolerance(0.01)
min_mover.max_iter(200)
min_mover.apply(pose)
return io.to_packed(pose)
@dask.delayed
def score(ppose):
import pyrosetta
pose = io.to_pose(ppose)
scorefxn = pyrosetta.create_score_function("ref2015")
total_score = scorefxn(pose)
return pose, total_score
pose = pyrosetta.io.pose_from_file("inputs/5JG9.clean.pdb")
keep_chA = pyrosetta.rosetta.protocols.grafting.simple_movers.KeepRegionMover(
res_start=str(pose.chain_begin(1)), res_end=str(pose.chain_end(1))
)
keep_chA.apply(pose)
#kwargs = {"extra_options": pyrosetta.distributed._normflags(flags)}
output = []
for target in random.sample(range(1, pose.size() + 1), 10):
if pose.sequence()[target - 1] != "C":
for new_res in ["ALA", "TRP"]:
a = mutate(io.to_packed(pose), target, new_res)
b = refine(a)
c = score(b)
output.append((target, new_res, c[0], c[1]))
delayed_obj = dask.delayed(np.argmin)([x[-1] for x in output])
delayed_obj.visualize()
print(output)
[(2, 'ALA', Delayed('getitem-d62580447670f937930519b29c2baea8'), Delayed('getitem-d56a296644eab884118ac84afb48c2a8')), (2, 'TRP', Delayed('getitem-3c6ddc615e5da4ff0cf6039ca00ac2b5'), Delayed('getitem-345b66a261acdf50094627f6274b1fc8')), (26, 'ALA', Delayed('getitem-d7333db345fa381a24c89c0807246f0c'), Delayed('getitem-55a09f3d55bc37204bef331712229d28')), (26, 'TRP', Delayed('getitem-946bcf7f947ce02258389e68405ee2e1'), Delayed('getitem-07adb31be33410bca532f4054c8a1a5e')), (43, 'ALA', Delayed('getitem-949c48690d4f050f42dd4d2e29b30c82'), Delayed('getitem-d29dfc3f14c9b935d48566cc59049ec9')), (43, 'TRP', Delayed('getitem-fc490282bbbbd16fe34fa2d4c6d520aa'), Delayed('getitem-dcc772e00c793a5ac35b04a75d08d83e')), (6, 'ALA', Delayed('getitem-0719523591d13f3d03027eafe39601a2'), Delayed('getitem-bec53a87987b24987bb60dd592363146')), (6, 'TRP', Delayed('getitem-a9999dd6b96693b15a0ac592b8e70372'), Delayed('getitem-93b4f2dda91a3e7e7fb7f04d9ab284ed')), (33, 'ALA', Delayed('getitem-c6c5b689a40217e1046b02dd5f691872'), Delayed('getitem-2ecd74ae56aabf11074d00d7c3dda42e')), (33, 'TRP', Delayed('getitem-2d008fec24605d8a46ac9d7bceb40bc4'), Delayed('getitem-d20da5ac43a23483bff7eb260ed25581')), (34, 'ALA', Delayed('getitem-994999abf29e66e929ae58f75a639a2d'), Delayed('getitem-20a8edebdd30b7431c95085d41538fd1')), (34, 'TRP', Delayed('getitem-bc45278e9b0164a9e1ea0f0161a1b288'), Delayed('getitem-b885caa1c171a70f9c22fbc956c47752')), (27, 'ALA', Delayed('getitem-08d66481eee1d24f998b76e7991bf9f8'), Delayed('getitem-1a3616b6b45ccde31858cba7a238284c')), (27, 'TRP', Delayed('getitem-00e8704786a85250c643bb606798f508'), Delayed('getitem-1bb9ae95236335ce46ecf06f7bf44a6c')), (8, 'ALA', Delayed('getitem-a688e7c968474d18257a29179df5fa4a'), Delayed('getitem-b1d54752a55c6c1a62d85f7053909be6')), (8, 'TRP', Delayed('getitem-4d1f87ec243c901daf133c0c08227e87'), Delayed('getitem-5ba961661b56e61be81e3ddb45616026'))]
if not os.getenv("DEBUG"):
delayed_result = delayed_obj.persist()
else:
delayed_result = None
progress(delayed_result)
VBox()
if not os.getenv("DEBUG"):
result = delayed_result.compute()
print("The mutation with the lowest energy is residue {0} at position {1}".format(output[result][1], output[result][0]))
The mutation with the lowest energy is residue ALA at position 32
Note: For best practices while using dask.delayed
, see: http://docs.dask.org/en/latest/delayed-best-practices.html