前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow学习:使用Tensorflow搭建深层网络分类器

Tensorflow学习:使用Tensorflow搭建深层网络分类器

作者头像
机器学习AI算法工程
发布2018-03-15 10:57:29
6810
发布2018-03-15 10:57:29
举报

根据官方文档整理而来的,主要是对Iris数据集进行分类。使用tf.contrib.learn.tf.contrib.learn快速搭建一个深层网络分类器,

步骤

  1. 导入csv数据
  2. 搭建网络分类器
  3. 训练网络
  4. 计算测试集正确率
  5. 对新样本进行分类

数据

Iris数据集包含150行数据,有三种不同的Iris品种分类。每一行数据给出了四个特征信息和一个分类信息。

现在已经将数据分为训练集和测试集

- A training set of 120 samples

http://download.tensorflow.org/data/iris_training.csv

- A test set of 30 samples

http://download.tensorflow.org/data/iris_test.csv

网络搭建

1. 首先,导入tensorflow 和 numpy

代码语言:javascript
复制

2. 导入数据

# 定义数据地址

IRIS_TRAINING = "iris_training.csv"

IRIS_TEST = "iris_test.csv"

# 导入数据

training_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TRAINING,

target_dtype=np.int,

features_dtype=np.float32)

test_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TEST,

target_dtype=np.int,

features_dtype=np.float32)

load_csv_with_header() 有三个参数 - filename, 数据地址 - target_dtype, 目标值的numpy datatype(iris的目标值是0,1,2,所以是np.int) - features_dtype, 特征值的numpy datatype .

3. 搭建网络结构

# 每行数据4个特征,都是real-value的

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# 3层DNN,3分类问题

classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,

hidden_units=[10, 20, 10],

n_classes=3,

model_dir="/tmp/iris_model")

参数解释 - feature_columns 特征值 - hidden_units=[10, 20, 10]. 3个隐藏层,包含的隐藏神经元依次是10, 20, 10 - n_classes 类别个数 - model_dir 模型保存地址

4. 训练数据

classifier.fit(x=training_set.data, y=training_set.target, steps=2000)

steps 为训练次数

5. 计算准确率

accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"]

print('Accuracy: {0:f}'.format(accuracy_score))

运行结果是

代码语言:javascript
复制

6. 对新样本进行预测

# Classify two new flower samples.

new_samples = np.array(

[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)

y = list(classifier.predict(new_samples, as_iterable=True))

print('Predictions: {}'.format(str(y)))

运行结果为:

代码语言:javascript
复制

完整代码

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

import tensorflow as tf

import numpy as np

IRIS_TRAINING = "iris_training.csv"

IRIS_TEST = "iris_test.csv"

training_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TRAINING,

target_dtype=np.int,

features_dtype=np.float32)

test_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TEST,

target_dtype=np.int,

features_dtype=np.float32)

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,

hidden_units=[10, 20, 10],

n_classes=3,

model_dir="/tmp/iris_model")

classifier.fit(x=training_set.data,

y=training_set.target,

steps=2000)

accuracy_score = classifier.evaluate(x=test_set.data,

y=test_set.target)["accuracy"]

print('Accuracy: {0:f}'.format(accuracy_score))

new_samples = np.array(

[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)

y = list(classifier.predict(new_samples, as_iterable=True))

print('Predictions: {}'.format(str(y)))

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2017-07-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 大数据挖掘DT数据分析 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 步骤
  • 数据
  • http://download.tensorflow.org/data/iris_test.csv
  • 网络搭建
    • 1. 首先,导入tensorflow 和 numpy
      • 2. 导入数据
        • 3. 搭建网络结构
          • 4. 训练数据
            • 5. 计算准确率
              • 6. 对新样本进行预测
              • 完整代码
              领券
              问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档