Source code for qumphy.models.pulsedb

"""
File: qumphy/models/pulsedb.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Lightning model for PulseDB data.
"""

import torch
import lightning as L
import qumphy


[docs] class PulseDBGaussianLoss(torch.nn.GaussianNLLLoss): """Gaussian negative log likelihood loss for the PulseDB dataset.""" def __init__(self, num_distributions: int = 2, **kwargs) -> None: """Initialize the Gaussian negative log likelihood loss. Parameters ---------- num_distributions : int Number of predicted mean values. The remaining output values are interpreted as log-variances. **kwargs Additional keyword arguments passed to torch.nn.GaussianNLLLoss. """ super().__init__(**kwargs) self.num_distributions = num_distributions
[docs] def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Calculate the Gaussian negative log likelihood loss. Parameters ---------- input : torch.Tensor Model output tensor containing predicted means and log-variances. target : torch.Tensor Ground truth target tensor. Returns ------- torch.Tensor Gaussian negative log likelihood loss. """ mean = input[:, : self.num_distributions] variance = torch.exp(input[:, self.num_distributions :]) return super().forward(mean, target, variance)
[docs] class PulseDBModule(L.LightningModule): """ Lightning parent module for PulseDB data. Takes a specific model architecture as input. (net) """ def __init__( self, net, dataset, optimizer, output_activation=torch.nn.Identity(), loss_fn=torch.nn.MSELoss(), lr_scheduler=None, pressure="both", ): """ Parameters ---------- net : torch.nn.Module The model architecture dataset : qumphy.datasets.PulseDBDataset A PulseDB dataset instance to obtain the target statistics optimizer : dict The optimizer configuration output_activation : torch.nn.Module, optional The activation function after the last layer loss_fn : torch.nn.Module, optional The loss function lr_scheduler : dict, optional The learning rate scheduler configuration pressure : str, optional The pressure to predict, by default "both" """ super().__init__() self.loss_fn = loss_fn self.net = net self.set_dataset_stats(dataset) self.pressure = pressure self.output_activation = output_activation self.lr_scheduler = lr_scheduler self.optimizer_config = optimizer
[docs] def forward(self, x): x = self.net(x) return self.output_activation(x)
[docs] def training_step(self, batch, batch_idx): return self._common_step(batch, batch_idx, "train")
[docs] def validation_step(self, batch, batch_idx): return self._common_step(batch, batch_idx, "val")
[docs] def test_step(self, batch, batch_idx): return self._common_step(batch, batch_idx, "test")
[docs] def predict_step(self, batch, batch_idx): data, target = batch prediction = self(data) return prediction
def _common_step(self, batch, batch_idx, stage): data, target = batch output = self(data) loss = self.loss_fn(output, target) prediction = output.detach() return { "loss": loss, "prediction": prediction, "target": target, }
[docs] def set_dataset_stats(self, dataset): """Set and register target statistics from the dataset. Parameters ---------- dataset : qumphy.datasets.PulseDBDataset PulseDB dataset instance that provides target statistics. Returns ------- None The function registers target statistics as buffers. """ BP_mean, BP_std, BP_median, MAE_baseline, RMSE_baseline = ( dataset.get_target_stats() ) self.register_buffer("BP_mean", BP_mean) self.register_buffer("BP_std", BP_std) self.register_buffer("BP_median", BP_median) self.register_buffer("MAE_baseline", MAE_baseline) self.register_buffer("RMSE_baseline", RMSE_baseline)
[docs] def configure_optimizers(self): """Configure the optimizer and optional learning rate scheduler. Returns ------- dict Dictionary containing the optimizer and, if provided, the learning rate scheduler configuration. """ if "params" in self.optimizer_config["init_args"]: for param_group in self.optimizer_config["init_args"]["params"]: param_group["params"] = eval(param_group["params"]) else: self.optimizer_config["init_args"]["params"] = self.parameters() self.optimizer = qumphy.misc.misc.instantiate_class_from_string( self.optimizer_config["class_path"], **self.optimizer_config["init_args"] ) if self.lr_scheduler is not None: self.set_lr_scheduler() return { "optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler["config"], } return {"optimizer": self.optimizer}
[docs] def set_lr_scheduler(self): """Configure the optimizer and optional learning rate scheduler. Returns ------- dict Dictionary containing the optimizer and, if provided, the learning rate scheduler configuration. """ self.lr_scheduler["init_args"]["optimizer"] = self.optimizer scheduler = qumphy.misc.misc.instantiate_class_from_string( self.lr_scheduler["class_path"], **self.lr_scheduler["init_args"], ) self.lr_scheduler["config"]["scheduler"] = scheduler
[docs] def denormalize_target(self, target): """ Rescales in-place the normalized target of 0 mean and 1 standard deviation to the original BP values, using the BP_mean and BP_std attributes. Parameters ---------- target : torch.tensor The input target to be denormalized. Returns ------- torch.tensor The denormalized target. """ self.BP_mean = self.BP_mean.to(target) self.BP_std = self.BP_std.to(target) if self.pressure == "both": target = target * self.BP_std + self.BP_mean elif self.pressure == "sbp": target = target * self.BP_std[0] + self.BP_mean[0] elif self.pressure == "dbp": target = target * self.BP_std[1] + self.BP_mean[1] return target
[docs] def denormalize_std(self, prediction_std): """ Rescales in-place the standard deviation of the normalized target of 0 mean and 1 standard deviation to the original values, using the BP_std attribute. Parameters ---------- prediction_std : torch.tensor The input standard deviation to be denormalized. Returns ------- torch.tensor The denormalized standard deviation. """ self.BP_std = self.BP_std.to(prediction_std) if self.pressure == "both": prediction_std *= self.BP_std elif self.pressure == "sbp": prediction_std *= self.BP_std[0] elif self.pressure == "dbp": prediction_std *= self.BP_std[1] return prediction_std
[docs] class PulseDBModule_MCD(PulseDBModule): """PulseDB Lightning module with Monte Carlo dropout evaluation.""" def __init__(self, *args, MCD_samples=1, **kwargs): """Initialize the Monte Carlo dropout PulseDB module. Parameters ---------- *args Positional arguments passed to PulseDBModule. MCD_samples : int Number of Monte Carlo dropout samples. **kwargs Keyword arguments passed to PulseDBModule. """ self.MCD_samples = MCD_samples super().__init__(*args, **kwargs)
[docs] def test_step(self, batch, batch_idx): data, target = batch predictions = torch.stack( [self(data) for sample in range(self.MCD_samples)], dim=0 ) loss = self.loss_fn(torch.mean(predictions, dim=0), target) predictions[..., 2:] = torch.exp(predictions[..., 2:]) stacked_predicted_means_sbp = predictions[..., 0] stacked_predicted_means_dbp = predictions[..., 1] stacked_predicted_vars_sbp = predictions[..., 2] stacked_predicted_vars_dbp = predictions[..., 3] output = torch.mean(predictions) return { "loss": loss, "prediction": output, "target": target, "stacked_predicted_means_sbp": stacked_predicted_means_sbp, "stacked_predicted_means_dbp": stacked_predicted_means_dbp, "stacked_predicted_vars_sbp": stacked_predicted_vars_sbp, "stacked_predicted_vars_dbp": stacked_predicted_vars_dbp, }
def _common_step(self, batch, batch_idx, stage): data, target = batch prediction = self(data) loss = self.loss_fn(prediction, target) return { "loss": loss, "prediction": prediction, "target": target, }
[docs] class PulseDBEnsemble(PulseDBModule): """PulseDB ensemble module for Gaussian prediction aggregation.""" def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def load_models(self, model_list): self.models = torch.nn.ModuleList(model_list)
[docs] def forward(self, x): predictions = [] for model in self.models: predictions.append(model(x)) return torch.stack(predictions)
[docs] def denormalize(self, prediction): prediction[..., :2] = prediction[..., :2] * self.BP_std + self.BP_mean prediction[..., 2:] = prediction[..., 2:] * self.BP_std return prediction
[docs] def test_step(self, batch, batch_idx): """Run one ensemble test step. prediction has shape [model, batchsize, logits+noise] Parameters ---------- batch : tuple Batch containing input data and target values. batch_idx : int Index of the current batch. Returns ------- dict Dictionary containing the loss, ensemble prediction, individual model predictions, and target. """ data, target = batch prediction = self(data) # GAUSSIAN MIXTURE AS IN LAKSMINARAYANAN PAPER # The predictions are given as mean and std and returned the same way. num_models = prediction.shape[0] BP_mean = prediction[..., :2] BP_std = torch.exp(prediction[..., 2:4]).sqrt() ensemble_BP_mean = torch.mean(BP_mean, axis=0) # \sigma^2 = 1/N * \sum_{i=1}^N (\sigma_i^2 + \mu_i^2) - \mu^2 ensemble_BP_std = torch.log( (torch.sum(BP_std**2 + BP_mean**2, axis=0)) / num_models - ensemble_BP_mean**2 ) ensemble_prediction = torch.concatenate( [ensemble_BP_mean, ensemble_BP_std], axis=-1 ) loss = self.loss_fn(ensemble_prediction, target) return { "loss": loss, "prediction": ensemble_prediction, "single_models_prediction": prediction, "target": target, }