简述
上一篇简单概述了下Relation Classification via Convolutional Deep Neural Network(2014)的论文内容,这一篇简单地阅读学习下此篇论文的复现代码(来自FrankWork from github)。
项目结构主要如下:
- base_model.py: 为模型设置保存、加载路径(ckpt)。
- cnn_model.py:主要层(nonlinear、CNN)以及整体模型的实现。
- train.py:参数及超参数设置,训练及测试实现。
- base.py: 数据预处理及训练集、测试集的加载。
下面主要关注前三个文件:
base_model.py
class BaseModel(object):
@classmethod
def set_saver(cls, save_dir):
'''
Args:
save_dir: relative path to FLAGS.logdir
'''
# shared between train and valid model instance
# saver
cls.saver = tf.train.Saver(var_list=None)
# 保存路径
cls.save_dir = os.path.join(FLAGS.logdir, save_dir)
# 保存文件ckpt
cls.save_path = os.path.join(cls.save_dir, "model.ckpt")
@classmethod
def restore(cls, session):
# 加载模型
ckpt = tf.train.get_checkpoint_state(cls.save_dir)
cls.saver.restore(session, ckpt.model_checkpoint_path)
@classmethod
def save(cls, session, global_step):
# 保存模型
cls.saver.save(session, cls.save_path, global_step)
这段比较简单,设置下模型训练中基本的保存和加载路径等信息,之后BaseModel类方便CNNModel继承。
cnn_model.py
def linear_layer(name, x, in_size, out_size, is_regularize=False):
# 非线性全连接层
with tf.variable_scope(name):
# L2正则化
loss_l2 = tf.constant(0, dtype=tf.float32)
w = tf.get_variable('linear_W', [in_size, out_size],
initializer=tf.truncated_normal_initializer(stddev=0.1))
b = tf.get_variable('linear_b', [out_size],
initializer=tf.constant_initializer(0.1))
o = tf.nn.xw_plus_b(x, w, b) # batch_size, out_size
if is_regularize:
# 对W和b均使用正则化
loss_l2 += tf.nn.l2_loss(w) + tf.nn.l2_loss(b)
return o, loss_l2
这个函数为非线性全连接层的实现,可以看到,作者简单地对W和b应用了L2正则化。这样使用正则化在较小的模型中没有问题,如果模型复杂的话,不方便这样使用,会让loss部分代码冗长。
def cnn_forward(name, sent_pos, lexical, num_filters):
# 文本CNN
with tf.variable_scope(name):
# [batch, seq_length, emb_size+2*pos_emb, 1]
input = tf.expand_dims(sent_pos, axis=-1)
input_dim = input.shape.as_list()[2]
# convolutional layer
pool_outputs = []
# 3种size的卷积核
for filter_size in [3,4,5]:
with tf.variable_scope('conv-%s' % filter_size):
# [ filter_size, emb_size, in_channel, num_filters ]
conv_weight = tf.get_variable('W1',
[filter_size, input_dim, 1, num_filters],
initializer=tf.truncated_normal_initializer(stddev=0.1))
# [num_filters]
conv_bias = tf.get_variable('b1', [num_filters],
initializer=tf.constant_initializer(0.1))
# SAME考虑边界,用 0 填充
conv = tf.nn.conv2d(input,
conv_weight,
strides=[1, 1, input_dim, 1],
padding='SAME')
# [batch_size, len, 1, num_filters]
conv = tf.nn.relu(conv + conv_bias)
max_len = FLAGS.max_len
# [batch_size, 1, 1, num_filters],去除与长度相关的这一维
pool = tf.nn.max_pool(conv,
ksize= [1, max_len, 1, 1],# 池化窗口大小
strides=[1, max_len, 1, 1], #每一维度滑动步长
padding='SAME') # batch_size, 1, 1, num_filters
pool_outputs.append(pool)
# [batch, 3*num_filters]
pools = tf.reshape(tf.concat(pool_outputs, 3), [-1, 3*num_filters])
# feature
feature = pools
if lexical is not None:
# [batch, 6*emb_size + 3*num_filters]
feature = tf.concat([lexical, feature], axis=1)
return feature
这段是此模型中sentence level的部分的关键一层,代码作者用了3、4、5三个filter_size的各filter_num个卷积核来做普通的文本卷积,然后接了个max池化层,最后和lexical level的词向量(多个词向量concat后)进行concat。需要注意的是,这里的max polling对应的维度是seq_length这一维,论文中这样实现是为了尽量减小文本长度不一的影响。最后输出的tensor的形状是 [batch, 6emb_size + 3num_filters]。
class CNNModel(BaseModel):
'''
Relation Classification via Convolutional Deep Neural Network
http://www.aclweb.org/anthology/C14-1220
'''
def __init__(self, word_embed, data, word_dim,
pos_num, pos_dim, num_relations,
keep_prob, num_filters,
lrn_rate, is_train):
# input data
lexical, rid, sentence, pos1, pos2 = data
# embedding initialization
w_trainable = True if FLAGS.word_dim==50 else False
# emb_table词向量表
word_embed = tf.get_variable('word_embed',
initializer=word_embed,
dtype=tf.float32,
trainable=w_trainable)
# [pos_num, pos_emb]
pos1_embed = tf.get_variable('pos1_embed', shape=[pos_num, pos_dim])
pos2_embed = tf.get_variable('pos2_embed', shape=[pos_num, pos_dim])
# # embedding lookup
# 词级别,[batch, 6] -> [batch, 6, emb_size]
lexical = tf.nn.embedding_lookup(word_embed, lexical) # batch_size, 6, word_dim
lexical = tf.reshape(lexical, [-1, 6*word_dim]) # [batch, 6*emb_size]
self.labels = tf.one_hot(rid, num_relations) # batch_size, num_relations
# [batch, seq_length] ->[batch, seq_length, emb_size]
sentence = tf.nn.embedding_lookup(word_embed, sentence) # batch_size, max_len, word_dim
pos1 = tf.nn.embedding_lookup(pos1_embed, pos1) # batch_size, max_len, pos_dim
pos2 = tf.nn.embedding_lookup(pos2_embed, pos2) # batch_size, max_len, pos_dim
# cnn model
sent_pos = tf.concat([sentence, pos1, pos2], axis=2) # [batch, seq_length, emb_size+2*pos_emb]
if is_train:
# 句级别
sent_pos = tf.nn.dropout(sent_pos, keep_prob)
# 词级别+句级别, [batch, 6*emb_size+3*num_filters]
feature = cnn_forward('cnn', sent_pos, lexical, num_filters)
feature_size = feature.shape.as_list()[1]
self.feature = feature
if is_train:
feature = tf.nn.dropout(feature, keep_prob)
# Map the features to n classes
# 非线性层
logits, loss_l2 = linear_layer('linear_cnn', feature,
feature_size, num_relations,
is_regularize=True)
prediction = tf.nn.softmax(logits) # 0<prebs<1
prediction = tf.argmax(prediction, axis=1)
accuracy = tf.equal(prediction, tf.argmax(self.labels, axis=1))
accuracy = tf.reduce_mean(tf.cast(accuracy, tf.float32))
# 平均交叉熵
loss_ce = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=self.labels, logits=logits))
self.logits = logits
self.prediction = prediction
self.accuracy = accuracy
# 正则化(只包含全连接层的w和b)
self.loss = loss_ce + 0.01*loss_l2
if not is_train:
return
# global_step = tf.train.get_or_create_global_step()
global_step = tf.Variable(0, trainable=False, name='step', dtype=tf.int32)
optimizer = tf.train.AdamOptimizer(lrn_rate)
# for Batch norm
# 为了更新 moving_mean和moving_variance
# tf.GraphKeys.UPDATE_OPS,
# 这是一个tensorflow的计算图中内置的一个集合,
# 其中会保存一些需要在训练操作之前完成的操作,
# 并配合tf.control_dependencies函数使用。
# 关于在batch_norm中,即为更新mean和variance的操作
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# tf.control_dependencies,该函数保证其辖域中的操作必须要在该函数所传递的参数中的操作完成后再进行
with tf.control_dependencies(update_ops):
self.train_op = optimizer.minimize(self.loss, global_step)
self.global_step = global_step
这一部分就是把之前写好的nonlinear层和CNN层按照论文思路组装好,CNNModel类继承了之前提到的BaseModel类(保存加载等信息),需要注意的是作者最后这里提到了使用Batch Norm,上面的注释写的很清楚了,注意想要正确使用BN的话,要利用好tf.get_collection(tf.GraphKeys.UPDATE_OPS)和tf.control_dependencies这两个API。这里的loss和L2正则化这样做加法没有问题,但当网络复杂了后会不怎么方便,推荐下述使用方法使用正则化。
def build_train_valid_model(word_embed, train_data, test_data):
'''Relation Classification via Convolutional Deep Neural Network'''
# 实际调用CNN_model
with tf.name_scope("Train"):
with tf.variable_scope('CNNModel', reuse=None):
m_train = CNNModel( word_embed, train_data, FLAGS.word_dim,
FLAGS.pos_num, FLAGS.pos_dim, FLAGS.num_relations,
FLAGS.keep_prob, FLAGS.num_filters,
FLAGS.lrn_rate, is_train=True)
with tf.name_scope('Valid'):
with tf.variable_scope('CNNModel', reuse=True):
m_valid = CNNModel( word_embed, test_data, FLAGS.word_dim,
FLAGS.pos_num, FLAGS.pos_dim, FLAGS.num_relations,
1.0, FLAGS.num_filters,
FLAGS.lrn_rate, is_train=False)
return m_train, m_valid
这段函数没什么好说的,返回两个CNNModel的实例(训练用和验证用)。
正则化
当网络复杂的时候定义网络的结构部分和计算损失函数的部分可能不在一个函数中,这样通过上面那种简单的变量这种计算损失函数就不方便了。此时可以使用Tensorflow中提供的集合,它可以在一个计算图(tf.Graph)中保存一组实体(比如张量)。
def get_weight(shape, lambda1):
# 生成一个变量
var = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
# add_to_collection()函数将新生成变量的L2正则化损失加入集合losses
# lambda1为正则化乘上的系数
tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(lambda1)(var))
return var # 返回生成的变量
.......
weight = get_weight([in_dimension, out_dimension], 0.003)
.......
# 在定义神经网络前向传播的同时已经将所有的L2正则化损失加入了图上的集合,这里是损失函数的定义。
mean_loss = tf.reduce_mean(loss)
# 将均方误差损失函数加入损失集合
tf.add_to_collection('losses', mean_loss)
# get_collection()返回一个列表,这个列表是所有这个集合中的元素,
# 在本样例中这些元素就是损失函数的不同部分,将他们加起来就是最终的损失函数
loss = tf.add_n(tf.get_collection('losses'))
.......
# 用优化器最小化加了L2正则的loss即可
train_op = tf.train.AdamOptimizer(0.001).minimize(loss)
train.py
flags = tf.app.flags
flags.DEFINE_string("train_file", "data/train.cln",
"original training file")
flags.DEFINE_string("test_file", "data/test.cln",
"original test file")
flags.DEFINE_string("vocab_file", "data/vocab.txt",
"vocab of train and test data")
flags.DEFINE_string("google_embed300_file",
"data/embed300.google.npy",
"google news word embeddding")
flags.DEFINE_string("google_words_file",
"data/google_words.lst",
"google words list")
flags.DEFINE_string("trimmed_embed300_file",
"data/embed300.trim.npy",
"trimmed google embedding")
flags.DEFINE_string("senna_embed50_file",
"data/embed50.senna.npy",
"senna words embeddding")
flags.DEFINE_string("senna_words_file",
"data/senna_words.lst",
"senna words list")
flags.DEFINE_string("trimmed_embed50_file",
"data/embed50.trim.npy",
"trimmed senna embedding")
flags.DEFINE_string("train_record", "data/train.tfrecord",
"training file of TFRecord format")
flags.DEFINE_string("test_record", "data/test.tfrecord",
"Test file of TFRecord format")
flags.DEFINE_string("relations_file", "data/relations.txt", "relations file")
flags.DEFINE_string("results_file", "data/results.txt", "predicted results file")
flags.DEFINE_string("logdir", "saved_models/", "where to save the model")
flags.DEFINE_integer("max_len", 96, "max length of sentences")
flags.DEFINE_integer("num_relations", 19, "number of relations")
flags.DEFINE_integer("word_dim", 50, "word embedding size")
flags.DEFINE_integer("num_epochs", 200, "number of epochs")
flags.DEFINE_integer("batch_size", 100, "batch size")
flags.DEFINE_integer("pos_num", 123, "number of position feature")
flags.DEFINE_integer("pos_dim", 5, "position embedding size")
flags.DEFINE_integer("num_filters", 100, "cnn number of output unit")
flags.DEFINE_float("lrn_rate", 1e-3, "learning rate")
flags.DEFINE_float("keep_prob", 0.5, "dropout keep probability")
flags.DEFINE_boolean('test', False, 'set True to test')
flags.DEFINE_boolean('trace', False, 'set True to test')
FLAGS = tf.app.flags.FLAGS
def trace_runtime(sess, m_train):
'''
trace runtime bottleneck using timeline api
navigate to the URL 'chrome://tracing' in a Chrome web browser,
click the 'Load' button and locate the timeline file.
'''
# 运行时记录运行信息的protocolmessage
run_metadata=tf.RunMetadata()
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
from tensorflow.python.client import timeline
trace_file = open('timeline.ctf.json', 'w')
fetches = [m_train.train_op, m_train.loss, m_train.accuracy]
_, loss, acc = sess.run(fetches,
options=options,
run_metadata=run_metadata)
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
trace_file.write(trace.generate_chrome_trace_format())
trace_file.close()
def train(sess, m_train, m_valid):
n = 1
best = .0
best_step = n
start_time = time.time()
orig_begin_time = start_time
fetches = [m_train.train_op, m_train.loss, m_train.accuracy]
while True:
try:
_, loss, acc = sess.run(fetches)
epoch = n // 80
if n % 80 == 0:
now = time.time()
duration = now - start_time
start_time = now
v_acc = sess.run(m_valid.accuracy)
if best < v_acc:
best = v_acc
best_step = n
m_train.save(sess, best_step)
print("Epoch %d, loss %.2f, acc %.2f %.4f, time %.2f" %
(epoch, loss, acc, v_acc, duration))
sys.stdout.flush()
n += 1
except tf.errors.OutOfRangeError:
break
duration = time.time() - orig_begin_time
duration /= 3600
print('Done training, best_step: %d, best_acc: %.4f' % (best_step, best))
print('duration: %.2f hours' % duration)
sys.stdout.flush()
def test(sess, m_valid):
m_valid.restore(sess)
fetches = [m_valid.accuracy, m_valid.prediction]
accuracy, predictions = sess.run(fetches)
print('accuracy: %.4f' % accuracy)
base_reader.write_results(predictions, FLAGS.relations_file, FLAGS.results_file)
这段代码主要用flag设置了本模型的一些参数和超参等信息,并且设置了三个函数来调用之前写好的训练和验证用的模型类实例。
def main(_):
with tf.Graph().as_default():
train_data, test_data, word_embed = base_reader.inputs()
m_train, m_valid = cnn_model.build_train_valid_model(word_embed,
train_data, test_data)
m_train.set_saver('cnn-%d-%d' % (FLAGS.num_epochs, FLAGS.word_dim))
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())# for file queue
# GPU config
config = tf.ConfigProto()
# config.gpu_options.per_process_gpu_memory_fraction = 0.9 # 占用GPU90%的显存
config.gpu_options.allow_growth = True
# sv finalize the graph
with tf.Session(config=config) as sess:
sess.run(init_op)
print('='*80)
if FLAGS.trace:
trace_runtime(sess, m_train)
elif FLAGS.test:
test(sess, m_valid)
else:
train(sess, m_train, m_valid)
if __name__ == '__main__':
# 如果你的代码中的入口函数不叫main(),而是一个其他名字的函数,
# 如test(),则你应该这样写入口tf.app.run(test)
# 通过处理flag解析,然后执行main函数
tf.app.run()
这里主要用了tf.app.run()这个API来训练。
dataset的使用(待补充)