Writing a Config File
QUMPHY trains models from a single YAML config that is parsed by
app/train.py
and turned into a Lightning training run by
qumphy.trainer.Trainer. The config describes what to build
(model, data module, trainer, callbacks, …), not how to run it
from the tasks: section of the config fit, test, predict.
The best starting points are the up-to-date templates under
app/configs/D2paper/.
The core idea: class_path + init_args
Every component in a QUMPHY config is constructed from the same shape:
class_path: dotted.path.to.SomeClass
init_args:
some_kwarg: value
another_kwarg: 42
At runtime, qumphy.misc.misc.instantiate_class() imports the class
named by class_path and instantiates it with the keyword arguments in
init_args. This is how Lightning’s cli-style configs work, and it
applies uniformly to optimizers, loss functions, networks, callbacks,
loggers, and the trainer itself.
Nesting components via classes
Many objects take other objects as constructor arguments — a model
needs a network, a loss function, optionally an output activation; a
trainer needs callbacks and a logger. Those nested objects are listed
under a sibling classes: key:
class_path: qumphy.models.pulsedb.PulseDBModule
init_args:
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1.e-5
weight_decay: 1.e-3
classes:
- keyword: net
class_path: qumphy.models.alexnet.AlexNet1D
init_args:
input_size: 1250
output_size: 4
- keyword: loss_fn
class_path: qumphy.models.pulsedb.PulseDBGaussianLoss
init_args:
num_distributions: 2
Each entry under classes carries a keyword: field — that’s the
constructor argument name the instantiated object is passed to (so
net=AlexNet1D(...), loss_fn=PulseDBGaussianLoss(...)).
If the same keyword appears multiple times, the instances are
collected into a list. This is how the trainer ends up with several
callbacks (see the trainer example below).
init_args: {} is a valid empty value — use it when a class takes no
arguments.
Top-level structure
A complete training config has these top-level keys:
Key |
Required |
What it is |
|---|---|---|
|
yes |
LightningModule definition (class + nested net, loss, etc.) |
|
yes |
LightningDataModule definition |
|
yes |
|
|
no |
Wraps the model + trainer in an N-member ensemble |
|
no |
If |
|
no |
Path to a checkpoint to resume from (or per-member list for ensembles) |
|
no |
Integer seed; defaults to |
|
no |
Optional pre-trained feature extractor applied to the data |
|
no |
Used by W&B sweeps to map top-level keys into nested locations |
tasks is consumed by the trainer to decide what to do (fit, test,
predict); CLI flags (--fit, --test, --predict) act as
additional toggles in app/train.py.
Walkthrough: a deep-ensemble regressor for PulseDB
The file
app/configs/D2paper/calib_alexnet.yaml
is a good reference. We’ll go through it block by block.
Header
find_lr: False
ckpt_path: null
find_lr runs Lightning’s learning-rate finder when True.
ckpt_path: null means start from scratch; provide a string path here
to resume.
model
model:
class_path: qumphy.models.pulsedb.PulseDBModule
init_args:
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1.e-5
weight_decay: 1.e-3
classes:
- keyword: net
class_path: qumphy.models.alexnet.AlexNet1D
init_args:
input_size: 1250
output_size: 4
- keyword: dataset
class_path: qumphy.data.pulsedb.PulseDBDataModule
init_args:
data_directory: /gpu-scratch/pfeffe01/pulsedb
batch_size: 32
num_workers: 1
dataset: calib
source: vital
load_data: False
- keyword: loss_fn
class_path: qumphy.models.pulsedb.PulseDBGaussianLoss
init_args:
num_distributions: 2
Notes:
The optimizer (and an optional
lr_scheduler:block) sit insideinit_args, not underclasses, because PyTorch optimizer objects are passed as ordinary keyword arguments to the LightningModule.net.output_size: 4matchesloss_fn.num_distributions: 2— the Gaussian loss expects two parameters (mu,sigma) per output target, and PulseDB calibration has two targets (SBP, DBP). For pinball loss the output size islen(quantiles) * num_targetsinstead.The
dataset:sub-class insidemodelis the data module used for internal calibration / normalisation only; the actual training data module is configured separately under the top-leveldata:key (see below).
Adding a learning-rate scheduler
init_args:
optimizer: { ... }
lr_scheduler:
class_path: torch.optim.lr_scheduler.ReduceLROnPlateau
init_args:
mode: max
factor: 0.5
patience: 8
config:
monitor: val_auc
The extra config: block at the scheduler level holds the Lightning
scheduler config (which metric to monitor, etc.), separate from the
constructor arguments. See
app/configs/D2paper/deepbeat_alexnet.yaml.
trainer
trainer:
class_path: lightning.Trainer
init_args:
accelerator: auto
fast_dev_run: False
max_epochs: 100
overfit_batches: 0
log_every_n_steps: 1
precision: "32"
default_root_dir: /…/logs/
classes:
- keyword: callbacks
class_path: qumphy.callbacks.pulsedb.PulseDBLogging_Ensemble
init_args:
log_quantities: ["mae", "std"]
log_pressure: "both"
save_predictions: True
- keyword: logger
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
save_dir: /…/logs/
offline: True
project: calib
name: alexnet
- keyword: callbacks
class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping
init_args:
monitor: val_loss
patience: 15
mode: min
- keyword: callbacks
class_path: qumphy.callbacks.progressbar.EpochProgressBar
init_args: {}
- keyword: callbacks
class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
mode: min
save_last: True
save_top_k: 1
The four entries with keyword: callbacks are collected into a list
and passed as callbacks=[...] to lightning.Trainer. The single
logger: entry is passed as logger=WandbLogger(...). Replace it with
e.g. lightning.pytorch.loggers.TensorBoardLogger if you don’t use W&B.
fast_dev_run: True is the quickest way to smoke-test a new config —
one batch through fit/validate/test, no checkpoints.
data
data:
class_path: qumphy.data.pulsedb.PulseDBDataModule
init_args:
data_directory: /gpu-scratch/pfeffe01/pulsedb
batch_size: 32
num_workers: 14
dataset: calib
source: vital
load_data: True
Notes:
data_directoryshould point at the local copy of the dataset (PulseDB / DeepBeat). Change this to match your machine.For PulseDB, valid
dataset:values includecalib,calibfree, andmini(the small dev split).For DeepBeat, see
deepbeat_alexnet.yaml:dataset: set_revised, plus atarget_formatfield (class_indexfor cross-entropy-style targets,one_hotfor the KG-loss variants).num_workersshould be increased on production runs; the value inside the model’s nesteddatasetblock can stay low (it’s only used for instantiation, not the actual loader).
ensemble
ensemble:
size: 5
class_path: qumphy.models.pulsedb.PulseDBEnsemble
init_args: {}
When ensemble: is present, the trainer creates size Lightning
trainers + models, fits each independently, then wraps them in the
class given by class_path for inference. Remove this block to train
a single model.
For DeepBeat ensembles, the wrapper takes an extra noise_samples:
argument; see
deepbeat_alexnet.yaml.
Uncertainty-quantification variants
Three approaches are represented in D2paper/, and each picks a
different combination of network + loss:
1. Deep ensemble + Gaussian NLL (default)
Network:
qumphy.models.alexnet.AlexNet1Dorqumphy.models.xresnet1d.XResNet1d50Loss:
qumphy.models.pulsedb.PulseDBGaussianLoss(num_distributions: 2)→ predictsmuandsigma; combined withensemble.size: 5gives a deep ensemble.Output size:
2 × num_distributions = 4.
Example: calib_alexnet.yaml.
2. Monte-Carlo dropout (MCD)
Module:
qumphy.models.pulsedb.PulseDBModule_MCDwithMCD_samples: 50(number of stochastic forward passes at inference).Network: the
_MCDvariant (e.g.AlexNet1D_MCD) withdropout_rate: 0.05andmcdropout: True.Logging callback:
qumphy.callbacks.pulsedb.PulseDBLogging_MCD.No
ensemble:block — uncertainty comes from dropout sampling.
Example: calib_alexnet_MCD.yaml.
3. Quantile / pinball loss
Loss:
qumphy.models.utils.pinballloss.PinballLossLoss args:
quantiles: [0.0228, 0.1587, 0.5, 0.8413, 0.9772]andnum_targets: 2.Network output size =
len(quantiles) * num_targets(here,10).Logging callback:
qumphy.callbacks.pulsedb.PulseDBLogging_Pinballloss.No
ensemble:block — a single model emits the full quantile vector.
Example: calib_alexnet_pinball.yaml.
When you change loss family, three things must agree: the loss
class, the network’s output_size, and the logging callback under the
trainer. Mismatches surface as cryptic shape errors at the first
training step.
Common edits
Goal |
Change |
|---|---|
Try a quick run |
Set |
Use a different network |
Swap the |
Use a different dataset split |
Change |
Disable W&B |
Remove the |
Stop earlier / later |
Adjust the |
Train a single model instead of an ensemble |
Delete the |
Resume from a checkpoint |
Set |
Reproducibility |
Add |
CLI parameter overrides
train.py accepts -p key.subkey:value to overlay values onto the
loaded config without editing the YAML. The same syntax can be passed
multiple times:
python app/train.py --config app/configs/D2paper/calib_alexnet.yaml \
-p trainer.init_args.max_epochs:5 \
-p data.init_args.batch_size:16
Unknown --key:value flags are also accepted (see
qumphy.misc.misc.train_argument_parser()). This is the same
mechanism the W&B sweep agent uses.
Validating a config before a long run
The fastest sanity check is a single-batch fit:
trainer:
init_args:
fast_dev_run: True
…or override from the CLI:
python app/train.py --config <file> -p trainer.init_args.fast_dev_run:True
If you use the GUI (Usage), the Config (editable) tab lets you toggle that flag and hit Run without modifying the on-disk file.
Where to put your own configs
app/configs/ is part of the repository, and most subdirectories
(D1paper/, D2paper/, …) are version-controlled curated sets.
For ad-hoc experiments use
app/configs/personal_configs/
or
app/configs/working_configs/,
which are ignored by git via app/configs/.gitignore.