织梦CMS - 轻松建站从此开始!

罗索

完全解析RNN, Seq2Seq, Attention注意力机制

jackyhwei 发布于 2020-10-22 13:51 点击:次 
循环神经网络RNN结构被广泛应用于自然语言处理、机器翻译、语音识别、文字识别等方向。本文主要介绍经典的RNN结构,以及RNN的变种(包括Seq2Seq结构和Attention机制)。希望这篇文章能够帮助
TAG: RNN  Seq2Seq  Attention  

循环神经网络RNN结构被广泛应用于自然语言处理、机器翻译、语音识别、文字识别等方向。本文主要介绍经典的RNN结构,以及RNN的变种(包括Seq2Seq结构和Attention机制)。希望这篇文章能够帮助初学者更好地入门。

经典的RNN结构

图1

 

这就是最经典的RNN结构,它的输入是:

 

 

 

 

输出为:

 

 

 

 

也就是说,输入和输出序列必有相同的时间长度!

 

 

图2

 

 

假设输入 [公式] ( [公式] ) 是一个长度为 [公式] ( [公式] ) 的列向量:

 

 

 

 

隐藏层 [公式] 是一个长度为 [公式] ( [公式] ) 的列向量:

 

 

 

 

输出 [公式] 是一个长度为 [公式] ( [公式] ) 的列向量:

 

 

 

 

其中 [公式][公式][公式] 都是由人工设定的。

 

 

图3

 

 

  • [公式] 时刻输入层--> [公式] 时刻隐藏层:

 

 

 

 

  • [公式] 时刻隐藏层--> [公式] 时刻隐藏层:

 

 

 

 

  • [公式] 时刻输入层 and [公式] 时刻隐藏层--> [公式] 时刻隐藏层:

 

 

 

 

  • [公式] 时刻隐藏层--> [公式] 时刻输出层:

 

 

 

 

需要注意的是,对于任意时刻 [公式] ,所有的权值(包括 [公式] , [公式] , [公式] , [公式] , [公式] , [公式] )都相等,这也就是RNN中的“权值共享”,极大的减少参数量。

其实RNN可以简单的表示为:

 

 

图4

 

 

还有一个小细节:在 [公式] 时刻,如果没有特别指定初始状态,一般都会使用全0的 [公式] 作为初始状态输入到 [公式]

 

 

 

 

Sequence to Sequence模型

 

 

图5

 

 

在Seq2Seq结构中,编码器Encoder把所有的输入序列都编码成一个统一的语义向量Context,然后再由解码器Decoder解码。在解码器Decoder解码的过程中,不断地将前一个时刻 [公式] 的输出作为后一个时刻 [公式] 的输入,循环解码,直到输出停止符为止。

图6

 

接下来以机器翻译为例,看看如何通过Seq2Seq结构把中文“早上好”翻译成英文“Good morning”:

  1. 将“早上好”通过Encoder编码,并将最后 [公式] 时刻的隐藏层状态 [公式] 作为语义向量。
  2. 以语义向量为Decoder的 [公式] 状态,同时在 [公式] 时刻输入<start>特殊标识符,开始解码。之后不断的将前一时刻输出作为下一时刻输入进行解码,直接输出<stop>特殊标识符结束。

当然,上述过程只是Seq2Seq结构的一种经典实现方式。与经典RNN结构不同的是,Seq2Seq结构不再要求输入和输出序列有相同的时间长度!

 

 

 

 

 

图7

 

 

 

 

进一步来看上面机器翻译例子Decoder端的 [公式] 时刻数据流,如图7:

  • 首先对RNN输入大小为 [公式] 的向量 [公式] (红点);
  • 然后经过RNN输出大小为 [公式] 的向量 [公式] (蓝点);
  • 接着使用全连接fc将 [公式] 变为大小为 [公式] 的向量 [公式] ,其中 [公式] 代表类别数量;
  • [公式] 经过softmax和argmax获取类别index,再经过int2str获取输出字符;
  • 最后将类别index输入到下一状态,直到接收到<stop>标志符停止。

Embedding

还有一点细节,就是如何将前一时刻输出类别index(数值)送入下一时刻输入(向量)进行解码。假设每个标签对应的类别index如下:

'<start>' : 0, '<stop>' : 1, 'good' : 2, 'morning' : 3, ...

已知<start>标志符index为0,如果需要将<start>标志符输入到input层,就需要把类别index=0转变为一个 [公式] 长度的特定对应向量。这时就需要应用嵌入 (embedding) 方法。

 

 

 

 

图8 嵌入 (embedding)

 

 

 

假设有 [公式] 个词,最简单的方法就是使用 [公式] 长度的one-hot编码,词表alphabet如下:

  1. '<start>' : 0  <-----> label('<start>')=[1, 0, 0, 0, 0,..., 0] 
  2. '<stop>' :  1  <-----> label('<stop>') =[0, 1, 0, 0, 0,..., 0]  
  3. 'hello':    2  <-----> label('hello')  =[0, 0, 1, 0, 0,..., 0]  
  4. 'good' :    3  <-----> label('good')   =[0, 0, 0, 1, 0,..., 0]  
  5. 'morning' : 4  <-----> label('morning')=[0, 0, 0, 0, 1,..., 0]  
  6. ....... 

但是使用one-hot编码进行嵌入过于稀疏,所以我们使用一种更加优雅的办法:

  • 首先随机生成一个大小为 [公式] embedding随机矩阵:

 

 

 


 

 

 

  • 然后通过start标志的one-hot编码乘以embedding矩阵(即获取embedding矩阵的第 [公式] 行),作为start标志对应的输入向量送入网络:

 

 

 


 

 

 

  • [公式] 时刻网络输入 [公式] 后输出了good字符,那么要在 [公式] 时刻再把good字符的one-hot编码乘以embedding矩阵获取 [公式]

 

 


 

 

  • 同理 [公式] 再把上一时刻输出的morning字符的one-hot编码乘以embedding获取新的 [公式]

tf.nn.embedding_lookup

而在pytorch中通过以下接口实现:

torch.nn.Embedding

需要注意的是:train和test阶段必须使用一样的embedding矩阵!否则输出肯定是乱码。

当然,还可以使用word2vec/glove/elmo/bert等更加“精致”的嵌入方法,也可以在训练过程中迭代更新embedding。这些内容超出本文范围,不再详述。embedding入门请参考:

快速入门词嵌入之word2vec

Seq2Seq训练问题

值得一提的是,在seq2seq结构中将 [公式] 作为下一时刻输入 [公式] 进网络,那么某一时刻输出 [公式] 错误就会导致后面全错。在训练时由于网络尚未收敛,这种蝴蝶效应格外明显。

图9

为了解决这个问题,Google提出了大名鼎鼎的Scheduled Sampling(即在训练中 [公式] 按照一定概率选择输入 [公式][公式] 时刻对应的真实值,即标签,如图10),既能加快训练速度,也能提高训练精度。

图10

Scheduled Sampling对应文章如下:

Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks arxiv.org

 

Attention注意力机制

 

 

 

图11

 

 

 

在Seq2Seq结构中,encoder把所有的输入序列都编码成一个统一的语义向量Context,然后再由Decoder解码。由于context包含原始序列中的所有信息,它的长度就成了限制模型性能的瓶颈。如机器翻译问题,当要翻译的句子较长时,一个Context可能存不下那么多信息,就会造成精度的下降。除此之外,如果按照上述方式实现,只用到了编码器的最后一个隐藏层状态,信息利用率低下。

所以如果要改进Seq2Seq结构,最好的切入角度就是:利用Encoder所有隐藏层状态 [公式] 解决Context长度限制问题。

接下来了解一下attention注意力机制基本思路(Luong Attention)

 

 

 

图12

 

 

 

考虑这样一个问题:由于Encoder的隐藏层状态 [公式] 代表对不同时刻输入 [公式] 的编码结果:

 

 

 

 

 

 

即Encoder状态 [公式][公式][公式] 对应编码器对“早”,“上”,“好”三个中文字符的编码结果。那么在Decoder时刻 [公式] 通过3个权重 [公式][公式][公式] 计算出一个向量 [公式]

 

 

 

 

 

 

然后将这个向量与前一个状态拼接在一起形成一个新的向量输入到隐藏层计算结果:

 

 

 

 

 

 

Decoder时刻 [公式]

 

 

 

 

 

 

Decoder时刻 [公式][公式] 同理,就可以解决Context长度限制问题。由于 [公式][公式][公式] 不同,就形成了一种对编码器不同输入 [公式] 对应 [公式] 的“注意力”机制(权重越大注意力越强)。

那么到底什么是LuongAttention注意力机制?

 

 

 

图13
Effective Approaches to Attention-based Neural Machine Translation arxiv.org

 

 

 

为了说明具体结构,重新定义符号: [公式] 代表Encoder状态, [公式] 代表Decoder状态, [公式] 代表Attention Layer输出的最终Decoder状态,如图13。需要说明, [公式][公式][公式] 大小的向量。接下来一起看看注意力机制具体实现方式。

  • 首先,计算Decoder的 [公式] 时刻隐藏层状态 [公式] 对Encoder每一个隐藏层状态 [公式] 权重 [公式] 数值:

[公式]

这里的 [公式] 可以通过以下三种方式计算:

 

 

 

 

 

 

所谓Dot就是向量内积,而General通过乘以 [公式] 权重矩阵进行计算( [公式][公式] 大小的矩阵)。一般经验General方法好于Dot方法,Concat方法略去不讲。

  • 其次,利用权重 [公式] 计算所有隐藏层状态 [公式] 加权之和 [公式] ,即生成新的大小为 [公式] 的Context状态向量:

 

 

 

 

 

 

  • 接下来,将通过权重 [公式] 生成的 [公式] 与原始Decoder隐藏层 [公式] 时刻状态 [公式] 拼接在一起:

 

 

 

 

 

 

这里 [公式][公式] 大小都是[公式] ,拼接后会变大。由于需要恢复为原来形状,所以乘以全连接 [公式] 矩阵。当然不恢复也可以,但是会造成Decoder RNN cell变大。

  • 最后,对加入“注意力”的Decoder状态 [公式] 乘以 [公式] 矩阵即可获得输出:

 

 

 

 

 

 

也可以根据需要,把新生成的状态 [公式] 继续送入RNN继续进行学习。其中 [公式][公式] 参数需要通过学习获得。

 

 

 

图14

 

 

 

在实际应用中当输入一组 [公式] ,除了可以获得输出 [公式] ,还能提取出 [公式][公式] 对应的权重数值 [公式] 并画出来,如图15,这样就可以直观的看到时刻 [公式] 注意力机制到底“注意”了什么。

 

 

 

图15 注意力机制中的权重

 

 

 

可以看到,整个Attention注意力机制相当于在Seq2Seq结构上加了一层“包装”,内部通过函数 [公式] 计算注意力向量 [公式],从而给Decoder RNN加入额外信息,以提高性能。无论在机器翻译,语音识别,自然语言处理(NLP),文字识别(OCR),Attention机制对Seq2Seq结构都有很大的提升。

如何向RNN加入额外信息

Attention机制其实就是将的Encoder RNN隐藏层状态加权后获得权重向量 [公式] ,额外加入到Decoder中,给Decoder RNN网络添加额外信息,从而使得网络有更完整的信息流。

 

 

 

图16 RNN添加额外信息的3中方式

 

 

 

所以,假设有额外信息 [公式] (如上文中的注意力向量 [公式] ),给RNN网络添加额外信息主要有以下3种方式:

  • ADD:直接将 [公式] 叠加在输出 [公式] 上。

 

 

 

 

 

 

  • CONCAT:将 [公式] 拼接在隐藏层 [公式] 后全连接恢复维度(不恢复维度也可以,但是会造成参数量加倍)。上篇文章中的LuongAttention机制就使用此种方法。

 

 

 

 

 

 

  • MLP:新添加一个对 [公式] 的感知单元 [公式]

 

 

 

 

 

 

 

特别说明:上文介绍的LuongAttention仅仅是注意力机制的一种具体实现,不代表Attention仅此一种。事实上Seq2Seq+Attention还有很多很玩法。望读者了解!

(白裳)
本站文章除注明转载外,均为本站原创或编译欢迎任何形式的转载,但请务必注明出处,尊重他人劳动,同学习共成长。转载请注明:文章转载自:罗索实验室 [http://www.rosoo.net/a/202010/17755.html]
本文出处:zhihu 作者:白裳 原文
顶一下
(0)
0%
踩一下
(0)
0%
------分隔线----------------------------
相关文章
发表评论
请自觉遵守互联网相关的政策法规,严禁发布色情、暴力、反动的言论。
评价:
表情:
用户名: 验证码:点击我更换图片