使用Tensorflow进行时序预测(TFTS)

时序预测是一个经典的话题,应用面也很广; 结合LSTM来做也是一个效果比较好的方式. 这次准备使用TF来进行时序预测,计划写两篇: 1. 使用Tensorflow Time Series模块 2. 使用底层点的LSTM Cell

这就是第一篇啦,Time Series Prediction via TFTS. 来源于此: TFTS介绍,略有加工整理,侵删!

TFTS

Tensorflow Time Series(TFTS)模块是TF1.3版本中引入的,官方是这么介绍的:

TensorFlow Time Series (TFTS) is a collection of ready-to-use classic models (state space, autoregressive), and flexible infrastructure for building high-performance time series models with custom architectures.

地址: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/timeseries, 里面给出了相关的examples.

主要提供三种预测模型: AR、Anomaly Mixture AR、LSTM

Examples

读入数据

你的数据可以是两种: 1. numpy array 2. from a CSV file

对于第一种: TFTS中可以使用NumpyReader

x = np.array(range(1000))
noise = np.random.uniform(-0.2, 0.2, 1000)
y = np.sin(np.pi * x / 100) + x / 200. + noise

data = {
    tf.contrib.timeseries.TrainEvalFeatures.TIMES: x,
    tf.contrib.timeseries.TrainEvalFeatures.VALUES: y,
}

reader = NumpyReader(data)

data是一个dict,’TIMES’和’VALUES’就是字符串的’times’和’values’,所以理论上你写成: data = {'times':x, 'values':y},也是可以的.

对于第二种: TFTS中提供了CSVReader

csv_file_name = './data/period_trend.csv'
reader = tf.contrib.timeseries.CSVReader(csv_file_name)

获得batch数据

对于以上两种Reader(CSVReader和NumpyReader),TFTS提供了对应的read_full()方法,返回的是时间序列的Tensor,read_full()产生读取队列,所以需要启动queue_runner然后才能run出值.

with tf.Session() as sess:
    data = reader.read_full()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    print(sess.run(data))
    coord.request_stop()

获取batch数据也很简单: TFTS提供RandomWindowInputFn

train_input_fn = tf.contrib.timeseries.RandomWindowInputFn(reader, batch_size=4, window_size=16)

tf.contrib.timeseries.RandomWindowInputFn会在reader的所有数据中,随机选取窗口长度为window_size的序列,并包装成batch_size大小的batch数据。换句话说,一个batch内共有batch_size个序列,每个序列的长度为window_size.

AR Model

AR(Auto Regression)是统计学上的方法,可以参考wiki: https://en.wikipedia.org/wiki/Autoregressive_model,主要的思想是假设当前值与前面出现的值是线性关系. autoregressive model specifies that the output variable depends linearly on its own previous values and on a stochastic term (an imperfectly predictable term).

对于AR模型:TFTS提供了ARRegrerssior

ar = tf.contrib.timeseries.ARRegressor(
        periodicities=200, input_window_size=30, output_window_size=10,
        num_features=1,
        loss=tf.contrib.timeseries.ARModel.NORMAL_LIKELIHOOD_LOSS)

在这里,我们总的window_size为40,input_window_size为30,output_window_size为10,也就是说,一个batch内每个序列的长度为40,其中前30个数被当作模型的输入值,后面10个数为这些输入对应的目标输出值。最后一个参数loss指定采取哪一种损失,一共有两种损失可以选择,分别是NORMAL_LIKELIHOOD_LOSS和SQUARED_LOSS. num_features参数表示在一个时间点上观察到的数的维度。我们这里每一步都是一个单独的值,所以num_features=1。 还有一个比较重要的参数是model_dir,它表示模型训练好后保存的地址,如果不指定的话,就会随机分配一个临时地址.

训练、验证(对训练集进行)、测试:

ar.train(input_fn=train_input_fn, steps=1000)

evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader)
# keys of evaluation: ['covariance', 'loss', 'mean', 'observed', 'start_tuple', 'times', 'global_step']
evaluation = ar.evaluate(input_fn=evaluation_input_fn, steps=1)

(predictions,) = tuple(ar.predict(
        input_fn=tf.contrib.timeseries.predict_continuation_input_fn(
            evaluation, steps=250)))  #预测之后的250步

结果大概是这样:

红色是预测的那一段.

LSTM

必须使用TF最新的开发版的代码,就是要保证’rom tensorflow.contrib.timeseries.python.timeseries.estimators import TimeSeriesRegressor’可以导入成功.

estimator = ts_estimators.TimeSeriesRegressor(
      model=_LSTMModel(num_features=1, num_units=128),
      optimizer=tf.train.AdamOptimizer(0.001))

_LSTMModel是一个class,可以直接copy官方给的代码.

接下来的训练、验证、测试和上面的AR模型是一样的.

结果大概是这样的:

自己只把关键点写出来了,有需要的可以去看原文,原文比较详细. 代码地址: https://github.com/hzy46/TensorFlow-Time-Series-Examples

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏SnailTyan

Caffe源码解析(一) —— caffe.proto

caffe.proto是caffe数据结构定义的主要文件,本文主要是在caffe.proto代码的基础上加上了部分中文注释,其中的内容与caffe的protot...

4495
来自专栏强仔仔

SpringBoot中实现邮件找回密码的功能

今天给大家介绍一下很常用的一个功能,就是邮件找回密码功能。找回密码一般会有:1.邮件找回密码、2短信找回密码、3问题找会密码。 关于邮件找回密码的原理思想为: ...

2388
来自专栏mathor

从暴力递归到动态规划

 动态规划没有那么难,但是很多老师在讲课的过程中讲的并不好,由此写下一篇文章记录学习过程

751
来自专栏wym

卡特兰数

      简介:卡特兰数又称卡塔兰数,英文名Catalan number,是组合数学中一个常出现在各种计数问题中出现的数列。由以比利时的数学家欧仁·查理·卡塔...

1340
来自专栏企鹅号快讯

PaddlePaddle之手写数字识别

作者:Charlotte77数学系的数据挖掘民工 博客专栏:http://www.cnblogs.com/charlotte77/ 个人公众号:Charlott...

2048
来自专栏智能计算时代

45测试深度学习基础知识的数据科学家的问题(以及解决方案)

原文:https://www.analyticsvidhya.com/blog/2017/01/must-know-questions-deep-learnin...

3266
来自专栏Python小屋

Python文本处理2个小案例(文本嗅探与关键词占比统计)

问题描述:有一些句子和一些关键词,现在想找出包含至少一个关键词的那些句子(文本嗅探),可以参考print('='*30)之前的代码。如果想进一步计算每个句子中的...

33711
来自专栏何俊林

H.264技术及原理

H.264组成 1、网络提取层 (Network Abstraction Layer,NAL) 2、视讯编码层 (Video Coding Layer,VCL)...

1889
来自专栏瓜大三哥

多任务验证码识别

使用Alexnet网络进行训练,多任务学习:验证码是根据随机字符生成一幅图片,然后在图片中加入干扰象素,用户必须手动填入,防止有人利用机器人自动批量注册、灌水、...

4937
来自专栏深度学习入门与实践

【深度学习系列】PaddlePaddle之手写数字识别

上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下。不过呢,这块...

2849

扫码关注云+社区