在训练过程中,我将历元设置为迭代数据的次数。当我已经可以用tf.data.Datasets.repeat(EPOCHS)做同样的事情时,我想知道model.fit(train_dataset,epochs=EPOCHS)的用途是什么?
发布于 2021-04-06 15:32:15
在图像数据的情况下,tf.data.Datasets.repeat()可以用于tf.data.Datasets上的数据增强。
假设您希望增加训练数据集中的图像数量,使用随机转换,然后重复训练数据集的count次数,并应用随机转换,如下所示
train_dataset = (
train_dataset
.map(resize, num_parallel_calls=AUTOTUNE)
.map(rescale, num_parallel_calls=AUTOTUNE)
.map(onehot, num_parallel_calls=AUTOTUNE)
.shuffle(BUFFER_SIZE, reshuffle_each_iteration=True)
.batch(BATCH_SIZE)
.repeat(count=5)
.map(random_flip, num_parallel_calls=AUTOTUNE)
.map(random_rotate, num_parallel_calls=AUTOTUNE)
.prefetch(buffer_size=AUTOTUNE)
)方法时,您必须创建dataset的副本,分别应用转换,然后连接数据集。但是使用repeat()简化了这一点,还利用了方法链,并且代码看起来很整洁。
关于数据增强的更多信息:https://www.tensorflow.org/tutorials/images/data_augmentation#apply_augmentation_to_a_dataset
https://stackoverflow.com/questions/66557699
复制相似问题