Source code for qumphy.models.utils.pinballloss

import torch


[docs] class PinballLoss(torch.nn.Module): """ Calculates the quantile loss function. Attributes ---------- self.quantiles : torch.tensor """ def __init__(self, quantiles, num_targets): super(PinballLoss, self).__init__() self.quantiles = torch.tensor(quantiles) self.num_targets = num_targets
[docs] def forward(self, prediction, target): """ Computes the loss for the given prediction. """ prediction = prediction.reshape( [*prediction.shape[:-1], self.num_targets, len(self.quantiles)] ) quantiles = self.quantiles.to(prediction) target = target.unsqueeze(-1) error = target - prediction upper = quantiles * error lower = (quantiles - 1) * error losses = torch.max(lower, upper) loss = torch.mean(torch.sum(losses, dim=(-1, -2))) return loss