Tensorflow教程:GPU调用如何实现

今天,给大家分析一下Tensorflow源码之GPU调用是如何实现的?

1. Tensorflow GPU支持

Tensorflow 支持GPU进行运算,目前官方版本只支持NVIDIA的GPU,可以在tensorflow的官方上看到。Tensorflow 对GPU的运算的支持最小力度就是OP,也就是我们常说的算子,下图提供了Tensorflow的一些常见算子,而每个算子在Tensorflow上都会提供GPU的算法:关于OP的具体实现,在本篇博客中就不叙述了。

2. Tensorflow GPU调用架构

从上图我们可以看到,Tensorflow提供两种方式调用NVIDIA的方式,而NVIDIA的GPU调用方式主要依靠的CUDA的并行计算框架

2.1 Stream Executor

StreamExecutor 是一个子项目,是一个google开源的数学并行运算库,是基于CUDA API、OpenCL API管理各种GPU设备的统一API,这种统一的GPU封装适用于需要与GPU设备通信的库,而在Tensorflow上只提供了对CUDA的支持

StreamExecutor的主要功能:

  • 抽象化底层平台,对开发者不需要考虑底层的GPU的平台
  • 流式的管理模式
  • 封装了主机和GPU之间的数据移动

在StreamExecutor里封装了几个常见的基本的核心运算:

  • BLAS: 基本线性代数
  • DNN:  深层神经网络
  • FFT:   快速傅里叶变换
  • RNG:  随机数生成

2.1.1 Stream 接口

  1.  算子直接通过Stream的API的调用,在Tensorflow里Stream executor 只支持4个核心算法
  2.  每个算法都提供Support的类,进行多态的支持,比如CUDA, OpenCL
  3.  通过Support,官方tensorflow 只提供了CUDA支持,如果要支持OpenCL,可以参考开源(点击打开链接
  4.  对CUDA的支持使用了基于CUDA平台的第三方开发库,没有直接使用CUDA编程

2.2  直接调用CUDA

Tensorflow 同时本身也可以直接调用CUDA,毕竟Stream的目前接口只是支持了Blas, DNN, FFT, RND这些基本接口

1.  进行复杂运算,需要连续调用Stream的接口,这里也带来了频繁的从主内存到GPU内存之间复制的开销

2.  Stream 并没有封装一些简单的一元运算,只是封装了CUDA的提供的第三方运算库,一元运算(加减乘除,log, exp)这些如果想在GPU运算,需要基于CUDA的运算框架进行自己写代码

在Tensorflow上写CUDA代码没什么两样, 下面是一个lstm的样例

1. 定义你的global 

[html] view plain copy

print?

  1. template <typename T, bool use_peephole>
  2. __global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev,  
  3.                            const T* wci, const T* wcf, const T* wco, T* o, T* h,  
  4.                            T* ci, T* cs, T* co, T* i, T* f, const T forget_bias,  
  5.                            const T cell_clip, const int batch_size,  
  6.                            const int cell_size) {  
  7.   const int batch_id = blockIdx.x * blockDim.x + threadIdx.x;  
  8.   const int act_id = blockIdx.y * blockDim.y + threadIdx.y;  
  9. .......  
  10. }  

2. 定义使用的网格,block, thread数

[html] view plain copy

print?

  1. dim3 block_dim_2d(std::min(batch_size, 8), 32);  
  2. dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast<int>(block_dim_2d.x)),  
  3.                  Eigen::divup(cell_size, static_cast<int>(block_dim_2d.y)));  
  4. if (use_peephole) {  
  5.   lstm_gates<T, true><<<grid_dim_2d, block_dim_2d, 0, cu_stream>>>(  
  6.       icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),  
  7.       wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),  
  8.       i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size);  
  9. } else {  
  10.   lstm_gates<T, false><<<grid_dim_2d, block_dim_2d, 0, cu_stream>>>(  
  11.       icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),  
  12.       wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),  
  13.       i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size);  
  14. }  

3. 定义你的OP,在你的OP里调用CUDA的代码,并注册到Tensorflow Kernel中,注意你的Device需要设置成DEVICE_GPU,tensorflow会依据客户端传递的device的参数来决定是否需调用GPU还是CPU的算法,CUDA的文件以.cu.cc为结尾

[html] view plain copy

print?

  1. REGISTER_KERNEL_BUILDER(  
  2.     Name("arithmetic").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),  
  3.     arithmeticOP<Eigen::half>);  

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏青玉伏案

算法与数据结构(六) 迪杰斯特拉算法的最短路径(Swift版)

上篇博客我们详细的介绍了两种经典的最小生成树的算法,本篇博客我们就来详细的讲一下最短路径的经典算法----迪杰斯特拉算法。首先我们先聊一下什么是最短路径,这个还...

2115
来自专栏Python小屋

详解Python GUI版24点游戏制作过程

本文作者为浙江温州永嘉县教师发展中心应根球老师,电子邮箱:ycicada@163.com。 传统用扑克牌算24点游戏用于小学低中段学生训练四则运算效果不错,也可...

3185
来自专栏小狼的世界

使用Numpy验证Google GRE的随机选择算法

最近在读《SRE Google运维解密》第20章提到数据中心内部服务器的负载均衡方法,文章对比了几种负载均衡的算法,其中随机选择算法,非常适合用 Numpy 模...

1162
来自专栏C#

开源免费的.NET图像即时处理的组件ImageProcessor

   承接以前的组件系列,这个组件系列旨在介绍.NET相关的组件,让大家可以在项目中有一个更好的选择,社区对于第三方插件的介绍还是比较少的,很多博文的内容主要还...

920
来自专栏落影的专栏

GPUImage详细解析

从源码的角度分析、学习GPUImage和OpenGL ES,这是第一篇,介绍GPUImageFilter 和 GPUImageFramebuffer。 Open...

3256
来自专栏42度空间

基于规则评分的密码强度检测算法分析及实现(JavaScript)

用正则表达式做用户密码强度的通过性判定,过于简单粗暴,不但用户体验差,而且用户帐号安全性也差。那么如何准确评价用户密码的强度,保护用户帐号安全呢?本文分析介绍了...

5156
来自专栏机器学习算法工程师

实例介绍TensorFlow的输入流水线

在训练模型时,我们首先要处理的就是训练数据的加载与预处理的问题,这里称这个过程为输入流水线(input pipelines,或输入管道,[参考:https://...

2276
来自专栏腾讯AlloyTeam的专栏

png的故事:获取图片信息和像素内容

现在时富媒体时代,图片的重要性对于数十亿互联网用户来说不言而喻,图片本身就是像素点阵的合集,但是为了如何更快更好的存储图片而诞生了各种各样的图片格式:jpeg、...

1.5K0
来自专栏机器学习从入门到成神

Python3读取深度学习CIFAR-10数据集出现的若干问题解决

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_35512245/articl...

1012
来自专栏简书专栏

基于tensorflow、CNN、清华数据集THUCNews的新浪新闻文本分类

tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流。 CNN是convolutional neural netwo...

8181

扫码关注云+社区