pytorch中读取模型权重数据、保存数据方法总结

pytorch中保存数据策略在长时间的深度训练中有很大的作用,我们可以通过保存训练好的权重,然后等到下次使用的时候再取出来。另外我们也可以通过迁移学习使用别人训练好的数据进行训练。达到事半功百的效果。

pytorch保存数据

pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

# 保存模型示例代码
print('===> Saving models...')
state = {
    'state': model.state_dict(),
    'epoch': epoch                   # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
    os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。

pytorch读取数据

pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。

下方的代码和上方的保存代码可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
    try:
        checkpoint = torch.load('./checkpoint/autoencoder.t7')
        model.load_state_dict(checkpoint['state'])        # 从字典中依次读取
        start_epoch = checkpoint['epoch']
        print('===> Load last checkpoint data')
    except FileNotFoundError:
        print('Can\'t found autoencoder.t7')
else:
    start_epoch = 0
    print('===> Start from scratch')

以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。

def vgg19(pretrained=False, **kwargs):
    """VGG 19-layer model (configuration "E")

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = VGG(make_layers(cfg['E']), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
    return model

假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的读取函数进行读取即可。

此文由腾讯云爬虫爬取,文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏君赏技术博客

【已解决】mas_updateConstraints更新约束引起的约束冲突

79DF33F7-1626-4FFD-9B01-1B115617F764

1832
来自专栏王小雷

降低PNG图片存储大小方法、图片压缩方法

降低PNG图片存储大小方法,图片压缩方法,如何降低PNG图片存储大小?前提是分辨率和尺寸大小不变,图形的透明部分不变。请看如下办法,亲测可用。 1. 将PNG图...

2618
来自专栏深度学习思考者

Caffe for windows 训练minst和cifar-10

前言 今天刚看了看Theano,本来打算试着改改,再深入学习学习(这就是所谓的深度学习啊,亲们!哈哈)。由于扣扣好友在训练CIFAR-10数据集,所以就在昨天配...

2466
来自专栏人工智能LeadAI

TensorFlow从0到1 | 第十七章 Step By Step上手TensorBoard

上一篇 16 L2正则化对抗“过拟合 ”提到,为了检测训练过程中发生的过拟合,需要记录每次迭代(甚至每次step)模型在训练集和验证集上的识别精度。其实,为了能...

4657
来自专栏梦里茶室

TensorFlow 深度学习笔记 逻辑回归 实践篇

Practical Aspects of Learning Install Ipython NoteBook 可以参考这个教程 可以直接安装anaconda,里...

2017
来自专栏MelonTeam专栏

Windows下Tensorflow的环境安装和Hello world

导语 深度学习 数字识别 Tensorflow 环境搭建 本人对深度学习是0基础,python也是没有用过,最近很流行深度学习,因此也想学习一下,...

2295
来自专栏TensorFlow从0到N

TensorFlow从0到1 - 17 - Step By Step上手TensorBoard

上一篇16 L2正则化对抗“过拟合”提到,为了检测训练过程中发生的过拟合,需要记录每次迭代(甚至每次step)模型在训练集和验证集上的识别精度。其实,为了能更...

3618
来自专栏程序生活

TensorFlow教程(十二) 随机数实例

tf.random_uniform([4,4], minval=-10,maxval=10,dtype=tf.float32)))返回4*4的矩阵,产生于-10...

781
来自专栏AI研习社

一个应用于物体识别的迁移学习工具链

迁移学习指的是,通过对预训练模型的参数进行微调,将训练好的模型应用到相似或者只有细微差异的不同任务中。通过这个方法,我们可以基于一些性能顶尖的深度学习模型得到别...

1142
来自专栏AIUAI

GPU 显存 - Caffe 内存优化

4286

扫码关注云+社区