前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用LSTM预测正弦曲线

使用LSTM预测正弦曲线

作者头像
用户6021899
发布2020-02-19 11:45:39
2K0
发布2020-02-19 11:45:39
举报

之前介绍过用LSTM预测天气的例子,该例子中数据集的处理和曲线绘制函数稍微有点复杂。这篇我们使用标准正弦函数做数据集,让代码更简单,来加深我们对LSTM的理解。

首先导入必要的库,并对matplotlib 库做些设置使之能正确显示中文:

代码语言:javascript
复制
# -*- coding: utf-8 -*-
"""
Created on Sun Jan  5 21:08:46 2020
@author: Administrator
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

准备数据集:

(数据量太大,仅显示最后若干个周期)

代码语言:javascript
复制
def univariate_data(dataset, start_index, end_index, history_size, target_size):
  #一段连续数据做data(长度为history_size),紧邻的一个数据做label
  #连续滚动,我们就得到了一系列数据和相应的labels
  data = []
  labels = []
  start_index = start_index + history_size
  if end_index is None:
    end_index = len(dataset) - target_size
  for i in range(start_index, end_index):
    indices = range(i-history_size, i)
    # Reshape data from (history_size,) to (history_size, 1)
    data.append(np.reshape(dataset[indices], (history_size, 1)))
    labels.append(dataset[i+target_size])
  return np.array(data), np.array(labels)
  
periods = 1000
points_pp = 16 #sine曲线一个周期取多少个点
X = np.arange(0,periods*2*np.pi-2.0*np.pi/points_pp,2.0*np.pi/points_pp)
Y = np.sin(X)
uni_data = Y
#头12k条数据作为训练集,剩下的4k作为验证集
TRAIN_SPLIT = 12000
#数据标准化(减去均值,再除以标准差)
uni_train_mean = uni_data[:TRAIN_SPLIT].mean()
uni_train_std = uni_data[:TRAIN_SPLIT].std()
uni_data = (uni_data-uni_train_mean)/uni_train_std
univariate_past_history = 48 #用48个历史数据点
univariate_future_target = 16 #预测接下来的16个数据点
x_train_uni, y_train_uni = univariate_data(uni_data, 0, TRAIN_SPLIT,
                                           univariate_past_history,
                                           univariate_future_target)
x_val_uni, y_val_uni = univariate_data(uni_data, TRAIN_SPLIT, None,
                                       univariate_past_history,
                                       univariate_future_target)
BATCH_SIZE = 128 # 128 段 数据
BUFFER_SIZE = 1000
#训练集
tf.random.set_seed(666)
train_univariate = tf.data.Dataset.from_tensor_slices((x_train_uni, y_train_uni))
train_univariate = train_univariate.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()#打乱训练集
#验证集
val_univariate = tf.data.Dataset.from_tensor_slices((x_val_uni, y_val_uni))
val_univariate = val_univariate.batch(BATCH_SIZE).repeat()

创建LSTM模型,并拟合/训练模型:

代码语言:javascript
复制
#创建一个简单的LSTM网络模型
simple_lstm_model = tf.keras.models.Sequential([
    tf.keras.layers.LSTM(units=8, input_shape=x_train_uni.shape[-2:],activation="tanh"),#units:输出空间的维度
    tf.keras.layers.Dense(1)
])
simple_lstm_model.compile(optimizer='adam', loss='mae')#模型编译,设定优化器和损失类型

#因为数据集很大,为了节省时间,每个EPOCH仅跑400步,没有跑完所有训练数据
EVALUATION_INTERVAL = 400
EPOCHS = 10
simple_lstm_model.fit(train_univariate, epochs=EPOCHS,
                      steps_per_epoch=EVALUATION_INTERVAL,
                      validation_data=val_univariate, validation_steps=50)

利用训练好的模型做预测,绘制最后的历史数据并预测未来:

代码语言:javascript
复制
plt.plot(X[-univariate_past_history:],Y[-univariate_past_history:],marker ="o",label ="最后的历史值")
X1 = np.arange(periods*2*np.pi,(periods+1)*2*np.pi - np.pi/8.0 ,np.pi/8.0)
#Y1 = np.sin(X1)
data0 = Y[-univariate_past_history:].copy()
data1 = data0.reshape((1,univariate_past_history,1))
predicts = []
for i in range(univariate_future_target):  
    predict = simple_lstm_model.predict(data1)
    predict = float(predict)
    predicts.append(predict)
    
    #依次将最新的预测值(单个点)作为添加到用于预测的数据的末端,首端弹出最旧的值
    data1[0, 0:-1] = data1[0, 1:]
    data1[0, -1] = predict
plt.plot(X1,predicts,linestyle="--",marker="o",label ="预测值(未来)")
plt.legend(loc="upper right")
plt.title("LSTM sine曲线 预测",fontsize =18)
plt.xlabel('Time')

我们可以看到,预测的数据点很好的反映了正弦曲线的变化趋势。

注意,除了首个预测点以外,对其它点进行预测时,除了用到历史数据外,也会用到一些预测值,所以预测多个点时,误差会积累 (图中预测的幅值大过1)。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-01-31,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python可视化编程机器学习OpenCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档