def weighted_mse_loss(input_tensor, target_tensor, weight = 1):
observation_dim = input_tensor.size()[-1]
streched_tensor = ((input_tensor - target_tensor) ** 2).view(-1, observation_dim)
entry_num = float(streched_tensor.size())[0]
non_zero_entry_num = torch.sum(streched_tensor[:,0] != 0).float()
weighted_tensor = torch.mm(
((input_tensor - target_tensor)**2).view(-1, observation_dim),
(torch.diag(weight.float().view(-1)))
)
return torch.mean(weighted_tensor) * weight.nelement() * entry_num / non_zero_entry_num
我不明白代码是如何给出加权均方误差损失的。我知道observation_dim
是最终的输出维度(我猜是类号),但在那一行之后,我就不明白了。有人能帮我弄清楚代码是如何计算损失的吗?
非常感谢。
发布于 2020-08-11 11:09:56
def weighted_mse_loss(input, target, weight):
return (weight * (input - target) ** 2).mean()
试试这个,希望这个能有所帮助。所有的参数都需要张量。
https://stackoverflow.com/questions/57004498
复制相似问题