qumphy.trainer module

File: qumphy/trainer.py Project: 22HLT01 QUMPHY Contact: oskar.pfeffer@ptb.de Gitlab: https://gitlab.com/qumphy Description: Lightning Trainer.

class qumphy.trainer.Trainer(config)[source]

Bases: object

Training pipeline wrapper for QuMPhy experiments.

This class loads the trainer, data module, feature extractor, model, and optional ensemble configuration from a configuration dictionary.

add_parameters_to_experiment()[source]
base_config()[source]
best_model_path(trainer)[source]

Get the best checkpoint path from a trainer.

Parameters:

trainer (lightning.Trainer) – Lightning trainer whose callbacks are searched.

Returns:

Best model checkpoint path if a ModelCheckpoint callback is found.

Return type:

str or None

extract_features()[source]

Extract features from the train, validation, and test datasets.

Returns:

The function replaces each dataset’s data with extracted features and keeps the corresponding labels.

Return type:

None

find_lr()[source]
fit()[source]

Train the model or ensemble members.

Returns:

The function starts the Lightning training loop.

Return type:

None

load_data_module()[source]

Load data module from config.

load_ensemble_model()[source]

Load trained ensemble members into an ensemble model.

Returns:

The function loads the best checkpoint for each ensemble member and stores the combined ensemble model.

Return type:

None

load_feature_extractor()[source]
load_model()[source]
load_trainer()[source]
load_tuner()[source]

Load the Lightning tuner.

Returns:

The function creates a tuner if it does not already exist.

Return type:

None

predict()[source]

Predict using the best model and save the predictions to “best_predictions.pt”.

seed_everything()[source]

Set random seeds for reproducibility.

Returns:

The function calls the Lightning seed utility.

Return type:

None

set_save_dir()[source]

The save directory and project name are passed to the config of the model_checkpoint and logger.

set_sweep_parameters()[source]

Set sweep parameters inside the nested configuration.

Returns:

The function modifies the configuration dictionary in place.

Return type:

None

test()[source]