Source code for app.train

import torch
import qumphy
import yaml

torch.set_float32_matmul_precision("medium")


[docs] def load_config(args): config = {} for config_file in args.config: with open(config_file, "r") as stream: qumphy.misc.update_dictionary(config, yaml.safe_load(stream)) for parameter in args.parameters: qumphy.misc.update_dictionary(config, parameter) return config
[docs] def train(args, config): trainer = qumphy.trainer.Trainer(config) if args.fit: trainer.fit() if args.test: trainer.test() if args.predict: trainer.predict()
[docs] def main(): args = qumphy.misc.train_argument_parser() config = load_config(args) train(args, config)
if __name__ == "__main__": main()