一、网上的资源
网上有不少用LSTM来预测时间序列的资源,如下面:
深度学习(08)_RNN-LSTM循环神经网络-03-Tensorflow进阶实现
Applying Deep Learning to Time Series Forecasting with TensorFlow
https://mapr.com/blog/deep-learning-tensorflow/
Tensorflow 笔记 RNN 预测时间序列
https://www.v2ex.com/t/339544
tf19: 预测铁路客运量
但是调试起来,都很困难!借鉴比较多的是tf19:预测铁路客运量这篇博文。这篇博文首先是基本上可以运行的。但是训练模型和测试模型分开,需要通过文件来传递模型参数。而且训练和测试不能同时运行。因此调试起来也费了不少功夫!
二、LSTM时间序列预测
1. 用namedtuple来配置模型的超参数。
HParams = namedtuple('HParams',
'seq_size, hidden_size, learning_rate')
这种方式比定义一个Config类好。
2. 构建时间序列预测模型类TS_LSTM
class TS_LSTM(object):
def __init__(self, hps):
self._X = X = tf.placeholder(tf.float32, [None, hps.seq_size, 1])
self._Y = Y = tf.placeholder(tf.float32, [None, hps.seq_size])
W = tf.Variable(tf.random_normal([hps.hidden_size, 1]), name='W')
b = tf.Variable(tf.random_normal([1]), name='b')
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hps.hidden_size) #测试cost 1.3809
outputs, states = tf.nn.dynamic_rnn(lstm_cell, X, dtype=tf.float32)
W_repeated = tf.tile(tf.expand_dims(W, 0), [tf.shape(X)[0], 1, 1])
output = tf.nn.xw_plus_b(outputs, W_repeated, b)
self._output = output = tf.squeeze(output)
self._cost = cost = tf.reduce_mean(tf.square(output - Y))
self._train_op = tf.train.AdamOptimizer(hps.learning_rate).minimize(cost)
@property
def X(self):
return self._X
@property
def Y(self):
return self._Y
@property
def cost(self):
return self._cost
@property
def output(self):
return self._output
@property
def train_op(self):
return self._train_op
这种方式比用函数定义模型更加方便。@property的设计使得模型用起来更加方便!
模型的关键就是:
1). 设定BasicLSTMCell的隐藏节点个数
2). 调用dynamic_rnn(lstm_cell,X)来计算输出outputs
3). 调用xw_plus_b将outputs计算为单个的output
模型中各变量的维度如下:(batch_size=100, seq_size=3, hidden_size=6)
- X定义为[None, hps.seq_size, 1]是因为dynamic_rnn的输入针对的是二维图像样本的输入,因此,必须多定义一个1的维度,传入的实际应该为100*3*1。
- Y的维度维持与图像标签输入数据维度相同,传入的实际应该为100*3。
- W为6*1
- b为1*1
- outputs为100*3*6
- W_repeated为100*6*1,其变化过程6*1→1*6*1→100*6*1。
- output在squeeze之前为100*3*1,squeeze后为100*3
- cost为1*1
3. 训练和测试函数train_test
def train_test(hps, data):
#训练数据准备
train_data_len = len(data)*2//3
train_x, train_y = [], []
for i in range(train_data_len - hps.seq_size - 1):
train_x.append(np.expand_dims(data[i : i + hps.seq_size], axis=1).tolist())
train_y.append(data[i + 1 : i + hps.seq_size + 1].tolist())
#测试数据准备
test_data_len = len(data)//3
test_x, test_y = [], []
for i in range(train_data_len,
train_data_len+test_data_len - hps.seq_size - 1):
test_x.append(np.expand_dims(data[i : i + hps.seq_size], axis=1).tolist())
test_y.append(data[i + 1 : i + hps.seq_size + 1].tolist())
with tf.Graph().as_default(), tf.Session() as sess:
with tf.variable_scope('model',reuse=None):
m_train = TS_LSTM(hps)
#训练
tf.global_variables_initializer().run()
for step in range(20000):
_, train_cost = sess.run([m_train.train_op, m_train.cost],
feed_dict={m_train.X: train_x, m_train.Y: train_y})
#预测
test_cost, output = sess.run([m_train.cost, m_train.output],
feed_dict={m_train.X: test_x, m_train.Y: test_y})
#print(hps, train_cost, test_cost)
return train_cost, test_cost
这里的关键是测试用是训练模型,我也不知道为什么好多网络资源都将训练模型和测试模型分离开来。测试不就是用测试数据来测试训练模型的效果吗?因此这里把2/3的数据划给训练,1/3的数据用于测试。自己动手编代码时一定要对session.run函数用法和原理熟悉。
4. 主函数(对超参数组合的测试误差进行比较)
def main():
#读取原始数据
f=open('铁路客运量.csv')
df=pd.read_csv(f)
data = np.array(df['铁路客运量_当期值(万人)'])
normalized_data = (data - np.mean(data)) / np.std(data)
#测试不同组合的超参数对测试误差的影响
costs =[]
for seq_size in [4,6,12,16,24]:
for hidden_size in [6,10,20,30]:
print(seq_size, hidden_size)
hps = HParams(seq_size, hidden_size, 0.003)
train_cost, test_cost = train_test(hps, normalized_data)
costs.append([train_cost,test_cost])
进行了初步比较,感觉有两个:
1)同一个超参数,测试误差相差挺大。 2)不同超参数,训练时误差基本都很小,但是测试误差相差很大,如何限制学习过程中的过拟合是一个很大的问题。 可以看看我运行的训练误差和测试误差的比较。
训练误差 测试误差
[[ 4.04044241e-02 4.97651482e+00]
[ 3.57200466e-02 6.96304381e-01]
[ 2.97380015e-02 1.77482967e+01]
[ 3.09452992e-02 2.62166214e+00]
[ 3.62494551e-02 2.53422332e+00]
[ 2.57663596e-02 1.44900203e+00]
[ 2.24006996e-02 2.28607416e+00]
[ 2.28729844e-02 1.12727535e+00]
[ 2.58173030e-02 1.43265343e+00]
[ 1.48035632e-02 1.05281734e+00]
[ 1.24982912e-02 6.59598827e+00]
[ 1.27354050e-02 1.69984627e+00]
[ 1.60749555e-02 4.03962803e+00]
[ 1.18473349e-02 7.92685986e-01]
[ 7.39684049e-03 6.16959620e+00]
[ 7.60479691e-03 3.01771784e+00]
[ 1.40351299e-02 4.48093843e+00]
[ 7.94599950e-03 3.78614712e+00]
[ 5.50406286e-03 5.83478451e-01]
[ 4.54067113e-03 8.15259743e+00]]