首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Keras多处理模型预测

Keras多处理模型预测
EN

Stack Overflow用户
提问于 2022-11-23 01:06:04
回答 1查看 20关注 0票数 0

我有一个简单的MNIST模型来进行预测并节省损失。我是在一个有多个CPU的服务器上运行的,所以我想使用多进程加速。

我已经成功地使用了具有一些基本功能的多处理,但是对于模型预测,这些过程永远不会完成,而使用非多处理方法时,它们工作得很好。

我怀疑这个问题可能与模型有关,因为有一个模型它不能在不同的并行进程中使用,所以我在每个进程中加载了模型,但是它没有工作。

我的代码是:

代码语言:javascript
运行
复制
from multiprocessing import Process
import tensorflow as tf

#make a prediction on a training sample
def predict(idx, return_dict):
  x = tf.convert_to_tensor(np.expand_dims(x_train[idx],axis=0))

  local_model=tf.keras.models.load_model('model.h5')
  y=local_model(x)
  print('this never gets printed')
  y_expanded=np.expand_dims(y_train[train_idx],axis=0)
  loss=tf.keras.losses.CategoricalCrossentropy(y_expanded,y)
  return_dict[i]=loss

manager = multiprocessing.Manager()
return_dict = manager.dict()
jobs = []

for i in range(10):
    p = Process(target=predict, args=(i, return_dict))
    jobs.append(p)
    p.start()
    
for proc in jobs:
    proc.join()

print(return_dict.values())

predict函数中的打印行永远不会显示,问题在于模型。即使没有在函数中加载模型并使用全局模型,问题仍然存在。

我跟踪了这个线程,但是它没有工作。我现在的问题是:

  1. 如何解决模型问题
  2. 我可以对所有进程使用相同的X_train吗?

谢谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-11-25 22:27:13

我找到了答案。首先,Keras在多处理12方面存在问题。此外,TensorFlow应该始终有一个会话。因此,它只能在函数中导入,而不是在其他地方导入。并且模型应该在每个函数中从磁盘加载。这可能是改进的来源(将模型移动到RAM,将模型序列化为一个文件,并将其传递给函数)。

尽管如此,下面的代码仍然有效。

代码语言:javascript
运行
复制
def predict(idx, return_dict):

  import tensorflow as tf

  x=tf.convert_to_tensor(x_train[idx])
  cce = tf.keras.losses.CategoricalCrossentropy()

  local=tf.keras.models.load_model('model.h5')

  y=local(np.expand_dims(x,axis=0))
  y_expanded=np.expand_dims(y_train[train_idx],axis=0)
  loss=cce(y_expanded,y)
   
  return_dict[idx]=loss

可以使用相同的x_train

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74540699

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档