qumphy.models.pulsedb module

File: qumphy/models/pulsedb.py Project: 22HLT01 QUMPHY Contact: oskar.pfeffer@ptb.de Gitlab: https://gitlab.com/qumphy Description: Lightning model for PulseDB data.

class qumphy.models.pulsedb.PulseDBEnsemble(*args: Any, **kwargs: Any)[source]

Bases: PulseDBModule

PulseDB ensemble module for Gaussian prediction aggregation.

denormalize(prediction)[source]
forward(x)[source]
load_models(model_list)[source]
test_step(batch, batch_idx)[source]

Run one ensemble test step. prediction has shape [model, batchsize, logits+noise]

Parameters:
  • batch (tuple) – Batch containing input data and target values.

  • batch_idx (int) – Index of the current batch.

Returns:

Dictionary containing the loss, ensemble prediction, individual model predictions, and target.

Return type:

dict

class qumphy.models.pulsedb.PulseDBGaussianLoss(num_distributions=2, **kwargs)[source]

Bases: GaussianNLLLoss

Gaussian negative log likelihood loss for the PulseDB dataset.

forward(input, target)[source]

Calculate the Gaussian negative log likelihood loss.

Return type:

Tensor

Parameters:
  • input (torch.Tensor) – Model output tensor containing predicted means and log-variances.

  • target (torch.Tensor) – Ground truth target tensor.

Returns:

Gaussian negative log likelihood loss.

Return type:

torch.Tensor

class qumphy.models.pulsedb.PulseDBModule(*args: Any, **kwargs: Any)[source]

Bases: LightningModule

Lightning parent module for PulseDB data. Takes a specific model architecture as input. (net)

configure_optimizers()[source]

Configure the optimizer and optional learning rate scheduler.

Returns:

Dictionary containing the optimizer and, if provided, the learning rate scheduler configuration.

Return type:

dict

denormalize_std(prediction_std)[source]

Rescales in-place the standard deviation of the normalized target of 0 mean and 1 standard deviation to the original values, using the BP_std attribute.

Parameters:

prediction_std (torch.tensor) – The input standard deviation to be denormalized.

Returns:

The denormalized standard deviation.

Return type:

torch.tensor

denormalize_target(target)[source]

Rescales in-place the normalized target of 0 mean and 1 standard deviation to the original BP values, using the BP_mean and BP_std attributes.

Parameters:

target (torch.tensor) – The input target to be denormalized.

Returns:

The denormalized target.

Return type:

torch.tensor

forward(x)[source]
predict_step(batch, batch_idx)[source]
set_dataset_stats(dataset)[source]

Set and register target statistics from the dataset.

Parameters:

dataset (qumphy.datasets.PulseDBDataset) – PulseDB dataset instance that provides target statistics.

Returns:

The function registers target statistics as buffers.

Return type:

None

set_lr_scheduler()[source]

Configure the optimizer and optional learning rate scheduler.

Returns:

Dictionary containing the optimizer and, if provided, the learning rate scheduler configuration.

Return type:

dict

test_step(batch, batch_idx)[source]
training_step(batch, batch_idx)[source]
validation_step(batch, batch_idx)[source]
class qumphy.models.pulsedb.PulseDBModule_MCD(*args: Any, **kwargs: Any)[source]

Bases: PulseDBModule

PulseDB Lightning module with Monte Carlo dropout evaluation.

test_step(batch, batch_idx)[source]