我正在训练磁共振脑图像(2D切片)的Conv。模型的输出为sigmoid,损失函数为二进制交叉熵:
x = input, x_hat = output
rec_loss = nn.functional.binary_cross_entropy(x_hat.view(-1, 128 ** 2), x.view(-1, 128 ** 2),reduction='sum')
但我的问题实际上是KL散度损失:
KL_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
在训练的某一时刻,KL散度损失非常高(某处无穷大)。
然后我有一个错误,在下面你可以看到,这可能是因为输出是nan。对如何避免爆炸有什么建议吗?
发布于 2022-04-28 07:25:11
在BCE和KL散度中,您可以使用这些方法作为一种约简方法。KL_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
KL散度通常有一些尖峰,可能比其他数值高出数量级。我不知道为什么会发生这种事,但这也让我很恼火:)
可能你的模型很早就崩溃了。如果你记录下你的KL散度,你会发现你以后还可以有尖峰,但是它们更小,因为整个KL发散项变小了。
https://stackoverflow.com/questions/67889578
复制相似问题