前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tensorflow的基本用法——保存神经网络参数和加载神经网络参数

tensorflow的基本用法——保存神经网络参数和加载神经网络参数

作者头像
Tyan
发布2019-05-25 23:16:28
1.3K0
发布2019-05-25 23:16:28
举报
文章被收录于专栏:SnailTyan

本文主要是使用tensorfl保存神经网络参数和加载神经网络参数。

代码语言:javascript
复制
#!/usr/bin/env python
# _*_ coding: utf-8 _*_

import tensorflow as tf
import numpy as np


# 保存神经网络参数
def save_para():
    # 定义权重参数
    W = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype = tf.float32, name = 'weights')
    # 定义偏置参数
    b = tf.Variable([[1, 2, 3]], dtype = tf.float32, name = 'biases')
    # 参数初始化
    init = tf.global_variables_initializer()
    # 定义保存参数的saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init)
        # 保存session中的数据
        save_path = saver.save(sess, 'my_net/save_net.ckpt')
        # 输出保存路径
        print 'Save to path: ', save_path

# 恢复神经网络参数
def restore_para():
    # 定义权重参数
    W = tf.Variable(np.arange(6).reshape((2, 3)), dtype = tf.float32, name = 'weights')
    # 定义偏置参数
    b = tf.Variable(np.arange(3).reshape((1, 3)), dtype = tf.float32, name = 'biases')
    # 定义提取参数的saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # 加载文件中的参数数据,会根据name加载数据并保存到变量W和b中
        save_path = saver.restore(sess, 'my_net/save_net.ckpt')
        # 输出保存路径
        print 'Weights: ', sess.run(W)
        print 'biases:  ', sess.run(b)


# save_para()
restore_para()

执行结果如下:

代码语言:javascript
复制
# save
Save to path:  my_net/save_net.ckpt


# restore
Weights:  [[ 1.  2.  3.]
 [ 4.  5.  6.]]
biases:   [[ 1.  2.  3.]]

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2017年04月20日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档