import torch
import torch.nn.functional as F
from torch import nn
[docs]
class KGLoss(nn.Module):
"""
Equation 12 from https://arxiv.org/abs/1703.04977
Models aleatoric uncertainty for classification tasks
"""
def __init__(self, noise_samples=100):
super(KGLoss, self).__init__()
self.noise_samples = noise_samples
self.Gauss = torch.distributions.normal.Normal(0.0, 1.0)
[docs]
def forward(self, logits, noise, target):
"""
Compute the loss for aleatoric uncertainty.
Parameters
----------
logits : torch.Tensor
The output of the model before softmax
noise : torch.Tensor
The output of the model for the aleatoric uncertainty
target : torch.Tensor
The target of the model
Returns
-------
loss : torch.Tensor
The loss of the model
"""
logits_shape = logits.shape
logits_shape = logits_shape[:-1] + (self.noise_samples,) + logits_shape[-1:]
epsilon = self.Gauss.sample(logits_shape).to(logits.device)
sigma = noise.unsqueeze(-2)
f = logits.unsqueeze(-2)
x = f + sigma * epsilon
x = F.softmax(x, dim=-1)
x = x.mean(dim=-2)
loss = F.nll_loss(torch.log(x + 1e-16), target=target, reduction="mean")
return loss
[docs]
class KGLoss_unified_prediction(KGLoss):
"""
Equation 12 from https://arxiv.org/abs/1703.04977
Models aleatoric uncertainty for classification tasks
"""
def __init__(self, num_classes=2, **kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
[docs]
def forward(self, prediction, target):
logits = prediction[..., : self.num_classes]
noise = prediction[..., self.num_classes :]
return super().forward(logits, noise, target)