首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在.pyth中读取PyTorch文件类型

在.pyth中读取PyTorch文件类型
EN

Stack Overflow用户
提问于 2022-07-21 11:06:08
回答 1查看 114关注 0票数 2

存储库中有预先训练过的模型,其文件类型为.pyth。我在网上搜索,找出这个文件类型和哪种语言能够读懂,但我什么也找不到。由于我正在使用PyTorch,是否可以在PyTorch中读取这样的文件?此外,通常情况下,如何才能阅读和生成这一点?

更清楚的是,在TimeSformer存储库模型中,预先训练过的模型都是这种文件类型,例如,您可以在存储库中找到以下命令:

代码语言:javascript
运行
复制
import torch
from timesformer.models.vit import TimeSformer

model = TimeSformer(img_size=224, num_classes=400, num_frames=8, attention_type='divided_space_time',  pretrained_model='/path/to/pretrained/model.pyth')

dummy_video = torch.randn(2, 3, 8, 224, 224) # (batch x channels x frames x height x width)

pred = model(dummy_video,) # (2, 400)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-07-21 23:40:02

文件扩展名可以是任何内容,它不会更改文件内容。如果您运行torch.load("file.pyth"),它将加载一个权重字典。您可以在您包括的回购代码中找到这一点。它们使用以下代码保存模型:

代码语言:javascript
运行
复制
path_to_checkpoint = get_path_to_checkpoint(path_to_job, epoch + 1)
with PathManager.open(path_to_checkpoint, "wb") as f:
    torch.save(checkpoint, f)

get_path_to_checkpoint函数可以找到这里

代码语言:javascript
运行
复制
def get_path_to_checkpoint(path_to_job, epoch):
    """
    Get the full path to a checkpoint file.
    Args:
        path_to_job (string): the path to the folder of the current job.
        epoch (int): the number of epoch for the checkpoint.
    """
    name = "checkpoint_epoch_{:05d}.pyth".format(epoch)
    return os.path.join(get_checkpoint_dir(path_to_job), name)

因此,他们只是将扩展名.pyth的文件名传递给torch.save

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73065101

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档