import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.add_module("conv", nn.Conv2d(10, 20, 4))
self.add_module("conv1", nn.Conv2d(20 ,10, 4))
model = Model()
for module in model.modules():
print(module)
Model (
(conv): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
(conv1): Conv2d(20, 10, kernel_size=(4, 4), stride=(1, 1))
)
Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
Conv2d(20, 10, kernel_size=(4, 4), stride=(1, 1))
可以看出,modules()返回的iterator不止包含子模块。这是和childern()的不同。NOTE:重复的模块只被返回一次(children()也是)。在下面的例子中submodule只会被返回一次。
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
submodule = nn.Conv2d(10, 20, 4)
self.add_module("conv", submodule)
self.add_module("conv1", submodule)
model = Model()
for module in model.modules():
print(module)
Model (
(conv): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1)) , →
(conv1): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1)) , →
)
Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
named_children()
返回包含模型当前子模块的迭代器,yield模块名字和模块本身。 例子:
for name, module in model.named_children():
if name in ['conv4', 'conv5']:
print(module)
named_modules(memo=None, prefix=”) 返回包含网络中所有模块的迭代器, yield
ing 模块名和模块本身。
重复的模块只被返回一次 (children() 也是)。在下面的例子中, submodule 只会被返回一次。
--parameters(memo=None) 返回一个 包含模型所有参数 的迭代器。
一般用来当作optinizer的参数。
例子:
for param in model.parameters():
print(type(param.data), param.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
register_backward_hook(hook)
在 module 上注册一个 bachward hook。
每次计算 module 的 inputs 的梯度的时候,这个 hook 会被调用。hook 应该拥有
下面的 signature。
hook(module, grad_input, grad_output) -> Variable or None
如果 module 有多个输入输出的话,那么 grad_input grad_output 将会是个 tuple。
hook 不应该修改它的 arguments,但是它可以选择性的返回关于输入的梯度,这
个返回的梯度在后续的计算中会替代 grad_input。
这个函数返回一个 句柄 (handle)。它有一个方法 handle.remove(),可以用这个
方法将 hook 从 module 移除。
– register_buffer(name, tensor) 给 module 添加一个 persistent buffer。
persistent buffer 通常被用在这么一种情况:我们需要保存一个状态,但是这个
状态不能看作成为模型参数。例如:, BatchNorm’s running_mean 不是一个
parameter, 但是它也是需要保存的状态之一。
Buffers 可以通过注册时候的 name 获取。
NOTE: 我们可以用 buffer 保存 moving average
例子:
self.register_buffer('running_mean',
torch.zeros(num_features))
self.running_mean