''' Description: attention注意力机制 Autor: 365JHWZGo Date: 2021-12-14 17:06:11 LastEditors: 365JHWZGo LastEditTime: 2021-12-14 22:23:54 '''
导入库
import torch import torch.nn as nn import torch.nn.functional as F
Attn类
class Attn(nn.Module): def __init__(self,query_size,key_size,value_size1,value_size2,output_size): super(Attn,self).__init__() self.query_size = query_size self.key_size = key_size self.value_size1 = value_size1 self.value_size2 = value_size2 self.output_size = output_size self.attn = nn.Linear(self.query_size+self.key_size,value_size1) self.attn_combine = nn.Linear(self.query_size+value_size2,output_size) def forward(self,q,k,v): # attn_weights=(1,32) attn_weights = F.softmax(self.attn(torch.concat((q[0],k[0]),1)),dim=1) # attn_weights.unsqueeze(0)=(1,1,32) # v=(1,32,64) # attn_applied=(1,1,64) attn_applied = torch.bmm(attn_weights.unsqueeze(0),v) # q[0]=(1,32) # attn_applied[0]=(1,64) # output=(1,96) output = torch.concat((q[0],attn_applied[0]),1) # output=(1,1,64) output = self.attn_combine(output).unsqueeze(0) return output,attn_weights
attn函数是将合成【Query|Key】,进行列合并
attn_conbine函数是生成【Query|attn_applied】,attn_applied是最后Query在SourceSource中的真正注意力分布
attn_weights的结果对应于a1,a2,a3…
attn_applied是计算Attention Value,bmm相当于a1value1+a2value2+…【矩阵乘法】
第二个W矩阵是训练得到的参数,维度是d2 x d1,d2是s的hidden state输出维数,d1是hi的hidden state维数
key=h
query=s
if __name__ == "__main__": query_size = 32 key_size = 32 # value 第二维度 value_size1 = 32 # value 第三个维度 value_size2 = 64 # 输出维度 output_size = 64 attn = Attn(query_size, key_size, value_size1, value_size2, output_size) Q = torch.randn(1,1,32) K = torch.randn(1,1,32) V = torch.randn(1,32,64) out = attn(Q, K ,V) print(out[0]) print(out[1])
tensor([[[ 0.2658, 0.0392, 0.2432, -0.6333, -0.2197, -0.0189, -0.2440, 0.2307, 0.3793, 0.1152, 0.3247, -0.0377, 0.5529, -0.2616, -0.1077, -0.2078, -0.2510, -0.4814, -0.2096, -0.1568, -0.0288, 0.0595, -0.2944, 0.1996, -0.2253, -0.1753, 0.3036, 0.4191, 0.0869, -0.4587, 0.0630, -0.0472, 0.1013, 0.2068, 0.0144, -0.5463, -0.0487, 0.2278, -0.2225, -0.2994, -0.2592, -0.0371, 0.0615, 0.3353, -0.2891, -0.1839, 0.3867, 0.2469, 0.1036, 0.2699, 0.1983, 0.0683, -0.3410, -0.1992, 0.5660, 0.0794, -0.2826, 0.0421, 0.0635, 0.1220, 0.1333, -0.2451, -0.4481, -0.1631]]], grad_fn=<UnsqueezeBackward0>) tensor([[0.0151, 0.0451, 0.0093, 0.0251, 0.0379, 0.0177, 0.0277, 0.0301, 0.0200, 0.0415, 0.0309, 0.0440, 0.0248, 0.0419, 0.0191, 0.0287, 0.0564, 0.0132, 0.0442, 0.0473, 0.0359, 0.0154, 0.0195, 0.0652, 0.0255, 0.0178, 0.0287, 0.0291, 0.0411, 0.0548, 0.0190, 0.0280]], grad_fn=<SoftmaxBackward0>)