首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >使用tensorflow.js打印预测时出现问题

使用tensorflow.js打印预测时出现问题
EN

Stack Overflow用户
提问于 2020-01-15 05:42:02
回答 1查看 76关注 0票数 1

我已经使用make_image_classifier命令行工具重新训练了一个mobilenet_v2模型,以重新训练模型,并使用tfjs-converter为浏览器准备模型。

代码语言:javascript
复制
make_image_classifier \
  --image_dir image_data \
  --tfhub_module https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4 \
  --saved_model_dir trained_models/1 \
  --labels_output_file class_labels.txt \
  --tflite_output_file trained_model.tflite
代码语言:javascript
复制
tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_format=tfjs_graph_model \
    --signature_name=serving_default \
    --saved_model_tags=serve \
    ./trained_models/1 \
    ./web_model

为了测试TF Lite模型,我使用了tflite example code。我遵循了工具的说明,因此使用了提供的代码。

如果我现在尝试在浏览器中预测图像,我得不到预期的输出。看起来只打印概率而不打印标签。

代码语言:javascript
复制
const MODEL_URL = 'model/model.json';
const model = await tf.loadGraphModel(MODEL_URL);

var canvas = document.getElementById("canvas").getContext("2d");;
const img = canvas.getImageData(0,0, 224,224)

const tfImg = tf.browser.fromPixels(img).expandDims(0);
const smalImg = tf.image.resizeBilinear(tfImg, [224, 224]);

let result = await model.predict(smalImg);

console.log(result.print())

输出:张量[0.0022475,0.0040588,0.0220788,0.0032885,0.000126,0.0030831,0.8462179,0.1188994,]

用python测试模型效果很好,我用标签和概率得到了预期的输出。我做错了什么吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-01-15 18:59:23

在python代码中,在预测之后有以下内容:

代码语言:javascript
复制
  # output_data is the predition result
  results = np.squeeze(output_data)
  top_k = results.argsort()[-5:][::-1]
  labels = load_labels(args.label_file)
  for i in top_k:
    if floating_model:
      print('{:08.6f}: {}'.format(float(results[i]), labels[i]))
    else:
      print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i]))

在js中的预测之后需要进行相同的处理。labels数组包含数据集的所有标签。它需要在js中加载,因为它是从labels.txt文件加载到python代码中的。

代码语言:javascript
复制
const topkIndices = result.topk(5) 
const topkIndices = await topk.indices.data();
const categories = (Array.from(topkIndices)).map((p: number) => labels[p]);

const topkValues = await topk.values.data()
const data = Array.from(topkValues).map(p => p * 100);
// categories will contain the top5 labels and
// data will contain their corresponding probabilities
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59742188

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档