This is something that I attempted to find a guide for online, but it appears that the niche nature of pyspark (and the lack of practice for testing data-centric operations) results in no clear one-stop location for figuring out how to do this (apparently simple?) task.
In pyspark
, we have several different ways to read/write data. The two I'm most familiar with (and will be focussing on) is reads/writes to parquets and to spark
tables. (The general patterns outlined here can be applied to other I/O however, such as to delta
and csv
).
The Setup
For a bit of context, here's the problem that needed solving; The following (fairly standard) python
directory structure is given:
python-project/
├── pyproject.toml
├── setup.py
├── src
│ ├── core.py
│ └── __init__.py
└── tests
├── conftest.py
├── __init__.py
└── unit.py
where core.py
has a python class with several functions, including few for writing/reading using the pyspark
API.
Problem
The problem lies in the size of the data being read/written to, as well as access cues surrounding the data. For example, some tables might be hundreds of megabytes (to over 10 gigabytes), and the locations for reading/writing may be protected by some access-control key (e.g. it might be an S3 bucket or a Blob storage). To zero in on a "unit" test, i.e. not testing the entire environment setup, we need to isolate only the functionality of the code within core.py
.
Aside: this is a simple project structure exaggerated for the sake of explanation, obviously in real-world projects the setup will be a bit more involved specifically concerning imports, however, does NOT change the pattern outlined in this article
The Solution
monkeypatch
This is what pytest
calls it's functionality for "patching" or mocking function calls. This allows the safe setting/deletion of an attribute, dictionary item of environment variable, and is undone after the given test (or fixture) has finished.
Let's use a simple example: if I had code that got the date and time from some external source (say a network-synchronised common server), but I don't want to call the server every time the test is run we can, use the monkeypatch
built-in fixture to set the attribute on my module that performs the date-time request, to return some fixed value (or set of return values).
# test_my_module.py
import my_module
def test_perform_calculation_based_on_date(monkeypatch):
# Mock the behavior of get_current_date_from_external_service
# Replace it with a custom function that returns a fixed date.
fixed_date = my_module.date(2023, 1, 1)
monkeypatch.setattr(my_module, "get_current_date_from_external_service", lambda: fixed_date)
# Call the function under test
result = my_module.perform_calculation_based_on_date()
# Assert the result
# In this example, the current date is mocked to be January 1, 2023,
# so the result should be (2023 * 10) + 1 = 20231
assert result == 20231
Mocking spark
Reads
In the file structure described prior, core.py
has a class that looks like this:
# core.py
class ReallyImportantClass:
def __init__(self, spark):
self.spark = spark
def read_from_parquet(self, path):
return self.spark.read.parquet(path)
def write_to_parquet(self, path, sdf):
sdf.write.mode('overwrite').parquet(path)
def read_from_table(self, table_name):
return self.spark.read.table(table_name)
def write_to_table(self, table_name, sdf):
sdf.write.table(table_name)
def super_important_complex_calculation(self, important_table_name):
table = self.read_from_table('some_important_table')
# do calculations
# THIS is what we really want to test
self.write_to_table(important_table_name)
When running tests, we don't want to actually read from the real data source, since the entire point of the test would be to ensure the logic within super_important_complex_calculation
makes sense.
In order to provide some fixed input (or set of inputs in order to test different sets of edge-cases, etc), we need to create a test that mocks the spark.read.table
(and spark.read.parquet
) functions to return some fixed input.
In our /tests/unit.py
file, we define two functions that should replace the read
functions, and return some given, known, input data:
import logging
import unittest
import pyspark
from pandas import DataFrame
from src.core import ReallyImportantClass
def mock_read_parquet(self, *args, **kwargs):
return self.spark.createDataFrame(DataFrame({'a': [1, 2]}))
def mock_read_table(self, *args, **kwargs):
return self.spark.createDataFrame(DataFrame({'a': [1, 2]}))
Then we define a function to unit-test our super_important_complex_calculation
:
def test_super_important(spark, monkeypatch, caplog, tmp_path):
monkeypatch.setattr(
ReallyImportantClass,
'read_from_table',
mock_read_table)
r = ReallyImportantClass(spark)
r.super_important_complex_calculation()
## define assertions here!!
Let's break down the beginning of the function, there's first a call to monkeypatch.setattr
, followed by the name of the class
. This tells pytest
to replace the return value of the attribtue write_to_table
with whatever mock_write_parquet
is instead.
In essence, when this file is run with pytest tests/unit.py
, pytest
replaces the value of the spark.read.parquet
and spark.read.table
with the mocked output value.
Now we have a consistent way of testing logic that previously depended on a particularly expensive function call, and the common functionality in pytest
is still supported (i.e. parameterizing tests to use different inputs, etc etc)
Mocking Spark Writes
Mocking writes is slightly more nuanced, since write_to_table
and write_to_parquet
does not return anything in the original core.py
. However, we can alter this behaviour to inspect the output using monkeypatch
once again:
def mock_write_parquet(self, path, sdf, *args, **kwargs):
logging.debug('Write Success')
return sdf
def mock_write_table(self, table_name, sdf, *args, **kwargs):
logging.debug('Write Success')
return sdf
In the above, we define functions that match the signatures of ReallyImportantClass
read/write methods, but this time we add a return value for the table write (so we can test the output, etc).
After this, we then monkeypatch
the write_to_table
function in the same was as we did the read
using monkeypatch.setattr
.
def test_super_important(spark, monkeypatch, caplog, tmp_path):
monkeypatch.setattr(
ReallyImportantClass,
'write_to_table',
mock_write_parquet)
monkeypatch.setattr(
ReallyImportantClass,
'read_from_table',
mock_read_table)
r = ReallyImportantClass(spark)
r.super_important_complex_calculation('output_table')
## define assertions here!
Inspecting the Output
Notice we mocked the output, but we can't actually test to ensure the function worked (since we don't explicitly call read_from_table
in the tests). In this case, what we can use is the pytest
tmp_path
builtin, to replace the name of an actual table with a local path:
def test_super_important(spark, monkeypatch, caplog, tmp_path):
monkeypatch.setattr(
ReallyImportantClass,
'write_to_table',
mock_write_parquet)
monkeypatch.setattr(
ReallyImportantClass,
'read_from_table',
mock_read_table)
r = ReallyImportantClass(spark)
r.super_important_complex_calculation('tmp/output_table')
## define assertions here!
From here, we can then go on to define the necessary assertions, etc, that may be use-case specific to your given class.
Another option, if we want to test just if the write has been called, would be to set the return value of the monkeypatch
of the write function to logging call alone instead.
Finally, we can see the output!!
pytest tests/unit.py --log-format="%(asctime)s %(levelname)s %(message)s" --log-date-format="%Y-%m-%d %H:%M:%S" --durations=5
==================================== test session starts =====================================
platform linux -- Python 3.8.15, pytest-7.4.0, pluggy-0.13.1
rootdir: /storage/projects/test-mock-spark
collected 5 items
tests/unit.py ..... [100%]
==================================== slowest 5 durations =====================================
2.35s call tests/unit.py::test_read_parquet
1.96s setup tests/unit.py::test_read_parquet
0.96s teardown tests/unit.py::test_super_important
0.10s call tests/unit.py::test_read_table
0.02s call tests/unit.py::test_write_table
=============================== 5 passed, 10 warnings in 5.44s ===============================