In large-scale data processing, customization is often necessary to extend the native capabilities of Spark. Python User-Defined Functions (UDFs) and User-Defined Table Functions (UDTFs) offer a way to perform complex transformations and computations using Python, seamlessly integrating them into Spark’s distributed environment.
In this section, we’ll explore how to write and use UDFs and UDTFs in Python, leveraging PySpark to perform complex data transformations that go beyond Spark’s built-in functions.
There are two main categories of UDFs supported in PySpark: Scalar Python UDFs and Pandas UDFs.
Based on Pandas UDFs implementation, there are also Pandas Function APIs: Map (i.e., mapInPandas
) and (Co)Grouped Map (i.e., applyInPandas
), as well as an Arrow Function API - mapInArrow
.
In the code below, we've created a simple scalar Python UDF.
from pyspark.sql.functions import udf
@udf(returnType='int')
def slen(s: str):
return len(s)
Scalar Python UDFs rely on cloudpickle for serialization and deserialization, and encounter performance bottlenecks, particularly when dealing with large data inputs and outputs. We introduce Arrow-optimized Python UDFs to significantly improve performance.
At the core of this optimization lies Apache Arrow, a standardized cross-language columnar in-memory data representation. By harnessing Arrow, these UDFs bypass the traditional, slower methods of data (de)serialization, leading to swift data exchange between JVM and Python processes. With Apache Arrow's rich type system, these optimized UDFs offer a more consistent and standardized way to handle type coercion.
We can control whether or not to enable Arrow optimization for individual UDFs by using the useArrow
boolean parameter of functions.udf
. An example is as shown below:
from pyspark.sql.functions import udf
@udf(returnType='int', useArrow=True) # An Arrow Python UDF
def arrow_slen(s: str):
...
In addition, we can enable Arrow optimization for all UDFs of an entire SparkSession via a Spark configuration: spark.sql.execution.pythonUDF.arrow.enabled
, as shown below:
spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", True)
@udf(returnType='int') # An Arrow Python UDF
def arrow_slen(s: str):
...
In Python, we can invoke a UDF directly on column(s), just like a built-in Spark function, as shown below.
data = [("Alice",), ("Bob",), ("Charlie",)]
df = spark.createDataFrame(data, ["name"])
df.withColumn("name_length", slen(df["name"])).show()
+-------+-----------+ | name|name_length| +-------+-----------+ | Alice| 5| | Bob| 3| |Charlie| 7| +-------+-----------+
In the code below, we've created a Pandas UDF which takes one pandas.Series
and outputs one pandas.Series
import pandas as pd
from pyspark.sql.functions import pandas_udf
@pandas_udf("string")
def to_upper(s: pd.Series) -> pd.Series:
return s.str.upper()
df = spark.createDataFrame([("John Doe",)], ("name",))
df.select(to_upper("name")).show()
+--------------+ |to_upper(name)| +--------------+ | JOHN DOE| +--------------+
Similar to a Scalar Python UDF, we can also invoke a pandas UDF directly on column(s):
data = [("Alice",), ("Bob",), ("Charlie",)]
df = spark.createDataFrame(data, ["name"])
df.withColumn("name_length", to_upper(df["name"])).show()
+-------+-----------+ | name|name_length| +-------+-----------+ | Alice| ALICE| | Bob| BOB| |Charlie| CHARLIE| +-------+-----------+
from pyspark.sql.types import ArrayType, IntegerType, StringType
from pyspark.sql.functions import udf
data = [
("Hello World", [1, 2, 3]),
("PySpark is Fun", [4, 5, 6]),
("PySpark Rocks", [7, 8, 9])
]
df = spark.createDataFrame(data, ["text_column", "list_column"])
@udf(returnType="string")
def process_row(text: str, numbers):
vowels_count = sum(1 for char in text if char in "aeiouAEIOU")
doubled = [x * 2 for x in numbers]
return f"Vowels: {vowels_count}, Doubled: {doubled}"
df.withColumn("process_row", process_row(df["text_column"], df["list_column"])).show(truncate=False)
+--------------+-----------+--------------------------------+ |text_column |list_column|process_row | +--------------+-----------+--------------------------------+ |Hello World |[1, 2, 3] |Vowels: 3, Doubled: [2, 4, 6] | |PySpark is Fun|[4, 5, 6] |Vowels: 3, Doubled: [8, 10, 12] | |PySpark Rocks |[7, 8, 9] |Vowels: 2, Doubled: [14, 16, 18]| +--------------+-----------+--------------------------------+
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StructType, StructField, DoubleType, StringType
import pandas as pd
data = [
(10.0, "Spark"),
(20.0, "Big Data"),
(30.0, "AI"),
(40.0, "Machine Learning"),
(50.0, "Deep Learning")
]
df = spark.createDataFrame(data, ["numeric_column", "text_column"])
# Schema for the result
schema = StructType([
StructField("mean_value", DoubleType(), True),
StructField("sum_value", DoubleType(), True),
StructField("processed_text", StringType(), True)
])
@pandas_udf(schema)
def compute_stats_and_transform_string(numeric_col: pd.Series, text_col: pd.Series) -> pd.DataFrame:
mean_value = numeric_col.mean()
sum_value = numeric_col.sum()
# Reverse the string if its length is greater than 5, otherwise capitalize it
processed_text = text_col.apply(lambda x: x[::-1] if len(x) > 5 else x.upper())
result_df = pd.DataFrame({
"mean_value": [mean_value] * len(text_col),
"sum_value": [sum_value] * len(text_col),
"processed_text": processed_text
})
return result_df
df.withColumn("result", compute_stats_and_transform_string(df["numeric_column"], df["text_column"])).show(truncate=False)
+--------------+----------------+------------------------------+ |numeric_column|text_column |result | +--------------+----------------+------------------------------+ |10.0 |Spark |{10.0, 10.0, SPARK} | |20.0 |Big Data |{20.0, 20.0, ataD giB} | |30.0 |AI |{30.0, 30.0, AI} | |40.0 |Machine Learning|{40.0, 40.0, gninraeL enihcaM}| |50.0 |Deep Learning |{50.0, 50.0, gninraeL peeD} | +--------------+----------------+------------------------------+
A Python user-defined table function (UDTF) is a new kind of function that returns a table as output instead of a single scalar result value. Once registered, they can appear in the FROM clause of a SQL query.
In short, if you want a function that generates multiple rows and columns, and want to leverage the rich Python ecosystem, Python UDTFs are for you.
Python UDTFs vs Python UDFs: While Python UDFs in Spark are designed to each accept zero or more scalar values as input, and return a single value as output, UDTFs offer more flexibility. They can return multiple rows and columns, extending the capabilities of UDFs. Here are a few scenarios where UDTFs are particularly helpful:
Python UDTFs vs SQL UDTFs: SQL UDTFs are efficient and versatile, but Python offers a richer set of libraries and tools. Compared to SQL, Python provides tools to enable advanced transformations or computations (e.g. statistical functions or machine learning inferences).
In the code below, we've created a simple UDTF that takes two integers as inputs and produces two columns as output: the original number and its square.
Note the use of the yield
statement; A Python UDTF requires the return type to be either a tuple or a Row object so that the results can be processed properly.
Also note the return type must be a StructType
with block-formatting or DDL string representing a StructType
with block-formatting in Spark.
from pyspark.sql.functions import udtf
@udtf(returnType="num: int, squared: int")
class SquareNumbers:
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num * num)
Apache Arrow is an in-memory columnar data format that allows for efficient data transfers between Java and Python processes. It can significantly boost performance when the UDTF outputs many rows. Arrow-optimization can be enabled using useArrow=True
, for example,
from pyspark.sql.functions import udtf
@udtf(returnType="num: int, squared: int", useArrow=True)
class SquareNumbers:
...
In Python, we can invoke a UDTF directly using the class name, as shown below.
from pyspark.sql.functions import lit
SquareNumbers(lit(1), lit(3)).show()
+---+-------+ |num|squared| +---+-------+ | 1| 1| | 2| 4| | 3| 9| +---+-------+
In SQL, we can register the Python UDTF and then use it in SQL as a table-valued function in the FROM clause of a query.
spark.sql("SELECT * FROM square_numbers(1, 3)").show()
from pyspark.sql.functions import lit, udtf
import math
@udtf(returnType="num: int, square: int, cube: int, factorial: int")
class GenerateComplexNumbers:
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num ** 2, num ** 3, math.factorial(num))
GenerateComplexNumbers(lit(1), lit(5)).show()
+---+------+----+---------+ |num|square|cube|factorial| +---+------+----+---------+ | 1| 1| 1| 1| | 2| 4| 8| 2| | 3| 9| 27| 6| | 4| 16| 64| 24| | 5| 25| 125| 120| +---+------+----+---------+
from pyspark.sql.functions import lit, udtf
@udtf(returnType="word: string, length: int, is_palindrome: boolean")
class ProcessWords:
def eval(self, sentence: str):
words = sentence.split() # Split sentence into words
for word in words:
is_palindrome = word == word[::-1] # Check if the word is a palindrome
yield (word, len(word), is_palindrome)
ProcessWords(lit("hello world")).show()
+-----+------+-------------+ | word|length|is_palindrome| +-----+------+-------------+ |hello| 5| false| |world| 5| false| +-----+------+-------------+
import json
from pyspark.sql.functions import lit, udtf
@udtf(returnType="key: string, value: string, value_type: string")
class ParseJSON:
def eval(self, json_str: str):
try:
json_data = json.loads(json_str)
for key, value in json_data.items():
value_type = type(value).__name__
yield (key, str(value), value_type)
except json.JSONDecodeError:
yield ("Invalid JSON", "", "")
ParseJSON(lit('{"name": "Alice", "age": 25, "is_student": false}')).show()
+----------+-----+----------+ | key|value|value_type| +----------+-----+----------+ | name|Alice| str| | age| 25| int| |is_student|False| bool| +----------+-----+----------+