#2021SC@SDUSC
这部分将分析rnn_cell.py和rnn_layers.py部分
RNNCell 表示 .NestedMap
中的循环状态。 zero_state(theta, batch_size)
返回初始状态,由每个子类定义。 从状态中,每个子类都定义了GetOutput()
来提取输出张量。 RNNCell.FProp
定义了前向函数:: (theta, state0, 输入) -> state1, extras 所有参数和返回值都是.NestedMap
。 每个子类都定义了这些 .NestedMap
应该具有的字段。 extras
是一个 .NestedMap
,其中包含一些 FProp
计算以促进反向传播的中间结果。 zero_state(theta, batch_size)
、state0
和 state1
都是兼容的 .NestedMap
(参见 .NestedMap.IsCompatible
)。 即,它们递归地具有相同的键。 此外,这些 .NestedMap
中相应的张量具有相同的形状和数据类型。
@classmethod def Params(cls): p = super().Params() p.Define('inputs_arity', 1, 'number of tensors expected for the inputs.act to FProp.') p.Define('num_input_nodes', 0, 'Number of input nodes.') p.Define( 'num_output_nodes', 0, 'Number of output nodes. If num_hidden_nodes is 0, also used as ' 'cell size.') p.Define( 'reset_cell_state', False, ('Set True to support resetting cell state in scenarios where multiple ' 'inputs are packed into a single training example. The RNN layer ' 'should provide reset_mask inputs in addition to act and padding if ' 'this flag is set.')) p.Define( 'zero_state_init_params', py_utils.DefaultRNNCellStateInit(), 'Parameters that define how the initial state values are set ' 'for each cell. Must be one of the static functions defined in ' 'py_utils.RNNCellStateInit.') return p
def FProp(self, theta, state0, inputs):
这里的默认实现假设 cell forward 函数由两个函数组成: _Gates(_Mix(theta,state0,inputs),theta,state0,inputs) _Mix
的结果存放在 extras
中以方便反向传播。 如果 reset_cell_state
为 True,则可选地应用 _ResetState
。除了其他输入之外,RNN 层还应提供“reset_mask”输入。 reset_mask
输入在运行 _Mix()
和 _Gates()
之前应该被重置为默认值(零)的时间步长为 0,否则为 1。这是为了支持诸如打包输入之类的用例,其中在单个输入示例序列中输入多个样本,并且需要相互屏蔽。例如,如果打包在一起的两个例子是 [‘good’, ‘day’] -> [‘guten-tag’] 和 [‘thanks’] -> [‘danke’] 产生 [‘good’, 'day ', ‘thanks’] -> [‘guten-tag’, ‘danke’],源 reset_mask 将为 [1, 1, 0],目标重置掩码将为 [1, 0]。这些 id 旨在为彼此不同的示例启用屏蔽计算。 参数: theta:一个.NestedMap
对象,包含该层及其子层的权重值。 state0:之前的循环状态。一个.NestedMap
。 输入:单元格的输入。一个.NestedMap
。 返回: 元组 (state1, extras)。 - state1:下一个循环状态。一个.NestedMap
。 - 附加:中间结果以促进反向传播。一个.NestedMap
。
assert isinstance(inputs.act, list) assert self.params.inputs_arity == len(inputs.act) if self.params.reset_cell_state: state0_modified = self._ResetState(state0.DeepCopy(), inputs) else: state0_modified = state0 xmw = self._Mix(theta, state0_modified, inputs) state1 = self._Gates(xmw, theta, state0_modified, inputs) return state1, py_utils.NestedMap()
def _GetBias(self, theta): 获取要添加的偏置向量。
包括forget_gate_bias 之类的调整。 直接使用 this 而不是 ‘b’ 变量,因为以这种方式包含调整允许 const-prop 在推理时消除调整。 参数: theta:一个.NestedMap
对象,包含该层及其子层的权重值。 返回: 偏置向量。
p = self.params if p.enable_lstm_bias: b = theta.b else: b = tf.zeros([self.num_gates * self.hidden_size], dtype=p.dtype) if p.forget_gate_bias != 0.0: # Apply the forget gate bias directly to the bias vector. if not p.couple_input_forget_gates: # Normal 4 gate bias (i_i, i_g, f_g, o_g). adjustment = ( tf.ones([4, self.hidden_size], dtype=p.dtype) * tf.expand_dims( tf.constant([0., 0., p.forget_gate_bias, 0.], dtype=p.dtype), axis=1)) else: # 3 gates with coupled input/forget (i_i, f_g, o_g). adjustment = ( tf.ones([3, self.hidden_size], dtype=p.dtype) * tf.expand_dims( tf.constant([0., p.forget_gate_bias, 0.], dtype=p.dtype), axis=1)) adjustment = tf.reshape(adjustment, [self.num_gates * self.hidden_size]) b = b + adjustment return b
函数GeneratePackedInputResetMask从 segment_id 生成 RNN 单元的掩码输入。
参数:
segment_id:形状为 [time, batch_size, 1] 的张量。
is_reverse:如果输入以相反的顺序馈送到 RNN,则为真。
返回:
reset_mask - 形状为 [time, batch_size, 1] 的张量。 对于样本设置为 0
需要重置状态的地方(在示例边界处),否则为 1。
segment_id_left = segment_id[:-1] segment_id_right = segment_id[1:] # Mask is a [t-1, bs, 1] tensor. reset_mask = tf.cast( tf.equal(segment_id_left, segment_id_right), dtype=segment_id.dtype) mask_padding_shape = tf.concat( [tf.ones([1], dtype=tf.int32), tf.shape(segment_id)[1:]], axis=0) mask_padding = tf.ones(mask_padding_shape, dtype=segment_id.dtype) if is_reverse: reset_mask = tf.concat([reset_mask, mask_padding], axis=0) else: reset_mask = tf.concat([mask_padding, reset_mask], axis=0) return reset_mask
Class RNN:
静态展开的RNN
形参:
def Params(cls): p = super().Params() p.Define('cell', rnn_cell.LSTMCellSimple.Params(), 'Configs for the RNN cell.') p.Define( 'sequence_length', 0, 'Sequence length to unroll. If > 0, then will unroll to this fixed ' 'size. If 0, then will unroll to accommodate the size of the inputs ' 'for each call to FProp.') p.Define('reverse', False, 'Whether or not to unroll the sequence in reversed order.') p.Define('packed_input', False, 'To reset states for packed inputs.') return p
初始化:
def __init__(self, params): super().__init__(params) p = self.params assert not p.packed_input, ('Packed inputs are currently not supported by ' 'Static RNN') p.cell.reset_cell_state = p.packed_input assert p.sequence_length >= 0 self.CreateChild('cell', p.cell)
函数FProp计算 RNN 前向传递。
参数: theta:一个.NestedMap
对象,包含该层及其子层的权重值。 输入:单个张量或基数等于的张量元组 rnn_cell.inputs_arity。 对于每个输入张量,假设第一维是时间、第二维批次和第三维深度。 填充:张量。 第一个暗淡是时间,第二个暗淡是批次,第三个暗淡预计为 1。 state0:如果不是 None,则为 .NestedMap
中的初始 rnn 状态。 默认为单元格的零状态。 返回: [时间、batch、dim]的张量。 最终的循环状态。
p = self.params assert isinstance(self.cell, rnn_cell.RNNCell) if p.sequence_length == 0: if isinstance(inputs, (tuple, list)): sequence_length = len(inputs) else: sequence_length = py_utils.GetShape(inputs)[0] else: sequence_length = p.sequence_length assert sequence_length >= 1, ('Sequence length must be defined or inputs ' 'must have fixed shapes.') with tf.name_scope(p.name): inputs_sequence = tf.unstack(inputs, num=sequence_length) paddings_sequence = tf.unstack(paddings, num=sequence_length) # We start from all 0 states. if state0: state = state0 else: inputs0 = py_utils.NestedMap( act=[inputs_sequence[0]], padding=paddings_sequence[0]) state = self.cell.zero_state(theta.cell, self.cell.batch_size(inputs0)) outputs = [None] * sequence_length if p.reverse: sequence = range(sequence_length - 1, -1, -1) else: sequence = range(0, sequence_length, 1) for idx in sequence: cur_input = py_utils.NestedMap(act=[inputs[idx]], padding=paddings[idx]) state, _ = self.cell.FProp(theta.cell, state, cur_input) outputs[idx] = self.cell.GetOutput(state) return tf.stack(outputs), state