Source code for qumphy.models.utils.kgactivation

import torch


[docs] class KGActivation(torch.nn.Module): """ Activation neeed to use KGLoss Equation 12 from https://arxiv.org/abs/1703.04977 Used to model aleatoric uncertainty for classification tasks """ def __init__(self): super(KGActivation, self).__init__() self.softplus = torch.nn.Softplus()
[docs] def forward(self, x): x1 = x[..., :2] x2 = self.softplus(x[..., 2:]) return torch.cat([x1, x2], dim=-1)