前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pyTorch入门(五)——训练自己的数据集

pyTorch入门(五)——训练自己的数据集

作者头像
Vaccae
发布2022-12-29 14:16:03
3660
发布2022-12-29 14:16:03
举报
文章被收录于专栏:微卡智享微卡智享

学更好的别人,

做更好的自己。

——《微卡智享》

本文长度为1749,预计阅读5分钟

前言

前面四篇将Minist数据集的训练及OpenCV的推理都介绍完了,在实际应用项目中,往往需要用自己的数据集进行训练,所以本篇就专门介绍一下pyTorch怎么训练自己的数据集。

微卡智享

生成自己的训练图片

上一篇《pyTorch入门(四)——导出Minist模型,C++ OpenCV DNN进行识别》中使用VS Studio实现了OpenCV的推理,介绍过在推理前需要将图片进行预处理,包括灰度、二值化,查找及排序轮廓都已经处理了,所以只要对上面的代码进行改造一下,将提取的信息保存出来,就是我们想要训练的数据了。先上源码:

代码语言:javascript
复制
#pragma once
#include<iostream>
#include<chrono>
#include<time.h>
#include<opencv2/opencv.hpp>
#include<opencv2/dnn/dnn.hpp>

using namespace cv;
using namespace std;

//参数iType  0-提取图片保存   1-使用DNN推理
int iType = 1;

dnn::Net net;

//排序矩形
void SortRect(vector<Rect>& inputrects) {
  for (int i = 0; i < inputrects.size(); ++i) {
    for (int j = i; j < inputrects.size(); ++j) {
      //说明顺序在上方,这里不用变
      if (inputrects[i].y + inputrects[i].height < inputrects[i].y) {

      }
      //同一排
      else if (inputrects[i].y <= inputrects[j].y + inputrects[j].height) {
        if (inputrects[i].x > inputrects[j].x) {
          swap(inputrects[i], inputrects[j]);
        }
      }
      //下一排
      else if (inputrects[i].y > inputrects[j].y + inputrects[j].height) {
        swap(inputrects[i], inputrects[j]);
      }
    }
  }
}

//处理DNN检测的MINIST图像,防止长方形图像直接转为28*28扁了
void DealInputMat(Mat& src, int row = 28, int col = 28, int tmppadding = 5) {
  int w = src.cols;
  int h = src.rows;
  //看图像的宽高对比,进行处理,先用padding填充黑色,保证图像接近正方形,这样缩放28*28比例不会失衡
  if (w > h) {
    int tmptopbottompadding = (w - h) / 2 + tmppadding;
    copyMakeBorder(src, src, tmptopbottompadding, tmptopbottompadding, tmppadding, tmppadding,
      BORDER_CONSTANT, Scalar(0));
  }
  else {
    int tmpleftrightpadding = (h - w) / 2 + tmppadding;
    copyMakeBorder(src, src, tmppadding, tmppadding, tmpleftrightpadding, tmpleftrightpadding,
      BORDER_CONSTANT, Scalar(0));

  }
  resize(src, src, Size(row, col));
}

// 获取当时系统时间
const string GetCurrentSystemTime()
{
  auto t = chrono::system_clock::to_time_t(std::chrono::system_clock::now());
  struct tm ptm { 60, 59, 23, 31, 11, 1900, 6, 365, -1 };
  _localtime64_s(&ptm, &t);
  char date[60] = { 0 };
  sprintf_s(date, "%d%02d%02d%02d%02d%02d",
    (int)ptm.tm_year + 1900, (int)ptm.tm_mon + 1, (int)ptm.tm_mday,
    (int)ptm.tm_hour, (int)ptm.tm_min, (int)ptm.tm_sec);
  return move(std::string(date));
}

int main(int argc, char** argv) {
  //定义onnx文件
  string onnxfile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/torchminist/ResNet.onnx";

  //测试图片文件
  string testfile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/testpic/test3.png";

  //提取的图片保存位置
  string savefile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/findcontoursMat";

  if (iType == 1) {
    net = dnn::readNetFromONNX(onnxfile);
    if (net.empty()) {
      cout << "加载Onnx文件失败!" << endl;
      return -1;
    }
  }

  //读取图片,灰度,高斯模糊
  Mat src = imread(testfile);
  //备份源图
  Mat backsrc;
  src.copyTo(backsrc);
  cvtColor(src, src, COLOR_BGR2GRAY);
  GaussianBlur(src, src, Size(3, 3), 0.5, 0.5);
  //二值化图片,注意用THRESH_BINARY_INV改为黑底白字,对应MINIST
  threshold(src, src, 0, 255, THRESH_BINARY_INV | THRESH_OTSU);

  //做彭账处理,防止手写的数字没有连起来,这里做了3次膨胀处理
  Mat kernel = getStructuringElement(MORPH_RECT, Size(3, 3));
  //加入开运算先去燥点
  morphologyEx(src, src, MORPH_OPEN, kernel, Point(-1, -1));
  morphologyEx(src, src, MORPH_DILATE, kernel, Point(-1, -1), 3);
  imshow("src", src);

  vector<vector<Point>> contours;
  vector<Vec4i> hierarchy;
  vector<Rect> rects;

  //查找轮廓
  findContours(src, contours, hierarchy, RETR_EXTERNAL, CHAIN_APPROX_NONE);
  for (int i = 0; i < contours.size(); ++i) {
    RotatedRect rect = minAreaRect(contours[i]);
    Rect outrect = rect.boundingRect();
    //插入到矩形列表中
    rects.push_back(outrect);
  }

  //按从左到右,从上到下排序
  SortRect(rects);
  //要输出的图像参数
  for (int i = 0; i < rects.size(); ++i) {
    Mat tmpsrc = src(rects[i]);
    DealInputMat(tmpsrc);

    if (iType == 1) {
      //Mat inputBlob = dnn::blobFromImage(tmpsrc, 0.3081, Size(28, 28), Scalar(0.1307), false, false);
      Mat inputBlob = dnn::blobFromImage(tmpsrc, 1, Size(28, 28), Scalar(), false, false);

      //输入参数值
      net.setInput(inputBlob, "input");
      //预测结果 
      Mat output = net.forward("output");

      //查找出结果中推理的最大值
      Point maxLoc;
      minMaxLoc(output, NULL, NULL, NULL, &maxLoc);

      cout << "预测值:" << maxLoc.x << endl;

      //画出截取图像位置,并显示识别的数字
      rectangle(backsrc, rects[i], Scalar(255, 0, 255));
      putText(backsrc, to_string(maxLoc.x), Point(rects[i].x, rects[i].y), FONT_HERSHEY_PLAIN, 5, Scalar(255, 0, 255), 1, -1);
    }
    else {
      string filename = savefile + "/" + GetCurrentSystemTime() + "-" + to_string(i) + ".jpg";
      cout << filename << endl;
      imwrite(filename, tmpsrc);
    }
  }

  imshow("backsrc", backsrc);


  waitKey(0);
  return 0;
}

划重点

加了一个参数,设置的时候0为提取保存的图片,1是上一篇的推理。

增加了一个获取当前时间的函数,主要作用就是保存图片的时候在文件名加上时间。

增加了一个保存图片的位置

根据上面的参数,设置为1时还是原来的DNN推理,0时通过imwrite将图片进行保存。

接下来我们自己做点数据集,用画图工具在上面写上数字,将0--9的数字分别做了10张图出来。

运行的效果如下:

可以看出上图中我们将数字9的图片分开截取并保存到指定的目录了。

同时在Dataset下创建mydata目录,并创建出train训练的目录,在目录下创建了0-9的文件夹,这样做的目录是在pyTorch调用时会直接根据train下不同的文件夹目录设置对应的label标签了,不用我们在每个进行对照,相应的,提取出的数字图片也要放到对应的目录中

将刚才生成的数字9的图片都剪切到9的文件夹下,其余的数字也是用同样方式。

test测试集也用相同的方式处理,只不过我们拷过来后删了一大部分,就做别的处理。做完这些,提取图片的准备工作就完成了,接下来就是通过pyTorch训练。

微卡智享

pyTorch训练自己数据集

新建了一个trainmydata.py的文件,训练的流程其实和原来差不多,只不过我们是在原来的基础上进行再训练,所以这些的模型是先加载原来的训练模型后,再进行训练,还是先上代码

代码语言:javascript
复制
import torch
import time
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim
import matplotlib.pyplot as plt
from pylab import mpl
import trainModel as tm

##训练轮数
epoch_times = 15

##设置初始预测率,用于判断高于当前预测率的保存模型
toppredicted = 0.0

##设置学习率
learnrate = 0.01 
##设置动量值,如果上一次的momentnum与本次梯度方向是相同的,梯度下降幅度会拉大,起到加速迭代的作用
momentnum = 0.5

##自己训练的模型前面加个my
savemodel_name = "my" + tm.savemodel_name

##生成图用的数组
##预测值
predict_list = []
##训练轮次值
epoch_list = []
##loss值
loss_list = []

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
]) ##Normalize 里面两个值0.1307是均值mean, 0.3081是标准差std,计算好的直接用了

##训练数据集位置
train_mydata = datasets.ImageFolder(
    root = '../datasets/mydata/train', 
    transform = transform
)
train_mydataloader = DataLoader(train_mydata, batch_size=64, shuffle=True, num_workers=0)

##测试数据集位置
test_mydata = datasets.ImageFolder(
    root = '../datasets/mydata/test', 
    transform = transform
)
test_mydataloader = DataLoader(test_mydata, batch_size=1, shuffle=True, num_workers=0)


##加载已经训练好的模型
model = tm.Net(tm.train_name)
model.load_state_dict(torch.load(tm.savemodel_name))

##加入判断是CPU训练还是GPU训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

##优化器 
optimizer = optim.SGD(model.parameters(), lr= learnrate, momentum= momentnum)

##训练函数
def train(epoch):
    model.train()
    for batch_idx, data in enumerate(train_mydataloader, 0):
        inputs, target = data
        ##加入CPU和GPU选择
        inputs, target = inputs.to(device), target.to(device)

        optimizer.zero_grad()

        #前馈,反向传播,更新
        outputs = model(inputs)
        loss = model.criterion(outputs, target)
        loss.backward()
        optimizer.step()

    loss_list.append(loss.item())
    print("progress:", epoch, 'loss=', loss.item())



def test():
    correct = 0 
    total = 0
    model.eval()
    ##with这里标记是不再计算梯度
    with torch.no_grad():
        for data in test_mydataloader:
            inputs, labels = data
            ##加入CPU和GPU选择
            inputs, labels = inputs.to(device), labels.to(device)


            outputs = model(inputs)
            ##预测返回的是两列,第一列是下标就是0-9的值,第二列为预测值,下面的dim=1就是找维度1(第二列)最大值输出
            _, predicted = torch.max(outputs.data, dim=1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    currentpredicted = (100 * correct / total)
    ##用global声明toppredicted,用于在函数内部修改在函数外部声明的全局变量,否则报错
    global toppredicted
    ##当预测率大于原来的保存模型
    if currentpredicted > toppredicted:
        toppredicted = currentpredicted
        torch.save(model.state_dict(), savemodel_name)
        print(savemodel_name+" saved, currentpredicted:%d %%" % currentpredicted)

    predict_list.append(currentpredicted)    
    print('Accuracy on test set: %d %%' % currentpredicted)        

##开始训练
timestart = time.time()
for epoch in range(epoch_times):
    train(epoch)
    test()
timeend = time.time() - timestart
print("use time: {:.0f}m {:.0f}s".format(timeend // 60, timeend % 60))



##设置画布显示中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]
##设置正常显示符号
mpl.rcParams["axes.unicode_minus"] = False

##创建画布
fig, (axloss, axpredict) = plt.subplots(nrows=1, ncols=2, figsize=(8,6))

#loss画布
axloss.plot(range(epoch_times), loss_list, label = 'loss', color='r')
##设置刻度
axloss.set_xticks(range(epoch_times)[::1])
axloss.set_xticklabels(range(epoch_times)[::1])

axloss.set_xlabel('训练轮数')
axloss.set_ylabel('数值')
axloss.set_title(tm.train_name+' 损失值')
#添加图例
axloss.legend(loc = 0)

#predict画布
axpredict.plot(range(epoch_times), predict_list, label = 'predict', color='g')
##设置刻度
axpredict.set_xticks(range(epoch_times)[::1])
axpredict.set_xticklabels(range(epoch_times)[::1])
# axpredict.set_yticks(range(100)[::5])
# axpredict.set_yticklabels(range(100)[::5])

axpredict.set_xlabel('训练轮数')
axpredict.set_ylabel('预测值')
axpredict.set_title(tm.train_name+' 预测值')
#添加图例
axpredict.legend(loc = 0)

#显示图像
plt.show()

划重点

自己训练的模型文件前面加上一个my,用于不覆盖原来的训练模型。

加载训练集和测试集

在transform中,增加了一行transforms.Grayscale(num_output_channels=1),主要原因是在OpenCV中使用imwrite保存的文件,虽然是二值化的图片,但是是3通道的,而在pyTorch我们的训练数据都是1X28X28,即是单通道的图像,所以这里加上这一句是将读取的图片设置为单通道。

使用datasets.ImageFolder直接读取train目录下的数据,自动将图像及对应的标签加载进来了。

加载已训练的模型

这里的model模型直接通过load_state_dict加载进来,然后再训练自己的数据,下面的训练方式和原来train都一样了。

因为我这边保存的数据很少,而且测试集的图片和训练集的一样,只训练了15轮,所以训练到第3轮的时候已经就到100%了。简单的训练自己的数据集就完成了。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2022-12-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 微卡智享 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 加载训练集和测试集
  • 加载已训练的模型
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档