前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow | 读取csv文件

Tensorflow | 读取csv文件

作者头像
努力在北京混出人样
发布2019-02-18 16:13:35
1.8K0
发布2019-02-18 16:13:35
举报
文章被收录于专栏:祥子的故事祥子的故事

如何将CSV数据读入到tensorflow中,这个问题困扰了我好几天,下面来说一种我现在用到的方法。

待有新的读取方法 ,本帖保持更新

  • 方法一: 以一个案例来切入:
代码语言:javascript
复制
#加载包
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

# 数据集名称,数据集要放在你的工作目录下
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"

# 数据集读取,训练集和测试集
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TRAINING,
    target_dtype=np.int,
    features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TEST,
    target_dtype=np.int,
    features_dtype=np.float32)

# 特征
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# 构建DNN网络,3层,每层分别为10,20,10个节点
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=3,
                                            model_dir="/tmp/iris_model")

# 拟合模型,迭代2000步
classifier.fit(x=training_set.data,
               y=training_set.target,
               steps=2000)

# 计算精度
accuracy_score = classifier.evaluate(x=test_set.data,y=test_set.target)["accuracy"]

print('Accuracy: {0:f}'.format(accuracy_score))

# 预测新样本的类别
new_samples = np.array([[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = list(classifier.predict(new_samples, as_iterable=True))
print('Predictions: {}'.format(str(y)))

结果好长,给出关键的部分: INFO:tensorflow:Saving evaluation summary for step 12001: accuracy = 0.966667, loss = 0.461221 Accuracy: 0.966667

预测结果: Predictions: [1, 1]

从上面的代码可以发现,读取方式为:

代码语言:javascript
复制
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TRAINING,
    target_dtype=np.int,
    features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TEST,
    target_dtype=np.int,
    features_dtype=np.float32)

IRIS_TRAINING :训练集 IRIS_TEST:测试集

特征提取:

代码语言:javascript
复制
# 特征
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

完整的代码见github:https://github.com/zhangdm/machine-learning-summary/tree/master/tensorflow/tensorflow_read_csv_DNN

  • 方法二:
代码语言:javascript
复制
#加载包
import tensorflow as tf
import os

#设置工作目录
os.chdir("你自己的目录")
#查看目录
print(os.getcwd())

#读取函数定义
def read_data(file_queue):
    reader = tf.TextLineReader(skip_header_lines=1)
    key, value = reader.read(file_queue)
    #定义列
    defaults = [[0], [0.], [0.], [0.], [0.], ['']]
 #编码   Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species = tf.decode_csv(value, defaults)

    #处理
    preprocess_op = tf.case({
        tf.equal(Species, tf.constant('Iris-setosa')): lambda: tf.constant(0),
        tf.equal(Species, tf.constant('Iris-versicolor')): lambda: tf.constant(1),
        tf.equal(Species, tf.constant('Iris-virginica')): lambda: tf.constant(2),
    }, lambda: tf.constant(-1), exclusive=True)

    #栈
    return tf.stack([SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm]), preprocess_op


def create_pipeline(filename, batch_size, num_epochs=None):
    file_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)
    example, label = read_data(file_queue)

    min_after_dequeue = 1000
    capacity = min_after_dequeue + batch_size
    example_batch, label_batch = tf.train.shuffle_batch(
        [example, label], batch_size=batch_size, capacity=capacity,
        min_after_dequeue=min_after_dequeue
    )

    return example_batch, label_batch

x_train_batch, y_train_batch = create_pipeline('Iris-train.csv', 50, num_epochs=1000)
x_test, y_test = create_pipeline('Iris-test.csv', 60)
print(x_train_batch,y_train_batch)

结果: Tensor(“shuffle_batch_2:0”, shape=(50, 4), dtype=float32) Tensor(“shuffle_batch_2:1”, shape=(50,), dtype=int32)

从它的数据维度可知,数据已经读入。

一个完整的例子见github:https://github.com/zhangdm/machine-learning-summary/tree/master/tensorflow/tensorflow_iris_nn

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017年01月22日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档