"""
File: qumphy/models/pulsedb.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Lightning model for PulseDB data.
"""
import torch
import lightning as L
import qumphy
[docs]
class PulseDBGaussianLoss(torch.nn.GaussianNLLLoss):
"""Gaussian negative log likelihood loss for the PulseDB dataset."""
def __init__(self, num_distributions: int = 2, **kwargs) -> None:
"""Initialize the Gaussian negative log likelihood loss.
Parameters
----------
num_distributions : int
Number of predicted mean values. The remaining output values are
interpreted as log-variances.
**kwargs
Additional keyword arguments passed to torch.nn.GaussianNLLLoss.
"""
super().__init__(**kwargs)
self.num_distributions = num_distributions
[docs]
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Calculate the Gaussian negative log likelihood loss.
Parameters
----------
input : torch.Tensor
Model output tensor containing predicted means and log-variances.
target : torch.Tensor
Ground truth target tensor.
Returns
-------
torch.Tensor
Gaussian negative log likelihood loss.
"""
mean = input[:, : self.num_distributions]
variance = torch.exp(input[:, self.num_distributions :])
return super().forward(mean, target, variance)
[docs]
class PulseDBModule(L.LightningModule):
"""
Lightning parent module for PulseDB data.
Takes a specific model architecture as input. (net)
"""
def __init__(
self,
net,
dataset,
optimizer,
output_activation=torch.nn.Identity(),
loss_fn=torch.nn.MSELoss(),
lr_scheduler=None,
pressure="both",
):
"""
Parameters
----------
net : torch.nn.Module
The model architecture
dataset : qumphy.datasets.PulseDBDataset
A PulseDB dataset instance to obtain the target statistics
optimizer : dict
The optimizer configuration
output_activation : torch.nn.Module, optional
The activation function after the last layer
loss_fn : torch.nn.Module, optional
The loss function
lr_scheduler : dict, optional
The learning rate scheduler configuration
pressure : str, optional
The pressure to predict, by default "both"
"""
super().__init__()
self.loss_fn = loss_fn
self.net = net
self.set_dataset_stats(dataset)
self.pressure = pressure
self.output_activation = output_activation
self.lr_scheduler = lr_scheduler
self.optimizer_config = optimizer
[docs]
def forward(self, x):
x = self.net(x)
return self.output_activation(x)
[docs]
def training_step(self, batch, batch_idx):
return self._common_step(batch, batch_idx, "train")
[docs]
def validation_step(self, batch, batch_idx):
return self._common_step(batch, batch_idx, "val")
[docs]
def test_step(self, batch, batch_idx):
return self._common_step(batch, batch_idx, "test")
[docs]
def predict_step(self, batch, batch_idx):
data, target = batch
prediction = self(data)
return prediction
def _common_step(self, batch, batch_idx, stage):
data, target = batch
output = self(data)
loss = self.loss_fn(output, target)
prediction = output.detach()
return {
"loss": loss,
"prediction": prediction,
"target": target,
}
[docs]
def set_dataset_stats(self, dataset):
"""Set and register target statistics from the dataset.
Parameters
----------
dataset : qumphy.datasets.PulseDBDataset
PulseDB dataset instance that provides target statistics.
Returns
-------
None
The function registers target statistics as buffers.
"""
BP_mean, BP_std, BP_median, MAE_baseline, RMSE_baseline = (
dataset.get_target_stats()
)
self.register_buffer("BP_mean", BP_mean)
self.register_buffer("BP_std", BP_std)
self.register_buffer("BP_median", BP_median)
self.register_buffer("MAE_baseline", MAE_baseline)
self.register_buffer("RMSE_baseline", RMSE_baseline)
[docs]
def set_lr_scheduler(self):
"""Configure the optimizer and optional learning rate scheduler.
Returns
-------
dict
Dictionary containing the optimizer and, if provided, the learning
rate scheduler configuration.
"""
self.lr_scheduler["init_args"]["optimizer"] = self.optimizer
scheduler = qumphy.misc.misc.instantiate_class_from_string(
self.lr_scheduler["class_path"],
**self.lr_scheduler["init_args"],
)
self.lr_scheduler["config"]["scheduler"] = scheduler
[docs]
def denormalize_target(self, target):
"""
Rescales in-place the normalized target of 0 mean and 1 standard deviation to the original BP values, using the BP_mean and BP_std attributes.
Parameters
----------
target : torch.tensor
The input target to be denormalized.
Returns
-------
torch.tensor
The denormalized target.
"""
self.BP_mean = self.BP_mean.to(target)
self.BP_std = self.BP_std.to(target)
if self.pressure == "both":
target = target * self.BP_std + self.BP_mean
elif self.pressure == "sbp":
target = target * self.BP_std[0] + self.BP_mean[0]
elif self.pressure == "dbp":
target = target * self.BP_std[1] + self.BP_mean[1]
return target
[docs]
def denormalize_std(self, prediction_std):
"""
Rescales in-place the standard deviation of the normalized target of
0 mean and 1 standard deviation to the original values, using the
BP_std attribute.
Parameters
----------
prediction_std : torch.tensor
The input standard deviation to be denormalized.
Returns
-------
torch.tensor
The denormalized standard deviation.
"""
self.BP_std = self.BP_std.to(prediction_std)
if self.pressure == "both":
prediction_std *= self.BP_std
elif self.pressure == "sbp":
prediction_std *= self.BP_std[0]
elif self.pressure == "dbp":
prediction_std *= self.BP_std[1]
return prediction_std
[docs]
class PulseDBModule_MCD(PulseDBModule):
"""PulseDB Lightning module with Monte Carlo dropout evaluation."""
def __init__(self, *args, MCD_samples=1, **kwargs):
"""Initialize the Monte Carlo dropout PulseDB module.
Parameters
----------
*args
Positional arguments passed to PulseDBModule.
MCD_samples : int
Number of Monte Carlo dropout samples.
**kwargs
Keyword arguments passed to PulseDBModule.
"""
self.MCD_samples = MCD_samples
super().__init__(*args, **kwargs)
[docs]
def test_step(self, batch, batch_idx):
data, target = batch
predictions = torch.stack(
[self(data) for sample in range(self.MCD_samples)], dim=0
)
loss = self.loss_fn(torch.mean(predictions, dim=0), target)
predictions[..., 2:] = torch.exp(predictions[..., 2:])
stacked_predicted_means_sbp = predictions[..., 0]
stacked_predicted_means_dbp = predictions[..., 1]
stacked_predicted_vars_sbp = predictions[..., 2]
stacked_predicted_vars_dbp = predictions[..., 3]
output = torch.mean(predictions)
return {
"loss": loss,
"prediction": output,
"target": target,
"stacked_predicted_means_sbp": stacked_predicted_means_sbp,
"stacked_predicted_means_dbp": stacked_predicted_means_dbp,
"stacked_predicted_vars_sbp": stacked_predicted_vars_sbp,
"stacked_predicted_vars_dbp": stacked_predicted_vars_dbp,
}
def _common_step(self, batch, batch_idx, stage):
data, target = batch
prediction = self(data)
loss = self.loss_fn(prediction, target)
return {
"loss": loss,
"prediction": prediction,
"target": target,
}
[docs]
class PulseDBEnsemble(PulseDBModule):
"""PulseDB ensemble module for Gaussian prediction aggregation."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
[docs]
def load_models(self, model_list):
self.models = torch.nn.ModuleList(model_list)
[docs]
def forward(self, x):
predictions = []
for model in self.models:
predictions.append(model(x))
return torch.stack(predictions)
[docs]
def denormalize(self, prediction):
prediction[..., :2] = prediction[..., :2] * self.BP_std + self.BP_mean
prediction[..., 2:] = prediction[..., 2:] * self.BP_std
return prediction
[docs]
def test_step(self, batch, batch_idx):
"""Run one ensemble test step.
prediction has shape [model, batchsize, logits+noise]
Parameters
----------
batch : tuple
Batch containing input data and target values.
batch_idx : int
Index of the current batch.
Returns
-------
dict
Dictionary containing the loss, ensemble prediction, individual
model predictions, and target.
"""
data, target = batch
prediction = self(data)
# GAUSSIAN MIXTURE AS IN LAKSMINARAYANAN PAPER
# The predictions are given as mean and std and returned the same way.
num_models = prediction.shape[0]
BP_mean = prediction[..., :2]
BP_std = torch.exp(prediction[..., 2:4]).sqrt()
ensemble_BP_mean = torch.mean(BP_mean, axis=0)
# \sigma^2 = 1/N * \sum_{i=1}^N (\sigma_i^2 + \mu_i^2) - \mu^2
ensemble_BP_std = torch.log(
(torch.sum(BP_std**2 + BP_mean**2, axis=0)) / num_models
- ensemble_BP_mean**2
)
ensemble_prediction = torch.concatenate(
[ensemble_BP_mean, ensemble_BP_std], axis=-1
)
loss = self.loss_fn(ensemble_prediction, target)
return {
"loss": loss,
"prediction": ensemble_prediction,
"single_models_prediction": prediction,
"target": target,
}