首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >如何获取TOCO tf_convert的冻结Tensorflow模型的input_shape

如何获取TOCO tf_convert的冻结Tensorflow模型的input_shape
EN

Stack Overflow用户
提问于 2018-11-30 00:48:40
回答 2查看 5.3K关注 0票数 3

我正在尝试使用TF Lite Converter (this is the specific model i am using)在Ubuntu18.04.1LTS (VirtualBox)上将我从davidsandberg/facenet获得的冻结模型转换为.tflite。当我尝试运行该命令时:

代码语言:javascript
复制
/home/nils/.local/bin/tflite_convert 
--output_file=/home/nils/Documents/frozen.tflite 
--graph_def_file=/home/nils/Documents/20180402-114759/20180402-114759.pb 
--input_arrays=input --output_array=embeddings

我得到以下错误:

代码语言:javascript
复制
2018-11-29 16:36:21.774098: I 
tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports 
instructions that this TensorFlow binary was not compiled to use: AVX2
Traceback (most recent call last):
File "/home/nils/.local/bin/tflite_convert", line 11, in <module>
 sys.exit(main())
File 
"/home/nils/.local/lib/python3.6/site-packages/tensorflow/contrib   /lite/python/tflite_convert.py", 
line 412, in main
 app.run(main=run_main, argv=sys.argv[:1])
   File 
"/home/nils/.local/lib/python3.6/site-packages/tensorflow/python/platform/app.py", 
line 125, in run
 _sys.exit(main(argv))
File 
"/home/nils/.local/lib/python3.6/site-packages/tensorflow/contrib/lite/python/tflite_convert.py", 
line 408, in run_main
 _convert_model(tflite_flags)
File 
"/home/nils/.local/lib/python3.6/site-packages/tensorflow/contrib/lite/python/tflite_convert.py", 
line 162, in _convert_model
 output_data = converter.convert()
File 
"/home/nils/.local/lib/python3.6/site-packages/tensorflow/contrib/lite/python/lite.py", 
line 404, in convert
 "'{0}'.".format(_tensor_name(tensor)))
ValueError: Provide an input shape for input array 'input'.

由于我没有亲自训练模型,我不知道输入的确切形状。也许可以从David Sandberg的GitHubRep.中找到的classifier.py和facenet.py中提取出来,位于facenet/src,但我对代码的理解还不足以让我自己这么做。我甚至试着通过tensorboard来分析图表。我无论如何都搞不清楚,但也许你可以:Tensorboard-Screenshot正如你可能已经注意到的,我对Ubuntu,Tensorflow和所有相关的东西都是新手,所以我很乐意在这个问题上接受任何建议。提前谢谢你!

这是classifier.py的相关部分,模型在这里加载和设置:

代码语言:javascript
复制
 # Load the model
        print('Loading feature extraction model')
        facenet.load_model(args.model)

        # Get input and output tensors
        images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
        embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
        phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
        embedding_size = embeddings.get_shape()[1]

        # Run forward pass to calculate embeddings
        print('Calculating features for images')
        nrof_images = len(paths)
        nrof_batches_per_epoch = int(math.ceil(1.0*nrof_images / args.batch_size))
        emb_array = np.zeros((nrof_images, embedding_size))
        for i in range(nrof_batches_per_epoch):
            start_index = i*args.batch_size
            end_index = min((i+1)*args.batch_size, nrof_images)
            paths_batch = paths[start_index:end_index]
            images = facenet.load_data(paths_batch, False, False, args.image_size)
            feed_dict = { images_placeholder:images, phase_train_placeholder:False }
            emb_array[start_index:end_index,:] = sess.run(embeddings, feed_dict=feed_dict)

        classifier_filename_exp = os.path.expanduser(args.classifier_filename)
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-12-13 04:04:06

如果你再次启动Tensorboard,回到你看到的图表,应该有一个搜索图标(我想在左上角),在那里你可以输入" input“并找到输入张量。它会给你想要的形状。我猜它应该是'1,image_size,image_size,3‘的形式。

或者,您可以检查代码

代码语言:javascript
复制
feed_dict = { images_placeholder:images, phase_train_placeholder:False }

请注意,我们将“图像”对象输入到images_placeholder中,该对象映射到"input:0“张量。实际上,您需要的是图像对象的形状。

图像来自对facenet.load_data()的调用。如果您进入facenet.py并检查load_data函数,您可以观察到其形状类似于我上面建议的形状。如果您打印image_size值,它应该与您在Tensorboard中看到的值相匹配。

票数 0
EN

Stack Overflow用户

发布于 2019-01-10 15:45:41

我看过了tflite转换器的代码。我发现您需要以{"input_tensor_name": [input shape]}格式将输入形状作为字典提供。

下面是一个解决这个问题的例子:

代码语言:javascript
复制
`graph_def_file = "20180402-114759/20180402-114759.pb"
input_arrays = ["input"]
output_arrays = ["embeddings"]

converter = tf.lite.TFLiteConverter.from_frozen_graph(
  graph_def_file, input_arrays, output_arrays,input_shapes={"input":[1,160,160,3]})

tflite_model = converter.convert()
open("model.tflite", "wb").write(tflite_model)
`
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53543872

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档