pytorch中储存多个model对象不能使用list,而是应该使用pytorch提供的ModuleList。
在用pytorch复现DRMM的时候我遇到了一个需要大量重复生成层的地方,于是写了个for循环,把model存在list中,想着list能储存对象,这么写应该可以:
1 | layers_input = [] # 创建ModuleList |
然而这么写并不可以,这种写法影响torch对每个模型的追踪。
pytorch提供了一个叫做ModuleList()的方,这个方法可以生成专门用于存放model的model_list:
1 | layers_input = torch.nn.ModuleList() # 创建ModuleList |
创建完成后,不管是向里面塞model,还是调用其中的model,都跟普通的list一样。
1 | for layer in self.layers_input: |