专栏首页贾志刚-OpenCV学堂tensorflow Object Detection API使用预训练模型mask r-cnn实现对象检测

tensorflow Object Detection API使用预训练模型mask r-cnn实现对象检测

Mask R-CNN模型下载

Mask R-CNN是何凯明大神在2017年整出来的新网络模型,在原有的R-CNN基础上实现了区域ROI的像素级别分割。关于Mask R-CNN模型本身的介绍与解释网络上面已经是铺天盖地了,论文也是到处可以看到。这里主要想介绍一下在tensorflow中如何使用预训练的Mask R-CNN模型实现对象检测与像素级别的分割。tensorflow框架有个扩展模块叫做models里面包含了很多预训练的网络模型,提供给tensorflow开发者直接使用或者迁移学习使用,首先需要下载Mask R-CNN网络模型,这个在tensorflow的models的github上面有详细的解释与model zoo的页面介绍, tensorflow models的github主页地址如下: https://github.com/tensorflow/models

我这里下载的是:

mask_rcnn_inception_v2_coco_2018_01_28.tar.gz

下载好模型之后可以解压缩为tar文件,然后通过下面的代码读入模型

MODEL_NAME = 'mask_rcnn_inception_v2_coco_2018_01_28'
MODEL_FILE = 'D:/tensorflow/' + MODEL_NAME + '.tar'

# Path to frozen detection graph
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('D:/tensorflow/models/research/object_detection/data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
    file_name = os.path.basename(file.name)
    if 'frozen_inference_graph.pb' in file_name:
        tar_file.extract(file, os.getcwd())

detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

模型使用coco数据集,可以检测与分割90个对象类别,所以下面需要把对应labelmap文件读进去,这个文件在

models\research\objectdetection\data

目录下,实现代码如下:

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

有了这个之后就需要从模型中取出如下几个tensor

  • num_detections 表示检测对象数目
  • detection_boxes 表示输出框BB
  • detection_scores 表示得分
  • detection_classes 表示对象类别索引
  • detection_masks 表示mask分割

然后在会话中运行这几个tensor即可,代码实现如下:

def run_inference_for_single_image(image, graph):
    with graph.as_default():
        with tf.Session() as sess:
            # Get handles to input and output tensors
            ops = tf.get_default_graph().get_operations()
            all_tensor_names = {output.name for op in ops for output in op.outputs}
            tensor_dict = {}
            for key in ['num_detections', 'detection_boxes', 'detection_scores', 'detection_classes', 'detection_masks']:
                tensor_name = key + ':0'
                if tensor_name in all_tensor_names:
                    tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(tensor_name)

            if 'detection_masks' in tensor_dict:
                # The following processing is only for single image
                detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
                detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
                # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
                real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
                detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
                detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
                detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
                    detection_masks, detection_boxes, image.shape[0], image.shape[1])
                detection_masks_reframed = tf.cast(
                    tf.greater(detection_masks_reframed, 0.5), tf.uint8)
                # Follow the convention by adding back the batch dimension
                tensor_dict['detection_masks'] = tf.expand_dims(
                    detection_masks_reframed, 0)
            image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')

            # Run inference
            output_dict = sess.run(tensor_dict,
                                 feed_dict={image_tensor: np.expand_dims(image, 0)})

            # all outputs are float32 numpy arrays, so convert types as appropriate
            output_dict['num_detections'] = int(output_dict['num_detections'][0])
            output_dict['detection_classes'] = output_dict[
              'detection_classes'][0].astype(np.uint8)
            output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
            output_dict['detection_scores'] = output_dict['detection_scores'][0]
            if 'detection_masks' in output_dict:
                output_dict['detection_masks'] = output_dict['detection_masks'][0]
        return output_dict

下面就是通过opencv来读取一张彩色测试图像,然后调用模型进行检测与对象分割,代码实现如下:

image = cv2.imread("D:/apple.jpg");
# image = cv2.imread("D:/tensorflow/models/research/object_detection/test_images/image2.jpg");
cv2.imshow("input image", image)
print(image.shape)

# Actual detection.
output_dict = run_inference_for_single_image(image, detection_graph)

# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
    image,
    output_dict['detection_boxes'],
    output_dict['detection_classes'],
    output_dict['detection_scores'],
    category_index,
    instance_masks=output_dict.get('detection_masks'),
    use_normalized_coordinates=True,
    line_thickness=8)

原图如下:

检测运行结果如下:

带mask分割效果如下:

官方测试图像运行结果:

本文分享自微信公众号 - OpenCV学堂(CVSCHOOL),作者:gloomyfish

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2018-08-24

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Tensorflow Object Detection API 终于支持tensorflow1.x与tensorflow2.x了

    基于tensorflow框架构建的快速对象检测模型构建、训练、部署框架,是针对计算机视觉领域对象检测任务的深度学习框架。之前tensorflow2.x一直不支持...

    OpenCV学堂
  • Tensorflow + OpenCV4 安全帽检测模型训练与推理

    如何安装tensorflow object detection API框架,看这里:

    OpenCV学堂
  • 使用OpenCV中的universal intrinsics为算法提速 (2)

    前言:因为新型冠状病毒导致疫情,最近几日各种新闻和消息满天飞。疫情之下不易出行、不宜聚会;宜宅在家、宜阅读、宜学习、宜写代码。鉴于此,本系列第2篇提前发布。希望...

    OpenCV学堂
  • GitHub实战系列~4.把github里面的库克隆到指定目录+日常使用 2015-12-11

    GitHub实战系列汇总:http://www.cnblogs.com/dunitian/p/5038719.html ————————————————————...

    逸鹏
  • 《我们捉鱼吧》——Scratch神奇的“侦测”功能总结

    导读:本文通过案例《鼠标捉鱼》、《大鱼吃小鱼》、《小猫捉鱼》总结了Scratch的侦测功能。

    一石匠人
  • Git命令集之六——查看仓库状态 原

    珲少
  • Neo4J:删除关系

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 ...

    程裕强
  • 关于HDFS应知应会的几个问题

    安全模式是Namenode的一种状态(Namenode主要有active/standby/safemode三种模式)。

    大数据学习与分享
  • 服务器上 git 的安装及基本配置

    git 对于开发者来说属于必备工具中的必备工具了。何况,没有 git 的话,「面向 github 编程」 从何说起,如同一个程序员断了左膀右臂。

    山月
  • 必须掌握的HDFS相关问题

    安全模式是Namenode的一种状态(Namenode主要有active/standby/safemode三种模式)。

    大数据学习与分享

扫码关注云+社区

领取腾讯云代金券