Python mock的時候偷換對象

最近做發短信的service的時候,與短信相關的測試需要mock,于是碰到以下問題。
例如有三個模塊a和b和test。

# a.py:

def send_sms():
    # 調用運營商的短信接口
    print 'send_sms'


# b.py:

from a import send_sms

def func():
    status_code = send_sms()
    return status_code

# test.py:

from b import func

def dummy_send_sms():
    print 'send_sms'

test_assert(func() == status_success)

a模塊負責發短信,b是具體的業務,我們想測試b的業務,但是我們可不想在測試的時候發短信。于是我們想在測試的時候,把b模塊func里面對send_sms的調用改成對一個mock函數的調用,如test模塊的dummy_send_sms函數。

先說結論,這里要這么干:
# test.py:
import sys
from b import func

def dummy_send_sms():
    print 'send_sms'

sys.modules['b'].__dict__['send_sms'] = dummy_send_sms
# 如果b模塊是import a,然后a.send_sms的話就要這樣
# sys.modules['a'].__dict__['send_sms'] = dummy_send_sms

test_assert(func() == status_success)

或者使用Python的mock庫的patch。
可以這么干,有兩個原因:
首先,Python里面模塊就是一個Python對象,所以我們可以隨時通過篡改這個模塊對象的成員來篡改模塊里面符號與對象的對應關系。
然后,Python里面對名字的resolution是在運行的時候發生的。

具體來說
函數調用的字節碼如下:

LOAD_NAME  0 (send_sms) 
# 0 表示符號表第0個元素,也就是字符串"send_sms",然后把"send_sms"對應的對象壓到棧里面(注意Python是基于堆棧的虛擬機,這里的棧跟調用棧并不完全一樣)
CALL_FUNCTION 0
# 0表示0個參數,所調用的函數就是棧上面的由"send_sms"查到的對象

拿到"send_sms"這個參數,CALL_FUNCTION這條字節碼怎么去執行函數呢?
# /Python/ceval.c

TARGET(LOAD_NAME) {
    PyObject *name = GETITEM(names, oparg);
    PyObject *locals = f->f_locals;
    PyObject *v;
    if (locals == NULL) {
        PyErr_Format(PyExc_SystemError,
                     "no locals when loading %R", name);
        goto error;
    }
    if (PyDict_CheckExact(locals)) {
        v = PyDict_GetItem(locals, name);
        Py_XINCREF(v);
    }
    else {
        v = PyObject_GetItem(locals, name);
        if (v == NULL && _PyErr_OCCURRED()) {
            if (!PyErr_ExceptionMatches(PyExc_KeyError))
                goto error;
            PyErr_Clear();
        }
    }
    if (v == NULL) {
        v = PyDict_GetItem(f->f_globals, name);
        Py_XINCREF(v);
        if (v == NULL) {
            if (PyDict_CheckExact(f->f_builtins)) {
                v = PyDict_GetItem(f->f_builtins, name);
                if (v == NULL) {
                    format_exc_check_arg(
                                PyExc_NameError,
                                NAME_ERROR_MSG, name);
                    goto error;
                }
                Py_INCREF(v);
            }
            else {
                v = PyObject_GetItem(f->f_builtins, name);
                if (v == NULL) {
                    if (PyErr_ExceptionMatches(PyExc_KeyError))
                        format_exc_check_arg(
                                    PyExc_NameError,
                                    NAME_ERROR_MSG, name);
                    goto error;
                }
            }
        }
    }
    PUSH(v);
    DISPATCH();
}

這段代碼首先拿出name(也就是"send_sms"這個字符串),然后就在ff_localsf_globalsf_builtins 里面找相應的對象。
這個f就是當前的棧幀PyFrameObject。

typedef struct _frame {
    PyObject_VAR_HEAD
    # ...
    PyObject *f_builtins;       /* builtin symbol table (PyDictObject) */
    PyObject *f_globals;        /* global symbol table (PyDictObject) */
    PyObject *f_locals;         /* local symbol table (any mapping) */
    # ...
} PyFrameObject;

所以由名字找出對象,是在LOAD_NAME這條指令運行的時候才計算的,這樣就為我們的篡改留了機會。
所以,當我們運行了sys.modules['b'].__dict__['send_sms'] = dummy_send_sms之后,LOAD_NAME之后根據"send_sms"找到的對象就是我們我們篡改的dummy_send_sms

其實,由于PyFunctionObject擁有個func_globals的指針指向所在模塊的符號表:

typedef struct {
    PyObject_HEAD
    PyObject *func_code;    /* A code object */
    PyObject *func_globals; /* A dictionary (other mappings won't do) */
    PyObject *func_defaults;    /* NULL or a tuple */
    PyObject *func_closure; /* NULL or a tuple of cell objects */
    PyObject *func_doc;     /* The __doc__ attribute, can be anything */
    PyObject *func_name;    /* The __name__ attribute, a string object */
    PyObject *func_dict;    /* The __dict__ attribute, a dict or NULL */
    PyObject *func_weakreflist; /* List of weak references */
    PyObject *func_module;  /* The __module__ attribute, can be anything */

    /* Invariant:
     *     func_closure contains the bindings for func_code->co_freevars, so
     *     PyTuple_Size(func_closure) == PyCode_GetNumFree(func_code)
     *     (func_closure may be NULL if PyCode_GetNumFree(func_code) == 0).
     */
} PyFunctionObject;

所以還可以這么干
# test.py:
import sys
from b import func

def dummy_send_sms():
    print 'send_sms'

# 可以直接從函數的 func_globals 指針修改符號表
b_globals = getattr(func, '__globals__')
b_globals['send_sms'] = dummy_send_sms

test_assert(func() == status_success)
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容