! pip install "ray==2.2.0"
To give you an example, take the following function which retrieves and processes data from a database. Our dummy database is a plain Python list containing the words of the title of this book. To simulate the idea that accessing and processing data from the database is costly, we have the function sleep (pause for a certain amount of time) in Python.
import time
database = [
"Learning", "Ray",
"Flexible", "Distributed", "Python", "for", "Machine", "Learning"
]
def retrieve(item):
time.sleep(item / 10.)
return item, database[item]
Our database has eight items in total. If we were to retrieve all items sequentially, how long should that take? For the item with index 5 we wait for half a second (5 / 10.) and so on. In total, we can expect a runtime of around (0+1+2+3+4+5+6+7)/10. = 2.8
seconds. Let’s see if that’s what we actually get:
def print_runtime(input_data, start_time):
print(f'Runtime: {time.time() - start_time:.2f} seconds, data:')
print(*input_data, sep="\n")
start = time.time()
data = [retrieve(item) for item in range(8)]
print_runtime(data, start)
Runtime: 2.81 seconds, data: (0, 'Learning') (1, 'Ray') (2, 'Flexible') (3, 'Distributed') (4, 'Python') (5, 'for') (6, 'Machine') (7, 'Learning')
The total time it takes to run the function is 2.81 seconds, but this may vary on your individual computer. It's important to note that our basic Python version is not capable of running this function simultaneously.
You may not have been surprised to hear this, but it's likely that you at least suspected that Python list comprehensions are more efficient in terms of performance. The runtime we measured, which was 2.8 seconds, is actually the worst case scenario. It may be frustrating to see that a program that mostly just "sleeps" during its runtime could still be so slow, but the reason for this is due to the Global Interpreter Lock (GIL), which gets enough criticism as it is.
It’s reasonable to assume that such a task can benefit from parallelization. Perfectly distributed, the runtime should not take much longer than the longest subtask, namely 7/10. = 0.7 seconds. So, let’s see how you can extend this example to run on Ray. To do so, you start by using the @ray.remote decorator
One aspect that may have been noticed is that the retrieve definition involves directly accessing items from the database. While this works well on a local Ray cluster, it is important to consider how this would function on an actual cluster with multiple computers. In a Ray cluster, there is a head node with a driver process and multiple worker nodes with worker processes executing tasks. By default, Ray creates as many worker processes as there are CPU cores on the machine. However, in this scenario the database is defined on the driver only, but the worker processes need access to it to run the retrieve task. Fortunately, Ray has a simple solution for sharing objects between the driver and workers or between workers - using the put function to place the data into Ray's distributed object store. In the retrieve_task definition, we explicitly include a db argument, which will later be passed the db_object_ref object.
import ray
#By utilizing the object store in this manner, you can allow Ray to manage data
#access throughout the entire cluster. We will discuss the specifics of how values
#are transmitted between nodes and within workers when discussing Ray's infrastructure.
db_object_ref = ray.put(database)
@ray.remote
def retrieve_task(item, db):
time.sleep(item / 10.)
return item, db[item]
In this way, the function retrieve_task becomes a so-called Ray task. In essence, a Ray task is a function that gets executed on a different process that it was called from, potentially on a different machine.
It is very convenient to use Ray because you can continue to write your Python code as usual, without having to significantly change your approach or programming style. Using the @ray.remote decorator on your retrieve function is the intended use of decorators, but for the purpose of clarity, we did not modify the original code in this example.
To retrieve database entries and measure performance, what changes do you need to make to the code? It's actually not a lot. Here's an overview of the process:
In this case, it would be more efficient to allow the driver process to perform other tasks while waiting for results, and to process results as they are completed rather than waiting for all items to be finished. Additionally, if one of the database items cannot be retrieved due to an issue like a deadlock in the database connection, the driver will hang indefinitely. To prevent this, it is a good idea to set reasonable timeouts using the wait function. For example, if we do not want to wait longer than ten times the longest data retrieval task, we can use the wait function to stop the task after that time has passed.
start = time.time()
object_references = [
retrieve_task.remote(item, db_object_ref) for item in range(8)
]
all_data = []
while len(object_references) > 0:
finished, object_references = ray.wait(
object_references, num_returns=2, timeout=7.0
)
data = ray.get(finished)
print_runtime(data, start)
all_data.extend(data)
print_runtime(all_data, start)
Runtime: 1.10 seconds, data: (0, 'Learning') (1, 'Ray') Runtime: 1.61 seconds, data: (2, 'Flexible') (3, 'Distributed') Runtime: 2.52 seconds, data: (4, 'Python') (5, 'for') Runtime: 3.83 seconds, data: (6, 'Machine') (7, 'Learning') Runtime: 3.83 seconds, data: (0, 'Learning') (1, 'Ray') (2, 'Flexible') (3, 'Distributed') (4, 'Python') (5, 'for') (6, 'Machine') (7, 'Learning')
We will be using Python to implement the MapReduce algorithm for our word-count purpose and utilizing Ray to parallelize the computation. To better understand what we are working with, we will begin by loading some example data.
import subprocess
zen_of_python = subprocess.check_output(["python", "-c", "import this"])
corpus = zen_of_python.split()
num_partitions = 3
chunk = len(corpus) // num_partitions
partitions = [
corpus[i * chunk: (i + 1) * chunk] for i in range(num_partitions)
]
We will be using the Zen of Python, a collection of guidelines from the Python community, as our data for this exercise. The Zen of Python can be accessed by typing "import this" in a Python session and is traditionally hidden as an "Easter egg." While it is beneficial for Python programmers to read these guidelines, for the purposes of this exercise, we will only be counting the number of words contained within them. To do this, we will divide the Zen of Python into three separate "documents" by treating each line as a separate entity and then splitting it into these partitions.
To determine the map phase, we require a map function that we will utilize on each document. In this particular scenario, we want to output the pair (word, 1) for every word found in a document. For basic text documents that are loaded as Python strings, this process appears as follows.
def map_function(document):
for word in document.lower().split():
yield word, 1
We will use the apply_map function on a large collection of documents by marking it as a task in Ray using the @ray.remote decorator. When we call apply_map, it will be applied to three sets of document data (num_partitions=3). The apply_map function will return three lists, one for each partition. We do this so that Ray can rearrange the results of the map phase and distribute them to the appropriate nodes for us.
import ray
@ray.remote
def apply_map(corpus, num_partitions=3):
map_results = [list() for _ in range(num_partitions)]
for document in corpus:
for result in map_function(document):
first_letter = result[0].decode("utf-8")[0]
word_index = ord(first_letter) % num_partitions
map_results[word_index].append(result)
return map_results
For text corpora that can be stored on a single machine, it is unnecessary to use the map phase. However, when the data needs to be divided across multiple nodes, the map phase becomes useful. In order to apply the map phase to our corpus in parallel, we use a remote call on apply_map, just like we have done in previous examples. The main difference now is that we also specify that we want three results returned (one for each partition) using the num_returns argument.
map_results = [
apply_map.options(num_returns=num_partitions)
.remote(data, num_partitions)
for data in partitions
]
for i in range(num_partitions):
mapper_results = ray.get(map_results[i])
for j, result in enumerate(mapper_results):
print(f"Mapper {i}, return value {j}: {result[:2]}")
Mapper 0, return value 0: [(b'of', 1), (b'is', 1)] Mapper 0, return value 1: [(b'python,', 1), (b'peters', 1)] Mapper 0, return value 2: [(b'the', 1), (b'zen', 1)] Mapper 1, return value 0: [(b'unless', 1), (b'in', 1)] Mapper 1, return value 1: [(b'although', 1), (b'practicality', 1)] Mapper 1, return value 2: [(b'beats', 1), (b'errors', 1)] Mapper 2, return value 0: [(b'is', 1), (b'is', 1)] Mapper 2, return value 1: [(b'although', 1), (b'a', 1)] Mapper 2, return value 2: [(b'better', 1), (b'than', 1)]
We can make it so that all pairs from the j-th return value end up on the same node for the reduce phase. Let’s discuss this phase next.
In the reduce phase we can create a dictionary that sums up all word occurrences on each partition:
@ray.remote
def apply_reduce(*results):
reduce_results = dict()
for res in results:
for key, value in res:
if key not in reduce_results:
reduce_results[key] = 0
reduce_results[key] += value
return reduce_results
We can take the j-th return value from each mapper and send it to the j-th reducer using the following method. It's important to note that this code works for larger datasets that don't fit on one machine because we are passing references to the data using Ray objects rather than the actual data itself. Both the map and reduce phases can be run on any Ray cluster and the data shuffling is also handled by Ray.
outputs = []
for i in range(num_partitions):
outputs.append(
apply_reduce.remote(*[partition[i] for partition in map_results])
)
counts = {k: v for output in ray.get(outputs) for k, v in output.items()}
sorted_counts = sorted(counts.items(), key=lambda item: item[1], reverse=True)
for count in sorted_counts:
print(f"{count[0].decode('utf-8')}: {count[1]}")
is: 10 better: 8 than: 8 the: 6 to: 5 of: 3 although: 3 be: 3 unless: 2 one: 2 if: 2 implementation: 2 idea.: 2 special: 2 should: 2 do: 2 may: 2 a: 2 never: 2 way: 2 explain,: 2 ugly.: 1 implicit.: 1 complex.: 1 complex: 1 complicated.: 1 flat: 1 readability: 1 counts.: 1 cases: 1 rules.: 1 in: 1 face: 1 refuse: 1 one--: 1 only: 1 --obvious: 1 it.: 1 obvious: 1 first: 1 often: 1 *right*: 1 it's: 1 it: 1 idea: 1 --: 1 let's: 1 python,: 1 peters: 1 simple: 1 sparse: 1 dense.: 1 aren't: 1 practicality: 1 purity.: 1 pass: 1 silently.: 1 silenced.: 1 ambiguity,: 1 guess.: 1 and: 1 preferably: 1 at: 1 you're: 1 dutch.: 1 good: 1 are: 1 great: 1 more: 1 zen: 1 by: 1 tim: 1 beautiful: 1 explicit: 1 nested.: 1 enough: 1 break: 1 beats: 1 errors: 1 explicitly: 1 temptation: 1 there: 1 that: 1 not: 1 now: 1 never.: 1 now.: 1 hard: 1 bad: 1 easy: 1 namespaces: 1 honking: 1 those!: 1
To gain a thorough understanding of how to scale MapReduce tasks across multiple nodes using Ray, including memory management, we suggest reading this insightful blog post on the topic.
The important part about this MapReduce example is to realize how flexible Ray’s programming model really is. Surely, a production-grade MapReduce implementation takes a bit more effort. But being able to reproduce common algorithms like this one quickly goes a long way. Keep in mind that in the earlier phases of MapReduce, say around 2010, this paradigm was often the only thing you had to express your workloads. With Ray, a whole range of interesting distributed computing patterns become accessible to any intermediate Python programmer.