Save you from anything

0%

pytorch中的model对象的列表:ModuleList

pytorch中储存多个model对象不能使用list,而是应该使用pytorch提供的ModuleList。

在用pytorch复现DRMM的时候我遇到了一个需要大量重复生成层的地方,于是写了个for循环,把model存在list中,想着list能储存对象,这么写应该可以:

1
2
3
4
5
6
layers_input = []  # 创建ModuleList
for i in range(len(hidden_size_list)):
temp = torch.nn.Linear() # 创建Linear
torch.nn.init.uniform_(temp.weight, -0.1, 0.1) # 初始化Liner权重
layers_input.append(temp)
layers_input.append(torch.nn.Tanh())

然而这么写并不可以,这种写法影响torch对每个模型的追踪。

pytorch提供了一个叫做ModuleList()的方,这个方法可以生成专门用于存放model的model_list:

1
2
3
4
5
6
layers_input = torch.nn.ModuleList()  # 创建ModuleList
for i in range(len(hidden_size_list)):
temp = torch.nn.Linear() # 创建Linear
torch.nn.init.uniform_(temp.weight, -0.1, 0.1) # 初始化Liner权重
layers_input.append(temp)
layers_input.append(torch.nn.Tanh())

创建完成后,不管是向里面塞model,还是调用其中的model,都跟普通的list一样。

1
2
for layer in self.layers_input:
high = layer(high) # Dense/Activation('tanh')