前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >用PyTorch做物体检测和追踪

用PyTorch做物体检测和追踪

作者头像
AI研习社
发布2019-01-08 18:07:12
1.8K0
发布2019-01-08 18:07:12
举报
文章被收录于专栏:AI研习社AI研习社

本文为 AI 研习社编译的技术博客,原标题 : Object detection and tracking in PyTorch 作者 | Chris Fotache 翻译 | 酱番梨、麦尔肯•诺埃、TripleZ 校对 | 酱番梨 整理 | 菠萝妹 原文链接: https://towardsdatascience.com/object-detection-and-tracking-in-pytorch-b3cf1a696a98 注:本文的相关链接请点击文末【阅读原文】进行访问

在图像中检测多目标以及在视频中跟踪这些目标

在我之前的工作中,我尝试过用自己的图像在PyTorch中训练一个图像分类器,然后用它来进行图像识别。现在,我将向你们展示如何使用预训练的分类器在一张图像中检测多个目标,之后在整个视频中跟踪他们。

图像分类(识别)和目标检测之间有什么区别?在分类问题中,你识别出在图像中哪一个才是主要目标,然后将整张图片分类到一个单一类别中;在检测问题中,图像中有多个目标被识别、分类,而且目标的位置同样被确定下来(比如一个边界框)。

图像中目标检测

现有多种目标检测算法,其中YOLO,SSD是最受欢迎的方法,本文采用YOLOv3作为示例。本文不会对YOLO的技术细节进行分析,只是关注如何在自己的应用中实现。

直接上代码~YOLO检测的代码是基于Erik Lindernoren实现的Joseph Redmon and Ali Farhadi的文章。代码可以在Github中找到,下面为部分代码,在运行代码之前需要先在config文件夹中运行download_weights.sh脚本下载YOLO的权重文件,首先需要导入必要的模块:

代码语言:javascript
复制
from models import *
from utils import *

import os, sys, time, datetime, random
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

之后下载预训练的配置和权重,Darknet训练所用的COCO数据集的类别名称。在PyTorch中,在加载之后不要忘记将model设置为eval模式。

代码语言:javascript
复制
config_path='config/yolov3.cfg'
weights_path='config/yolov3.weights'
class_path='config/coco.names'
img_size=416
conf_thres=0.8
nms_thres=0.4

# Load model and weights
model = Darknet(config_path, img_size=img_size)
model.load_weights(weights_path)
model.cuda()
model.eval()
classes = utils.load_classes(class_path)
Tensor = torch.cuda.FloatTensor

上述代码中还有一些提前定义的值:图像尺寸(416*416像素),置信度阈值,非极大值抑制阈值。

下面是返回对特定图像的检测结果的基本函数。注意输入Pillow图像,大部分代码将图像resize至416*416,保持图像的纵横比并且填充溢出,实际的检测为最后的4行。

代码语言:javascript
复制
def detect_image(img):
    # scale and pad image
    ratio = min(img_size/img.size[0], img_size/img.size[1])
    imw = round(img.size[0] * ratio)
    imh = round(img.size[1] * ratio)
    img_transforms=transforms.Compose([transforms.Resize((imh,imw)),
         transforms.Pad((max(int((imh-imw)/2),0), 
              max(int((imw-imh)/2),0), max(int((imh-imw)/2),0),
              max(int((imw-imh)/2),0)), (128,128,128)),
         transforms.ToTensor(),
         ])
    # convert image to Tensor
    image_tensor = img_transforms(img).float()
    image_tensor = image_tensor.unsqueeze_(0)
    input_img = Variable(image_tensor.type(Tensor))
    # run inference on the model and get detections
    with torch.no_grad():
        detections = model(input_img)
        detections = utils.non_max_suppression(detections, 80 
                        conf_thres, nms_thres)
    return detections[0]

最后,将加载图像,获取检测结果,显示检测到的目标的边界框组合到一起。同样,这里大部分代码处理图像的放缩和填充,对每个不同的目标类别设置不同的颜色。

代码语言:javascript
复制
# load image and get detections
img_path = "images/blueangels.jpg"
prev_time = time.time()
img = Image.open(img_path)
detections = detect_image(img)
inference_time = datetime.timedelta(seconds=time.time() - prev_time)
print ('Inference Time: %s' % (inference_time))

# Get bounding-box colors
cmap = plt.get_cmap('tab20b')
colors = [cmap(i) for i in np.linspace(0, 1, 20)]
img = np.array(img)
plt.figure()
fig, ax = plt.subplots(1, figsize=(12,9))
ax.imshow(img)

pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape))
pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape))
unpad_h = img_size - pad_y
unpad_w = img_size - pad_x

if detections is not None:
    unique_labels = detections[:, -1].cpu().unique()
    n_cls_preds = len(unique_labels)
    bbox_colors = random.sample(colors, n_cls_preds)
    # browse detections and draw bounding boxes
    for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
        box_h = ((y2 - y1) / unpad_h) * img.shape[0]
        box_w = ((x2 - x1) / unpad_w) * img.shape[1]
        y1 = ((y1 - pad_y // 2) / unpad_h) * img.shape[0]
        x1 = ((x1 - pad_x // 2) / unpad_w) * img.shape[1]
        color = bbox_colors[int(np.where(
             unique_labels == int(cls_pred))[0])]
        bbox = patches.Rectangle((x1, y1), box_w, box_h,
             linewidth=2, edgecolor=color, facecolor='none')
        ax.add_patch(bbox)
        plt.text(x1, y1, s=classes[int(cls_pred)], 
                color='white', verticalalignment='top',
                bbox={'color': color, 'pad': 0})
plt.axis('off')
# save image
plt.savefig(img_path.replace(".jpg", "-det.jpg"),        
                  bbox_inches='tight', pad_inches=0.0)
plt.show()

您可以将这些代码片段放在一起运行代码,或者从Github下载。下面是一些图像中目标检测的例子 :

视频中的物体追踪

所以,现在你知道了检测图像中的不同对象的方法。当你在视频中逐帧执行时,可视化可能非常酷,你会看到这些跟踪框四处移动。但是,如果这些视频帧中有多个对象,我们如何知道一帧中的对象是否与前一帧中的对象相同?这就是我们所说的“对象追踪”,并使用多个检测来识别特定对象随时间的变化。

有几种算法可以做到这一点,我决定使用SORT,它非常易于使用且速度非常快。SORT(简单在线和实时跟踪)是由Alex Bewley,Zongyuan Ge,Lionel Ott,Fabio Ramos,Ben Upcroft等人于2017年撰写的论文,其中提出使用卡尔曼滤波器来预测先前识别的对象的轨迹,并将它们与新的检测相匹配。作者Alex Bewley还写了一个多功能的Python实现,我将用它来讲述这个故事。确保从我的Github repo下载Sort版本,因为我必须进行一些小的更改才能将它集成到我的项目中。

现在我们来详细聊聊代码,前3个代码段将与单个图像检测中的相同,因为它们涉及在单个帧上获取YOLO检测。不同之处在于最后一部分,对于每个检测,我们调用Sort对象的Update函数以获取对图像中对象的引用。因此,除了上一个示例的常规检测(包括边界框的坐标和类预测)之外,我们将获得跟踪对象,除了上述参数之外,还包括对象ID。然后我们以几乎相同的方式显示,但添加该ID并使用不同的颜色,以便大家可以轻松地在视频帧中查看对象。

我还使用OpenCV来读取视频并显示视频帧。请注意,Jupyter笔记本在处理视频时速度很慢。你可以将它用于测试和简单可视化,我还提供了一个独立的Python脚本,它将读取源视频,并输出带有跟踪对象的副本。在笔记本电脑中播放OpenCV视频并不容易,因此你可以将此代码保留在其他实验中。

代码语言:javascript
复制
videopath = 'video/intersection.mp4'
%pylab inline 
import cv2
from IPython.display import clear_output
cmap = plt.get_cmap('tab20b')
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]
# initialize Sort object and video capture
from sort import *
vid = cv2.VideoCapture(videopath)
mot_tracker = Sort()
#while(True):
for ii in range(40):
    ret, frame = vid.read()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pilimg = Image.fromarray(frame)
    detections = detect_image(pilimg)
    img = np.array(pilimg)
    pad_x = max(img.shape[0] - img.shape[1], 0) * 
            (img_size / max(img.shape))
    pad_y = max(img.shape[1] - img.shape[0], 0) * 
            (img_size / max(img.shape))
    unpad_h = img_size - pad_y
    unpad_w = img_size - pad_x
    if detections is not None:
        tracked_objects = mot_tracker.update(detections.cpu())
        unique_labels = detections[:, -1].cpu().unique()
        n_cls_preds = len(unique_labels)
        for x1, y1, x2, y2, obj_id, cls_pred in tracked_objects:
            box_h = int(((y2 - y1) / unpad_h) * img.shape[0])
            box_w = int(((x2 - x1) / unpad_w) * img.shape[1])
            y1 = int(((y1 - pad_y // 2) / unpad_h) * img.shape[0])
            x1 = int(((x1 - pad_x // 2) / unpad_w) * img.shape[1])
            color = colors[int(obj_id) % len(colors)]
            color = [i * 255 for i in color]
            cls = classes[int(cls_pred)]
            cv2.rectangle(frame, (x1, y1), (x1+box_w, y1+box_h),
                         color, 4)
            cv2.rectangle(frame, (x1, y1-35), (x1+len(cls)*19+60,
                         y1), color, -1)
            cv2.putText(frame, cls + "-" + str(int(obj_id)), 
                        (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 
                        1, (255,255,255), 3)
代码语言:javascript
复制
fig=figure(figsize=(12, 8))
    title("Video Stream")
    imshow(frame)
    show()
    clear_output(wait=True)

使用笔记本后,你可以使用常规Python脚本进行实时处理(可以从相机获取输入)并保存视频。以下是我使用此程序生成的视频示例。

PyTorch中的对象检测和跟踪 [深度学习]

就是这样,你可以尝试自己检测图像中的多个对象并在视频帧中跟踪这些对象。你还可以对YOLO进行更多研究,并了解如何使用图像训练模型。 Chris Fotache是位于新泽西州的CYNET.ai的人工智能研究员。他介绍了与人生智能相关的主题,Python编程,机器学习,计算机视觉,自然语言处理等。

想要继续查看该篇文章相关链接和参考文献?

长按链接点击打开或点击底部【阅读原文】:

https://ai.yanxishe.com/page/TextTranslation/1333

AI研习社每日更新精彩内容,观看更多精彩内容:

用 Python 做机器学习不得不收藏的重要库

算法基础:五大排序算法Python实战教程

手把手:用PyTorch实现图像分类器(第一部分)

手把手:用PyTorch实现图像分类器(第二部分)

等你来译:

对混乱的数据进行聚类

初学者怎样使用Keras进行迁移学习

强化学习:通往基于情感的行为系统

一文带你读懂 WaveNet:谷歌助手的声音合成器

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-01-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI研习社 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 在图像中检测多目标以及在视频中跟踪这些目标
    • 用 Python 做机器学习不得不收藏的重要库
      • 算法基础:五大排序算法Python实战教程
        • 手把手:用PyTorch实现图像分类器(第一部分)
          • 手把手:用PyTorch实现图像分类器(第二部分)
            • 对混乱的数据进行聚类
              • 初学者怎样使用Keras进行迁移学习
                • 强化学习:通往基于情感的行为系统
                  • 一文带你读懂 WaveNet:谷歌助手的声音合成器
                  相关产品与服务
                  NLP 服务
                  NLP 服务(Natural Language Process,NLP)深度整合了腾讯内部的 NLP 技术,提供多项智能文本处理和文本生成能力,包括词法分析、相似词召回、词相似度、句子相似度、文本润色、句子纠错、文本补全、句子生成等。满足各行业的文本智能需求。
                  领券
                  问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档