首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >用于不适合内存的大型hdf5文件的Keras自定义数据生成器

用于不适合内存的大型hdf5文件的Keras自定义数据生成器
EN

Stack Overflow用户
提问于 2017-11-02 00:44:19
回答 4查看 10.2K关注 0票数 10

我正在尝试使用预训练的InceptionV3模型来对food-101 dataset进行分类,它包含101个类别的食物图像,每个类别1000个。到目前为止,我已经将这个数据集预处理成一个hdf5文件(我假设这与在训练时在路上加载图像相比是有益的),其中包含以下表格:

数据分割是标准的70%训练,20%验证,10%测试,例如,valid_img的大小为20200*299*299*3。标签是针对Keras编码的,因此valid_labels的大小为20200*101.

此hdf5文件的大小为27.1 GB,因此无法放入我的内存中。(有8 GB的内存,虽然在运行Ubuntu时实际上只有4-5 GB可用。另外,我的GPU是GTX 960,有2 GB的VRAM,到目前为止,当我尝试启动训练脚本时,它看起来有1.5 GB可用于python )。我正在使用Tensorflow后端。

我的第一个想法是使用带有双重嵌套的for循环的model.train_on_batch(),如下所示:

代码语言:javascript
复制
#Loading InceptionV3, adding my fully connected layers, compiling model...    

dataset = h5py.File('/home/uzoltan/PycharmProjects/food-101/food-101_299x299.hdf5', 'r')
    epoch = 50
    for i in range(epoch):
        for i in range(100): #1000 images can fit in the memory easily, this could probably be range(10) too
            train_images = dataset["train_img"][i * 706:(i + 1) * 706, ...]
            train_labels = dataset["train_labels"][i * 706:(i + 1) * 706, ...]
            val_images = dataset["valid_img"][i * 202:(i + 1) * 202, ...]
            val_labels = dataset["valid_labels"][i * 202:(i + 1) * 202, ...]
            model.train_on_batch(x=train_images, y=train_labels, class_weight=None,
                                 sample_weight=None, )

我对这种方法的问题是,train_on_batch为验证或批处理混洗提供了0支持,因此每个时期批处理的顺序不同。

所以我把目光投向了model.fit_generator(),它提供了与fit()相同的所有功能,加上内置的ImageDataGenerator,你可以进行图像增强(旋转,水平翻转等)。同时配合CPU,这样你的模型才能更健壮。我这里的问题是,如果我理解正确的话,ImageDataGenerator.flow(x,y)方法需要一次所有的样本和标签,但是我的训练/验证数据无法放入我的内存中。

这就是我认为自定义数据生成器需要考虑的地方,但在广泛地查看了一些我可以在Keras GitHub/Issues页面上找到的示例后,我仍然不太明白应该如何实现一个自定义生成器,它将从我的hdf5文件中批量读取数据。有人能为我提供一个好的示例或指针吗?如何将自定义批处理生成器与图像增强功能耦合起来?或者,也许为train_on_batch()实现某种手动验证和批处理混洗更容易?如果是这样的话,我也可以在那里使用一些指针。

EN

回答 4

Stack Overflow用户

发布于 2019-05-22 16:19:16

对于任何还在寻找答案的人,我围绕ImageDataGeneator的apply_transform方法做了以下“粗糙的包装”。

代码语言:javascript
复制
from numpy.random import uniform, randint
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
import numpy as np

class CustomImagesGenerator:
    def __init__(self, x, zoom_range, shear_range, rescale, horizontal_flip, batch_size):
        self.x = x
        self.zoom_range = zoom_range
        self.shear_range = shear_range
        self.rescale = rescale
        self.horizontal_flip = horizontal_flip
        self.batch_size = batch_size
        self.__img_gen = ImageDataGenerator()
        self.__batch_index = 0

    def __len__(self):
        # steps_per_epoch, if unspecified, will use the len(generator) as a number of steps.
        # hence this
        return np.floor(self.x.shape[0]/self.batch_size)

    def next(self):
        return self.__next__()

    def __next__(self):
        start = self.__batch_index*self.batch_size
        stop = start + self.batch_size
        self.__batch_index += 1
        if stop > len(self.x):
            raise StopIteration
        transformed = np.array(self.x[start:stop])  # loads from hdf5
        for i in range(len(transformed)):
            zoom = uniform(self.zoom_range[0], self.zoom_range[1])
            transformations = {
                'zx': zoom,
                'zy': zoom,
                'shear': uniform(-self.shear_range, self.shear_range),
                'flip_horizontal': self.horizontal_flip and bool(randint(0,2))
            }
            transformed[i] = self.__img_gen.apply_transform(transformed[i], transformations)
        return transformed * self.rescale

它可以这样命名:

代码语言:javascript
复制
import h5py
f = h5py.File("my_heavy_dataset_file.hdf5", 'r')
images = f['mydatasets/images']

my_gen = CustomImagesGenerator(
    images, 
    zoom_range=[0.8, 1], 
    shear_range=6, 
    rescale=1./255, 
    horizontal_flip=True, 
    batch_size=64
)

model.fit_generator(my_gen)
票数 3
EN

Stack Overflow用户

发布于 2018-07-05 02:23:56

如果我没理解错的话,您希望使用HDF5中的数据(内存不能容纳),同时对其使用数据增强。

我和你的情况一样,我发现这段代码可能会有帮助,只需做一些修改:

https://gist.github.com/wassname/74f02bc9134897e3fe4e60784f5aaa15

票数 2
EN

Stack Overflow用户

发布于 2018-05-01 13:29:17

这是我使用h5文件对每个时期的数据进行混洗的解决方案。index表示列车或val索引表。

代码语言:javascript
复制
def generator(h5path, indices, batchSize=128, is_train=True, aug=None):

    db = h5py.File(h5path, "r")
    with open("mean.json") as f:
        mean = json.load(f)
    meanV = np.array([mean["R"], mean["G"], mean["B"]])

    while True:

        np.random.shuffle(indices)
        for i in range(0, len(indices), batchSize):
            t0 = time()
            batch_indices = indices[i:i+batchSize]
            batch_indices.sort()

            by = db["labels"][batch_indices,:]
            bx = db["images"][batch_indices,:,:,:]

            bx[:,:,:,0] -= meanV[0]
            bx[:,:,:,1] -= meanV[1]
            bx[:,:,:,2] -= meanV[2]
            t1=time()

            if is_train:

                #bx = random_crop(bx, (224,224))
                if aug is not None:
                    bx,by = next(aug.flow(bx,by,batchSize))

            yield (bx,by)


h5path='all_224.hdf5'   
model.fit_generator(generator(h5path, train_indices, batchSize=batchSize, is_train=True, aug=aug),
                steps_per_epoch = 20000//batchSize,
                validation_data= generator(h5path, test_indices, is_train=False, batchSize=batchSize), 
                validation_steps = 2424//batchSize,
                epochs=args.epoch, 
                max_queue_size=100,
                callbacks=[checkpoint, early_stop])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47059698

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档