qumphy.models.deepbeat module
File: qumphy/models/deepbeat.py Project: 22HLT01 QUMPHY Contact: oskar.pfeffer@ptb.de Gitlab: https://gitlab.com/qumphy Description: Lightning model for DeepBeat data.
- class qumphy.models.deepbeat.DeepBeatEnsemble(noise_samples=100, **kwargs)[source]
Bases:
DeepBeatModuleDeepBeat ensemble module.
This module combines predictions from multiple models and estimates uncertainty using noise-corrupted logits.
- Parameters:
noise_samples (int) – Number of Gaussian noise samples used to corrupt logits.
**kwargs – Keyword arguments passed to DeepBeatModule.
- forward(x)[source]
Run a forward pass through all models in the ensemble.
- Parameters:
x (torch.Tensor) – Input tensor passed to each model.
- Returns:
Stacked predictions from all ensemble models.
- Return type:
torch.Tensor
- load_models(model_list)[source]
Load models into the ensemble.
- Parameters:
model_list (list) – List of PyTorch models used as ensemble members.
- Returns:
The function stores the models as a ModuleList.
- Return type:
None
- test_step(batch, batch_idx)[source]
Run one ensemble test step. prediction has shape [model, batchsize, logits+noise]
- Parameters:
batch (tuple) – Batch containing input data and target values.
batch_idx (int) – Index of the current batch.
- Returns:
Dictionary containing the loss, ensemble prediction, individual model predictions, target, corrupted entropy, aleatoric uncertainty, and non-noise-corrupted entropy.
- Return type:
dict
- class qumphy.models.deepbeat.DeepBeatModule(*args: Any, **kwargs: Any)[source]
Bases:
LightningModuleLightning module for DeepBeat data.
This module wraps a specific model architecture and defines the common training, validation, testing, prediction, optimizer, and scheduler logic.
- Parameters:
net (torch.nn.Module) – Model architecture used for prediction.
optimizer (dict) – Optimizer configuration dictionary.
output_activation (torch.nn.Module) – Activation function applied to the raw model output.
prediction_activation (torch.nn.Module) – Activation function applied to predictions returned by evaluation steps.
loss_fn (torch.nn.Module) – Loss function used for training, validation, and testing.
lr_scheduler (dict) – Learning rate scheduler configuration dictionary.
- configure_optimizers()[source]
Configure the optimizer and optional learning rate scheduler.
- Returns:
Dictionary containing the optimizer and, if provided, the learning rate scheduler configuration.
- Return type:
dict
- forward(data)[source]
Run a forward pass through the model.
- Parameters:
data (torch.Tensor) – Input tensor passed to the model.
- Returns:
Model prediction after applying the output activation.
- Return type:
torch.Tensor
- predict_step(batch, batch_idx)[source]
Run one prediction step.
- Parameters:
batch (tuple) – Batch containing input data and target values.
batch_idx (int) – Index of the current batch.
- Returns:
Prediction after applying the prediction activation.
- Return type:
torch.Tensor
- set_lr_scheduler()[source]
Instantiate and set the learning rate scheduler.
- Returns:
The function modifies the learning rate scheduler configuration in place.
- Return type:
None
- test_step(batch, batch_idx)[source]
Run one test step.
- Parameters:
batch (tuple) – Batch containing input data and target values.
batch_idx (int) – Index of the current batch.
- Returns:
Dictionary containing the loss, prediction, and target.
- Return type:
dict
- class qumphy.models.deepbeat.DeepBeatModule_MCD(*args: Any, **kwargs: Any)[source]
Bases:
DeepBeatModuleLightning parent module for DeepBeat data. Takes a specific model architecture as inputand uses sampling softmax for evaluation from https://arxiv.org/abs/1703.04977 :param *args: Positional arguments passed to DeepBeatModule. :param MCD_samples: Number of Monte Carlo dropout samples. :type MCD_samples: int :param **kwargs: Keyword arguments passed to DeepBeatModule.
- test_step(batch, batch_idx)[source]
Run one Monte Carlo dropout test step.
- Parameters:
batch (tuple) – Batch containing input data and target values.
batch_idx (int) – Index of the current batch.
- Returns:
Dictionary containing the loss, prediction, target, corrupted entropy, aleatoric uncertainty, and non-noise-corrupted entropy.
- Return type:
dict