Tensorflow教程-保存和加载模型

“ 再小的你,也有自己的个性”

通过阅读本文,您可以学会:

1、如何保存训练好的模型

2、如何加载训练好的模型

源代码:

https://github.com/PrivateThink/tensorflow_tutorial

在前面的教程,学习了如何训练模型,如果数据量大的话,是要花很长时间来训练的,重复训练即浪费时间又浪费资源,所以在很多机器学习任务中,会将训练好的模型保存下来,下次直接加载保存的模型。

数据准备和模型创建

上述程序,从文件中读取数据,构建线性模型,利用交叉熵损失函数,采用Adam优化器,最后计算准确率。

Tensorflow中用tf.train.Saver来声明保存训练好模型的操作。初始化以后就可以进行训练和测试了。

上述程序中training_enpochs=10,只训练10次,每两次打印一次训练结果。训练完以后,就可以用saver.save保存模型

第一次训练结果

经过第一次训练,测试准确率可达0.958.

保存模型很简单,同样加载模型也很简单。

上述将训练次数training_epochs改为20,继续迭代,用saver.restore就可以加载模型了,然后继续训练。

第二次训练结果

第二次的预测准确率为0.975,比第一次预测的准确要好。今天的教程就讲到了。

后续持续更新Tensorflow教程

欢迎关注和分享

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180627G01OCM00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 yunjia_community@tencent.com 删除。

扫码关注云+社区

领取腾讯云代金券