专栏首页数据分析与挖掘【猫狗数据集】加载保存的模型进行测试

【猫狗数据集】加载保存的模型进行测试

已重新上传好数据集:

分割线-----------------------------------------------------------------

2020.3.10

发现数据集没有完整的上传到谷歌的colab上去,我说怎么计算出来的step不对劲。

测试集是完整的。

训练集中cat的确是有10125张图片,而dog只有1973张,所以完成一个epoch需要迭代的次数为:

(10125+1973)/128=94.515625,约等于95。

顺便提一下,有两种方式可以计算出数据集的量:

第一种:print(len(train_dataset))

第二种:在../dog目录下,输入ls | wc -c

今天重新上传dog数据集。

分割线-----------------------------------------------------------------

数据集下载地址:

链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw 提取码:2xq4

创建数据集:https://www.cnblogs.com/xiximayou/p/12398285.html

读取数据集:https://www.cnblogs.com/xiximayou/p/12422827.html

进行训练:https://www.cnblogs.com/xiximayou/p/12448300.html

保存模型并继续进行训练:https://www.cnblogs.com/xiximayou/p/12452624.html

epoch、batchsize、step之间的关系:https://www.cnblogs.com/xiximayou/p/12405485.html

我们在test目录下新建一个文件test.py

test.py

import sys
sys.path.append("/content/drive/My Drive/colab notebooks")
from utils import rdata
from model import resnet
import torch.nn as nn
import torch
import numpy as np
import torchvision


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_loader,test_loader,train_data,test_data=rdata.load_dataset()
model =torchvision.models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features,2,bias=False)
model.cuda()
#print(model) 

save_path="/content/drive/My Drive/colab notebooks/output/dogcat-resnet18.t7" 
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model'])
start_epoch = checkpoint['epoch']
start_loss=checkpoint["train_loss"]
start_acc=checkpoint["train_acc"]
print("当前epoch:{} 当前训练损失:{:.4f} 当前训练准确率:{:.4f}".format(start_epoch+1,start_loss,start_acc))

num_epochs=1
criterion=nn.CrossEntropyLoss()

# Train the model
total_step = len(test_loader)
def test():
  for epoch in range(num_epochs):
      tot_loss = 0.0
      correct = 0
      for i ,(images, labels) in enumerate(test_loader):
          images = images.cuda()
          labels = labels.cuda()

          # Forward pass
          outputs = model(images)
          _, preds = torch.max(outputs.data,1)
          loss = criterion(outputs, labels)
          tot_loss += loss.data
          correct += torch.sum(preds == labels.data).to(torch.float32)
          if (i+1) % 2 == 0:
              print('Epoch: [{}/{}], Step: [{}/{}], Loss: {:.4f}'
                    .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
      ### Epoch info ####
      epoch_loss = tot_loss/len(test_data)
      print('test loss: {:.4f}'.format(epoch_loss))
      epoch_acc = correct/len(test_data)
      print('test acc: {:.4f}'.format(epoch_acc))
with torch.no_grad():
  test()

需要注意,测试的时候我们不需要进行反向传播更新参数。

结果:

当前epoch:2 当前训练损失:0.0037 当前训练准确率:0.8349
Epoch: [1/1], Step: [2/38], Loss: 1.0218
Epoch: [1/1], Step: [4/38], Loss: 0.9890
Epoch: [1/1], Step: [6/38], Loss: 0.9255
Epoch: [1/1], Step: [8/38], Loss: 0.9305
Epoch: [1/1], Step: [10/38], Loss: 0.9013
Epoch: [1/1], Step: [12/38], Loss: 1.0436
Epoch: [1/1], Step: [14/38], Loss: 0.8102
Epoch: [1/1], Step: [16/38], Loss: 0.9356
Epoch: [1/1], Step: [18/38], Loss: 0.8668
Epoch: [1/1], Step: [20/38], Loss: 1.0083
Epoch: [1/1], Step: [22/38], Loss: 1.0202
Epoch: [1/1], Step: [24/38], Loss: 0.8906
Epoch: [1/1], Step: [26/38], Loss: 1.0110
Epoch: [1/1], Step: [28/38], Loss: 0.8508
Epoch: [1/1], Step: [30/38], Loss: 0.9539
Epoch: [1/1], Step: [32/38], Loss: 0.9225
Epoch: [1/1], Step: [34/38], Loss: 0.9501
Epoch: [1/1], Step: [36/38], Loss: 0.8252
Epoch: [1/1], Step: [38/38], Loss: 0.9201
test loss: 0.0074
test acc: 0.5000

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • golang数据结构之冒泡排序

    绝命生
  • golang数据结构之快速排序

    具体过程:黑色标记代表左指针,红色标记代表右指针,蓝色标记代表中间值。(依次从左往向下)

    绝命生
  • (三十七)golang--如何获取命令行参数

    绝命生
  • Java集合总览

    这篇文章总结了所有的Java集合(Collection)。主要介绍各个集合的特性和用途,以及在不同的集合类型之间转换的方式。 Arrays Array是Java...

    非著名程序员
  • Java集合类型详解

    这篇文章总结了所有的Java集合(Collection)。主要介绍各个集合的特性和用途,以及在不同的集合类型之间转换的方式。

    Java团长
  • javascript数组去重方法

    版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_35512245/articl...

    大黄大黄大黄
  • Pythpon 爬取中国天气网数据

    以前看别人用python写爬取数据的程序感觉特牛掰,今天在网上找到了一个例子参考了下,自己也写了一个。之后会结合微信机器人,然后每隔一段时间给自己和好友发送天气...

    用户5908113
  • Web 自动化:一种基于 Page Object 的实现及常见异常处理

    Page Object 设计模式是 Selenium 官网推荐的一种自动化构建模式。PageObject 设计模式对网页进行一个简单抽象,将每个页面设计成一个类...

    腾讯移动品质中心TMQ
  • JavaScript数组基础及实例

    js数组 和var i=1;这样的简单存储一样是js中的一种数据结构,是专门用来存储多个数据的一种数据结构。 摘:数组是一组数据的集合,其表现形式就是内存中的一...

    二十三年蝉
  • 干货集锦(下)︱云+未来峰会:如何保护企业数据,建立安全壁垒?

    腾讯云安全

扫码关注云+社区

领取腾讯云代金券