"""
File: qumphy/data/signal_preprocessing/noise.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Functions handling SleepApnea data.
"""
from pathlib import Path
import h5py
import lightning as L
import torch
from scipy.io import loadmat
from torch.utils.data import DataLoader
import qumphy
[docs]
class SleepApneaDataset(torch.utils.data.Dataset):
"""SleepApnea dataset class."""
split_to_data_key = {
"train": "signalsTrain",
"val": "signalsValidation",
"test_ID": "signalsTest_ID",
"test_OOD": "signalsTest_OOD",
}
split_to_label_key = {
"train": "labelsTrain",
"val": "labelsValidation",
"test_ID": "labelsTest_ID",
"test_OOD": "labelsTest_OOD",
}
def __init__(
self,
data_directory: str,
split: str,
normalize: bool = False,
dtype: torch.dtype = torch.float32,
data_fraction: float = 1.0,
filter_params: dict | None = None,
noise_params: dict | None = None,
input_sampling_rate: float | None = None,
split_to_input_sampling_rate: dict[str, float] | None = None,
target_sampling_rate: float | None = None,
):
"""
Initializes the SleepApnea Dataset.
Parameters
----------
data_directory : string
Full path of the directory containing the data.
split : string
Splits into "train", "val" or "test" subsets.
normalize : bool
Normalize the data in the range [-1, 1].
dtype : torch.dtype
The data type of the dataset.
data_fraction : float
Fraction of data to load (0.0-1.0). For test subset, this is ignored and full data is loaded.
filter_params : dict
Dictionary containing the parameters for the filter. If None, no filter is applied.
noise_params : dict
Dictionary containing the parameters for the noise. If None, no noise is added.
input_sampling_rate : float
Sampling frequency (Hz) of the stored input signals.This is ignored when split_to_input_sampling_rate is provided.
split_to_input_sampling_rate : dict
Dictionary mapping subset names to the sampling frequency.This is useful when different subsets are stored in different sampling rates.
target_sampling_rate : float
Choose a target sampling rate for the signals in the processing run.
"""
super().__init__()
assert split in self.split_to_data_key, f"Unknown split: {split}"
self.split = split
self.normalize = normalize
self.dtype = dtype
base_dir = Path(data_directory)
self.data_path = base_dir / "ApneaDetection_PPG_raw_segments.mat"
self.labels_path = base_dir / "ApneaDetection_Labels_corrected.mat"
# decide input fs
if (
split_to_input_sampling_rate is not None
and split in split_to_input_sampling_rate
):
fs_in = float(split_to_input_sampling_rate[split])
elif input_sampling_rate is not None:
fs_in = float(input_sampling_rate)
else:
fs_in = 256.0
# decide target fs
fs_out = (
float(target_sampling_rate) if target_sampling_rate is not None else fs_in
)
self.fs_in = fs_in
self.fs_out = fs_out
data_key = self.split_to_data_key[split]
label_key = self.split_to_label_key[split]
# load data
with h5py.File(self.data_path, "r") as f_data:
data_np = f_data[data_key][:]
# Ensure shape (N, C, T) before preprocessing
if data_np.ndim == 2:
data_np = data_np[:, None, :]
# Apply noise then filter
if noise_params is not None:
noise_params = dict(noise_params)
kwargs = dict(noise_params.get("kwargs", {}))
kwargs.setdefault("signal_frequency", self.fs_in)
noise_params["kwargs"] = kwargs
data_np = qumphy.data.signal_preprocessing.noise.add_noise(
data_np, noise_params
)
if filter_params is not None:
filter_params = dict(filter_params)
kwargs = dict(filter_params.get("kwargs", {}))
kwargs.setdefault("signal_frequency", self.fs_in)
filter_params["kwargs"] = kwargs
data_np = qumphy.data.signal_preprocessing.filtersf.apply_filter(
data_np, filter_params
)
data = torch.from_numpy(data_np).to(dtype=self.dtype)
# for classification
if data.ndim == 2:
data = data.unsqueeze(1)
# resample data
if self.fs_out != self.fs_in:
data = qumphy.data.signal_preprocessing.resampling.resample_like_matlab(
data, fs_in=self.fs_in, fs_out=self.fs_out, axis=-1
)
# load labels
lab = loadmat(str(self.labels_path))
if label_key not in lab:
raise KeyError(
f"{label_key} not found in {self.labels_path}. Keys: {lab.keys()}"
)
label_np = lab[label_key]
label_np = label_np.reshape(-1)
print(
f"[{split}] labels shape from {label_key}:",
label_np.shape,
"unique:",
set(label_np.tolist()),
)
label = torch.from_numpy(label_np).to(dtype=self.dtype)
label = label.unsqueeze(1)
assert data.shape[0] == label.shape[0], (
f"Data and labels size mismatch for {split}: "
f"{data.shape[0]} data vs {label.shape[0]} labels"
)
if self.normalize:
for i in range(len(data)):
data[i] = self.normalize_data(data[i])
self.data = data
self.label = label
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
return self.data[idx], self.label[idx]
[docs]
def normalize_data(self, data):
"""
Rescales the data to the range [-1, 1].
Parameters
----------
data (array): The input data to be normalized.
Returns
-------
array: The normalized data in the range [-1, 1].
"""
data_min = torch.min(data)
data_max = torch.max(data)
delta = data_max - data_min
data -= data_min
data = data / delta * 2 - 1
return data
[docs]
def get_labels(self):
return self.label
[docs]
def get_data(self):
return self.data
[docs]
class SleepApneaDataModule(L.LightningDataModule):
"""LightningDataModule implementation for the SleepApnea dataset."""
def __init__(
self,
sampling_rate: float,
batch_size: int,
num_workers: int,
pin_memory: bool = True,
prefetch_factor=8, # each worker laods 8 batches in advance
**dskwargs,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.prefetch_factor = prefetch_factor
self.sampling_rate = float(sampling_rate)
self.dskwargs = dict(dskwargs)
self.dskwargs["target_sampling_rate"] = self.sampling_rate
self.train_ds = None
self.val_ds = None
self.test_ID_ds = None
self.test_OOD_ds = None
[docs]
def setup(self, stage: str | None = None):
if stage in ("fit", "validate", None):
if self.train_ds is None:
self.train_ds = SleepApneaDataset(
split="train",
**self.dskwargs,
)
if self.val_ds is None:
self.val_ds = SleepApneaDataset(
split="val",
**self.dskwargs,
)
if stage in ("test", None):
if self.test_ID_ds is None:
self.test_ID_ds = SleepApneaDataset(
split="test_ID",
**self.dskwargs,
)
if self.test_OOD_ds is None:
self.test_OOD_ds = SleepApneaDataset(
split="test_OOD",
**self.dskwargs,
)
[docs]
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=self.pin_memory,
prefetch_factor=self.prefetch_factor,
)
[docs]
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=self.pin_memory,
prefetch_factor=self.prefetch_factor,
)
[docs]
def test_dataloader(self):
loaders = [
DataLoader(
self.test_ID_ds,
batch_size=self.batch_size,
shuffle=False,
prefetch_factor=self.prefetch_factor,
)
]
if self.test_OOD_ds is not None:
loaders.append(
DataLoader(
self.test_OOD_ds,
batch_size=self.batch_size,
shuffle=False,
prefetch_factor=self.prefetch_factor,
)
)
return loaders
# if __name__ == "__main__":
# ds_train = SleepApneaDataset(
# data_directory="/Users/numerics4/Desktop/PTB/qumphy-software/data/mesa_osasud",
# data_path="ApneaDetection_PPG_raw_segments_downsampled.mat",
# labels_path="/Users/numerics4/Desktop/PTB/qumphy-software/data/mesa_osasud/ApneaDetection_Labels_corrected.mat",
# split="train",
# )
# print("Train data:", ds_train.data.shape)
# print("Train labels:", ds_train.label.shape)
# print("Unique labels:", torch.unique(ds_train.label))
# ds_val = SleepApneaDataset(
# data_directory="/Users/numerics4/Desktop/PTB/qumphy-software/data/mesa_osasud",
# data_path="ApneaDetection_PPG_raw_segments_downsampled.mat",
# labels_path="/Users/numerics4/Desktop/PTB/qumphy-software/data/mesa_osasud/ApneaDetection_Labels_corrected.mat",
# split="val",
# )
# print("Val data:", ds_val.data.shape)
# print("Val labels:", ds_val.label.shape)
# print("Unique labels:", torch.unique(ds_val.label))