首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytroch Hub原理分析及离线hub设计考虑

Pytroch Hub原理分析及离线hub设计考虑

作者头像
sparkexpert
发布2019-07-02 18:09:33
1.3K0
发布2019-07-02 18:09:33
举报

近日,PyTorch 社区发布了一个深度学习工具包 PyTorch Hub,PyTorch Hub 由一个预训练模型仓库组成,可用于提高研究工作的复现性以及新的研究。同时它还内置了对Google Colab的支持,并与Papers With Code集成。

由于之前安装的pytorch 1.0版本并没有该包, 本以为需要更新整个Pytorch包才可以使用。在更新之前,优先看一下hub代码,发现其是一个独立的python文件,没有任何关联代码,于是单独下载该文件,即可以实现对应功能。如下图所示。

由于没有集成到torch包中,需要单独引用hub包,然后进行加载。

从load源码可以看出,该实现功能也比较简单,需要对应先下载github中对应源码包,然后定位到相关模型,调用该模型,才可以实现模型的调用。

def load(github, model, *args, **kwargs):
    r"""
    Load a model from a github repo, with pretrained weights.

    Args:
        github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional
            tag/branch. The default branch is `master` if not specified.
            Example: 'pytorch/vision[:hub]'
        model: Required, a string of entrypoint name defined in repo's hubconf.py
        *args: Optional, the corresponding args for callable `model`.
        force_reload: Optional, whether to force a fresh download of github repo unconditionally.
            Default is `False`.
        **kwargs: Optional, the corresponding kwargs for callable `model`.

    Returns:
        a single model with corresponding pretrained weights.

    Example:
        >>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
    """
    # Setup hub_dir to save downloaded files
    _setup_hubdir()

    force_reload = kwargs.get('force_reload', False)
    kwargs.pop('force_reload', None)

    repo_dir = _get_cache_or_reload(github, force_reload)

    sys.path.insert(0, repo_dir)

    hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)

    entry = _load_entry_from_hubconf(hub_module, model)

    model = entry(*args, **kwargs)

    sys.path.remove(repo_dir)

    return model

从第一个图的执行效果可以看出:要先下载整个github包,再调用模型。如下所示:

通过上述考虑,完全可以通过修改下载包地址到指定网址或者文件夹目录来实现离线hub的设置。

具体修改主要如下:

def _git_archive_link(repo_owner, repo_name, branch):
    return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch)
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年06月14日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档