"""
File: app/train.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Entry point for training, testing and predicting QUMPHY models.
"""
import torch
import qumphy
import yaml
from lightning.pytorch.callbacks import EarlyStopping
import optuna
import wandb
torch.set_float32_matmul_precision("medium")
[docs]
def load_config(args):
"""Merge YAML config files and command-line overrides into a single dict.
Reads each YAML file listed in ``args.config`` (later files override earlier
ones) and then applies the parameter overrides in ``args.parameters``.
Parameters
----------
args : argparse.Namespace
Parsed command-line arguments. Must expose ``config`` (list of YAML
paths) and ``parameters`` (iterable of dotted-key/value overrides).
Returns
-------
dict
Combined configuration dictionary.
"""
config = {}
for config_file in args.config:
with open(config_file, "r") as stream:
qumphy.misc.misc.update_dictionary(config, yaml.safe_load(stream))
for parameter in args.parameters:
qumphy.misc.misc.update_dictionary(config, parameter)
return config
[docs]
def train(args, config):
"""Run the QUMPHY trainer for the tasks selected in ``config`` / ``args``.
Executes ``fit``, ``test`` and/or ``predict`` depending on whether each
task is listed in ``config["tasks"]`` or enabled by the matching CLI flag.
Parameters
----------
args : argparse.Namespace
Parsed command-line arguments providing the ``fit``, ``test`` and
``predict`` boolean flags.
config : dict
Combined training configuration.
Returns
-------
qumphy.trainer.Trainer
The trainer instance after the requested tasks have run.
"""
trainer = qumphy.trainer.Trainer(config)
if "fit" in config["tasks"] or args.fit:
trainer.fit()
if "test" in config["tasks"] or args.test:
trainer.test()
if "predict" in config["tasks"] or args.predict:
trainer.predict()
return trainer
[docs]
def objective(trial: optuna.trial.Trial, args, config: dict) -> float:
"""Optuna objective: sample hyperparameters, train, return early-stop score.
For each parameter listed in ``config["optuna"]["parameters"]``, the
matching Optuna ``trial.suggest_*`` function is called with its arguments
and the sampled value is written back into ``config`` before training.
Parameters
----------
trial : optuna.trial.Trial
Trial object provided by the Optuna study.
args : argparse.Namespace
Parsed command-line arguments forwarded to :func:`train`.
config : dict
Training configuration; mutated in-place with the sampled parameters.
Returns
-------
float
Best validation score recorded by the trainer's early-stopping callback.
"""
parameters = config["optuna"]["parameters"]
# Suggest params
print("Using following parameters for the trial:")
for param, par_config in parameters.items():
suggest_function = getattr(trial, par_config["function"])
config[param] = suggest_function(**par_config["arguments"], name=param)
print(f"{param} \t {config[param]}")
trainer = train(args, config)
earlystopping = next(
(c for c in trainer._trainer.callbacks if isinstance(c, EarlyStopping)), None
)
best_score = float(earlystopping.best_score.item())
del trainer
if wandb.run is not None:
wandb.finish()
return best_score
[docs]
def run_optuna(args, config: dict):
"""Create and run an Optuna study around :func:`objective`.
Parameters
----------
args : argparse.Namespace
Parsed command-line arguments forwarded to :func:`objective`.
config : dict
Configuration with an ``"optuna"`` section providing ``sampler``,
``direction``, optional ``sampler_arguments``, ``n_trials`` and
``timeout``.
Returns
-------
optuna.study.Study
The completed study. Best trial, value and parameters are also printed.
"""
opt_config = config.get("optuna", {})
sampler_arguments = opt_config.get("sampler_arguments", {})
sampler = getattr(optuna.samplers, opt_config["sampler"])(**sampler_arguments)
study = optuna.create_study(
direction=opt_config["direction"],
sampler=sampler,
)
study.optimize(
lambda trial: objective(trial, args, config),
n_trials=int(opt_config.get("n_trials", 20)),
timeout=opt_config.get("timeout", None),
gc_after_trial=True,
)
print("Best trial:", study.best_trial.number)
print("Best value:", study.best_value)
print("Best params:", study.best_params)
return study
[docs]
def main():
"""CLI entry point: parse arguments and dispatch to Optuna or plain training."""
args = qumphy.misc.misc.train_argument_parser()
config = load_config(args)
if "optuna" in config:
run_optuna(args, config)
else:
train(args, config)
if __name__ == "__main__":
main()