前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >python开发:基于SSD下的图像内容识别(二)

python开发:基于SSD下的图像内容识别(二)

作者头像
sladesal
发布2018-08-27 11:09:26
2.2K1
发布2018-08-27 11:09:26
举报
文章被收录于专栏:机器学习之旅

感谢 @zcl1122指出的倒数第三节代码中的i错误的被简书转行成大写的I的问题。

上一节粗略的描述了如何关于图像识别,抠图,分类的理论相关,本节主要用代码,来和大家一起分析每一步骤。 看完本节,希望你也能独立完成自己的图片、视频的内容实时定位。

首先,我们需要安装TensorFlow环境,建议利用conda进行安装,配置,90%尝试单独安装的人最后都挂了。

其次,我们需要安装从git上下载训练好的模型,git clone https://github.com/balancap/SSD-Tensorflow 如果没有安装git的朋友,请自行百度安装。

最后找到你下载的位置进行解压,unzip ./SSD-Tensorflow/checkpoints/ssd_300_vgg.ckpt.zip 这边务必注意,网上90%的教程这边就结束了,其实你这样是最后跑不通代码的,你需要把解压的文件进行移动到checkpoint的文件夹下面,这个问题git上这个同学解释了,详细的去看下https://github.com/balancap/SSD-Tensorflow/issues/150

最后的最后,下载你需要检测的网路图片,就ok了

预处理步骤完成了,下面让我们看代码。 加载相关的包:

代码语言:javascript
复制
import os
import math
import random
import sys
import numpy as np
import tensorflow as tf
import cv2
import matplotlib.pyplot as plt
import matplotlib.cm as mpcm
sys.path.append('./SSD-Tensorflow/')
from nets import ssd_vgg_300, ssd_common, np_methods
from preprocessing import ssd_vgg_preprocessing

配置相关TensorFlow环境

代码语言:javascript
复制
gpu_options = tf.GPUOptions(allow_growth=True)
config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
isess = tf.InteractiveSession(config=config)

做图片的格式的处理,使他满足input的条件

代码语言:javascript
复制
#我们用的TensorFlow下的一个集成包slim,比tensor要更加轻便
slim = tf.contrib.slim
#训练数据中包含了一下已知的类别,也就是我们可以识别出以下的东西,不过后续我们将自己自己训练自己的模型,来识别自己想识别的东西
l_VOC_CLASS = [
                'aeroplane',   'bicycle', 'bird',  'boat',      'bottle',
                'bus',         'car',     'cat',   'chair',     'cow',
                'diningTable', 'dog',     'horse', 'motorbike', 'person',
                'pottedPlant', 'sheep',   'sofa',  'train',     'TV'
]
# 定义数据格式
net_shape = (300, 300)
data_format = 'NHWC'  # [Number, height, width, color],Tensorflow backend 的格式

# 预处理将输入图片大小改成 300x300,作为下一步输入
img_input = tf.placeholder(tf.uint8, shape=(None, None, 3))
image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(
    img_input, 
    None, 
    None, 
    net_shape, 
    data_format, 
    resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE
)
image_4d = tf.expand_dims(image_pre, 0)

下面我们来载入SSD作者已经搞定的模型

代码语言:javascript
复制
# 定义 SSD 模型结构
reuse = True if 'ssd_net' in locals() else None
ssd_net = ssd_vgg_300.SSDNet()
with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):
    predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse)
# 导入官方给出的 SSD 模型参数
#这边修改成你自己的路径
ckpt_filename = '/Users/slade/SSD-Tensorflow/checkpoints/ssd_300_vgg.ckpt'
isess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(isess, ckpt_filename)
ssd_anchors = ssd_net.anchors(net_shape)

下面让我们把SSD识别出来的结果在图片中表示出来

代码语言:javascript
复制
#不同类别,我们以不同的颜色表示
def colors_subselect(colors, num_classes=21):
    dt = len(colors) // num_classes
    sub_colors = []
    for i in range(num_classes):
        color = colors[i*dt]
        if isinstance(color[0], float):
            sub_colors.append([int(c * 255) for c in color])
        else:
            sub_colors.append([c for c in color])
    return sub_colors
#画出在图中的位置
def bboxes_draw_on_img(img, classes, scores, bboxes, colors, thickness=5):
    shape = img.shape
    for i in range(bboxes.shape[0]):
        bbox = bboxes[i]
        color = colors[classes[i]]
        # Draw bounding box...
        p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))
        p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))
        cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)
        # Draw text...
        s = '%s:%.3f' % ( l_VOC_CLASS[int(classes[i])-1], scores[i])
        p1 = (p1[0]-5, p1[1])
        cv2.putText(img, s, p1[::-1], cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
colors_plasma = colors_subselect(mpcm.plasma.colors, num_classes=21)

让我们开始训练吧

代码语言:javascript
复制
def process_image(img, select_threshold=0.3, nms_threshold=.8, net_shape=(300, 300)):
    #先获取SSD网络的层相关的参数
    rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img],
                                                              feed_dict={img_input: img})
    #获取分类结果,位置
    rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(
            rpredictions, rlocalisations, ssd_anchors,
            select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True)
    rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)
    rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400)
    rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)
    # 让我们在图中画出来就行了
    rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)
    bboxes_draw_on_img(img, rclasses, rscores, rbboxes, colors_plasma, thickness=2)
    return img

预处理的函数都写完了,我们就可以执行了。

代码语言:javascript
复制
#读取数据
img = cv2.imread("/Users/slade/Documents/Yoho/picture_recognize/test7.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(process_image(img))
plt.show()

img的数据形式如下:

代码语言:javascript
复制
In [8]: img
Out[8]:
array([[[ 35,  59,  43],
        [ 37,  60,  44],
        [ 38,  61,  45],
        ...,
        [ 73,  99,  62],
        [ 74,  99,  60],
        [ 72,  97,  57]],

       [[ 37,  60,  44],
        [ 37,  60,  44],
        [ 37,  60,  44],
        ...,
        [ 66,  92,  57],
        [ 67,  93,  56],
        [ 67,  92,  53]],

       [[ 37,  60,  44],
        [ 36,  59,  43],
        [ 37,  58,  43],
        ...,
        [ 56,  83,  48],
        [ 60,  86,  51],
        [ 61,  87,  50]],

       ...,
       [[ 96, 101,  95],
        [107, 109, 104],
        [ 98,  97,  95],
        ...,
        [ 84, 126,  76],
        [ 72, 118,  72],
        [ 78, 126,  86]],

       [[ 98, 103,  96],
        [114, 116, 111],
        [112, 113, 108],
        ...,
        [ 94, 137,  84],
        [ 87, 133,  86],
        [105, 153, 111]],

       [[ 99, 105,  95],
        [110, 113, 106],
        [134, 135, 129],
        ...,
        [127, 170, 116],
        [121, 167, 118],
        [131, 180, 135]]], dtype=uint8)

处理后的结果如下:

是不是非常无脑,上面的代码直接复制就可以完成。

下面在拓展一下视频的处理方式,其实相关的内容是一致的。 利用moviepy.editor包里面的VideoFileClip的切片的功能,然后对每一次切片的结果进行process_image过程就可以了,这边就不贴代码了,需要的朋友私密我。

最后感谢大家阅读。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2017.11.17 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
图像识别
腾讯云图像识别基于深度学习等人工智能技术,提供车辆,物体及场景等检测和识别服务, 已上线产品子功能包含车辆识别,商品识别,宠物识别,文件封识别等,更多功能接口敬请期待。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档