qumphy.models.deepbeat module

File: qumphy/models/deepbeat.py Project: 22HLT01 QUMPHY Contact: oskar.pfeffer@ptb.de Gitlab: https://gitlab.com/qumphy Description: Lightning model for DeepBeat data.

class qumphy.models.deepbeat.DeepBeatEnsemble(noise_samples=100, **kwargs)[source]

Bases: DeepBeatModule

DeepBeat ensemble module.

This module combines predictions from multiple models and estimates uncertainty using noise-corrupted logits.

Parameters:
  • noise_samples (int) – Number of Gaussian noise samples used to corrupt logits.

  • **kwargs – Keyword arguments passed to DeepBeatModule.

forward(x)[source]

Run a forward pass through all models in the ensemble.

Parameters:

x (torch.Tensor) – Input tensor passed to each model.

Returns:

Stacked predictions from all ensemble models.

Return type:

torch.Tensor

load_models(model_list)[source]

Load models into the ensemble.

Parameters:

model_list (list) – List of PyTorch models used as ensemble members.

Returns:

The function stores the models as a ModuleList.

Return type:

None

test_step(batch, batch_idx)[source]

Run one ensemble test step. prediction has shape [model, batchsize, logits+noise]

Parameters:
  • batch (tuple) – Batch containing input data and target values.

  • batch_idx (int) – Index of the current batch.

Returns:

Dictionary containing the loss, ensemble prediction, individual model predictions, target, corrupted entropy, aleatoric uncertainty, and non-noise-corrupted entropy.

Return type:

dict

class qumphy.models.deepbeat.DeepBeatModule(*args: Any, **kwargs: Any)[source]

Bases: LightningModule

Lightning module for DeepBeat data.

This module wraps a specific model architecture and defines the common training, validation, testing, prediction, optimizer, and scheduler logic.

Parameters:
  • net (torch.nn.Module) – Model architecture used for prediction.

  • optimizer (dict) – Optimizer configuration dictionary.

  • output_activation (torch.nn.Module) – Activation function applied to the raw model output.

  • prediction_activation (torch.nn.Module) – Activation function applied to predictions returned by evaluation steps.

  • loss_fn (torch.nn.Module) – Loss function used for training, validation, and testing.

  • lr_scheduler (dict) – Learning rate scheduler configuration dictionary.

configure_optimizers()[source]

Configure the optimizer and optional learning rate scheduler.

Returns:

Dictionary containing the optimizer and, if provided, the learning rate scheduler configuration.

Return type:

dict

forward(data)[source]

Run a forward pass through the model.

Parameters:

data (torch.Tensor) – Input tensor passed to the model.

Returns:

Model prediction after applying the output activation.

Return type:

torch.Tensor

predict_step(batch, batch_idx)[source]

Run one prediction step.

Parameters:
  • batch (tuple) – Batch containing input data and target values.

  • batch_idx (int) – Index of the current batch.

Returns:

Prediction after applying the prediction activation.

Return type:

torch.Tensor

set_lr_scheduler()[source]

Instantiate and set the learning rate scheduler.

Returns:

The function modifies the learning rate scheduler configuration in place.

Return type:

None

test_step(batch, batch_idx)[source]

Run one test step.

Parameters:
  • batch (tuple) – Batch containing input data and target values.

  • batch_idx (int) – Index of the current batch.

Returns:

Dictionary containing the loss, prediction, and target.

Return type:

dict

training_step(batch, batch_idx)[source]

Run one training step.

Parameters:
  • batch (tuple) – Batch containing input data and target values.

  • batch_idx (int) – Index of the current batch.

Returns:

Dictionary containing the loss, prediction, and target.

Return type:

dict

validation_step(batch, batch_idx)[source]

Run one validation step.

Parameters:
  • batch (tuple) – Batch containing input data and target values.

  • batch_idx (int) – Index of the current batch.

Returns:

Dictionary containing the loss, prediction, and target.

Return type:

dict

class qumphy.models.deepbeat.DeepBeatModule_MCD(*args: Any, **kwargs: Any)[source]

Bases: DeepBeatModule

Lightning parent module for DeepBeat data. Takes a specific model architecture as inputand uses sampling softmax for evaluation from https://arxiv.org/abs/1703.04977 :param *args: Positional arguments passed to DeepBeatModule. :param MCD_samples: Number of Monte Carlo dropout samples. :type MCD_samples: int :param **kwargs: Keyword arguments passed to DeepBeatModule.

test_step(batch, batch_idx)[source]

Run one Monte Carlo dropout test step.

Parameters:
  • batch (tuple) – Batch containing input data and target values.

  • batch_idx (int) – Index of the current batch.

Returns:

Dictionary containing the loss, prediction, target, corrupted entropy, aleatoric uncertainty, and non-noise-corrupted entropy.

Return type:

dict