Transformer代码实现
多层Transformer
在多层Transformer中,多层编码器先对输入序列进行编码,然后得到最后一个Encoder的输出Memory;解码器先通过Masked Multi-Head Attention对输入序列进行编码,然后将输出结果同Memory通过Encoder-Decoder Attention后得到第1层解码器的输出;接着再将第1层Decoder的输出通过Masked Multi-Head Attention进行编码,最后再将编码后的结果同Memory通过Encoder-Decoder Attention后得到第2层解码器的输出,以此类推得到最后一个Decoder的输出。
值得注意的是,在多层Transformer的解码过程中,每一个Decoder在Encoder-Decoder Attention中所使用的Memory均是同一个。

Transformer中的掩码
在Transformer中,主要有两个地方会用到掩码这一机制。第1个地方就是Attention Mask,用于在训练过程中解码的时候掩盖掉当前时刻之后的信息;第2个地方便是对一个batch中不同长度的序列在Padding到相同长度后,对Padding部分的信息进行掩盖。下面分别就这两种情况进行介绍。
Attention Mask
如图3-3所示,在训练过程中对于每一个样本来说都需要这样一个对称矩阵来掩盖掉当前时刻之后所有位置的信息。
具体细节在Decoder部分有介绍
Padding Mask
在Transformer中,使用到掩码的第2个地方便是Padding Mask。由于在网络的训练过程中同一个batch会包含有多个文本序列,而不同的序列长度并不一致。因此在数据集的生成过程中,就需要将同一个batch中的序列Padding到相同的长度。但是,这样就会导致在注意力的计算过程中会考虑到Padding位置上的信息。
如图3-4所示,P表示Padding的位置,右边的矩阵表示计算得到的注意力权重矩阵。可以看到,此时的注意力权重对于Padding位置山的信息也会加以考虑。因此在Transformer中,作者通过在生成训练集的过程中记录下每个样本Padding的实际位置;然后再将注意力权重矩阵中对应位置的权重替换成负无穷,经softmax操作后对应Padding位置上的权重就变成了0,从而达到了忽略Padding位置信息的目的。这种做法也是Encoder-Decoder网络结构中通用的一种办法。
如图3-5所示,对于"我 是 谁 P P"
这个序列来说,前3个字符是正常的,后2个字符是Padding后的结果。因此,其Mask向量便为[True, True, True, False, False]
。通过这个Mask向量可知,需要将权重矩阵的最后两列替换成负无穷。
实现多头注意力机制
多头注意力机制
class MyMultiheadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0., bias=True):
super(MyMultiheadAttention, self).__init__()
"""
:param embed_dim: 词嵌入的维度,也就是前面的d_model参数,论文中的默认值为512
:param num_heads: 多头注意力机制中多头的数量,也就是前面的nhead参数, 论文默认值为 8
:param bias: 最后对多头的注意力(组合)输出进行线性变换时,是否使用偏置
"""
self.embed_dim = embed_dim # 前面的d_model参数
self.head_dim = embed_dim // num_heads # head_dim 指的就是d_k,d_v
self.kdim = self.head_dim
self.vdim = self.head_dim
self.num_heads = num_heads # 多头个数
self.dropout = dropout
assert self.head_dim * num_heads == self.embed_dim, "embed_dim 除以 num_heads必须为整数"
# 上面的限制条件就是论文中的 d_k = d_v = d_model/n_head 条件
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
# embed_dim = kdim * num_heads
# 这里第二个维度之所以是embed_dim,实际上这里是同时初始化了num_heads个W_q堆叠起来的, 也就是num_heads个头
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
# W_k, embed_dim = kdim * num_heads
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
# W_v, embed_dim = vdim * num_heads
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
# 最后将所有的Z组合起来的时候,也是一次性完成, embed_dim = vdim * num_heads
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
在上述代码中,embed_dim
表示模型的维度(图3-8中的d_m);num_heads
表示多头的个数;bias
表示是否在多头线性组合时使用偏置。同时,为了使得实现代码更加高效,所以PyTorch在实现的时候是多个头注意力机制一起进行的计算,也就上面代码的第17-22行,分别用来初始化了多个头的权重值(这一过程从图3-8也可以看出)。当多头注意力机制计算完成后,将会得到一个形状为[src_len,embed_dim]
的矩阵,也就是图3-8中多个水平堆叠后的结果。因此,第24行代码将会初始化一个线性层来对这一结果进行一个线性变换。
3.3.3 定义前向传播过程
在定义完初始化函数后,便可以定义如下所示的多头注意力前向传播的过程
def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
"""
在论文中,编码时query, key, value 都是同一个输入,
解码时 输入的部分也都是同一个输入,
解码和编码交互时 key,value指的是 memory, query指的是tgt
:param query: # [tgt_len, batch_size, embed_dim], tgt_len 表示目标序列的长度
:param key: # [src_len, batch_size, embed_dim], src_len 表示源序列的长度
:param value: # [src_len, batch_size, embed_dim], src_len 表示源序列的长度
:param attn_mask: # [tgt_len,src_len] or [num_heads*batch_size,tgt_len, src_len]
一般只在解码时使用,为了并行一次喂入所有解码部分的输入,所以要用mask来进行掩盖当前时刻之后的位置信息
:param key_padding_mask: [batch_size, src_len], src_len 表示源序列的长度
:return:
attn_output: [tgt_len, batch_size, embed_dim]
attn_output_weights: # [batch_size, tgt_len, src_len]
"""
return multi_head_attention_forward(
query,
key,
value,
self.num_heads,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
attn_mask=attn_mask,
)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
在上述代码中,query
、key
、value
指的并不是图3-6中的Q、K和V,而是没有经过线性变换前的输入。例如在编码时三者指的均是原始输入序列src
;在解码时的Mask Multi-Head Attention中三者指的均是目标输入序列tgt
;在解码时的Encoder-Decoder Attention中三者分别指的是Mask Multi-Head Attention的输出、Memory和Memory。key_padding_mask
指的是编码或解码部分,输入序列的Padding情况,形状为[batch_size,src_len]
或者[batch_size,tgt_len]
;attn_mask
指的就是注意力掩码矩阵,形状为[tgt_len,src_len]
,它只会在解码时使用。
注意,在上面的这些维度中,tgt_len
本质上指的其实是query_len
;src_len
本质上指的是key_len
。只是在不同情况下两者可能会是一样,也可能会是不一样。