torch.einsum函数学习

torch.einsum

简单例子

基础知识见csdn,这篇博客写的比较好。简而言之,einsum就是爱因斯坦求和简记法的实现,这里引用该文的一个例子,非常好懂。

print(a_tensor)
 
tensor([[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]])
 
print(b_tensor)
 
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]])
 
# 'ik, kj -> ij'语义解释如下:
# 输入a_tensor: 2维数组,下标为ik,
# 输入b_tensor: 2维数组,下标为kj,
# 输出output:2维数组,下标为ij。
# 隐含语义:输入a,b下标中相同的k,是求和的下标,对应上面的例子2的公式
output = torch.einsum('ik, kj -> ij', a_tensor, b_tensor)
 
print(output)
 
tensor([[130, 130, 130, 130],
        [230, 230, 230, 230],
        [330, 330, 330, 330],
        [430, 430, 430, 430]])

高维案例

但是对于高维案例来说,这个简记法就么那么直观了。同样引用原文的例子

a = np.arange(60.).reshape(3,4,5)
b = np.arange(24.).reshape(4,3,2)
 
# 语义解析:
# 输入a:3阶张量,下标为ijk
# 输入b: 3阶张量,下标为jil
# 输出o: 2阶张量,下标为k和l
# 隐含语义:对i,j进行求和,公式附于代码之后:
o = np.einsum('ijk,jil->kl', a, b)
print(o)
 
array([[4400., 4730.],
       [4532., 4874.],
       [4664., 5018.],
       [4796., 5162.],
       [4928., 5306.]])
 
# 验证:
print(np.sum(a[:,:,0]*b[:,:,0].T))
 
4400.0
 
print(np.sum(a[:,:,1]*b[:,:,0].T))
 
4532.0

上述式子从k,l -> kl可以猜想到其计算过程o[k,l] = a[i,j,k] * b[j,i,l],然后每个元素分别计算出对应的值。

更加复杂

但是对于更复杂的例子,就需要进一步思考。下面对于Informer中的attention计算为例子进行进一步学习。

queries = torch.arange(0,120).view(2,3,4,5)
keys = torch.arange(0,120).view(2,3,4,5)
values = torch.arange(0,120).view(2,3,4,5)
B, L, H, E = queries.shape
_, S, _, D = values.shape
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
result = torch.zeros_like(scores)

for b in range(B):
    for l in range(L):
        for h in range(H):
            for s in range(S):
                for e in range(E):
                    result[b,h,l,s] += queries[b,l,h,e]*keys[b,s,h,e]
                    # result[b,h,l,s] = torch.sum(queries[b,l,h,:]*keys[b,s,h,:])

从这个例子中可以看出求和符号的计算。

0%