"""
File: qumphy/models/lightning_callbacks.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Pytorch lightning callbacks.
"""
import qumphy
from qumphy.callbacks import base_logging
import numpy as np
import torch
import pickle
import pathlib
[docs]
class DeepBeatLogging(base_logging.BaseLoggingCallback):
"""
Pytorch lightning callback for logging for the DeepBeat dataset.
"""
def __init__(self, target_format="class_index", *args, **kwargs):
super().__init__(*args, **kwargs)
self.target_format = target_format
[docs]
def log_epoch_end(self, trainer, pl_module, stage):
target = torch.cat(self.outputs[f"{stage}"]["target"], dim=0).cpu()
prediction = (
torch.cat(self.outputs[f"{stage}"]["prediction"], dim=0).detach().cpu()
)
if self.target_format == "binary":
target = target[:, 0]
prediction = prediction[:, 0]
elif self.target_format == "class_index":
prediction = prediction[:, 1]
target = target.float()
elif self.target_format == "class_probability":
target = target[:, 1]
prediction = prediction[:, 1]
try:
metrics_dict = qumphy.metrics.all_binary_metrics(
target.numpy(), prediction.numpy()
)
except ValueError:
print("Prediction has nan\n", prediction)
metrics_dict = {}
# metrics_dict = qumphy.metrics.all_binary_metrics(
# target.numpy(), prediction.numpy()
# )
for key, value in metrics_dict.items():
pl_module.log(
f"{stage}_{key}",
value,
sync_dist=False,
prog_bar=True,
on_step=False,
on_epoch=True,
)
return target, prediction, metrics_dict
[docs]
class DeepBeatLogging_KGLoss(DeepBeatLogging):
[docs]
def on_test_epoch_end(self, trainer, pl_module):
stage = "test"
self.log_loss(trainer, pl_module, stage)
target, prediction = self.log_epoch_end(trainer, pl_module, stage)
single_models_prediction = torch.cat(
self.outputs[stage]["single_models_prediction"], dim=1
)
uncertainties = torch.cat(self.outputs[stage]["entropy_corrupted"], dim=0)
uncertainties_aleaotric = torch.cat(self.outputs[stage]["H2"], dim=0)
if self.save_predictions:
save_dir = pathlib.Path(pl_module.logger.experiment.dir)
torch.save(prediction, save_dir / "predictions.pt")
torch.save(uncertainties, save_dir / "uncertainties.pt")
torch.save(uncertainties_aleaotric, save_dir / "uncertainties_aleatoric.pt")
torch.save(
single_models_prediction, save_dir / "single_models_prediction.pt"
)
torch.save(target, save_dir / "target.pt")
[docs]
class DeepBeatLogging_MCD(DeepBeatLogging):
"""
Pytorch lightning callback for logging for the DeepBeat dataset.
Initialize with a list of quantities to log.
Available quantities:
- loss
- qumphy
"""
[docs]
def on_test_epoch_end(self, trainer, pl_module):
stage = "test"
self.log_loss(trainer, pl_module, stage)
target = torch.cat(self.outputs[f"{stage}"]["target"], dim=0).cpu()
prediction = (
torch.cat(self.outputs[f"{stage}"]["prediction"], dim=0).detach().cpu()
)
if self.target_format == "binary":
target = target[:, 0]
prediction = prediction[:, 0]
elif self.target_format == "class_index":
prediction = prediction[:, 1]
target = target.float()
elif self.target_format == "class_probability":
target = target[:, 1]
prediction = prediction[:, 1]
metrics_dict = qumphy.metrics.all_binary_metrics(
target.numpy(), prediction.numpy()
)
for key, value in metrics_dict.items():
pl_module.log(
f"{stage}_{key}",
value,
sync_dist=False,
prog_bar=True,
on_step=False,
on_epoch=True,
)
uncertainties = np.concatenate(
self.outputs[f"{stage}"]["entropy_corrupted"], axis=0
)
directory = pathlib.Path(pl_module.logger.experiment.dir)
with open(directory / "uncertainties_AF.pkl", "wb") as file:
pickle.dump(uncertainties, file)
with open(directory / "prediction_AF.pkl", "wb") as file:
pickle.dump(prediction, file)
with open(directory / "target_AF.pkl", "wb") as file:
pickle.dump(target, file)
aleatoric = np.concatenate(self.outputs[f"{stage}"]["H2"], axis=0)
with open(directory / "uncertainties_ale_AF.pkl", "wb") as file:
pickle.dump(aleatoric, file)
# Log uncertainties somehow