import torch
[docs]
class PinballLoss(torch.nn.Module):
"""
Calculates the quantile loss function.
Attributes
----------
self.quantiles : torch.tensor
"""
def __init__(self, quantiles, num_targets):
super(PinballLoss, self).__init__()
self.quantiles = torch.tensor(quantiles)
self.num_targets = num_targets
[docs]
def forward(self, prediction, target):
"""
Computes the loss for the given prediction.
"""
prediction = prediction.reshape(
[*prediction.shape[:-1], self.num_targets, len(self.quantiles)]
)
quantiles = self.quantiles.to(prediction)
target = target.unsqueeze(-1)
error = target - prediction
upper = quantiles * error
lower = (quantiles - 1) * error
losses = torch.max(lower, upper)
loss = torch.mean(torch.sum(losses, dim=(-1, -2)))
return loss