Source code for qumphy.data.deepbeat

"""File: qumphy/data/pulsedb.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Functions handling DeepBeat dataset.
"""

import torch
import lightning as L
import numpy as np
import pandas as pd
import pathlib


[docs] class DeepBeatDataset(torch.utils.data.Dataset): """DeepBeat dataset class.""" def __init__( self, data_directory: str, subset: str, dataset: str = "set_revised", target_format: str = "binary", normalize: bool = False, dtype: torch.dtype = torch.float32, load_data: bool = True, ): """ Initializes the DeepBeat Dataset. Parameters ---------- data_directory : string Full path of the directory containing the data. subset : string Choose between the "train", "val" or "test" subset. dataset : string Choose between "full", "set_revised" or "mini_revised". target_format : string Choose between "binary" ((1), float), "class_index" ((), long) or "class_probability" ((2), float). normalize : bool Normalize the data in the range [-1, 1]. dtype : torch.dtype The data type of the dataset. load_data : bool Load the data. """ super().__init__() self.subset = subset self.dataset = dataset self.target_format = target_format if self.target_format not in ["binary", "class_index", "class_probability"]: raise ValueError( f"Unknown target shape: '{self.target_format}', expected 'binary', 'class_index' or 'class_probability'" ) self.normalize = normalize self.dtype = dtype data_directory = pathlib.Path(data_directory) if load_data: self.load_data(data_directory) def __len__(self): return self.data.shape[0] def __getitem__(self, index): data = self.data[index] target = self.target[index] data = torch.unsqueeze(data, -2) if self.target_format == "binary": target = torch.unsqueeze(target, -1) elif self.target_format == "class_index": target = target.long() elif self.target_format == "class_probability": target = torch.tensor([1.0 - target, target]) return data, target
[docs] def load_data( self, data_directory: pathlib.Path, ): data = np.load(data_directory / "signals.npy", mmap_mode="r") metadata = pd.read_csv(data_directory / "metadata.csv") target = metadata["label"].to_numpy() indices = self.select_subset_indices(metadata) data = data[indices] target = target[indices] data = torch.as_tensor(data, dtype=self.dtype) target = torch.as_tensor(target, dtype=self.dtype) if self.normalize: for i in range(len(data)): data[i] = self.normalize_data(data[i]) self.data = data self.target = target
[docs] def select_subset_indices( self, metadata: pd.core.frame.DataFrame, ) -> pd.core.indexes.base.Index: """Set an index mask of the memmapped dataset based on the selected subset. Parameters ---------- metadata : pd.core.frame.DataFrame The metadata dataframe. Returns ------- pd.core.indexes.base.Index The index mask. """ if self.dataset not in ["full", "set_revised", "mini_revised"]: raise ValueError( f"Unknown dataset: '{self.dataset}', expected 'full', 'set_revised' or 'mini_revised'" ) if self.subset not in ["train", "val", "calib", "test"]: raise ValueError( f"Unknown subset: '{self.subset}', expected 'train', 'val', 'calib' or 'test'" ) # Select the dataset if self.dataset == "full": subset_value = np.array(metadata["set"]) elif self.dataset == "set_revised": subset_value = np.array(metadata["set_revised"]) if self.dataset == "mini_revised": subset_value = np.array(metadata["set_revised"]) mini_mask = np.ones(subset_value.shape, dtype=bool) mini_mask[::100] = False subset_value[mini_mask] = -1 # Mask the subset if self.subset == "train": mask = subset_value == 0 elif self.subset == "val": mask = subset_value == 1 elif self.subset == "calib": mask = subset_value == 2 elif self.subset == "test": mask = subset_value == (2 if self.dataset == "full" else 3) return mask
[docs] def normalize_data(self, data): """ Rescales the data to the range [-1, 1]. Parameters ---------- data (array): The input data to be normalized. Returns ------- array: The normalized data in the range [-1, 1]. """ data_min = torch.min(data) data_max = torch.max(data) delta = data_max - data_min data -= data_min data = data / delta * 2 - 1 return data
[docs] def get_labels(self): return self.target
[docs] def get_data(self): return self.data
[docs] class DeepBeatDataModule(L.LightningDataModule): """LightningDataModule implementation for the DeepBeat dataset.""" def __init__( self, batch_size: int, num_workers: int, pin_memory: bool = True, **dskwargs, ): super().__init__() self.batch_size = batch_size self.num_workers = num_workers self.pin_memory = pin_memory self.dskwargs = dskwargs
[docs] def setup(self, stage): if stage == "fit": if hasattr(self, "train_ds"): return self.train_ds = DeepBeatDataset(subset="train", **self.dskwargs) self.val_ds = DeepBeatDataset(subset="val", **self.dskwargs) elif stage == "validate": if hasattr(self, "val_ds"): return self.val_ds = DeepBeatDataset(subset="val", **self.dskwargs) elif stage == "test": if hasattr(self, "test_ds"): return self.test_ds = DeepBeatDataset(subset="test", **self.dskwargs)
[docs] def train_dataloader(self): return torch.utils.data.DataLoader( self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=self.pin_memory, )
[docs] def val_dataloader(self): return torch.utils.data.DataLoader( self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=self.pin_memory, )
[docs] def test_dataloader(self): return torch.utils.data.DataLoader( self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=self.pin_memory, )