PyTorch里用于初始化的super函数是什么


在使用PyTorch的创建模型的时候,通常要继承 nn.Module 类。而且其中用到了super这个函数,这其实涉及到了OOP里关于继承的知识。这篇文章就是解释super函数做了什么,方便大家理解PyTorch里为什么要这么写。

TL;DR

super()函数可以隐式地将子类(sublass)里的method,与父类(superclass)里的method进行关联。这样的好处在于我们不用在子类里显式地重新创建父类method里的属性。

super函数讲解

先来看一个不使用super的例子。父类是Person,子类是Runner

# superclass
class Person:
  def __init__(self, name, age):
    self.name = name
    self.age = age

  def get_name(self):
    return self.name

  def get_age(self):
    return self.age


# subclass
class Runner(Person):
  def __init__(self, runner_name, runner_age, runner_height):
    Person.__init__(self, runner_name, runner_age)    
    self.runner_height = runner_height

  def get_height(self):
    return self.runner_height

# test superclass and subclass
jack = Person('Jack', 23)
print(jack.get_age()) # 23

kipchoge = Runner('Kipchoge', 35, 167)
print(kipchoge.get_height()) # 35
print(kipchoge.get_name()) # 167

在这个例子里,我们显式地声明了 Person.__init__(self, runner_name, runner_age) 来调用父类初始化runner_namerunner_age这两个属性。但是如果使用super()函数的话,写法能简化:

class Runner(Person):
  def __init__(self, runner_name, runner_age, runner_height):
    super().__init__(runner_name, runner_age) # use super
    self.runner_height = runner_height

    def get_height(self):
      return self.runner_height

这里的super()代表了父类的Person,而且也不用写入self作为第一个参数了。

super函数的多层继承

下面讲一下super函数用于多层继承(multilevel inheritance)的情况,简单来说,就是之继承最近的那个父类。

class A:
    def __init__(self):
        print('Initializing: class A')

    def sub_method(self, b):
        print('Printing from class A:', b)


class B(A):
    def __init__(self):
        print('Initializing: class B')
        super().__init__()

    def sub_method(self, b):
        print('Printing from class B:', b)
        super().sub_method(b + 1)


class C(B):
    def __init__(self):
        print('Initializing: class C')
        super().__init__()

    def sub_method(self, b):
        print('Printing from class C:', b)
        super().sub_method(b + 1)


if __name__ == '__main__':
    c = C()
    c.sub_method(1)

# Initializing: class C
# Initializing: class B
# Initializing: class A
# Printing from class C: 1
# Printing from class B: 2
# Printing from class A: 3

c = C() 创建了一个class C的实例,然后可以看到初始化是从C->B->A的。

c.sub_method(1) 首先调用了C类里的sub_method(),输出了1,然后通过super().sub_method(b + 1)调用了B类里的sub_method()。可以看到C类里的super()就是代替了class C(B)里的B类。

PyTorch中的super函数

下面的代码创建了一个简单的CNN模型,用于图像分类。根据我们上面讲过的super()的只是,可以知道super().__init__()这一行其实就是调用了nn.Module.__init__()

import torch.nn as nn
import torch.nn.functional as F


class Network(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
    self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

    self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
    self.fc2 = nn.Linear(in_features=120, out_features=60)
    self.out = nn.Linear(in_features=60, out_features=10)

  def forward(self, t):
    t = F.relu(self.conv1(t))
    t = F.max_pool2d(t, kernel_size=2, stride=2)

    t = F.relu(self.conv2(t))
    t = F.max_pool2d(t, kernel_size=2, stride=2)

    t = t.flatten(start_dim=1) # t = t.reshape(-1, 12 * 4 * 4)
    t = F.relu(self.fc1(t))
    t = F.relu(self.fc2(t))
    t = self.out(t)

    return t

我们实际去看一下nn.Module.__init__()源代码初始化了哪些东西:

class Module(object):
  def __init__(self):
    """
    Initializes internal Module state, shared by both nn.Module and ScriptModule.
    """
    torch._C._log_api_usage_once("python.nn_module")

    self.training = True
    self._parameters = OrderedDict()
    self._buffers = OrderedDict()
    self._backward_hooks = OrderedDict()
    self._forward_hooks = OrderedDict()
    self._forward_pre_hooks = OrderedDict()
    self._state_dict_hooks = OrderedDict()
    self._load_state_dict_pre_hooks = OrderedDict()
    self._modules = OrderedDict()

也就是说,继承了nn.Module的话,也就初始化了上面这些参数。

参考资料


文章作者: BrambleXu
版权声明: 本博客所有文章除特別声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来源 BrambleXu !
评论
  目录