torch.roll函数学习

torch.roll函数是真的比较难以理解,我觉得之后我碰上可能也不一定能转过弯来,因此写一篇博客记录一下。

torch.roll的文档在这里。从官方文档可以看到,torch.roll(input,shifts,dims)中的三个参数意思都是比较明确的,input即为输入的tensor,shifts表示位移的距离,dims为位移的方向。其中shifts和dims既可以为数字,也可以为元组。

其中最让人困惑的莫过于dims,特别是在高维(大于3维)的时候,基本上感觉怎么移动都不对味。比如我目前需要对一个形状为(batch_size, time_length, feature_size)的向量在时间维度上进行迁移,感觉就怎么都不对。

官方例子的改版:

import torch
x = torch.arange(0,9).view(3,3)
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [6, 7, 8]])
torch.roll(x,-1,1)
# tensor([[2, 3, 1],
#         [5, 6, 4],
#         [8, 9, 7]])
print(x[:,0])
# tensor([1, 4, 7])

可以看到,上面这个例子在dim=1的维度进行操作,而最终这个维度上是没有发生变化的(其他维度上均发生位移)。

在三维上的例子

x = torch.arange(0,27).view(3,3,3)
# tensor([[[ 0,  1,  2],
#          [ 3,  4,  5],
#          [ 6,  7,  8]],

#         [[ 9, 10, 11],
#          [12, 13, 14],
#          [15, 16, 17]],

#         [[18, 19, 20],
#          [21, 22, 23],
#          [24, 25, 26]]])
x = torch.roll(x,-1,1)
# tensor([[[ 3,  4,  5],
#          [ 6,  7,  8],
#          [ 0,  1,  2]],

#         [[12, 13, 14],
#          [15, 16, 17],
#          [ 9, 10, 11]],

#         [[21, 22, 23],
#          [24, 25, 26],
#          [18, 19, 20]]])
# 效果是等价的
# >>> torch.roll(x[0],-1,0)
# tensor([[3, 4, 5],
#         [6, 7, 8],
#         [0, 1, 2]])

通过最后的备注语句可以看到,在三维的情况下相当于是对二维的情况进行了广播,这里用PPT简单画了一下。其中的dim=1相当于图中的y-z平面。因此可以显而易见的看到,沿dim=1进行roll,相当于是把y-z平面顺次平移。

example_ppt

0%