学习笔记TF062:TensorFlow线性代数编译框架XLA

XLA(Accelerated Linear Algebra),线性代数领域专用编译器(demain-specific compiler),优化TensorFlow计算。即时(just-in-time,JIT)编译或提前(ahead-of-time,AOT)编译实现XLA,有助于硬件加速。XLA还在试验阶段。https://www.tensorflow.org/versions/master/experimental/xla/

XLA优势。线性代数领域专用编译器,优化TensorFlow计算的执行速度(编译子图减少生命周期较短操作执行时间,融合管道化操作减少内存占用)、内存使用(分析、规划内存使用需求,消除许多中间结果缓存)、自定义操作依赖(提高自动化融合底层操作low-level op性能,达到手动融合自定义操作custom op效果)、移动端内存占用(提前AOT编译子图减少TensorFlow执行时间,共享头文件对被其他程序直接链接)、可移植性方面(为新硬件开发新后端,TensorFlow不需要更改很多代码用在新硬件设备上)。

XLA工作原理。LLVM编译器框架系统,C++编写,优化任意编程语言缩写程序编译时间(compile time)、链接时间(link time)、运行时间(run time)、空闲时间(idle time)。前端解析、验证、论断输入代码错误,解析代码转换LLVM中间表示(intermdediate representation,IR)。IR分析、优化改进代码,发送到代码生成器,产生本地机器代码。三相设计LLVM实现。最重要,LLVM IR。编译器IR表示代码。C->Clang C/C++/ObjC前端、Fortran->llvm-gcc前端、Haskell->GHC前端 LLVM IR-> LLVM 优化器 ->LLVM IR LLVM X86后端->X86、LLVM PowerPC后端->PowerPC、LLVM ARM后端->ARM。http://www.aosabook.org/en/llvm.html 。 XLA输入语言HLO IR,XLA HLO定义图形,编译成各种体系结构机器指令。编译过程。XLA HLO->目标无关优化分析->XLA HLO->XLA后端->目标相关优化分析->目标特定代码生成。XLA首先进行目标无关优化分析(公共子表达式消除common subexpression elimination CSE,目标无关操作融合,运行时内存缓冲区分析)。XLA将HLO计算发送到后端。后端执行进一步HLO级目标不相关优化分析。XLA GPU后端执行对GPU编程模型有益操作融合,确定计算划分成流。生成目标特定代码。XLA CPU、GPU后端用LLVM中间表示、优化、代码生成。后端用LLVM IR表示XLA HLO计算。XLA 支持x86-64、NVIDIA GPU JIT编译,x86-64、ARM AOT编译。AOT更适合移动、嵌入式深度学习应用。

JIT编译方式。XLA编译、运行TensorFlow计算图一部分。XLA 将多个操作(内核)融合到少量编译内核,融合操作符减少存储器带宽提高性能。XLA 运行TensorFlow计算方法。一,打开CPU、GPU设备JIT编译。二,操作符放在XLA_CPU、XLA_GPU设备。 打开JIT编译。在会话打开。把所有可能操作符编程成XLA计算。

config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
sess = tf.Session(config=config)

为一个或多个操作符手动打开JIT编译。属性_XlaCompile = true标记编译操作符。

jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
x = tf.placeholder(np.float32)
with jit_scope():
  y = tf.add(x, x)

操作符放在XLA设备。有效设备XLA_CPU、XLA_GPU:

with tf.device("/job:localhost/replica:0/task:0/device:XLA_GPU:0"):
  output = tf.add(input1, input2)

JIT编译MNIST实现。https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py 。 不使用XLA运行。

python mnist_softmax_xla.py --xla=false

运行完成生成时间线文件timeline.ctf.json,用Chrome跟踪事件分析器 chrome://tracing,打开时间线文件,呈现时间线。左侧列出GPU,可以看操作符时间消耗情况。 用XLA训练模型。

TF_XLA_FLAGS=--xla_generate_hlo_graph=.* python mnist_softmax_xla.py

XLA框架处于试验阶段,AOT主要应用场景内存较小嵌入式设备、手机、树莓派。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.client import timeline
FLAGS = None
def main(_):
  # Import data
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
  # Create the model
  x = tf.placeholder(tf.float32, [None, 784])
  w = tf.Variable(tf.zeros([784, 10]))
  b = tf.Variable(tf.zeros([10]))
  y = tf.matmul(x, w) + b
  # Define loss and optimizer
  y_ = tf.placeholder(tf.float32, [None, 10])
  # The raw formulation of cross-entropy,
  #
  #   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
  #                                 reduction_indices=[1]))
  #
  # can be numerically unstable.
  #
  # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
  # outputs of 'y', and then average across the batch.
  cross_entropy = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
  config = tf.ConfigProto()
  jit_level = 0
  if FLAGS.xla:
    # Turns on XLA JIT compilation.
    # 开启XLA JIT编译
    jit_level = tf.OptimizerOptions.ON_1
  config.graph_options.optimizer_options.global_jit_level = jit_level
  run_metadata = tf.RunMetadata()
  sess = tf.Session(config=config)
  tf.global_variables_initializer().run(session=sess)
  # Train
  # 训练
  train_loops = 1000
  for i in range(train_loops):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    # Create a timeline for the last loop and export to json to view with
    # chrome://tracing/.
    # 在最后一次循环创建时间线文件,用chrome://tracing/打开分析
    if i == train_loops - 1:
      sess.run(train_step,
               feed_dict={x: batch_xs,
                          y_: batch_ys},
               options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
               run_metadata=run_metadata)
      trace = timeline.Timeline(step_stats=run_metadata.step_stats)
      with open('timeline.ctf.json', 'w') as trace_file:
        trace_file.write(trace.generate_chrome_trace_format())
    else:
      sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
  # Test trained model
  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print(sess.run(accuracy,
                 feed_dict={x: mnist.test.images,
                            y_: mnist.test.labels}))
  sess.close()
if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--data_dir',
      type=str,
      default='/tmp/tensorflow/mnist/input_data',
      help='Directory for storing input data')
  parser.add_argument(
      '--xla', type=bool, default=True, help='Turn xla via JIT on')
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

参考资料: 《TensorFlow技术解析与实战》

欢迎推荐上海机器学习工作机会,我的微信:qingxingfengzi

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Windows Community

Microsoft AI - Custom Vision in C#

概述 前面一篇 Microsoft AI - Custom Vision 中,我们介绍了 Azure 认知服务中的自定义影像服务:Custom Vision,也...

3568
来自专栏Python中文社区

Keras 的 Web 填坑记

博客主页:https://www.zhihu.com/people/tu-dou-dou-27-10

1643
来自专栏应兆康的专栏

Bing 每日一图 & 随机图片 API

大家都知道微软的 Bing 搜索引擎首页每天都会提供了一些有趣的图片,而这些图片很多都是有故事含义的,很多网友每天去访问 Bing 首页都是为了这些图片而去的。...

8547
来自专栏IT笔记

Dubbo负载均衡配置

在集群负载均衡时,Dubbo提供了多种均衡策略,缺省为random随机调用。 负载均衡扩展 (1) 扩展说明: 从多个服务提者方中选择一个进行调用。 (2) 扩...

3435
来自专栏炉边夜话

利用Oprofile对多核多线程进行性能分析

在对应用程序不断调优的过程中,除了制定完备的测试基准(Benchmark)外,还需要一把直中要害的利器——性能分析工具。

962
来自专栏华章科技

利用R语言制作出漂亮的交互数据可视化

利用R语言也可以制作出漂亮的交互数据可视化,下面和大家分享一些常用的交互可视化的R包。

931
来自专栏宏伦工作室

深度有趣 | 01-02 前言和准备工作

用 Python 做一些有意思的案例和应用,内容和领域不限,可以包括数据分析、自然语言理解、计算机视觉,等等等等

752
来自专栏mathor

尼姆博弈(Nim Game)

723
来自专栏不止思考

网络中的「动态路由算法」,你了解吗?

在计算机网络中,路由器的一个很重要责任就是要在端对端的节点中找出一条最佳路径出来,通过自己与相邻节点之间的信息,来计算出从自己位置到目的节点之间的最佳线路,这种...

1345
来自专栏Timhbw博客

个人博客SEO设置小技巧

2016-05-0518:42:17 发表评论 499℃热度 个人水平有限,还在初步学习SEO中,下面会更新一些我所学到的关于博客SEO的小技巧,大家可以发表...

2808

扫码关注云+社区