首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在Tensorflow中恢复保存的神经网络

在Tensorflow中恢复保存的神经网络
EN

Stack Overflow用户
提问于 2019-01-02 11:52:44
回答 1查看 339关注 0票数 1

在把我的问题标为重复之前,我想让大家明白,我已经经历了很多问题,但没有一个解决办法能够消除我的疑虑,解决我的问题。我有一个经过训练的神经网络,我想要保存,然后用这个模型来测试这个模型对测试数据集。

我试着保存和恢复它,但我没有得到预期的结果。恢复似乎不起作用,也许我错误地使用了它,它只是使用全局变量初始化器给出的值。

这是我用来保存模型的代码。

代码语言:javascript
运行
复制
 sess.run(tf.initializers.global_variables())
#num_epochs = 7
for epoch in range(num_epochs):  
  start_time = time.time()
  train_accuracy = 0
  train_loss = 0
  val_loss = 0
  val_accuracy = 0

  for bid in range(int(train_data_size/batch_size)):
     X_train_batch = X_train[bid*batch_size:(bid+1)*batch_size]
     y_train_batch = y_train[bid*batch_size:(bid+1)*batch_size]
     sess.run(optimizer, feed_dict = {x:X_train_batch, y:y_train_batch,prob:0.50})  

     train_accuracy = train_accuracy + sess.run(model_accuracy, feed_dict={x : X_train_batch,y:y_train_batch,prob:0.50})
     train_loss = train_loss + sess.run(loss_value, feed_dict={x : X_train_batch,y:y_train_batch,prob:0.50})

  for bid in range(int(val_data_size/batch_size)):
     X_val_batch = X_val[bid*batch_size:(bid+1)*batch_size]
     y_val_batch = y_val[bid*batch_size:(bid+1)*batch_size]
     val_accuracy = val_accuracy + sess.run(model_accuracy,feed_dict = {x:X_val_batch, y:y_val_batch,prob:0.75})
     val_loss = val_loss + sess.run(loss_value, feed_dict = {x:X_val_batch, y:y_val_batch,prob:0.75})

  train_accuracy = train_accuracy/int(train_data_size/batch_size)
  val_accuracy = val_accuracy/int(val_data_size/batch_size)
  train_loss = train_loss/int(train_data_size/batch_size)
  val_loss = val_loss/int(val_data_size/batch_size)


  end_time = time.time()


  saver.save(sess,'./blood_model_x_v2',global_step = epoch)  

保存模型后,这些文件会写在我的工作目录中,如下所示。

血模型x_v2-2数据-0000-0001

代码语言:javascript
运行
复制
blood\_model\_x\_v2-2.index
代码语言:javascript
运行
复制
blood\_model\_x\_v2-2.meta

类似地,v2-3,等到v2-6,然后是一个‘检查点’文件。然后,我尝试使用这个代码片段(初始化后)恢复它,但是得到了与预期的结果不同的结果。我做错什么了?

代码语言:javascript
运行
复制
saver = tf.train.import_meta_graph('blood_model_x_v2-5.meta')
saver.restore(test_session,tf.train.latest_checkpoint('./'))
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-01-02 12:36:18

根据tensorflow博士

还原还原以前保存的变量。 此方法运行由构造函数添加的用于还原变量的操作。它需要一个启动图形的会话。要还原的变量不必已初始化,因为还原本身就是初始化变量的一种方法。

让我们看看一个例子:

我们保存与此类似的模型:

代码语言:javascript
运行
复制
import tensorflow as tf

# Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

# Define a test operation that we will restore
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create a saver object which will save all the variables
saver = tf.train.Saver()

# Run the operation by feeding input
print (sess.run(w4, feed_dict))
# Prints 24 which is sum of (w1+w2)*b1

# Now, save the graph
saver.save(sess, './ckpnt/my_test_model', global_step=1000)

然后加载经过训练的模型:

代码语言:javascript
运行
复制
import tensorflow as tf

sess = tf.Session()
# First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('./ckpnt/my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./ckpnt'))

# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}

# Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print (sess.run(op_to_restore, feed_dict))
# This will print 60 which is calculated
# using new values of w1 and w2 and saved value of b1.

如您所见,我们不会在还原部分初始化会话。使用Checkpoint保存和恢复模型有更好的方法,它允许您检查模型是否被正确还原。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54005871

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档