从单个字的角度:
q
i
=
h
i
W
Q
,
k
j
=
h
j
W
K
,
v
j
=
h
j
W
V
q_i = h_iW_Q,k_j = h_jW_K,v_j = h_jW_V
qi=hiWQ,kj=hjWK,vj=hjWV
e
i
j
=
q
i
k
j
T
e_{ij} = q_ik_j^T
eij=qikjT
α
i
=
S
o
f
t
m
a
x
(
[
e
i
,
1
,
.
.
.
,
e
i
,
T
]
)
\alpha_i = Softmax([e_{i,1},...,e_{i,T}])
αi=Softmax([ei,1,...,ei,T])
h
i
′
=
(
∑
j
=
1
T
α
i
,
j
v
j
)
W
0
h'_i = (\sum_{j=1}^T \alpha_{i,j}v_j)W_0
hi′=(∑j=1Tαi,jvj)W0
矩阵的形式:
Q
=
H
W
Q
,
K
=
H
W
K
,
V
=
H
W
V
Q = HW_Q,K = HW_K,V = HW_V
Q=HWQ,K=HWK,V=HWV
E
=
Q
K
T
E = QK^T
E=QKT
E
′
=
S
o
f
t
m
a
x
(
E
)
E' = Softmax(E)
E′=Softmax(E)
H
′
=
E
′
V
H' = E'V
H′=E′V
import math import torch import torch.nn as nn class SelfAttention(nn.Module): def __init__(self,d_model,d_head): super(SelfAttention,self).__init__() self.w_q = nn.Linear(d_model,d_head) self.w_k = nn.Linear(d_model,d_head) self.w_v = nn.Linear(d_model,d_head) self.w_o = nn.Linear(d_head,d_model) def forward(self,x): # x:[batch_size,max_len,model_dim] # q,k,v:[batch_size,max_len,d_head] q = self.w_q(x) k = self.w_k(x) v = self.w_v(x) attn_score = torch.matmul(q,k.permute(0,2,1)) # 注意这里不是reshape attn_score = torch.softmax(attn_score,dim = -1) # [batch_size,max_len,max_len] output = torch.matmul(attn_score,v) # [batch_size,max_len,d_head] return self.w_o(output) x = torch.randn(3,9,100) model = SelfAttention(100,80) model(x).shape
# 多头selfattention class MultiHeadSelfAttention(nn.Module): def __init__(self,d_model = 768,d_head = 64): super(MultiHeadSelfAttention,self).__init__() assert d_model % d_head == 0 self.w_q = nn.Linear(d_model,d_model) self.w_k = nn.Linear(d_model,d_model) self.w_v = nn.Linear(d_model,d_model) self.w_o = nn.Linear(d_model,d_model) self.n_heads = int(d_model // d_head) self.d_model = d_model self.d_head = d_head def forward(self,x,mask = None): batch_size = x.shape[0] max_len = x.shape[1] q = self.w_q(x).view(batch_size,max_len,self.n_heads,self.d_head) k = self.w_k(x).view(batch_size,max_len,self.n_heads,self.d_head) v = self.w_v(x).view(batch_size,max_len,self.n_heads,self.d_head) q = q.permute(0,2,1,3) k = k.permute(0,2,1,3) v = v.permute(0,2,1,3) # [batch_size,num_head,max_len,d_head] attn_score = torch.matmul(q,k.permute(0,1,3,2)) if mask is not None: mask = mask.unsqueeze(1).unsqueeze(-1) # [batch_size,1,max_len,1] attn_score = attn_score.masked_fill(mask == 0,-1e-25) attn_score = torch.softmax(attn_score,-1) # [batch_size,num_head,max_len,max_len] out = torch.matmul(attn_score,v).permute(0,2,1,3) out = out.contiguous().view(batch_size,max_len,-1) return self.w_o(out) if __name__ == "__main__": x = torch.randn(2, 9, 768) mask = torch.tensor([ [1, 1, 1, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0], ]).bool() model = MultiHeadSelfAttention() print(model(x,mask).shape)