前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【动手学深度学习笔记】之读取和存储

【动手学深度学习笔记】之读取和存储

作者头像
树枝990
发布2020-08-19 09:48:47
4560
发布2020-08-19 09:48:47
举报
文章被收录于专栏:拇指笔记拇指笔记

1. 读取和存储

在训练好模型后,有时需要把训练好的模型参数存储起来以供后续使用。

1.1 读写Tensor

存储和读取Tensor可以分别使用save函数和load函数实现。save函数的操作对象包括模型、张量和字典等。

首先创建两个Tensor

代码语言:javascript
复制
import torch

x = torch.ones(3)
y = torch.zeros(4)

读取和存储单个Tensor实例:

代码语言:javascript
复制
torch.save(x,'x.pt')
x2 = torch.load('x.pt')
print(x2)

Out[1]:

tensor([1., 1., 1.])

读取和存储一个Tensor列表实例:

代码语言:javascript
复制
torch.save([x,y],'xy.pt')
xy_list = torch.load('xy.pt')
print(xy_list)

Out[1]:

[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]

读取和存储一个Tensor字典实例:

代码语言:javascript
复制
torch.save({'x': x, 'y': y}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
print(xy)

Out[1]:

{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}

1.2 读写模型

优化器和具有可学习参数的层的参数名称和参数被存储在state_dict

下面以实例调用state_dict来显示模型参数和名称。

代码语言:javascript
复制
class MLP(nn.Module):
    def __init__(self):
        super(MLP,self).__init__()
        self.hidden = nn.Linear(3,2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2,1)
        
    def forward(self,x):
        a = self.act(self.hidden(x))
        return self.output(a)
    
net = MLP()
net.state_dict()

Out[1]:
    
OrderedDict([('hidden.weight',
              tensor([[-0.3303, -0.2529, -0.4268],
                      [ 0.4672, -0.2530, -0.0974]])),
             ('hidden.bias', tensor([-0.1994, -0.2971])),
             ('output.weight', tensor([[-0.3032, -0.0526]])),
             ('output.bias', tensor([0.5046]))])

下面以实例调用state_dict来显示优化器状态和超参数。

代码语言:javascript
复制
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()

Out[1]:

{'state': {},
 'param_groups': [{'lr': 0.001,
   'momentum': 0.9,
   'dampening': 0,
   'weight_decay': 0,
   'nesterov': False,
   'params': [1967586368392, 1967586368472, 1967586368632, 196758368712]}]}

可以通过保存和加载模型参数(state_dict)来实现保存和加载模型。

同样通过实例来显示整个过程。

代码语言:javascript
复制
#保存:
torch.save(net.state_dict(),'K:\sd.pt')

#读取:
net1 = MLP()
net1.load_state_dict(torch.load('K:\sd.pt'))

Out[1]:
    <All keys matched successfully>

也可以直接存储和读取整个模型。

代码语言:javascript
复制
#存储
torch.save(net,'K:\sd1.pt')

#读取
net2 = torch.load('K:\sd1.pt')

通过这两种方法保存和读取的模型具有相同的模型参数,因此他们的正向传播结果是相同的。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-03-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 拇指笔记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 读取和存储
    • 1.1 读写Tensor
      • 1.2 读写模型
      相关产品与服务
      对象存储
      对象存储(Cloud Object Storage,COS)是由腾讯云推出的无目录层次结构、无数据格式限制,可容纳海量数据且支持 HTTP/HTTPS 协议访问的分布式存储服务。腾讯云 COS 的存储桶空间无容量上限,无需分区管理,适用于 CDN 数据分发、数据万象处理或大数据计算与分析的数据湖等多种场景。
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档