"""
File: qumphy/trainer.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Lightning Trainer.
"""
import lightning as L
import qumphy
import torch
from tqdm import tqdm
import copy
import pathlib
import yaml
[docs]
class Trainer:
"""Training pipeline wrapper for QuMPhy experiments.
This class loads the trainer, data module, feature extractor, model, and
optional ensemble configuration from a configuration dictionary.
"""
def __init__(self, config):
"""Initialize the training pipeline.
Parameters
----------
config : dict
Experiment configuration dictionary. Values in this dictionary
overwrite the default configuration.
"""
self.base_config()
qumphy.misc.misc.update_dictionary(self.config, config)
self.set_sweep_parameters()
self.seed_everything()
self.load_trainer()
self.add_parameters_to_experiment()
self.load_data_module()
self.load_feature_extractor()
self.load_model()
self.find_lr()
[docs]
def set_save_dir(self):
"""
The save directory and project name are passed to the config of the model_checkpoint and logger.
"""
save_dir = self.config["save_dir"]
project = self.config["project"]
version = 0
while pathlib.Path(save_dir + "/" + project + f"/v{version}").exists():
version += 1
version = f"v{version}"
if save_dir is not None:
for callback in self.config["trainer"]["classes"]:
if (
callback["class_path"]
== "lightning.pytorch.callbacks.ModelCheckpoint"
):
callback["init_args"]["dirpath"] = (
save_dir + f"/{project}" + f"/{version}"
)
elif callback["keyword"] == "logger":
callback["init_args"]["save_dir"] = save_dir
callback["init_args"]["project"] = project
callback["init_args"]["name"] = version
pathlib.Path(save_dir + f"/{project}").mkdir(exist_ok=True)
pathlib.Path(save_dir + f"/{project}" + f"/{version}").mkdir(exist_ok=True)
with open(save_dir + f"/{project}" + f"/{version}/config.yaml", "w") as f:
yaml.dump(self.config, f)
return save_dir + f"/{project}" + f"/{version}"
[docs]
def add_parameters_to_experiment(self):
if "sweep_parameters" in self.config:
for key, path_list in self.config["sweep_parameters"].items():
if key in self.config:
value = self.config[key]
self._trainer.logger.experiment.config[key] = value
[docs]
def base_config(self):
self.config = {
"ckpt_path": None,
"find_lr": False,
"seed_everything": None,
}
[docs]
def set_sweep_parameters(self):
"""Set sweep parameters inside the nested configuration.
Returns
-------
None
The function modifies the configuration dictionary in place.
"""
if "sweep_parameters" in self.config:
for key, path_list in self.config["sweep_parameters"].items():
if key in self.config:
value = self.config[key]
qumphy.misc.misc.set_value_at_nested_key(
self.config, path_list, value
)
[docs]
def find_lr(self):
if "find_lr" not in self.config or not self.config["find_lr"]:
return
self.load_tuner()
lr_finder = self.tuner.lr_find(
self.model,
datamodule=self.data_module,
)
fig = lr_finder.plot(suggest=True)
fig.savefig(self.model.logger.log_dir + "/lr_finder.png")
[docs]
def load_tuner(self):
"""Load the Lightning tuner.
Returns
-------
None
The function creates a tuner if it does not already exist.
"""
if hasattr(self, "tuner"):
return
self.tuner = L.pytorch.tuner.tuning.Tuner(self._trainer)
[docs]
def seed_everything(self):
"""Set random seeds for reproducibility.
Returns
-------
None
The function calls the Lightning seed utility.
"""
if self.config["seed_everything"] is not None:
L.pytorch.seed_everything(self.config["seed_everything"], workers=True)
[docs]
def load_trainer(self):
if "ensemble" in self.config:
self._trainers = []
for i in range(self.config["ensemble"]["size"]):
save_dir = self.set_save_dir()
self._trainers.append(
qumphy.misc.misc.instantiate_class(self.config["trainer"])
)
self._trainers[i].save_dir = save_dir
else:
save_dir = self.set_save_dir()
self._trainer = qumphy.misc.misc.instantiate_class(self.config["trainer"])
self._trainer.save_dir = save_dir
[docs]
def load_data_module(self):
"""Load data module from config."""
if "data" not in self.config:
return
data_module = qumphy.misc.misc.instantiate_class(self.config["data"])
self.data_module = data_module
[docs]
def load_model(self):
if "ensemble" in self.config:
self.models = [
qumphy.misc.misc.instantiate_class(self.config["model"])
for i in range(self.config["ensemble"]["size"])
]
else:
self.model = qumphy.misc.misc.instantiate_class(self.config["model"])
if self.config["ckpt_path"] is not None:
self.model.load_state_dict(
torch.load(self.config["ckpt_path"], weights_only=True)[
"state_dict"
]
)
[docs]
def load_ensemble_model(self):
"""Load trained ensemble members into an ensemble model.
Returns
-------
None
The function loads the best checkpoint for each ensemble member and
stores the combined ensemble model.
"""
if hasattr(self, "ensemble_model"):
return
ensemble_config = copy.deepcopy(self.config["model"])
qumphy.misc.misc.update_dictionary(ensemble_config, self.config["ensemble"])
self.ensemble_model = qumphy.misc.misc.instantiate_class(ensemble_config)
for i in range(self.config["ensemble"]["size"]):
self.models[i].load_state_dict(
torch.load(self.best_model_path(self._trainers[i]), weights_only=True)[
"state_dict"
]
)
self.ensemble_model.load_models(self.models)
[docs]
def fit(self):
"""Train the model or ensemble members.
Returns
-------
None
The function starts the Lightning training loop.
"""
if "ensemble" in self.config:
for i in range(self.config["ensemble"]["size"]):
trainer = self._trainers[i]
model = self.models[i]
if self.config["ckpt_path"] is not None:
ckpt_path = self.config["ckpt_path"][i]
else:
ckpt_path = None
trainer.fit(
model=model, datamodule=self.data_module, ckpt_path=ckpt_path
)
else:
self._trainer.fit(model=self.model, datamodule=self.data_module)
[docs]
def test(self):
if "ensemble" in self.config:
self.load_ensemble_model()
trainer = qumphy.misc.misc.instantiate_class(self.config["trainer"])
trainer.test(model=self.ensemble_model, datamodule=self.data_module)
else:
ckpt_path = self.best_model_path(self._trainer)
if not ckpt_path:
ckpt_path = None
self._trainer.test(
model=self.model, datamodule=self.data_module, ckpt_path=ckpt_path
)
[docs]
def best_model_path(self, trainer):
"""Get the best checkpoint path from a trainer.
Parameters
----------
trainer : lightning.Trainer
Lightning trainer whose callbacks are searched.
Returns
-------
str or None
Best model checkpoint path if a ModelCheckpoint callback is found.
"""
for callback in trainer.callbacks:
if isinstance(callback, L.pytorch.callbacks.ModelCheckpoint):
return callback.best_model_path
[docs]
def predict(self):
"""Predict using the best model and save the predictions to "best_predictions.pt"."""
if "ensemble" in self.config:
self.load_ensemble_model()
trainer = qumphy.misc.misc.instantiate_class(self.config["trainer"])
predictions = trainer.predict(
model=self.ensemble_model, datamodule=self.data_module
)
else:
predictions = self._trainer[0].predict(
model=self.model,
dataloaders=self.data_module.test_dataloader(),
ckpt_path="best",
)
predictions = torch.cat(predictions)
save_dir = self._trainer.logger.experiment.dir
torch.save(predictions, save_dir / "/predictions.pt")