前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MxNet预训练模型到Pytorch模型的转换

MxNet预训练模型到Pytorch模型的转换

作者头像
sparkexpert
发布2019-05-26 14:05:06
2.2K0
发布2019-05-26 14:05:06
举报

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

代码语言:javascript
复制
def convert_from_mxnet(model, checkpoint_prefix, debug=False):
    _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
    remapped_state = {}
    for state_key in model.state_dict().keys():
        k = state_key.split('.')
        aux = False
        mxnet_key = ''
        if k[0] == 'features':
            if k[1] == 'conv1_1':
                # input block
                mxnet_key += 'conv1_x_1__'
                if k[2] == 'bn':
                    mxnet_key += 'relu-sp__bn_'
                    aux, key_add = _convert_bn(k[3])
                    mxnet_key += key_add
                else:
                    assert k[3] == 'weight'
                    mxnet_key += 'conv_' + k[3]
            elif k[1] == 'conv5_bn_ac':
                # bn + ac at end of features block
                mxnet_key += 'conv5_x_x__relu-sp__bn_'
                assert k[2] == 'bn'
                aux, key_add = _convert_bn(k[3])
                mxnet_key += key_add
            else:
                # middle blocks
                if model.b and 'c1x1_c' in k[2]:
                    bc_block = True  # b-variant split c-block special treatment
                else:
                    bc_block = False
                ck = k[1].split('_')
                mxnet_key += ck[0] + '_x__' + ck[1] + '_'
                ck = k[2].split('_')
                mxnet_key += ck[0] + '-' + ck[1]
                if ck[1] == 'w' and len(ck) > 2:
                    mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
                mxnet_key += '__'
                if k[3] == 'bn':
                    mxnet_key += 'bn_' if bc_block else 'bn__bn_'
                    aux, key_add = _convert_bn(k[4])
                    mxnet_key += key_add
                else:
                    ki = 3 if bc_block else 4
                    assert k[ki] == 'weight'
                    mxnet_key += 'conv_' + k[ki]
        elif k[0] == 'classifier':
            if 'fc6-1k_weight' in mxnet_weights:
                mxnet_key += 'fc6-1k_'
            else:
                mxnet_key += 'fc6_'
            mxnet_key += k[1]
        else:
            assert False, 'Unexpected token'

        if debug:
            print(mxnet_key, '=> ', state_key, end=' ')

        mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
        torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
        if k[0] == 'classifier' and k[1] == 'weight':
            torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
        remapped_state[state_key] = torch_tensor

        if debug:
            print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())

    model.load_state_dict(remapped_state)

    return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

第二步,运行转换程序,实现预训练模型的转换。

可以看到在相当的文件夹下已经出现了转换后的模型。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年06月28日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档