我正在使用只使用CPU的Keras在python中训练LSTM自动编码器模型。
我可以看到在fit函数中有一个名为use_multiprocessing的参数。请你简单地解释一下这个论点到底是做什么的。我在tensorflow.org上阅读了解释,但是如果我将参数设置为真,我无法从它中理解如何影响我的模型。我正在寻找方法来加快我的模型的训练,我想知道这个参数是否会有帮助。
发布于 2022-10-24 02:04:11
use_multiprocessing (以及workers和max_queue_size)参数应用于批处理数据生成。文档中的线索是:“只用于生成器或keras.utils.Sequence输入”,参考keras.utils.Sequence在幕后,keras使用一个ordereden队列来包装您的输入。
如果use_multiprocessing为True,而workers > 0,则keras将创建多个(number = workers) 进程,以便同时运行,并从生成器/序列中准备批处理。他们将努力保持队列的批准备,以进行培训,直到max_queue_size。
如果use_multiprocessing为False,而workers > 1,则keras将创建多个(number = workers) 线程,以同时准备批处理,类似于上面的批处理(但您的输入数据对象需要线程安全)。
如果您的批处理数据生成是培训过程中的瓶颈,这将大大加快速度。我发现use_multiprocessing真可以线性地加快批处理数据绑定工作,例如,两个工作人员=2倍的速度(尽管有启动进程的开销)。对于use_multiprocessing False和线程,我发现在没有开销的情况下,速度会增加0%-15%,这是不可预测的。
也请参考这个问题,并提供详细信息:How to define max_queue_size, workers and use_multiprocessing in keras fit_generator()?
https://stackoverflow.com/questions/71114440
复制相似问题