Tensorflow学习:使用Tensorflow搭建深层网络分类器

根据官方文档整理而来的,主要是对Iris数据集进行分类。使用tf.contrib.learn.tf.contrib.learn快速搭建一个深层网络分类器,

步骤

  1. 导入csv数据
  2. 搭建网络分类器
  3. 训练网络
  4. 计算测试集正确率
  5. 对新样本进行分类

数据

Iris数据集包含150行数据,有三种不同的Iris品种分类。每一行数据给出了四个特征信息和一个分类信息。

现在已经将数据分为训练集和测试集

- A training set of 120 samples

http://download.tensorflow.org/data/iris_training.csv

- A test set of 30 samples

http://download.tensorflow.org/data/iris_test.csv

网络搭建

1. 首先,导入tensorflow 和 numpy

2. 导入数据

# 定义数据地址

IRIS_TRAINING = "iris_training.csv"

IRIS_TEST = "iris_test.csv"

# 导入数据

training_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TRAINING,

target_dtype=np.int,

features_dtype=np.float32)

test_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TEST,

target_dtype=np.int,

features_dtype=np.float32)

load_csv_with_header() 有三个参数 - filename, 数据地址 - target_dtype, 目标值的numpy datatype(iris的目标值是0,1,2,所以是np.int) - features_dtype, 特征值的numpy datatype .

3. 搭建网络结构

# 每行数据4个特征,都是real-value的

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# 3层DNN,3分类问题

classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,

hidden_units=[10, 20, 10],

n_classes=3,

model_dir="/tmp/iris_model")

参数解释 - feature_columns 特征值 - hidden_units=[10, 20, 10]. 3个隐藏层,包含的隐藏神经元依次是10, 20, 10 - n_classes 类别个数 - model_dir 模型保存地址

4. 训练数据

classifier.fit(x=training_set.data, y=training_set.target, steps=2000)

steps 为训练次数

5. 计算准确率

accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"]

print('Accuracy: {0:f}'.format(accuracy_score))

运行结果是

6. 对新样本进行预测

# Classify two new flower samples.

new_samples = np.array(

[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)

y = list(classifier.predict(new_samples, as_iterable=True))

print('Predictions: {}'.format(str(y)))

运行结果为:

完整代码

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

import tensorflow as tf

import numpy as np

IRIS_TRAINING = "iris_training.csv"

IRIS_TEST = "iris_test.csv"

training_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TRAINING,

target_dtype=np.int,

features_dtype=np.float32)

test_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TEST,

target_dtype=np.int,

features_dtype=np.float32)

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,

hidden_units=[10, 20, 10],

n_classes=3,

model_dir="/tmp/iris_model")

classifier.fit(x=training_set.data,

y=training_set.target,

steps=2000)

accuracy_score = classifier.evaluate(x=test_set.data,

y=test_set.target)["accuracy"]

print('Accuracy: {0:f}'.format(accuracy_score))

new_samples = np.array(

[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)

y = list(classifier.predict(new_samples, as_iterable=True))

print('Predictions: {}'.format(str(y)))

本文分享自微信公众号 - 大数据挖掘DT数据分析(datadw)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2017-07-05

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人人都是极客

OpenCV和SVM分类器在自动驾驶中的车辆检测

这次文章的车辆检测在车辆感知模块中是非常重要的功能,本节课我们的目标如下:

899100
来自专栏北京马哥教育

很吓人的技术,200行Python代码做一个换脸程序

在这篇文章中将介绍如何写一个简短(200行)的 Python 脚本,来自动地将一幅图片的脸替换为另一幅图片的脸。

11200
来自专栏AI科技大本营的专栏

AI 技术讲座精选:​通过学习Keras从零开始实现VGG网络

Keras代码示例多达数百个。通常我们只需复制粘贴代码,而无需真正理解这些代码。通过学习本教程,您将搭建非常简单的构架,但是此过程会带给您些许好处:您将通过阅读...

40590
来自专栏贾志刚-OpenCV学堂

干货 | OpenCV中KLT光流跟踪原理详解与代码演示

在视频移动对象跟踪中,稀疏光流跟踪是一种经典的对象跟踪算法,可以绘制运动对象的跟踪轨迹与运行方向,是一种简单、实时高效的跟踪算法,这个算法最早是有Bruce D...

1.4K20
来自专栏人人都是极客

OpenCV和SVM分类器在自动驾驶中的车辆检测

这次文章的车辆检测在车辆感知模块中是非常重要的功能,本节课我们的目标如下: 在标记的图像训练集上进行面向梯度的直方图(HOG)特征提取并训练分类器线性SVM分类...

1.4K70
来自专栏专知

【专知-Deeplearning4j深度学习教程03】使用多层神经网络分类MNIST数据集:图文+代码

【导读】主题链路知识是我们专知的核心功能之一,为用户提供AI领域系统性的知识学习服务,一站式学习人工智能的知识,包含人工智能( 机器学习、自然语言处理、计算机视...

499110
来自专栏磐创AI技术团队的专栏

FaceRank-人脸打分基于 TensorFlow 的 CNN 模型,这个妹子颜值几分?

FaceRank-人脸打分基于 TensorFlow 的 CNN 模型 机器学习是不是很无聊,用来用去都是识别字体。能不能帮我找到颜值高的妹子,顺便提高一下姿势...

47340
来自专栏人工智能LeadAI

C++实现神经网络之一 | Net类的设计和神经网络的初始化

闲言少叙,直接开始 既然是要用C++来实现,那么我们自然而然的想到设计一个神经网络类来表示神经网络,这里我称之为Net类。由于这个类名太过普遍,很有可能跟其他人...

34850
来自专栏贾志刚-OpenCV学堂

OpenCV直线拟合检测

OpenCV直线拟合检测 霍夫直线检测容易受到线段形状与噪声的干扰而失真,这个时候我们需要另辟蹊径,通过对图像进行二值分析,提取骨架,对骨架像素点拟合生成直线,...

2.4K50
来自专栏人工智能

Net类的设计和神经网络的初始化

? NVIDIA 深度学习学院 带你快速进入火热的DL领域 正文共4898个字,2张图,预计阅读时间28分钟。 闲言少叙,直接开始 既然是要用C++来实现,那...

21060

扫码关注云+社区

领取腾讯云代金券