前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Object Detection-YOLOv2 Anchor Box Clustering

Object Detection-YOLOv2 Anchor Box Clustering

作者头像
YoungTimes
发布2022-04-28 19:51:04
4340
发布2022-04-28 19:51:04
举报

Dimension Clusters是YOLOv2中使用的优化策略之一,它的主要思路是通过聚合算法,从数据集中预先得到Bounding Box的形状先验数据,从而使得模型更容易学习,并且得到更好的Object Detection结果。

Average IOU of boxes to closest priors on VOC 2007

从实际的实验结果看,Cluster算法得到5个先验Anchor Box的效果已经与手工挑选的9个先验Anchor Box的效果持平,9个Clustering AnchorBox的效果相比于5个Clustering Anchor Box的效果有明显的提升,说明,Dimension Cluster的策略确实起到了非常正向的效果。

Clustering box dimensions on VOC and COCO

本文主要理解Clustering Box Dimensions的详细实现过程。

1.Anchor Box

目标检测(Object Detection)中最大的挑战之一就是在同一个邻域(Neighboorhood)内找到多个形状各异的物体。

如下图所示的场景,一个人站在船上,两个物体的位置非常近,形状大小不同。为了更好的在同一邻域同时检测这些物体,Object Detection都引入了Anchor Box的概念。

图片来源【1】

YOLOv1中的Anchor Box需要用户预先定义两个超参数(Hyperparameters): Anchor Box的数量和Anchor Box的形状。

Anchor Box的数量越多,YOLO能够检测的对象就越多,代价就是神经网络模型的参数和计算量的增加。

Anchor Box的形状越有代表性,Object Detection检测的对象就越多。如果预定义的Anchor Box在实际场景中几乎不会出现,那这个形状的定义就几乎没有正向的价值。

为了更好的指定Anchor Box的数量和形状,YOLOv2提出使用K-means聚类的算法来设置这两个超参数。

本文我们看下如何在PASCAL VOC2012数据集上通过K-means算法来获取Anchor Box的数量和形状参数。

2.PASCAL VOC2012

PASCAL VOC2012数据集可以从官网直接获取,该数据集中包含如下的物体类别:

代码语言:javascript
复制
LABELS = ['aeroplane',  'bicycle', 'bird',  'boat',      'bottle', 
          'bus',        'car',      'cat',  'chair',     'cow',
          'diningtable','dog',    'horse',  'motorbike', 'person',
          'pottedplant','sheep',  'sofa',   'train',   'tvmonitor']

下载数据集。

代码语言:javascript
复制
train_image_folder = "../ObjectDetectionRCNN/VOCdevkit/VOC2012/JPEGImages/"
train_annot_folder = "../ObjectDetectionRCNN/VOCdevkit/VOC2012/Annotations/"
代码语言:javascript
复制
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
%matplotlib inline

为了方便后续处理,需要对数据集的Annotations进行预处理。预处理的代码如下:

代码语言:javascript
复制
import xml.etree.ElementTree as ET

def parse_annotation(ann_dir, img_dir, labels=[]):
    '''
    output:
    - Each element of the train_image is a dictionary containing the annoation infomation of an image.
    - seen_train_labels is the dictionary containing
            (key, value) = (the object class, the number of objects found in the images)
    '''
    all_imgs = []
    seen_labels = {}
    
    for ann in sorted(os.listdir(ann_dir)):
        if "xml" not in ann:
            continue
        img = {'object':[]}

        tree = ET.parse(ann_dir + ann)
        
        for elem in tree.iter():
            if 'filename' in elem.tag:
                path_to_image = img_dir + elem.text
                img['filename'] = path_to_image
                ## make sure that the image exists:
                if not os.path.exists(path_to_image):
                    assert False, "file does not exist!\n{}".format(path_to_image)
            if 'width' in elem.tag:
                img['width'] = int(elem.text)
            if 'height' in elem.tag:
                img['height'] = int(elem.text)
            if 'object' in elem.tag or 'part' in elem.tag:
                obj = {}
                
                for attr in list(elem):
                    if 'name' in attr.tag:
                        
                        obj['name'] = attr.text
                        
                        if len(labels) > 0 and obj['name'] not in labels:
                            break
                        else:
                            img['object'] += [obj]
                            
                        

                        if obj['name'] in seen_labels:
                            seen_labels[obj['name']] += 1
                        else:
                            seen_labels[obj['name']]  = 1
                        

                            
                    if 'bndbox' in attr.tag:
                        for dim in list(attr):
                            if 'xmin' in dim.tag:
                                obj['xmin'] = int(round(float(dim.text)))
                            if 'ymin' in dim.tag:
                                obj['ymin'] = int(round(float(dim.text)))
                            if 'xmax' in dim.tag:
                                obj['xmax'] = int(round(float(dim.text)))
                            if 'ymax' in dim.tag:
                                obj['ymax'] = int(round(float(dim.text)))

        if len(img['object']) > 0:
            all_imgs += [img]
                        
    return all_imgs, seen_labels

## Parse annotations 
train_image, seen_train_labels = parse_annotation(train_annot_folder,train_image_folder, labels=LABELS)
print("N train = {}".format(len(train_image)))

Output : train_image

train_image中的每个元素包含了一副图像的所有Annoation信息。

代码语言:javascript
复制
train_image[:2]
代码语言:javascript
复制
[{'filename': '../ObjectDetectionRCNN/VOCdevkit/VOC2012/JPEGImages/2007_000027.jpg',
  'height': 500,
  'object': [{'name': 'person',
    'xmax': 349,
    'xmin': 174,
    'ymax': 351,
    'ymin': 101}],
  'width': 486},
 {'filename': '../ObjectDetectionRCNN/VOCdevkit/VOC2012/JPEGImages/2007_000032.jpg',
  'height': 281,
  'object': [{'name': 'aeroplane',
    'xmax': 375,
    'xmin': 104,
    'ymax': 183,
    'ymin': 78},
   {'name': 'aeroplane', 'xmax': 197, 'xmin': 133, 'ymax': 123, 'ymin': 88},
   {'name': 'person', 'xmax': 213, 'xmin': 195, 'ymax': 229, 'ymin': 180},
   {'name': 'person', 'xmax': 44, 'xmin': 26, 'ymax': 238, 'ymin': 189}],
  'width': 500}]

Visualize output : seen_train_labels

seen_train_labels中包含Object分类以及在数据集中发现的该Object分类的数量。

(key, value) = (the object class, the number of objects found in the images)

代码语言:javascript
复制
y_pos = np.arange(len(seen_train_labels))
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.barh(y_pos,list(seen_train_labels.values()))
ax.set_yticks(y_pos)
ax.set_yticklabels(list(seen_train_labels.keys()))
ax.set_title("The total number of objects = {} in {} images".format(
    np.sum(list(seen_train_labels.values())),len(train_image)
))
plt.show()

3.K-means Clustering

由于图像中Object的宽和高各异,所以首先要对Object的宽和高进行归一化处理。

代码语言:javascript
复制
for anno in train_image:
    aw = float(anno['width'])  # width of the original image
    ah = float(anno['height']) # height of the original image
    for obj in anno["object"]:
        w = (obj["xmax"] - obj["xmin"])/aw # make the width range between [0,GRID_W)
        h = (obj["ymax"] - obj["ymin"])/ah # make the width range between [0,GRID_H)
        temp = [w,h]
        wh.append(temp)
wh = np.array(wh)
print("clustering feature data is ready. shape = (N object, width and height) =  {}".format(wh.shape))

输出:

代码语言:javascript
复制
clustering feature data is ready. shape = (N object, width and height) =  (40138, 2)

3.1 Visualize Clustering Data

对将要进行k-means聚类的数据进行可视化:

代码语言:javascript
复制
plt.figure(figsize=(10,10))
plt.scatter(wh[:,0],wh[:,1],alpha=0.1)
plt.title("Clusters",fontsize=20)
plt.xlabel("normalized width",fontsize=20)
plt.ylabel("normalized height",fontsize=20)
plt.show()

3.2 Intersection Over Union

Object的形状通常是使用(xmin, ymin, width, height)四个参数定义的Bounding Box。在IOU计算中,我们只关心Width跟Height。

计算IOU的代码如下:

代码语言:javascript
复制
def iou(box, clusters):
    '''
    :param box:      np.array of shape (2,) containing w and h
    :param clusters: np.array of shape (N cluster, 2) 
    '''
    x = np.minimum(clusters[:, 0], box[0]) 
    y = np.minimum(clusters[:, 1], box[1])

    intersection = x * y
    box_area = box[0] * box[1]
    cluster_area = clusters[:, 0] * clusters[:, 1]

    iou_ = intersection / (box_area + cluster_area - intersection)

    return iou_

3.3 K-means Clustering

k-mean聚类算法的流程比较简单。首先设置cluster的数量,初始化每个cluster的中心。然后重复执行以下两个步骤直至相邻的两次迭代计算得到的聚类中心是相同的。

1)将每个Item分配给最近的聚类中心(Cluster Center), 其中采用IOU作为距离计算函数。

2)计算新的聚类中心(mean or median)。

代码语言:javascript
复制
def kmeans(boxes, k, dist=np.median,seed=1):
    """
    Calculates k-means clustering with the Intersection over Union (IoU) metric.
    :param boxes: numpy array of shape (r, 2), where r is the number of rows
    :param k: number of clusters
    :param dist: distance function
    :return: numpy array of shape (k, 2)
    """
    rows = boxes.shape[0]

    distances     = np.empty((rows, k)) ## N row x N cluster
    last_clusters = np.zeros((rows,))

    np.random.seed(seed)

    # initialize the cluster centers to be k items
    clusters = boxes[np.random.choice(rows, k, replace=False)]

    while True:
        # Step 1: allocate each item to the closest cluster centers
        for icluster in range(k): # I made change to lars76's code here to make the code faster
            distances[:,icluster] = 1 - iou(clusters[icluster], boxes)

        nearest_clusters = np.argmin(distances, axis=1)

        if (last_clusters == nearest_clusters).all():
            break
            
        # Step 2: calculate the cluster centers as mean (or median) of all the cases in the clusters.
        for cluster in range(k):
            clusters[cluster] = dist(boxes[nearest_clusters == cluster], axis=0)

        last_clusters = nearest_clusters

    return clusters,nearest_clusters,distances

为了确定Anchor Boxes的数量,将k设置为2,3,4,...,11,分别执行k-means聚类操作。

代码语言:javascript
复制
kmax = 11
dist = np.mean
results = {}
for k in range(2,kmax):
    clusters, nearest_clusters, distances = kmeans(wh,k,seed=2,dist=dist)
    WithinClusterMeanDist = np.mean(distances[np.arange(distances.shape[0]),nearest_clusters])
    result = {"clusters":             clusters,
              "nearest_clusters":     nearest_clusters,
              "distances":            distances,
              "WithinClusterMeanDist": WithinClusterMeanDist}
    print("{:2.0f} clusters: mean IoU = {:5.4f}".format(k,1-result["WithinClusterMeanDist"]))
    results[k] = result

k值越大,IOU的值越大,这是符合预期的,当k的数值等于Box的数量时,IOU应该等于1。

代码语言:javascript
复制
 2 clusters: mean IoU = 0.4646
 3 clusters: mean IoU = 0.5391
 4 clusters: mean IoU = 0.5801
 5 clusters: mean IoU = 0.6016
 6 clusters: mean IoU = 0.6253
 7 clusters: mean IoU = 0.6434
 8 clusters: mean IoU = 0.6595
 9 clusters: mean IoU = 0.6712
10 clusters: mean IoU = 0.6840

4. K-means结果可视化

不同的聚类数量生成的聚类结果的效果可视化出来,更直观的看到效果。

代码语言:javascript
复制
def plot_cluster_result(plt,clusters,nearest_clusters,WithinClusterSumDist,wh):
    for icluster in np.unique(nearest_clusters):
        pick = nearest_clusters==icluster
        c = current_palette[icluster]
        plt.rc('font', size=8) 
        plt.plot(wh[pick,0],wh[pick,1],"p",
                 color=c,
                 alpha=0.5,label="cluster = {}, N = {:6.0f}".format(icluster,np.sum(pick)))
        plt.text(clusters[icluster,0],
                 clusters[icluster,1],
                 "c{}".format(icluster),
                 fontsize=20,color="red")
        plt.title("Clusters")
        plt.xlabel("width")
        plt.ylabel("height")
    plt.legend(title="Mean IoU = {:5.4f}".format(WithinClusterSumDist))  
    
import seaborn as sns
current_palette = list(sns.xkcd_rgb.values())

figsize = (15,35)
count =1 
fig = plt.figure(figsize=figsize)
for k in range(2,kmax):
    result               = results[k]
    clusters             = result["clusters"]
    nearest_clusters     = result["nearest_clusters"]
    WithinClusterSumDist = result["WithinClusterMeanDist"]
    
    ax = fig.add_subplot(kmax/2,2,count)
    plot_cluster_result(plt,clusters,nearest_clusters,1 - WithinClusterSumDist,wh)
    count += 1
plt.show()

Anchor Box数量与IOU的对应关系如下:

代码语言:javascript
复制
plt.figure(figsize=(6,6))
plt.plot(np.arange(2,kmax),
         [1 - results[k]["WithinClusterMeanDist"] for k in range(2,kmax)],"o-")
plt.title("within cluster mean of {}".format(dist))
plt.ylabel("mean IOU")
plt.xlabel("N clusters (= N anchor boxes)")
plt.show()

当Anchor Box的数量为4时,得到的4个Anchor Box的形状如下:

代码语言:javascript
复制
Nanchor_box = 4
results[Nanchor_box]["clusters"]
代码语言:javascript
复制
array([[0.08285376, 0.13705531],
       [0.20850361, 0.39420716],
       [0.80552421, 0.77665105],
       [0.42194719, 0.62385487]])
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-11-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 半杯茶的小酒杯 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1.Anchor Box
  • 2.PASCAL VOC2012
  • 3.K-means Clustering
    • 3.1 Visualize Clustering Data
      • 3.2 Intersection Over Union
        • 3.3 K-means Clustering
        • 4. K-means结果可视化
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档