Source code for qumphy.models.apnea

"""
Description: Lightning model for SleepApnea data.
"""

import torch
import lightning as L
import qumphy


[docs] class SleepApneaModule(L.LightningModule): """ Lightning parent module for DeepBeat data. Takes a specific model architecture as input. (net) """ def __init__( self, net, optimizer, output_activation=torch.nn.Identity(), prediction_activation=torch.nn.Identity(), loss_fn=torch.nn.CrossEntropyLoss(), lr_scheduler=None, ): """ Parameters ---------- net : torch.nn.Module The model architecture optimizer : dict The optimizer configuration output_activation : torch.nn.Module, optional The activation function after the last layer loss_fn : torch.nn.Module, optional The loss function lr_scheduler : dict, optional The learning rate scheduler configuration """ super().__init__() self.loss_fn = loss_fn self.net = net self.output_activation = output_activation self.prediction_activation = prediction_activation self.lr_scheduler = lr_scheduler self.optimizer_config = optimizer
[docs] def forward(self, data): data = self.net(data) prediction = self.output_activation(data) return prediction
[docs] def training_step(self, batch, batch_idx): return self._common_step(batch, batch_idx, "train")
[docs] def validation_step(self, batch, batch_idx): return self._common_step(batch, batch_idx, "val")
[docs] def test_step(self, batch, batch_idx): return self._common_step(batch, batch_idx, "test")
[docs] def predict_step(self, batch, batch_idx): data, label = batch prediction = self(data) prediction = self.prediction_activation(prediction) return prediction
def _common_step(self, batch, batch_idx, stage): data, label = batch prediction = self(data) loss = self.loss_fn(prediction, label) prediction = self.prediction_activation(prediction[:, :2]) return { "loss": loss, "prediction": prediction, "target": label, }
[docs] def configure_optimizers(self): if "params" in self.optimizer_config["init_args"]: for param_group in self.optimizer_config["init_args"]["params"]: param_group["params"] = eval(param_group["params"]) else: self.optimizer_config["init_args"]["params"] = self.parameters() self.optimizer = qumphy.misc.misc.instantiate_class_from_string( self.optimizer_config["class_path"], **self.optimizer_config["init_args"] ) if self.lr_scheduler is not None: self.set_lr_scheduler() return { "optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler["config"], } else: return {"optimizer": self.optimizer}
[docs] def set_lr_scheduler(self): self.lr_scheduler["init_args"]["optimizer"] = self.optimizer scheduler = qumphy.misc.misc.instantiate_class_from_string( self.lr_scheduler["class_path"], **self.lr_scheduler["init_args"], ) self.lr_scheduler["config"]["scheduler"] = scheduler