"""
File: qumphy/models/lightning_callbacks.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Pytorch lightning callbacks.
"""
import torch
import torchmetrics
from qumphy.callbacks.base_logging import BaseLoggingCallback
import qumphy
import pickle
import pathlib
[docs]
class PulseDBLogging_Pinballloss(BaseLoggingCallback):
[docs]
def on_test_epoch_end(self, trainer, pl_module):
stage = "test"
target = torch.cat(self.outputs[f"{stage}"]["target"], dim=0).cpu()
prediction = (
torch.cat(self.outputs[f"{stage}"]["prediction"], dim=0).detach().cpu()
)
quantiles = pl_module.loss_fn.quantiles
prediction = prediction.reshape([*prediction.shape[:-1], 2, len(quantiles)])
BP_mean = pl_module.BP_mean.to(prediction)
BP_std = pl_module.BP_std.to(prediction)
# Denormalize
prediction[..., 0, :] = prediction[..., 0, :] * BP_std[0] + BP_mean[0]
prediction[..., 1, :] = prediction[..., 1, :] * BP_std[1] + BP_mean[1]
if self.save_predictions:
save_dir = pathlib.Path(pl_module.logger.experiment.dir)
torch.save(target, save_dir / "target.pt")
torch.save(prediction, save_dir / "prediction.pt")
[docs]
class PulseDBLogging(BaseLoggingCallback):
"""
Pytorch lightning callback for logging for the PulseDB dataset.
Initialize with a list of quantities to log and a flag to specify which
pressure to log. Loss is always logged.
Available quantities:
- mae
- rmse
- std
"""
def __init__(self, log_quantities, log_pressure="both", **kwargs):
super().__init__(**kwargs)
if log_pressure not in ["both", "sbp", "dbp"]:
raise ValueError("log_pressure must be either 'both', 'sbp' or 'dbp'")
self.log_quantities = set(log_quantities)
self.log_pressure = log_pressure
self.set_function_dictionary()
[docs]
def set_function_dictionary(self):
if not self.log_quantities.issubset({"mae", "rmse", "std", "ieee"}):
raise ValueError(
"log_quantities must be part of 'mae', 'rmse', 'std', 'ieee'"
)
self.function_dictionary = {
"mae": self.log_mae,
"rmse": self.log_rmse,
"std": self.log_std,
"ieee": self.log_ieee_metrics,
}
[docs]
def log_ieee_metrics(self, pl_module, values, stage):
sbp_ieee_metrics = qumphy.metrics.ieee_grades_torch(
values["sbp_hat"], values["sbp"]
)
dbp_ieee_metrics = qumphy.metrics.ieee_grades_torch(
values["dbp_hat"], values["dbp"]
)
logs = {
f"{stage}_SBP_A": sbp_ieee_metrics["A"],
f"{stage}_SBP_B": sbp_ieee_metrics["B"],
f"{stage}_SBP_C": sbp_ieee_metrics["C"],
f"{stage}_SBP_D": sbp_ieee_metrics["D"],
f"{stage}_DBP_A": dbp_ieee_metrics["A"],
f"{stage}_DBP_B": dbp_ieee_metrics["B"],
f"{stage}_DBP_C": dbp_ieee_metrics["C"],
f"{stage}_DBP_D": dbp_ieee_metrics["D"],
}
pl_module.log_dict(logs, sync_dist=False, on_step=False, on_epoch=True)
return logs
[docs]
def log_mae(self, pl_module, values, stage):
self.mae = torchmetrics.MeanAbsoluteError().to("cpu")
MAE_baseline = pl_module.MAE_baseline.to("cpu")
if self.log_pressure == "both":
bp_mae = torch.stack(
(
self.mae(values["sbp_hat"], values["sbp"]),
self.mae(values["dbp_hat"], values["dbp"]),
)
)
bp_mae_baseline = bp_mae / MAE_baseline
logs = {
f"{stage}_SBP_MAE": bp_mae[0],
f"{stage}_DBP_MAE": bp_mae[1],
f"{stage}_SBP_MAE_baseline": bp_mae_baseline[0],
f"{stage}_DBP_MAE_baseline": bp_mae_baseline[1],
}
elif self.log_pressure == "sbp":
sbp_mae = self.mae(values["sbp_hat"], values["sbp"])
logs = {
f"{stage}_SBP_MAE": sbp_mae,
f"{stage}_SBP_MAE_baseline": sbp_mae / MAE_baseline[0],
}
elif self.log_pressure == "dbp":
dbp_mae = self.mae(values["dbp_hat"], values["dbp"])
logs = {
f"{stage}_DBP_MAE": dbp_mae,
f"{stage}_DBP_MAE_baseline": dbp_mae / MAE_baseline[1],
}
pl_module.log_dict(logs, sync_dist=False, on_step=False, on_epoch=True)
return logs
[docs]
def log_rmse(self, pl_module, values, stage):
self.rmse = torchmetrics.MeanSquaredError(squared=False).to("cpu")
RMSE_baseline = pl_module.RMSE_baseline.to("cpu")
if self.log_pressure == "both":
bp_rmse = torch.stack(
(
self.rmse(values["sbp_hat"], values["sbp"]),
self.rmse(values["dbp_hat"], values["dbp"]),
)
)
bp_rmse_baseline = bp_rmse / RMSE_baseline
logs = {
f"{stage}_SBP_RMSE": bp_rmse[0],
f"{stage}_DBP_RMSE": bp_rmse[1],
f"{stage}_SBP_RMSE_baseline": bp_rmse_baseline[0],
f"{stage}_DBP_RMSE_baseline": bp_rmse_baseline[1],
}
elif self.log_pressure == "sbp":
sbp_rmse = self.rmse(values["sbp_hat"], values["sbp"])
logs = {
f"{stage}_SBP_RMSE": sbp_rmse,
f"{stage}_SBP_RMSE_baseline": sbp_rmse / RMSE_baseline[0],
}
elif self.log_pressure == "dbp":
dbp_rmse = self.rmse(values["dbp_hat"], values["dbp"])
logs = {
f"{stage}_DBP_RMSE": dbp_rmse,
f"{stage}_DBP_RMSE_baseline": dbp_rmse / RMSE_baseline[1],
}
pl_module.log_dict(logs, sync_dist=False, on_step=False, on_epoch=True)
return logs
[docs]
def log_std(self, pl_module, values, stage):
if self.log_pressure == "both":
logs = {
f"{stage}_SBP_std": values["sbp_std"],
f"{stage}_DBP_std": values["dbp_std"],
}
else:
logs = {
f"{stage}_{self.log_pressure.upper()}_std": values[
f"{self.log_pressure}_std"
],
}
pl_module.log_dict(logs, sync_dist=False, on_step=False, on_epoch=True)
return logs
[docs]
def log_epoch_end(self, trainer, pl_module, stage):
target = torch.cat(self.outputs[f"{stage}"]["target"], dim=0).cpu()
prediction = (
torch.cat(self.outputs[f"{stage}"]["prediction"], dim=0).detach().cpu()
)
target = pl_module.denormalize_target(target)
if self.log_pressure == "both":
prediction[:, :2] = pl_module.denormalize_target(prediction[:, :2])
values = {
"sbp": target[:, 0],
"dbp": target[:, 1],
"sbp_hat": prediction[:, 0],
"dbp_hat": prediction[:, 1],
}
elif self.log_pressure == "sbp" or self.log_pressure == "dbp":
prediction[:, :1] = pl_module.denormalize_target(prediction[:, :1])
values = {
f"{self.log_pressure}": target[:, 0],
f"{self.log_pressure}_hat": prediction[:, 0],
}
if "std" in self.log_quantities:
if self.log_pressure == "both":
prediction[:, 2:4] = torch.exp(prediction[:, 2:4]).sqrt()
prediction[:, 2:4] = pl_module.denormalize_std(prediction[:, 2:4])
values["sbp_std"] = prediction[:, 2].mean()
values["dbp_std"] = prediction[:, 3].mean()
elif self.log_pressure == "sbp" or self.log_pressure == "dbp":
prediction[:, 1] = torch.exp(prediction[:, 1]).sqrt()
prediction[:, 1] = pl_module.denormalize_std(prediction[:, 1])
values[f"{self.log_pressure}_std"] = prediction[:, 1].mean()
logs = {}
for log_quantity in self.log_quantities:
logs.update(
self.function_dictionary[log_quantity](pl_module, values, stage)
)
# if self.save_predictions and stage == "test":
# save_dir = pathlib.Path(pl_module.logger.experiment.dir) / "predictions.pt"
# torch.save(prediction, save_dir)
return target, prediction, logs
[docs]
class PulseDBLogging_Ensemble(PulseDBLogging):
[docs]
def on_test_epoch_end(self, trainer, pl_module):
stage = "test"
self.log_loss(trainer, pl_module, stage)
target, prediction = self.log_epoch_end(trainer, pl_module, stage)
single_models_prediction = (
torch.cat(self.outputs[f"{stage}"]["single_models_prediction"], dim=1)
.detach()
.cpu()
)
single_models_prediction[..., :2] = pl_module.denormalize_target(
single_models_prediction[..., :2]
)
single_models_prediction[..., 2:4] = torch.exp(
single_models_prediction[..., 2:4]
).sqrt()
single_models_prediction[..., 2:4] = pl_module.denormalize_std(
single_models_prediction[..., 2:4]
)
save_dir = pathlib.Path(pl_module.logger.experiment.dir)
torch.save(target, save_dir / "target.pt")
torch.save(prediction, save_dir / "prediction.pt")
torch.save(single_models_prediction, save_dir / "single_models_prediction.pt")
[docs]
class PulseDBLogging_MCD(PulseDBLogging):
def __init__(self, log_quantities, log_pressure="both"):
super().__init__(log_quantities, log_pressure)
[docs]
def on_test_epoch_end(self, trainer, pl_module):
stage = "test"
self.log_loss(trainer, pl_module, stage)
target = torch.cat(self.outputs[f"{stage}"]["target"], dim=0).cpu()
# prediction = (
# torch.cat(self.outputs[f"{stage}"]["prediction"], dim=0).detach().cpu()
# )
target = pl_module.denormalize_target(target)
if self.log_pressure == "both":
# stacked_predictions_mean_sbp = torch.cat([tensor.view(-1) for tensor in self.outputs[f"{stage}"]["stacked_predicted_means_sbp"]])
# print(self.outputs[f"{stage}"]["stacked_predicted_means_sbp"][0].size())
# print(len(self.outputs[f"{stage}"]["stacked_predicted_means_sbp"]))
stacked_predictions_mean_sbp = torch.cat(
self.outputs[f"{stage}"]["stacked_predicted_means_sbp"], dim=-1
)
# stacked_predictions_mean_dbp = torch.cat([tensor.view(-1) for tensor in self.outputs[f"{stage}"]["stacked_predicted_means_dbp"]])
stacked_predictions_mean_dbp = torch.cat(
self.outputs[f"{stage}"]["stacked_predicted_means_dbp"], dim=-1
)
stacked_predictions_var_dbp = torch.cat(
self.outputs[f"{stage}"]["stacked_predicted_vars_dbp"], dim=-1
)
stacked_predictions_var_sbp = torch.cat(
self.outputs[f"{stage}"]["stacked_predicted_vars_sbp"], dim=-1
)
# we need to take every MCD sample and denormalise it...
denorm_stacked_predictions_mean_sbp = (
0 * stacked_predictions_mean_sbp.clone()
)
denorm_stacked_predictions_mean_dbp = (
0 * stacked_predictions_mean_dbp.clone()
)
for sample in range(stacked_predictions_mean_sbp.size(0)):
sampled_data = torch.stack(
(
stacked_predictions_mean_sbp[sample, :],
stacked_predictions_mean_dbp[sample, :],
),
dim=-1,
)
sampled_data = pl_module.denormalize_target(sampled_data)
denorm_stacked_predictions_mean_sbp[sample, :] = sampled_data[:, 0]
denorm_stacked_predictions_mean_dbp[sample, :] = sampled_data[:, 1]
test_epistemic_sbp = torch.var(denorm_stacked_predictions_mean_sbp, dim=0)
test_epistemic_dbp = torch.var(denorm_stacked_predictions_mean_dbp, dim=0)
test_preds_sbp = torch.mean(denorm_stacked_predictions_mean_sbp, dim=0)
test_preds_dbp = torch.mean(denorm_stacked_predictions_mean_dbp, dim=0)
denorm_stacked_predictions_var_sbp = 0 * stacked_predictions_var_sbp.clone()
denorm_stacked_predictions_var_dbp = 0 * stacked_predictions_var_dbp.clone()
for sample in range(denorm_stacked_predictions_var_sbp.size(0)):
sampled_data = torch.stack(
(
torch.sqrt(stacked_predictions_var_sbp[sample, :]),
torch.sqrt(stacked_predictions_var_dbp[sample, :]),
),
dim=-1,
)
sampled_data = pl_module.denormalize_std(sampled_data)
denorm_stacked_predictions_var_sbp[sample, :] = sampled_data[:, 0]
denorm_stacked_predictions_var_dbp[sample, :] = sampled_data[:, 1]
test_aleatoric_sbp = torch.mean(denorm_stacked_predictions_var_sbp, dim=0)
test_aleatoric_dbp = torch.mean(denorm_stacked_predictions_var_dbp, dim=0)
# stacked_predictions_mean = torch.stack(
# stacked_predictions_mean_sbp, stacked_predictions_mean_dbp, dim=-1
# )
# stacked_predictions_mean = pl_module.denormalize_target(
# stacked_predictions_mean
# )
directory = pathlib.Path(pl_module.logger.experiment.dir)
# Save the accumulated test values to a pickle file at the end of the test epoch
if self.log_pressure == "both":
with open(
directory / "test_preds_sbp.pkl",
"wb",
) as f:
pickle.dump(test_preds_sbp.detach().cpu(), f)
with open(
directory / "test_preds_dbp.pkl",
"wb",
) as f:
pickle.dump(test_preds_dbp.detach().cpu(), f)
with open(
directory / "test_epistemic_sbp.pkl",
"wb",
) as f:
pickle.dump(test_epistemic_sbp.detach().cpu(), f)
with open(
directory / "test_aleatoric_sbp.pkl",
"wb",
) as f:
pickle.dump(test_aleatoric_sbp.detach().cpu(), f)
with open(
directory / "test_epistemic_dbp.pkl",
"wb",
) as f:
pickle.dump(test_epistemic_dbp.detach().cpu(), f)
with open(
directory / "test_aleatoric_dbp.pkl",
"wb",
) as f:
pickle.dump(test_aleatoric_dbp.detach().cpu(), f)
with open(
directory / "test_targs.pkl",
"wb",
) as f:
pickle.dump(target.detach().cpu(), f)
elif self.log_pressure == "sbp":
with open(
directory / "test_preds_sbp.pkl",
"wb",
) as f:
pickle.dump(test_preds_sbp, f)
with open(
directory / "test_epistemic_sbp.pkl",
"wb",
) as f:
pickle.dump(test_epistemic_sbp, f)
with open(
directory / "test_aleatoric_sbp.pkl",
"wb",
) as f:
pickle.dump(test_aleatoric_sbp, f)
elif self.log_pressure == "dbp":
with open(
directory / "test_preds_dbp.pkl",
"wb",
) as f:
pickle.dump(test_preds_dbp, f)
with open(
directory / "test_epistemic_dbp.pkl",
"wb",
) as f:
pickle.dump(test_epistemic_dbp, f)
with open(
directory / "test_aleatoric_dbp.pkl",
"wb",
) as f:
pickle.dump(test_aleatoric_dbp, f)