"""
Description: Pytorch lightning callbacks.
"""
import qumphy
from qumphy.callbacks import base_logging
import torch
[docs]
class SleepApneaLogging(base_logging.BaseLoggingCallback):
"""
Pytorch lightning callback for logging for the sleepApnea dataset.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def log_epoch_end(self, trainer, pl_module, stage: str):
label = torch.cat(self.outputs[stage]["target"], dim=0).cpu()
prediction = torch.cat(self.outputs[stage]["prediction"], dim=0).detach().cpu()
label = label[:, 0]
prediction = prediction[:, 0]
try:
metrics_dict = qumphy.metrics.all_binary_metrics(
label.numpy(), prediction.numpy()
)
except ValueError:
print(f"[{stage}]Prediction has nan\n", prediction)
metrics_dict = {}
# metrics_dict = qumphy.metrics.all_binary_metrics(
# target.numpy(), prediction.numpy()
# )
for key, value in metrics_dict.items():
pl_module.log(
f"{stage}_{key}",
value,
sync_dist=False,
prog_bar=True,
on_step=False,
on_epoch=True,
)
return label, prediction