首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Tensorflow数据集预取和缓存选项的正确用法是什么?

Tensorflow数据集预取和缓存选项的正确用法是什么?
EN

Stack Overflow用户
提问于 2020-09-08 23:21:07
回答 2查看 3.4K关注 0票数 3

我已经阅读了TF页面和一些帖子,并了解了如何使用prefetch()和cache()来加速模型输入管道,并尝试在我的数据上实现它。Cache()和我预期的一样工作,即在第一个时期从dist中读取数据,在随后的所有时期中,它只是从内存中读取数据。但是我使用prefetch()有很多困难,我真的不知道什么时候和如何使用它。有人能帮我吗?我真的需要一些帮助。我的应用程序是这样的:我有一组大型TFRecord文件,每个文件都包含一些原始记录,以便在输入网络之前进行处理。它们将混合在一起(不同的样本流),所以我要做的是:

代码语言:javascript
运行
复制
def read_datasets(pattern, numFiles, numEpochs=125, batchSize=1024, take=dataLength):

    files = tf.data.Dataset.list_files(pattern)

    def _parse(x):
        x = tf.data.TFRecordDataset(x, compression_type='GZIP')
    return x

    np = 4 # half of the number of CPU cores
    dataset = files.interleave(_parse, cycle_length=numFiles, block_length=1, num_parallel_calls=np)\
    .map(lambda x: parse_tfrecord(x), num_parallel_calls=np)
    dataset = dataset.take(take)
    dataset = dataset.batch(batchSize)
    dataset = dataset.cache()
    dataset = dataset.prefetch(buffer_size=10)
    dataset = dataset.repeat(numEpochs)
    return dataset

交织函数中的parse_tfrecord(x)函数是应用于模型之前所需的数据预处理,我猜测预处理时间与网络批处理时间相当。我的整个数据集(包括所有输入文件)包含大约500批1024个样本。我的问题是:

1-如果我做缓存,我真的需要预取吗?

2-映射、批处理、缓存、预取和重复的顺序是否正确?

3- Tensorflow文档指出,预取的缓冲区大小指的是数据集元素,如果是批处理,则指批处理的数量。因此,在这种情况下,我将阅读10批1024个示例,对吗?我的问题是,通过更改预取缓冲区大小,我看不到运行时间有什么不同,即使将缓冲区大小设置为1000或更大,内存消耗也没有太大变化。

EN

回答 2

Stack Overflow用户

发布于 2021-05-03 05:37:17

我为斯坦福大学的Andrew Nu找到了一个很好的解释。https://cs230.stanford.edu/blog/datapipeline/#best-practices

“当GPU在当前批次上进行正向/反向传播时,我们希望CPU处理下一批数据,以便它立即准备就绪。作为计算机最昂贵的部件,我们希望GPU在训练期间始终得到充分利用。我们称之为消费者/生产者重叠,其中消费者是GPU,生产者是CPU。

使用tf.data,只需在流水线末尾(批处理之后)调用dataset.prefetch(1)即可完成此操作。这将始终预取一批数据,并确保始终有一批数据准备就绪。

在某些情况下,预取多个批处理可能很有用。例如,如果预处理的持续时间变化很大,那么预取10个批次将平均出10个批次的处理时间,而不是有时等待更长的批次。

举一个具体的例子,假设10%的批次需要10s来计算,90%的批次需要1s。如果GPU在一个批次上训练需要2秒,那么通过预取多个批次,您可以确保我们永远不会等待这些罕见的较长批次。“

我不太确定如何确定每个批次的处理时间,但这是下一步。如果您的批处理花费了大致相同的时间,那么我相信预取(batch_size=1)应该足够了,因为您的GPU不会等待GPU完成处理计算昂贵的批处理。

票数 4
EN

Stack Overflow用户

发布于 2020-09-17 17:24:37

您可以查看此Stackoverflow Answer以快速了解TensorFlow Dataset的函数cache()prefetch()

此外,我发现这个Tensorflow Documentation对优化tf.Data Api的性能非常有帮助。他们为各种执行方式指定了基准和执行时间。您还可以分别找到有关数据的序列化和并行化加载和转换及其执行时间的信息。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63796936

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档