Python patch()

Created with Sketch.

Python patch()

Summary: in this tutorial, you’ll learn how to use the Python patch() to replace a target with a mock object temporarily.

Introduction to the Python patch

The unittest.mock module has a patch() that allows you to temporarily replace a target with a mock object.

A target can be a function, a method, or a class. It’s a string with the following format:

'package.module.className'

Code language: Python (python)

To use the patch() correctly, you need to understand two important steps:

  • Identify the target
  • How to call patch()

Identifying the target

To identify a target:

  • The target must be importable.
  • And patch the target where it is used, not where it comes from.

Calling patch

Python provides you with three ways to call patch():

  • Decorators for a function or a class.
  • Context manager
  • Manual start/stop

When you use the patch() as a decorator of a function or class, inside the function or class the target is replaced with a new object.

If you use the patch in a context manager, inside the with statement, the target is replaced with a new object.

In both cases, when the function or the with statement exits, the patch is undone.

Python patch examples

Let’s create a new module called total.py for demonstration purposes:

def read(filename):
""" read a text file and return a list of numbers """
with open(filename) as f:
lines = f.readlines()
return [float(line.strip()) for line in lines]

def calculate_total(filename):
""" return the sum of numbers in a text file """
numbers = read(filename)
return sum(numbers)

Code language: Python (python)

How it works.

The read() function reads a text file, converts each line into a number, and returns a list of numbers. For example, a text file has the following lines:

1
2
3

Code language: Python (python)

the read() function will return the following list:

[1, 2, 3]

Code language: Python (python)

The calculate_total() function uses the read() function to get a list of numbers from a file and returns the sum of the numbers.

To test calculate_total(), you can create a test_total_mock.py module and mock the read() function as follows:

import unittest

from unittest.mock import MagicMock

import total

class TestTotal(unittest.TestCase):
def test_calculate_total(self):
total.read = MagicMock()
total.read.return_value = [1, 2, 3]
result = total.calculate_total('')
self.assertEqual(result, 6)

Code language: Python (python)

Run the test:

python -m unittest test_total_mock.py -v

Code language: Python (python)

Output:

test_calculate_total (test_total_mock.TestTotal) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.001s

OK

Code language: Python (python)

Instead of using the MagicMock() object directly, you can use the patch().

1) Using patch() as a decorator

The following test module test_total_with_patch_decorator.py tests the total.py module using the patch() as a function decorator:

import unittest
from unittest.mock import patch
import total

class TestTotal(unittest.TestCase):
@patch('total.read')
def test_calculate_total(self, mock_read):
mock_read.return_value = [1, 2, 3]
result = total.calculate_total('')
self.assertEqual(result, 6)

Code language: Python (python)

How it works.

First, import the patch from the unittest.mock module:

from unittest.mock import patch

Code language: Python (python)

Second, decorate the test_calculate_total() test method with the @patch decorator. The target is the read function of the total module.

@patch('total.read')
def test_calculate_total(self, mock_read):
# ...

Code language: Python (python)

Because of the @patch decorator, the test_calculate_total() method has an additional argument mock_read which is an instance of the MagicMock.

Note that you can name the parameter whatever you want.

Inside the test_calculate_total() method, the patch() will replace the total.read() function with the mock_read object.

Third, assign a list to the return_value of the mock object:

mock_read.return_value = [1, 2, 3]

Code language: Python (python)

Finally, call the calculate_total() function and use the assertEqual() method to test if the total is 6.

Because the mock_read object will be called instead of the total.read() function, you can pass any filename to the calculate_total() function:

result = total.calculate_total('')
self.assertEqual(result, 6)

Code language: Python (python)

Run the test:

python -m unittest test_total_patch_decorator -v

Code language: Python (python)

Output:

test_calculate_total (test_total_patch_decorator.TestTotal) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.001s

OK

Code language: Python (python)

2) Using patch() as a context manager

The following example illustrates how to use the patch() as a context manager:

import unittest
from unittest.mock import patch
import total

class TestTotal(unittest.TestCase):
def test_calculate_total(self):
with patch('total.read') as mock_read:
mock_read.return_value = [1, 2, 3]
result = total.calculate_total('')
self.assertEqual(result, 6)

Code language: Python (python)

How it works.

First, patch total.read() function using as the mock_read object in a context manager:

with patch('total.read') as mock_read:

Code language: Python (python)

It means that within the with block, the patch() replaces the total.read() function with the mock_read object.

Second, assign a list of numbers to the return_value property of the mock_read object:

mock_read.return_value = [1, 2, 3]

Code language: Python (python)

Third, call the calculate_total() function and test if the result of the calculate_total() function is equal 6 using the assertEqual() method:

result = total.calculate_total('')
self.assertEqual(result, 6)

Code language: Python (python)

Run the test:

python -m unittest test_total_patch_ctx_mgr -v

Code language: Python (python)

Output:

test_calculate_total (test_total_patch_ctx_mgr.TestTotal) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.001s

OK

Code language: Python (python)

3) Using patch() manually

The following test module (test_total_patch_manual.py) shows how to use patch() manually:

import unittest
from unittest.mock import patch
import total

class TestTotal(unittest.TestCase):
def test_calculate_total(self):
# start patching
patcher = patch('total.read')

# create a mock object
mock_read = patcher.start()

# assign the return value
mock_read.return_value = [1, 2, 3]

# test the calculate_total
result = total.calculate_total('')
self.assertEqual(result, 6)

# stop patching
patcher.stop()

Code language: Python (python)

How it works.

First, start a patch by calling patch() with a target is the read() function of the total module:

patcher = patch('total.read')

Code language: Python (python)

Next, create a mock object for the read() function:

mock_read = patcher.start()

Code language: Python (python)

Then, assign a list of numbers to the return_value of the mock_read object:

result = total.calculate_total('')
self.assertEqual(result, 6)

Code language: Python (python)

After that, call the calculate_total() and test its result.

def test_calculate_total(self):
self.mock_read.return_value = [1, 2, 3]
result = total.calculate_total('')
self.assertEqual(result, 6)

Code language: Python (python)

Finally, stop patching by calling the stop() method of the patcher object:

patcher.stop()

Code language: Python (python)

Summary

  • Use the patch() from unittest.mock module to temporarily replace a target with a mock object.
  • Use the patch() as a decorator, a context manager, or manually call start() and stop() patching.

Leave a Reply

Your email address will not be published. Required fields are marked *