前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用Tensorflow的DataSet和Iterator读取数据!

使用Tensorflow的DataSet和Iterator读取数据!

作者头像
石晓文
发布2018-07-25 14:44:45
2.1K0
发布2018-07-25 14:44:45
举报
文章被收录于专栏:小小挖掘机小小挖掘机

今天在写NCF代码的时候,发现网络上的代码有一种新的数据读取方式,这里将对应的片段剪出来给大家分享下。

原始数据 我们的原始数据保存在npy文件中,是一个字典类型,有三个key,分别是user,item和label:

代码语言:javascript
复制
data = np.load('data/test_data.npy').item()
print(type(data))

#output
<class 'dict'>

构建tf的Dataset 使用 tf.data.Dataset.from_tensor_slices方法,将我们的数据变成tensorflow的DataSet:

代码语言:javascript
复制
dataset = tf.data.Dataset.from_tensor_slices(data)
print(type(dataset))
#output
<class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>

进一步,将我们的Dataset变成一个BatchDataset,这样的话,在迭代数据的时候,就可以一次返回一个batch大小的数据:

代码语言:javascript
复制
dataset = dataset.shuffle(1000).batch(100)
print(type(dataset))

#output
<class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>

可以看到,我们在变成batch之前使用了一个shuffle对数据进行打乱,100表示buffersize,即每取1000个打乱一次。

此时dataset有两个属性,分别是output_shapes和output_types,我们将根据这两个属性来构造迭代器,用于迭代数据。

代码语言:javascript
复制
print(dataset.output_shapes)
print(dataset.output_types)

#output
{'user': TensorShape([Dimension(None)]), 'item': TensorShape([Dimension(None)]), 'label': TensorShape([Dimension(None)])}
{'user': tf.int32, 'item': tf.int32, 'label': tf.int32}

构造迭代器 我们使用上面提到的两个dataset的属性,并使用tf.data.Iterator.from_structure方法来构造一个迭代器:

代码语言:javascript
复制
iterator = tf.data.Iterator.from_structure(dataset.output_types,
                                            dataset.output_shapes)

迭代器需要初始化:

代码语言:javascript
复制
 sess.run(iterator.make_initializer(dataset))

此时,就可以使用get_next(),方法来源源不断的读取batch大小的数据了

代码语言:javascript
复制
def getBatch():
    sample = iterator.get_next()
    print(sample)
    user = sample['user']
    item = sample['item']
    return user,item

使用迭代器的正确姿势 我们这里来计算返回的每个batch中,user和item的平均值:

代码语言:javascript
复制
users,items = getBatch()
usersum = tf.reduce_mean(users,axis=-1)
itemsum = tf.reduce_mean(items,axis=-1)

迭代器iterator只能往前遍历,如果遍历完之后还调用get_next()的话,会报tf.errors.OutOfRangeError错误,因此需要使用try-catch。

代码语言:javascript
复制
try:
    while True:
        print(sess.run([usersum,itemsum]))
except tf.errors.OutOfRangeError:
    print("outOfRange")  

如果想要多次遍历数据的话,初始化外面包裹一层循环即可:

代码语言:javascript
复制
for i in range(2):
    sess.run(iterator.make_initializer(dataset))
    try:
        while True:
            print(sess.run([usersum,itemsum]))
    except tf.errors.OutOfRangeError:
        print("outOfRange")

完整代码

代码语言:javascript
复制
import numpy as np
import tensorflow as tf


data = np.load('data/test_data.npy').item()
print(type(data))


dataset = tf.data.Dataset.from_tensor_slices(data)
print(type(dataset))
dataset = dataset.shuffle(10000).batch(100)
print(type(dataset))

print(dataset.output_shapes)
print(dataset.output_types)

iterator = tf.data.Iterator.from_structure(dataset.output_types,
                                            dataset.output_shapes)

print(type(iterator))


def getBatch():
    sample = iterator.get_next()
    print(sample)
    user = sample['user']
    item = sample['item']
    return user,item


users,items = getBatch()
usersum = tf.reduce_mean(users,axis=-1)
itemsum = tf.reduce_mean(items,axis=-1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for i in range(2):
        sess.run(iterator.make_initializer(dataset))
        try:
            while True:
                print(sess.run([usersum,itemsum]))
        except tf.errors.OutOfRangeError:
            print("outOfRange")

参考文献:

1、Facebook的paper:http://quinonero.net/Publications/predicting-clicks-facebook.pdf 2、http://www.cbdio.com/BigData/2015-08/27/content_3750170.htm 3、https://blog.csdn.net/shine19930820/article/details/71713680 4、https://www.zhihu.com/question/35821566 5、https://github.com/neal668/LightGBM-GBDT-LR/blob/master/GBFT%2BLR_simple.py

推荐阅读:

推荐系统遇上深度学习系列:

推荐系统遇上深度学习(一)--FM模型理论和实践

推荐系统遇上深度学习(二)--FFM模型理论和实践

推荐系统遇上深度学习(三)--DeepFM模型理论和实践

推荐系统遇上深度学习(四)--多值离散特征的embedding解决方案

推荐系统遇上深度学习(五)--Deep&Cross Network模型理论和实践

推荐系统遇上深度学习(六)--PNN模型理论和实践

推荐系统遇上深度学习(七)--NFM模型理论和实践

推荐系统遇上深度学习(八)--AFM模型理论和实践

推荐系统遇上深度学习(九)--评价指标AUC原理及实践

推荐系统遇上深度学习(十)--GBDT+LR融合方案实战

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-06-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 小小挖掘机 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 参考文献:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档