首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >实现无限循环数据集& DataLoader在PyTorch中的实现

实现无限循环数据集& DataLoader在PyTorch中的实现
EN

Stack Overflow用户
提问于 2019-01-25 05:08:07
回答 3查看 7K关注 0票数 9

我想实现一个无限循环数据集& DataLoader。以下是我尝试过的:

代码语言:javascript
运行
复制
class Infinite(Dataset):
    def __len__(self):
        return HPARAMS.batch_size
#         return 1<<30 # This causes huge memory usage.
    def __getitem__(self, idx):
        """Randomly generates one new example."""
        return sample_func_to_be_parallelized()

infinite_loader = DataLoader(
    dataset=Infinite(), 
    batch_size=HPARAMS.batch_size, 
    num_workers=16,
    worker_init_fn=lambda worker_id: np.random.seed(worker_id),  
)

while True:
    for idx, data in enumerate(infinite_loader):
        # forward + backward on "data"

正如您所看到的,这里的主要挑战是__len()__方法。如果我在那里放了一个足够大的数字,比如1<<30,那么在列车循环的第一次迭代中,内存使用量就会跳转到10+GB。过了一段时间,工人们可能死于OOM。

如果我把一个小的数字,如1或BATCH_SIZE,采样的“数据”在列车循环中将定期重复。这不是我想要的,因为我希望在每次迭代中生成和训练新的数据。

我猜过度使用内存的罪魁祸首在堆栈的某个地方,很多东西被缓存。随便看一看Python方面的东西,我就找不到具体的位置了。

有人能建议我想要实现的最好的方法是什么吗?(使用DataLoader的并行加载,同时确保加载的每一批都是全新的。)

EN

回答 3

Stack Overflow用户

发布于 2019-02-27 08:42:05

这样做似乎没有周期性地重复数据:

代码语言:javascript
运行
复制
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

BATCH_SIZE = 2

class Infinite(Dataset):

    def __len__(self):
        return BATCH_SIZE

    def __getitem__(self, idx):
        return torch.randint(0, 10, (3,))


data_loader = DataLoader(Infinite(), batch_size=BATCH_SIZE, num_workers=16)

batch_count = 0
while True:
    batch_count += 1
    print(f'Batch {batch_count}:')

    data = next(iter(data_loader))
    print(data)
    # forward + backward on "data"  

    if batch_count == 5:
        break

结果:

代码语言:javascript
运行
复制
Batch 1:
tensor([[4, 7, 7],
        [0, 8, 0]])
Batch 2:
tensor([[6, 8, 6],
        [2, 6, 7]])
Batch 3:
tensor([[6, 6, 2],
        [8, 7, 0]])
Batch 4:
tensor([[9, 4, 8],
        [2, 4, 1]])
Batch 5:
tensor([[9, 6, 1],
        [2, 7, 5]])

因此,我认为问题在于您的函数sample_func_to_be_parallelized()

编辑:如果我在__getitem__中使用np.random.randint(10, size=3)代替torch.randint(0, 10, (3,)) (作为sample_func_to_be_parallelized()的一个例子),那么数据实际上在每一批中都会重复。看这个问题

因此,如果您在sample_func_to_be_parallelized()中的某个地方使用numpy的RGN,那么解决方法是使用

代码语言:javascript
运行
复制
worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id) 

并在每次调用np.random.seed()之前通过data = next(iter(data_loader))重新设置种子。

票数 4
EN

Stack Overflow用户

发布于 2019-01-25 14:48:09

DataLoader不需要替换就可以对数据集进行采样。为此,它生成介于0和len(dataset)之间的索引的len(dataset)。我猜这个排列是吞噬你大部分记忆的原因。我不认为PyTorch API支持无限集合,但是您可以尝试在DataLoader中分叉代码并自己执行。您可以使用batch_sampler参数,并传入一个基于RandomSampler实现的自定义变体。这将允许您保留DataLoader的并行加载部分。

尽管如此,基于__len____getitem__的迭代协议并不适合无限集合。您最好重新实现您的Dataset.__len__,只返回1,您的Dataset.__getitem__总是返回一个新的样本,不管索引是什么,然后从这个数据集中用替换来采样n时间。从技术上讲,它将要求第0次样本的n时间,但是由于重写__getitem__返回不同的样本,这将有效地完成所需的任务。

票数 1
EN

Stack Overflow用户

发布于 2019-02-24 11:31:34

尝试使用来自itertoolsitertools。下面是一个简单数据集的示例:

代码:

代码语言:javascript
运行
复制
from itertools import cycle

import torch
from torch.utils.data import Dataset, DataLoader


# Create some dummy data.
data = torch.tensor([[0, 0],
                     [1, 1],
                     [2, 2],
                     [3, 3]])


class DataSet(Dataset):
    """Our dataset. Iterates over tensor data"""

    def __init__(self, data):
        self.data = data
        self.n = self.data.shape[0]

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.data[idx]


bs = 1  # batch size
workers = 1  # number of workers

dataset = DataSet(data)
data_loader = DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=workers)

# Infinite loop.
print(f'batch size: {bs} | number of workers: {workers}')
for i, data in cycle(enumerate(data_loader)):
    print(i, data)

输出:

代码语言:javascript
运行
复制
batch size: 1 | number of workers: 1
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
...
代码语言:javascript
运行
复制
batch size: 2 | number of workers: 2
0 tensor([[0, 0],
        [1, 1]])
1 tensor([[2, 2],
        [3, 3]])
0 tensor([[0, 0],
        [1, 1]])
1 tensor([[2, 2],
...
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54359243

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档