Monkeypatching Spark I/O in Python

Monkeypatching Spark I/O in Python

Unit-Testing Data Science Code

This is something that I attempted to find a guide for online, but it appears that the niche nature of pyspark (and the lack of practice for testing data-centric operations) results in no clear one-stop location for figuring out how to do this (apparently simple?) task.

In pyspark, we have several different ways to read/write data. The two I'm most familiar with (and will be focussing on) is reads/writes to parquets and to spark tables. (The general patterns outlined here can be applied to other I/O however, such as to delta and csv).

The Setup

For a bit of context, here's the problem that needed solving; The following (fairly standard) python directory structure is given:

python-project/
├── pyproject.toml
├── setup.py
├── src
│   ├── core.py
│   └── __init__.py
└── tests
    ├── conftest.py
    ├── __init__.py
    └── unit.py

where core.py has a python class with several functions, including few for writing/reading using the pyspark API.

Problem

The problem lies in the size of the data being read/written to, as well as access cues surrounding the data. For example, some tables might be hundreds of megabytes (to over 10 gigabytes), and the locations for reading/writing may be protected by some access-control key (e.g. it might be an S3 bucket or a Blob storage). To zero in on a "unit" test, i.e. not testing the entire environment setup, we need to isolate only the functionality of the code within core.py.

Aside: this is a simple project structure exaggerated for the sake of explanation, obviously in real-world projects the setup will be a bit more involved specifically concerning imports, however, does NOT change the pattern outlined in this article

The Solution

monkeypatch

This is what pytest calls it's functionality for "patching" or mocking function calls. This allows the safe setting/deletion of an attribute, dictionary item of environment variable, and is undone after the given test (or fixture) has finished.

Let's use a simple example: if I had code that got the date and time from some external source (say a network-synchronised common server), but I don't want to call the server every time the test is run we can, use the monkeypatch built-in fixture to set the attribute on my module that performs the date-time request, to return some fixed value (or set of return values).

# test_my_module.py

import my_module

def test_perform_calculation_based_on_date(monkeypatch):
    # Mock the behavior of get_current_date_from_external_service
    # Replace it with a custom function that returns a fixed date.
    fixed_date = my_module.date(2023, 1, 1)
    monkeypatch.setattr(my_module, "get_current_date_from_external_service", lambda: fixed_date)

    # Call the function under test
    result = my_module.perform_calculation_based_on_date()

    # Assert the result
    # In this example, the current date is mocked to be January 1, 2023,
    # so the result should be (2023 * 10) + 1 = 20231
    assert result == 20231

Mocking spark Reads

In the file structure described prior, core.py has a class that looks like this:

# core.py
class ReallyImportantClass:
  def __init__(self, spark):
    self.spark = spark
  def read_from_parquet(self, path):
    return self.spark.read.parquet(path)

  def write_to_parquet(self, path, sdf):
    sdf.write.mode('overwrite').parquet(path)

  def read_from_table(self, table_name):
    return self.spark.read.table(table_name)

  def write_to_table(self, table_name, sdf):
    sdf.write.table(table_name)

  def super_important_complex_calculation(self, important_table_name):
    table = self.read_from_table('some_important_table')
    # do calculations
    # THIS is what we really want to test
    self.write_to_table(important_table_name)

When running tests, we don't want to actually read from the real data source, since the entire point of the test would be to ensure the logic within super_important_complex_calculation makes sense.

In order to provide some fixed input (or set of inputs in order to test different sets of edge-cases, etc), we need to create a test that mocks the spark.read.table (and spark.read.parquet) functions to return some fixed input.

In our /tests/unit.py file, we define two functions that should replace the read functions, and return some given, known, input data:

import logging
import unittest

import pyspark
from pandas import DataFrame

from src.core import ReallyImportantClass

def mock_read_parquet(self, *args, **kwargs):
  return self.spark.createDataFrame(DataFrame({'a': [1, 2]})) 

def mock_read_table(self, *args, **kwargs):
  return self.spark.createDataFrame(DataFrame({'a': [1, 2]}))

Then we define a function to unit-test our super_important_complex_calculation:

def test_super_important(spark, monkeypatch, caplog, tmp_path):
  monkeypatch.setattr(
    ReallyImportantClass, 
    'read_from_table', 
    mock_read_table)


  r = ReallyImportantClass(spark)
  r.super_important_complex_calculation()

  ## define assertions here!!

Let's break down the beginning of the function, there's first a call to monkeypatch.setattr, followed by the name of the class. This tells pytest to replace the return value of the attribtue write_to_table with whatever mock_write_parquet is instead.

In essence, when this file is run with pytest tests/unit.py, pytest replaces the value of the spark.read.parquet and spark.read.table with the mocked output value.

Now we have a consistent way of testing logic that previously depended on a particularly expensive function call, and the common functionality in pytest is still supported (i.e. parameterizing tests to use different inputs, etc etc)

Mocking Spark Writes

Mocking writes is slightly more nuanced, since write_to_table and write_to_parquet does not return anything in the original core.py. However, we can alter this behaviour to inspect the output using monkeypatch once again:

def mock_write_parquet(self, path, sdf, *args, **kwargs):
  logging.debug('Write Success')
  return sdf

def mock_write_table(self, table_name, sdf, *args, **kwargs):
  logging.debug('Write Success')
  return sdf

In the above, we define functions that match the signatures of ReallyImportantClass read/write methods, but this time we add a return value for the table write (so we can test the output, etc).

After this, we then monkeypatch the write_to_table function in the same was as we did the read using monkeypatch.setattr.

def test_super_important(spark, monkeypatch, caplog, tmp_path):
  monkeypatch.setattr(
    ReallyImportantClass, 
    'write_to_table', 
    mock_write_parquet)

  monkeypatch.setattr(
    ReallyImportantClass, 
    'read_from_table', 
    mock_read_table)


  r = ReallyImportantClass(spark)
  r.super_important_complex_calculation('output_table')

  ## define assertions here!

Inspecting the Output

Notice we mocked the output, but we can't actually test to ensure the function worked (since we don't explicitly call read_from_table in the tests). In this case, what we can use is the pytest tmp_path builtin, to replace the name of an actual table with a local path:

def test_super_important(spark, monkeypatch, caplog, tmp_path):
  monkeypatch.setattr(
    ReallyImportantClass, 
    'write_to_table', 
    mock_write_parquet)

  monkeypatch.setattr(
    ReallyImportantClass, 
    'read_from_table', 
    mock_read_table)


  r = ReallyImportantClass(spark)
  r.super_important_complex_calculation('tmp/output_table')

  ## define assertions here!

From here, we can then go on to define the necessary assertions, etc, that may be use-case specific to your given class.

Another option, if we want to test just if the write has been called, would be to set the return value of the monkeypatch of the write function to logging call alone instead.

Finally, we can see the output!!

pytest tests/unit.py --log-format="%(asctime)s %(levelname)s %(message)s" --log-date-format="%Y-%m-%d %H:%M:%S" --durations=5
==================================== test session starts =====================================
platform linux -- Python 3.8.15, pytest-7.4.0, pluggy-0.13.1
rootdir: /storage/projects/test-mock-spark
collected 5 items

tests/unit.py .....                                                                    [100%]

==================================== slowest 5 durations =====================================
2.35s call     tests/unit.py::test_read_parquet
1.96s setup    tests/unit.py::test_read_parquet
0.96s teardown tests/unit.py::test_super_important
0.10s call     tests/unit.py::test_read_table
0.02s call     tests/unit.py::test_write_table
=============================== 5 passed, 10 warnings in 5.44s ===============================