在字符串输入上使用onnx运行时进行预测

内容来源于 Stack Overflow,并遵循CC BY-SA 3.0许可协议进行翻译与使用

  • 回答 (1)
  • 关注 (0)
  • 查看 (78)

我正在尝试对已打包为ONNX的sklearn管道中的文本进行预测。我能够写出并读入模型,但是当我进行预测时,我得到错误:“方法运行失败,原因是:[ONNXRuntimeError]:2:INVALID_ARGUMENT:缺少必需的输入:float_input”。有谁知道如何从文本上的sklearn管道进行预测?

我已经按照本教程http://onnx.ai/sklearn-onnx/auto_examples/plot_tfidfvectorizer.html#sphx-glr-download-auto-examples-plot-tfidfvectorizer-py进行了预测。

'''
#convert pipeline into onnx
model_onnx = convert_sklearn(pipeline, "tfidf",
                             initial_types=[("str_input", StringTensorType([1, 2000]))])

with open("pipeline_emails.onnx", "wb") as f:
     f.write(onx.SerializeToString())

#make predictions on test data
sess = rt.InferenceSession("pipeline_emails.onnx")
pred_onx = sess.run(None, {"str_input": test_df.as_matrix()})[0]
print("predict", pred_onx[0])
print("predict_proba", pred_onx[1])
'''

我希望得到一组关于我的测试数据的预测,但我得到:

RuntimeError                              Traceback (most recent call last)
<ipython-input-118-5db056b989a8> in <module>()
      2 sess = rt.InferenceSession("pipeline_emails.onnx")
      3 inputs = {'str_input': test_df.as_matrix()}
----> 4 pred_onx = sess.run(None, {"str_input": test_df.as_matrix()})[0]
      5 print("predict", pred_onx[0])
      6 print("predict_proba", pred_onx[1])

~\AppData\Local\Continuum\anaconda3\lib\site-packages\onnxruntime\capi\session.py in run(self, output_names, input_feed, run_options)
     70         if not output_names:
     71             output_names = [output.name for output in self._outputs_meta]
---> 72         return self._sess.run(output_names, input_feed, run_options)
     73 
     74     def end_profiling(self):

RuntimeError: Method run failed due to: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Missing required input: float_input
提问于

扫码关注云+社区

领取腾讯云代金券