专栏首页贾志刚-OpenCV学堂轻松学pytorch – 使用多标签损失函数训练卷积网络

轻松学pytorch – 使用多标签损失函数训练卷积网络

大家好,我还在坚持继续写,如果我没有记错的话,这个是系列文章的第十五篇,pytorch中有很多非常方便使用的损失函数,本文就演示了如何通过多标签损失函数训练验证码识别网络,实现验证码识别。

数据集

这个数据是来自Kaggle上的一个验证码识别例子,作者采用的是迁移学习,基于ResNet18做到的训练。

https://www.kaggle.com/anjalichoudhary12/captcha-with-pytorch

这个数据集总计有1070张验证码图像,我把其中的1040张用作训练,30张作为测试,使用pytorch自定义了一个数据集类,代码如下:

 1import torch
 2import numpy as np
 3from torch.utils.data import Dataset, DataLoader
 4from torchvision import transforms
 5import os
 6import cv2 as cv
 7
 8NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 9ALPHABET = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
10ALL_CHAR_SET = NUMBER + ALPHABET
11ALL_CHAR_SET_LEN = len(ALL_CHAR_SET)
12MAX_CAPTCHA = 5
13
14
15def output_nums():
16    return MAX_CAPTCHA * ALL_CHAR_SET_LEN
17
18
19def encode(a):
20    onehot = [0]*ALL_CHAR_SET_LEN
21    idx = ALL_CHAR_SET.index(a)
22    onehot[idx] += 1
23    return onehot
24
25
26class CapchaDataset(Dataset):
27    def __init__(self, root_dir):
28        self.transform = transforms.Compose([transforms.ToTensor()])
29        img_files = os.listdir(root_dir)
30        self.txt_labels = []
31        self.encodes = []
32        self.images = []
33        for file_name in img_files:
34            label = file_name[:-4]
35            label_oh = []
36            for i in label:
37                label_oh += encode(i)
38            self.images.append(os.path.join(root_dir, file_name))
39            self.encodes.append(np.array(label_oh))
40            self.txt_labels.append(label)
41
42    def __len__(self):
43        return len(self.images)
44
45    def num_of_samples(self):
46        return len(self.images)
47
48    def __getitem__(self, idx):
49        if torch.is_tensor(idx):
50            idx = idx.tolist()
51            image_path = self.images[idx]
52        else:
53            image_path = self.images[idx]
54        img = cv.imread(image_path)  # BGR order
55        h, w, c = img.shape
56        # rescale
57        img = cv.resize(img, (128, 32))
58        img = (np.float32(img) /255.0 - 0.5) / 0.5
59        # H, W C to C, H, W
60        img = img.transpose((2, 0, 1))
61        sample = {'image': torch.from_numpy(img), 'encode': self.encodes[idx], 'label': self.txt_labels[idx]}
62        return sample

模型实现

基于ResNet的block结构,我实现了一个比较简单的残差网络,最后加一个全连接层输出多个标签。验证码是有5个字符的,每个字符的是小写26个字母加上0~9十个数字,总计36个类别,所以5个字符就有5x36=180个输出,其中每个字符是独热编码,这个可以从数据集类的实现看到。模型的输入与输出格式:

输入:NCHW=Nx3x32x128 卷积层最终输出:NCHW=Nx256x1x4 全连接层:Nx(256x4) 最终输出层:Nx180

代码实现如下:

 1class CapchaResNet(torch.nn.Module):
 2    def __init__(self):
 3        super(CapchaResNet, self).__init__()
 4        self.cnn_layers = torch.nn.Sequential(
 5            # 卷积层 (128x32x3)
 6            ResidualBlock(3, 32, 1),
 7            ResidualBlock(32, 64, 2),
 8            ResidualBlock(64, 64, 2),
 9            ResidualBlock(64, 128, 2),
10            ResidualBlock(128, 256, 2),
11            ResidualBlock(256, 256, 2),
12        )
13
14        self.fc_layers = torch.nn.Sequential(
15            torch.nn.Linear(256 * 4, output_nums()),
16        )
17
18    def forward(self, x):
19        # stack convolution layers
20        x = self.cnn_layers(x)
21        out = x.view(-1, 4 * 256)
22        out = self.fc_layers(out)
23        return out

模型训练与测试

使用多标签损失函数,Adam优化器,代码实现如下:

 1model = CapchaResNet()
 2print(model)
 3
 4# 使用GPU
 5if train_on_gpu:
 6    model.cuda()
 7
 8ds = CapchaDataset("D:/python/pytorch_tutorial/capcha/samples")
 9num_train_samples = ds.num_of_samples()
10bs = 16
11dataloader = DataLoader(ds, batch_size=bs, shuffle=True)
12
13# 训练模型的次数
14num_epochs = 25
15# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
16optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
17model.train()
18
19# 损失函数
20mul_loss = torch.nn.MultiLabelSoftMarginLoss()
21index = 0
22for epoch in range(num_epochs):
23    train_loss = 0.0
24    for i_batch, sample_batched in enumerate(dataloader):
25        images_batch, oh_labels = \
26            sample_batched['image'], sample_batched['encode']
27        if train_on_gpu:
28            images_batch, oh_labels= images_batch.cuda(), oh_labels.cuda()
29        optimizer.zero_grad()
30
31        # forward pass: compute predicted outputs by passing inputs to the model
32        m_label_out_ = model(images_batch)
33        oh_labels = torch.autograd.Variable(oh_labels.float())
34
35        # calculate the batch loss
36        loss = mul_loss(m_label_out_, oh_labels)
37
38        # backward pass: compute gradient of the loss with respect to model parameters
39        loss.backward()
40
41        # perform a single optimization step (parameter update)
42        optimizer.step()
43
44        # update training loss
45        train_loss += loss.item()
46        if index % 100 == 0:
47            print('step: {} \tTraining Loss: {:.6f} '.format(index, loss.item()))
48        index += 1
49
50        # 计算平均损失
51    train_loss = train_loss / num_train_samples
52
53    # 显示训练集与验证集的损失函数
54    print('Epoch: {} \tTraining Loss: {:.6f} '.format(epoch, train_loss))
55
56# save model
57model.eval()
58torch.save(model, 'capcha_recognize_model.pt')

调用保存之后的模型,对图像测试代码如下:

 1cnn_model = torch.load("./capcha_recognize_model.pt")
 2root_dir = "D:/python/pytorch_tutorial/capcha/testdata"
 3files = os.listdir(root_dir)
 4one_hot_len = ALL_CHAR_SET_LEN
 5for file in files:
 6    if os.path.isfile(os.path.join(root_dir, file)):
 7        image = cv.imread(os.path.join(root_dir, file))
 8        h, w, c = image.shape
 9        img = cv.resize(image, (128, 32))
10        img = (np.float32(img) /255.0 - 0.5) / 0.5
11        img = img.transpose((2, 0, 1))
12        x_input = torch.from_numpy(img).view(1, 3, 32, 128)
13        probs = cnn_model(x_input.cuda())
14        mul_pred_labels = probs.squeeze().cpu().tolist()
15        c0 = ALL_CHAR_SET[np.argmax(mul_pred_labels[0:one_hot_len])]
16        c1 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len:one_hot_len*2])]
17        c2 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*2:one_hot_len*3])]
18        c3 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*3:one_hot_len*4])]
19        c4 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*4:one_hot_len*5])]
20        pred_txt = '%s%s%s%s%s' % (c0, c1, c2, c3, c4)
21        cv.putText(image, pred_txt, (10, 20), cv.FONT_HERSHEY_PLAIN, 1.5, (0, 0, 255), 2)
22        print("current code : %s, predict code : %s "%(file[:-4], pred_txt))
23        cv.imshow("capcha predict", image)
24        cv.waitKey(0)

其中对输入结果,要根据每个字符的独热编码,截取成五个独立的字符分类标签,然后使用argmax获取index根据index查找类别标签,得到最终的验证码预测字符串,代码运行结果如下:

本文分享自微信公众号 - OpenCV学堂(CVSCHOOL),作者:gloomyfish

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-07-15

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 轻松学Pytorch – 人脸五点landmark提取网络训练与使用

    大家好,本文是轻松学Pytorch系列文章第十篇,本文将介绍如何使用卷积神经网络实现参数回归预测,这个跟之前的分类预测最后softmax层稍有不同,本文将通过卷...

    OpenCV学堂
  • 图形图像算法中必须要了解的设计模式(3)

    随着信息的多元化,信息的概念不仅仅指的是文字,它还包含图片、声音、视频等其它丰富的信息。文字信息越来越多地被图片、声音、视频信息所替代,而视频又是由一针一针的图...

    OpenCV学堂
  • OpenCV中如何读取URL图像文件

    最近知识星球收到的提问,觉得是一个很有趣的问题,就通过搜集整理归纳了一番,主要思想是通过URL解析来生成数据,转为图像/Mat对象。但是在Python语言与C+...

    OpenCV学堂
  • 爬虫遇到头疼的验证码?Python实战讲解弹窗处理和验证码识别

    在我们写爬虫的过程中,目标网站常见的干扰手段就是设置验证码等,本就将基于Selenium实战讲解如何处理弹窗和验证码,爬取的目标网站为某仪器预约平台

    刘早起
  • python-PIL模块画图

    python中执行mysql遇到like 怎么办 ? ​ ​sql = "SELECT * FROM T_ARTICLE WHERE title LIKE '%...

    py3study
  • Python 为什么只需一条语句“a,b=b,a”,就能直接交换两个变量?

    从接触 Python 时起,我就觉得 Python 的元组解包(unpacking)挺有意思,非常简洁好用。

    Python猫
  • 以图搜图系统工程实践

    •提取图像特征向量(用特征向量去表示一幅图像)•特征向量的相似度计算(寻找内容相似的图像)

    凌虚
  • Spark 数据倾斜及其解决方案

    本文从数据倾斜的危害、现象、原因等方面,由浅入深阐述Spark数据倾斜及其解决方案。

    2020labs小助手
  • 使用nginx image filter实现类OSS对象存储中对图片的实时处理

    在家使用自己的电脑做了一个小应用,可查看照片,按以前的方式,需要在用户上传图片后对进行裁剪压缩,然后给前端一个缩略图地址与原图地址。这种方式有两个弊端磁盘空间的...

    兜兜毛毛
  • Hive数据倾斜问题总结

    Hive数据倾斜问题总结 1、MapReduce数据倾斜 Hive查询最终转换为MapReduce操作,所以要先了解MapReduce数据倾斜问题。 MapRe...

    程裕强

扫码关注云+社区

领取腾讯云代金券