微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识
大家好,今天来继续更新轻松学Pytorch专栏,这个是系列文章我会一直坚持写下去的,希望大家转发、点赞、留言支持!上一篇文章使用了torchvision中提供的预训练对象检测网络Faster-RCNN实现了常见的对象检测,基于COCO数据集,支持90个类型对象检测,非常的实用。本文将介绍如何使用自定义数据集,使用Faster-RCNN预训练模型实现迁移学习,完成自定义对象检测。
数据集
使用了公开的宠物数据集,下载地址如下:
http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
对象检测模型的输入是image图像,需要target信息包括:
boxes:表示标注的矩形左上角与右下角坐标(x1,y1,x2,y2) labels:表示每个标注框的类别,注意从1开始,0永远表示背景 image_id:数据集中图像索引id值, area:标注框的面积,COCO评估的时候会用到 iscrowd:当iscrowd=true不会参与模型评估计算
从标注xml文件中读取相关信息,完成解析,自定义一个宠物数据集的代码如下:
class PetDataset(Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.transforms = T.Compose([T.ToTensor()])
self.ann_xmls = list(sorted(os.listdir(os.path.join(root_dir, "annotations/xmls"))))
def __len__(self):
return len(self.ann_xmls)
def num_of_samples(self):
return len(self.ann_xmls)
def __getitem__(self, idx):
# load images and bbox
bbox_xml_path = os.path.join(self.root_dir, "annotations/xmls", self.ann_xmls[idx])
# 读取xml
dom = parse(bbox_xml_path)
# 获取文档元素对象
data = dom.documentElement
# 获取 objects
objects = data.getElementsByTagName('object')
node = data.getElementsByTagName('filename')[0]
file_ame = node.childNodes[0].nodeValue
image_path = os.path.join(self.root_dir, "images", file_ame)
img = cv.imread(image_path)
# get bounding box coordinates
boxes = []
labels = []
for object_ in objects:
# 获取标签中内容
name = object_.getElementsByTagName('name')[0].childNodes[0].nodeValue
if name == "dog":
labels.append(np.int(1))
if name == "cat":
labels.append(np.int(2))
bndbox = object_.getElementsByTagName('bndbox')[0]
xmin = np.float(bndbox.getElementsByTagName('xmin')[0].childNodes[0].nodeValue)
ymin = np.float(bndbox.getElementsByTagName('ymin')[0].childNodes[0].nodeValue)
xmax = np.float(bndbox.getElementsByTagName('xmax')[0].childNodes[0].nodeValue)
ymax = np.float(bndbox.getElementsByTagName('ymax')[0].childNodes[0].nodeValue)
boxes.append([xmin, ymin, xmax, ymax])
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# there is only one class
labels = torch.as_tensor(labels, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
iscrowd = torch.zeros((len(objects),), dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
img, target = self.transforms(img, target)
return img, target
顺便说一下,这里输入图像通道顺序是BGR!
Faster RCNN模型训练
之前一篇文章中介绍了Faster-RCNN模型与预训练模型使用,这里通过下面的代码加载模型:
num_classes = 2
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=num_classes, pretrained_backbone=True)
device = torch.device('cuda:0')
model.to(device)
其中pretrained=False表示训练使用,num_classes 表示对象检测数据集的对象类别,这里只有dog跟cat两个类别,所以num_classes =2
设置好了模型的参数,下面就可以初始化加载数据集,开始正式训练,代码如下:
dataset = PetDataset("D:/pytorch/pet_data")
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=4, shuffle=True, # num_workers=4,
collate_fn=utils.collate_fn)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=5,
gamma=0.1)
num_epochs = 8
for epoch in range(num_epochs):
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
lr_scheduler.step()
torch.save(model.state_dict(), "faster_rcnn_vehicle_model.pt")
运行结果如下:
如果你的内存不够猛,训练的时候可能会得到下面这个错误:
回去改一下batch size就好了,如果改成1还有这个错误话,就直接砸机器就对了!
模型推理使用
对训练好的模型,加载模型,然后就可以推理预测了,代码演示如下:
image = cv.imread("D:/images/test.jpg")
blob = transform(image)
c, h, w = blob.shape
input_x = blob.view(1, c, h, w)
output = model(input_x.cuda())[0]
boxes = output['boxes'].cpu().detach().numpy()
scores = output['scores'].cpu().detach().numpy()
labels = output['labels'].cpu().detach().numpy()
index = 0
for x1, y1, x2, y2 in boxes:
if scores[index] > 0.5:
cv.rectangle(image, (np.int32(x1), np.int32(y1)),
(np.int32(x2), np.int32(y2)), (140, 199, 0), 4, 8, 0)
label_id = labels[index]
label_txt = coco_names[str(label_id)]
cv.putText(image, label_txt, (np.int32(x1), np.int32(y1)), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 1)
index += 1
cv.imshow("Faster-RCNN Pet Detection", image)
cv.imwrite("D:/pet2.png", image)
cv.waitKey(0)
cv.destroyAllWindows()
运行结果如下: