首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >StopIteration: generator_output = next(output_generator)

StopIteration: generator_output = next(output_generator)
EN

Stack Overflow用户
提问于 2018-02-10 00:10:33
回答 2查看 8.8K关注 0票数 12

我有以下代码,我将其重写以处理大规模数据集。我正在使用Python生成器对逐批生成的数据进行模型拟合。

代码语言:javascript
运行
复制
def subtract_mean_gen(x_source,y_source,avg_image,batch):
    batch_list_x=[]
    batch_list_y=[]
    for line,y in zip(x_source,y_source):
        x=line.astype('float32')
        x=x-avg_image
        batch_list_x.append(x)
        batch_list_y.append(y)
        if len(batch_list_x) == batch:
            yield (np.array(batch_list_x),np.array(batch_list_y))
            batch_list_x=[]
            batch_list_y=[] 

model = resnet.ResnetBuilder.build_resnet_18((img_channels, img_rows, img_cols), nb_classes)
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

val = subtract_mean_gen(X_test,Y_test,avg_image_test,batch_size)
model.fit_generator(subtract_mean_gen(X_train,Y_train,avg_image_train,batch_size), steps_per_epoch=X_train.shape[0]//batch_size,epochs=nb_epoch,validation_data = val,
                    validation_steps = X_test.shape[0]//batch_size)

我得到了以下错误:

代码语言:javascript
运行
复制
239/249 [===========================>..] - ETA: 60s - loss: 1.3318 - acc: 0.8330Exception in thread Thread-1:
Traceback (most recent call last):
  File "/usr/lib/python2.7/threading.py", line 801, in __bootstrap_inner
    self.run()
  File "/usr/lib/python2.7/threading.py", line 754, in run
    self.__target(*self.__args, **self.__kwargs)
  File "/usr/local/lib/python2.7/dist-packages/keras/utils/data_utils.py", line 560, in data_generator_task
    generator_output = next(self._generator)
StopIteration

240/249 [===========================>..] - ETA: 54s - loss: 1.3283 - acc: 0.8337Traceback (most recent call last):
  File "cifa10-copy.py", line 125, in <module>
    validation_steps = X_test.shape[0]//batch_size)
  File "/usr/local/lib/python2.7/dist-packages/keras/legacy/interfaces.py", line 87, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1809, in fit_generator
    generator_output = next(output_generator)
StopIteration

我研究了一个发布在here上的类似问题,但是,我不能解决为什么StopIteration被引发的错误。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-02-10 00:16:23

keras的生成器必须是无限的:

代码语言:javascript
运行
复制
def subtract_mean_gen(x_source,y_source,avg_image,batch):
    while True:
        batch_list_x=[]
        batch_list_y=[]
        for line,y in zip(x_source,y_source):
            x=line.astype('float32')
            x=x-avg_image
            batch_list_x.append(x)
            batch_list_y.append(y)
            if len(batch_list_x) == batch:
                yield (np.array(batch_list_x),np.array(batch_list_y))
                batch_list_x=[]
                batch_list_y=[] 

之所以会出现这个错误,是因为keras试图获取一个新的批处理,但是您的生成器已经到达了它的末尾。(即使您定义了正确的步骤数,keras也有一个队列,即使您处于最后一步,它也会尝试从生成器获取更多批处理。)

显然,您有一个默认的队列大小,即10 (异常出现在结束之前的10个批,因为队列试图在结束之后获取一个批)。

票数 25
EN

Stack Overflow用户

发布于 2018-02-10 00:17:24

正如您提供的链接问题所示,Keras生成器必须无限迭代,因此您可以根据需要将元素输出到您的训练中。关于this Github问题的更多信息。

为此,你必须对你的生成器做一些修改,比如:

代码语言:javascript
运行
复制
def subtract_mean_gen(x_source,y_source,avg_image,batch):
batch_list_x=[]
batch_list_y=[]
while 1: #run forever, so you can generate elements indefinitely
    for line,y in zip(x_source,y_source):
        x=line.astype('float32')
        x=x-avg_image    
        batch_list_x.append(x)
        batch_list_y.append(y)
        if len(batch_list_x) == batch:
            yield (np.array(batch_list_x),np.array(batch_list_y))
            batch_list_x=[]
            batch_list_y=[]
票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48709839

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档