在训练过程中,我将历元设置为迭代数据的次数。当我已经可以用tf.data.Datasets.repeat(EPOCHS)做同样的事情时,我想知道model.fit(train_dataset,epochs=EPOCHS)的用途是什么?
发布于 2021-03-10 03:56:25
它的工作原理略有不同。
让我们选择两个不同的例子。
我们还假设您有一个包含100条记录的数据集。
如果您选择1,每个时代将有2,000条记录。您将“检查”您的模型是如何改善后,通过2,000条记录通过您的模型,你将这样做10次。
如果您选择选项2,每个时代将有1,000条记录。你将评估你的模型在推送1,000张记录后是如何改进的,你将这样做20次。
在这两个选项中,用于培训的记录总数是相同的,但是当您评估、记录等时,模型的行为是不同的。
发布于 2021-04-06 15:32:15
在图像数据的情况下,tf.data.Datasets.repeat()可以用于tf.data.Datasets上的数据增强。
假设您希望增加训练数据集中的图像数量,使用随机转换,然后重复训练数据集的count次数,并应用随机转换,如下所示
train_dataset = (
train_dataset
.map(resize, num_parallel_calls=AUTOTUNE)
.map(rescale, num_parallel_calls=AUTOTUNE)
.map(onehot, num_parallel_calls=AUTOTUNE)
.shuffle(BUFFER_SIZE, reshuffle_each_iteration=True)
.batch(BATCH_SIZE)
.repeat(count=5)
.map(random_flip, num_parallel_calls=AUTOTUNE)
.map(random_rotate, num_parallel_calls=AUTOTUNE)
.prefetch(buffer_size=AUTOTUNE)
)方法时,您必须创建dataset的副本,分别应用转换,然后连接数据集。但是使用repeat()简化了这一点,还利用了方法链,并且代码看起来很整洁。
关于数据增强的更多信息:https://www.tensorflow.org/tutorials/images/data_augmentation#apply_augmentation_to_a_dataset
https://stackoverflow.com/questions/66557699
复制相似问题