近日,PyTorch 社区发布了一个深度学习工具包 PyTorch Hub,PyTorch Hub 由一个预训练模型仓库组成,可用于提高研究工作的复现性以及新的研究。同时它还内置了对Google Colab的支持,并与Papers With Code集成。
由于之前安装的pytorch 1.0版本并没有该包, 本以为需要更新整个Pytorch包才可以使用。在更新之前,优先看一下hub代码,发现其是一个独立的python文件,没有任何关联代码,于是单独下载该文件,即可以实现对应功能。如下图所示。
由于没有集成到torch包中,需要单独引用hub包,然后进行加载。
从load源码可以看出,该实现功能也比较简单,需要对应先下载github中对应源码包,然后定位到相关模型,调用该模型,才可以实现模型的调用。
def load(github, model, *args, **kwargs):
r"""
Load a model from a github repo, with pretrained weights.
Args:
github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional
tag/branch. The default branch is `master` if not specified.
Example: 'pytorch/vision[:hub]'
model: Required, a string of entrypoint name defined in repo's hubconf.py
*args: Optional, the corresponding args for callable `model`.
force_reload: Optional, whether to force a fresh download of github repo unconditionally.
Default is `False`.
**kwargs: Optional, the corresponding kwargs for callable `model`.
Returns:
a single model with corresponding pretrained weights.
Example:
>>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
"""
# Setup hub_dir to save downloaded files
_setup_hubdir()
force_reload = kwargs.get('force_reload', False)
kwargs.pop('force_reload', None)
repo_dir = _get_cache_or_reload(github, force_reload)
sys.path.insert(0, repo_dir)
hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
entry = _load_entry_from_hubconf(hub_module, model)
model = entry(*args, **kwargs)
sys.path.remove(repo_dir)
return model
从第一个图的执行效果可以看出:要先下载整个github包,再调用模型。如下所示:
通过上述考虑,完全可以通过修改下载包地址到指定网址或者文件夹目录来实现离线hub的设置。
具体修改主要如下:
def _git_archive_link(repo_owner, repo_name, branch):
return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch)