我在恢复我用辍学训练过的TF模型时遇到了问题。如何将keep_prob设置为1.0
下面我尝试过的代码不起作用,我认为这是因为我在恢复模型时创建了一个新的tf.placeholder。但是如何恢复keep_prob占位符呢?
这是我的恢复代码
import tensorflow as tf
import numpy as np
logs_path = ...
def readImage(filenames):
filenameQ = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.WholeFileReader() # Magic function
key, value = reader.read(filenameQ)
image = tf.image.decode_png(value)
image.set_shape([101, 201, 1])
return image
image = readImage([("../image-to-tfrecords/train/chef/chef%d.png" % i) for i in range(5000)])
merged_summary_op = tf.summary.merge_all()
class CNN:
"""
Class to load saved CNN
"""
def __init__(self, model_file, imgsize=None, visualize=True, saver=None, batch_size=100):
self.model_file = model_file
self.saver = saver
self.batch_size = batch_size
if imgsize:
self.img_h = imgsize[0]
self.img_w = imgsize[1]
def predict(self, X):
# OUTCOMMENTED THIS LINE:
#keep_prob = tf.placeholder(tf.float32) # dropout (keep probability)
"""
Prediction Routine
"""
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
train_writer = tf.summary.FileWriter(logs_path + '/train', sess.graph)
graph = tf.get_default_graph()
# restore the model
self.saver = tf.train.import_meta_graph(self.model_file) #, input_map={"keep_prob_training:0": keep_prob}, import_scope='imported'
self.saver.restore(sess, tf.train.latest_checkpoint('./tfmodels/cnn/'))
x, y = tf.get_collection('inputs')
# ADDED THE FOLLOWING LINE:
keep_prob = tf.get_collection('dropout_train')[0]
logits, predict_op = tf.get_collection('outputs')
probs = tf.nn.softmax(logits)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
predictions = []
train_writer.add_graph(sess.graph)
for i in range(5000):
batch_xs = sess.run(X)
# Reshape batch_xs if only a single image is given
# (numpy is 4D: batch_size * heigth * width * channels)
batch_xs = np.reshape(batch_xs, (-1, self.img_w * self.img_h))
prediction = sess.run([predict_op], feed_dict={x: batch_xs, keep_prob: 1.0})
predictions.append(prediction[0][0])
train_writer.close()
# finalize
coord.request_stop()
coord.join(threads)
return predictions
@staticmethod
def load(model_file, imgsize=[201, 101]):
""" Load TF metagraph """
print "Loading Model from: " + model_file
return FNN(model_file, imgsize)
def main():
""" Main """
# Load and predict
model = CNN.load("tfmodels/cnn/tf.model.meta")
model.predict(image)
if __name__ == '__main__':
main()更新
下面是张卡的图表。我比较了保存程序和恢复程序的图表,图是相同的:)



发布于 2017-03-21 16:35:09
你正确地观察到了这个问题。keep_prob = tf.placeholder(tf.float32)张量与你用tf.train.import_meta_graph()导入的图不相连,所以给张量喂食对推理没有影响。
解决方案将取决于您如何构建初始模型。您首先需要识别张量的name,该张量在原始图中用作keep_prob。例如,如果您使用以下语句在原始图的顶层创建它:
keep_prob = tf.placeholder(tf.float32, name="keep_prob_training")...the的名字应该是"keep_prob_training:0"。但是,如果您没有传递一个显式的name参数,那么名称将类似于"Placeholder:0"、"Placeholder_1:0"等。最可靠的方法是在原始程序中传递给print(keep_prob.name)。
一旦您有了这个名称(为了具体起见,我假设它是"keep_prob_training:0" ),您需要对tf.train.import_meta_graph()调用进行简单的修改,以便设置input_map并将新的keep_prob张量连接到导入的图。下列措施应能发挥作用:
self.saver = tf.train.import_meta_graph(
self.model_file, input_map={"keep_prob_training:0": keep_prob})在您这样做之后,输入keep_prob张量将允许您控制在推理时应用的辍学。
https://stackoverflow.com/questions/42912234
复制相似问题