首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >ONNX运行时推理- session.run()多处理

ONNX运行时推理- session.run()多处理
EN

Stack Overflow用户
提问于 2022-01-21 15:43:13
回答 1查看 2.6K关注 0票数 1

目标:在多个CPU核上并行运行推理

我正在尝试使用inference.ipynb进行推理。

个别:

代码语言:javascript
运行
复制
outputs = session.run([output_name], {input_name: x})

许多:

代码语言:javascript
运行
复制
outputs = session.run(["output1", "output2"], {"input1": indata1, "input2": indata2})

依次:

代码语言:javascript
运行
复制
%%time
outputs = [session.run([output_name], {input_name: inputs[i]})[0] for i in range(test_data_num)]

这个多处理教程提供了许多并行处理任何任务的方法。

但是,我想知道哪种方法对session.run()最好,是否传递了outputs

如何并行地推断所有输出和输入?

代码:

代码语言:javascript
运行
复制
import onnxruntime
import multiprocessing as mp

session = onnxruntime.InferenceSession('bert.opt.quant.onnx')

i = 0
# First Input
input_name = session.get_inputs()[i].name
print("Input Name  :", input_name)

# First Output
output_name = session.get_outputs()[i].name
print("Output Name  :", output_name)  

pool = mp.Pool(mp.cpu_count())

# PARALLELISE THIS LINE
outputs = [session.run([], {input_name: inputs[i]})[0] for i in range(test_data_num)]
# outputs = pool.starmap(func, zip(iter_1, iter_2))

pool.close()

print(results)

Update:这个解决方案建议使用starmap()zip()来传递函数名和两个单独的迭代。

将一行改为:

代码语言:javascript
运行
复制
outputs = pool.starmap(session.run, zip([output_name], [ {input_name: inputs[i]}[0] for i in range(test_data_num) ]))

回溯:

代码语言:javascript
运行
复制
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-45-0aab302a55eb> in <module>
     25 #%%time
     26 #outputs = [session.run([output_name], {input_name: inputs[i]})[0] for i in range(test_data_num)]
---> 27 outputs = pool.starmap(session.run, zip([output_name], [ {input_name: inputs[i]}[0] for i in range(test_data_num) ]))
     28 
     29 pool.close()

<ipython-input-45-0aab302a55eb> in <listcomp>(.0)
     25 #%%time
     26 #outputs = [session.run([output_name], {input_name: inputs[i]})[0] for i in range(test_data_num)]
---> 27 outputs = pool.starmap(session.run, zip([output_name], [ {input_name: inputs[i]}[0] for i in range(test_data_num) ]))
     28 
     29 pool.close()

KeyError: 0
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-01-21 16:56:26

代码语言:javascript
运行
复制
def run_inference(i):
    output_name = session.get_outputs()[0].name
    return session.run([output_name], {input_name: inputs[i]})[0]  # [0] bc array in list

outputs = pool.map(run_inference, [i for i in range(test_data_num)])

任何人都可以批评

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70803924

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档