[Dd]enzow(ill)? with DB and Python

DBとか資格とかPythonとかの話をつらつらと

Pythonのmockをしても引数が変わらないようなデコレータを書いた

小ネタです。unittest.mock.patchを使うと、動的にメソッド内での関数やモジュールを差し替えることができますが元関数の引数への影響をなくしたかったという話です。そういやそもそもこんな記事書いてました。

patch

unittest配下にあるように、これはユニットテストを書くための補助的なモジュールです(だと思ってる)。

def _hello():
    return 'HELLO'


def hello(name):
    print('hello, {}'.format(name))
    print(_hello())


if __name__ == '__main__':
    hello('AAA')

このコードをもとにpatchしていきます。普通に使うとこんな感じになります。

from unittest.mock import patch
import functools

def _hello():
    return 'HELLO'

@patch('__main__._hello', side_effect=lambda: 'mocked hello')
def hello(name, mocked_obj):
    print('hello, {}'.format(name))
    print(_hello())

if __name__ == '__main__':
    hello('AAA')

実行するとこうなります。

hello, AAA
mocked hello

_helloがモックオブジェクトに置き換えられており、side_effect=lambda: 'mocked hello'で指定されたようにmocked helloが戻されています。この動き自体はとても嬉しいのですが、helloという元の関数にはモックされたオブジェクトへの参照が自動的に渡されるので、それを受け取れるようにしなければいけません。

しれっと、helloの定義もdef hello(name, mocked_obj):に書き換えてあります。

仮に、

@patch('__main__._hello', side_effect=lambda: 'mocked hello')
def hello(name):
    print('hello, {}'.format(name))
    print(_hello())

とした場合はこんな感じで引数のエラーがでます。

Traceback (most recent call last):
  File "test.py", line 29, in <module>
    hello('AAA')
  File "/Users/denzow/.pyenv/versions/3.6.5/lib/python3.6/unittest/mock.py", line 1179, in patched
    return func(*args, **keywargs)
TypeError: hello() takes 1 positional argument but 2 were given

つまり、@patchを付ける場合はそれを踏まえて適用される側の引数を調整しないといけないわけです。

まぁ、モックオブジェクトが引数経由で渡されてくれないとテストコード側でモックがちゃんと呼び出されているかとかチェックできないと困るのですが、個人的な要件で呼び出し元の引数の数を変えたくなかったので考えました。

どうしたのか

@patchをラップするデコレータを作りました。

class KeepArgPatch:

    def __init__(self, *args, **kwargs):
        self._args = args
        self._kwargs = kwargs

    def __call__(self, func):
        @functools.wraps(func)
        def _decorated_fun(*args2, **kwargs2):
            # mock.patchで元関数をラップする
            patched_func = patch(*self._args, **self._kwargs)(func)
            # 位置変数末尾にモックへの参照が入るので、削り取る
            args2 = args2[:-1]
            return patched_func(*args2, **kwargs2)
        return _decorated_fun

これで以下のようにpatchされていても元の定義のまま関数が呼び出せます。

from unittest.mock import patch
import functools

def _hello():
    return 'HELLO'


class KeepArgPatch:

    def __init__(self, *args, **kwargs):
        self._args = args
        self._kwargs = kwargs

    def __call__(self, func):
        @functools.wraps(func)
        def _decorated_fun(*args2, **kwargs2):
            patched_func = patch(*self._args, **self._kwargs)(func)
            args2 = args2[:-1]
            return patched_func(*args2, **kwargs2)
        return _decorated_fun

@KeepArgPatch('__main__._hello', side_effect=lambda: 'mocked hello')
def hello(name):  # mockを受け取っていない
    print('hello, {}'.format(name))
    print(_hello())

if __name__ == '__main__':
    hello('AAA')
$ python test.py
hello, AAA
mocked hello

まとめ

小ネタだしニッチだし誰得でしたが、まぁできてよかったです。この記事を書きながら思いついたコードなのでバグがあるかもしれないですが・・・