首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >TensorFlow:如何保存/恢复模型?

TensorFlow:如何保存/恢复模型?
EN

Stack Overflow用户
提问于 2018-06-21 02:56:54
回答 2查看 0关注 0票数 0

当你在坦索弗洛训练了一个模特之后:

  1. 你是如何拯救受过训练的模特的?
  2. 稍后如何恢复保存的模型?
EN

回答 2

Stack Overflow用户

发布于 2018-06-21 11:31:44

我的环境:Python3.6,TensorFlow 1.3.0

虽然有很多解决方案,但大多数都是基于tf.train.Saver。当我们加载一个.ckpt保存下来Saver,我们要么重新定义TensorFlow网络,要么使用一些奇怪的、难以记住的名称。'placehold_0:0''dense/Adam/Weight:0'

保存模型:

代码语言:txt
复制
import tensorflow as tf

# define the tensorflow network and do some trains
x = tf.placeholder("float", name="x")
w = tf.Variable(2.0, name="w")
b = tf.Variable(0.0, name="bias")

h = tf.multiply(x, w)
y = tf.add(h, b, name="y")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# save the model
export_path =  './savedmodel'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)

tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
  tf.saved_model.signature_def_utils.build_signature_def(
      inputs={'x_input': tensor_info_x},
      outputs={'y_output': tensor_info_y},
      method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

builder.add_meta_graph_and_variables(
  sess, [tf.saved_model.tag_constants.SERVING],
  signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          prediction_signature 
  },
  )
builder.save()

加载模型:

代码语言:txt
复制
import tensorflow as tf
sess=tf.Session() 
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'x_input'
output_key = 'y_output'

export_path =  './savedmodel'
meta_graph_def = tf.saved_model.loader.load(
           sess,
          [tf.saved_model.tag_constants.SERVING],
          export_path)
signature = meta_graph_def.signature_def

x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name

x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)

y_out = sess.run(y, {x: 3.0})
票数 0
EN

Stack Overflow用户

发布于 2018-06-21 12:42:35

你也可以采取更简单的方式。

步骤1:初始化所有变量

代码语言:txt
复制
W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

步骤2:在模型中保存会话Saver保存它

代码语言:txt
复制
model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

步骤3:恢复模型

代码语言:txt
复制
with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

步骤4:检查变量

代码语言:txt
复制
W1 = session.run(W1)
print(W1)

在不同python实例中运行时,请使用

代码语言:txt
复制
with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/-100005462

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档