"""
Description: Lightning model for SleepApnea data.
"""
import torch
import lightning as L
import qumphy
[docs]
class SleepApneaModule(L.LightningModule):
"""
Lightning parent module for DeepBeat data.
Takes a specific model architecture as input. (net)
"""
def __init__(
self,
net,
optimizer,
output_activation=torch.nn.Identity(),
prediction_activation=torch.nn.Identity(),
loss_fn=torch.nn.CrossEntropyLoss(),
lr_scheduler=None,
):
"""
Parameters
----------
net : torch.nn.Module
The model architecture
optimizer : dict
The optimizer configuration
output_activation : torch.nn.Module, optional
The activation function after the last layer
loss_fn : torch.nn.Module, optional
The loss function
lr_scheduler : dict, optional
The learning rate scheduler configuration
"""
super().__init__()
self.loss_fn = loss_fn
self.net = net
self.output_activation = output_activation
self.prediction_activation = prediction_activation
self.lr_scheduler = lr_scheduler
self.optimizer_config = optimizer
[docs]
def forward(self, data):
data = self.net(data)
prediction = self.output_activation(data)
return prediction
[docs]
def training_step(self, batch, batch_idx):
return self._common_step(batch, batch_idx, "train")
[docs]
def validation_step(self, batch, batch_idx):
return self._common_step(batch, batch_idx, "val")
[docs]
def test_step(self, batch, batch_idx):
return self._common_step(batch, batch_idx, "test")
[docs]
def predict_step(self, batch, batch_idx):
data, label = batch
prediction = self(data)
prediction = self.prediction_activation(prediction)
return prediction
def _common_step(self, batch, batch_idx, stage):
data, label = batch
prediction = self(data)
loss = self.loss_fn(prediction, label)
prediction = self.prediction_activation(prediction[:, :2])
return {
"loss": loss,
"prediction": prediction,
"target": label,
}
[docs]
def set_lr_scheduler(self):
self.lr_scheduler["init_args"]["optimizer"] = self.optimizer
scheduler = qumphy.misc.misc.instantiate_class_from_string(
self.lr_scheduler["class_path"],
**self.lr_scheduler["init_args"],
)
self.lr_scheduler["config"]["scheduler"] = scheduler