Source code for qumphy.callbacks.base_logging

"""
File: qumphy/models/lightning_callbacks.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Pytorch lightning callbacks.
"""

from lightning.pytorch.callbacks import Callback
import torch
import numpy as np
import pathlib
import yaml


[docs] class BaseLoggingCallback(Callback): """ This is a base class for logging callbacks. Predictions and targets of each epoch for each stage are stored in dictionaries in self.predictions and self.targets as self.predictions[stage] and self.targets[stage], where stage is "train", "val" or "test". This allows to write new logging callbacks on top of this class, by overriding the log_epoch_end method. """ def __init__(self, save_predictions=False): super().__init__() self.save_predictions = save_predictions self.outputs = {"train": {}, "val": {}, "test": {}}
[docs] def on_train_epoch_start(self, trainer, pl_module): stage = "train" self._common_on_epoch_start(trainer, pl_module, stage)
[docs] def on_validation_epoch_start(self, trainer, pl_module): stage = "val" self._common_on_epoch_start(trainer, pl_module, stage)
[docs] def on_test_epoch_start(self, trainer, pl_module): stage = "test" self._common_on_epoch_start(trainer, pl_module, stage)
def _common_on_epoch_start(self, trainer, pl_module, stage): self.outputs[stage] = {}
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): stage = "train" self._common_on_batch_end(trainer, pl_module, outputs, stage)
[docs] def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): stage = "val" self._common_on_batch_end(trainer, pl_module, outputs, stage)
[docs] def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): stage = "test" self._common_on_batch_end(trainer, pl_module, outputs, stage)
def _common_on_batch_end(self, trainer, pl_module, batch_outputs, stage): for key in batch_outputs: if key not in self.outputs[stage]: self.outputs[stage][key] = [batch_outputs[key]] else: self.outputs[stage][key].append(batch_outputs[key])
[docs] def on_train_epoch_end(self, trainer, pl_module): stage = "train" self._common_on_epoch_end(trainer, pl_module, stage)
[docs] def on_validation_epoch_end(self, trainer, pl_module): stage = "val" self._common_on_epoch_end(trainer, pl_module, stage)
[docs] def on_test_epoch_end(self, trainer, pl_module): stage = "test" loss = self.log_loss(trainer, pl_module, stage) target, prediction, logs = self.log_epoch_end(trainer, pl_module, stage) logs.update({f"{stage}_loss": loss}) for key, value in logs.items(): if type(value) is torch.Tensor or type(value) is np.ndarray: value = value.item() logs[key] = value save_path = pathlib.Path(trainer.save_dir) if self.save_predictions and stage == "test": np.save( save_path / "predictions.npy", prediction.detach().cpu().numpy(), ) np.save( save_path / "targets.npy", target.detach().cpu().numpy(), ) with open(save_path / "metrics.yaml", "w") as f: yaml.dump(logs, f)
def _common_on_epoch_end(self, trainer, pl_module, stage): self.log_loss(trainer, pl_module, stage) self.log_epoch_end(trainer, pl_module, stage) del self.outputs[stage]
[docs] def log_loss(self, trainer, pl_module, stage): loss = torch.stack(self.outputs[stage]["loss"]).mean() pl_module.log( f"{stage}_loss", loss, sync_dist=False, prog_bar=True, on_epoch=True, on_step=False, ) return loss
[docs] def log_epoch_end(self, trainer, pl_module, stage): """Override this method to log the metrics at the end of the epoch.""" pass