qumphy.trainer module
File: qumphy/trainer.py Project: 22HLT01 QUMPHY Contact: oskar.pfeffer@ptb.de Gitlab: https://gitlab.com/qumphy Description: Lightning Trainer.
- class qumphy.trainer.Trainer(config)[source]
Bases:
objectTraining pipeline wrapper for QuMPhy experiments.
This class loads the trainer, data module, feature extractor, model, and optional ensemble configuration from a configuration dictionary.
- best_model_path(trainer)[source]
Get the best checkpoint path from a trainer.
- Parameters:
trainer (lightning.Trainer) – Lightning trainer whose callbacks are searched.
- Returns:
Best model checkpoint path if a ModelCheckpoint callback is found.
- Return type:
str or None
- extract_features()[source]
Extract features from the train, validation, and test datasets.
- Returns:
The function replaces each dataset’s data with extracted features and keeps the corresponding labels.
- Return type:
None
- fit()[source]
Train the model or ensemble members.
- Returns:
The function starts the Lightning training loop.
- Return type:
None
- load_ensemble_model()[source]
Load trained ensemble members into an ensemble model.
- Returns:
The function loads the best checkpoint for each ensemble member and stores the combined ensemble model.
- Return type:
None
- load_tuner()[source]
Load the Lightning tuner.
- Returns:
The function creates a tuner if it does not already exist.
- Return type:
None
- seed_everything()[source]
Set random seeds for reproducibility.
- Returns:
The function calls the Lightning seed utility.
- Return type:
None
- set_save_dir()[source]
The save directory and project name are passed to the config of the model_checkpoint and logger.