TensorFlow的高阶接口Estimator的使用(1)

在《TensorFlow机器学习项目实战》的4.4节,作者使用了skflow。skflow刚出来的时候火了一阵,但是接口变化非常频繁,所以后来用的人也越来越少,也导致4.4的程序不能运行了。

但是最近发布的TensorFlow 1.4中,我们发现该模块已经集成到了核心模块,意味着接口基本稳定下来,并有推广使用的趋势。所以我把4.4的程序重新用Estimator写了一下,变量名基本保持不变,代码如下:

import tensorflow as tf

from sklearn import datasets, metrics, preprocessing

import numpy as np

import pandas as pd

import os

df = pd.read_csv("data/CHD.csv", header=0)

print( df.describe())

X=df['age'].astype(float)

classifier.train(inputfn=inputfn_train,steps=2000)

模型的准确度

score = classifier.evaluate(inputfn=inputfn_train,steps=50)["accuracy"]

print("Accuracy: %f" % score)

注:这段程序可以在Ubuntu和MacOS下面跑,但是Windows下面还不行,是路径的问题。这应该是Estimator的一个BUG,在contrib.learn下也是一样的不行。

这里面最难写的是input_fn函数,也是最重要的函数,我在这段程序中直接使用了numpy_input_fn来构建。[1]中除了这个方法还给出了从pandas构建的方法,大家可以自己尝试。

input_fn带来了一个好处,就是可以按照生产者消费者模式读取数据,具体的解释可以参考[2]。简单的解释,就是IO一般都比较慢,我们需要在数据处理的过程中进行读取数据,那样就可以充分的节省时间,这样就设计多线程在后台不断的取数据。

feature_colums的构建需要一定的技巧,这个主要参考[3]

另外的一个变化就是模型的准确度不再是用metric模块,而是Estimator自带的模块。

如果大家有什么问题欢迎留言。

TensorFlow机器学习项目实战

作者:【阿根廷】Rodolfo Bonnin

深度学习人工智能参考书

第二代机器学习实战指南

提供深度学习神经网络等项目实战

有效改善项目速度和效率

TensorFlow是Google所主导的机器学习框架,也是机器学习领域研究和应用的热门对象。

本书主要介绍如何使用TensorFlow库实现各种各样的模型,旨在降低学习门槛,并为读者解决问题提供详细的方法和指导。全书共10章,分别介绍了TensorFlow基础知识、聚类、线性回归、逻辑回归、不同的神经网络、规模化运行模型以及库的应用技巧。

本书适合想要学习和了解 TensorFlow 和机器学习的读者阅读参考。如果读者具备一定的C++和Python的经验,将能够更加轻松地阅读和学习本书。

  • 发表于:
  • 原文链接:http://kuaibao.qq.com/s/20171225A0444700?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。

扫码关注云+社区

领取腾讯云代金券