基于OpenGL ES的深度学习框架编写

背景与工程定位

背景

项目组基于深度学习实现了视频风格化和人像抠图的功能,但这是在PC/服务端上跑的,现在需要移植到移动端,因此需要一个移动端的深度学习的计算框架。

同类型的库

  • caffe-Android-lib 目前应该是最便于集成使用的深度学习框架库。
  • tensorflow和mxnet据说也有对应的android库,因时间原因暂未测试
  • CNNdroid,网址https://zhuanlan.zhihu.com/p/25259452,这个是用

renderscript 作优化的深度学习框架,不过就代码实现和实际测试结果来看,性能一般。

工程定位

实现可实时、体积小、通用的深度学习预测框架。

可实时

跟PC或服务器不同,移动设备上的GPU可不一定有CPU强悍(多线程+neon/vfp),但在需要实时计算的场景(主要是相机预览和视频播放),往往都是基于OpenGL渲染环境的。

实时的情况下,深度学习框架的输入和输出都在GPU端,使用CPU进行计算往往需要拷贝图像出来,算好后再传到GPU端,因此基于GPU实现的深度学习的库能持平CPU版本的效率就有足够优势了。

比如实时抠人像这个case:

对每一帧相机预览产生的数据,系统将其映射为opengl 的一个external texture,然后需要 计算出一个 mask texture,与原先的texture作混合,显示出来。如果mask texture 的计算在cpu上进行,则需要每帧先把 graphicbuffer 的数据拷贝出来,计算出mask后上传到 mask texture 去,产生一来一回两次额外拷贝。

通用

本工程需要支持 caffe 产出的模型文件,支持常见的网络如lenet、ResNet等等。这个工作量包括编写相应层的算子,设计网络结构,解析caffe模型的参数等。

所幸的是,目前在移动端做好深度学习的预测就足够了,相比于兼顾训练的结构至少省去2/3的工作量。

工程实现

方案选型

GPU加速的API

使用GPU加速有如下一些方案:

CUDA、OpenCL、OpenGL(ES)、RenderScript、Metal CUDA只适用到NVIDIA的GPU,Metal只适用于apple系列,这两个对android设备而言基本不用考虑。

对于OpenCL,虽然有不少移动GPU已经支持,比如 Arm 的 mali 系列(T628之后),且有相应的支持库。但是,一方面由于Android在系统层面上没有支持,没有相应的系统API,兼容性还是比较差,另一方面,OpenCL 操作完成后的内存传到OpenGL还是需要同步一下,会影响效率。

RenderScript 这个坑比较多,文档极少,而且会有跟OpenCL一样的需要跟OpenGL同步的问题,不做考虑。 最后就只剩下 OpenGL ES,为了开发方便,用 Computer shader 实现,尽管会有一定的兼容性牺牲(Android 5.1 及以上,GPU支持openGLES 3.1),但考虑到下面两点是值得的:

1. 走渲染管线去实现通用计算,编程复杂且容易出错,调优也很麻烦。有 computer shader之后,编程就跟opencl、metal类似,这些工作量可以大幅降低,大大加快开发。 2. 支持OpenGLES 3.1版本的GPU一般都是相对较新的,性能不会太差,能够实现加速的目的。

运算的分配

CNNdroid中仅用GPU加速卷积层的运算,其他还是由CPU+多线程执行。以前我们在早期作gpu加速的预研时,也有过类似的尝试,但是数据传输和同步的性能消耗远大于协同计算带来的性能提升。因此这个工程中,网络中的计算全部由GPU完成,避免数据在CPU和GPU之间反复传输或同步。

另外,GPU驱动在申请内存(分配纹理所需要内存空间)的时间消耗在移动设备端是不可忽略的,因此,不能在运算过程中临时创建纹理或其他Buffer,必须事先分配好。

优化注意点

1. 向量化运算

预测时,我们输入神经网络的数据可表示为 w∗h∗d的三维数据。我们将输入数据用一个RGBA32F格式的3D纹理存维,由于每一个像素有4个数值,得到的纹理大小是w∗h∗ceil(d4)。

对于卷积层和内积层,我们把参数存储为mat4的数组,然后其计算就完全是vec4级的向量化运算。

2. 合适的localsize设计

与OpenCL不一样,computer shader 必须手动指定 workgroup 的大小,并且指定运行的 workgroup 数量。这两组维度,都是越大越好。

local size 一般而言越大越好,但 computer shader 所需要的寄存器越多,local size 的最大值就越小,考虑到最耗时的卷积shader所能使用的local size 一般也就 64,保守起见都定为64(8乘8)。

不能对齐的情况在shader中处理,比如下面的代码:

3. 适当地合并/去除layer

如正则层可以直接和上一层合并(末尾加个max处理就行),dropout层可以直接丢弃。 合并可以提升性能(不过不会太多),但最重要的是减少了中间内存。

框架设计

分为两个子模块,引擎模块在客户端上运行,工具模块用来转换caffe的模型文件。

引擎模块

1. 数据层

Image 为一个RGBA32F格式的2D Array纹理,SSBO为一种vbo, 全称为GL_SHADER_STORAGE_BUFFER,用于存储自定义类型的数据(主要就是卷积层和内积层的参数)。

Program 为 着色器链接而成的 opengl program,NetInfo 由 proto 定义,用于规定网络结构。

在 shader 中,image 和 SSBO 示例如下:

2. 算子层

包括各类layer的实现,如卷积、正则、内积(全连接)、Softmax等。 每一个layer要负责申请自己的输出内存(image)。

3. 结构层

根据 NetInfo 的信息,创建各类算子并构成DAG(有向无环图),执行运算并输出结果。

下图是lenet的dag示例:

工具模块

一个结构转换器、参数初始化和拷贝工具。拷贝工具是比较容易出错的,因为卷积层和内积层的参数需要补零对齐及重排。

性能与效果

跟开源的caffe-android-lib对比:https://github.com/sh1r0/caffe-android-lib

库大小

  • caffe-android-lib 11M
  • DeeplearningOGL 440K

全自主开发的,毫无疑问要小很多很多。

运行效率

Oppo R9 (MT6755, GPU: Mali-T860)上的测试结果:

连续运行十次,去除第一次的结果(移动设备上一般都是动态调频的,第一次跑的时候CPU/GPU的频率还没调起来,会比较慢)。

Lenet 网络:

  • caffe-android-lib:5.0~5.2ms(线程设为4)
  • DeeplearningOGL:3.6-3.8 ms

较CPU版本(包含了neon与多线程优化)提升了 50%左右的效率,已经大大超出预期了,在GPU更好的机器上(如mate8上)表现会更佳。

相比于 CNNdroid 更是好很多了。

人像抠图的场景很流畅,且不需要隔帧计算。

本文来自CSDN博客:http://blog.csdn.net/jxt1234and2010/article/details/71056736

原文发布于微信公众号 - CSDN技术头条(CSDN_Tech)

原文发表时间:2017-05-09

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器学习算法与Python学习

TensorFlow的安装与初步了解

今天终于有时间一探滕三福了,TensorFlow(腾三福)是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理。Tenso...

3506
来自专栏技术翻译

机器学习和容器

机器学习(ML)和人工智能(AI)现在是IT行业中的热门话题。和容器一样。在这个博客中,我尝试将两者绘制在同一张图片中,看看是否有任何协同作用。

1590
来自专栏人工智能

TensorFlow核心使用要点

正文之前,小梦先来说说什么是TensorFlow。TensorFlow是谷歌研发的第二代人工智能学习系统,可被用于语音识别或图像识别等多项机器深度学 习领域。T...

2337
来自专栏AI研习社

想要训练专属人脸识别模型?先掌握构建人脸数据集的三种绝招

雷锋网 AI 研习社按,随着深度学习的发展,很多技术已经落地,成为我们每天都能接触到的产品,人脸识别就是其中之一。人脸识别的应用范围很广,涉及上下班打卡、门禁、...

1912
来自专栏CDA数据分析师

工具 | Python 和 R 数据分析/挖掘工具互查

在此总结一些在数据分析/挖掘中可能用到的功能,方便大家索引或者从一种语言迁移到另一种。如果大家已经熟悉python和R的模块/包载入方式,那下面的表查找起来相对...

2417
来自专栏WOLFRAM

Mathematica 11.1.1 中文版已发布

1463
来自专栏媒矿工厂

视频编码的GPU加速

前言 随着视频编解码技术的不断发展,视频逐步向着高清晰、高动态、高数据量的方向演进。这对视频编解码终端的计算能力提出了越来越高的要求。同时,在GPU领域,随着C...

6344
来自专栏慎独

Python科学计算和绘图入门

4254
来自专栏智能算法

数据异常到底该如何检测?(一)

小编在正式进入工作之后,面对的第一个需要去解决的问题:在网络安全监测中,如何发现异常数据?如异常用户登录,异常操作等。对于网络上的问题我确实是第一次接触这样类型...

6967
来自专栏磐创AI技术团队的专栏

Tensorboard详解(下篇)

2595

扫码关注云+社区

领取腾讯云代金券