"""
File: qumphy/models/deepbeat.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Lightning model for DeepBeat data.
"""
import torch
import lightning as L
import torch.nn.functional as F
import numpy as np
import qumphy
[docs]
class DeepBeatModule(L.LightningModule):
"""Lightning module for DeepBeat data.
This module wraps a specific model architecture and defines the common
training, validation, testing, prediction, optimizer, and scheduler logic.
Parameters
----------
net : torch.nn.Module
Model architecture used for prediction.
optimizer : dict
Optimizer configuration dictionary.
output_activation : torch.nn.Module
Activation function applied to the raw model output.
prediction_activation : torch.nn.Module
Activation function applied to predictions returned by evaluation steps.
loss_fn : torch.nn.Module
Loss function used for training, validation, and testing.
lr_scheduler : dict
Learning rate scheduler configuration dictionary.
"""
def __init__(
self,
net,
optimizer,
output_activation=torch.nn.Identity(),
prediction_activation=torch.nn.Identity(),
loss_fn=torch.nn.CrossEntropyLoss(),
lr_scheduler=None,
):
super().__init__()
self.loss_fn = loss_fn
self.net = net
self.output_activation = output_activation
self.prediction_activation = prediction_activation
self.lr_scheduler = lr_scheduler
self.optimizer_config = optimizer
[docs]
def forward(self, data):
"""Run a forward pass through the model.
Parameters
----------
data : torch.Tensor
Input tensor passed to the model.
Returns
-------
torch.Tensor
Model prediction after applying the output activation.
"""
data = self.net(data)
prediction = self.output_activation(data)
return prediction
[docs]
def training_step(self, batch, batch_idx):
"""Run one training step.
Parameters
----------
batch : tuple
Batch containing input data and target values.
batch_idx : int
Index of the current batch.
Returns
-------
dict
Dictionary containing the loss, prediction, and target.
"""
return self._common_step(batch, batch_idx, "train")
[docs]
def validation_step(self, batch, batch_idx):
"""Run one validation step.
Parameters
----------
batch : tuple
Batch containing input data and target values.
batch_idx : int
Index of the current batch.
Returns
-------
dict
Dictionary containing the loss, prediction, and target.
"""
return self._common_step(batch, batch_idx, "val")
[docs]
def test_step(self, batch, batch_idx):
"""Run one test step.
Parameters
----------
batch : tuple
Batch containing input data and target values.
batch_idx : int
Index of the current batch.
Returns
-------
dict
Dictionary containing the loss, prediction, and target.
"""
return self._common_step(batch, batch_idx, "test")
[docs]
def predict_step(self, batch, batch_idx):
"""Run one prediction step.
Parameters
----------
batch : tuple
Batch containing input data and target values.
batch_idx : int
Index of the current batch.
Returns
-------
torch.Tensor
Prediction after applying the prediction activation.
"""
data, target = batch
prediction = self(data)
prediction = self.prediction_activation(prediction)
return prediction
def _common_step(self, batch, batch_idx, stage):
"""Run a common step for training, validation, or testing.
Parameters
----------
batch : tuple
Batch containing input data and target values.
batch_idx : int
Index of the current batch.
stage : str
Current stage name. Expected values are "train", "val", or "test".
Returns
-------
dict
Dictionary containing the loss, prediction, and target.
"""
data, target = batch
prediction = self(data)
loss = self.loss_fn(prediction, target)
prediction = self.prediction_activation(prediction[:, :2])
return {
"loss": loss,
"prediction": prediction,
"target": target,
}
[docs]
def set_lr_scheduler(self):
"""Instantiate and set the learning rate scheduler.
Returns
-------
None
The function modifies the learning rate scheduler configuration in place.
"""
self.lr_scheduler["init_args"]["optimizer"] = self.optimizer
scheduler = qumphy.misc.misc.instantiate_class_from_string(
self.lr_scheduler["class_path"],
**self.lr_scheduler["init_args"],
)
self.lr_scheduler["config"]["scheduler"] = scheduler
[docs]
class DeepBeatModule_MCD(DeepBeatModule):
"""
Lightning parent module for DeepBeat data.
Takes a specific model architecture as inputand uses sampling softmax for evaluation from https://arxiv.org/abs/1703.04977
Parameters
----------
*args
Positional arguments passed to DeepBeatModule.
MCD_samples : int
Number of Monte Carlo dropout samples.
**kwargs
Keyword arguments passed to DeepBeatModule.
"""
def __init__(self, *args, MCD_samples=1, **kwargs): #
super().__init__(*args, **kwargs)
self.MCD_samples = MCD_samples
[docs]
def test_step(self, batch, batch_idx):
"""Run one Monte Carlo dropout test step.
Parameters
----------
batch : tuple
Batch containing input data and target values.
batch_idx : int
Index of the current batch.
Returns
-------
dict
Dictionary containing the loss, prediction, target, corrupted
entropy, aleatoric uncertainty, and non-noise-corrupted entropy.
"""
data, target = batch
dropout_logits = []
SM_dropout_logits = [] # list of softmaxzed logits from each dropout sample
noise_dropout_outputs = []
noise_corrupted_dropout_outputs = (
[]
) # list of the averages over the softmaxxed noise corrputed logits
entropies_noise_corrupted_probs = (
[]
) # list of the entropies for noise_corrupted_dropout_outputs elements
entropies_non_noise_corrupted_probs = (
[]
) # list of the entropies for SM_dropout_logits elements
for i in range(self.MCD_samples):
preds_all = self(data)
dropout_logits.append(preds_all[:, :2].cpu())
SM_logits = F.softmax(preds_all[:, :2], dim=-1)
SM_dropout_logits.append(SM_logits.cpu())
noise_dropout_outputs.append(preds_all[:, 2:].cpu())
noise = preds_all[:, 2:]
logit = preds_all[:, :2]
noise_samples = 100
Gauss = torch.distributions.multivariate_normal.MultivariateNormal(
torch.zeros(2), torch.eye(2)
)
# each output logit will be corrupted by noise_samples different noises,
# with the variance of each defined by the logit's corresponding
# output noise param also estimated by the model
epsilon = Gauss.sample([logit.shape[0], noise_samples]).to(noise.device)
# Go from shape: [batch x num_classes] -> [batch x noise_samples x num_classes]
sigma = noise[:, None, :].repeat(1, noise_samples, 1).to(noise.device)
f = logit[:, None, :].repeat(1, noise_samples, 1).to(noise.device)
# corrupt the logits with noise
x_t = f + sigma * epsilon
x_t_sm = F.softmax(x_t, dim=-1)
# aggregate the 100 versions of each probability with an average
x_t_sm_avg = torch.mean(x_t_sm, dim=-2)
noise_corrupted_dropout_outputs.append(torch.unsqueeze(x_t_sm_avg, dim=-1))
# Used to express uncertainty
entropy_noise_corrupted_dropout_output = -1.0 * np.sum(
x_t_sm_avg.cpu().numpy() * np.log(x_t_sm_avg.cpu().numpy() + 1e-16),
axis=-1,
)
entropies_noise_corrupted_probs.append(
entropy_noise_corrupted_dropout_output
)
# used to calc different type of entropy
entropy_dropout_output = -1.0 * np.sum(
SM_logits.cpu().numpy() * np.log(SM_logits.cpu().numpy() + 1e-16),
axis=-1,
)
entropies_non_noise_corrupted_probs.append(entropy_dropout_output)
# H2 is ALEATORIC UNCERTAINTY
H2 = np.mean(entropies_noise_corrupted_probs, axis=0)
# H3 is saved for NPL-related study, but not used for D2 results...
H3 = np.mean(entropies_non_noise_corrupted_probs, axis=0)
noise_corrupted_dropout_outputs = torch.cat(
noise_corrupted_dropout_outputs, dim=-1
)
# take the mean over all the noise corrupted dists - this is the prediction!
MCD_mean = torch.mean(noise_corrupted_dropout_outputs, dim=-1)
# compute the entropy of the distribution - this it TOTAL UNCERTAINTY
entropy_corrupted = -1.0 * np.sum(
MCD_mean.cpu().numpy() * np.log(MCD_mean.cpu().numpy() + 1e-16), axis=-1
)
loss = self.loss_fn(preds_all[:, :2], preds_all[:, 2:], target)
return {
"loss": loss,
"prediction": MCD_mean,
"target": target,
"entropy_corrupted": entropy_corrupted,
"H2": H2,
"H3": H3,
}
# deal with the facts that no softmax activation is applied
# for logits with the model
def _common_step(self, batch, batch_idx, stage):
data, target = batch
prediction = self(data)
outputs = prediction[:, :2]
noises = prediction[:, 2:]
loss = self.loss_fn(outputs, noises, target)
sm = torch.nn.Softmax(dim=-1)
softmax_preds = sm(outputs)
return {
"loss": loss,
"prediction": softmax_preds, # outputs
"target": target,
}
[docs]
class DeepBeatEnsemble(DeepBeatModule):
"""DeepBeat ensemble module.
This module combines predictions from multiple models and estimates
uncertainty using noise-corrupted logits.
Parameters
----------
noise_samples : int
Number of Gaussian noise samples used to corrupt logits.
**kwargs
Keyword arguments passed to DeepBeatModule.
"""
def __init__(self, noise_samples: int = 100, **kwargs):
super().__init__(**kwargs)
self.noise_samples = noise_samples
self.Gauss = torch.distributions.normal.Normal(0.0, 1.0)
[docs]
def load_models(self, model_list):
"""Load models into the ensemble.
Parameters
----------
model_list : list
List of PyTorch models used as ensemble members.
Returns
-------
None
The function stores the models as a ModuleList.
"""
self.models = torch.nn.ModuleList(model_list)
[docs]
def forward(self, x):
"""Run a forward pass through all models in the ensemble.
Parameters
----------
x : torch.Tensor
Input tensor passed to each model.
Returns
-------
torch.Tensor
Stacked predictions from all ensemble models.
"""
predictions = []
for model in self.models:
predictions.append(model(x))
return torch.stack(predictions)
[docs]
def test_step(self, batch, batch_idx):
"""Run one ensemble test step.
prediction has shape [model, batchsize, logits+noise]
Parameters
----------
batch : tuple
Batch containing input data and target values.
batch_idx : int
Index of the current batch.
Returns
-------
dict
Dictionary containing the loss, ensemble prediction, individual
model predictions, target, corrupted entropy, aleatoric uncertainty,
and non-noise-corrupted entropy.
"""
data, target = batch
prediction = self(data)
loss = self.loss_fn(prediction.mean(dim=0), target)
logits = prediction[..., :2]
noise = prediction[..., 2:]
logits_shape = logits.shape
logits_shape = logits_shape[:-1] + (self.noise_samples,) + logits_shape[-1:]
epsilon = self.Gauss.sample(logits_shape).to(logits.device)
sigma = noise.unsqueeze(-2)
f = logits.unsqueeze(-2)
logits_corrupted = f + sigma * epsilon
logits_corrupted = F.softmax(logits_corrupted, dim=-1)
logits_corrupted = logits_corrupted.mean(dim=-2)
logits = F.softmax(logits, dim=-1)
entropy_corrupted = -1.0 * torch.sum(
logits_corrupted * torch.log(logits_corrupted + 1e-16), dim=-1
)
entropy = -1.0 * torch.sum(logits * torch.log(logits + 1e-16), dim=-1)
ensemble_logits = torch.mean(logits_corrupted, dim=0)
ensemble_entropy = -1.0 * torch.sum(
ensemble_logits * torch.log(ensemble_logits + 1e-16), dim=-1
)
H2 = torch.mean(entropy_corrupted, dim=0)
H3 = torch.mean(entropy, dim=0)
return {
"loss": loss,
"prediction": ensemble_logits,
"single_models_prediction": logits_corrupted,
"target": target,
"entropy_corrupted": ensemble_entropy,
"H2": H2,
"H3": H3,
}