I am a big fan of unit-testing.
āTesting Is Not About Finding Bugs. We believe that the major benefits of testing happen when you think about and write the tests, not when you run them.ā
The Pragmatic Programmer, David Thomas and Andrew Hunt
Reading The Pragmatic Programmer and Refactoring completely changed the way I viewed unit-testing.
Instead of seeing testing as a chore to complete after I have finished my pipelines, I see it as a powerful tool to improve the design of my code, reduce coupling, iterate more quickly and build trust with others in my work.
However, writing good tests for data applications can be difficult.
Unlike traditional software applications with relatively well defined inputs, data applications in production depend on large and constantly changing input data. It can be very challenging to accurately represent this data in a test environment in enough detail to cover all edge cases. Some people argue there is little point in unit-testing data pipelines, and focus on data validation techniques instead.
I strongly believe in implementing unit-testing and data validation in your data pipelines. Unit-testing isn’t just about finding bugs, it is about creating better designed code and building trust with colleagues and end users.
If you can get in the habit of writing tests, you will write better designed code, save time in the long run and reduce the pain of pipelines failing or giving incorrect results in production.
Challenges with Unit testing PySpark code
A good unit-test should have the following characteristics:
- Focused. Each test should test a single behaviour/functionality.
- Fast. Allowing you to iterate and gain feedback quickly
- Isolated. Each test should be responsible for testing a specific functionality and not depend on external factors in order to run successfully
- Concise. Creating a test shouldn’t include lots of boiler-plate code to mock/create complex objects in order for the test to run
When it comes to writing unit-tests for PySpark pipelines, writing focussed, fast, isolated and concise tests can be a challenge.
Here are some of the hurdles to overcome…
Writing testable code in the first place
PySpark pipelines tend to be written as one giant function responsible for multiple transformations. For example:
def main(spark, country):
"""Example function for processing data received from different countries"""
# fetch input data
...
# preprocess based on individual country requirements
if country == 'US':
# preprocessing for US
elif country == 'GB':
# preprocessing for UK
else:
# more preprocessing etc..
# join dataframes together
...
# run some calculations/aggregate
...
# save results
...
It is logical to think about transformations in this way, and in many ways is easier to reason about and read.
But, when you start trying to write a test for this function you quickly realise it is very difficult to write a test to cover all functionality.
This is because the function is highly coupled and there many different paths that the function can take.
Even if you did write a test that verified the input data and output data were as expected. If the test failed for any reason it would be very difficult to understand which part of the very long function was at fault.
Instead, you should break your transformations into blocks of reusable functions which are responsible for a single task. You can then write a unit-test for each individual function (task) which. When each of these unit-tests pass, you can be more confident in the output of the final pipeline when you compose all the functions together.
Writing tests is a good practice and forces you to think about design principles. If it is difficult to test your code then you probably need to rethink the design of your code.
Speed
Spark is optimised to work on very large data and uses lazy evaluation and distributed processing.
When you run a PySpark pipeline, spark will evaluate the entire pipeline and calculate an optimised ‘plan’ to perform the computation across a distributed cluster.
The planning comes with significant overhead. This makes sense when you are processing terabytes of data on a distributed cluster of machines. But when working on a single machine on a small dataset it can be surprisingly slow. Especially compared with what you might have experienced with Pandas.
Without optimising your SparkSession configuration parameters your unit-tests will take an agonizingly long time to run.
Dependency on a Spark Session
To run PySpark code in your unit-test, you need a SparkSession.
As stated above, ideally each test should be isolated from others and not require complex external objects. Unfortunately, there is no escaping the requirement to initiate a spark session for your unit-tests.
Creating a spark session is the first hurdle to overcome when writing a unit-test for your PySpark pipeline.
How should you create a SparkSession for your tests?
Initiating a new spark session for each test would dramatically increase the time to run the tests and introduce a ton of boiler-plate code to your tests.
Efficiently, creating and sharing a SparkSession across your tests is vital to keep the performance of your tests at an acceptable level.
Data
Your tests will require input data.
There are two main problems with creating example data for big data pipelines.
The first is size. Obviously, you cannot run your tests against the full dataset that will be used in production. You have to use a much smaller subset.
But, by using a small dataset, you run into the second problem which is providing enough test data to cover all the edge cases you want to handle.
It is really hard to mock realistic data for testing. However, you can use smaller, targeted datasets for your tests. So what is the most efficient way to pass example data to your PySpark unit-tests?
Steps to unit-test your PySpark code with Pytest
Let’s work through an example using PyTest .
š» Full code is available in the e4ds-snippets GitHub repository
Example code
Here is an example PySpark pipeline to process some bank transactions and classify them as debit account or credit account transactions:
Each transaction record comes with an account ID. We will use this account ID to join to account information table which has information on whether this account ID is from a debit or credit account.
import pyspark.sql.functions as F
from pyspark.sql import DataFrame
def classify_debit_credit_transactions(
transactionsDf: DataFrame, accountDf: DataFrame
) -> DataFrame:
"""Join transactions with account information and classify as debit/credit"""
# normalise strings
transactionsDf = transactionsDf.withColumn(
"transaction_information_cleaned",
F.regexp_replace(F.col("transaction_information"), r"[^A-Z0-9]+", ""),
)
# join on customer account using first 9 characters
transactions_accountsDf = transactionsDf.join(
accountDf,
on=F.substring(F.col("transaction_information_cleaned"), 1, 9)
== F.col("account_number"),
how="inner",
)
# classify transactions as from debit or credit account customers
credit_account_ids = ["100", "101", "102"]
debit_account_ids = ["200", "201", "202"]
transactions_accountsDf = transactions_accountsDf.withColumn(
"business_line",
F.when(F.col("business_line_id").isin(credit_account_ids), F.lit("credit"))
.when(F.col("business_line_id").isin(debit_account_ids), F.lit("debit"))
.otherwise(F.lit("other")),
)
return transactions_accountsDf
There are few issues with this example pipeline:
- Difficult to read. Lots of complex logic in one place. For example, regex replacements, joining on substrings etc.
- Difficult to test. Single function responsible for multiple actions
- Difficult to reuse. The debit/credit classification is business logic which currently cannot be reused across the project
Refactor into smaller logical units
Let’s first refactor the code into individual functions, then compose the functions together for the main classify_debit_credit_transactions
function.
We can then write an test for each individual function to ensure it is behaving as expected.
While this increases the overall number of lines of code, it is easier to test and we can now reuse the functions across other parts of the project.
import pyspark.sql.functions as F
from pyspark.sql import DataFrame
def classify_debit_credit_transactions(
transactionsDf: DataFrame, accountsDf: DataFrame
) -> DataFrame:
"""Join transactions with account information and classify as debit/credit"""
transactionsDf = normalise_transaction_information(transactionsDf)
transactions_accountsDf = join_transactionsDf_to_accountsDf(
transactionsDf, accountsDf
)
transactions_accountsDf = apply_debit_credit_business_classification(
transactions_accountsDf
)
return transactions_accountsDf
def normalise_transaction_information(transactionsDf: DataFrame) -> DataFrame:
"""Remove special characters from transaction information"""
return transactionsDf.withColumn(
"transaction_information_cleaned",
F.regexp_replace(F.col("transaction_information"), r"[^A-Z0-9]+", ""),
)
def join_transactionsDf_to_accountsDf(
transactionsDf: DataFrame, accountsDf: DataFrame
) -> DataFrame:
"""Join transactions to accounts information"""
return transactionsDf.join(
accountsDf,
on=F.substring(F.col("transaction_information_cleaned"), 1, 9)
== F.col("account_number"),
how="inner",
)
def apply_debit_credit_business_classification(
transactions_accountsDf: DataFrame,
) -> DataFrame:
"""Classify transactions as coming from debit or credit account customers"""
# TODO: move to config file
CREDIT_ACCOUNT_IDS = ["101", "102", "103"]
DEBIT_ACCOUNT_IDS = ["202", "202", "203"]
return transactions_accountsDf.withColumn(
"business_line",
F.when(F.col("business_line_id").isin(CREDIT_ACCOUNT_IDS), F.lit("credit"))
.when(F.col("business_line_id").isin(DEBIT_ACCOUNT_IDS), F.lit("debit"))
.otherwise(F.lit("other")),
)
Creating a resuable SparkSession using Fixtures
Before writing our unit-tests, we need to create a SparkSession which we can reuse across all our tests.
To do this, we create a PyTest fixture
in a conftest.py
file.
Pytest fixtures are objects which are created once and then reused across multiple tests. This is particularly useful for complex objects like the SparkSession which have a significant overhead to create.
# conftest.py
from pyspark.sql import SparkSession
import pytest
@pytest.fixture(scope="session")
def spark():
spark = (
SparkSession.builder.master("local[1]")
.appName("local-tests")
.config("spark.executor.cores", "1")
.config("spark.executor.instances", "1")
.config("spark.sql.shuffle.partitions", "1")
.config("spark.driver.bindAddress", "127.0.0.1")
.getOrCreate()
)
yield spark
spark.stop()
It is important to set a number of configuration parameters in order to optimise the SparkSession for processing small data on a single machine for testing:
master = local[1]
– specifies that spark is running on a local machine with one threadspark.executor.cores = 1
– set number of cores to onespark.executor.instances = 1
- set executors to onespark.sql.shuffle.partitions = 1
- set the maximum number of partitions to 1spark.driver.bindAddress = 127.0.0.1
– (optional) Explicitly specify the driver bind address. Useful if your machine has also has a live connection to a remote cluster
These config parameters essentially tell spark that you are processing on a single machine and spark should not try to distribute the computation. This will save a significant amount of time in both the planning of the pipeline execution and the computation itself.
Note, it is recommended to yield
the spark session instead of using return
. Read the PyTest documentation
for more information. Using yield
also allows you to perform any clean up actions after your tests have run (e.g. deleting any local temp directories, databases or tables etc.).
Creating unit-tests for the code
Now lets write some tests for our code.
I find it most efficient to organise my PySpark unit tests with the following structure:
- Create the input dataframe
- Create the output dataframe using the function we want to test
- Specify the expected output values
- Compare the results
I also try to ensure the test covers positive test cases and at least one negative test case.
from src.data_processing import (
apply_debit_credit_business_classification,
classify_debit_credit_transactions,
join_transactionsDf_to_accountsDf,
normalise_transaction_information,
)
def test_classify_debit_credit_transactions(spark):
# create input dataframes
transactionsDf = spark.createDataFrame(
data=[
("1", 1000.00, "123-456-789"),
("3", 3000.00, "222222222EUR"),
],
schema=["transaction_id", "amount", "transaction_information"],
)
accountsDf = spark.createDataFrame(
data=[
("123456789", "101"),
("222222222", "202"),
("000000000", "302"),
],
schema=["account_number", "business_line_id"],
)
# output dataframe after applying function
output = classify_debit_credit_transactions(transactionsDf, accountsDf)
# expected outputs in the target column
expected_classifications = ["credit", "debit"]
# assert results are as expected
assert output.count() == 2
assert [row.business_line for row in output.collect()] == expected_classifications
def test_normalise_transaction_information(spark):
data = ["123-456-789", "123456789", "123456789EUR", "TEXT*?WITH.*CHARACTERS"]
test_df = spark.createDataFrame(data, "string").toDF("transaction_information")
expected = ["123456789", "123456789", "123456789EUR", "TEXTWITHCHARACTERS"]
output = normalise_transaction_information(test_df)
assert [row.transaction_information_cleaned for row in output.collect()] == expected
def test_join_transactionsDf_to_accountsDf(spark):
data = ["123456789", "222222222EUR"]
transactionsDf = spark.createDataFrame(data, "string").toDF(
"transaction_information_cleaned"
)
data = [
"123456789", # match
"222222222", # match
"000000000", # no-match
]
accountsDf = spark.createDataFrame(data, "string").toDF("account_number")
output = join_transactionsDf_to_accountsDf(transactionsDf, accountsDf)
assert output.count() == 2
def test_apply_debit_credit_business_classification(spark):
data = [
"101", # credit
"202", # debit
"000", # other
]
df = spark.createDataFrame(data, "string").toDF("business_line_id")
output = apply_debit_credit_business_classification(df)
expected = ["credit", "debit", "other"]
assert [row.business_line for row in output.collect()] == expected
We now have unit-test for each component in the PySpark pipeline.
As each test reuses the same SparkSession the overhead of running multiple tests is significantly reduced.
Tips for unit testing PySpark code
Keep the unit-tests isolated
Be careful not to modify your spark session during a test (e.g. creating a table, but not deleting it afterwards).
Try and keep the creation of data close to where it is used.
You could use fixtures to share dataframes or even load test data from csv files etc. However, in my experience, it is easier and more readable to create data as required for each individual test.
Create test dataframes with the minimum required information
When creating dataframes with test data, only create columns relevant to the transformation.
You only really need to create data with columns required for the function. You don’t need all the other columns which might be present in the production data.
This helps write concise functions and is more readable as it is clear which columns are required and impacted by the function. If you find you need a big dataframe with many columns in order to carry out a transformation you are probably trying to do too much at once.
This is just a guideline, your own usecase might require more complicated test data, but if possible keep it small, concise and localised to the test.
Remember to call an action in order to trigger the PySpark computation
PySpark uses lazy evaluation. You need to call an ‘action’ (e.g. collect
, count
etc.) during your test in order to compute a result that you can compare to the expected output.
Don't run all PySpark tests if you don't need to
PySpark tests generally take longer than normal unit tests to run as there is overhead to calculate a computation plan and then execute it.
During development, make use of some of Pytest’s features such as the -k
flag to run single tests or just run tests in a single file. Then only run the full test suite before committing your code.
Test positive and negative outcomes
For example, when testing a joining condition you should include data which should not satisfy the join condition. This helps ensure you are excluding the right data as well as including the right data.
Happy testing!
Other Resources
Other PySpark Unit-testing blog posts
- https://hangar.tech/posts/unit-testing-spark/
- https://www.mikulskibartosz.name/how-to-unit-test-pyspark/
- https://blog.cambridgespark.com/unit-testing-with-pyspark-fb31671b1ad8
- https://medium.com/@davisjustin42/writing-pyspark-unit-tests-1e0ef6187f5e
Pytest Resources
General PySpark Best Practices
Further Reading
- Pytest: How to use fixtures as arguments in parametrize
- How to set up Logging for Python Projects
- Google Search Console API with Python
- What I learned optimising someone else’s code
- Deploying Dremio on Google Cloud
- Gitmoji: Add Emojis to Your Git Commit Messages!
- Do Programmers Need to be able to Type Fast?
- How to Manage Multiple Git Accounts on the Same Machine