专栏首页大龄程序员的人工智能之路从人工智能鉴黄模型,尝试TensorRT优化

从人工智能鉴黄模型,尝试TensorRT优化

题图 摄于洞庭湖畔

随着互联网的快速发展,越来越多的图片和视频出现在网络,特别是UCG产品,激发人们上传图片和视频的热情,比如微信每天上传的图片就高达10亿多张。每个人都可以上传,这就带来监管问题,如果没有内容审核,色情图片和视频就会泛滥。前不久,一向以开放著称的tumblr,就迫于压力,开始限制人们分享色情图片。更别提国内,内容审核是UCG绕不过去的坎。还记得前几年出现的职业鉴黄师这一职业么?传说百万年薪,每天看黄片看得想吐,但最近又很少有人提及这一职业,这个应监管而生的职业,因人工智能的出现又快速消亡。(当然也不是完全消亡,毕竟判断是否色情是一个主观的事情,有些艺术和色情之间的边界比较模糊,需要人工加以判断)

之前写过一篇文章利用人工智能检测色情图片,也曾经尝试过在浏览器中加入色情图片过滤功能,但实验下来,推理速度太慢(当时使用的Google Nexus 4做的测试,检测一张图片需要几秒钟),没法做实时过滤。最近在研究nvidia的Jetson Nano以及推理加速框架TensorRT,因此想尝试一下,看能否应用一些加速方法,加速推理。

虽然我的最终目标是应用到Jetson Nano,但是TensorRT其实适用于几乎所有的Nvidia显卡,为了方便起见,我还是先在PC端进行尝试。没有Nvidia显卡?也没有关系,可以看看我前面发布的两篇文章:

  1. 谷歌GPU云计算平台,免费又好用
  2. Google Colab上安装TensorRT

open_nsfw

本文采用的深度学习模型是雅虎开源的深度学习色情图片检测模型open_nsfw,这里的NSFW代表Not Suitable for Work,该项目基于caffe框架。由于我主要研究的是Tensorflow,所以在网上找到该模型的Tensorflow实现版本,fork了一份,并添加了TensorRT框架的处理脚本,你可以使用如下命令获得相关代码:

git clone https://github.com/mogoweb/tensorflow-open_nsfw.git

model.py 中,我们可以看到open_nsfw的模型定义,data/open_nsfw-weights.npy 是采用工具从yahoo open_nsfw的cafee权重转换得到的Tensorflow权重,这样我们无需训练模型,直接用于推理过程。classify_nsfw.py 脚本可用于单张图片的推理:

python classify_nsfw.py -m data/open_nsfw-weights.npy test.jpg

注意:脚本提供了两种解码图片文件的方式,一种是采用PIL.image、skimage进行图片处理,也就是所谓的yahoo_image_loader,一种是采用tensorflow中的图片处理函数进行处理。因为原始的open_nsfw模型是采用PIL.image、skimage进行预处理而训练的,而不同的库解码出来的结果存在细微的差异,会影响最终结果,一般优选选择yahoo_image_loader。当然,如果你打算自己训练模型,那选择哪种图片处理库都可以。

tools 目录下有一些脚本,可以将模型导出为frozen graphsaved model以及tflite等格式,这样我们可以方便的在服务器端部署,还可以应用到手机端。

opt是我编写的采用TensorRT框架加速的代码,在下面我将详细说明。

导出为TensorRT模型

目前TensorRT作为Tensorflow的一部分得到Google官方支持,其包位于tensorflow.contrib.tensorrt,在代码中加入:

import tensorflow.contrib.tensorrt as trt

就可以使用TensorRT,因为有Google的支持,导出到TensorRT模型也就相当简单:

        trt_graph = trt.create_inference_graph(
                input_graph_def=frozen_graph_def,
                outputs=[output_node_name],
                max_batch_size=1,
                max_workspace_size_bytes=1 << 25,
                precision_mode='FP16',
                minimum_segment_size=50
        )

        graph_io.write_graph(trt_graph, export_base_path, 'trt_' + graph_name, as_text=False)

其中:

  • input_graph_def 为需要导出的Tensorflow模型图定义
  • outputs 为输出节点名称
  • max_batch_size 为最大的batch size限制,因为GPU存在显存限制,需要根据GPU memory大小决定,一般情况可以给8或者16
  • precision_mode 为模型精度,有FP32、FP16和INT8可选,精度越高,推理速度越慢,也要依GPU而定。

graph_io.write_graph 将图写入到文件,在后续的代码中可以加载之。

完整的代码请参考 opt/export_trt.py 文件。

测试数据

因为一些政策法规的限制,并没有公开数据集可提供下载,不过在github上有一些开源项目,提供脚本,从网络上进行下载。我使用的是 https://github.com/alexkimxyz/nsfw_data_scraper 这个开源项目中的脚本。这个项目提供drawings、hentai、neutral、porn、sexy四种类别图片,可以划分为训练集和测试集,并检查图片是否有效(因为从网络爬取,有些链接不一定能访问到)。

注意这个图片下载量非常大,需要注意别把硬盘撑满。虽然这个数据量够大(几万张),可以自行进行模型训练,但和yahoo训练open_nsfw模型的图片量相比,还是小巫见大巫,据说yahoo训练这个模型用了几百万张的图片。

推理速度对比

在opt目录下,我针对两种模型的加载和推理添加了两个脚本,分别是 benchmark_classify_nsfw.pybenchmark_classify_trt.py,细心的同学可能会发现,这两个脚本几乎一模一样,是的,除了 benchmark_classify_trt.py 多了一行代码:

import tensorflow.contrib.tensorrt as trt

加入这行import语句,告诉tensorflow使用TensorRT框架,否则的话,会出现如下错误:

tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered 'TRTEngineOp' in binary running on alex-550-279cn. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed.

取2000张测试图片进行测试,在我的GTX 960上,推理速度如下:

未优化模型: 53 s
使用TensorRT优化模型: 54 s

如果你下载更大的数据集,可以多测试一些图片,看看优化效果。

在Google Colab上,我放了一份Jupter Notebook,有兴趣的同学可以借助Google Colab尝试一下,文件地址:https://colab.research.google.com/drive/1vH-GF6F8HQeGwkKFoR6WWMEUzqgVOetu ,当然你也可以访问我github上完整的脚本及Notebook:

https://github.com/mogoweb/tensorflow-open_nsfw.git

点击阅读原文,可以跳转到该项目。

题外话:

微信公众号流量主的门槛已经大大降低,我在公众号文章底部开通了广告,希望没有影响大家的阅读体验。我一直很好奇,这种广告会有人点击么,过一段也许我会得到答案。

本文分享自微信公众号 - 云水木石(ourpoeticlife),作者:陈正勇

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-05-20

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 使用TensorFlow一步步进行目标检测(1)

    目标检测(Object Detection)是人工智能最基础的应用,不论是我们常见的人脸识别,还是高大上的自动驾驶,都离不开目标检测。要从一幅复杂的画面中识别出...

    云水木石
  • 如何将自己开发的模型转换为TensorFlow Lite可用模型

    对于开发者来说,在移动设备上运行预先训练好的模型的能力意味着向边界计算(edge computing)迈进了一大步。[译注:所谓的边界计算,从字面意思理解,就是...

    云水木石
  • [机器学习实战札记] NumPy基础

    <<机器学习实战>>一书非常注重实践,对每个算法的实现和使用示例都提供了python实现。在阅读代码的过程中,发现对NumPy有一定的了解有助于理解代码。特别是...

    云水木石
  • Vue.js动态组件解析

    什么是动态组件绑定?简单的说,就是几个组件放在一个挂载点下,然后根据父组件的某个变量来决定显示哪个,或者都不显示。

    Dunizb
  • 面向设计的半封装web组件开发

    作者:张鑫旭,资深钓鱼爱好者,然后平时喜欢研究与学习前端技术。 前言 本文内容可谓是对大脑认知的一场洗礼。我们平常提到组件,就会想到重用,各个项目都能使用。而...

    腾讯大讲堂
  • 面向设计的半封装web组件开发 - 腾讯ISUX

    腾讯ISUX
  • 面向设计的半封装web组件开发

    本文内容可谓是对大脑认知的一场洗礼。我们平常提到组件,就会想到重用,各个项目都能使用。而本文的组件,对于某具体项目而言是组件,但是,对于其他项目,就是个半封装的...

    疯狂的技术宅
  • 如何写出漂亮的 React 组件

    在Walmart Labs的产品开发中,我们进行了大量的Code Review工作,这也保证了我有机会从很多优秀的工程师的代码中学习他们的代码风格与样式。在这篇...

    哲洛不闹
  • Vue中组件之间8中通信方式,值得收藏

    vue是数据驱动视图更新的框架, 所以对于vue来说组件间的数据通信非常重要,那么组件之间如何进行数据通信的呢? 首先我们需要知道在vue中组件之间存在什么样的...

    用户6973020
  • Vue中组件之间8中通信方式,值得收藏

    vue是数据驱动视图更新的框架, 所以对于vue来说组件间的数据通信非常重要,那么组件之间如何进行数据通信的呢? 首先我们需要知道在vue中组件之间存在什么样的...

    coder_koala

扫码关注云+社区

领取腾讯云代金券