Source code for qumphy.models.utils.kgactivation
import torch
[docs]
class KGActivation(torch.nn.Module):
"""
Activation neeed to use KGLoss
Equation 12 from https://arxiv.org/abs/1703.04977
Used to model aleatoric uncertainty for classification tasks
"""
def __init__(self):
super(KGActivation, self).__init__()
self.softplus = torch.nn.Softplus()
[docs]
def forward(self, x):
x1 = x[..., :2]
x2 = self.softplus(x[..., 2:])
return torch.cat([x1, x2], dim=-1)