Source code for qumphy.callbacks.pulsedb

"""
File: qumphy/models/lightning_callbacks.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Pytorch lightning callbacks.
"""

import torch
import torchmetrics
from qumphy.callbacks.base_logging import BaseLoggingCallback
import qumphy
import pickle
import pathlib


[docs] class PulseDBLogging_Pinballloss(BaseLoggingCallback):
[docs] def on_test_epoch_end(self, trainer, pl_module): stage = "test" target = torch.cat(self.outputs[f"{stage}"]["target"], dim=0).cpu() prediction = ( torch.cat(self.outputs[f"{stage}"]["prediction"], dim=0).detach().cpu() ) quantiles = pl_module.loss_fn.quantiles prediction = prediction.reshape([*prediction.shape[:-1], 2, len(quantiles)]) BP_mean = pl_module.BP_mean.to(prediction) BP_std = pl_module.BP_std.to(prediction) # Denormalize prediction[..., 0, :] = prediction[..., 0, :] * BP_std[0] + BP_mean[0] prediction[..., 1, :] = prediction[..., 1, :] * BP_std[1] + BP_mean[1] if self.save_predictions: save_dir = pathlib.Path(pl_module.logger.experiment.dir) torch.save(target, save_dir / "target.pt") torch.save(prediction, save_dir / "prediction.pt")
[docs] class PulseDBLogging(BaseLoggingCallback): """ Pytorch lightning callback for logging for the PulseDB dataset. Initialize with a list of quantities to log and a flag to specify which pressure to log. Loss is always logged. Available quantities: - mae - rmse - std """ def __init__(self, log_quantities, log_pressure="both", **kwargs): super().__init__(**kwargs) if log_pressure not in ["both", "sbp", "dbp"]: raise ValueError("log_pressure must be either 'both', 'sbp' or 'dbp'") self.log_quantities = set(log_quantities) self.log_pressure = log_pressure self.set_function_dictionary()
[docs] def set_function_dictionary(self): if not self.log_quantities.issubset({"mae", "rmse", "std", "ieee"}): raise ValueError( "log_quantities must be part of 'mae', 'rmse', 'std', 'ieee'" ) self.function_dictionary = { "mae": self.log_mae, "rmse": self.log_rmse, "std": self.log_std, "ieee": self.log_ieee_metrics, }
[docs] def log_ieee_metrics(self, pl_module, values, stage): sbp_ieee_metrics = qumphy.metrics.ieee_grades_torch( values["sbp_hat"], values["sbp"] ) dbp_ieee_metrics = qumphy.metrics.ieee_grades_torch( values["dbp_hat"], values["dbp"] ) logs = { f"{stage}_SBP_A": sbp_ieee_metrics["A"], f"{stage}_SBP_B": sbp_ieee_metrics["B"], f"{stage}_SBP_C": sbp_ieee_metrics["C"], f"{stage}_SBP_D": sbp_ieee_metrics["D"], f"{stage}_DBP_A": dbp_ieee_metrics["A"], f"{stage}_DBP_B": dbp_ieee_metrics["B"], f"{stage}_DBP_C": dbp_ieee_metrics["C"], f"{stage}_DBP_D": dbp_ieee_metrics["D"], } pl_module.log_dict(logs, sync_dist=False, on_step=False, on_epoch=True) return logs
[docs] def log_mae(self, pl_module, values, stage): self.mae = torchmetrics.MeanAbsoluteError().to("cpu") MAE_baseline = pl_module.MAE_baseline.to("cpu") if self.log_pressure == "both": bp_mae = torch.stack( ( self.mae(values["sbp_hat"], values["sbp"]), self.mae(values["dbp_hat"], values["dbp"]), ) ) bp_mae_baseline = bp_mae / MAE_baseline logs = { f"{stage}_SBP_MAE": bp_mae[0], f"{stage}_DBP_MAE": bp_mae[1], f"{stage}_SBP_MAE_baseline": bp_mae_baseline[0], f"{stage}_DBP_MAE_baseline": bp_mae_baseline[1], } elif self.log_pressure == "sbp": sbp_mae = self.mae(values["sbp_hat"], values["sbp"]) logs = { f"{stage}_SBP_MAE": sbp_mae, f"{stage}_SBP_MAE_baseline": sbp_mae / MAE_baseline[0], } elif self.log_pressure == "dbp": dbp_mae = self.mae(values["dbp_hat"], values["dbp"]) logs = { f"{stage}_DBP_MAE": dbp_mae, f"{stage}_DBP_MAE_baseline": dbp_mae / MAE_baseline[1], } pl_module.log_dict(logs, sync_dist=False, on_step=False, on_epoch=True) return logs
[docs] def log_rmse(self, pl_module, values, stage): self.rmse = torchmetrics.MeanSquaredError(squared=False).to("cpu") RMSE_baseline = pl_module.RMSE_baseline.to("cpu") if self.log_pressure == "both": bp_rmse = torch.stack( ( self.rmse(values["sbp_hat"], values["sbp"]), self.rmse(values["dbp_hat"], values["dbp"]), ) ) bp_rmse_baseline = bp_rmse / RMSE_baseline logs = { f"{stage}_SBP_RMSE": bp_rmse[0], f"{stage}_DBP_RMSE": bp_rmse[1], f"{stage}_SBP_RMSE_baseline": bp_rmse_baseline[0], f"{stage}_DBP_RMSE_baseline": bp_rmse_baseline[1], } elif self.log_pressure == "sbp": sbp_rmse = self.rmse(values["sbp_hat"], values["sbp"]) logs = { f"{stage}_SBP_RMSE": sbp_rmse, f"{stage}_SBP_RMSE_baseline": sbp_rmse / RMSE_baseline[0], } elif self.log_pressure == "dbp": dbp_rmse = self.rmse(values["dbp_hat"], values["dbp"]) logs = { f"{stage}_DBP_RMSE": dbp_rmse, f"{stage}_DBP_RMSE_baseline": dbp_rmse / RMSE_baseline[1], } pl_module.log_dict(logs, sync_dist=False, on_step=False, on_epoch=True) return logs
[docs] def log_std(self, pl_module, values, stage): if self.log_pressure == "both": logs = { f"{stage}_SBP_std": values["sbp_std"], f"{stage}_DBP_std": values["dbp_std"], } else: logs = { f"{stage}_{self.log_pressure.upper()}_std": values[ f"{self.log_pressure}_std" ], } pl_module.log_dict(logs, sync_dist=False, on_step=False, on_epoch=True) return logs
[docs] def log_epoch_end(self, trainer, pl_module, stage): target = torch.cat(self.outputs[f"{stage}"]["target"], dim=0).cpu() prediction = ( torch.cat(self.outputs[f"{stage}"]["prediction"], dim=0).detach().cpu() ) target = pl_module.denormalize_target(target) if self.log_pressure == "both": prediction[:, :2] = pl_module.denormalize_target(prediction[:, :2]) values = { "sbp": target[:, 0], "dbp": target[:, 1], "sbp_hat": prediction[:, 0], "dbp_hat": prediction[:, 1], } elif self.log_pressure == "sbp" or self.log_pressure == "dbp": prediction[:, :1] = pl_module.denormalize_target(prediction[:, :1]) values = { f"{self.log_pressure}": target[:, 0], f"{self.log_pressure}_hat": prediction[:, 0], } if "std" in self.log_quantities: if self.log_pressure == "both": prediction[:, 2:4] = torch.exp(prediction[:, 2:4]).sqrt() prediction[:, 2:4] = pl_module.denormalize_std(prediction[:, 2:4]) values["sbp_std"] = prediction[:, 2].mean() values["dbp_std"] = prediction[:, 3].mean() elif self.log_pressure == "sbp" or self.log_pressure == "dbp": prediction[:, 1] = torch.exp(prediction[:, 1]).sqrt() prediction[:, 1] = pl_module.denormalize_std(prediction[:, 1]) values[f"{self.log_pressure}_std"] = prediction[:, 1].mean() logs = {} for log_quantity in self.log_quantities: logs.update( self.function_dictionary[log_quantity](pl_module, values, stage) ) # if self.save_predictions and stage == "test": # save_dir = pathlib.Path(pl_module.logger.experiment.dir) / "predictions.pt" # torch.save(prediction, save_dir) return target, prediction, logs
[docs] class PulseDBLogging_Ensemble(PulseDBLogging):
[docs] def on_test_epoch_end(self, trainer, pl_module): stage = "test" self.log_loss(trainer, pl_module, stage) target, prediction = self.log_epoch_end(trainer, pl_module, stage) single_models_prediction = ( torch.cat(self.outputs[f"{stage}"]["single_models_prediction"], dim=1) .detach() .cpu() ) single_models_prediction[..., :2] = pl_module.denormalize_target( single_models_prediction[..., :2] ) single_models_prediction[..., 2:4] = torch.exp( single_models_prediction[..., 2:4] ).sqrt() single_models_prediction[..., 2:4] = pl_module.denormalize_std( single_models_prediction[..., 2:4] ) save_dir = pathlib.Path(pl_module.logger.experiment.dir) torch.save(target, save_dir / "target.pt") torch.save(prediction, save_dir / "prediction.pt") torch.save(single_models_prediction, save_dir / "single_models_prediction.pt")
[docs] class PulseDBLogging_MCD(PulseDBLogging): def __init__(self, log_quantities, log_pressure="both"): super().__init__(log_quantities, log_pressure)
[docs] def on_test_epoch_end(self, trainer, pl_module): stage = "test" self.log_loss(trainer, pl_module, stage) target = torch.cat(self.outputs[f"{stage}"]["target"], dim=0).cpu() # prediction = ( # torch.cat(self.outputs[f"{stage}"]["prediction"], dim=0).detach().cpu() # ) target = pl_module.denormalize_target(target) if self.log_pressure == "both": # stacked_predictions_mean_sbp = torch.cat([tensor.view(-1) for tensor in self.outputs[f"{stage}"]["stacked_predicted_means_sbp"]]) # print(self.outputs[f"{stage}"]["stacked_predicted_means_sbp"][0].size()) # print(len(self.outputs[f"{stage}"]["stacked_predicted_means_sbp"])) stacked_predictions_mean_sbp = torch.cat( self.outputs[f"{stage}"]["stacked_predicted_means_sbp"], dim=-1 ) # stacked_predictions_mean_dbp = torch.cat([tensor.view(-1) for tensor in self.outputs[f"{stage}"]["stacked_predicted_means_dbp"]]) stacked_predictions_mean_dbp = torch.cat( self.outputs[f"{stage}"]["stacked_predicted_means_dbp"], dim=-1 ) stacked_predictions_var_dbp = torch.cat( self.outputs[f"{stage}"]["stacked_predicted_vars_dbp"], dim=-1 ) stacked_predictions_var_sbp = torch.cat( self.outputs[f"{stage}"]["stacked_predicted_vars_sbp"], dim=-1 ) # we need to take every MCD sample and denormalise it... denorm_stacked_predictions_mean_sbp = ( 0 * stacked_predictions_mean_sbp.clone() ) denorm_stacked_predictions_mean_dbp = ( 0 * stacked_predictions_mean_dbp.clone() ) for sample in range(stacked_predictions_mean_sbp.size(0)): sampled_data = torch.stack( ( stacked_predictions_mean_sbp[sample, :], stacked_predictions_mean_dbp[sample, :], ), dim=-1, ) sampled_data = pl_module.denormalize_target(sampled_data) denorm_stacked_predictions_mean_sbp[sample, :] = sampled_data[:, 0] denorm_stacked_predictions_mean_dbp[sample, :] = sampled_data[:, 1] test_epistemic_sbp = torch.var(denorm_stacked_predictions_mean_sbp, dim=0) test_epistemic_dbp = torch.var(denorm_stacked_predictions_mean_dbp, dim=0) test_preds_sbp = torch.mean(denorm_stacked_predictions_mean_sbp, dim=0) test_preds_dbp = torch.mean(denorm_stacked_predictions_mean_dbp, dim=0) denorm_stacked_predictions_var_sbp = 0 * stacked_predictions_var_sbp.clone() denorm_stacked_predictions_var_dbp = 0 * stacked_predictions_var_dbp.clone() for sample in range(denorm_stacked_predictions_var_sbp.size(0)): sampled_data = torch.stack( ( torch.sqrt(stacked_predictions_var_sbp[sample, :]), torch.sqrt(stacked_predictions_var_dbp[sample, :]), ), dim=-1, ) sampled_data = pl_module.denormalize_std(sampled_data) denorm_stacked_predictions_var_sbp[sample, :] = sampled_data[:, 0] denorm_stacked_predictions_var_dbp[sample, :] = sampled_data[:, 1] test_aleatoric_sbp = torch.mean(denorm_stacked_predictions_var_sbp, dim=0) test_aleatoric_dbp = torch.mean(denorm_stacked_predictions_var_dbp, dim=0) # stacked_predictions_mean = torch.stack( # stacked_predictions_mean_sbp, stacked_predictions_mean_dbp, dim=-1 # ) # stacked_predictions_mean = pl_module.denormalize_target( # stacked_predictions_mean # ) directory = pathlib.Path(pl_module.logger.experiment.dir) # Save the accumulated test values to a pickle file at the end of the test epoch if self.log_pressure == "both": with open( directory / "test_preds_sbp.pkl", "wb", ) as f: pickle.dump(test_preds_sbp.detach().cpu(), f) with open( directory / "test_preds_dbp.pkl", "wb", ) as f: pickle.dump(test_preds_dbp.detach().cpu(), f) with open( directory / "test_epistemic_sbp.pkl", "wb", ) as f: pickle.dump(test_epistemic_sbp.detach().cpu(), f) with open( directory / "test_aleatoric_sbp.pkl", "wb", ) as f: pickle.dump(test_aleatoric_sbp.detach().cpu(), f) with open( directory / "test_epistemic_dbp.pkl", "wb", ) as f: pickle.dump(test_epistemic_dbp.detach().cpu(), f) with open( directory / "test_aleatoric_dbp.pkl", "wb", ) as f: pickle.dump(test_aleatoric_dbp.detach().cpu(), f) with open( directory / "test_targs.pkl", "wb", ) as f: pickle.dump(target.detach().cpu(), f) elif self.log_pressure == "sbp": with open( directory / "test_preds_sbp.pkl", "wb", ) as f: pickle.dump(test_preds_sbp, f) with open( directory / "test_epistemic_sbp.pkl", "wb", ) as f: pickle.dump(test_epistemic_sbp, f) with open( directory / "test_aleatoric_sbp.pkl", "wb", ) as f: pickle.dump(test_aleatoric_sbp, f) elif self.log_pressure == "dbp": with open( directory / "test_preds_dbp.pkl", "wb", ) as f: pickle.dump(test_preds_dbp, f) with open( directory / "test_epistemic_dbp.pkl", "wb", ) as f: pickle.dump(test_epistemic_dbp, f) with open( directory / "test_aleatoric_dbp.pkl", "wb", ) as f: pickle.dump(test_aleatoric_dbp, f)