首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在keras中通过ResNet50迁移学习时,Loss总是去nan

在keras中通过ResNet50迁移学习时,Loss总是去nan
EN

Stack Overflow用户
提问于 2018-06-04 17:53:01
回答 1查看 636关注 0票数 0

我正在使用迁移学习在Keras中通过ResNet50模型和加载预先训练的权重来训练图像分类器,但loss最初立即进入nan,而acc保持在随机水平。

实际上,我不知道哪里出了问题,因为我已经用这个模型成功地训练了一个分类器,虽然它的acc不高,但它工作得很好。这一次它失败了。

我调整了lr,但什么也没发生。有人说数据可能有问题,所以我改变了数据,只发现不同的图像,相同的模型会显示不同的结果(也就是说,一些数据/图像工作良好,另一个数据/图像将立即loss:nan )。怎么可能呢?我真的很困惑,不知道我的图片出了什么问题。

数据集:8个类,每个类包含大约300个图像。

下面是所有代码:

代码语言:javascript
运行
复制
import keras
import h5py
import numpy as np
import matplotlib.pyplot as plt

from keras.applications import ResNet50
from keras.models import Sequential
from keras.layers import Dense, Flatten, GlobalAveragePooling2D
from keras.applications.resnet50 import preprocess_input
from keras.preprocessing.image import ImageDataGenerator


data_generator = ImageDataGenerator(preprocessing_function= preprocess_input, 
                        rescale = 1./255)

train_generator = data_generator.flow_from_directory("image/train", 
                        target_size = (100, 100), 
                        batch_size = 32, 
                        class_mode = "categorical")
dev_generator = data_generator.flow_from_directory("image/dev", 
                        target_size = (100, 100), 
                        batch_size = 32, 
                        class_mode = "categorical")

num_classes = 8
model = Sequential()
model.add(ResNet50(include_top = False, pooling = "avg", weights= "resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5"))
model.add(Dense(num_classes, activation = "softmax"))
model.layers[0].trainable = False

model.compile(optimizer= "adam", loss= "categorical_crossentropy", metrics=["accuracy"])

model.fit_generator(train_generator, steps_per_epoch= 1, epochs = 1)

并且运行的输出是:

代码语言:javascript
运行
复制
Epoch 1/1
1/1 [==============================] - 6s 6s/step - loss: nan - acc: 0.0938
EN

回答 1

Stack Overflow用户

发布于 2018-06-04 18:05:49

第一个将“image/dev”更正为"image/dev"

我认为你的错误在于这一行:

代码语言:javascript
运行
复制
data_generator = ImageDataGenerator(preprocessing_function= preprocess_input, rescale = 1./255)

当您同时使用preprocess_input函数和rescale = 1./255时,您可以对数据进行双倍缩放。尝试删除重新缩放...

代码语言:javascript
运行
复制
data_generator = ImageDataGenerator(preprocessing_function= preprocess_input)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50677834

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档