Save you from anything

0%

pytorch中deepcopy/detach/clone、view/reshape的区别

两个关于pytorch的小tips。在折腾变分词嵌入时,我被复杂损失函数弄得头大,因而试图将变分词嵌入在pytorch上实现,不过最后没实现成,但折腾途中了解了一些别的知识。

(大概是2020年6月左右的事)

clone/detach/deepcopy

clone

对一个张量调用clone方法,会:

  • 创建出完全新的对象,即与被源对象不使用同一块内存
  • 不引用、指向源对象的子对象和数据
  • 新对象会复制源对象的计算历史(梯度图)

这意味可以通过clone复制出的对象可以完全取代源对象进行反向传播。

detach

对一个张量调用detach方法,会:

  • 返回一个新的张量对象,这个张量有着独立的元数据
  • 但这个张量中的每个元素都指向源对象的元素
  • 这个张量不继承梯度图,requires_grad也为false

这个操作通常在需要截断反向传播的地方使用,或者是配合clone指令返回一个只有张量值的拷贝

1
y = x.clone().detach()

deepcopy

deepcopy是python自带的复制方法(copy.deepcopy),用于复制一个对象,以及对象内的子对象。复制出的新对象与源对象没有任何关系。

就我看到的情况中,当其作用于pytorch的张量时,效果与clone一样。

view/reshape

这两个函数都是对张量的形状进行调整,区别在于:

view直接返回原始张量的新视图对象,这个对象只是形状发生了改变,底层数据仍然指向原始数据。view需要在连续的数据上执行。

reshape等价于.contiguous().view(),它可能是原始张量的新视图,即仍然指向原始数据(如果原始数据在内存中连续,不需要调整数据);也有可能是新的张量(如果数据发生了调整)。

如果需要保证是在原数据上进行修改,就只能使用view。