Source code for qumphy.models.deepbeat

"""
File: qumphy/models/deepbeat.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Lightning model for DeepBeat data.
"""

import torch
import lightning as L
import torch.nn.functional as F
import numpy as np
import qumphy


[docs] class DeepBeatModule(L.LightningModule): """Lightning module for DeepBeat data. This module wraps a specific model architecture and defines the common training, validation, testing, prediction, optimizer, and scheduler logic. Parameters ---------- net : torch.nn.Module Model architecture used for prediction. optimizer : dict Optimizer configuration dictionary. output_activation : torch.nn.Module Activation function applied to the raw model output. prediction_activation : torch.nn.Module Activation function applied to predictions returned by evaluation steps. loss_fn : torch.nn.Module Loss function used for training, validation, and testing. lr_scheduler : dict Learning rate scheduler configuration dictionary. """ def __init__( self, net, optimizer, output_activation=torch.nn.Identity(), prediction_activation=torch.nn.Identity(), loss_fn=torch.nn.CrossEntropyLoss(), lr_scheduler=None, ): super().__init__() self.loss_fn = loss_fn self.net = net self.output_activation = output_activation self.prediction_activation = prediction_activation self.lr_scheduler = lr_scheduler self.optimizer_config = optimizer
[docs] def forward(self, data): """Run a forward pass through the model. Parameters ---------- data : torch.Tensor Input tensor passed to the model. Returns ------- torch.Tensor Model prediction after applying the output activation. """ data = self.net(data) prediction = self.output_activation(data) return prediction
[docs] def training_step(self, batch, batch_idx): """Run one training step. Parameters ---------- batch : tuple Batch containing input data and target values. batch_idx : int Index of the current batch. Returns ------- dict Dictionary containing the loss, prediction, and target. """ return self._common_step(batch, batch_idx, "train")
[docs] def validation_step(self, batch, batch_idx): """Run one validation step. Parameters ---------- batch : tuple Batch containing input data and target values. batch_idx : int Index of the current batch. Returns ------- dict Dictionary containing the loss, prediction, and target. """ return self._common_step(batch, batch_idx, "val")
[docs] def test_step(self, batch, batch_idx): """Run one test step. Parameters ---------- batch : tuple Batch containing input data and target values. batch_idx : int Index of the current batch. Returns ------- dict Dictionary containing the loss, prediction, and target. """ return self._common_step(batch, batch_idx, "test")
[docs] def predict_step(self, batch, batch_idx): """Run one prediction step. Parameters ---------- batch : tuple Batch containing input data and target values. batch_idx : int Index of the current batch. Returns ------- torch.Tensor Prediction after applying the prediction activation. """ data, target = batch prediction = self(data) prediction = self.prediction_activation(prediction) return prediction
def _common_step(self, batch, batch_idx, stage): """Run a common step for training, validation, or testing. Parameters ---------- batch : tuple Batch containing input data and target values. batch_idx : int Index of the current batch. stage : str Current stage name. Expected values are "train", "val", or "test". Returns ------- dict Dictionary containing the loss, prediction, and target. """ data, target = batch prediction = self(data) loss = self.loss_fn(prediction, target) prediction = self.prediction_activation(prediction[:, :2]) return { "loss": loss, "prediction": prediction, "target": target, }
[docs] def configure_optimizers(self): """Configure the optimizer and optional learning rate scheduler. Returns ------- dict Dictionary containing the optimizer and, if provided, the learning rate scheduler configuration. """ if "params" in self.optimizer_config["init_args"]: for param_group in self.optimizer_config["init_args"]["params"]: param_group["params"] = eval(param_group["params"]) else: self.optimizer_config["init_args"]["params"] = self.parameters() self.optimizer = qumphy.misc.misc.instantiate_class_from_string( self.optimizer_config["class_path"], **self.optimizer_config["init_args"] ) if self.lr_scheduler is not None: self.set_lr_scheduler() return { "optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler["config"], } else: return {"optimizer": self.optimizer}
[docs] def set_lr_scheduler(self): """Instantiate and set the learning rate scheduler. Returns ------- None The function modifies the learning rate scheduler configuration in place. """ self.lr_scheduler["init_args"]["optimizer"] = self.optimizer scheduler = qumphy.misc.misc.instantiate_class_from_string( self.lr_scheduler["class_path"], **self.lr_scheduler["init_args"], ) self.lr_scheduler["config"]["scheduler"] = scheduler
[docs] class DeepBeatModule_MCD(DeepBeatModule): """ Lightning parent module for DeepBeat data. Takes a specific model architecture as inputand uses sampling softmax for evaluation from https://arxiv.org/abs/1703.04977 Parameters ---------- *args Positional arguments passed to DeepBeatModule. MCD_samples : int Number of Monte Carlo dropout samples. **kwargs Keyword arguments passed to DeepBeatModule. """ def __init__(self, *args, MCD_samples=1, **kwargs): # super().__init__(*args, **kwargs) self.MCD_samples = MCD_samples
[docs] def test_step(self, batch, batch_idx): """Run one Monte Carlo dropout test step. Parameters ---------- batch : tuple Batch containing input data and target values. batch_idx : int Index of the current batch. Returns ------- dict Dictionary containing the loss, prediction, target, corrupted entropy, aleatoric uncertainty, and non-noise-corrupted entropy. """ data, target = batch dropout_logits = [] SM_dropout_logits = [] # list of softmaxzed logits from each dropout sample noise_dropout_outputs = [] noise_corrupted_dropout_outputs = ( [] ) # list of the averages over the softmaxxed noise corrputed logits entropies_noise_corrupted_probs = ( [] ) # list of the entropies for noise_corrupted_dropout_outputs elements entropies_non_noise_corrupted_probs = ( [] ) # list of the entropies for SM_dropout_logits elements for i in range(self.MCD_samples): preds_all = self(data) dropout_logits.append(preds_all[:, :2].cpu()) SM_logits = F.softmax(preds_all[:, :2], dim=-1) SM_dropout_logits.append(SM_logits.cpu()) noise_dropout_outputs.append(preds_all[:, 2:].cpu()) noise = preds_all[:, 2:] logit = preds_all[:, :2] noise_samples = 100 Gauss = torch.distributions.multivariate_normal.MultivariateNormal( torch.zeros(2), torch.eye(2) ) # each output logit will be corrupted by noise_samples different noises, # with the variance of each defined by the logit's corresponding # output noise param also estimated by the model epsilon = Gauss.sample([logit.shape[0], noise_samples]).to(noise.device) # Go from shape: [batch x num_classes] -> [batch x noise_samples x num_classes] sigma = noise[:, None, :].repeat(1, noise_samples, 1).to(noise.device) f = logit[:, None, :].repeat(1, noise_samples, 1).to(noise.device) # corrupt the logits with noise x_t = f + sigma * epsilon x_t_sm = F.softmax(x_t, dim=-1) # aggregate the 100 versions of each probability with an average x_t_sm_avg = torch.mean(x_t_sm, dim=-2) noise_corrupted_dropout_outputs.append(torch.unsqueeze(x_t_sm_avg, dim=-1)) # Used to express uncertainty entropy_noise_corrupted_dropout_output = -1.0 * np.sum( x_t_sm_avg.cpu().numpy() * np.log(x_t_sm_avg.cpu().numpy() + 1e-16), axis=-1, ) entropies_noise_corrupted_probs.append( entropy_noise_corrupted_dropout_output ) # used to calc different type of entropy entropy_dropout_output = -1.0 * np.sum( SM_logits.cpu().numpy() * np.log(SM_logits.cpu().numpy() + 1e-16), axis=-1, ) entropies_non_noise_corrupted_probs.append(entropy_dropout_output) # H2 is ALEATORIC UNCERTAINTY H2 = np.mean(entropies_noise_corrupted_probs, axis=0) # H3 is saved for NPL-related study, but not used for D2 results... H3 = np.mean(entropies_non_noise_corrupted_probs, axis=0) noise_corrupted_dropout_outputs = torch.cat( noise_corrupted_dropout_outputs, dim=-1 ) # take the mean over all the noise corrupted dists - this is the prediction! MCD_mean = torch.mean(noise_corrupted_dropout_outputs, dim=-1) # compute the entropy of the distribution - this it TOTAL UNCERTAINTY entropy_corrupted = -1.0 * np.sum( MCD_mean.cpu().numpy() * np.log(MCD_mean.cpu().numpy() + 1e-16), axis=-1 ) loss = self.loss_fn(preds_all[:, :2], preds_all[:, 2:], target) return { "loss": loss, "prediction": MCD_mean, "target": target, "entropy_corrupted": entropy_corrupted, "H2": H2, "H3": H3, }
# deal with the facts that no softmax activation is applied # for logits with the model def _common_step(self, batch, batch_idx, stage): data, target = batch prediction = self(data) outputs = prediction[:, :2] noises = prediction[:, 2:] loss = self.loss_fn(outputs, noises, target) sm = torch.nn.Softmax(dim=-1) softmax_preds = sm(outputs) return { "loss": loss, "prediction": softmax_preds, # outputs "target": target, }
[docs] class DeepBeatEnsemble(DeepBeatModule): """DeepBeat ensemble module. This module combines predictions from multiple models and estimates uncertainty using noise-corrupted logits. Parameters ---------- noise_samples : int Number of Gaussian noise samples used to corrupt logits. **kwargs Keyword arguments passed to DeepBeatModule. """ def __init__(self, noise_samples: int = 100, **kwargs): super().__init__(**kwargs) self.noise_samples = noise_samples self.Gauss = torch.distributions.normal.Normal(0.0, 1.0)
[docs] def load_models(self, model_list): """Load models into the ensemble. Parameters ---------- model_list : list List of PyTorch models used as ensemble members. Returns ------- None The function stores the models as a ModuleList. """ self.models = torch.nn.ModuleList(model_list)
[docs] def forward(self, x): """Run a forward pass through all models in the ensemble. Parameters ---------- x : torch.Tensor Input tensor passed to each model. Returns ------- torch.Tensor Stacked predictions from all ensemble models. """ predictions = [] for model in self.models: predictions.append(model(x)) return torch.stack(predictions)
[docs] def test_step(self, batch, batch_idx): """Run one ensemble test step. prediction has shape [model, batchsize, logits+noise] Parameters ---------- batch : tuple Batch containing input data and target values. batch_idx : int Index of the current batch. Returns ------- dict Dictionary containing the loss, ensemble prediction, individual model predictions, target, corrupted entropy, aleatoric uncertainty, and non-noise-corrupted entropy. """ data, target = batch prediction = self(data) loss = self.loss_fn(prediction.mean(dim=0), target) logits = prediction[..., :2] noise = prediction[..., 2:] logits_shape = logits.shape logits_shape = logits_shape[:-1] + (self.noise_samples,) + logits_shape[-1:] epsilon = self.Gauss.sample(logits_shape).to(logits.device) sigma = noise.unsqueeze(-2) f = logits.unsqueeze(-2) logits_corrupted = f + sigma * epsilon logits_corrupted = F.softmax(logits_corrupted, dim=-1) logits_corrupted = logits_corrupted.mean(dim=-2) logits = F.softmax(logits, dim=-1) entropy_corrupted = -1.0 * torch.sum( logits_corrupted * torch.log(logits_corrupted + 1e-16), dim=-1 ) entropy = -1.0 * torch.sum(logits * torch.log(logits + 1e-16), dim=-1) ensemble_logits = torch.mean(logits_corrupted, dim=0) ensemble_entropy = -1.0 * torch.sum( ensemble_logits * torch.log(ensemble_logits + 1e-16), dim=-1 ) H2 = torch.mean(entropy_corrupted, dim=0) H3 = torch.mean(entropy, dim=0) return { "loss": loss, "prediction": ensemble_logits, "single_models_prediction": logits_corrupted, "target": target, "entropy_corrupted": ensemble_entropy, "H2": H2, "H3": H3, }