簡介
在一個稍大一點的python項目中,我們很有可能會用到注冊器(register)。這個注冊器不是用戶賬號注冊的模塊,而是項目中注冊模塊的一個模塊。舉個例子,一個深度學習項目可能支持多種模型;具體使用哪種模型可能是用戶在配置文件中指定的。最簡單的實現方式,就是維護一個模型名稱->模型類
的字典。但每當你增加一個模型時,這個字典就需要手動維護,比較繁瑣。本文介紹一種注冊器的模塊,你需要維護的是需要注冊的模塊的代碼路徑(相對簡介些)。
這個模塊在我們的開源項目Delta中也有使用。
要注冊的模塊
models/model.py
:
class Model:
pass
@Registers.model.register
class Model1(Model):
pass
@Registers.model.register
class Model2(Model):
pass
@Registers.model.register
class Model3(Model):
pass
注冊器 Register
class Register:
def __init__(self, registry_name):
self._dict = {}
self._name = registry_name
def __setitem__(self, key, value):
if not callable(value):
raise Exception(f"Value of a Registry must be a callable!\nValue: {value}")
if key is None:
key = value.__name__
if key in self._dict:
logging.warning("Key %s already in registry %s." % (key, self._name))
self._dict[key] = value
def register(self, target):
"""Decorator to register a function or class."""
def add(key, value):
# 間接調用__setitem__
self[key] = value
return value
if callable(target):
# @reg.register
return add(None, target)
# @reg.register('alias')
return lambda x: add(target, x)
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def keys(self):
"""key"""
return self._dict.keys()
補充一個知識點,@是python的裝飾器語法糖。
@decorate
def func():
等價于
func = decorate(func)
這里,Register類似于一個dict(實際上是有一個_dict屬性),可以set_item和get_item。關鍵是register函數,它可以作為裝飾器,注冊一個函數或者一個類。例如:
@register_obj.register
class Modle1:
pass
等價于register_obj.register(Model1),最終執行的是add(None, Model1)。
而:
@register_obj.register("model_one")
class Model1:
pass
實際上是register_obj.register("model_one")(Model1),最終執行的是add("model_one", Model_1)。
總結下:Register類保存了名稱->模塊的數據,且提供了方便的注冊裝飾器。
所有注冊器 Registers
class Registers:
def __init__(self):
raise RuntimeError("Registries is not intended to be instantiated")
model = Register('model')
Registers保存了所有的Register對象。
加載所有需要的模塊到注冊器 import_all_modules_for_register
在模塊代碼中加入注冊裝飾器之后,我們還需要把這些模塊實際地導入,才能讓這些子模塊加入進注冊器中。
一般大家會首先想到import。比如這里可以直接import models.models就可以讓注冊裝飾器起作用。
但是import子模塊這種形式很有可能導致循環引用的問題。為了避免循環引用,我們可以在代碼入口處,統一地動態引入所有子模塊。動態導入包使用importlib。
MODEL_MODULES = ["models"]
ALL_MODULES = [("models", MODEL_MODULES)]
def _handle_errors(errors):
"""Log out and possibly reraise errors during import."""
if not errors:
return
for name, err in errors:
logging.warning("Module {} import failed: {}".format(name, err))
def import_all_modules_for_register(custom_module_paths=None):
"""Import all modules for register."""
modules = []
for base_dir, modules in ALL_MODULES:
for name in modules:
full_name = base_dir + "." + name
modules.append(full_name)
if isinstance(custom_module_paths, list):
modules += custom_module_paths
errors = []
for module in modules:
try:
importlib.import_module(module)
except ImportError as error:
errors.append((module, error))
_handle_errors(errors)
使用
最后我們使用下我們的注冊器模塊:
from register import import_all_modules_for_register
from register import Registers
print("Registers.model._dict before: ", Registers.model._dict)
import_all_modules_for_register()
print("Registers.model._dict after: ", Registers.model._dict)
輸出:
Registers.model._dict before: {}
Registers.model._dict after: {'Model': <class 'models.models.Model'>, 'Model1': <class 'models.models.Model1'>, 'Model2': <class 'models.models.Model2'>, 'Model3': <class 'models.models.Model3'>}
可以看到,需要的模塊已經加入到注冊器中。
這個模塊在我們的開源項目Delta中也有使用。