前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用TensorFlow.js进行时间序列预测

使用TensorFlow.js进行时间序列预测

作者头像
代码医生工作室
发布2019-06-22 14:33:15
1.7K0
发布2019-06-22 14:33:15
举报
文章被收录于专栏:相约机器人相约机器人

机器学习现在越来越受欢迎,越来越多的世界人口认为它是一个神奇的水晶球:预测未来何时以及将会发生什么。该实验使用人工神经网络揭示股市趋势,并展示时间序列预测根据过去的历史数据预测未来股票价格的能力。

免责声明:由于多种因素,股票市场波动是动态且不可预测的,因此该实验是100%教育,绝不是交易预测工具。

浏览演示并在Github上查看源代码

https://github.com/lonedune/tfjs-stocks

项目演练

该项目演练分为4个部分:

  1. 从在线API获取股票数据
  2. 计算给定时间窗口的简单移动平均值
  3. 训练LSTM神经网络
  4. 预测并将预测值与实际值进行比较

获取股票数据

在训练神经网络并进行任何预测之前,首先需要数据。要查找的数据类型是时间序列:按时间顺序排列的数字序列。获取这些数据的好地方来自alphavantage.co。此API允许检索过去20年中特定公司股票价格的时间顺序数据。

https://www.alphavantage.co/

API会产生以下字段:

  • 开盘价
  • 当天的最高价
  • 当天最低价
  • 收盘价(在本项目中使用)
  • 体积

为神经网络准备训练数据集,将使用收盘股票价格。这也意味着目标是预测未来的收盘价。下图显示了微软公司每周20年的收盘价。

20年的微软公司每周从alphavantage.co收盘价格数据

简单移动平均线

对于这个实验,使用监督学习,这意味着将数据馈送到神经网络,并通过将输入数据映射到输出标签来学习。准备训练数据集的一种方法是从该时间序列数据中提取移动平均值。

简单移动平均线(SMA)是一种通过查看该时间窗内所有值的平均值来识别特定时间段的趋势方向的方法。通过实验选择时间窗口中的价格数量。

例如假设过去5天的收盘价是13,15,14,16,17,SMA将是(13 + 15 + 14 + 16 + 17)/ 5 = 15.所以训练的输入数据集是单个时间窗口内的价格集,其标签是这些价格的计算移动平均值。

计算一下微软公司每周收盘价数据的SMA,窗口大小为50。

代码语言:javascript
复制
function ComputeSMA(data, window_size){  let r_avgs = [], avg_prev = 0;  for (let i = 0; i <= data.length - window_size; i++){    let curr_avg = 0.00, t = i + window_size;    for (let k = i; k < t && k <= data.length; k++){      curr_avg += data[k]['price'] / window_size;    }    r_avgs.push({ set: data.slice(i, i + window_size), avg: curr_avg });    avg_prev = curr_avg;  }  return r_avgs;}

这就是得到的,每周股票收盘价为蓝色,SMA为橙色。因为SMA是50周的移动平均线,所以它比每周价格更平滑,每周价格可能会波动。

Microsoft Corporation的简单移动平均值收盘价格数据

训练数据

可以使用每周股票价格和计算出的SMA来准备训练数据。鉴于窗口大小为50,这意味着将使用连续50周的收盘价作为训练功能,并将这50周的SMA作为训练标签。看起来像......

代码语言:javascript
复制
"Row #", "Label (Y)", "Features (X)"1, 107.9674, "[127,135.25,138.25,149.19,158.13,157.5,155.13,84.75,82.75,82.37,81.81,87.81,93,89,92.12,92.12,89.62,85.75,89.44,85.56,84.81,86.25,85.75,94.69,104.44,107.25,113.19,117.94,113.81,109.94,105.87,104.25,110.62,105.25,96.62,104.25,105.37,113.06,104.12,96.87,105.06,106.37,105.87,109.31,110,113.62,128.06,127.37,134,137.81]"2, 108.2624, "[135.25,138.25,149.19,158.13,157.5,155.13,84.75,82.75,82.37,81.81,87.81,93,89,92.12,92.12,89.62,85.75,89.44,85.56,84.81,86.25,85.75,94.69,104.44,107.25,113.19,117.94,113.81,109.94,105.87,104.25,110.62,105.25,96.62,104.25,105.37,113.06,104.12,96.87,105.06,106.37,105.87,109.31,110,113.62,128.06,127.37,134,137.81,141.75]"3, 108.3312, "[138.25,149.19,158.13,157.5,155.13,84.75,82.75,82.37,81.81,87.81,93,89,92.12,92.12,89.62,85.75,89.44,85.56,84.81,86.25,85.75,94.69,104.44,107.25,113.19,117.94,113.81,109.94,105.87,104.25,110.62,105.25,96.62,104.25,105.37,113.06,104.12,96.87,105.06,106.37,105.87,109.31,110,113.62,128.06,127.37,134,137.81,141.75,138.69]"4, 108.5638, "[149.19,158.13,157.5,155.13,84.75,82.75,82.37,81.81,87.81,93,89,92.12,92.12,89.62,85.75,89.44,85.56,84.81,86.25,85.75,94.69,104.44,107.25,113.19,117.94,113.81,109.94,105.87,104.25,110.62,105.25,96.62,104.25,105.37,113.06,104.12,96.87,105.06,106.37,105.87,109.31,110,113.62,128.06,127.37,134,137.81,141.75,138.69,149.88]"5, 108.5750, "[158.13,157.5,155.13,84.75,82.75,82.37,81.81,87.81,93,89,92.12,92.12,89.62,85.75,89.44,85.56,84.81,86.25,85.75,94.69,104.44,107.25,113.19,117.94,113.81,109.94,105.87,104.25,110.62,105.25,96.62,104.25,105.37,113.06,104.12,96.87,105.06,106.37,105.87,109.31,110,113.62,128.06,127.37,134,137.81,141.75,138.69,149.88,149.75]"6, 108.5374, "[157.5,155.13,84.75,82.75,82.37,81.81,87.81,93,89,92.12,92.12,89.62,85.75,89.44,85.56,84.81,86.25,85.75,94.69,104.44,107.25,113.19,117.94,113.81,109.94,105.87,104.25,110.62,105.25,96.62,104.25,105.37,113.06,104.12,96.87,105.06,106.37,105.87,109.31,110,113.62,128.06,127.37,134,137.81,141.75,138.69,149.88,149.75,156.25]"7, 108.8874, "[155.13,84.75,82.75,82.37,81.81,87.81,93,89,92.12,92.12,89.62,85.75,89.44,85.56,84.81,86.25,85.75,94.69,104.44,107.25,113.19,117.94,113.81,109.94,105.87,104.25,110.62,105.25,96.62,104.25,105.37,113.06,104.12,96.87,105.06,106.37,105.87,109.31,110,113.62,128.06,127.37,134,137.81,141.75,138.69,149.88,149.75,156.25,175]"8, 108.9848, "[84.75,82.75,82.37,81.81,87.81,93,89,92.12,92.12,89.62,85.75,89.44,85.56,84.81,86.25,85.75,94.69,104.44,107.25,113.19,117.94,113.81,109.94,105.87,104.25,110.62,105.25,96.62,104.25,105.37,113.06,104.12,96.87,105.06,106.37,105.87,109.31,110,113.62,128.06,127.37,134,137.81,141.75,138.69,149.88,149.75,156.25,175,160]"9, 110.4448, "[82.75,82.37,81.81,87.81,93,89,92.12,92.12,89.62,85.75,89.44,85.56,84.81,86.25,85.75,94.69,104.44,107.25,113.19,117.94,113.81,109.94,105.87,104.25,110.62,105.25,96.62,104.25,105.37,113.06,104.12,96.87,105.06,106.37,105.87,109.31,110,113.62,128.06,127.37,134,137.81,141.75,138.69,149.88,149.75,156.25,175,160,157.75]"10, 111.7448, "[82.37,81.81,87.81,93,89,92.12,92.12,89.62,85.75,89.44,85.56,84.81,86.25,85.75,94.69,104.44,107.25,113.19,117.94,113.81,109.94,105.87,104.25,110.62,105.25,96.62,104.25,105.37,113.06,104.12,96.87,105.06,106.37,105.87,109.31,110,113.62,128.06,127.37,134,137.81,141.75,138.69,149.88,149.75,156.25,175,160,157.75,147.75]"

接下来将数据分为2组,即训练和验证集。如果70%的数据用于训练,则30%用于验证。API返回大约1000周的数据,因此700个用于训练,300个用于验证。

训练神经网络

现在训练数据准备好了,是时候为时间序列预测创建一个模型,为实现这个目的,将使用TensorFlow.js框架。TensorFlow.js是一个用JavaScript开发和训练机器学习模型的库,可以在Web浏览器中部署这些机器学习功能。

选择顺序模型,其简单地连接每个层并在训练过程中将数据从输入传递到输出。为了使模型学习顺序的时间序列数据,创建递归神经网络(RNN)层并且将多个LSTM单元添加到RNN。

该模型将使用Adam(研究论文)进行训练,这是一种流行的机器学习优化算法。均方根误差将决定预测值与实际值之间的差异,因此模型能够通过最小化训练过程中的误差来学习。

这是上述模型的代码片段。

代码语言:javascript
复制
async function trainModel(inputs, outputs, trainingsize, window_size, n_epochs, learning_rate, n_layers, callback){   const input_layer_shape  = window_size;  const input_layer_neurons = 100;   const rnn_input_layer_features = 10;  const rnn_input_layer_timesteps = input_layer_neurons / rnn_input_layer_features;   const rnn_input_shape  = [rnn_input_layer_features, rnn_input_layer_timesteps];  const rnn_output_neurons = 20;   const rnn_batch_size = window_size;   const output_layer_shape = rnn_output_neurons;  const output_layer_neurons = 1;   const model = tf.sequential();   let X = inputs.slice(0, Math.floor(trainingsize / 100 * inputs.length));  let Y = outputs.slice(0, Math.floor(trainingsize / 100 * outputs.length));   const xs = tf.tensor2d(X, [X.length, X[0].length]).div(tf.scalar(10));  const ys = tf.tensor2d(Y, [Y.length, 1]).reshape([Y.length, 1]).div(tf.scalar(10));   model.add(tf.layers.dense({units: input_layer_neurons, inputShape: [input_layer_shape]}));  model.add(tf.layers.reshape({targetShape: rnn_input_shape}));   let lstm_cells = [];  for (let index = 0; index < n_layers; index++) {       lstm_cells.push(tf.layers.lstmCell({units: rnn_output_neurons}));  }   model.add(tf.layers.rnn({    cell: lstm_cells,    inputShape: rnn_input_shape,    returnSequences: false  }));   model.add(tf.layers.dense({units: output_layer_neurons, inputShape: [output_layer_shape]}));   model.compile({    optimizer: tf.train.adam(learning_rate),    loss: 'meanSquaredError'  });   const hist = await model.fit(xs, ys,    { batchSize: rnn_batch_size, epochs: n_epochs, callbacks: {      onEpochEnd: async (epoch, log) => {        callback(epoch, log);      }    }  });   return { model: model, stats: hist };}

这些是可用于在前端进行调整的超参数(在训练过程中使用的参数):

  • 训练数据集大小(%):用于训练的数据量,剩余数据将用于验证
  • 时期:数据集用于训练模型的次数
  • 学习率:每个步骤中训练期间的权重变化量
  • 隐藏的LSTM图层:增加模型复杂度以在更高维空间学习

Web前端[ https://lonedune.github.io/tfjs-stocks/demo/ ],显示可用于调整的参数

单击Begin Training Model按钮...

训练模型UI

该模型似乎收敛于大约15个时代。

验证和预测

现在模型已经过训练,现在是时候用它来预测未来的值,它是移动平均线。实际上使用剩余的30%的数据进行预测,这能够看到预测值与实际值的接近程度。

绿线表示验证数据的预测

这意味着该模型看不到最后30%的数据,看起来该模型可以很好地绘制与移动平均线密切相关的数据。

结论

除了使用简单的移动平均线之外,还有很多方法可以进行时间序列预测。未来可能的工作是使用来自各种来源的更多数据来实现这一点。

使用TensorFlow.js,可以在Web浏览器上进行机器学习,这实际上非常酷。

在Github上探索演示,这个实验是100%教育,绝不是交易预测工具:

股票预测(TensorFlow.js)

https://lonedune.github.io/tfjs-stocks/demo/

在Github上查看源代码:

https://github.com/lonedune/tfjs-stocks

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-05-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 相约机器人 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档