今天在写NCF代码的时候,发现网络上的代码有一种新的数据读取方式,这里将对应的片段剪出来给大家分享下。
原始数据
我们的原始数据保存在npy文件中,是一个字典类型,有三个key,分别是user,item和label:
构建tf的Dataset
使用tf.data.Dataset.from_tensor_slices方法,将我们的数据变成tensorflow的DataSet:
进一步,将我们的Dataset变成一个BatchDataset,这样的话,在迭代数据的时候,就可以一次返回一个batch大小的数据:
可以看到,我们在变成batch之前使用了一个shuffle对数据进行打乱,100表示buffersize,即每取1000个打乱一次。
此时dataset有两个属性,分别是output_shapes和output_types,我们将根据这两个属性来构造迭代器,用于迭代数据。
构造迭代器
我们使用上面提到的两个dataset的属性,并使用tf.data.Iterator.from_structure方法来构造一个迭代器:
迭代器需要初始化:
此时,就可以使用get_next(),方法来源源不断的读取batch大小的数据了
使用迭代器的正确姿势
我们这里来计算返回的每个batch中,user和item的平均值:
迭代器iterator只能往前遍历,如果遍历完之后还调用get_next()的话,会报tf.errors.OutOfRangeError错误,因此需要使用try-catch。
如果想要多次遍历数据的话,初始化外面包裹一层循环即可:
完整代码
参考文献:
1、Facebook的paper:http://quinonero.net/Publications/predicting-clicks-facebook.pdf
2、http://www.cbdio.com/BigData/2015-08/27/content_3750170.htm
3、https://blog.csdn.net/shine19930820/article/details/71713680
4、https://www.zhihu.com/question/35821566
5、https://github.com/neal668/LightGBM-GBDT-LR/blob/master/GBFT%2BLR_simple.py
领取专属 10元无门槛券
私享最新 技术干货