涉及以下内容
简述
例如
a = torch.rand(3, 4),b = torch.rand(4, 5),c = torch.einsum("ik,kj->ij", [a, b]) # einsum 的第一个参数 "ik,kj->ij" 描述张量的计算规则,且维度的字符只能是26个英文字母 'a' - 'z' # einsum 的第一个参数可以不写包括箭头在内的右边部分,比如矩阵乘法 "ik,kj" 等价于 "ik,kj->ij" 输# 出保留输入只出现一次的索引,按字母表顺序排列 # einsum 的第一个参数支持 "..." 省略号,用于表示用户不关心的索引, # einsum 的第二个参数 [a, b] 表示实际的输入张量列表,且真实维度需匹配规则 # 索引顺序可以任意,但 "ik,kj->ij" 如果写成 "ik,kj->ji" 后一将返回前一的转置
实践
import torch import numpy as np # 1:矩阵乘法 a = torch.rand(2, 3) b = torch.rand(3, 4) ein_out = torch.einsum("ik,kj->ij", [a, b]).numpy() # ein_out = torch.einsum("ik,kj", [a, b]).numpy() org_out = torch.mm(a, b).numpy() print("input:\n", a, b, sep='\n') print("ein_out: \n", ein_out) print("org_out: \n", org_out) print("is org_out == ein_out ?", np.allclose(ein_out, org_out)) # 2:矩阵点乘 a = torch.arange(6).reshape(2, 3) b = torch.arange(6, 12).reshape(2, 3) ein_out = torch.einsum('ij,ij->ij', [a, b]).numpy() org_out = torch.mul(a, b).numpy() print("input:\n", a, b, sep='\n') print("ein_out: \n", ein_out) print("org_out: \n", org_out) print("is org_out == ein_out ?", np.allclose(ein_out, org_out)) # 3:张量后两维乘法 a = torch.randn(2, 3, 5) b = torch.randn(2, 5, 3) ein_out = torch.einsum('ijk,ikl->ijl', [a, b]).numpy() org_out = torch.matmul(a, b).numpy() # org_out = torch.bmm(a, b).numpy() # batch矩阵乘法 print("input:\n", a, b, sep='\n') print("ein_out: \n", ein_out) print("org_out: \n", org_out) print("is org_out == ein_out ?", np.allclose(ein_out, org_out)) # 4:矩阵转置 a = torch.arange(6).reshape(2, 3) ein_out = torch.einsum('ij->ji', [a]).numpy() org_out = torch.transpose(a, 0, 1).numpy() print("input:\n", a) print("ein_out: \n", ein_out) print("org_out: \n", org_out) print("is org_out == ein_out ?", np.allclose(ein_out, org_out)) # 5:张量后两维转置 a = torch.randn(1, 2, 3, 4, 5) ein_out = torch.einsum('...ij->...ji', [a]).numpy() org_out = a.permute(0, 1, 2, 4, 3).numpy() print("input:\n", a) print("ein_out: \n", ein_out) print("org_out: \n", org_out) print("is org_out == ein_out ?", np.allclose(ein_out, org_out)) # 6:矩阵求和 a = torch.arange(6).reshape(2, 3) ein_out = torch.einsum('ij->', a).numpy() org_out = torch.sum(a).numpy() ein_out_i = torch.einsum('ij->i', a).numpy() org_out_i = torch.sum(a, dim=1).numpy() ein_out_j = torch.einsum('ij->j', a).numpy() org_out_j = torch.sum(a, dim=0).numpy() print("input:\n", a) print("ein_out: \n", ein_out) print("org_out: \n", org_out) print("is org_out == ein_out ?", np.allclose(ein_out, org_out)) print("input:\n", a) print("ein_out_i: \n", ein_out_i) print("org_out_i: \n", org_out_i) print("is org_out_i == ein_out_i ?", np.allclose(ein_out, org_out)) print("input:\n", a) print("ein_out_j: \n", ein_out_j) print("org_out_J: \n", org_out_j) print("is org_out_j == ein_out_j ?", np.allclose(ein_out, org_out)) # 7:矩阵提取对角线元素 a = torch.arange(9).reshape(3, 3) ein_out = torch.einsum('ii->i', a).numpy() org_out = torch.diagonal(a, 0).numpy() print("input:\n", a) print("ein_out: \n", ein_out) print("org_out: \n", org_out) print("is org_out == ein_out ?", np.allclose(ein_out, org_out)) # 8:矩阵向量乘法 a = torch.rand(3, 4) b = torch.arange(4.0) ein_out = torch.einsum('ik,k->i', [a, b]).numpy() # ein_out_k = torch.einsum('ik,k', [a, b]).numpy() org_out = torch.mv(a, b).numpy() print("input:\n", a, b, sep='\n') print("ein_out_k: \n", ein_out) print("org_out_k: \n", org_out) print("is org_out_k == ein_out_k ?", np.allclose(ein_out, org_out)) # 9:向量内积 a = torch.arange(3) b = torch.arange(3, 6) ein_out = torch.einsum('i,i->', [a, b]).numpy() # ein_out = torch.einsum('i,i', [a, b]).numpy() org_out = torch.dot(a, b).numpy() print("input:\n", a, b, sep='\n') print("ein_out: \n", ein_out) print("org_out: \n", org_out) print("is org_out == ein_out ?", np.allclose(ein_out, org_out)) # 10:向量外积 a = torch.arange(3) b = torch.arange(3, 5) ein_out = torch.einsum('i,j->ij', [a, b]).numpy() # ein_out = torch.einsum('i,j', [a, b]).numpy() org_out = torch.outer(a, b).numpy() print("input:\n", a, b, sep='\n') print("ein_out: \n", ein_out) print("org_out: \n", org_out) print("is org_out == ein_out ?", np.allclose(ein_out, org_out)) # 11:张量收缩 a = torch.randn(1, 3, 5, 7) b = torch.randn(11, 33, 3, 55, 5) ein_out = torch.einsum('pqrs,tuqvr->pstuv', [a, b]).numpy() org_out = torch.tensordot(a, b, dims=([1, 2], [2, 4])).numpy() print("input:\n", a, b, sep='\n') print("ein_out: \n", ein_out) print("org_out: \n", org_out) print("is org_out == ein_out ?", np.allclose(ein_out, org_out))