"""
File: qumphy/data/pulsedb.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Functions handling DeepBeat dataset.
"""
import lightning as L
import numpy as np
import pandas as pd
import pathlib
import qumphy
import torch
[docs]
class DeepBeatDataset(torch.utils.data.Dataset):
"""DeepBeat dataset class."""
def __init__(
self,
data_directory: str,
subset: str,
dataset: str = "set_revised",
target_format: str = "binary",
normalize: bool = False,
dtype: torch.dtype = torch.float32,
load_data: bool = True,
data_fraction: float = 1.0,
filter_params: dict = None,
noise_params: dict = None,
input_sampling_rate: float | None = None,
split_to_input_sampling_rate: dict[str, float] | None = None,
target_sampling_rate: float | None = None,
):
"""
Initializes the DeepBeat Dataset.
Parameters
----------
data_directory : string
Full path of the directory containing the data.
subset : string
Choose between the "train", "val" or "test" subset.
dataset : string
Choose between "full", "set_revised" or "mini_revised".
target_format : string
Choose between "binary" ((1), float), "class_index" ((), long) or "class_probability" ((2), float).
normalize : bool
Normalize the data in the range [-1, 1].
dtype : torch.dtype
The data type of the dataset.
load_data : bool
Load the data.
data_fraction : float
Fraction of data to load (0.0-1.0). For test subset, this is ignored and full data is loaded.
filter_params : dict
Dictionary containing the parameters for the filter. If None, no filter is applied.
noise_params : dict
Dictionary containing the parameters for the noise. If None, no noise is added.
input_sampling_rate : float
Sampling frequency (Hz) of the stored input signals.This is ignored when split_to_input_sampling_rate is provided.
split_to_input_sampling_rate : dict
Dictionary mapping subset names to the sampling frequency.This is useful when different subsets are stored in different sampling rates.
target_sampling_rate : float
Choose a target sampling rate for the signals in the processing run.
"""
super().__init__()
self.subset = subset
self.dataset = dataset
self.target_format = target_format
if self.target_format not in ["binary", "class_index", "class_probability"]:
raise ValueError(
f"Unknown target shape: '{self.target_format}', expected 'binary', 'class_index' or 'class_probability'"
)
self.normalize = normalize
self.dtype = dtype
self.input_sampling_rate = input_sampling_rate
self.split_to_input_sampling_rate = split_to_input_sampling_rate
self.target_sampling_rate = target_sampling_rate
if load_data:
self.load_data(
pathlib.Path(data_directory), filter_params, noise_params, data_fraction
)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
data = self.data[index]
target = self.target[index]
data = torch.unsqueeze(data, -2)
return data, target
[docs]
def load_data(
self,
data_directory: pathlib.Path,
filter_params: dict = None,
noise_params: dict = None,
data_fraction: float = 1.0,
):
data = np.load(data_directory / "signals.npy", mmap_mode="r")
metadata = pd.read_csv(data_directory / "metadata.csv")
target = metadata["label"].to_numpy()
indices = self.select_subset_indices(metadata, data_fraction)
data = data[indices]
target = target[indices]
if noise_params is not None:
data = qumphy.data.signal_preprocessing.noise.add_noise(data, noise_params)
if filter_params is not None:
data = qumphy.data.signal_preprocessing.filters.apply_filter(
data, filter_params
)
data = torch.as_tensor(data.copy(), dtype=self.dtype)
target = torch.as_tensor(target, dtype=self.dtype)
if (
self.split_to_input_sampling_rate is not None
and self.subset in self.split_to_input_sampling_rate
):
fs_in = float(self.split_to_input_sampling_rate[self.subset])
elif self.input_sampling_rate is not None:
fs_in = float(self.input_sampling_rate)
else:
fs_in = 256.0
fs_out = (
float(self.target_sampling_rate)
if self.target_sampling_rate is not None
else fs_in
)
if fs_out != fs_in:
data = qumphy.data.signal_preprocessing.resampling.resample_like_matlab(
data, fs_in=fs_in, fs_out=fs_out, axis=-1
)
if self.target_format == "binary":
target = torch.unsqueeze(target, -1)
elif self.target_format == "class_index":
target = target.long()
elif self.target_format == "class_probability":
target = torch.tensor([1.0 - target, target])
if self.normalize:
for i in range(len(data)):
data[i] = self.normalize_data(data[i])
self.data = data
self.target = target
[docs]
def normalize_data(self, data):
"""
Rescales the data to the range [-1, 1].
Parameters
----------
data (array): The input data to be normalized.
Returns
-------
array: The normalized data in the range [-1, 1].
"""
data_min = torch.min(data)
data_max = torch.max(data)
delta = data_max - data_min
data -= data_min
data = data / delta * 2 - 1
return data
[docs]
def get_labels(self):
return self.target
[docs]
def get_data(self):
return self.data
[docs]
def select_subset_indices(
self, metadata: pd.core.frame.DataFrame, data_fraction: float = 1.0
) -> pd.core.indexes.base.Index:
"""
Set an index mask of the memmapped dataset based on the selected subset.
Parameters
----------
metadata : pd.core.frame.DataFrame
The metadata dataframe.
Returns
-------
pd.core.indexes.base.Index
The index mask.
"""
if self.dataset not in ["full", "set_revised", "mini_revised"]:
raise ValueError(
f"Unknown dataset: '{self.dataset}', expected 'full', 'set_revised' or 'mini_revised'"
)
if self.subset not in ["train", "val", "calib", "test"]:
raise ValueError(
f"Unknown subset: '{self.subset}', expected 'train', 'val', 'calib' or 'test'"
)
# Select the dataset
if self.dataset == "full":
subset_value = np.array(metadata["set"])
elif self.dataset == "set_revised":
subset_value = np.array(metadata["set_revised"])
if self.dataset == "mini_revised":
subset_value = np.array(metadata["set_revised"])
mini_mask = np.ones(subset_value.shape, dtype=bool)
mini_mask[::100] = False
subset_value[mini_mask] = -1
# Mask the subset
if self.subset == "train":
mask = subset_value == 0
elif self.subset == "val":
mask = subset_value == 1
elif self.subset == "calib":
mask = subset_value == 2
elif self.subset == "test":
mask = subset_value == (2 if self.dataset == "full" else 3)
indices = np.where(mask)[0]
if self.subset in ["train", "val", "calib"] and data_fraction < 1.0:
total_samples = len(indices)
num_samples = int(total_samples * data_fraction)
indices = np.random.choice(indices, size=num_samples, replace=False)
return indices
[docs]
class DeepBeatDataModule(L.LightningDataModule):
"""LightningDataModule implementation for the DeepBeat dataset."""
def __init__(
self,
sampling_rate: float,
batch_size: int,
num_workers: int,
pin_memory: bool = True,
prefetch_factor=8,
**dskwargs,
):
"""Initialize the DeepBeat data module.
Parameters
----------
batch_size : int
Number of samples per batch.
num_workers : int
Number of worker processes used by the dataloaders.
pin_memory : bool
If True, use pinned memory in the dataloaders.
prefetch_factor : int
Number of batches loaded in advance by each worker.
**dskwargs
Additional keyword arguments passed to DeepBeatDataset.
"""
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.prefetch_factor = prefetch_factor
self.dskwargs = dict(dskwargs)
self.dskwargs["target_sampling_rate"] = float(sampling_rate)
[docs]
def setup(self, stage):
if stage == "fit":
if hasattr(self, "train_ds"):
return
self.train_ds = DeepBeatDataset(subset="train", **self.dskwargs)
self.val_ds = DeepBeatDataset(subset="val", **self.dskwargs)
elif stage == "validate":
if hasattr(self, "val_ds"):
return
self.val_ds = DeepBeatDataset(subset="val", **self.dskwargs)
elif stage == "test":
if hasattr(self, "test_ds"):
return
self.test_ds = DeepBeatDataset(subset="test", **self.dskwargs)
[docs]
def train_dataloader(self):
"""Create the training dataloader.
Returns
-------
torch.utils.data.DataLoader
Dataloader for the training dataset.
"""
return torch.utils.data.DataLoader(
self.train_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=self.pin_memory,
prefetch_factor=self.prefetch_factor,
)
[docs]
def val_dataloader(self):
"""Create the validation dataloader.
Returns
-------
torch.utils.data.DataLoader
Dataloader for the validation dataset.
"""
return torch.utils.data.DataLoader(
self.val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=self.pin_memory,
prefetch_factor=self.prefetch_factor,
)
[docs]
def test_dataloader(self):
"""Create the test dataloader.
Returns
-------
torch.utils.data.DataLoader
Dataloader for the test dataset.
"""
return torch.utils.data.DataLoader(
self.test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=self.pin_memory,
prefetch_factor=self.prefetch_factor,
)