Source code for qumphy.models.utils.kgloss

import torch
import torch.nn.functional as F
from torch import nn


[docs] class KGLoss(nn.Module): """ Equation 12 from https://arxiv.org/abs/1703.04977 Models aleatoric uncertainty for classification tasks """ def __init__(self, noise_samples=100): super(KGLoss, self).__init__() self.noise_samples = noise_samples self.Gauss = torch.distributions.normal.Normal(0.0, 1.0)
[docs] def forward(self, logits, noise, target): """ Compute the loss for aleatoric uncertainty. Parameters ---------- logits : torch.Tensor The output of the model before softmax noise : torch.Tensor The output of the model for the aleatoric uncertainty target : torch.Tensor The target of the model Returns ------- loss : torch.Tensor The loss of the model """ logits_shape = logits.shape logits_shape = logits_shape[:-1] + (self.noise_samples,) + logits_shape[-1:] epsilon = self.Gauss.sample(logits_shape).to(logits.device) sigma = noise.unsqueeze(-2) f = logits.unsqueeze(-2) x = f + sigma * epsilon x = F.softmax(x, dim=-1) x = x.mean(dim=-2) loss = F.nll_loss(torch.log(x + 1e-16), target=target, reduction="mean") return loss
[docs] class KGLoss_unified_prediction(KGLoss): """ Equation 12 from https://arxiv.org/abs/1703.04977 Models aleatoric uncertainty for classification tasks """ def __init__(self, num_classes=2, **kwargs): super().__init__(**kwargs) self.num_classes = num_classes
[docs] def forward(self, prediction, target): logits = prediction[..., : self.num_classes] noise = prediction[..., self.num_classes :] return super().forward(logits, noise, target)