理解RNN

背景

Rnn全名Recurrent Neural Network,对时序进行了建模,能够记忆之前看到的信息。

举例来说:订一张从北京到上海的机票 vs 订一张从上海到北京的机票,如果直接用bow模型是无法区分这两个文本的,而使用rnn可以刻画出时序关系。

介绍

模型结构

rnn有两个输入,一个是真实输入$X^t$,一个是上一层的输入$H^{t-1}$,上一层的输入能够对上一个cell的状态进行记忆。模型将$X^t$和$H^{t-1}$进行拼接后,通过一个ffc(W)得到下一层的$H^t$,同时再通过一个FFC得到输出$Y^t$。

具体公式:

$H^t = \sigma{(W^{hh}h^{t - 1}+ W^{hx}X^t + b)}$

$Y^t = softmax(W^{S}h^t)$

具体流程如下:

rnn

相关代码

计算$H^t$相关代码,将inputs 和 state拼接后,和_kernel矩阵相乘,加bias和激活函数,得到最终的输出。

1
2
3
4
5
6
7
8
def call(self, inputs, state):
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""

gate_inputs = math_ops.matmul(
array_ops.concat([inputs, state], 1), self._kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
output = self._activation(gate_inputs)
return output, output

kernel相关定义:

1
2
3
self._kernel = self.add_variable(
_WEIGHTS_VARIABLE_NAME,
shape=[input_depth + self._num_units, self._num_units])

注意

  1. 所有的rnn cell的权重是共享的。可以想象成输入是不动的,使rnn_cell在输入上不断的滑动。
  2. rnn容易出现梯度消失和梯度爆炸,可以考虑使用lstm。
  3. 无法并行,可以考虑使用transformer。

参考

  1. https://towardsdatascience.com/animated-rnn-lstm-and-gru-ef124d06cf45
  2. https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/rnn_cell_impl.py
------ 本文结束 ------
k