Source code for app.train

"""
File: app/train.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Entry point for training, testing and predicting QUMPHY models.
"""

import torch
import qumphy
import yaml
from lightning.pytorch.callbacks import EarlyStopping
import optuna
import wandb

torch.set_float32_matmul_precision("medium")


[docs] def load_config(args): """Merge YAML config files and command-line overrides into a single dict. Reads each YAML file listed in ``args.config`` (later files override earlier ones) and then applies the parameter overrides in ``args.parameters``. Parameters ---------- args : argparse.Namespace Parsed command-line arguments. Must expose ``config`` (list of YAML paths) and ``parameters`` (iterable of dotted-key/value overrides). Returns ------- dict Combined configuration dictionary. """ config = {} for config_file in args.config: with open(config_file, "r") as stream: qumphy.misc.misc.update_dictionary(config, yaml.safe_load(stream)) for parameter in args.parameters: qumphy.misc.misc.update_dictionary(config, parameter) return config
[docs] def train(args, config): """Run the QUMPHY trainer for the tasks selected in ``config`` / ``args``. Executes ``fit``, ``test`` and/or ``predict`` depending on whether each task is listed in ``config["tasks"]`` or enabled by the matching CLI flag. Parameters ---------- args : argparse.Namespace Parsed command-line arguments providing the ``fit``, ``test`` and ``predict`` boolean flags. config : dict Combined training configuration. Returns ------- qumphy.trainer.Trainer The trainer instance after the requested tasks have run. """ trainer = qumphy.trainer.Trainer(config) if "fit" in config["tasks"] or args.fit: trainer.fit() if "test" in config["tasks"] or args.test: trainer.test() if "predict" in config["tasks"] or args.predict: trainer.predict() return trainer
[docs] def objective(trial: optuna.trial.Trial, args, config: dict) -> float: """Optuna objective: sample hyperparameters, train, return early-stop score. For each parameter listed in ``config["optuna"]["parameters"]``, the matching Optuna ``trial.suggest_*`` function is called with its arguments and the sampled value is written back into ``config`` before training. Parameters ---------- trial : optuna.trial.Trial Trial object provided by the Optuna study. args : argparse.Namespace Parsed command-line arguments forwarded to :func:`train`. config : dict Training configuration; mutated in-place with the sampled parameters. Returns ------- float Best validation score recorded by the trainer's early-stopping callback. """ parameters = config["optuna"]["parameters"] # Suggest params print("Using following parameters for the trial:") for param, par_config in parameters.items(): suggest_function = getattr(trial, par_config["function"]) config[param] = suggest_function(**par_config["arguments"], name=param) print(f"{param} \t {config[param]}") trainer = train(args, config) earlystopping = next( (c for c in trainer._trainer.callbacks if isinstance(c, EarlyStopping)), None ) best_score = float(earlystopping.best_score.item()) del trainer if wandb.run is not None: wandb.finish() return best_score
[docs] def run_optuna(args, config: dict): """Create and run an Optuna study around :func:`objective`. Parameters ---------- args : argparse.Namespace Parsed command-line arguments forwarded to :func:`objective`. config : dict Configuration with an ``"optuna"`` section providing ``sampler``, ``direction``, optional ``sampler_arguments``, ``n_trials`` and ``timeout``. Returns ------- optuna.study.Study The completed study. Best trial, value and parameters are also printed. """ opt_config = config.get("optuna", {}) sampler_arguments = opt_config.get("sampler_arguments", {}) sampler = getattr(optuna.samplers, opt_config["sampler"])(**sampler_arguments) study = optuna.create_study( direction=opt_config["direction"], sampler=sampler, ) study.optimize( lambda trial: objective(trial, args, config), n_trials=int(opt_config.get("n_trials", 20)), timeout=opt_config.get("timeout", None), gc_after_trial=True, ) print("Best trial:", study.best_trial.number) print("Best value:", study.best_value) print("Best params:", study.best_params) return study
[docs] def main(): """CLI entry point: parse arguments and dispatch to Optuna or plain training.""" args = qumphy.misc.misc.train_argument_parser() config = load_config(args) if "optuna" in config: run_optuna(args, config) else: train(args, config)
if __name__ == "__main__": main()