专栏首页贾志刚-OpenCV学堂轻松学Pytorch-实现自定义对象检测器

轻松学Pytorch-实现自定义对象检测器

微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识

大家好,今天来继续更新轻松学Pytorch专栏,这个是系列文章我会一直坚持写下去的,希望大家转发、点赞、留言支持!上一篇文章使用了torchvision中提供的预训练对象检测网络Faster-RCNN实现了常见的对象检测,基于COCO数据集,支持90个类型对象检测,非常的实用。本文将介绍如何使用自定义数据集,使用Faster-RCNN预训练模型实现迁移学习,完成自定义对象检测。

数据集

使用了公开的宠物数据集,下载地址如下:

http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz

对象检测模型的输入是image图像,需要target信息包括:

boxes:表示标注的矩形左上角与右下角坐标(x1,y1,x2,y2) labels:表示每个标注框的类别,注意从1开始,0永远表示背景 image_id:数据集中图像索引id值, area:标注框的面积,COCO评估的时候会用到 iscrowd:当iscrowd=true不会参与模型评估计算

从标注xml文件中读取相关信息,完成解析,自定义一个宠物数据集的代码如下:

class PetDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.transforms = T.Compose([T.ToTensor()])
        self.ann_xmls = list(sorted(os.listdir(os.path.join(root_dir, "annotations/xmls"))))

    def __len__(self):
        return len(self.ann_xmls)

    def num_of_samples(self):
        return len(self.ann_xmls)

    def __getitem__(self, idx):
        # load images and bbox
        bbox_xml_path = os.path.join(self.root_dir, "annotations/xmls", self.ann_xmls[idx])

        # 读取xml
        dom = parse(bbox_xml_path)
        # 获取文档元素对象
        data = dom.documentElement
        # 获取 objects
        objects = data.getElementsByTagName('object')
        node = data.getElementsByTagName('filename')[0]
        file_ame = node.childNodes[0].nodeValue
        image_path = os.path.join(self.root_dir, "images", file_ame)
        img = cv.imread(image_path)

        # get bounding box coordinates
        boxes = []
        labels = []
        for object_ in objects:
            # 获取标签中内容
            name = object_.getElementsByTagName('name')[0].childNodes[0].nodeValue
            if name == "dog":
                labels.append(np.int(1))
            if name == "cat":
                labels.append(np.int(2))

            bndbox = object_.getElementsByTagName('bndbox')[0]
            xmin = np.float(bndbox.getElementsByTagName('xmin')[0].childNodes[0].nodeValue)
            ymin = np.float(bndbox.getElementsByTagName('ymin')[0].childNodes[0].nodeValue)
            xmax = np.float(bndbox.getElementsByTagName('xmax')[0].childNodes[0].nodeValue)
            ymax = np.float(bndbox.getElementsByTagName('ymax')[0].childNodes[0].nodeValue)
            boxes.append([xmin, ymin, xmax, ymax])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.as_tensor(labels, dtype=torch.int64)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((len(objects),), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
        img, target = self.transforms(img, target)
        return img, target

顺便说一下,这里输入图像通道顺序是BGR

Faster RCNN模型训练

之前一篇文章中介绍了Faster-RCNN模型与预训练模型使用,这里通过下面的代码加载模型:

num_classes = 2
 model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=num_classes, pretrained_backbone=True)
 device = torch.device('cuda:0')
 model.to(device)

其中pretrained=False表示训练使用,num_classes 表示对象检测数据集的对象类别,这里只有dog跟cat两个类别,所以num_classes =2

设置好了模型的参数,下面就可以初始化加载数据集,开始正式训练,代码如下:

dataset = PetDataset("D:/pytorch/pet_data")
data_loader = torch.utils.data.DataLoader(
     dataset, batch_size=4, shuffle=True,  # num_workers=4,
     collate_fn=utils.collate_fn)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                             momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=5,
                                                gamma=0.1)
num_epochs = 8
for epoch in range(num_epochs):
     train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
     lr_scheduler.step()
torch.save(model.state_dict(), "faster_rcnn_vehicle_model.pt")

运行结果如下:

如果你的内存不够猛,训练的时候可能会得到下面这个错误:

回去改一下batch size就好了,如果改成1还有这个错误话,就直接砸机器就对了!

模型推理使用

对训练好的模型,加载模型,然后就可以推理预测了,代码演示如下:

image = cv.imread("D:/images/test.jpg")
blob = transform(image)
c, h, w = blob.shape
input_x = blob.view(1, c, h, w)
output = model(input_x.cuda())[0]
boxes = output['boxes'].cpu().detach().numpy()
scores = output['scores'].cpu().detach().numpy()
labels = output['labels'].cpu().detach().numpy()
index = 0
for x1, y1, x2, y2 in boxes:
    if scores[index] > 0.5:
        cv.rectangle(image, (np.int32(x1), np.int32(y1)),
                     (np.int32(x2), np.int32(y2)), (140, 199, 0), 4, 8, 0)
        label_id = labels[index]
        label_txt = coco_names[str(label_id)]
        cv.putText(image, label_txt, (np.int32(x1), np.int32(y1)), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 1)
    index += 1
cv.imshow("Faster-RCNN Pet Detection", image)
cv.imwrite("D:/pet2.png", image)
cv.waitKey(0)
cv.destroyAllWindows()

运行结果如下:

本文分享自微信公众号 - OpenCV学堂(CVSCHOOL),作者:gloomyfish

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-07-28

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 轻松学Pytorch – 行人检测Mask-RCNN模型训练与使用

    大家好,这个是轻松学Pytorch的第20篇的文章分享,主要是给大家分享一下,如何使用数据集基于Mask-RCNN训练一个行人检测与实例分割网络。这个例子是来自...

    OpenCV学堂
  • scikit-image图像处理入门

    skimage是纯python语言实现的BSD许可开源图像处理算法库,主要的优势在于:

    OpenCV学堂
  • 【项目实践】YOLO V4万字原理详细讲解并训练自己的数据集(pytorch完整项目打包下载)

    YOLOV4是YOLOV3的改进版,在YOLOV3的基础上结合了非常多的小Tricks。尽管没有目标检测上革命性的改变,但是YOLOV4依然很好...

    OpenCV学堂
  • 基于TensorFlow Eager Execution的简单神经网络模型

    Eager Execution是TensorFlow(TF)中一种从头开始构建深度学习模型的好方法。它允许您构建原型模型,而不会出现TF常规使用的图形方法所带来...

    代码医生工作室
  • 用 RNN 训练语言模型生成文本

    ---- 本文结构: 什么是 Language Model? 怎么实现?怎么应用? ---- cs224d Day 8: 项目2-用 RNN 建立 Langua...

    杨熹
  • 老规矩 从HelloWorld 开始吧

    JRE: Java Runtime Environment 翻译:java 运行 环境

    用户5745563
  • ORM和SQLAlchemy

    (英语:(Object Relational Mapping,简称ORM,或O/RM,或O/R mapping),是一种程序技术,用于实现面向对象编程语言里不同...

    zx钟
  • 一个cheat命令 == Linux命令小抄大全

    当你要执行一个linux命令,在这个命令参数选项众多时,你一般怎么做?对,我们大多数人都会去求助man命令。此外,linux上帮助相关的命令还有”help””w...

    小小科
  • Python基础---类的内置方法

    __init__(): __init__方法在类的一个对象被建立时,马上运行。这个方法可以用来对你的对象做一些你希望的初始化。注意,这个名称的开始和结尾都是双下...

    我被狗咬了
  • python 通过字符串方式调用方法operator.methodcaller

    class Point: def init(self, x, y): self.x = x self.y = y

    用户5760343

扫码关注云+社区

领取腾讯云代金券