前言:不管是Detectron还是mmdetection,都有用到这个register机制,特意去弄明白,记录一下。
首先看Registry代码:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reservedfrom typing import Dict, Optional, Iterable, Tuple, Iteratorfrom tabulate import tabulateclass Registry(Iterable[Tuple[str, object]]):"""The registry that provides name -> object mapping, to support third-partyusers' custom modules.To create a registry (e.g. a backbone registry):.. code-block:: pythonBACKBONE_REGISTRY = Registry('BACKBONE')To register an object:.. code-block:: python@BACKBONE_REGISTRY.register()class MyBackbone():...Or:.. code-block:: pythonBACKBONE_REGISTRY.register(MyBackbone)"""def __init__(self, name: str) -> None:"""Args:name (str): the name of this registry"""self._name: str = nameself._obj_map: Dict[str, object] = {}def _do_register(self, name: str, obj: object) -> None:assert (name not in self._obj_map), "An object named '{}' was already registered in '{}' registry!".format(name, self._name)self._obj_map[name] = objdef register(self, obj: object = None) -> Optional[object]:"""Register the given object under the the name `obj.__name__`.Can be used as either a decorator or not. See docstring of this class for usage."""if obj is None:# used as a decoratordef deco(func_or_class: object) -> object:name = func_or_class.__name__ # pyre-ignoreself._do_register(name, func_or_class)return func_or_classreturn deco# used as a function callname = obj.__name__ # pyre-ignoreself._do_register(name, obj)def get(self, name: str) -> object:ret = self._obj_map.get(name)if ret is None:raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name))return retdef __contains__(self, name: str) -> bool:return name in self._obj_mapdef __repr__(self) -> str:table_headers = ["Names", "Objects"]table = tabulate(self._obj_map.items(), headers=table_headers, tablefmt="fancy_grid")return "Registry of {}:\n".format(self._name) + tabledef __iter__(self) -> Iterator[Tuple[str, object]]:return iter(self._obj_map.items())# pyre-fixme[4]: Attribute must be annotated.__str__ = __repr__
可看出register方法就是通过调用_do_register将函数名称或者类名称,函数地址或者类地址做成一个字典,在通过get方法获取函数或者类功能。
示例代码调用:
from fvcore.common.registry import RegistryBACKBONE_REGISTRY = Registry("BACKBONE")@BACKBONE_REGISTRY.register()
def test_register(cfg):print('==cfg:', cfg)return '==test_register is called'def debug_register():cfg = 'hahahah'print(BACKBONE_REGISTRY.get('test_register'))##返回函数或者类对象res = BACKBONE_REGISTRY.get('test_register')(cfg)#调用函数或者类功能print('==res:', res)if __name__ == '__main__':debug_register()
而对于mmcv:
import mmcvdef build_from_cfg(cfg, registry, default_args=None):args = cfg.copy()print('==cfg:', cfg)print('==registry:', registry)print('==default_args:', default_args)if default_args is not None:for name, value in default_args.items():args.setdefault(name, value)obj_type = args.pop('type') # 注册 str 类名if isinstance(obj_type, str):# 相当于 self._module_dict[obj_type]obj_cls = registry.get(obj_type)print('==obj_cls:', obj_cls)if obj_cls is None:raise KeyError(f'{obj_type} is not in the {registry.name} registry')# 如果已经实例化了,那就直接返回elif inspect.isclass(obj_type):obj_cls = obj_typeelse:raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')# 最终初始化对于类,并且返回,就完成了一个类的实例化过程return obj_cls(**args)ANYNAMES = mmcv.Registry('convert')#其实就是将Converter1 和 类实例化做成字典
@ANYNAMES.register_module()
class Converter1(object):def __init__(self, a, b):self.a = aself.b = ba_value = 10
b_value = 20
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
print('==converter_cfg:', converter_cfg)
converter = build_from_cfg(converter_cfg, ANYNAMES)
print('==converter:', converter)
print('==converter.a:', converter.a)
print('==converter.b:', converter.b)
上述例子就是将Converter1 和 类实例化做成字典,然后再通过build_from_cfg经过get方法获取类功能。