tensorflow自定义op:work_shard

强行解释 work_shard

在学习 tensorflow 自定义 op 的时候碰到的,google 了一下,也没有找到详细的介绍,难道是姿势不对?? 通过看 了一些示例,这里打算强行解释一波。

概览

如果想用 work shard,首先 代码能够并行化计算。work shard 是一个代码并行化工具。不用自己头疼的写多线程代码了。

什么样的代码能够并行化计算 —> 每一个输出数据都能表示成相互无关的

work_shard 的最后一个参数就是要 shard 的 work, 这个 work 的签名为 void shard(int64 start, int64 limit),work_shard 就是将 (start, limit) 给划分成多块,然后 块给 一个线程来计算。

# 如何使用 work_shard
# 1. 包含头文件
# 2. 该用的地方用就行了
# 3. 链接的时候 g++ 会自动找到实现去链接的,不用操心。

代码

work_shard声明代码 地址

// work_sharder.h
#ifndef TENSORFLOW_UTIL_WORK_SHARDER_H_
#define TENSORFLOW_UTIL_WORK_SHARDER_H_

#include <functional>

#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {

// Shards the "total" unit of work assuming each unit of work having
// roughly "cost_per_unit". Each unit of work is indexed 0, 1, ...,
// total - 1. Each shard contains 1 or more units of work and the
// total cost of each shard is roughly the same. The calling thread and the
// "workers" are used to compute each shard (calling work(start,
// limit). A common configuration is that "workers" is a thread pool
// with at least "max_parallelism" threads.
//
// "cost_per_unit" is an estimate of the number of CPU cycles (or nanoseconds
// if not CPU-bound) to complete a unit of work. Overestimating creates too
// many shards and CPU time will be dominated by per-shard overhead, such as
// Context creation. Underestimating may not fully make use of the specified
// parallelism.
//
// "work" should be a callable taking (int64, int64) arguments.
// work(start, limit) computes the work units from [start,
// limit), i.e., [start, limit) is a shard.
//
// REQUIRES: max_parallelism >= 0
// REQUIRES: workers != nullptr
// REQUIRES: total >= 0
// REQUIRES: cost_per_unit >= 0
void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
           int64 cost_per_unit, std::function<void(int64, int64)> work);

}  // end namespace tensorflow

#endif  // TENSORFLOW_UTIL_WORK_SHARDER_H_

用到 Sharder 的地方(见代码片段最后) 完整代码地址

auto shard = [pooled_height, pooled_width, spatial_scale,
num_rois, batch_size, data_height, data_width, num_channels,
&bottom_data_flat, &bottom_rois_flat, &output, &argmax]
(int64 start, int64 limit) {
for (int64 b = start; b < limit; ++b)
{
  // (n, ph, pw, c) is an element in the pooled output
  int n = b;
  int c = n % num_channels;
  n /= num_channels;
  int pw = n % pooled_width;
  n /= pooled_width;
  int ph = n % pooled_height;
  n /= pooled_height;

  const float* bottom_rois = bottom_rois_flat.data() + n * 5;
  int roi_batch_ind = bottom_rois[0];
  int roi_start_w = round(bottom_rois[1] * spatial_scale);
  int roi_start_h = round(bottom_rois[2] * spatial_scale);
  int roi_end_w = round(bottom_rois[3] * spatial_scale);
  int roi_end_h = round(bottom_rois[4] * spatial_scale);

  // Force malformed ROIs to be 1x1
  int roi_width = std::max(roi_end_w - roi_start_w + 1, 1);
  int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
  const T bin_size_h = static_cast<T>(roi_height)
  / static_cast<T>(pooled_height);
  const T bin_size_w = static_cast<T>(roi_width)
  / static_cast<T>(pooled_width);

  int hstart = static_cast<int>(floor(ph * bin_size_h));
  int wstart = static_cast<int>(floor(pw * bin_size_w));
  int hend = static_cast<int>(ceil((ph + 1) * bin_size_h));
  int wend = static_cast<int>(ceil((pw + 1) * bin_size_w));

  // Add roi offsets and clip to input boundaries
  hstart = std::min(std::max(hstart + roi_start_h, 0), data_height);
  hend = std::min(std::max(hend + roi_start_h, 0), data_height);
  wstart = std::min(std::max(wstart + roi_start_w, 0), data_width);
  wend = std::min(std::max(wend + roi_start_w, 0), data_width);
  bool is_empty = (hend <= hstart) || (wend <= wstart);

  // Define an empty pooling region to be zero
  float maxval = is_empty ? 0 : -FLT_MAX;
  // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
  int maxidx = -1;
  const float* bottom_data = bottom_data_flat.data() + roi_batch_ind * num_channels * data_height * data_width;
  for (int h = hstart; h < hend; ++h) {
    for (int w = wstart; w < wend; ++w) {
      int bottom_index = (h * data_width + w) * num_channels + c;
      if (bottom_data[bottom_index] > maxval) {
      maxval = bottom_data[bottom_index];
      maxidx = bottom_index;
      }
    }
  }
  output(b) = maxval;
  argmax(b) = maxidx;
  }
};

const DeviceBase::CpuWorkerThreads& worker_threads =
*(context->device()->tensorflow_cpu_worker_threads());
const int64 shard_cost =
num_rois * num_channels * pooled_height * pooled_width * spatial_scale;

// 用到 shard 的地方
Shard(worker_threads.num_threads, worker_threads.workers, output.size(), shard_cost, shard);

通过调用方法来分析Shard 声明中各参数的意义:

// 声明
void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
           int64 cost_per_unit, std::function<void(int64, int64)> work);

// 调用
Shard(worker_threads.num_threads, worker_threads.workers, output.size(), shard_cost, shard);

// max_parallelism: 最大并行个数,通过调用的形式来看,一般是使用 本机的线程数。
// workers: 从声明来看,是代表的线程池。
// total: 从调用来看,像是 work 中 unit 的数量,即最外层 for 循环的数量。
// cost_per_unit: 对每个 unit 的 cpu 循环的一个估计。

// work: 一个可调用对象,work的调用应该是这样的 work(int64, int64)

Shard 的实现源码:地址 地址如果失效,就去 tensorflow/core/util/work_sharder.cc

将work 分块执行,[0, limit) 变成 [0,block_size), [block_size, 2*block_size) 这么一块一块。 num_shards = total * cost_per_unit / 10000 为了理解 cost_per_unit 可以只关心这一部分 从这个部分可以看出,如果cost_per_unit 的运算量很大的话,tensorflow 会多分几块,那么问题来了,分成多少是合适的呢? block_size = (total + num_shards - 1) / num_shards 。num_shards 要分成几块。

#include "tensorflow/core/util/work_sharder.h"

#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/logging.h"

namespace tensorflow {

void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
           int64 cost_per_unit, std::function<void(int64, int64)> work) {
  CHECK_GE(total, 0);
  if (total == 0) {
    return;
  }
  if (max_parallelism <= 1) {
    // Just inline the whole work since we only have 1 thread (core).
    work(0, total);
    return;
  }
  if (max_parallelism >= workers->NumThreads()) {
    workers->ParallelFor(total, cost_per_unit, work);
    return;
  }
  cost_per_unit = std::max(1LL, cost_per_unit);
  // We shard [0, total) into "num_shards" shards.
  //   1 <= num_shards <= num worker threads
  //
  // If total * cost_per_unit is small, it is not worth shard too
  // much. Let us assume each cost unit is 1ns, kMinCostPerShard=10000
  // is 10us.
  static const int64 kMinCostPerShard = 10000;
  const int num_shards =
      std::max<int>(1, std::min(static_cast<int64>(max_parallelism),
                                total * cost_per_unit / kMinCostPerShard));

  // Each shard contains up to "block_size" units. [0, total) is sharded
  // into:
  //   [0, block_size), [block_size, 2*block_size), ...
  // The 1st shard is done by the caller thread and the other shards
  // are dispatched to the worker threads. The last shard may be smaller than
  // block_size.
  const int64 block_size = (total + num_shards - 1) / num_shards;
  CHECK_GT(block_size, 0);  // total > 0 guarantees this.
  if (block_size >= total) {
    work(0, total);
    return;
  }
  const int num_shards_used = (total + block_size - 1) / block_size;
  BlockingCounter counter(num_shards_used - 1);
  for (int64 start = block_size; start < total; start += block_size) {
    auto limit = std::min(start + block_size, total);
    workers->Schedule([&work, &counter, start, limit]() {
      work(start, limit);        // Compute the shard.
      counter.DecrementCount();  // The shard is done.
    });
  }

  // Inline execute the 1st shard.
  work(0, std::min(block_size, total));
  counter.Wait();
}

}  // end namespace tensorflow

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏微信公众号:Java团长

Java异常进阶

在使用Java编写应用的时候,我们常常需要通过第三方类库来帮助我们完成所需要的功能。有时候这些类库所提供的很多API都通过throws声明了它们所可能抛出的异常...

744
来自专栏葡萄城控件技术团队

深入浅出OOP(一): 多态和继承(早期绑定/编译时多态)

在本系列中,我们以CodeProject上比较火的OOP系列博客为主,进行OOP深入浅出展现。 无论作为软件设计的高手、或者菜鸟,对于架构设计而言,均需要多次重...

1926
来自专栏数说工作室

【SAS Says】基础篇:2. 读取数据

转载请在文章开头注明微信号:shushuojun,谢谢! 本节数据中,我们将介绍SAS读取数据的三种方式: list input、column input、in...

3106
来自专栏DOTNET

设计原则

一、面向对象应用程序开发原则(SOLID) 1单一职责原则(SRP) 定义: 一个类应该只有一个发生变化的原因。这条原则曾被称为内聚性,即一个模块的组成元素之间...

2787
来自专栏程序员阿凯

JDK10 揭秘

1215
来自专栏Clive的技术分享

实现PHP内部的通知机制,如当一个类的属性发生变化时,另外一个类就可以收到通知设计模式:观察者模式使用场景参考链接

设计模式:观察者模式 当一个对象的状态发生改变时,依赖他的对象会全部收到通知,并自动更新。 使用场景 一个事件发生后,要执行一连串更新操作。传统的编程方式,就是...

6067
来自专栏Python小屋

使用Python编写程序求解数独游戏答案

问题描述:数独盘面是个九宫,每一宫又分为九个小格。在这八十一格中给出一定的已知数字和解题条件,利用逻辑和推理,在其他的空格上填入1-9的数字。使1-9每个数字在...

2253
来自专栏IT米粉

你必须了解的反射——反射来实现实体验证

日常开发,都是通过API进行前后端的系统对接,对API参数的验证是一个使用率非常高的功能,如果能非常简便的的进行参数验证,能降低代码量,提升工作效率。

3888
来自专栏程序你好

C# API中的模型和它们的接口设计

682
来自专栏吉浦迅科技

DAY40:阅读Memory Fence Functions

The CUDA programming model assumes a device with a weakly-ordered memory model, ...

604

扫码关注云+社区