理解LSTM

背景

rnn可以解决时序问题,但是容易出现梯度爆炸/梯度爆炸的问题,当文本较长的时候,效果并不好。LSTM,是Long Short-Term Memory,中文可以理解成比较长的Short-Term Memory,因为当文本非常长的时候,效果也会有问题(可以通过加入attention方式解决)。

介绍

具体网络结构

下面这张图,是目前我看到的最清晰的描述LSTM网络结构的图片,使用一个实例描述了整个过程。其中S表示该神经元使用sigmoid激活函数触发,T表示该神经元使用tanh函数触发。输入的纬度是3,隐层的纬度是2。

lstm_cell

抽象网络结构

下面这张图画的也很不错,从抽象的角度定义了LSTM,可以对比来看,一个是抽象,一个是实例。

lstm_cell_con

lstm_single_cell

几个公式:

  1. 遗忘门:$f_t = \sigma(W_f [h_{t-1}, x_t] + b_f)$
  2. 输入门:$i_t = \sigma(W_i[h_{t-1}, x_t] + b_i)$
  3. $\tilde{C_t} = tanh(W_C [h_{t-1}, x_t] + b_C)$
  4. $C_t = f_t C_{t - 1} + i_t \tilde {C_t}$
  5. $O_t = \sigma(W_o[h_{t-1}, x_t] + b_o)$
  6. $h_t = o_t * tanh(C_t)$

相关代码

下面代码构造网络参数,我们可以看出来_kernel用来存储模型参数,可以看出来存储了4倍的参数,因为我们有四个FFC网络。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% str(inputs_shape))

input_depth = inputs_shape[-1]
h_depth = self._num_units
self._kernel = self.add_variable(
_WEIGHTS_VARIABLE_NAME,
shape=[input_depth + h_depth, 4 * self._num_units])
self._bias = self.add_variable(
_BIAS_VARIABLE_NAME,
shape=[4 * self._num_units],
initializer=init_ops.zeros_initializer(dtype=self.dtype))

self.built = True

下面这段代码是实际的计算过程,需要注意的是,每次更新,需要返回两个状态$h_t$和$C_t$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def call(self, inputs, state):
"""Long short-term memory cell (LSTM).
Args:
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
state: An `LSTMStateTuple` of state tensors, each shaped
`[batch_size, num_units]`, if `state_is_tuple` has been set to
`True`. Otherwise, a `Tensor` shaped
`[batch_size, 2 * num_units]`.
Returns:
A pair containing the new hidden state, and the new state (either a
`LSTMStateTuple` or a concatenated state, depending on
`state_is_tuple`).
"""
sigmoid = math_ops.sigmoid
one = constant_op.constant(1, dtype=dtypes.int32)
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)

gate_inputs = math_ops.matmul(
array_ops.concat([inputs, h], 1), self._kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)

# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
value=gate_inputs, num_or_size_splits=4, axis=one)

forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
# Note that using `add` and `multiply` instead of `+` and `*` gives a
# performance improvement. So using those at the cost of readability.
add = math_ops.add
multiply = math_ops.multiply
new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
multiply(sigmoid(i), self._activation(j)))
new_h = multiply(self._activation(new_c), sigmoid(o))

if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state

参考

  1. https://towardsdatascience.com/animated-rnn-lstm-and-gru-ef124d06cf45
  2. http://colah.github.io/posts/2015-08-Understanding-LSTMs/
------ 本文结束 ------
k