两个关于pytorch的小tips。在折腾变分词嵌入时,我被复杂损失函数弄得头大,因而试图将变分词嵌入在pytorch上实现,不过最后没实现成,但折腾途中了解了一些别的知识。
(大概是2020年6月左右的事)
clone/detach/deepcopy
clone
对一个张量调用clone方法,会:
- 创建出完全新的对象,即与被源对象不使用同一块内存
- 不引用、指向源对象的子对象和数据
- 新对象会复制源对象的计算历史(梯度图)
这意味可以通过clone复制出的对象可以完全取代源对象进行反向传播。
detach
对一个张量调用detach方法,会:
- 返回一个新的张量对象,这个张量有着独立的元数据
- 但这个张量中的每个元素都指向源对象的元素
- 这个张量不继承梯度图,requires_grad也为false
这个操作通常在需要截断反向传播的地方使用,或者是配合clone指令返回一个只有张量值的拷贝
1 | y = x.clone().detach() |
deepcopy
deepcopy是python自带的复制方法(copy.deepcopy),用于复制一个对象,以及对象内的子对象。复制出的新对象与源对象没有任何关系。
就我看到的情况中,当其作用于pytorch的张量时,效果与clone一样。
view/reshape
这两个函数都是对张量的形状进行调整,区别在于:
view直接返回原始张量的新视图对象,这个对象只是形状发生了改变,底层数据仍然指向原始数据。view需要在连续的数据上执行。
reshape等价于.contiguous().view(),它可能是原始张量的新视图,即仍然指向原始数据(如果原始数据在内存中连续,不需要调整数据);也有可能是新的张量(如果数据发生了调整)。
如果需要保证是在原数据上进行修改,就只能使用view。