专栏首页AI研习社用PyTorch做物体检测和追踪

用PyTorch做物体检测和追踪

本文为 AI 研习社编译的技术博客,原标题 : Object detection and tracking in PyTorch 作者 | Chris Fotache 翻译 | 酱番梨、麦尔肯•诺埃、TripleZ 校对 | 酱番梨 整理 | 菠萝妹 原文链接: https://towardsdatascience.com/object-detection-and-tracking-in-pytorch-b3cf1a696a98 注:本文的相关链接请点击文末【阅读原文】进行访问

在图像中检测多目标以及在视频中跟踪这些目标

在我之前的工作中,我尝试过用自己的图像在PyTorch中训练一个图像分类器,然后用它来进行图像识别。现在,我将向你们展示如何使用预训练的分类器在一张图像中检测多个目标,之后在整个视频中跟踪他们。

图像分类(识别)和目标检测之间有什么区别?在分类问题中,你识别出在图像中哪一个才是主要目标,然后将整张图片分类到一个单一类别中;在检测问题中,图像中有多个目标被识别、分类,而且目标的位置同样被确定下来(比如一个边界框)。

图像中目标检测

现有多种目标检测算法,其中YOLO,SSD是最受欢迎的方法,本文采用YOLOv3作为示例。本文不会对YOLO的技术细节进行分析,只是关注如何在自己的应用中实现。

直接上代码~YOLO检测的代码是基于Erik Lindernoren实现的Joseph Redmon and Ali Farhadi的文章。代码可以在Github中找到,下面为部分代码,在运行代码之前需要先在config文件夹中运行download_weights.sh脚本下载YOLO的权重文件,首先需要导入必要的模块:

from models import *
from utils import *

import os, sys, time, datetime, random
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

之后下载预训练的配置和权重,Darknet训练所用的COCO数据集的类别名称。在PyTorch中,在加载之后不要忘记将model设置为eval模式。

config_path='config/yolov3.cfg'
weights_path='config/yolov3.weights'
class_path='config/coco.names'
img_size=416
conf_thres=0.8
nms_thres=0.4

# Load model and weights
model = Darknet(config_path, img_size=img_size)
model.load_weights(weights_path)
model.cuda()
model.eval()
classes = utils.load_classes(class_path)
Tensor = torch.cuda.FloatTensor

上述代码中还有一些提前定义的值:图像尺寸(416*416像素),置信度阈值,非极大值抑制阈值。

下面是返回对特定图像的检测结果的基本函数。注意输入Pillow图像,大部分代码将图像resize至416*416,保持图像的纵横比并且填充溢出,实际的检测为最后的4行。

def detect_image(img):
    # scale and pad image
    ratio = min(img_size/img.size[0], img_size/img.size[1])
    imw = round(img.size[0] * ratio)
    imh = round(img.size[1] * ratio)
    img_transforms=transforms.Compose([transforms.Resize((imh,imw)),
         transforms.Pad((max(int((imh-imw)/2),0), 
              max(int((imw-imh)/2),0), max(int((imh-imw)/2),0),
              max(int((imw-imh)/2),0)), (128,128,128)),
         transforms.ToTensor(),
         ])
    # convert image to Tensor
    image_tensor = img_transforms(img).float()
    image_tensor = image_tensor.unsqueeze_(0)
    input_img = Variable(image_tensor.type(Tensor))
    # run inference on the model and get detections
    with torch.no_grad():
        detections = model(input_img)
        detections = utils.non_max_suppression(detections, 80 
                        conf_thres, nms_thres)
    return detections[0]

最后,将加载图像,获取检测结果,显示检测到的目标的边界框组合到一起。同样,这里大部分代码处理图像的放缩和填充,对每个不同的目标类别设置不同的颜色。

# load image and get detections
img_path = "images/blueangels.jpg"
prev_time = time.time()
img = Image.open(img_path)
detections = detect_image(img)
inference_time = datetime.timedelta(seconds=time.time() - prev_time)
print ('Inference Time: %s' % (inference_time))

# Get bounding-box colors
cmap = plt.get_cmap('tab20b')
colors = [cmap(i) for i in np.linspace(0, 1, 20)]
img = np.array(img)
plt.figure()
fig, ax = plt.subplots(1, figsize=(12,9))
ax.imshow(img)

pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape))
pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape))
unpad_h = img_size - pad_y
unpad_w = img_size - pad_x

if detections is not None:
    unique_labels = detections[:, -1].cpu().unique()
    n_cls_preds = len(unique_labels)
    bbox_colors = random.sample(colors, n_cls_preds)
    # browse detections and draw bounding boxes
    for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
        box_h = ((y2 - y1) / unpad_h) * img.shape[0]
        box_w = ((x2 - x1) / unpad_w) * img.shape[1]
        y1 = ((y1 - pad_y // 2) / unpad_h) * img.shape[0]
        x1 = ((x1 - pad_x // 2) / unpad_w) * img.shape[1]
        color = bbox_colors[int(np.where(
             unique_labels == int(cls_pred))[0])]
        bbox = patches.Rectangle((x1, y1), box_w, box_h,
             linewidth=2, edgecolor=color, facecolor='none')
        ax.add_patch(bbox)
        plt.text(x1, y1, s=classes[int(cls_pred)], 
                color='white', verticalalignment='top',
                bbox={'color': color, 'pad': 0})
plt.axis('off')
# save image
plt.savefig(img_path.replace(".jpg", "-det.jpg"),        
                  bbox_inches='tight', pad_inches=0.0)
plt.show()

您可以将这些代码片段放在一起运行代码,或者从Github下载。下面是一些图像中目标检测的例子 :

视频中的物体追踪

所以,现在你知道了检测图像中的不同对象的方法。当你在视频中逐帧执行时,可视化可能非常酷,你会看到这些跟踪框四处移动。但是,如果这些视频帧中有多个对象,我们如何知道一帧中的对象是否与前一帧中的对象相同?这就是我们所说的“对象追踪”,并使用多个检测来识别特定对象随时间的变化。

有几种算法可以做到这一点,我决定使用SORT,它非常易于使用且速度非常快。SORT(简单在线和实时跟踪)是由Alex Bewley,Zongyuan Ge,Lionel Ott,Fabio Ramos,Ben Upcroft等人于2017年撰写的论文,其中提出使用卡尔曼滤波器来预测先前识别的对象的轨迹,并将它们与新的检测相匹配。作者Alex Bewley还写了一个多功能的Python实现,我将用它来讲述这个故事。确保从我的Github repo下载Sort版本,因为我必须进行一些小的更改才能将它集成到我的项目中。

现在我们来详细聊聊代码,前3个代码段将与单个图像检测中的相同,因为它们涉及在单个帧上获取YOLO检测。不同之处在于最后一部分,对于每个检测,我们调用Sort对象的Update函数以获取对图像中对象的引用。因此,除了上一个示例的常规检测(包括边界框的坐标和类预测)之外,我们将获得跟踪对象,除了上述参数之外,还包括对象ID。然后我们以几乎相同的方式显示,但添加该ID并使用不同的颜色,以便大家可以轻松地在视频帧中查看对象。

我还使用OpenCV来读取视频并显示视频帧。请注意,Jupyter笔记本在处理视频时速度很慢。你可以将它用于测试和简单可视化,我还提供了一个独立的Python脚本,它将读取源视频,并输出带有跟踪对象的副本。在笔记本电脑中播放OpenCV视频并不容易,因此你可以将此代码保留在其他实验中。

videopath = 'video/intersection.mp4'
%pylab inline 
import cv2
from IPython.display import clear_output
cmap = plt.get_cmap('tab20b')
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]
# initialize Sort object and video capture
from sort import *
vid = cv2.VideoCapture(videopath)
mot_tracker = Sort()
#while(True):
for ii in range(40):
    ret, frame = vid.read()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pilimg = Image.fromarray(frame)
    detections = detect_image(pilimg)
    img = np.array(pilimg)
    pad_x = max(img.shape[0] - img.shape[1], 0) * 
            (img_size / max(img.shape))
    pad_y = max(img.shape[1] - img.shape[0], 0) * 
            (img_size / max(img.shape))
    unpad_h = img_size - pad_y
    unpad_w = img_size - pad_x
    if detections is not None:
        tracked_objects = mot_tracker.update(detections.cpu())
        unique_labels = detections[:, -1].cpu().unique()
        n_cls_preds = len(unique_labels)
        for x1, y1, x2, y2, obj_id, cls_pred in tracked_objects:
            box_h = int(((y2 - y1) / unpad_h) * img.shape[0])
            box_w = int(((x2 - x1) / unpad_w) * img.shape[1])
            y1 = int(((y1 - pad_y // 2) / unpad_h) * img.shape[0])
            x1 = int(((x1 - pad_x // 2) / unpad_w) * img.shape[1])
            color = colors[int(obj_id) % len(colors)]
            color = [i * 255 for i in color]
            cls = classes[int(cls_pred)]
            cv2.rectangle(frame, (x1, y1), (x1+box_w, y1+box_h),
                         color, 4)
            cv2.rectangle(frame, (x1, y1-35), (x1+len(cls)*19+60,
                         y1), color, -1)
            cv2.putText(frame, cls + "-" + str(int(obj_id)), 
                        (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 
                        1, (255,255,255), 3)
fig=figure(figsize=(12, 8))
    title("Video Stream")
    imshow(frame)
    show()
    clear_output(wait=True)

使用笔记本后,你可以使用常规Python脚本进行实时处理(可以从相机获取输入)并保存视频。以下是我使用此程序生成的视频示例。

PyTorch中的对象检测和跟踪 [深度学习]

就是这样,你可以尝试自己检测图像中的多个对象并在视频帧中跟踪这些对象。你还可以对YOLO进行更多研究,并了解如何使用图像训练模型。 Chris Fotache是位于新泽西州的CYNET.ai的人工智能研究员。他介绍了与人生智能相关的主题,Python编程,机器学习,计算机视觉,自然语言处理等。

想要继续查看该篇文章相关链接和参考文献?

长按链接点击打开或点击底部【阅读原文】:

https://ai.yanxishe.com/page/TextTranslation/1333

AI研习社每日更新精彩内容,观看更多精彩内容:

用 Python 做机器学习不得不收藏的重要库

算法基础:五大排序算法Python实战教程

手把手:用PyTorch实现图像分类器(第一部分)

手把手:用PyTorch实现图像分类器(第二部分)

等你来译:

对混乱的数据进行聚类

初学者怎样使用Keras进行迁移学习

强化学习:通往基于情感的行为系统

一文带你读懂 WaveNet:谷歌助手的声音合成器

本文分享自微信公众号 - AI研习社(okweiwu)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-01-06

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • CV 新手避坑指南:计算机视觉常见的8个错误

    人类并不是完美的,我们经常在编写软件的时候犯错误。有时这些错误很容易找到:你的代码根本不工作,你的应用程序会崩溃。但有些 bug 是隐藏的,很难发现,这使它们更...

    AI研习社
  • 如何使用 OpenCV 编写基于 Node.js 命令行界面和神经网络模型的图像分类

    如何使用 OpenCV 编写基于 Node.js 命令行界面和神经网络模型的图像分类

    AI研习社
  • 如何使用注意力模型生成图像描述?

    我们的目标是用一句话来描述图片, 比如「一个冲浪者正在冲浪」。 本教程中用到了基于注意力的模型,它使我们很直观地看到当文字生成时模型会关注哪些部分。

    AI研习社
  • python: 在图片上 打印中文

    JNingWei
  • 【深度学习系列】用PaddlePaddle进行车牌识别(一)

    小伙伴们,终于到了实战部分了!今天给大家带来的项目是用PaddlePaddle进行车牌识别。车牌识别其实属于比较常见的图像识别的项目了,目前也属于比较成熟的应...

    Charlotte77
  • 图像处理基础(七)图像的PCA(主成分分析)降维

    Pulsar-V
  • matlab图像解密

    最近我一直在准备神经网络方面的推送。但是一直有人问我:以前发过一个关于图像加密的代码,一直没有等到解密的代码出来。该怎么解密。

    matlab爱好者
  • 基于Python查找一张图像中主要颜色组成

    如果我们能够得知道一幅图像中最多的颜色是什么的话,可以帮助我们解决很多实际问题。例如在农业领域中想确定水果的成熟度,我们可以通过检查水果的颜色是否落在特定范围内...

    AI算法与图像处理
  • 基于Python查找图像中最常见的颜色

    如果我们能够得知道一幅图像中最多的颜色是什么的话,可以帮助我们解决很多实际问题。例如在农业领域中想确定水果的成熟度,我们可以通过检查水果的颜色是否落在特定范围内...

    小白学视觉
  • Caffe2 - (十八) 图片数据处理函数

    Caffe2 提供了对图片进行加载、裁剪、缩放、去均值、batch 等处理的函数 - helper.py.

    AIHGF

扫码关注云+社区

领取腾讯云代金券