Source code for qumphy.trainer

"""
File: qumphy/trainer.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Lightning Trainer.
"""

import lightning as L
import qumphy
import torch
from tqdm import tqdm
import copy
import pathlib
import yaml


[docs] class Trainer: """Training pipeline wrapper for QuMPhy experiments. This class loads the trainer, data module, feature extractor, model, and optional ensemble configuration from a configuration dictionary. """ def __init__(self, config): """Initialize the training pipeline. Parameters ---------- config : dict Experiment configuration dictionary. Values in this dictionary overwrite the default configuration. """ self.base_config() qumphy.misc.misc.update_dictionary(self.config, config) self.set_sweep_parameters() self.seed_everything() self.load_trainer() self.add_parameters_to_experiment() self.load_data_module() self.load_feature_extractor() self.load_model() self.find_lr()
[docs] def set_save_dir(self): """ The save directory and project name are passed to the config of the model_checkpoint and logger. """ save_dir = self.config["save_dir"] project = self.config["project"] version = 0 while pathlib.Path(save_dir + "/" + project + f"/v{version}").exists(): version += 1 version = f"v{version}" if save_dir is not None: for callback in self.config["trainer"]["classes"]: if ( callback["class_path"] == "lightning.pytorch.callbacks.ModelCheckpoint" ): callback["init_args"]["dirpath"] = ( save_dir + f"/{project}" + f"/{version}" ) elif callback["keyword"] == "logger": callback["init_args"]["save_dir"] = save_dir callback["init_args"]["project"] = project callback["init_args"]["name"] = version pathlib.Path(save_dir + f"/{project}").mkdir(exist_ok=True) pathlib.Path(save_dir + f"/{project}" + f"/{version}").mkdir(exist_ok=True) with open(save_dir + f"/{project}" + f"/{version}/config.yaml", "w") as f: yaml.dump(self.config, f) return save_dir + f"/{project}" + f"/{version}"
[docs] def add_parameters_to_experiment(self): if "sweep_parameters" in self.config: for key, path_list in self.config["sweep_parameters"].items(): if key in self.config: value = self.config[key] self._trainer.logger.experiment.config[key] = value
[docs] def base_config(self): self.config = { "ckpt_path": None, "find_lr": False, "seed_everything": None, }
[docs] def set_sweep_parameters(self): """Set sweep parameters inside the nested configuration. Returns ------- None The function modifies the configuration dictionary in place. """ if "sweep_parameters" in self.config: for key, path_list in self.config["sweep_parameters"].items(): if key in self.config: value = self.config[key] qumphy.misc.misc.set_value_at_nested_key( self.config, path_list, value )
[docs] def extract_features(self): """Extract features from the train, validation, and test datasets. Returns ------- None The function replaces each dataset's data with extracted features and keeps the corresponding labels. """ if "feature_extractor" not in self.config: return device = "cuda" if torch.cuda.is_available() else "cpu" self.feature_extractor.to(device) self.feature_extractor.eval() features = [] labels = [] self.data_module.setup(stage="fit") self.data_module.setup(stage="test") train_data_loader = self.data_module.train_dataloader() val_data_loader = self.data_module.val_dataloader() test_data_loader = self.data_module.test_dataloader() print("Extracting features...") for data_loader in [train_data_loader, val_data_loader, test_data_loader]: with torch.no_grad(): for batch in tqdm(data_loader): data, label = batch data = data.to(device) feature = self.feature_extractor(data) features.append(feature.cpu()) labels.append(label.cpu()) features = torch.cat(features) labels = torch.cat(labels) data_loader.dataset.data = features data_loader.dataset.target = labels features = [] labels = []
[docs] def find_lr(self): if "find_lr" not in self.config or not self.config["find_lr"]: return self.load_tuner() lr_finder = self.tuner.lr_find( self.model, datamodule=self.data_module, ) fig = lr_finder.plot(suggest=True) fig.savefig(self.model.logger.log_dir + "/lr_finder.png")
[docs] def load_tuner(self): """Load the Lightning tuner. Returns ------- None The function creates a tuner if it does not already exist. """ if hasattr(self, "tuner"): return self.tuner = L.pytorch.tuner.tuning.Tuner(self._trainer)
[docs] def seed_everything(self): """Set random seeds for reproducibility. Returns ------- None The function calls the Lightning seed utility. """ if self.config["seed_everything"] is not None: L.pytorch.seed_everything(self.config["seed_everything"], workers=True)
[docs] def load_trainer(self): if "ensemble" in self.config: self._trainers = [] for i in range(self.config["ensemble"]["size"]): save_dir = self.set_save_dir() self._trainers.append( qumphy.misc.misc.instantiate_class(self.config["trainer"]) ) self._trainers[i].save_dir = save_dir else: save_dir = self.set_save_dir() self._trainer = qumphy.misc.misc.instantiate_class(self.config["trainer"]) self._trainer.save_dir = save_dir
[docs] def load_feature_extractor(self): if "feature_extractor" not in self.config: return feature_extractor = qumphy.misc.misc.instantiate_class( self.config["feature_extractor"] ) self.feature_extractor = feature_extractor
[docs] def load_data_module(self): """Load data module from config.""" if "data" not in self.config: return data_module = qumphy.misc.misc.instantiate_class(self.config["data"]) self.data_module = data_module
[docs] def load_model(self): if "ensemble" in self.config: self.models = [ qumphy.misc.misc.instantiate_class(self.config["model"]) for i in range(self.config["ensemble"]["size"]) ] else: self.model = qumphy.misc.misc.instantiate_class(self.config["model"]) if self.config["ckpt_path"] is not None: self.model.load_state_dict( torch.load(self.config["ckpt_path"], weights_only=True)[ "state_dict" ] )
[docs] def load_ensemble_model(self): """Load trained ensemble members into an ensemble model. Returns ------- None The function loads the best checkpoint for each ensemble member and stores the combined ensemble model. """ if hasattr(self, "ensemble_model"): return ensemble_config = copy.deepcopy(self.config["model"]) qumphy.misc.misc.update_dictionary(ensemble_config, self.config["ensemble"]) self.ensemble_model = qumphy.misc.misc.instantiate_class(ensemble_config) for i in range(self.config["ensemble"]["size"]): self.models[i].load_state_dict( torch.load(self.best_model_path(self._trainers[i]), weights_only=True)[ "state_dict" ] ) self.ensemble_model.load_models(self.models)
[docs] def fit(self): """Train the model or ensemble members. Returns ------- None The function starts the Lightning training loop. """ if "ensemble" in self.config: for i in range(self.config["ensemble"]["size"]): trainer = self._trainers[i] model = self.models[i] if self.config["ckpt_path"] is not None: ckpt_path = self.config["ckpt_path"][i] else: ckpt_path = None trainer.fit( model=model, datamodule=self.data_module, ckpt_path=ckpt_path ) else: self._trainer.fit(model=self.model, datamodule=self.data_module)
[docs] def test(self): if "ensemble" in self.config: self.load_ensemble_model() trainer = qumphy.misc.misc.instantiate_class(self.config["trainer"]) trainer.test(model=self.ensemble_model, datamodule=self.data_module) else: ckpt_path = self.best_model_path(self._trainer) if not ckpt_path: ckpt_path = None self._trainer.test( model=self.model, datamodule=self.data_module, ckpt_path=ckpt_path )
[docs] def best_model_path(self, trainer): """Get the best checkpoint path from a trainer. Parameters ---------- trainer : lightning.Trainer Lightning trainer whose callbacks are searched. Returns ------- str or None Best model checkpoint path if a ModelCheckpoint callback is found. """ for callback in trainer.callbacks: if isinstance(callback, L.pytorch.callbacks.ModelCheckpoint): return callback.best_model_path
[docs] def predict(self): """Predict using the best model and save the predictions to "best_predictions.pt".""" if "ensemble" in self.config: self.load_ensemble_model() trainer = qumphy.misc.misc.instantiate_class(self.config["trainer"]) predictions = trainer.predict( model=self.ensemble_model, datamodule=self.data_module ) else: predictions = self._trainer[0].predict( model=self.model, dataloaders=self.data_module.test_dataloader(), ckpt_path="best", ) predictions = torch.cat(predictions) save_dir = self._trainer.logger.experiment.dir torch.save(predictions, save_dir / "/predictions.pt")