我已经使用python-tensorflow训练了一个模型,我想在java-tensorflow中进行推断。我已经将训练好的模型/图加载到Java中。在此之后,我想永久更新图中的一个变量。我知道python中的tf.variable.load(value,session)
函数可以用来更新变量的值。我想知道Java中是否有类似的方法。
到目前为止,我已经尝试了以下几种方法。
// g and s are loaded graphs and sessions respectively
s.runner().feed(variableName,updatedTensorValue)
但上面的代码行仅在同一行中执行的fetch
调用期间对variableName
使用updatedTensorValue
。
g.opBuilder("Assign",variableName).setAttr("value",updatedTensorValue).build();
上面这行代码不是更新值,而是试图将相同的变量添加到图中,因此抛出了一个异常。
除了永久更新图形中的变量之外,我还可以在所有fetch
调用期间始终调用feed(variableName,updatedTensorValue)
方法。我将在几个实例上运行推理代码,所以我想知道这个额外的feed
调用需要额外的时间。
谢谢
发布于 2018-04-17 09:13:22
在TensorFlow中做大多数事情的方法是执行一个操作。您尝试运行Assign
操作是正确的,但是错误地调用了它,因为要分配的value
不是Assign
操作的“属性”,而是一个输入张量。(请参阅原始的definition of the operation,但无可否认,除非您熟悉TensorFlow的内部结构,否则很难理解该定义)。
但是,您不需要在Java中向图形中添加操作来完成此操作。相反,您可以做Python语言中tf.Variable.load
所做的事情--执行tf.Variable.initializer
操作,输入输入值。
例如,考虑使用Python构建的以下图表:
import tensorflow as tf
var = tf.Variable(1.0, name='myvar')
init = tf.global_variables_initializer()
# Save the graph and write out the names of the operations of interest
tf.train.write_graph(tf.get_default_graph(), '/tmp', 'graph.pb', as_text=False)
print('Init all variables: ', init.name)
print('myvar.initializer: ', var.initializer.name)
print('myvar.initializer.inputs[1]:', var.initializer.inputs[1].name)
现在,我们在Java中复制Python var.load()
的行为,将值3.0赋值给变量,如下所示:
try (Tensor<Float> newValue = Tensors.create(3.0f)) {
s.runner()
.feed("myvar/initial_value", newVal) // myvar.initializer.inputs[1].name
.addTarget("myvar/Assign") // myvar.initializer.name
.run();
}
希望这能有所帮助。
https://stackoverflow.com/questions/49801711
复制