Tensorflow教程:GPU调用如何实现

今天,给大家分析一下Tensorflow源码之GPU调用是如何实现的?

1. Tensorflow GPU支持

Tensorflow 支持GPU进行运算,目前官方版本只支持NVIDIA的GPU,可以在tensorflow的官方上看到。Tensorflow 对GPU的运算的支持最小力度就是OP,也就是我们常说的算子,下图提供了Tensorflow的一些常见算子,而每个算子在Tensorflow上都会提供GPU的算法:关于OP的具体实现,在本篇博客中就不叙述了。

2. Tensorflow GPU调用架构

从上图我们可以看到,Tensorflow提供两种方式调用NVIDIA的方式,而NVIDIA的GPU调用方式主要依靠的CUDA的并行计算框架

2.1 Stream Executor

StreamExecutor 是一个子项目,是一个google开源的数学并行运算库,是基于CUDA API、OpenCL API管理各种GPU设备的统一API,这种统一的GPU封装适用于需要与GPU设备通信的库,而在Tensorflow上只提供了对CUDA的支持

StreamExecutor的主要功能:

  • 抽象化底层平台,对开发者不需要考虑底层的GPU的平台
  • 流式的管理模式
  • 封装了主机和GPU之间的数据移动

在StreamExecutor里封装了几个常见的基本的核心运算:

  • BLAS: 基本线性代数
  • DNN:  深层神经网络
  • FFT:   快速傅里叶变换
  • RNG:  随机数生成

2.1.1 Stream 接口

  1.  算子直接通过Stream的API的调用,在Tensorflow里Stream executor 只支持4个核心算法
  2.  每个算法都提供Support的类,进行多态的支持,比如CUDA, OpenCL
  3.  通过Support,官方tensorflow 只提供了CUDA支持,如果要支持OpenCL,可以参考开源(点击打开链接
  4.  对CUDA的支持使用了基于CUDA平台的第三方开发库,没有直接使用CUDA编程

2.2  直接调用CUDA

Tensorflow 同时本身也可以直接调用CUDA,毕竟Stream的目前接口只是支持了Blas, DNN, FFT, RND这些基本接口

1.  进行复杂运算,需要连续调用Stream的接口,这里也带来了频繁的从主内存到GPU内存之间复制的开销

2.  Stream 并没有封装一些简单的一元运算,只是封装了CUDA的提供的第三方运算库,一元运算(加减乘除,log, exp)这些如果想在GPU运算,需要基于CUDA的运算框架进行自己写代码

在Tensorflow上写CUDA代码没什么两样, 下面是一个lstm的样例

1. 定义你的global 

[html] view plain copy

print?

  1. template <typename T, bool use_peephole>
  2. __global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev,  
  3.                            const T* wci, const T* wcf, const T* wco, T* o, T* h,  
  4.                            T* ci, T* cs, T* co, T* i, T* f, const T forget_bias,  
  5.                            const T cell_clip, const int batch_size,  
  6.                            const int cell_size) {  
  7.   const int batch_id = blockIdx.x * blockDim.x + threadIdx.x;  
  8.   const int act_id = blockIdx.y * blockDim.y + threadIdx.y;  
  9. .......  
  10. }  

2. 定义使用的网格,block, thread数

[html] view plain copy

print?

  1. dim3 block_dim_2d(std::min(batch_size, 8), 32);  
  2. dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast<int>(block_dim_2d.x)),  
  3.                  Eigen::divup(cell_size, static_cast<int>(block_dim_2d.y)));  
  4. if (use_peephole) {  
  5.   lstm_gates<T, true><<<grid_dim_2d, block_dim_2d, 0, cu_stream>>>(  
  6.       icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),  
  7.       wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),  
  8.       i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size);  
  9. } else {  
  10.   lstm_gates<T, false><<<grid_dim_2d, block_dim_2d, 0, cu_stream>>>(  
  11.       icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),  
  12.       wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),  
  13.       i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size);  
  14. }  

3. 定义你的OP,在你的OP里调用CUDA的代码,并注册到Tensorflow Kernel中,注意你的Device需要设置成DEVICE_GPU,tensorflow会依据客户端传递的device的参数来决定是否需调用GPU还是CPU的算法,CUDA的文件以.cu.cc为结尾

[html] view plain copy

print?

  1. REGISTER_KERNEL_BUILDER(  
  2.     Name("arithmetic").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),  
  3.     arithmeticOP<Eigen::half>);  

原创声明,本文系作者授权云+社区-专栏发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏码匠的流水账

聊聊resilience4j的bulkhead

resilience4j-bulkhead-0.13.0-sources.jar!/io/github/resilience4j/bulkhead/Bulkhe...

501
来自专栏24K纯开源

OpenGL与CUDA互操作方式总结

一、介绍 CUDA是Nvidia推出的一个通用GPU计算平台,对于提升并行任务的效率非常有帮助。本人主管的项目中采用了OpenGL做图像渲染,但是在数据处理方面...

2145
来自专栏Golang语言社区

Golang语言RPC Authorization进行简单ip安全验证的方法

前言:写网络服务,总要考虑安全机制,对ip和网段进行判断是最简单的一个验证机制。之后想做一个类似注册式的安全验证机制,既可以减少配置文件的麻烦,又可以很好的进行...

2625
来自专栏斑斓

AKKA中的事件流

在《企业应用集成模式》一书中,定义了许多与消息处理有关的模式,其中运用最为广泛的模式为Publisher-Subscriber模式,尤其是在异步处理场景下。 基...

3404
来自专栏琯琯博客

一个有用的PHP片段的集合

2517
来自专栏西安-晁州

golang代码片段(摘抄)

以下是从golang并发编程实战2中摘抄过来的代码片段,主要是实现一个简单的tcp socket通讯(客户端发送一个数字,服务端计算该数字的立方根然后返回),写...

2500
来自专栏吉浦迅科技

DAY46:阅读Surface Reference API

reads the CUDA array bound to the one-dimensional surface reference surfRef usin...

825
来自专栏xdecode

Guice之IOC教程

Guice 在上一篇博客中, 我们讲解了Spring中的IOC示例与实现, 本文着重介绍Guice注入以及与Spring中的差异. Guice是Google开发...

2839
来自专栏Golang语言社区

转-带交互的telnet小工具,golang版

package netTools //main // import ( "fmt" "net" "strconv" "strings" "time"...

32910
来自专栏码匠的流水账

聊聊JerseyEurekaHttpClient的参数

eureka-client-1.8.8-sources.jar!/com/netflix/discovery/shared/transport/jersey/J...

382

扫码关注云+社区