I am a big fan of unit-testing.

“Testing Is Not About Finding Bugs. We believe that the major benefits of testing happen when you think about and write the tests, not when you run them.”

The Pragmatic Programmer, David Thomas and Andrew Hunt

Reading The Pragmatic Programmer and Refactoring completely changed the way I viewed unit-testing.

Instead of seeing testing as a chore to complete after I have finished my pipelines, I see it as a powerful tool to improve the design of my code, reduce coupling, iterate more quickly and build trust with others in my work.

However, writing good tests for data applications can be difficult.

Unlike traditional software applications with relatively well defined inputs, data applications in production depend on large and constantly changing input data. It can be very challenging to accurately represent this data in a test environment in enough detail to cover all edge cases. Some people argue there is little point in unit-testing data pipelines, and focus on data validation techniques instead.

I strongly believe in implementing unit-testing and data validation in your data pipelines. Unit-testing isn’t just about finding bugs, it is about creating better designed code and building trust with colleagues and end users.

If you can get in the habit of writing tests, you will write better designed code, save time in the long run and reduce the pain of pipelines failing or giving incorrect results in production.

Challenges with Unit testing PySpark code

A good unit-test should have the following characteristics:

  • Focused. Each test should test a single behaviour/functionality.
  • Fast. Allowing you to iterate and gain feedback quickly
  • Isolated. Each test should be responsible for testing a specific functionality and not depend on external factors in order to run successfully
  • Concise. Creating a test shouldn’t include lots of boiler-plate code to mock/create complex objects in order for the test to run

When it comes to writing unit-tests for PySpark pipelines, writing focussed, fast, isolated and concise tests can be a challenge.

Here are some of the hurdles to overcome…

Writing testable code in the first place

PySpark pipelines tend to be written as one giant function responsible for multiple transformations. For example:

def main(spark, country):
    """Example function for processing data received from different countries"""

    # fetch input data

    # preprocess based on individual country requirements
    if country == 'US':
       # preprocessing for US
    elif country == 'GB':
       # preprocessing for UK
       # more preprocessing etc..

    # join dataframes together

    # run some calculations/aggregate

    # save results

It is logical to think about transformations in this way, and in many ways is easier to reason about and read.

But, when you start trying to write a test for this function you quickly realise it is very difficult to write a test to cover all functionality.

This is because the function is highly coupled and there many different paths that the function can take.

Even if you did write a test that verified the input data and output data were as expected. If the test failed for any reason it would be very difficult to understand which part of the very long function was at fault.

Instead, you should break your transformations into blocks of reusable functions which are responsible for a single task. You can then write a unit-test for each individual function (task) which. When each of these unit-tests pass, you can be more confident in the output of the final pipeline when you compose all the functions together.

Writing tests is a good practice and forces you to think about design principles. If it is difficult to test your code then you probably need to rethink the design of your code.


Spark is optimised to work on very large data and uses lazy evaluation and distributed processing.

When you run a PySpark pipeline, spark will evaluate the entire pipeline and calculate an optimised ‘plan’ to perform the computation across a distributed cluster.

The planning comes with significant overhead. This makes sense when you are processing terabytes of data on a distributed cluster of machines. But when working on a single machine on a small dataset it can be surprisingly slow. Especially compared with what you might have experienced with Pandas.

Without optimising your SparkSession configuration parameters your unit-tests will take an agonizingly long time to run.

Dependency on a Spark Session

To run PySpark code in your unit-test, you need a SparkSession.

As stated above, ideally each test should be isolated from others and not require complex external objects. Unfortunately, there is no escaping the requirement to initiate a spark session for your unit-tests.

Creating a spark session is the first hurdle to overcome when writing a unit-test for your PySpark pipeline.

How should you create a SparkSession for your tests?

Initiating a new spark session for each test would dramatically increase the time to run the tests and introduce a ton of boiler-plate code to your tests.

Efficiently, creating and sharing a SparkSession across your tests is vital to keep the performance of your tests at an acceptable level.


Your tests will require input data.

There are two main problems with creating example data for big data pipelines.

The first is size. Obviously, you cannot run your tests against the full dataset that will be used in production. You have to use a much smaller subset.

But, by using a small dataset, you run into the second problem which is providing enough test data to cover all the edge cases you want to handle.

It is really hard to mock realistic data for testing. However, you can use smaller, targeted datasets for your tests. So what is the most efficient way to pass example data to your PySpark unit-tests?

Steps to unit-test your PySpark code with Pytest

Let’s work through an example using PyTest .

💻 Full code is available in the e4ds-snippets GitHub repository

Example code

Here is an example PySpark pipeline to process some bank transactions and classify them as debit account or credit account transactions:

Each transaction record comes with an account ID. We will use this account ID to join to account information table which has information on whether this account ID is from a debit or credit account.

import pyspark.sql.functions as F
from pyspark.sql import DataFrame

def classify_debit_credit_transactions(
    transactionsDf: DataFrame, accountDf: DataFrame
) -> DataFrame:
    """Join transactions with account information and classify as debit/credit"""

    # normalise strings
    transactionsDf = transactionsDf.withColumn(
        F.regexp_replace(F.col("transaction_information"), r"[^A-Z0-9]+", ""),

    # join on customer account using first 9 characters
    transactions_accountsDf = transactionsDf.join(
        on=F.substring(F.col("transaction_information_cleaned"), 1, 9)
        == F.col("account_number"),

    # classify transactions as from debit or credit account customers
    credit_account_ids = ["100", "101", "102"]
    debit_account_ids = ["200", "201", "202"]
    transactions_accountsDf = transactions_accountsDf.withColumn(
        F.when(F.col("business_line_id").isin(credit_account_ids), F.lit("credit"))
        .when(F.col("business_line_id").isin(debit_account_ids), F.lit("debit"))

    return transactions_accountsDf

There are few issues with this example pipeline:

  • Difficult to read. Lots of complex logic in one place. For example, regex replacements, joining on substrings etc.
  • Difficult to test. Single function responsible for multiple actions
  • Difficult to reuse. The debit/credit classification is business logic which currently cannot be reused across the project

Refactor into smaller logical units

Let’s first refactor the code into individual functions, then compose the functions together for the main classify_debit_credit_transactions function.

We can then write an test for each individual function to ensure it is behaving as expected.

While this increases the overall number of lines of code, it is easier to test and we can now reuse the functions across other parts of the project.

import pyspark.sql.functions as F
from pyspark.sql import DataFrame

def classify_debit_credit_transactions(
    transactionsDf: DataFrame, accountsDf: DataFrame
) -> DataFrame:
    """Join transactions with account information and classify as debit/credit"""

    transactionsDf = normalise_transaction_information(transactionsDf)

    transactions_accountsDf = join_transactionsDf_to_accountsDf(
        transactionsDf, accountsDf

    transactions_accountsDf = apply_debit_credit_business_classification(

    return transactions_accountsDf

def normalise_transaction_information(transactionsDf: DataFrame) -> DataFrame:
    """Remove special characters from transaction information"""
    return transactionsDf.withColumn(
        F.regexp_replace(F.col("transaction_information"), r"[^A-Z0-9]+", ""),

def join_transactionsDf_to_accountsDf(
    transactionsDf: DataFrame, accountsDf: DataFrame
) -> DataFrame:
    """Join transactions to accounts information"""
    return transactionsDf.join(
        on=F.substring(F.col("transaction_information_cleaned"), 1, 9)
        == F.col("account_number"),

def apply_debit_credit_business_classification(
    transactions_accountsDf: DataFrame,
) -> DataFrame:
    """Classify transactions as coming from debit or credit account customers"""
    # TODO: move to config file
    CREDIT_ACCOUNT_IDS = ["101", "102", "103"]
    DEBIT_ACCOUNT_IDS = ["202", "202", "203"]

    return transactions_accountsDf.withColumn(
        F.when(F.col("business_line_id").isin(CREDIT_ACCOUNT_IDS), F.lit("credit"))
        .when(F.col("business_line_id").isin(DEBIT_ACCOUNT_IDS), F.lit("debit"))

Creating a resuable SparkSession using Fixtures

Before writing our unit-tests, we need to create a SparkSession which we can reuse across all our tests.

To do this, we create a PyTest fixture in a conftest.py file.

Pytest fixtures are objects which are created once and then reused across multiple tests. This is particularly useful for complex objects like the SparkSession which have a significant overhead to create.

# conftest.py

from pyspark.sql import SparkSession

import pytest

def spark():
    spark = (
        .config("spark.executor.cores", "1")
        .config("spark.executor.instances", "1")
        .config("spark.sql.shuffle.partitions", "1")
        .config("spark.driver.bindAddress", "")
    yield spark

It is important to set a number of configuration parameters in order to optimise the SparkSession for processing small data on a single machine for testing:

  • master = local[1] – specifies that spark is running on a local machine with one thread
  • spark.executor.cores = 1 – set number of cores to one
  • spark.executor.instances = 1 - set executors to one
  • spark.sql.shuffle.partitions = 1 - set the maximum number of partitions to 1
  • spark.driver.bindAddress = – (optional) Explicitly specify the driver bind address. Useful if your machine has also has a live connection to a remote cluster

These config parameters essentially tell spark that you are processing on a single machine and spark should not try to distribute the computation. This will save a significant amount of time in both the planning of the pipeline execution and the computation itself.

Note, it is recommended to yield the spark session instead of using return. Read the PyTest documentation for more information. Using yield also allows you to perform any clean up actions after your tests have run (e.g. deleting any local temp directories, databases or tables etc.).

Creating unit-tests for the code

Now lets write some tests for our code.

I find it most efficient to organise my PySpark unit tests with the following structure:

  • Create the input dataframe
  • Create the output dataframe using the function we want to test
  • Specify the expected output values
  • Compare the results

I also try to ensure the test covers positive test cases and at least one negative test case.

from src.data_processing import (

def test_classify_debit_credit_transactions(spark):

    # create input dataframes
    transactionsDf = spark.createDataFrame(
            ("1", 1000.00, "123-456-789"),
            ("3", 3000.00, "222222222EUR"),
        schema=["transaction_id", "amount", "transaction_information"],

    accountsDf = spark.createDataFrame(
            ("123456789", "101"),
            ("222222222", "202"),
            ("000000000", "302"),
        schema=["account_number", "business_line_id"],

    # output dataframe after applying function
    output = classify_debit_credit_transactions(transactionsDf, accountsDf)

    # expected outputs in the target column
    expected_classifications = ["credit", "debit"]

    # assert results are as expected
    assert output.count() == 2
    assert [row.business_line for row in output.collect()] == expected_classifications

def test_normalise_transaction_information(spark):
    data = ["123-456-789", "123456789", "123456789EUR", "TEXT*?WITH.*CHARACTERS"]
    test_df = spark.createDataFrame(data, "string").toDF("transaction_information")

    expected = ["123456789", "123456789", "123456789EUR", "TEXTWITHCHARACTERS"]
    output = normalise_transaction_information(test_df)
    assert [row.transaction_information_cleaned for row in output.collect()] == expected

def test_join_transactionsDf_to_accountsDf(spark):

    data = ["123456789", "222222222EUR"]
    transactionsDf = spark.createDataFrame(data, "string").toDF(

    data = [
        "123456789",  # match
        "222222222",  # match
        "000000000",  # no-match
    accountsDf = spark.createDataFrame(data, "string").toDF("account_number")

    output = join_transactionsDf_to_accountsDf(transactionsDf, accountsDf)

    assert output.count() == 2

def test_apply_debit_credit_business_classification(spark):
    data = [
        "101",  # credit
        "202",  # debit
        "000",  # other
    df = spark.createDataFrame(data, "string").toDF("business_line_id")
    output = apply_debit_credit_business_classification(df)

    expected = ["credit", "debit", "other"]
    assert [row.business_line for row in output.collect()] == expected

We now have unit-test for each component in the PySpark pipeline.

As each test reuses the same SparkSession the overhead of running multiple tests is significantly reduced.

Tips for unit testing PySpark code

Keep the unit-tests isolated

Be careful not to modify your spark session during a test (e.g. creating a table, but not deleting it afterwards).

Try and keep the creation of data close to where it is used.

You could use fixtures to share dataframes or even load test data from csv files etc. However, in my experience, it is easier and more readable to create data as required for each individual test.

Create test dataframes with the minimum required information

When creating dataframes with test data, only create columns relevant to the transformation.

You only really need to create data with columns required for the function. You don’t need all the other columns which might be present in the production data.

This helps write concise functions and is more readable as it is clear which columns are required and impacted by the function. If you find you need a big dataframe with many columns in order to carry out a transformation you are probably trying to do too much at once.

This is just a guideline, your own usecase might require more complicated test data, but if possible keep it small, concise and localised to the test.

Remember to call an action in order to trigger the PySpark computation

PySpark uses lazy evaluation. You need to call an ‘action’ (e.g. collect, count etc.) during your test in order to compute a result that you can compare to the expected output.

Don't run all PySpark tests if you don't need to

PySpark tests generally take longer than normal unit tests to run as there is overhead to calculate a computation plan and then execute it.

During development, make use of some of Pytest’s features such as the -k flag to run single tests or just run tests in a single file. Then only run the full test suite before committing your code.

Test positive and negative outcomes

For example, when testing a joining condition you should include data which should not satisfy the join condition. This helps ensure you are excluding the right data as well as including the right data.

Happy testing!

Other Resources

Other PySpark Unit-testing blog posts

Pytest Resources

General PySpark Best Practices

Further Reading