存储库中有预先训练过的模型,其文件类型为.pyth。我在网上搜索,找出这个文件类型和哪种语言能够读懂,但我什么也找不到。由于我正在使用PyTorch,是否可以在PyTorch中读取这样的文件?此外,通常情况下,如何才能阅读和生成这一点?
更清楚的是,在TimeSformer存储库模型中,预先训练过的模型都是这种文件类型,例如,您可以在存储库中找到以下命令:
import torch
from timesformer.models.vit import TimeSformer
model = TimeSformer(img_size=224, num_classes=400, num_frames=8, attention_type='divided_space_time', pretrained_model='/path/to/pretrained/model.pyth')
dummy_video = torch.randn(2, 3, 8, 224, 224) # (batch x channels x frames x height x width)
pred = model(dummy_video,) # (2, 400)发布于 2022-07-21 23:40:02
文件扩展名可以是任何内容,它不会更改文件内容。如果您运行torch.load("file.pyth"),它将加载一个权重字典。您可以在您包括的回购代码中找到这一点。它们使用以下代码保存模型:
path_to_checkpoint = get_path_to_checkpoint(path_to_job, epoch + 1)
with PathManager.open(path_to_checkpoint, "wb") as f:
torch.save(checkpoint, f)而get_path_to_checkpoint函数可以找到这里。
def get_path_to_checkpoint(path_to_job, epoch):
"""
Get the full path to a checkpoint file.
Args:
path_to_job (string): the path to the folder of the current job.
epoch (int): the number of epoch for the checkpoint.
"""
name = "checkpoint_epoch_{:05d}.pyth".format(epoch)
return os.path.join(get_checkpoint_dir(path_to_job), name)因此,他们只是将扩展名.pyth的文件名传递给torch.save。
https://stackoverflow.com/questions/73065101
复制相似问题