学习
实践
活动
专区
工具
TVP
写文章

使用Tensorflow的DataSet和Iterator读取数据!

今天在写NCF代码的时候,发现网络上的代码有一种新的数据读取方式,这里将对应的片段剪出来给大家分享下。

原始数据

我们的原始数据保存在npy文件中,是一个字典类型,有三个key,分别是user,item和label:

构建tf的Dataset

使用tf.data.Dataset.from_tensor_slices方法,将我们的数据变成tensorflow的DataSet:

进一步,将我们的Dataset变成一个BatchDataset,这样的话,在迭代数据的时候,就可以一次返回一个batch大小的数据:

可以看到,我们在变成batch之前使用了一个shuffle对数据进行打乱,100表示buffersize,即每取1000个打乱一次。

此时dataset有两个属性,分别是output_shapes和output_types,我们将根据这两个属性来构造迭代器,用于迭代数据。

构造迭代器

我们使用上面提到的两个dataset的属性,并使用tf.data.Iterator.from_structure方法来构造一个迭代器:

迭代器需要初始化:

此时,就可以使用get_next(),方法来源源不断的读取batch大小的数据了

使用迭代器的正确姿势

我们这里来计算返回的每个batch中,user和item的平均值:

迭代器iterator只能往前遍历,如果遍历完之后还调用get_next()的话,会报tf.errors.OutOfRangeError错误,因此需要使用try-catch。

如果想要多次遍历数据的话,初始化外面包裹一层循环即可:

完整代码

参考文献:

1、Facebook的paper:http://quinonero.net/Publications/predicting-clicks-facebook.pdf

2、http://www.cbdio.com/BigData/2015-08/27/content_3750170.htm

3、https://blog.csdn.net/shine19930820/article/details/71713680

4、https://www.zhihu.com/question/35821566

5、https://github.com/neal668/LightGBM-GBDT-LR/blob/master/GBFT%2BLR_simple.py

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180602G1CES900?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

关注

腾讯云开发者公众号
10元无门槛代金券
洞察腾讯核心技术
剖析业界实践案例
腾讯云开发者公众号二维码

扫码关注腾讯云开发者

领取腾讯云代金券