首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >无法从检查点正确加载PyTorch闪电模型

无法从检查点正确加载PyTorch闪电模型
EN

Stack Overflow用户
提问于 2022-03-18 13:52:00
回答 1查看 713关注 0票数 1

我曾训练过以下班的火炬闪电模型:

代码语言:javascript
运行
复制
    class LSTMClassifier(pl.LightningModule):
    def __init__(self, n_features, hidden_size, batch_size, num_layers, dropout, learning_rate):
        super(LSTMClassifier, self).__init__()
        self.save_hyperparameters()

        # Params
        self.n_features = n_features
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.learning_rate = learning_rate

        # Architecture Baseline
        self.lstm = nn.LSTM(input_size=n_features,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            dropout=dropout,
                            batch_first=True)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(hidden_size, 2)
        self.sigmoid = nn.Sigmoid()

当我在训练后直接调用trainer.test()函数时,它给我的测试集精度为0.76:

代码语言:javascript
运行
复制
    # Init PyTorch model
    model = LSTMClassifier(
         n_features=p['n_features'],
         hidden_size=p['hidden_size'],
         batch_size=p['batch_size'],
         num_layers=p['num_layers'],
         dropout=p['dropout'],
         learning_rate=p['learning_rate']
    )

    model_checkpoint = ModelCheckpoint(
        filename='[PATH.ckpt]'
    )

    # Trainer GPU
    trainer = Trainer(max_epochs=p['max_epochs'], callbacks=[model_checkpoint], gpus=int(GPU))

    trainer.fit(model, dm)

    trainer.test(model, test_dataloaders=dm.test_dataloader())

但是,当我稍后用完全相同的数据设备加载检查点时,它给我的精度为0.48:

代码语言:javascript
运行
复制
    model_checkpoint = ModelCheckpoint(
        filename='LSTM-batch-{batch_size}-epoch-{max_epochs}-hidden-{hidden_size}-layers-{'
                 'num_layers}-dropout-{dropout}-lr-{learning_rate}'.format(**p)
    )

    # Trainer GPU
    trainer = Trainer(max_epochs=p['max_epochs'], callbacks=[model_checkpoint], gpus=int(GPU))

    model = LSTMClassifier.load_from_checkpoint([PATH TO CHECKPOINT])
    model.eval()

    trainer.test(model, test_dataloaders=dm.test_dataloader())

我怀疑这个模型没有正确加载,但我不知道该做什么不同。有什么想法吗?

使用PyTorch闪电1.4.4

EN

回答 1

Stack Overflow用户

发布于 2022-03-30 13:22:19

事实证明,trainer.test(model, test_dataloaders=dm.test_dataloader())才是问题所在。一旦我按照更新的trainer.test(model, datamodule=dm)替换了它,它就能工作了。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71528079

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档