前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >轻松学pytorch – 使用多标签损失函数训练卷积网络

轻松学pytorch – 使用多标签损失函数训练卷积网络

作者头像
OpenCV学堂
发布2020-07-16 21:46:08
1.1K0
发布2020-07-16 21:46:08
举报

大家好,我还在坚持继续写,如果我没有记错的话,这个是系列文章的第十五篇,pytorch中有很多非常方便使用的损失函数,本文就演示了如何通过多标签损失函数训练验证码识别网络,实现验证码识别。

数据集

这个数据是来自Kaggle上的一个验证码识别例子,作者采用的是迁移学习,基于ResNet18做到的训练。

https://www.kaggle.com/anjalichoudhary12/captcha-with-pytorch

这个数据集总计有1070张验证码图像,我把其中的1040张用作训练,30张作为测试,使用pytorch自定义了一个数据集类,代码如下:

代码语言:javascript
复制
 1import torch
 2import numpy as np
 3from torch.utils.data import Dataset, DataLoader
 4from torchvision import transforms
 5import os
 6import cv2 as cv
 7
 8NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 9ALPHABET = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
10ALL_CHAR_SET = NUMBER + ALPHABET
11ALL_CHAR_SET_LEN = len(ALL_CHAR_SET)
12MAX_CAPTCHA = 5
13
14
15def output_nums():
16    return MAX_CAPTCHA * ALL_CHAR_SET_LEN
17
18
19def encode(a):
20    onehot = [0]*ALL_CHAR_SET_LEN
21    idx = ALL_CHAR_SET.index(a)
22    onehot[idx] += 1
23    return onehot
24
25
26class CapchaDataset(Dataset):
27    def __init__(self, root_dir):
28        self.transform = transforms.Compose([transforms.ToTensor()])
29        img_files = os.listdir(root_dir)
30        self.txt_labels = []
31        self.encodes = []
32        self.images = []
33        for file_name in img_files:
34            label = file_name[:-4]
35            label_oh = []
36            for i in label:
37                label_oh += encode(i)
38            self.images.append(os.path.join(root_dir, file_name))
39            self.encodes.append(np.array(label_oh))
40            self.txt_labels.append(label)
41
42    def __len__(self):
43        return len(self.images)
44
45    def num_of_samples(self):
46        return len(self.images)
47
48    def __getitem__(self, idx):
49        if torch.is_tensor(idx):
50            idx = idx.tolist()
51            image_path = self.images[idx]
52        else:
53            image_path = self.images[idx]
54        img = cv.imread(image_path)  # BGR order
55        h, w, c = img.shape
56        # rescale
57        img = cv.resize(img, (128, 32))
58        img = (np.float32(img) /255.0 - 0.5) / 0.5
59        # H, W C to C, H, W
60        img = img.transpose((2, 0, 1))
61        sample = {'image': torch.from_numpy(img), 'encode': self.encodes[idx], 'label': self.txt_labels[idx]}
62        return sample

模型实现

基于ResNet的block结构,我实现了一个比较简单的残差网络,最后加一个全连接层输出多个标签。验证码是有5个字符的,每个字符的是小写26个字母加上0~9十个数字,总计36个类别,所以5个字符就有5x36=180个输出,其中每个字符是独热编码,这个可以从数据集类的实现看到。模型的输入与输出格式:

输入:NCHW=Nx3x32x128 卷积层最终输出:NCHW=Nx256x1x4 全连接层:Nx(256x4) 最终输出层:Nx180

代码实现如下:

代码语言:javascript
复制
 1class CapchaResNet(torch.nn.Module):
 2    def __init__(self):
 3        super(CapchaResNet, self).__init__()
 4        self.cnn_layers = torch.nn.Sequential(
 5            # 卷积层 (128x32x3)
 6            ResidualBlock(3, 32, 1),
 7            ResidualBlock(32, 64, 2),
 8            ResidualBlock(64, 64, 2),
 9            ResidualBlock(64, 128, 2),
10            ResidualBlock(128, 256, 2),
11            ResidualBlock(256, 256, 2),
12        )
13
14        self.fc_layers = torch.nn.Sequential(
15            torch.nn.Linear(256 * 4, output_nums()),
16        )
17
18    def forward(self, x):
19        # stack convolution layers
20        x = self.cnn_layers(x)
21        out = x.view(-1, 4 * 256)
22        out = self.fc_layers(out)
23        return out

模型训练与测试

使用多标签损失函数,Adam优化器,代码实现如下:

代码语言:javascript
复制
 1model = CapchaResNet()
 2print(model)
 3
 4# 使用GPU
 5if train_on_gpu:
 6    model.cuda()
 7
 8ds = CapchaDataset("D:/python/pytorch_tutorial/capcha/samples")
 9num_train_samples = ds.num_of_samples()
10bs = 16
11dataloader = DataLoader(ds, batch_size=bs, shuffle=True)
12
13# 训练模型的次数
14num_epochs = 25
15# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
16optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
17model.train()
18
19# 损失函数
20mul_loss = torch.nn.MultiLabelSoftMarginLoss()
21index = 0
22for epoch in range(num_epochs):
23    train_loss = 0.0
24    for i_batch, sample_batched in enumerate(dataloader):
25        images_batch, oh_labels = \
26            sample_batched['image'], sample_batched['encode']
27        if train_on_gpu:
28            images_batch, oh_labels= images_batch.cuda(), oh_labels.cuda()
29        optimizer.zero_grad()
30
31        # forward pass: compute predicted outputs by passing inputs to the model
32        m_label_out_ = model(images_batch)
33        oh_labels = torch.autograd.Variable(oh_labels.float())
34
35        # calculate the batch loss
36        loss = mul_loss(m_label_out_, oh_labels)
37
38        # backward pass: compute gradient of the loss with respect to model parameters
39        loss.backward()
40
41        # perform a single optimization step (parameter update)
42        optimizer.step()
43
44        # update training loss
45        train_loss += loss.item()
46        if index % 100 == 0:
47            print('step: {} \tTraining Loss: {:.6f} '.format(index, loss.item()))
48        index += 1
49
50        # 计算平均损失
51    train_loss = train_loss / num_train_samples
52
53    # 显示训练集与验证集的损失函数
54    print('Epoch: {} \tTraining Loss: {:.6f} '.format(epoch, train_loss))
55
56# save model
57model.eval()
58torch.save(model, 'capcha_recognize_model.pt')

调用保存之后的模型,对图像测试代码如下:

代码语言:javascript
复制
 1cnn_model = torch.load("./capcha_recognize_model.pt")
 2root_dir = "D:/python/pytorch_tutorial/capcha/testdata"
 3files = os.listdir(root_dir)
 4one_hot_len = ALL_CHAR_SET_LEN
 5for file in files:
 6    if os.path.isfile(os.path.join(root_dir, file)):
 7        image = cv.imread(os.path.join(root_dir, file))
 8        h, w, c = image.shape
 9        img = cv.resize(image, (128, 32))
10        img = (np.float32(img) /255.0 - 0.5) / 0.5
11        img = img.transpose((2, 0, 1))
12        x_input = torch.from_numpy(img).view(1, 3, 32, 128)
13        probs = cnn_model(x_input.cuda())
14        mul_pred_labels = probs.squeeze().cpu().tolist()
15        c0 = ALL_CHAR_SET[np.argmax(mul_pred_labels[0:one_hot_len])]
16        c1 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len:one_hot_len*2])]
17        c2 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*2:one_hot_len*3])]
18        c3 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*3:one_hot_len*4])]
19        c4 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*4:one_hot_len*5])]
20        pred_txt = '%s%s%s%s%s' % (c0, c1, c2, c3, c4)
21        cv.putText(image, pred_txt, (10, 20), cv.FONT_HERSHEY_PLAIN, 1.5, (0, 0, 255), 2)
22        print("current code : %s, predict code : %s "%(file[:-4], pred_txt))
23        cv.imshow("capcha predict", image)
24        cv.waitKey(0)

其中对输入结果,要根据每个字符的独热编码,截取成五个独立的字符分类标签,然后使用argmax获取index根据index查找类别标签,得到最终的验证码预测字符串,代码运行结果如下:

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-07-15,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 数据集
  • 模型训练与测试
相关产品与服务
验证码
腾讯云新一代行为验证码(Captcha),基于十道安全栅栏, 为网页、App、小程序开发者打造立体、全面的人机验证。最大程度保护注册登录、活动秒杀、点赞发帖、数据保护等各大场景下业务安全的同时,提供更精细化的用户体验。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档