"""
File: qumphy/models/lightning_callbacks.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Pytorch lightning callbacks.
"""
from lightning.pytorch.callbacks import Callback
import torch
import numpy as np
import pathlib
import yaml
[docs]
class BaseLoggingCallback(Callback):
"""
This is a base class for logging callbacks.
Predictions and targets of each epoch for each stage are stored in
dictionaries in self.predictions and self.targets as
self.predictions[stage] and self.targets[stage], where stage is "train",
"val" or "test". This allows to write new logging callbacks on top of this
class, by overriding the log_epoch_end method.
"""
def __init__(self, save_predictions=False):
super().__init__()
self.save_predictions = save_predictions
self.outputs = {"train": {}, "val": {}, "test": {}}
[docs]
def on_train_epoch_start(self, trainer, pl_module):
stage = "train"
self._common_on_epoch_start(trainer, pl_module, stage)
[docs]
def on_validation_epoch_start(self, trainer, pl_module):
stage = "val"
self._common_on_epoch_start(trainer, pl_module, stage)
[docs]
def on_test_epoch_start(self, trainer, pl_module):
stage = "test"
self._common_on_epoch_start(trainer, pl_module, stage)
def _common_on_epoch_start(self, trainer, pl_module, stage):
self.outputs[stage] = {}
[docs]
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
stage = "train"
self._common_on_batch_end(trainer, pl_module, outputs, stage)
[docs]
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
stage = "val"
self._common_on_batch_end(trainer, pl_module, outputs, stage)
[docs]
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
stage = "test"
self._common_on_batch_end(trainer, pl_module, outputs, stage)
def _common_on_batch_end(self, trainer, pl_module, batch_outputs, stage):
for key in batch_outputs:
if key not in self.outputs[stage]:
self.outputs[stage][key] = [batch_outputs[key]]
else:
self.outputs[stage][key].append(batch_outputs[key])
[docs]
def on_train_epoch_end(self, trainer, pl_module):
stage = "train"
self._common_on_epoch_end(trainer, pl_module, stage)
[docs]
def on_validation_epoch_end(self, trainer, pl_module):
stage = "val"
self._common_on_epoch_end(trainer, pl_module, stage)
[docs]
def on_test_epoch_end(self, trainer, pl_module):
stage = "test"
loss = self.log_loss(trainer, pl_module, stage)
target, prediction, logs = self.log_epoch_end(trainer, pl_module, stage)
logs.update({f"{stage}_loss": loss})
for key, value in logs.items():
if type(value) is torch.Tensor or type(value) is np.ndarray:
value = value.item()
logs[key] = value
save_path = pathlib.Path(trainer.save_dir)
if self.save_predictions and stage == "test":
np.save(
save_path / "predictions.npy",
prediction.detach().cpu().numpy(),
)
np.save(
save_path / "targets.npy",
target.detach().cpu().numpy(),
)
with open(save_path / "metrics.yaml", "w") as f:
yaml.dump(logs, f)
def _common_on_epoch_end(self, trainer, pl_module, stage):
self.log_loss(trainer, pl_module, stage)
self.log_epoch_end(trainer, pl_module, stage)
del self.outputs[stage]
[docs]
def log_loss(self, trainer, pl_module, stage):
loss = torch.stack(self.outputs[stage]["loss"]).mean()
pl_module.log(
f"{stage}_loss",
loss,
sync_dist=False,
prog_bar=True,
on_epoch=True,
on_step=False,
)
return loss
[docs]
def log_epoch_end(self, trainer, pl_module, stage):
"""Override this method to log the metrics at the end of the epoch."""
pass