Source code for qumphy.callbacks.deepbeat

"""
File: qumphy/models/lightning_callbacks.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Pytorch lightning callbacks.
"""

import qumphy
from qumphy.callbacks import base_logging
import numpy as np
import torch
import pickle
import pathlib


[docs] class DeepBeatLogging(base_logging.BaseLoggingCallback): """ Pytorch lightning callback for logging for the DeepBeat dataset. """ def __init__(self, target_format="class_index", *args, **kwargs): super().__init__(*args, **kwargs) self.target_format = target_format
[docs] def log_epoch_end(self, trainer, pl_module, stage): target = torch.cat(self.outputs[f"{stage}"]["target"], dim=0).cpu() prediction = ( torch.cat(self.outputs[f"{stage}"]["prediction"], dim=0).detach().cpu() ) if self.target_format == "binary": target = target[:, 0] prediction = prediction[:, 0] elif self.target_format == "class_index": prediction = prediction[:, 1] target = target.float() elif self.target_format == "class_probability": target = target[:, 1] prediction = prediction[:, 1] try: metrics_dict = qumphy.metrics.all_binary_metrics( target.numpy(), prediction.numpy() ) except ValueError: print("Prediction has nan\n", prediction) metrics_dict = {} # metrics_dict = qumphy.metrics.all_binary_metrics( # target.numpy(), prediction.numpy() # ) for key, value in metrics_dict.items(): pl_module.log( f"{stage}_{key}", value, sync_dist=False, prog_bar=True, on_step=False, on_epoch=True, ) return target, prediction, metrics_dict
[docs] class DeepBeatLogging_KGLoss(DeepBeatLogging):
[docs] def on_test_epoch_end(self, trainer, pl_module): stage = "test" self.log_loss(trainer, pl_module, stage) target, prediction = self.log_epoch_end(trainer, pl_module, stage) single_models_prediction = torch.cat( self.outputs[stage]["single_models_prediction"], dim=1 ) uncertainties = torch.cat(self.outputs[stage]["entropy_corrupted"], dim=0) uncertainties_aleaotric = torch.cat(self.outputs[stage]["H2"], dim=0) if self.save_predictions: save_dir = pathlib.Path(pl_module.logger.experiment.dir) torch.save(prediction, save_dir / "predictions.pt") torch.save(uncertainties, save_dir / "uncertainties.pt") torch.save(uncertainties_aleaotric, save_dir / "uncertainties_aleatoric.pt") torch.save( single_models_prediction, save_dir / "single_models_prediction.pt" ) torch.save(target, save_dir / "target.pt")
[docs] class DeepBeatLogging_MCD(DeepBeatLogging): """ Pytorch lightning callback for logging for the DeepBeat dataset. Initialize with a list of quantities to log. Available quantities: - loss - qumphy """
[docs] def on_test_epoch_end(self, trainer, pl_module): stage = "test" self.log_loss(trainer, pl_module, stage) target = torch.cat(self.outputs[f"{stage}"]["target"], dim=0).cpu() prediction = ( torch.cat(self.outputs[f"{stage}"]["prediction"], dim=0).detach().cpu() ) if self.target_format == "binary": target = target[:, 0] prediction = prediction[:, 0] elif self.target_format == "class_index": prediction = prediction[:, 1] target = target.float() elif self.target_format == "class_probability": target = target[:, 1] prediction = prediction[:, 1] metrics_dict = qumphy.metrics.all_binary_metrics( target.numpy(), prediction.numpy() ) for key, value in metrics_dict.items(): pl_module.log( f"{stage}_{key}", value, sync_dist=False, prog_bar=True, on_step=False, on_epoch=True, ) uncertainties = np.concatenate( self.outputs[f"{stage}"]["entropy_corrupted"], axis=0 ) directory = pathlib.Path(pl_module.logger.experiment.dir) with open(directory / "uncertainties_AF.pkl", "wb") as file: pickle.dump(uncertainties, file) with open(directory / "prediction_AF.pkl", "wb") as file: pickle.dump(prediction, file) with open(directory / "target_AF.pkl", "wb") as file: pickle.dump(target, file) aleatoric = np.concatenate(self.outputs[f"{stage}"]["H2"], axis=0) with open(directory / "uncertainties_ale_AF.pkl", "wb") as file: pickle.dump(aleatoric, file)
# Log uncertainties somehow