目录

0 规划

1 nn.Module 完成

1.1 常用接口

1.1.1 init 函数

1.1.2 状况的转化

1.1.3 参数的转化或转移

1.1.4 Apply 函数

1.2 特色的增修正查

1.2.1 特色设置

1.2.2 特色删去

1.2.3 常见的特色拜访

1.3 Forward & Backward

1.3.1 Hooks

1.3.2 运转逻辑

1.4 模块存取

1.4.1 Hooks

1.4.2 功用完成

1.4.3 _load_from_state_dict 妙用

\

本次解读首要介绍 PyTorch 中的神经网络模块,即 torch.nn,其间首要介绍 nn.Module,其他模块的细节能够经过 PyTorch 的 API 文档进行查阅,一些较重要的模块如 DataParallel 和 BN/SyncBN 等,都有独立的文章进行介绍。

0 规划

nn.Module 其实是 PyTorch 系统下一切神经网络模块的基类,此处顺带梳理了一下 torch.nn 中的各个组件,他们的联系概览如下图所示。

展开各模块后,模块之间的承继联系与层次结构如下图所示:

从各模块的承继联系来看,模块的安排和完成有几个常见的特色,供 PyTorch 代码库的开发者参考借鉴:

  • 一般有一个基类来界说接口,经过承继来处理不同维度的 input,如:
  1. Conv1d,Conv2d,Conv3d,ConvTransposeNd 承继自 _ConvNd
  2. MaxPool1d,MaxPool2d,MaxPool3d 承继自 _MaxPoolNd 等
  • 每一个类都有一个对应的 nn.functional 函数,类界说了所需求的 arguments 和模块的 parameters,在 forward 函数中将 arguments 和 parameters 传给 nn.functional 的对应函数来完成 forward 功用。比如:
  1. 一切的非线性激活函数,都是在 forward 中直接调用对应的 nn.functional 函数
  2. Normalization 层都是调用的如 F.layer_norm, F.group_norm 等函数
  • 承继 nn.Module 的模块首要重载 init、 forward、 和 extra_repr 函数,含有 parameters 的模块还会完成 reset_parameters 函数来初始化参数

1 nn.Module 完成

1.1 常用接口

1.1.1 init 函数

在 nn.Module 的 __init__ 函数中,会首先调用 torch._C._log_api_usage_once(“python.nn_module”), 这一行代码是 PyTorch 1.7 的新功用,用于监测并记载 API 的调用,具体解说可见 文档。

在此之后,nn.Module 初始化了一系列重要的成员变量。这些变量初始化了在模块 forward、 backward 和权重加载等时分会被调用的的 hooks,也界说了 parameters 和 buffers,如下面的代码所示:

self.training = True  # 控制 training/testing 状况
self._parameters = OrderedDict()  # 在练习过程中会跟着 BP 而更新的参数
self._buffers = OrderedDict()  # 在练习过程中不会跟着 BP 而更新的参数
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()  # Backward 完成后会被调用的 hook
self._forward_hooks = OrderedDict()  # Forward 完成后会被调用的 hook
self._forward_pre_hooks = OrderedDict()  # Forward 前会被调用的 hook
self._state_dict_hooks = OrderedDict()  # 得到 state_dict 今后会被调用的 hook
self._load_state_dict_pre_hooks = OrderedDict()  # load state_dict 前会被调用的 hook
self._modules = OrderedDict()  # 子神经网络模块

各个成员变量的功用在后面还会继续说到,这儿先在注释中简单解说。由源码的完成可见,承继 nn.Module 的神经网络模块在完成自己的 init 函数时,一定要先调用 super().__init__() 。只有这样才干正确地初始化自界说的神经网络模块,否则会短少上面代码中的成员变量而导致模块被调用时出错。实际上,假如没有提早调用 super().__init__(),在添加模块的 parameter 或许 buffer 的时分,被调用的 __setattr__ 函数也会查看出父类 nn.Module 没被正确地初始化并报错。(在面试的过程中,咱们常常发现面试者在写自界说神经网络模块的时分会疏忽掉这一点,看了这篇文章今后可要千万记得哦~)

1.1.2 状况的转化

  • 练习与测验

nn.Module 经过 self.training 来差异练习和测验两种状况,使得模块能够在练习和测验时有不同的 forward 行为(如 Batch Normalization)。nn.Module 经过 self.train() 和 self.eval() 来修正练习和测验状况,其间 self.eval 直接调用了 self.train(False),而 self.train() 会修正 self.training 并经过 self.children() 来调整一切子模块的状况。关于 self.children() 的介绍可见下文的 常见的特色拜访 章节。

def train(self: T, mode: bool = True) -> T:
    self.training = mode
    for module in self.children():
        module.train(mode)
    return self
  • Example: freeze 部分模型参数

在方针检测等使命中,常见的 training practice 会将 backbone 中的一切 BN 层保留为 eval 状况,即 freeze BN 层中的 running_mean 和 running_var,而且将浅层的模块 freeze。此时就需求重载 detector 类的 train 函数,MMDetection 中 ResNet 的 train 函数完成如下:

def train(self, mode=True):
    super(ResNet, self).train(mode)
    self._freeze_stages()
    if mode and self.norm_eval:
        for m in self.modules():
            # trick: eval have effect on BatchNorm only
            if isinstance(m, _BatchNorm):
                m.eval()
  • 梯度的处理

对于梯度的处理 nn.Module 有两个相关的函数完成,别离是 requires_grad_ 和 zero_grad 函数,他们都调用了 self.parameters() 来拜访一切的参数,并修正参数的 requires_grad 状况 或许 清理参数的梯度。

def requires_grad_(self: T, requires_grad: bool = True) -> T:
    for p in self.parameters():
        p.requires_grad_(requires_grad)
    return self
def zero_grad(self, set_to_none: bool = False) -> None:
    if getattr(self, '_is_replica', False):
        warnings.warn(
            "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
            "The parameters are copied (in a differentiable manner) from the original module. "
            "This means they are not leaf nodes in autograd and so don't accumulate gradients. "
            "If you need gradients in your forward method, consider using autograd.grad instead.")
    for p in self.parameters():
        if p.grad is not None:
            if set_to_none:
                p.grad = None
            else:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()

1.1.3 参数的转化或转移

nn.Module 完成了如下 8 个常用函数将模块转变成 float16 等类型、转移到 CPU/ GPU上。

  1. CPU:将一切 parameters 和 buffer 转移到 CPU 上
  2. type:将一切 parameters 和 buffer 转变成另一个类型
  3. CUDA:将一切 parameters 和 buffer 转移到 GPU 上
  4. float:将一切浮点类型的 parameters 和 buffer 转变成 float32 类型
  5. double:将一切浮点类型的 parameters 和 buffer 转变成 double 类型
  6. half:将一切浮点类型的 parameters 和 buffer 转变成 float16 类型
  7. bfloat16:将一切浮点类型的 parameters 和 buffer 转变成 bfloat16 类型
  8. to:移动模块或/和改动模块的类型

这些函数的功用最终都是经过 self._apply(function) 来完成的, function 一般是 lambda 表达式或其他自界说函数。因此,用户其实也能够经过 self._apply(function) 来完成一些特别的转化。self._apply() 函数实际上做了如下 3 件事情,最终将 function 完整地应用于整个模块。

  1. 经过 self.children() 进行递归的调用
  2. 对 self._parameters 中的参数及其 gradient 经过 function 进行处理
  3. 对 self._buffers 中的 buffer 逐一经过 function 来进行处理
def _apply(self, fn):
    # 对子模块进行递归调用
    for module in self.children():
        module._apply(fn)
    # 为了 BC-breaking 而新增了一个 tensor 类型判断
    def compute_should_use_set_data(tensor, tensor_applied):
        if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
            # If the new tensor has compatible tensor type as the existing tensor,
            # the current behavior is to change the tensor in-place using `.data =`,
            # and the future behavior is to overwrite the existing tensor. However,
            # changing the current behavior is a BC-breaking change, and we want it
            # to happen in future releases. So for now we introduce the
            # `torch.__future__.get_overwrite_module_params_on_conversion()`
            # global flag to let the user control whether they want the future
            # behavior of overwriting the existing tensor or not.
            return not torch.__future__.get_overwrite_module_params_on_conversion()
        else:
            return False
    # 处理参数及其gradint
    for key, param in self._parameters.items():
        if param is not None:
            # Tensors stored in modules are graph leaves, and we don't want to
            # track autograd history of `param_applied`, so we have to use
            # `with torch.no_grad():`
            with torch.no_grad():
                param_applied = fn(param)
            should_use_set_data = compute_should_use_set_data(param, param_applied)
            if should_use_set_data:
                param.data = param_applied
            else:
                assert isinstance(param, Parameter)
                assert param.is_leaf
                self._parameters[key] = Parameter(param_applied, param.requires_grad)
            if param.grad is not None:
                with torch.no_grad():
                    grad_applied = fn(param.grad)
                should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
                if should_use_set_data:
                    param.grad.data = grad_applied
                else:
                    assert param.grad.is_leaf
                    self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)
    # 处理 buffers
    for key, buf in self._buffers.items():
        if buf is not None:
            self._buffers[key] = fn(buf)
    return self

1.1.4 Apply 函数

nn.Module 还完成了一个 apply 函数,与 _apply() 函数不同的是,apply 函数只是简单地递归调用了 self.children() 去处理自己以及子模块,如下面的代码所示。

def apply(self: T, fn: Callable[['Module'], None]) -> T:
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self

apply 函数和 _apply 函数的差异在于,_apply() 是专门针对 parameter 和 buffer 而完成的一个“仅供内部运用”的接口,可是 apply 函数是“公有”接口 (Python 对类的“公有”和“私有”差异并不是很严格,一般经过单前导下划线来差异)。apply 实际上能够经过修正 fn 来完成 _apply 能完成的功用,一起还能够完成其他功用,如下面给出的从头初始化参数的比如。

  • Example: 参数从头初始化

能够自界说一个 init_weights 函数,经过 net.apply(init_weights) 来初始化模型权重。

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)

1.2 特色的增修正查

1.2.1 特色设置

对 nn.Module 特色的修正有一下三个函数,函数以及对应功用如下

  1. add_module:添加子神经网络模块,更新 self._modules
  2. register_parameter:添加经过 BP 能够更新的 parameters (如 BN 和 Conv 中的 weight 和 bias ),更新 self._parameters
  3. register_buffer:添加不经过 BP 更新的 buffer(如 BN 中的 running_mean 和 running_var),更新 self._buffers,假如 buffer 不是 persistant 的,还会一起更新到 self._non_persistent_buffers_set 中。buffer 是否 persistant 的差异在于这个 buffer 是否会能被放入 self.state_dict 中被保存下来。 这 3 个函数都会先查看 self.__dict__ 中是否包括对应的特色字典以确保 nn.Module 被正确初始化,然后查看特色的 name 是否合法,如不为空 string 且不包括“.”,一起还会查看他们是否现已存在于要修正的特色字典中。

在日常的代码开发过程中,更常见的用法是直接经过 self.xxx = xxx 的方法来添加或修正子神经网络模块、parameters、buffers 以及其他一般的 attribute。这种方法实质上会调用 nn.Module 重载的函数 __setattr__ ,具体的代码如下:

def __setattr__(self, name: str, value: Union[Tensor, 'Module']):
    def remove_from(*dicts_or_sets):
        for d in dicts_or_sets:
            if name in d:
                if isinstance(d, dict):
                    del d[name]
                else:
                    d.discard(name)
    params = self.__dict__.get('_parameters')
    if isinstance(value, Parameter):
        if params is None:
            raise AttributeError(
                "cannot assign parameters before Module.__init__() call")
        remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
        self.register_parameter(name, value)
    elif params is not None and name in params:
        if value is not None:
            raise TypeError("cannot assign '{}' as parameter '{}' "
                            "(torch.nn.Parameter or None expected)"
                            .format(torch.typename(value), name))
        self.register_parameter(name, value)
    else:
        modules = self.__dict__.get('_modules')
        if isinstance(value, Module):
            if modules is None:
                raise AttributeError(
                    "cannot assign module before Module.__init__() call")
            remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
            modules[name] = value
        elif modules is not None and name in modules:
            if value is not None:
                raise TypeError("cannot assign '{}' as child module '{}' "
                                "(torch.nn.Module or None expected)"
                                .format(torch.typename(value), name))
            modules[name] = value
        else:
            buffers = self.__dict__.get('_buffers')
            if buffers is not None and name in buffers:
                if value is not None and not isinstance(value, torch.Tensor):
                    raise TypeError("cannot assign '{}' as buffer '{}' "
                                    "(torch.Tensor or None expected)"
                                    .format(torch.typename(value), name))
                buffers[name] = value
            else:
                object.__setattr__(self, name, value)

从源码中咱们还有如下观察:

  1. 在第 14 行和 28 行,函数查看了承继 nn.Module 的自界说模块是否有正确地初始化父类 nn.Module,这也说明晰 super().init()  的必要性
  2. 在添加 self._parameters,self._modules 的时分,会预先调用 remove_from 函数 (15 和 29 行)从其余私有特色中删去对应的 name,这说明 self.dict,self._buffers,self._parameters,self._modules 中的特色应该是互斥的
  3. 假如要给模块添加 buffer,self.register_buffer 是唯一的方法__setattr__ 只能将 self._buffers 中已有的 buffer 从头赋值为 None 或许 tensor 。这是由于 buffer 的初始化类型便是 torch.Tensor 或许 None,而不像 parameters 和 module 别离是 nn.Parameter 和 nn.Module 类型
  4. 除了其他一般的 attribute,最终 parameters 还是会在 __setattr__ 中经过 register_parameter 来添加,可是子神经网络模块和 buffer 是直接修正的 self._modules 和 self._buffers
  5. 由第三点和前文所述的 _apply 完成能够得出 self.xxxx = torch.Tensor() 是一种不被引荐的行为,由于这样新增的 attribute 既不归于 self._parameters,也不归于 self._buffers,而会被视为一般的 attribute ,在将模块进行状况转化的时分,self.xxxx 会被遗失从而导致 device 或许 type 不一样的 bug

1.2.2 特色删去

特色的删去经过重载的 __delattr__ 来完成,具体代码如下:

def __delattr__(self, name):
    if name in self._parameters:
        del self._parameters[name]
    elif name in self._buffers:
        del self._buffers[name]
        self._non_persistent_buffers_set.discard(name)
    elif name in self._modules:
        del self._modules[name]
    else:
        object.__delattr__(self, name)

__delattr__ 会挨个查看 self._parameters、self._buffers、self._modules 和一般的 attribute 并将 name 从中删去。

1.2.3 常见的特色拜访

nn.Module 中的常用函数包括下面 8 个,他们都会回来一个迭代器用于拜访模块中的 buffer,parameter,子模块等。他们的功用与差异如下

  1. parameters:调用 self.named_parameters 并回来模型参数,被应用于 self.requires_grad_ 和 self.zero_grad 函数中
  2. named_parameters:回来 self._parameters 中的 name 和 parameter 元组,假如 recurse=True 还会回来子模块中的模型参数
  3. buffers:调用 self.named_buffers 并回来模型参数
  4. named_buffers:回来 self._buffers 中的 name 和 buffer 元组,假如 recurse=True 还会回来子模块中的模型 buffer
  5. children:调用 self.named_children,只回来 self._modules 中的模块,被应用于 self.train 函数中
  6. named_children:只回来 self._modules 中的 name 和 module 元组
  7. modules:调用 self.named_modules 并回来各个 module 但不回来 name
  8. named_modules:回来 self._modules 下的 name 和 module 元组,并递归调用和回来 module.named_modules
def _named_members(self, get_members_fn, prefix='', recurse=True):
    memo = set()
    modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
    for module_prefix, module in modules:
        members = get_members_fn(module)
        for k, v in members:
            if v is None or v in memo:
                continue
            memo.add(v)
            name = module_prefix + ('.' if module_prefix else '') + k
            yield name, v
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
    for name, param in self.named_parameters(recurse=recurse):
        yield param
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
    gen = self._named_members(
        lambda module: module._parameters.items(),
        prefix=prefix, recurse=recurse)
    for elem in gen:
        yield elem
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
    for name, buf in self.named_buffers(recurse=recurse):
        yield buf
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
    gen = self._named_members(
        lambda module: module._buffers.items(),
        prefix=prefix, recurse=recurse)
    for elem in gen:
        yield elem
def children(self) -> Iterator['Module']:
    for name, module in self.named_children():
        yield module
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
    memo = set()
    for name, module in self._modules.items():
        if module is not None and module not in memo:
            memo.add(module)
            yield name, module
def modules(self) -> Iterator['Module']:
    for name, module in self.named_modules():
        yield module
def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''):
    if memo is None:
        memo = set()
    if self not in memo:
        memo.add(self)
        yield prefix, self
        for name, module in self._modules.items():
            if module is None:
                continue
            submodule_prefix = prefix + ('.' if prefix else '') + name
            for m in module.named_modules(memo, submodule_prefix):
                yield m

named_parameters 和 named_buffers 都是调用的 self._named_members 完成的,named_modules 和 named_children 尽管有自己的完成,但和 self._named_members 一样,都是经过 set 类型的 memo 来记载现已抛出的模块,假如 member 不在 memo 中,才会将 member 抛出并将 member 放入 memo 中,因此 named_parameters、named_buffers、named_modules 和named_children 都不会回来重复的 parameter、 buffer 或 module

nn.Module 重载了 __dir__ 函数,重载的 __dir__ 函数会将 self._modules、self._parameters 和 self._buffers 中的 attributes 给露出出来。

def __dir__(self):
    module_attrs = dir(self.__class__)
    attrs = list(self.__dict__.keys())
    parameters = list(self._parameters.keys())
    modules = list(self._modules.keys())
    buffers = list(self._buffers.keys())
    keys = module_attrs + attrs + parameters + modules + buffers
    # Eliminate attrs that are not legal Python variable names
    keys = [key for key in keys if not key[0].isdigit()]
    return sorted(keys)

还有一种常见的特色拜访是经过 module.attribute 来进行的。这种调用等价于 getattr(module, 'attribute') 。和 nn.Module 对 __delattr__ 以及 __setattr__ 的重载相似,为了确保 getattr 能拜访到一切的特色,nn.Module 也重载了 __getattr__ 函数,以拜访 self._parameters,self._buffers,self._modules 中的特色。

根据 Python 对实例特色的查找规则,当咱们调用 module.attribute 的时分,Python 会首先查找 module 的 类及其基类的 __dict__,然后查找这个 object 的 __dict__,最终查找 __getattr__ 函数。因此,尽管 nn.Module 的 __getattr__ 只查找了 self._parameters,self._buffers,self._modules 三个成员变量,可是 getattr(module, ‘attribute’) 覆盖的规模和 __dir__ 露出的规模是一致的

def __getattr__(self, name: str) -> Union[Tensor, 'Module']:
    if '_parameters' in self.__dict__:
        _parameters = self.__dict__['_parameters']
        if name in _parameters:
            return _parameters[name]
    if '_buffers' in self.__dict__:
        _buffers = self.__dict__['_buffers']
        if name in _buffers:
            return _buffers[name]
    if '_modules' in self.__dict__:
        modules = self.__dict__['_modules']
        if name in modules:
            return modules[name]
    raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
        type(self).__name__, name))

1.3 Forward & Backward

1.3.1 Hooks

在 nn.Module 的完成文件中,首先完成了 3 个通用的 hook 注册函数,用于注册被应用于大局的 hook。这 3 个函数会将 hook 别离注册进 3 个大局的 OrderedDict,使得一切的 nn.Module 的子类实例在运转的时分都会触发这些 hook。每个 hook 修正的 OrderedDict 如下所示:

  1. register_module_backward_hook:_global_backward_hooks
  2. register_module_forward_pre_hook:_global_forward_pre_hooks
  3. register_module_forward_hook:_global_forward_hooks

同样的,nn.Module 也支撑注册只被应用于自己的 forward 和 backward hook,经过 3 个函数 来办理 自己的 3 个特色并保护 3 个 attribute,他们的类型也是 OrderedDict,每个 hook 修正的 OrderedDict 如下所示:

  1. self.register_backward_hook: self._backward_hooks
  2. self.register_forward_pre_hook: self._forward_pre_hooks
  3. self.register_forward_hook: self._forward_hooks

1.3.2 运转逻辑

nn.Module 在被调用的时分,一般是以 module(input) 的方法,此时会首先调用 self.__call__,接下来这些 hooks 在模块被调用时分的执行顺序如下图所示:

_call_impl 的代码完成如下。注意到 _call_impl 在界说今后被直接赋值给了 __call__ 。一起咱们注意到在 torch._C._get_tracing_state() 为 True 的时分,nn.Module 会经过 _slow_forward() 来调用 forward 函数而非直接调用 forward 函数,这一功用首要用于 JIT。

def _call_impl(self, *input, **kwargs):
    for hook in itertools.chain(
            _global_forward_pre_hooks.values(),
            self._forward_pre_hooks.values()):
        result = hook(self, input)
        if result is not None:
            if not isinstance(result, tuple):
                result = (result,)
            input = result
    if torch._C._get_tracing_state():
        result = self._slow_forward(*input, **kwargs)
    else:
        result = self.forward(*input, **kwargs)
    for hook in itertools.chain(
            _global_forward_hooks.values(),
            self._forward_hooks.values()):
        hook_result = hook(self, input, result)
        if hook_result is not None:
            result = hook_result
    if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
            else:
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in itertools.chain(
                    _global_backward_hooks.values(),
                    self._backward_hooks.values()):
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                grad_fn.register_hook(wrapper)
    return result
__call__ : Callable[..., Any] = _call_impl

1.4 模块存取

1.4.1 Hooks

nn.Module 还有两个相关的 hook 是关于模型参数的加载和存储的,别离是:

  1. _register_state_dict_hook:在self.state_dict()的最终对模块导出的 state_dict 进行修正
  2. _register_load_state_dict_pre_hook:在 _load_from_state_dict 中最早执行

1.4.2 功用完成

nn.Module 运用 state_dict() 函数来进行取得当时的完整状况,用于在模型练习中储存 checkpoint。 模块的 _version 信息会首先存入 metadata 中,用于模型的版本办理,然后会经过 _save_to_state_dict() 将 self._parameters 以及 self._buffers 中的 persistent buffer 进行保存。 用户能够经过重载 _save_to_state_dict 函数来满意特定的需求

nn.Module 运用 load_state_dict() 函数来读取 checkpoint。load_state_dict() 会经过调用每个子模块的_load_from_state_dict 函数来加载他们所需的权重,如下面代码的 55-63 行所示。而 _load_from_state_dict 才是真正担任加载 parameter 和 buffer 的函数。这也说明晰每个模块能够自行界说他们的 _load_from_state_dict 函数来满意特别需求,实际上这也是 PyTorch 官方引荐的做法。在后面的两个比如中,咱们也给出了 _load_from_state_dict 的运用比如。

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                          missing_keys, unexpected_keys, error_msgs):
    for hook in self._load_state_dict_pre_hooks.values():
        hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
    persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
    local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
    local_state = {k: v for k, v in local_name_params if v is not None}
    for name, param in local_state.items():
        key = prefix + name
        if key in state_dict:
            input_param = state_dict[key]
            # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
            if len(param.shape) == 0 and len(input_param.shape) == 1:
                input_param = input_param[0]
            if input_param.shape != param.shape:
                # local shape should match the one in checkpoint
                error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
                                  'the shape in current model is {}.'
                                  .format(key, input_param.shape, param.shape))
                continue
            try:
                with torch.no_grad():
                    param.copy_(input_param)
            except Exception as ex:
                error_msgs.append('While copying the parameter named "{}", '
                                  'whose dimensions in the model are {} and '
                                  'whose dimensions in the checkpoint are {}, '
                                  'an exception occurred : {}.'
                                  .format(key, param.size(), input_param.size(), ex.args))
        elif strict:
            missing_keys.append(key)
    if strict:
        for key in state_dict.keys():
            if key.startswith(prefix):
                input_name = key[len(prefix):]
                input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
                if input_name not in self._modules and input_name not in local_state:
                    unexpected_keys.append(key)
def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]], strict: bool = True):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata
    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')
    load(self)
    load = None  # break load->load reference cycle
    if strict:
        if len(unexpected_keys) > 0:
            error_msgs.insert(
                0, 'Unexpected key(s) in state_dict: {}. '.format(
                    ', '.join('"{}"'.format(k) for k in unexpected_keys)))
        if len(missing_keys) > 0:
            error_msgs.insert(
                0, 'Missing key(s) in state_dict: {}. '.format(
                    ', '.join('"{}"'.format(k) for k in missing_keys)))
    if len(error_msgs) > 0:
        raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
                           self.__class__.__name__, "\n\t".join(error_msgs)))
    return _IncompatibleKeys(missing_keys, unexpected_keys)

1.4.3 _load_from_state_dict 妙用

  • Example: 防止 BC-breaking

在模型迭代的过程中,module 很简单出现 BC-breaking ,PyTorch 经过 _version 和 _load_from_state_dict 来处理的这类问题(这也是 PyTorch 引荐的方法)。 下面的代码是 _NormBase 类防止 BC-breaking 的方法。在 PyTorch 的开发过程中,Normalization layers 在某个新版本中 引入了 num_batches_tracked 这个 key,给 BN 记载练习过程中经历的 batch 数,为了兼容旧版本练习的模型,PyTorch 修正了 _version,并修正了 _load_from_state_dict

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                          missing_keys, unexpected_keys, error_msgs):
    version = local_metadata.get('version', None)
    if (version is None or version < 2) and self.track_running_stats:
        # at version 2: added num_batches_tracked buffer
        #               this should have a default value of 0
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key not in state_dict:
            state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)
    super(_NormBase, self)._load_from_state_dict(
        state_dict, prefix, local_metadata, strict,
        missing_keys, unexpected_keys, error_msgs)

这儿再举一个 MMCV 中的比如,DCN 经历了一次重构,特色的姓名经过了重命名。

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                          missing_keys, unexpected_keys, error_msgs):
    version = local_metadata.get('version', None)
    if version is None or version < 2:
        # the key is different in early versions
        # In version < 2, DeformConvPack loads previous benchmark models.
        if (prefix + 'conv_offset.weight' not in state_dict
                and prefix[:-1] + '_offset.weight' in state_dict):
            state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
                prefix[:-1] + '_offset.weight')
        if (prefix + 'conv_offset.bias' not in state_dict
                and prefix[:-1] + '_offset.bias' in state_dict):
            state_dict[prefix +
                       'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
                                                            '_offset.bias')
    if version is not None and version > 1:
        print_log(
            f'DeformConv2dPack {prefix.rstrip(".")} is upgraded to '
            'version 2.',
            logger='root')
    super()._load_from_state_dict(state_dict, prefix, local_metadata,
                                  strict, missing_keys, unexpected_keys,
                                  error_msgs)
  • Example: 模型无痛搬迁

假如在 MMDetection 中练习了一个 detector,MMDetection3D 中的多模态检测器想要加载这个预练习的检测器,很多权重姓名对不上,又不想写一个脚本手动来转,能够运用 _load_from_state_dict 来进行。经过这种方法,MMDetection3D 能够加载并运用 MMDetection 练习的任意一个检测器。

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                          missing_keys, unexpected_keys, error_msgs):
    # override the _load_from_state_dict function
    # convert the backbone weights pre-trained in Mask R-CNN
    # use list(state_dict.keys()) to avoid
    # RuntimeError: OrderedDict mutated during iteration
    for key_name in list(state_dict.keys()):
        key_changed = True
        if key_name.startswith('backbone.'):
            new_key_name = f'img_backbone{key_name[8:]}'
        elif key_name.startswith('neck.'):
            new_key_name = f'img_neck{key_name[4:]}'
        elif key_name.startswith('rpn_head.'):
            new_key_name = f'img_rpn_head{key_name[8:]}'
        elif key_name.startswith('roi_head.'):
            new_key_name = f'img_roi_head{key_name[8:]}'
        else:
            key_changed = False
        if key_changed:
            logger = get_root_logger()
            print_log(
                f'{key_name} renamed to be {new_key_name}', logger=logger)
            state_dict[new_key_name] = state_dict.pop(key_name)
    super()._load_from_state_dict(state_dict, prefix, local_metadata,
                                  strict, missing_keys, unexpected_keys,
                                  error_msgs)

Reference

  • Pytorch nn.Module 文档
  • MMCV 中 DCN 的完成
  • MMDetection3D