前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >torch Dataloader中的num_workers

torch Dataloader中的num_workers

作者头像
狼啸风云
修改2022-09-02 22:12:34
1.8K0
修改2022-09-02 22:12:34
举报
文章被收录于专栏:计算机视觉理论及其实现

考虑这么一个场景,有海量txt文件,一个个batch读进来,测试一下torch DataLoader的效率如何。

基本信息:

  • 本机配置:8核32G内存,工作站内置一块2T的机械硬盘,数据均放在该硬盘上
  • 操作系统:ubuntu 16.04 LTS
  • pytorch:1.0
  • python:3.6

1、首先生成很多随机文本txt

代码语言:javascript
复制
def gen_test_txt():
    population = list(string.ascii_letters) + ['\n']
    for i in range(1000):
        with open(f'./test_txt/{i}.txt', 'w') as f:
            f.write(
                ''.join(random.choices(population, k=1000000))
            )

2、然后顺序读取作为benchmark

代码语言:javascript
复制
def test_torch_reader():
    class Dst(Dataset):
        def __init__(self, paths):
            self.paths = paths

        def __len__(self):
            return len(self.paths)

        def __getitem__(self, i):
            open(self.paths[i], 'r').read()
            return 1

    dst = Dst([f'./test_txt/{i}.txt' for i in range(1000)])
    loader = DataLoader(dst, 128, num_workers=0)

    ts = time()
    time_cost = []
    for i, ele in enumerate(loader, 1):
        dur = time() - ts
        time_cost.append(dur)
        print(i, dur)
        ts = time()

    print(f"{sum(time_cost):.3f}, "
          f"{np.mean(time_cost):.3f}, "
          f"{np.std(time_cost):.3f}, "
          f"{max(time_cost):.3f}, "
          f"{min(time_cost):.3f}")

    plt.plot(time_cost)
    plt.grid()
    plt.show()

每个batch耗时的基本统计信息如下,

基本维持在0.9 sec / batch

total, mean, std, max, min

7.148, 0.893, 0.074, 1.009, 0.726

可见,一共是1000个文件,batch size 128,也就是8个batch,总共耗时7.1s,接下来清除cache,

3、设置num_workers为4

每隔4个batch,要准备4个batch,且是串行的,因此时间增大4倍,接下来3个batch几乎不占用时间

total, mean, std, max, min

7.667, 0.958, 1.652, 3.983, 0.000

接下来实验在SSD上进行,同样num_workers先0后4,如下

total, mean, std, max, min

3.251, 0.406, 0.026, 0.423, 0.338

SSD上,对比机械硬盘更加稳定

然后是num_workers = 4,

total, mean, std, max, min

1.934, 0.242, 0.421, 1.088, 0.000

观察到同样的现象,但尖峰应该是0.4*4=1.6,这里反而epoch 4 (0-index)降为一半为0.8

基本结论:可以看到,不管是在SSD,还是机械硬盘上,总的耗时基本不变(SSD小一些,但原因也可能是实验不充分),并没有因为numworkers增大而减小,令我很费解!我一贯的理解是:比如num_workers为4,那么每个worker计算一个batch,因为本机多核且大于4,讲道理4个worker并行处理,因此时间为num_workers=0的1/4才合理,那原因是为何呢?(这个实验本来是为了load audio数据,其实在audio上作类似实验也是一致的现象)

补充了一个实验,尝试用ray读取,代码如下,

代码语言:javascript
复制
def test_ray():
    ray.init()

    @ray.remote
    def read(paths):
        for path in paths:
            open(path, 'r').read()
        return 1

    def ray_read(paths, n_cpu=4):
        chunk_size = len(paths) // n_cpu
        object_ids = []
        for i in range(n_cpu):
            x = read.remote(paths[i * chunk_size: (i + 1) * chunk_size])
            object_ids.append(x)

        return ray.get(object_ids)

    def batch(l, bs):
        out = []
        i = 0
        while i < len(l):
            out.append(l[i: i + bs])
            i += bs
        return out

    paths = [os.path.expanduser(f'~/test_txt/{i}.txt') for i in range(1000)]
    paths = batch(paths, 128)

    time_cost = []
    ts = time()
    for i, ele in enumerate(paths, 1):
        # read(paths[i - 1])
        ray_read(paths[i - 1], 8)
        dur = time() - ts
        time_cost.append(dur)
        print(i, dur)
        ts = time()

    print(f"{sum(time_cost):.3f}, "
          f"{np.mean(time_cost):.3f}, "
          f"{np.std(time_cost):.3f}, "
          f"{max(time_cost):.3f}, "
          f"{min(time_cost):.3f}")

    plt.plot(time_cost)
    plt.grid()
    plt.show()

流程是这样的:将输入paths分成n_cpu个chunk,chunk之间通过ray异步执行,结果是:同样是在SSD上,理论上每个batch耗时是之前的1/4,也就是0.1s左右,但实测是0.2s,也就是说,n_cpu最大有效值就是2

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020/05/25 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1、首先生成很多随机文本txt
  • 2、然后顺序读取作为benchmark
  • 3、设置num_workers为4
相关产品与服务
对象存储
对象存储(Cloud Object Storage,COS)是由腾讯云推出的无目录层次结构、无数据格式限制,可容纳海量数据且支持 HTTP/HTTPS 协议访问的分布式存储服务。腾讯云 COS 的存储桶空间无容量上限,无需分区管理,适用于 CDN 数据分发、数据万象处理或大数据计算与分析的数据湖等多种场景。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档