专栏首页机器学习之旅python开发:基于SSD下的图像内容识别(二)

python开发:基于SSD下的图像内容识别(二)

感谢 @zcl1122指出的倒数第三节代码中的i错误的被简书转行成大写的I的问题。

上一节粗略的描述了如何关于图像识别,抠图,分类的理论相关,本节主要用代码,来和大家一起分析每一步骤。 看完本节,希望你也能独立完成自己的图片、视频的内容实时定位。

首先,我们需要安装TensorFlow环境,建议利用conda进行安装,配置,90%尝试单独安装的人最后都挂了。

其次,我们需要安装从git上下载训练好的模型,git clone https://github.com/balancap/SSD-Tensorflow 如果没有安装git的朋友,请自行百度安装。

最后找到你下载的位置进行解压,unzip ./SSD-Tensorflow/checkpoints/ssd_300_vgg.ckpt.zip 这边务必注意,网上90%的教程这边就结束了,其实你这样是最后跑不通代码的,你需要把解压的文件进行移动到checkpoint的文件夹下面,这个问题git上这个同学解释了,详细的去看下https://github.com/balancap/SSD-Tensorflow/issues/150

最后的最后,下载你需要检测的网路图片,就ok了

预处理步骤完成了,下面让我们看代码。 加载相关的包:

import os
import math
import random
import sys
import numpy as np
import tensorflow as tf
import cv2
import matplotlib.pyplot as plt
import matplotlib.cm as mpcm
sys.path.append('./SSD-Tensorflow/')
from nets import ssd_vgg_300, ssd_common, np_methods
from preprocessing import ssd_vgg_preprocessing

配置相关TensorFlow环境

gpu_options = tf.GPUOptions(allow_growth=True)
config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
isess = tf.InteractiveSession(config=config)

做图片的格式的处理,使他满足input的条件

#我们用的TensorFlow下的一个集成包slim,比tensor要更加轻便
slim = tf.contrib.slim
#训练数据中包含了一下已知的类别,也就是我们可以识别出以下的东西,不过后续我们将自己自己训练自己的模型,来识别自己想识别的东西
l_VOC_CLASS = [
                'aeroplane',   'bicycle', 'bird',  'boat',      'bottle',
                'bus',         'car',     'cat',   'chair',     'cow',
                'diningTable', 'dog',     'horse', 'motorbike', 'person',
                'pottedPlant', 'sheep',   'sofa',  'train',     'TV'
]
# 定义数据格式
net_shape = (300, 300)
data_format = 'NHWC'  # [Number, height, width, color],Tensorflow backend 的格式

# 预处理将输入图片大小改成 300x300,作为下一步输入
img_input = tf.placeholder(tf.uint8, shape=(None, None, 3))
image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(
    img_input, 
    None, 
    None, 
    net_shape, 
    data_format, 
    resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE
)
image_4d = tf.expand_dims(image_pre, 0)

下面我们来载入SSD作者已经搞定的模型

# 定义 SSD 模型结构
reuse = True if 'ssd_net' in locals() else None
ssd_net = ssd_vgg_300.SSDNet()
with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):
    predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse)
# 导入官方给出的 SSD 模型参数
#这边修改成你自己的路径
ckpt_filename = '/Users/slade/SSD-Tensorflow/checkpoints/ssd_300_vgg.ckpt'
isess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(isess, ckpt_filename)
ssd_anchors = ssd_net.anchors(net_shape)

下面让我们把SSD识别出来的结果在图片中表示出来

#不同类别,我们以不同的颜色表示
def colors_subselect(colors, num_classes=21):
    dt = len(colors) // num_classes
    sub_colors = []
    for i in range(num_classes):
        color = colors[i*dt]
        if isinstance(color[0], float):
            sub_colors.append([int(c * 255) for c in color])
        else:
            sub_colors.append([c for c in color])
    return sub_colors
#画出在图中的位置
def bboxes_draw_on_img(img, classes, scores, bboxes, colors, thickness=5):
    shape = img.shape
    for i in range(bboxes.shape[0]):
        bbox = bboxes[i]
        color = colors[classes[i]]
        # Draw bounding box...
        p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))
        p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))
        cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)
        # Draw text...
        s = '%s:%.3f' % ( l_VOC_CLASS[int(classes[i])-1], scores[i])
        p1 = (p1[0]-5, p1[1])
        cv2.putText(img, s, p1[::-1], cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
colors_plasma = colors_subselect(mpcm.plasma.colors, num_classes=21)

让我们开始训练吧

def process_image(img, select_threshold=0.3, nms_threshold=.8, net_shape=(300, 300)):
    #先获取SSD网络的层相关的参数
    rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img],
                                                              feed_dict={img_input: img})
    #获取分类结果,位置
    rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(
            rpredictions, rlocalisations, ssd_anchors,
            select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True)
    rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)
    rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400)
    rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)
    # 让我们在图中画出来就行了
    rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)
    bboxes_draw_on_img(img, rclasses, rscores, rbboxes, colors_plasma, thickness=2)
    return img

预处理的函数都写完了,我们就可以执行了。

#读取数据
img = cv2.imread("/Users/slade/Documents/Yoho/picture_recognize/test7.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(process_image(img))
plt.show()

img的数据形式如下:

In [8]: img
Out[8]:
array([[[ 35,  59,  43],
        [ 37,  60,  44],
        [ 38,  61,  45],
        ...,
        [ 73,  99,  62],
        [ 74,  99,  60],
        [ 72,  97,  57]],

       [[ 37,  60,  44],
        [ 37,  60,  44],
        [ 37,  60,  44],
        ...,
        [ 66,  92,  57],
        [ 67,  93,  56],
        [ 67,  92,  53]],

       [[ 37,  60,  44],
        [ 36,  59,  43],
        [ 37,  58,  43],
        ...,
        [ 56,  83,  48],
        [ 60,  86,  51],
        [ 61,  87,  50]],

       ...,
       [[ 96, 101,  95],
        [107, 109, 104],
        [ 98,  97,  95],
        ...,
        [ 84, 126,  76],
        [ 72, 118,  72],
        [ 78, 126,  86]],

       [[ 98, 103,  96],
        [114, 116, 111],
        [112, 113, 108],
        ...,
        [ 94, 137,  84],
        [ 87, 133,  86],
        [105, 153, 111]],

       [[ 99, 105,  95],
        [110, 113, 106],
        [134, 135, 129],
        ...,
        [127, 170, 116],
        [121, 167, 118],
        [131, 180, 135]]], dtype=uint8)

处理后的结果如下:

是不是非常无脑,上面的代码直接复制就可以完成。

下面在拓展一下视频的处理方式,其实相关的内容是一致的。 利用moviepy.editor包里面的VideoFileClip的切片的功能,然后对每一次切片的结果进行process_image过程就可以了,这边就不贴代码了,需要的朋友私密我。

最后感谢大家阅读。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 理论:决策树及衍射指标

    特征A对训练数据集D的信息增益g(D,A),定义为集合D的经验熵H(D)与特征A给定条件下的经验条件熵H(D|A)之差

    sladesal
  • Python踩坑指南(第一季)

    最近在python开发的过程中,发现了一些比较有意思的问题,确实让自己在开发过程中被恶心了一把,所以开了这个连续的更新博文,之后会持续的按第一第二第三这种版本下...

    sladesal
  • R开发:协调过滤推荐

    对于realRatingMatrix有六种方法:IBCF(基于物品的推荐)、UBCF(基于用户的推荐)、PCA(主成分分析)、RANDOM(随机推荐)、SVD(...

    sladesal
  • 17.HTML

    HTML简介 htyper text markup language  即超文本标记语言。 超文本: 就是指页面内可以包含图片、链接,甚至音乐、程序等非文字元素...

    zhang_derek
  • Django 项目中添加静态文件夹

    在 mysite 文件夹下添加一个 statics 文件夹用来存放 js 文件

    py3study
  • 什么是多态?如何实现?只看这一篇就够了

    多态的概念:通俗来说,就是多种形态,具体点就是去完成某个行为,当不同的对象去完成时会产生出不同的状态。

    海盗船长
  • JavaScript 实现前端table页面,vue.js实现前端表格

    对于table中的th,tr,td 可以设置rowspan,colspan属性,使得具有任何复杂包含、重叠、组合关系的表格都能做出来。

    acoolgiser
  • python之telnetlib模块实现远程登录代码

    在 python 中有一个 telnetlib,它的作用就是建立一个通到主机的 telnet连线实体, 然后向主机传送命令 (就像用键盘输入一样 )并从该连线接...

    菲宇
  • 防止小程序多次点击跳转解决方案

    在使用小程序的时候会出现这样一种情况:当网络条件差或卡顿的情况下,使用者会认为点击无效而进行多次点击,最后出现多次跳转页面的情况,就像下图(快速点击了两次):

    疯狂的小程序
  • 如何消除双休日影响来计算销售额?

    我们需要求出当月每星期的平均销售额,然后再根据当日的销售额去对比看下完成比例情况。

    逍遥之

扫码关注云+社区

领取腾讯云代金券