qumphy.models.pulsedb module
File: qumphy/models/pulsedb.py Project: 22HLT01 QUMPHY Contact: oskar.pfeffer@ptb.de Gitlab: https://gitlab.com/qumphy Description: Lightning model for PulseDB data.
- class qumphy.models.pulsedb.PulseDBEnsemble(*args: Any, **kwargs: Any)[source]
Bases:
PulseDBModulePulseDB ensemble module for Gaussian prediction aggregation.
- test_step(batch, batch_idx)[source]
Run one ensemble test step. prediction has shape [model, batchsize, logits+noise]
- Parameters:
batch (tuple) – Batch containing input data and target values.
batch_idx (int) – Index of the current batch.
- Returns:
Dictionary containing the loss, ensemble prediction, individual model predictions, and target.
- Return type:
dict
- class qumphy.models.pulsedb.PulseDBGaussianLoss(num_distributions=2, **kwargs)[source]
Bases:
GaussianNLLLossGaussian negative log likelihood loss for the PulseDB dataset.
- forward(input, target)[source]
Calculate the Gaussian negative log likelihood loss.
- Return type:
Tensor- Parameters:
input (torch.Tensor) – Model output tensor containing predicted means and log-variances.
target (torch.Tensor) – Ground truth target tensor.
- Returns:
Gaussian negative log likelihood loss.
- Return type:
torch.Tensor
- class qumphy.models.pulsedb.PulseDBModule(*args: Any, **kwargs: Any)[source]
Bases:
LightningModuleLightning parent module for PulseDB data. Takes a specific model architecture as input. (net)
- configure_optimizers()[source]
Configure the optimizer and optional learning rate scheduler.
- Returns:
Dictionary containing the optimizer and, if provided, the learning rate scheduler configuration.
- Return type:
dict
- denormalize_std(prediction_std)[source]
Rescales in-place the standard deviation of the normalized target of 0 mean and 1 standard deviation to the original values, using the BP_std attribute.
- Parameters:
prediction_std (torch.tensor) – The input standard deviation to be denormalized.
- Returns:
The denormalized standard deviation.
- Return type:
torch.tensor
- denormalize_target(target)[source]
Rescales in-place the normalized target of 0 mean and 1 standard deviation to the original BP values, using the BP_mean and BP_std attributes.
- Parameters:
target (torch.tensor) – The input target to be denormalized.
- Returns:
The denormalized target.
- Return type:
torch.tensor
- set_dataset_stats(dataset)[source]
Set and register target statistics from the dataset.
- Parameters:
dataset (qumphy.datasets.PulseDBDataset) – PulseDB dataset instance that provides target statistics.
- Returns:
The function registers target statistics as buffers.
- Return type:
None
- class qumphy.models.pulsedb.PulseDBModule_MCD(*args: Any, **kwargs: Any)[source]
Bases:
PulseDBModulePulseDB Lightning module with Monte Carlo dropout evaluation.