首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >根据层次结构将pytorch模型参数导出到不同的文件中

根据层次结构将pytorch模型参数导出到不同的文件中
EN

Stack Overflow用户
提问于 2020-08-19 23:32:56
回答 1查看 96关注 0票数 1

是否可以将Pytorch模型的训练参数导出到文件夹层次结构下的单独二进制文件(float32/64,而不是文本)中,以反映模型的体系结构定义的层?

我希望在没有框架开销的情况下检查一个相当大的预训练模型,并将检查点拆分成可管理的块。

EN

回答 1

Stack Overflow用户

发布于 2020-08-20 00:03:43

没有直接的方法可以做到这一点,但它应该只需要几行代码。例如,假设我有一个具有以下结构的模型:

代码语言:javascript
运行
复制
class ConvBlock(nn.Module):
    def __init__(self, C_in, C_out, kernel, pool):
        super().__init__()
        self.conv = nn.Conv2d(C_in, C_out, kernel)
        self.relu = nn.ReLU(inplace = True)
        self.pool = nn.MaxPool2d(2,2) if pool else nn.Identity()
        
    def forward(self, input):
        out = self.conv(input)
        out = self.relu(out)
        out = self.pool(out)
        return out
    
class LeNet5(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = ConvBlock(1, 6, 5, pool = True)
        self.block2 = ConvBlock(6, 16, 5, pool = True)
        self.block3 = ConvBlock(16, 120, 5, pool = False)
        self.fc = nn.Sequential(
            nn.Linear(120, 84),
            nn.ReLU(inplace = True),
            nn.Linear(84, 10)
        )
        
    def forward(self, input):
        out = self.block1(input)
        out = self.block2(out)
        out = self.block3(out)
        out = out.view(-1,120)
        out = self.fc(out)
        return out

要对各个参数进行二进制化,您需要做的就是遍历它们。

代码语言:javascript
运行
复制
net = LeNet5()

basedir = 'lenet_params'

for name, param in net.named_parameters():
    name = name.split('.')
    out_dir, filename = os.path.join(basedir, *name[:-1]), name[-1]+'.pth'
    out_path = os.path.join(out_dir, filename)

    if not os.path.exists(out_dir):
        os.makedirs(out_dir, exist_ok=True)
    torch.save(param, out_path)

这将产生如下的目录结构:

代码语言:javascript
运行
复制
lenet_params
|---block1
|   |---conv
|   |   |---weight.pth
|   |   |---bias.pth
|---block2
|   |---conv
|   |   |---weight.pth
|   |   |---bias.pth
|---block3
|   |---conv
|   |   |---weight.pth
|   |   |---bias.pth
|---fc
|   |---0
|   |   |---weight.pth
|   |   |---bias.pth
|   |---2
|   |   |---weight.pth
|   |   |---bias.pth
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63490419

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档