Source code for qumphy.data.deepbeat

"""
File: qumphy/data/pulsedb.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Functions handling DeepBeat dataset.
"""

import lightning as L
import numpy as np
import pandas as pd
import pathlib
import qumphy
import torch


[docs] class DeepBeatDataset(torch.utils.data.Dataset): """DeepBeat dataset class.""" def __init__( self, data_directory: str, subset: str, dataset: str = "set_revised", target_format: str = "binary", normalize: bool = False, dtype: torch.dtype = torch.float32, load_data: bool = True, data_fraction: float = 1.0, filter_params: dict = None, noise_params: dict = None, input_sampling_rate: float | None = None, split_to_input_sampling_rate: dict[str, float] | None = None, target_sampling_rate: float | None = None, ): """ Initializes the DeepBeat Dataset. Parameters ---------- data_directory : string Full path of the directory containing the data. subset : string Choose between the "train", "val" or "test" subset. dataset : string Choose between "full", "set_revised" or "mini_revised". target_format : string Choose between "binary" ((1), float), "class_index" ((), long) or "class_probability" ((2), float). normalize : bool Normalize the data in the range [-1, 1]. dtype : torch.dtype The data type of the dataset. load_data : bool Load the data. data_fraction : float Fraction of data to load (0.0-1.0). For test subset, this is ignored and full data is loaded. filter_params : dict Dictionary containing the parameters for the filter. If None, no filter is applied. noise_params : dict Dictionary containing the parameters for the noise. If None, no noise is added. input_sampling_rate : float Sampling frequency (Hz) of the stored input signals.This is ignored when split_to_input_sampling_rate is provided. split_to_input_sampling_rate : dict Dictionary mapping subset names to the sampling frequency.This is useful when different subsets are stored in different sampling rates. target_sampling_rate : float Choose a target sampling rate for the signals in the processing run. """ super().__init__() self.subset = subset self.dataset = dataset self.target_format = target_format if self.target_format not in ["binary", "class_index", "class_probability"]: raise ValueError( f"Unknown target shape: '{self.target_format}', expected 'binary', 'class_index' or 'class_probability'" ) self.normalize = normalize self.dtype = dtype self.input_sampling_rate = input_sampling_rate self.split_to_input_sampling_rate = split_to_input_sampling_rate self.target_sampling_rate = target_sampling_rate if load_data: self.load_data( pathlib.Path(data_directory), filter_params, noise_params, data_fraction ) def __len__(self): return self.data.shape[0] def __getitem__(self, index): data = self.data[index] target = self.target[index] data = torch.unsqueeze(data, -2) return data, target
[docs] def load_data( self, data_directory: pathlib.Path, filter_params: dict = None, noise_params: dict = None, data_fraction: float = 1.0, ): data = np.load(data_directory / "signals.npy", mmap_mode="r") metadata = pd.read_csv(data_directory / "metadata.csv") target = metadata["label"].to_numpy() indices = self.select_subset_indices(metadata, data_fraction) data = data[indices] target = target[indices] if noise_params is not None: data = qumphy.data.signal_preprocessing.noise.add_noise(data, noise_params) if filter_params is not None: data = qumphy.data.signal_preprocessing.filters.apply_filter( data, filter_params ) data = torch.as_tensor(data.copy(), dtype=self.dtype) target = torch.as_tensor(target, dtype=self.dtype) if ( self.split_to_input_sampling_rate is not None and self.subset in self.split_to_input_sampling_rate ): fs_in = float(self.split_to_input_sampling_rate[self.subset]) elif self.input_sampling_rate is not None: fs_in = float(self.input_sampling_rate) else: fs_in = 256.0 fs_out = ( float(self.target_sampling_rate) if self.target_sampling_rate is not None else fs_in ) if fs_out != fs_in: data = qumphy.data.signal_preprocessing.resampling.resample_like_matlab( data, fs_in=fs_in, fs_out=fs_out, axis=-1 ) if self.target_format == "binary": target = torch.unsqueeze(target, -1) elif self.target_format == "class_index": target = target.long() elif self.target_format == "class_probability": target = torch.tensor([1.0 - target, target]) if self.normalize: for i in range(len(data)): data[i] = self.normalize_data(data[i]) self.data = data self.target = target
[docs] def normalize_data(self, data): """ Rescales the data to the range [-1, 1]. Parameters ---------- data (array): The input data to be normalized. Returns ------- array: The normalized data in the range [-1, 1]. """ data_min = torch.min(data) data_max = torch.max(data) delta = data_max - data_min data -= data_min data = data / delta * 2 - 1 return data
[docs] def get_labels(self): return self.target
[docs] def get_data(self): return self.data
[docs] def select_subset_indices( self, metadata: pd.core.frame.DataFrame, data_fraction: float = 1.0 ) -> pd.core.indexes.base.Index: """ Set an index mask of the memmapped dataset based on the selected subset. Parameters ---------- metadata : pd.core.frame.DataFrame The metadata dataframe. Returns ------- pd.core.indexes.base.Index The index mask. """ if self.dataset not in ["full", "set_revised", "mini_revised"]: raise ValueError( f"Unknown dataset: '{self.dataset}', expected 'full', 'set_revised' or 'mini_revised'" ) if self.subset not in ["train", "val", "calib", "test"]: raise ValueError( f"Unknown subset: '{self.subset}', expected 'train', 'val', 'calib' or 'test'" ) # Select the dataset if self.dataset == "full": subset_value = np.array(metadata["set"]) elif self.dataset == "set_revised": subset_value = np.array(metadata["set_revised"]) if self.dataset == "mini_revised": subset_value = np.array(metadata["set_revised"]) mini_mask = np.ones(subset_value.shape, dtype=bool) mini_mask[::100] = False subset_value[mini_mask] = -1 # Mask the subset if self.subset == "train": mask = subset_value == 0 elif self.subset == "val": mask = subset_value == 1 elif self.subset == "calib": mask = subset_value == 2 elif self.subset == "test": mask = subset_value == (2 if self.dataset == "full" else 3) indices = np.where(mask)[0] if self.subset in ["train", "val", "calib"] and data_fraction < 1.0: total_samples = len(indices) num_samples = int(total_samples * data_fraction) indices = np.random.choice(indices, size=num_samples, replace=False) return indices
[docs] class DeepBeatDataModule(L.LightningDataModule): """LightningDataModule implementation for the DeepBeat dataset.""" def __init__( self, sampling_rate: float, batch_size: int, num_workers: int, pin_memory: bool = True, prefetch_factor=8, **dskwargs, ): """Initialize the DeepBeat data module. Parameters ---------- batch_size : int Number of samples per batch. num_workers : int Number of worker processes used by the dataloaders. pin_memory : bool If True, use pinned memory in the dataloaders. prefetch_factor : int Number of batches loaded in advance by each worker. **dskwargs Additional keyword arguments passed to DeepBeatDataset. """ super().__init__() self.batch_size = batch_size self.num_workers = num_workers self.pin_memory = pin_memory self.prefetch_factor = prefetch_factor self.dskwargs = dict(dskwargs) self.dskwargs["target_sampling_rate"] = float(sampling_rate)
[docs] def setup(self, stage): if stage == "fit": if hasattr(self, "train_ds"): return self.train_ds = DeepBeatDataset(subset="train", **self.dskwargs) self.val_ds = DeepBeatDataset(subset="val", **self.dskwargs) elif stage == "validate": if hasattr(self, "val_ds"): return self.val_ds = DeepBeatDataset(subset="val", **self.dskwargs) elif stage == "test": if hasattr(self, "test_ds"): return self.test_ds = DeepBeatDataset(subset="test", **self.dskwargs)
[docs] def train_dataloader(self): """Create the training dataloader. Returns ------- torch.utils.data.DataLoader Dataloader for the training dataset. """ return torch.utils.data.DataLoader( self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=self.pin_memory, prefetch_factor=self.prefetch_factor, )
[docs] def val_dataloader(self): """Create the validation dataloader. Returns ------- torch.utils.data.DataLoader Dataloader for the validation dataset. """ return torch.utils.data.DataLoader( self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=self.pin_memory, prefetch_factor=self.prefetch_factor, )
[docs] def test_dataloader(self): """Create the test dataloader. Returns ------- torch.utils.data.DataLoader Dataloader for the test dataset. """ return torch.utils.data.DataLoader( self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=self.pin_memory, prefetch_factor=self.prefetch_factor, )