Source code for qumphy.callbacks.sleepapnea

"""
Description: Pytorch lightning callbacks.
"""

import qumphy
from qumphy.callbacks import base_logging
import torch


[docs] class SleepApneaLogging(base_logging.BaseLoggingCallback): """ Pytorch lightning callback for logging for the sleepApnea dataset. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def log_epoch_end(self, trainer, pl_module, stage: str): label = torch.cat(self.outputs[stage]["target"], dim=0).cpu() prediction = torch.cat(self.outputs[stage]["prediction"], dim=0).detach().cpu() label = label[:, 0] prediction = prediction[:, 0] try: metrics_dict = qumphy.metrics.all_binary_metrics( label.numpy(), prediction.numpy() ) except ValueError: print(f"[{stage}]Prediction has nan\n", prediction) metrics_dict = {} # metrics_dict = qumphy.metrics.all_binary_metrics( # target.numpy(), prediction.numpy() # ) for key, value in metrics_dict.items(): pl_module.log( f"{stage}_{key}", value, sync_dist=False, prog_bar=True, on_step=False, on_epoch=True, ) return label, prediction