Source code for qumphy.data.sleepapnea

"""
File: qumphy/data/signal_preprocessing/noise.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Functions handling SleepApnea data.
"""

from pathlib import Path

import h5py
import lightning as L
import torch
from scipy.io import loadmat
from torch.utils.data import DataLoader

import qumphy


[docs] class SleepApneaDataset(torch.utils.data.Dataset): """SleepApnea dataset class.""" split_to_data_key = { "train": "signalsTrain", "val": "signalsValidation", "test_ID": "signalsTest_ID", "test_OOD": "signalsTest_OOD", } split_to_label_key = { "train": "labelsTrain", "val": "labelsValidation", "test_ID": "labelsTest_ID", "test_OOD": "labelsTest_OOD", } def __init__( self, data_directory: str, split: str, normalize: bool = False, dtype: torch.dtype = torch.float32, data_fraction: float = 1.0, filter_params: dict | None = None, noise_params: dict | None = None, input_sampling_rate: float | None = None, split_to_input_sampling_rate: dict[str, float] | None = None, target_sampling_rate: float | None = None, ): """ Initializes the SleepApnea Dataset. Parameters ---------- data_directory : string Full path of the directory containing the data. split : string Splits into "train", "val" or "test" subsets. normalize : bool Normalize the data in the range [-1, 1]. dtype : torch.dtype The data type of the dataset. data_fraction : float Fraction of data to load (0.0-1.0). For test subset, this is ignored and full data is loaded. filter_params : dict Dictionary containing the parameters for the filter. If None, no filter is applied. noise_params : dict Dictionary containing the parameters for the noise. If None, no noise is added. input_sampling_rate : float Sampling frequency (Hz) of the stored input signals.This is ignored when split_to_input_sampling_rate is provided. split_to_input_sampling_rate : dict Dictionary mapping subset names to the sampling frequency.This is useful when different subsets are stored in different sampling rates. target_sampling_rate : float Choose a target sampling rate for the signals in the processing run. """ super().__init__() assert split in self.split_to_data_key, f"Unknown split: {split}" self.split = split self.normalize = normalize self.dtype = dtype base_dir = Path(data_directory) self.data_path = base_dir / "ApneaDetection_PPG_raw_segments.mat" self.labels_path = base_dir / "ApneaDetection_Labels_corrected.mat" # decide input fs if ( split_to_input_sampling_rate is not None and split in split_to_input_sampling_rate ): fs_in = float(split_to_input_sampling_rate[split]) elif input_sampling_rate is not None: fs_in = float(input_sampling_rate) else: fs_in = 256.0 # decide target fs fs_out = ( float(target_sampling_rate) if target_sampling_rate is not None else fs_in ) self.fs_in = fs_in self.fs_out = fs_out data_key = self.split_to_data_key[split] label_key = self.split_to_label_key[split] # load data with h5py.File(self.data_path, "r") as f_data: data_np = f_data[data_key][:] # Ensure shape (N, C, T) before preprocessing if data_np.ndim == 2: data_np = data_np[:, None, :] # Apply noise then filter if noise_params is not None: noise_params = dict(noise_params) kwargs = dict(noise_params.get("kwargs", {})) kwargs.setdefault("signal_frequency", self.fs_in) noise_params["kwargs"] = kwargs data_np = qumphy.data.signal_preprocessing.noise.add_noise( data_np, noise_params ) if filter_params is not None: filter_params = dict(filter_params) kwargs = dict(filter_params.get("kwargs", {})) kwargs.setdefault("signal_frequency", self.fs_in) filter_params["kwargs"] = kwargs data_np = qumphy.data.signal_preprocessing.filtersf.apply_filter( data_np, filter_params ) data = torch.from_numpy(data_np).to(dtype=self.dtype) # for classification if data.ndim == 2: data = data.unsqueeze(1) # resample data if self.fs_out != self.fs_in: data = qumphy.data.signal_preprocessing.resampling.resample_like_matlab( data, fs_in=self.fs_in, fs_out=self.fs_out, axis=-1 ) # load labels lab = loadmat(str(self.labels_path)) if label_key not in lab: raise KeyError( f"{label_key} not found in {self.labels_path}. Keys: {lab.keys()}" ) label_np = lab[label_key] label_np = label_np.reshape(-1) print( f"[{split}] labels shape from {label_key}:", label_np.shape, "unique:", set(label_np.tolist()), ) label = torch.from_numpy(label_np).to(dtype=self.dtype) label = label.unsqueeze(1) assert data.shape[0] == label.shape[0], ( f"Data and labels size mismatch for {split}: " f"{data.shape[0]} data vs {label.shape[0]} labels" ) if self.normalize: for i in range(len(data)): data[i] = self.normalize_data(data[i]) self.data = data self.label = label def __len__(self): return self.data.shape[0] def __getitem__(self, idx): return self.data[idx], self.label[idx]
[docs] def normalize_data(self, data): """ Rescales the data to the range [-1, 1]. Parameters ---------- data (array): The input data to be normalized. Returns ------- array: The normalized data in the range [-1, 1]. """ data_min = torch.min(data) data_max = torch.max(data) delta = data_max - data_min data -= data_min data = data / delta * 2 - 1 return data
[docs] def get_labels(self): return self.label
[docs] def get_data(self): return self.data
[docs] class SleepApneaDataModule(L.LightningDataModule): """LightningDataModule implementation for the SleepApnea dataset.""" def __init__( self, sampling_rate: float, batch_size: int, num_workers: int, pin_memory: bool = True, prefetch_factor=8, # each worker laods 8 batches in advance **dskwargs, ): super().__init__() self.batch_size = batch_size self.num_workers = num_workers self.pin_memory = pin_memory self.prefetch_factor = prefetch_factor self.sampling_rate = float(sampling_rate) self.dskwargs = dict(dskwargs) self.dskwargs["target_sampling_rate"] = self.sampling_rate self.train_ds = None self.val_ds = None self.test_ID_ds = None self.test_OOD_ds = None
[docs] def setup(self, stage: str | None = None): if stage in ("fit", "validate", None): if self.train_ds is None: self.train_ds = SleepApneaDataset( split="train", **self.dskwargs, ) if self.val_ds is None: self.val_ds = SleepApneaDataset( split="val", **self.dskwargs, ) if stage in ("test", None): if self.test_ID_ds is None: self.test_ID_ds = SleepApneaDataset( split="test_ID", **self.dskwargs, ) if self.test_OOD_ds is None: self.test_OOD_ds = SleepApneaDataset( split="test_OOD", **self.dskwargs, )
[docs] def train_dataloader(self): return torch.utils.data.DataLoader( self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=self.pin_memory, prefetch_factor=self.prefetch_factor, )
[docs] def val_dataloader(self): return torch.utils.data.DataLoader( self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=self.pin_memory, prefetch_factor=self.prefetch_factor, )
[docs] def test_dataloader(self): loaders = [ DataLoader( self.test_ID_ds, batch_size=self.batch_size, shuffle=False, prefetch_factor=self.prefetch_factor, ) ] if self.test_OOD_ds is not None: loaders.append( DataLoader( self.test_OOD_ds, batch_size=self.batch_size, shuffle=False, prefetch_factor=self.prefetch_factor, ) ) return loaders
# if __name__ == "__main__": # ds_train = SleepApneaDataset( # data_directory="/Users/numerics4/Desktop/PTB/qumphy-software/data/mesa_osasud", # data_path="ApneaDetection_PPG_raw_segments_downsampled.mat", # labels_path="/Users/numerics4/Desktop/PTB/qumphy-software/data/mesa_osasud/ApneaDetection_Labels_corrected.mat", # split="train", # ) # print("Train data:", ds_train.data.shape) # print("Train labels:", ds_train.label.shape) # print("Unique labels:", torch.unique(ds_train.label)) # ds_val = SleepApneaDataset( # data_directory="/Users/numerics4/Desktop/PTB/qumphy-software/data/mesa_osasud", # data_path="ApneaDetection_PPG_raw_segments_downsampled.mat", # labels_path="/Users/numerics4/Desktop/PTB/qumphy-software/data/mesa_osasud/ApneaDetection_Labels_corrected.mat", # split="val", # ) # print("Val data:", ds_val.data.shape) # print("Val labels:", ds_val.label.shape) # print("Unique labels:", torch.unique(ds_val.label))