import torch
import qumphy
import yaml
torch.set_float32_matmul_precision("medium")
[docs]
def load_config(args):
config = {}
for config_file in args.config:
with open(config_file, "r") as stream:
qumphy.misc.update_dictionary(config, yaml.safe_load(stream))
for parameter in args.parameters:
qumphy.misc.update_dictionary(config, parameter)
return config
[docs]
def train(args, config):
trainer = qumphy.trainer.Trainer(config)
if args.fit:
trainer.fit()
if args.test:
trainer.test()
if args.predict:
trainer.predict()
[docs]
def main():
args = qumphy.misc.train_argument_parser()
config = load_config(args)
train(args, config)
if __name__ == "__main__":
main()