我使用keras进行分类。在一些数据集中,它运行良好并计算损失,而在另一些数据集中,损失是NaN
。
不同的数据集是相似的,因为它们是原始数据集的增广版本。使用keras-bert时,原始数据和某些增广版本的数据运行良好,而其他增广版本的数据运行不好。
当我在扩展版本的数据上使用常规的单层BiLSTM
时,它的工作效果很好,这意味着我可以排除数据出错或包含可能影响损失计算方式的虚假值的可能性。使用中的数据有三个类。
我用的是基于bert的
!wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
有人能告诉我为什么失去的是南吗?
inputs = model.inputs[:2]
dense = model.layers[-3].output
outputs = keras.layers.Dense(3, activation='sigmoid', kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),name = 'real_output')(dense)
decay_steps, warmup_steps = calc_train_steps(train_y.shape[0], batch_size=BATCH_SIZE,epochs=EPOCHS,)
#(decay_steps=decay_steps, warmup_steps=warmup_steps, lr=LR)
model = keras.models.Model(inputs, outputs)
model.compile(AdamWarmup(decay_steps=decay_steps, warmup_steps=warmup_steps, lr=LR), loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])
sess = tf.compat.v1.keras.backend.get_session()
uninitialized_variables = set([i.decode('ascii') for i in sess.run(tf.compat.v1.report_uninitialized_variables ())])
init_op = tf.compat.v1.variables_initializer([v for v in tf.compat.v1.global_variables() if v.name.split(':')[0] in uninitialized_variables])
sess.run(init_op)
model.fit(train_x,train_y,epochs=EPOCHS,batch_size=BATCH_SIZE)
Train on 20342 samples
Epoch 1/10
20342/20342 [==============================] - 239s 12ms/sample - loss: nan - sparse_categorical_accuracy: 0.5572
Epoch 2/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
Epoch 3/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2081
Epoch 4/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
Epoch 5/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
Epoch 6/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
Epoch 7/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
Epoch 8/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2081
Epoch 9/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
Epoch 10/10
20342/20342 [==============================] - 225s 11ms/sample - loss: nan - sparse_categorical_accuracy: 0.2082
<tensorflow.python.keras.callbacks.History at 0x7f1caf9b0f90>
另外,我在Google和tensorflow 2.3.0
和keras 2.4.3
上运行这个
UPDATE
我再次查看了导致这个问题的数据,我意识到其中一个目标标签不见了。我可能错误地编辑了它。一旦我修好了,损失就是NaN问题消失了。然而,我将奖励我得到的答案的50分,因为它让我更好地思考我的代码。谢谢。
发布于 2021-05-08 16:05:34
我注意到了代码中的一个问题,但我不确定这是否是主要原因;如果您可能提供一些可重复的代码,情况会更好。
在上面的代码片段中,您使用sigmoid
在最后一层使用unit < 1
进行激活,这表明问题数据集可能是多标签,这就是为什么丢失函数应该是binary_crossentropy
,但是设置sparse_categorical_crossentropy
是典型的使用多类问题和整数标签< code >E29。
outputs = keras.layers.Dense(3, activation='sigmoid',
kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
name = 'real_output')(dense)
model = keras.models.Model(inputs, outputs)
model.compile(AdamWarmup(decay_steps=decay_steps,
warmup_steps=warmup_steps, lr=LR),
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
因此,如果您的问题数据集是带有最后一层的多标签,那么设置应该更像
outputs = keras.layers.Dense(3, activation='sigmoid',
kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
name = 'real_output')(dense)
model.compile(AdamWarmup(decay_steps=decay_steps,
warmup_steps=warmup_steps, lr=LR),
loss='binary_crossentropy',
metrics=['accuracy'])
但是,如果问题集是一个多类问题,并且您的目标标签是整数 (unit = 3
),那么设置应该更像如下所示:
outputs = keras.layers.Dense(3, activation='softmax',
kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
name = 'real_output')(dense)
model.compile(AdamWarmup(decay_steps=decay_steps,
warmup_steps=warmup_steps, lr=LR),
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
https://stackoverflow.com/questions/67378194
复制相似问题