首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >Tensorflow :如何在java中使用用python训练的语音识别模型

Tensorflow :如何在java中使用用python训练的语音识别模型
EN

Stack Overflow用户
提问于 2018-08-20 19:52:05
回答 1查看 505关注 0票数 2

我有一个用python训练的tensorflow模型,在训练后我已经生成了this article.冻结的图形。现在,我需要使用这个图形并在基于的JAVA应用程序上生成识别。为此,我查看了以下example。然而,我无法理解的是如何收集我的输出。我知道我需要为图表提供3个输入。

从官方教程中给出的示例中,我已经阅读了基于python的代码。

代码语言:javascript
复制
def run_graph(wav_data, labels, input_layer_name, output_layer_name,
              num_top_predictions):
  """Runs the audio data through the graph and prints predictions."""
  with tf.Session() as sess:
    # Feed the audio data as input to the graph.
    #   predictions  will contain a two-dimensional array, where one
    #   dimension represents the input image count, and the other has
    #   predictions per class
    softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
    predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})

    # Sort to show labels in order of confidence
    top_k = predictions.argsort()[-num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = labels[node_id]
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))

    return 0

有人能帮我理解一下tensorflow java api吗?

EN

回答 1

Stack Overflow用户

发布于 2018-08-23 14:24:25

上面列出的Python代码的字面翻译如下所示:

代码语言:javascript
复制
public static float[][] getPredictions(Session sess, byte[] wavData, String inputLayerName, String outputLayerName) {
  try (Tensor<String> wavDataTensor = Tensors.create(wavData);
       Tensor<Float> predictionsTensor = sess.runner()
                    .feed(inputLayerName, wavDataTensor)
                    .fetch(outputLayerName)
                    .run()
                    .get(0)
                    .expect(Float.class)) {
    float[][] predictions = new float[(int)predictionsTensor.shape(0)][(int)predictionsTensor.shape(1)];
    predictionsTensor.copyTo(predictions);
    return predictions;
  }
}

返回的predictions数组将具有每个预测的“置信度”值,您必须运行逻辑来计算它的“前K”,类似于Python代码如何使用numpy (.argsort())来计算sess.run()返回的值。

粗略地阅读一下教程页面和代码,就会发现predictions将有1行12列(每个热词对应一列)。这是我从下面的Python代码中得到的:

代码语言:javascript
复制
import tensorflow as tf

graph_def = tf.GraphDef()
with open('/tmp/my_frozen_graph.pb', 'rb') as f:
  graph_def.ParseFromString(f.read())

output_layer_name = 'labels_softmax:0'

tf.import_graph_def(graph_def, name='')
print(tf.get_default_graph().get_tensor_by_name(output_layer_name).shape)

希望这能有所帮助。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51930183

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档