我想运行一个需要生成3d输出的TFLite模型(示例代码是生成错误的最小示例)。是否有一个与gather_nd等价的tensorflow不会将维度减少1?
我尝试在文档中查找我能想到的相关函数,但没有找到一个好的选择。
import tensorflow.compat.v1 as tf
import numpy as np
tf.disable_v2_behavior()
initial_input = tf.placeholder(dtype=tf.float32, shape=(None,5,1024))
cap_i = tf.gather_nd(initial_input, [[0,1]]) #[0,2],[0,3],[0,4],[0,5]
cap_i_broadcast = tf.broadcast_to(cap_i, [1,5,1024])
cap_iT = tf.transpose(cap_i_broadcast, perm=[0,2,1])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
tf.io.write_graph(sess.graph_def, '', 'train.pbtxt')
converter = tf.lite.TFLiteConverter.from_session(sess, [initial_input], [cap_iT])
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open('converted_model.tflite', "wb").write(tflite_model)
sess.close()
标准TensorFlow精简版运行时不支持模型中的某些运算符,因此TensorFlow无法识别这些运算符。如果您对它们有一个自定义的实现,您可以使用--allow_ custom _ops来禁用这个错误,或者在调用tf.lite.TFLiteConverter()时设置allow_custom_ops=True。下面是您正在使用的内置操作符的列表: GATHER_ND、TRANSPOSE。下面是您需要定制实现的运算符列表: BroadcastTo。
发布于 2019-09-23 20:46:16
下面的代码有一个解决方案,使用降维的步长切片,然后重塑以获得正确的尺寸。
import tensorflow.compat.v1 as tf
import numpy as np
tf.disable_v2_behavior()
initial_input = tf.placeholder(dtype=tf.float32, shape=(None,5,1024))
cap_i = tf.strided_slice(initial_input, [0,0,0], [0,5,1024], [1,1,1],
shrink_axis_mask=1)
cap_i_reshaped =tf.reshape(cap_i,[1,5,1024])
cap_iT = tf.transpose(cap_i_reshaped, perm=[0,2,1])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
tf.io.write_graph(sess.graph_def, '', 'train.pbtxt')
converter = tf.lite.TFLiteConverter.from_session(sess, [initial_input],
[cap_iT])
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open('converted_model.tflite', "wb").write(tflite_model)
sess.close()
以前认为TFLite支持slice,但只有strided_slice支持。
https://stackoverflow.com/questions/58069572
复制相似问题