input: speech(16,80,183)# 183属于batch中最大元素决定 speech_length text (16,6)# 6由batch最大值决定 text_length
mask :(16,183)
input(speech,mask)
torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU() # output (16,256,45,19)
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) # output (16,45,256)
pos_emb (1,45,256) # output # speech = torch.nn.Dropout(speech) (16,45,256) # pos_emb = torch.nn.Dropout(pos_emb) (1,45,256)
#output: speech , pos_emb , mask(16,1,45) #x_mask[:, :, :-2:2][:, :, :-2:2]
add_optional_chunk_mask