最近做fairseq的实验,发现之前对transformer的理解还是不够深入,尤其是解码器的部分,所以我又重新深入地了解了一下transformer解码器的构造。比对fairseq后对框架的了解也更加深入了。
transformer和self-attention的的介绍我前面的文章已经说过了,这里就不重复提了。这篇文章只说明解码器部分的细节。
这个是transformer那边论文中对应的整体的图,可以看到解码器的的block和编码器的是不一样的。解码器的block由masked multi-head attention和一个encoder-decoder的attention组成。
其中masked multi-head attention部分用于将未来的信息mask掉,因为在生成的时候是无法知道未来的信息的。下图就是一个mask的例子,至于attention的计算,还是和之前的方式一致。这一部分可以看成模型在考虑之前的输出信息来计算当前应该输出什么。
Decoder的第二个部分是一个encoder和decoder的attention,这一部分可以看成解码器在用编码器的输出信息来计算当前解码应该输出什么。这一部分的区别在于Q,K,V的三个向量中。Q,V都是是解码器的属性,而K是encoder端最后输出的K。attention的计算方式和之前的计算方式都是一样的。通过这个方法编码器可以捕捉编码器的输出信息。这一部分fairseq中使用了static k,v来表示(fairseq的这两块代码是分开的),因为编码器输出后,在解码过程中,这些k,v是不会变的。
后面的前向网络和其余部分的就没什么区别了。具体解码过程就是decoder端先输入一个起始的token,然后通过self-attention层,再通过encoder-decoder attention层,再通过前向层给出输出。得到这个token的最后表示。然后通过一个线性层加softmax来预测输出词典中的哪个词。
解码器每一步都只解码一个词,输出单词后放入解码器的输入,重复上述操作直到解码到<eos>为止。到这里就完成了句子的输出。
上述就是基本的解码操作,其实解码的过程还有很多骚操作,比如beam search,teacher force, length normalize等,有空再更吧。
以上图片均来自于网络。