我有一个大约有2M行和6000列的大型数据集。输入的numpy数组(X,y)可以很好地保存训练数据。但是当它转到model.fit()时,我得到了一个GPU内存不足的错误。我使用的是tensorflow 2.2。根据它的手册,model.fit_generator已经被弃用,而model.fit是首选。
有人能概述一下使用tensorflow v2.2训练大型数据集的步骤吗?
发布于 2020-06-29 06:44:18
最好的解决方案是使用tf.data.Dataset()
,因此您可以使用.batch()
方法轻松地对数据进行批处理。
这里有很多教程,你可能想要使用from_tensor_slices()
直接使用numpy
数组。
下面有两个很好的文档来满足你的需求。
https://stackoverflow.com/questions/62632485
复制相似问题