首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >在Keras中为fit_generate创建自定义data_generator ()

在Keras中为fit_generate创建自定义data_generator ()
EN

Stack Overflow用户
提问于 2018-06-04 23:13:33
回答 1查看 2.4K关注 0票数 0

我正在尝试训练一个CNN-LSTM来读取一个6帧的序列到CNN (没有顶层的VGG16),并将提取的特征提供给Keras中的LSTM。

问题是,因为我需要一次发送6帧,所以我需要每6帧重塑一次,并添加一个维度。此外,因为标签是针对每个帧的,所以我需要创建另一个变量来获取每个序列的第一帧的标签,并将其放入一个新的数组中,然后馈送这两个帧以馈送模型(代码如下)。

问题是数据变得太大而不能使用model.fit(),甚至当我尝试对一小部分数据使用它时,我也会得到奇怪的可怕的结果,所以我尝试使用model.fit_generator来迭代模型的输入。但由于我不能直接提供从dataset加载的数据(因为我需要重塑并执行我在第一段中解释的操作),我正在尝试制作自己的生成器。然而,事情并不顺利,我一直收到错误,说'tuple‘不是迭代器。有谁知道我如何修复代码使其正常工作?

代码语言:javascript
复制
train_batches = ImageDataGenerator().flow_from_directory(train_path, target_size=(224, 224),
                                                         classes=['Bark', 'Bitting', 'Engage', 'Hidden', 'Jump',
                                                                  'Stand', 'Walk'], batch_size=18156, shuffle=False)
valid_batches = ImageDataGenerator().flow_from_directory(valid_path, target_size=(224, 224),
                                                         classes=['Bark', 'Bitting', 'Engage', 'Hidden', 'Jump',
                                                                  'Stand', 'Walk'], batch_size=6, shuffle=False)
test_batches = ImageDataGenerator().flow_from_directory(test_path, target_size=(224, 224),
                                                        classes=['Bark', 'Bitting', 'Engage', 'Hidden', 'Jump',
                                                                 'Stand','Walk'], batch_size=6, shuffle=False)

def train_gen():

    n_frames=6    
    n_samples=6 #to decide
    H=W=224
    C = 3

    imgs, labels = next(train_batches)

    y = np.empty((n_samples, 7))
    j = 0
    for i in range(n_samples):       
        y[i] = labels[j]        
        j +=6
    frame_sequence = imgs.reshape(n_samples,n_frames, H,W,C)

    return frame_sequence,y



def valid_gen():

    v_frames=6
    v_samples=1
    H=W=224
    C = 3

    vimgs,vlabels = next(valid_batches)
    y2 = np.empty((v_samples, 7))

    k = 0
    for l in range(v_samples):       
        y2[l] = vlabels[k]        
        k +=6
    valid_sequence = vimgs.reshape(v_samples,v_frames, H,W,C)

    return valid_sequence,y2

def main():

    cnn = VGG16(weights='imagenet',
                include_top='False', pooling='avg')

    cnn.layers.pop()
    print(cnn.summary())
    cnn.trainable = False

    video_input= Input(shape=(None,224,224,3), name='video_input')   
    print(video_input.shape)

    encoded_frame_sequence = TimeDistributed(cnn)(video_input) # the output will be a sequence of vectors
    encoded_video = LSTM(256)(encoded_frame_sequence)  # the output will be a vector
    output = Dense(7, activation='relu')(encoded_video)
    video_model = Model(inputs=[video_input], outputs=output)

    tr_data = train_gen()
    vd_data= valid_gen()
    print(video_model.summary())

    imgs, labels = next(train_batches)    
    vimgs,vlabels = next(valid_batches)

    print("Training ...")
    video_model.compile(Adam(lr=.001), loss='categorical_crossentropy', metrics=['accuracy'])
    video_model.fit_generator(tr_data, 
                              steps_per_epoch=1513, 
                             validation_data=vd_data, 
                              validation_steps=431, 
                              epochs=1, 
                              verbose=2)

我定义生成器的方式有误吗?

EN

回答 1

Stack Overflow用户

发布于 2018-06-05 21:42:04

我认为你应该为数据生成器实现一个类,我找到了这个链接,它可能会对你有所帮助。A detailed example of how to use data generators with Keras

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50683740

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档