"""File: qumphy/data/pulsedb.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Functions handling DeepBeat dataset.
"""
import torch
import lightning as L
import numpy as np
import pandas as pd
import pathlib
[docs]
class DeepBeatDataset(torch.utils.data.Dataset):
"""DeepBeat dataset class."""
def __init__(
self,
data_directory: str,
subset: str,
dataset: str = "set_revised",
target_format: str = "binary",
normalize: bool = False,
dtype: torch.dtype = torch.float32,
load_data: bool = True,
):
"""
Initializes the DeepBeat Dataset.
Parameters
----------
data_directory : string
Full path of the directory containing the data.
subset : string
Choose between the "train", "val" or "test" subset.
dataset : string
Choose between "full", "set_revised" or "mini_revised".
target_format : string
Choose between "binary" ((1), float), "class_index" ((), long) or "class_probability" ((2), float).
normalize : bool
Normalize the data in the range [-1, 1].
dtype : torch.dtype
The data type of the dataset.
load_data : bool
Load the data.
"""
super().__init__()
self.subset = subset
self.dataset = dataset
self.target_format = target_format
if self.target_format not in ["binary", "class_index", "class_probability"]:
raise ValueError(
f"Unknown target shape: '{self.target_format}', expected 'binary', 'class_index' or 'class_probability'"
)
self.normalize = normalize
self.dtype = dtype
data_directory = pathlib.Path(data_directory)
if load_data:
self.load_data(data_directory)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
data = self.data[index]
target = self.target[index]
data = torch.unsqueeze(data, -2)
if self.target_format == "binary":
target = torch.unsqueeze(target, -1)
elif self.target_format == "class_index":
target = target.long()
elif self.target_format == "class_probability":
target = torch.tensor([1.0 - target, target])
return data, target
[docs]
def load_data(
self,
data_directory: pathlib.Path,
):
data = np.load(data_directory / "signals.npy", mmap_mode="r")
metadata = pd.read_csv(data_directory / "metadata.csv")
target = metadata["label"].to_numpy()
indices = self.select_subset_indices(metadata)
data = data[indices]
target = target[indices]
data = torch.as_tensor(data, dtype=self.dtype)
target = torch.as_tensor(target, dtype=self.dtype)
if self.normalize:
for i in range(len(data)):
data[i] = self.normalize_data(data[i])
self.data = data
self.target = target
[docs]
def select_subset_indices(
self,
metadata: pd.core.frame.DataFrame,
) -> pd.core.indexes.base.Index:
"""Set an index mask of the memmapped dataset based on the selected subset.
Parameters
----------
metadata : pd.core.frame.DataFrame
The metadata dataframe.
Returns
-------
pd.core.indexes.base.Index
The index mask.
"""
if self.dataset not in ["full", "set_revised", "mini_revised"]:
raise ValueError(
f"Unknown dataset: '{self.dataset}', expected 'full', 'set_revised' or 'mini_revised'"
)
if self.subset not in ["train", "val", "calib", "test"]:
raise ValueError(
f"Unknown subset: '{self.subset}', expected 'train', 'val', 'calib' or 'test'"
)
# Select the dataset
if self.dataset == "full":
subset_value = np.array(metadata["set"])
elif self.dataset == "set_revised":
subset_value = np.array(metadata["set_revised"])
if self.dataset == "mini_revised":
subset_value = np.array(metadata["set_revised"])
mini_mask = np.ones(subset_value.shape, dtype=bool)
mini_mask[::100] = False
subset_value[mini_mask] = -1
# Mask the subset
if self.subset == "train":
mask = subset_value == 0
elif self.subset == "val":
mask = subset_value == 1
elif self.subset == "calib":
mask = subset_value == 2
elif self.subset == "test":
mask = subset_value == (2 if self.dataset == "full" else 3)
return mask
[docs]
def normalize_data(self, data):
"""
Rescales the data to the range [-1, 1].
Parameters
----------
data (array): The input data to be normalized.
Returns
-------
array: The normalized data in the range [-1, 1].
"""
data_min = torch.min(data)
data_max = torch.max(data)
delta = data_max - data_min
data -= data_min
data = data / delta * 2 - 1
return data
[docs]
def get_labels(self):
return self.target
[docs]
def get_data(self):
return self.data
[docs]
class DeepBeatDataModule(L.LightningDataModule):
"""LightningDataModule implementation for the DeepBeat dataset."""
def __init__(
self,
batch_size: int,
num_workers: int,
pin_memory: bool = True,
**dskwargs,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.dskwargs = dskwargs
[docs]
def setup(self, stage):
if stage == "fit":
if hasattr(self, "train_ds"):
return
self.train_ds = DeepBeatDataset(subset="train", **self.dskwargs)
self.val_ds = DeepBeatDataset(subset="val", **self.dskwargs)
elif stage == "validate":
if hasattr(self, "val_ds"):
return
self.val_ds = DeepBeatDataset(subset="val", **self.dskwargs)
elif stage == "test":
if hasattr(self, "test_ds"):
return
self.test_ds = DeepBeatDataset(subset="test", **self.dskwargs)
[docs]
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=self.pin_memory,
)
[docs]
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=self.pin_memory,
)
[docs]
def test_dataloader(self):
return torch.utils.data.DataLoader(
self.test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=self.pin_memory,
)