tensorflow学习笔记(三十一):构建多GPU代码

构建多GPU代码

结构

  1. 先构建单GPU代码
  2. 写个函数multi_gpu_model(num_gpus)来生成多GPU代码,并将对象保存在collection
  3. feed data
  4. run

如何构建单GPU代码

见之前博客构建TF代码 不要在单GPU代码中创建optimizer op,因为是multi gpu,所以参数更新的操作是所有的GPU计算完梯度之后,才进行更新的。

如何实现multi_gpu_model函数

def multi_gpu_model(num_gpus=1):
  grads = []
  for i in range(num_gpus):
    with tf.device("/gpu:%d"%i):
      with tf.name_scope("tower_%d"%i):
        model = Model(is_training, config, scope)
        # 放到collection中,方便feed的时候取
        tf.add_to_collection("train_model", model)
        grads.append(model.grad) #grad 是通过tf.gradients(loss, vars)求得
        #以下这些add_to_collection可以直接在模型内部完成。
        # 将loss放到 collection中, 方便以后操作
        tf.add_to_collection("loss",model.loss)
        #将predict放到collection中,方便操作
        tf.add_to_collection("predict", model.predict)
        #将 summary.merge op放到collection中,方便操作
        tf.add_to_collection("merge_summary", model.merge_summary)
        # ...
  with tf.device("cpu:0"):
    averaged_gradients = average_gradients(grads)# average_gradients后面说明
    opt = tf.train.GradientDescentOptimizer(learning_rate)
    train_op=opt.apply_gradients(zip(average_gradients,tf.trainable_variables()))

  return train_op

如何feed data

def generate_feed_dic(model, feed_dict, batch_generator):
  x, y = batch_generator.next_batch()
  feed_dict[model.x] = x
  feed_dict[model.y] = y

如何实现run_epoch

#这里的scope是用来区别 train 还是 test
def run_epoch(session, data_set, scope, train_op=None, is_training=True):
  batch_generator = BatchGenerator(data_set, batch_size)
  ...
  ...
  if is_training and train_op is not None:
    models = tf.get_collection("train_model")
    # 生成 feed_dict
    feed_dic = {}
    for model in models:
      generate_feed_dic(model, feed_dic, batch_generator)
    #生成fetch_dict
    losses = tf.get_collection("loss", scope)#保证了在 test的时候,不会fetch train的loss
    ...
    ...

main函数

main 函数干了以下几件事: 1. 数据处理 2. 建立多GPU训练模型 3. 建立单/多GPU测试模型 4. 创建Saver对象和FileWriter对象 5. 创建session 6. run_epoch

data_process()
with tf.name_scope("train") as train_scope:
  train_op = multi_gpu_model(..)
with tf.name_scope("test") as test_scope:
  model = Model(...)
saver = tf.train.Saver()
# 建图完毕,开始执行运算
with tf.Session() as sess:
  writer = tf.summary.FileWriter(...)
  ...
  run_epoch(...,train_scope)
  run_epoch(...,test_scope)

如何编写average_gradients函数

def average_gradients(grads):#grads:[[grad0, grad1,..], [grad0,grad1,..]..]
  averaged_grads = []
  for grads_per_var in zip(*grads):
    grads = []
    for grad in grads_per_var:
      expanded_grad = tf.expanded_dim(grad,0)
      grads.append(expanded_grad)
    grads = tf.concat_v2(grads, 0)
    grads = tf.reduce_mean(grads, 0)
    averaged_grads.append(grads)

  return averaged_grads

还有一个版本,但是不work,不知为啥

def average_gradients(grads):#grads:[[grad0, grad1,..], [grad0,grad1,..]..]
  averaged_grads = []
  for grads_per_var in zip(*grads):
    grads = tf.reduce_mean(grads_per_var, 0)
    averaged_grads.append(grads)
  return averaged_grads

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏恰同学骚年

设计模式的征途—8.桥接(Bridge)模式

在现实生活中,我们常常会用到两种或多种类型的笔,比如毛笔和蜡笔。假设我们需要大、中、小三种类型的画笔来绘制12中不同的颜色,如果我们使用蜡笔,需要准备3*12=...

943
来自专栏人工智能LeadAI

在TensorFlow中使用pipeline加载数据

前面对TensorFlow的多线程做了测试,接下来就利用多线程和Queue pipeline地加载数据。数据流如下图所示: ? 首先,A、B、C三个文件通过Ra...

3983
来自专栏逸鹏说道

C# 温故而知新:Stream篇(五)下

对于重写的方法这里不再重复说明,大家可以参考我写的第一篇 以下是memoryStream独有的方法 virtual byte[] GetBuffer() 这个方...

33510
来自专栏老秦求学

任务管理器编码详解

模仿windows任务管理器制作一个任务管理器软件。设计语言不限。 二知识要求    Windows编程,MFC编程,API调用 三.开发环境 使用Micros...

29011
来自专栏性能与架构

大数据运算模型 MapReduce 原理

MapReduce 是一个大数据集合的并行运算模型,由google提出,现在流行的hadoop中也使用了MapReduce作为计算模型 MapReduce 通俗...

3787
来自专栏机器学习算法原理与实践

用Spark学习FP Tree算法和PrefixSpan算法

    在FP Tree算法原理总结和PrefixSpan算法原理总结中,我们对FP Tree和PrefixSpan这两种关联算法的原理做了总结,这里就从实践的...

1243
来自专栏数值分析与有限元编程

ANSYS模拟梁单元铰接点

ANSYS模拟梁单元铰接点有以下几种方法: 1.BEAM3/BEAM4单元,利用结点自由度耦合来实现铰接,在铰接处设两个单独的结点,每个结点只与一个梁单元连接,...

2945
来自专栏用户2442861的专栏

分页和分段的联系和区别

    用户程序的地址空间被划分成若干固定大小的区域,称为“页”,相应地,内存空间分成若干个物理块,页和块的大小相等。可将用户程序的任一页放在内存的任一块中,实...

861
来自专栏机器之心

入门 | GPU是如何优化运行机器学习算法的?

34014
来自专栏小特工作室

DataWindow.Net组件示例(全部开源)

1概述 1.1功能简介 Sybase公司的PowerBuilder开发工具,在以前VS工具没有成事以前,是相当风光的.微软都要与其合作,学习它Db方面的技术,才...

21510

扫码关注云+社区