def my_lstm_layer(input_representation, 
lstm_dim=100,
input_lengths=None,
scope_name="my_rnn",
reuse=False,
is_training=True,
dropout_rate=0.2,
use_cudnn=False):
'''
input_representation: [batch_size, seq_len, feature_dim]
input_lengths: [batch_size] (可以为空)
'''
if is_training:
input_representation = tf.nn.dropout(input_representation, (1 - dropout_rate))
with tf.variable_scope(scope_name, reuse=reuse):
if use_cudnn:
inputs = tf.transpose(input_representation, [1, 0, 2])
lstm = tf.contrib.cudnn_rnn.CudnnLSTM(1, lstm_dim, direction="bidirectional",
name="{}_cudnn_bi_lstm".format(scope_name),
dropout=dropout_rate if is_training else 0)
outputs, _ = lstm(inputs)
outputs = tf.transpose(outputs, [1, 0, 2])
forward_representation = outputs[:, :, 0:lstm_dim]
backward_representation = outputs[:, :, lstm_dim:2 * lstm_dim]
else:
context_lstm_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(lstm_dim)
context_lstm_cell_bw = tf.nn.rnn_cell.BasicLSTMCell(lstm_dim)
if is_training:
context_lstm_cell_fw = tf.nn.rnn_cell.DropoutWrapper(context_lstm_cell_fw,
output_keep_prob=(1 - dropout_rate))
context_lstm_cell_bw = tf.nn.rnn_cell.DropoutWrapper(context_lstm_cell_bw,
output_keep_prob=(1 - dropout_rate))
context_lstm_cell_fw = tf.nn.rnn_cell.MultiRNNCell([context_lstm_cell_fw])
context_lstm_cell_bw = tf.nn.rnn_cell.MultiRNNCell([context_lstm_cell_bw])

(forward_representation, backward_representation), _ = tf.nn.bidirectional_dynamic_rnn(
context_lstm_cell_fw, context_lstm_cell_bw, input_representation, dtype=tf.float32,
sequence_length=input_lengths)
outputs = tf.concat(axis=2, values=[forward_representation, backward_representation])
return (forward_representation, backward_representation, outputs)