Source code for qumphy.data.pulsedb

"""
File: qumphy/data/pulsedb.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Functions handling PulseDB data.
"""

import qumphy
import torch
import lightning as L
import numpy as np
import os
import pandas as pd
import yaml
import pathlib


[docs] class PulseDBDataset(torch.utils.data.Dataset): """ Dataset Class for the PulseDB dataset. To use the PulseDB Dataset, the data should be in a directory containing signals.npy and metadata.csv. Before using the dataset, run the `write_target_stats_yaml(data_directory)` function to create a stats.yaml file containint the statistics of the target, i.e., the mean, median, std, and baseline measures of the SBP and DBP. """ def __init__( self, data_directory: str, dataset: str, subset: str, source: str = "all", pressure: str = "both", normalize: bool = False, dtype: torch.dtype = torch.float32, load_data: bool = True, data_fraction: float = 1.0, filter_params: dict = None, noise_params: dict = None, input_sampling_rate: float | None = None, split_to_input_sampling_rate: dict[str, float] | None = None, target_sampling_rate: float | None = None, ): """ Initializes the Miiciii Dataset. Parameters ---------- data_directory : string Full path of the directory containing the data. dataset : string Choose between the "calibfree", "calib", "aami" or "mini" dataset. subset : string Choose between the "train", "val", "calib" or "test" subset. source : string Either "mimiciii" or "vital" or "all" (both). pressure : string Either "both", "sbp" or "dbp". normalize : bool Normalize the data. dtype : torch.dtype The data type of the dataset. load_data : bool Load the data. data_fraction : float Fraction of data to load (0.0-1.0). For test subset, this is ignored and full data is loaded. filter_params : dict Dictionary containing the parameters for the filter. If None, no filter is applied. noise_params : dict Dictionary containing the parameters for the noise. If None, no noise is added. input_sampling_rate : float Sampling frequency (Hz) of the stored input signals.This is ignored when split_to_input_sampling_rate is provided. split_to_input_sampling_rate : dict Dictionary mapping subset names to the sampling frequency.This is useful when different subsets are stored in different sampling rates. target_sampling_rate : float Choose a target sampling rate for the signals in the processing run. """ super().__init__() # Data == Signal # Target == Label self.dtype = dtype self.normalize = normalize self.dataset = dataset self.subset = subset self.source = source self.data_fraction = data_fraction self.input_sampling_rate = input_sampling_rate self.split_to_input_sampling_rate = split_to_input_sampling_rate self.target_sampling_rate = target_sampling_rate if pressure not in ["both", "sbp", "dbp"]: raise ValueError("Pressure must be either 'both', 'sbp' or 'dbp'") self.pressure = pressure data_directory = pathlib.Path(data_directory) self.load_target_stats(data_directory) if load_data: self.load_data(data_directory, filter_params, noise_params) def __len__(self): return self.data.shape[0] def __getitem__(self, index): data = self.data[index] if self.pressure == "both": target = self.target[index] elif self.pressure == "sbp": target = self.target[index, 0:1] elif self.pressure == "dbp": target = self.target[index, 1:2] data = torch.unsqueeze(data, -2) return data, target
[docs] def calculate_target_stats(self): mean = torch.mean(self.target, axis=0) std = torch.std(self.target, axis=0) median = torch.median(self.target, axis=0).values return mean, std, median
[docs] def load_target_stats(self, data_directory: pathlib.Path) -> None: """ Reads the target statistics of the PulseDB training datasets from a YAML file. Parameters ---------- data_directory: pathlib.Path The directory where the data is located. Returns ------- None """ filepath = data_directory / "stats.yaml" if not os.path.exists(filepath): print(f"File not found: {filepath}") print("Target statistics not set.") return with open(filepath, "r") as stream: params = yaml.full_load(stream) if params is None: print(f"File is empty: {filepath}") print("Target statistics not set.") return prefix = "" if self.source: prefix += self.source + "_" prefix += self.dataset + "_" d = { "BP_mean": ["sbp_mean", "dbp_mean"], "BP_std": ["sbp_std", "dbp_std"], "BP_median": ["sbp_median", "dbp_median"], "MAE_baseline": ["sbp_mae", "dbp_mae"], "RMSE_baseline": ["sbp_rmse", "dbp_rmse"], } for key, val in d.items(): if prefix + val[0] not in params: continue setattr( self, key, torch.tensor( [params[prefix + val[0]], params[prefix + val[1]]], dtype=self.dtype, ), )
[docs] def get_target_stats(self): """Returns the statistics of the target. Returns ------- tuple (BP_mean, BP_std, BP_median, MAE_baseline, RMSE_baseline) """ return ( self.BP_mean, self.BP_std, self.BP_median, self.MAE_baseline, self.RMSE_baseline, )
[docs] def load_data( self, data_directory: pathlib.Path, filter_params: dict = None, noise_params: dict = None, ): data = np.load(data_directory / "signals.npy", mmap_mode="r") metadata = pd.read_csv(data_directory / "metadata.csv") sbp = metadata["sbp_avg"].to_numpy() dbp = metadata["dbp_avg"].to_numpy() target = np.stack([sbp, dbp], axis=1) indices = self.select_subset_indices(metadata) data = data[indices] target = target[indices] if noise_params is not None: data = qumphy.data.signal_preprocessing.noise.add_noise(data, noise_params) if filter_params is not None: data = qumphy.data.signal_preprocessing.filters.apply_filter( data, filter_params ) data = torch.as_tensor(data.copy(), dtype=self.dtype) # copy needed for strides target = torch.as_tensor(target, dtype=self.dtype) if ( self.split_to_input_sampling_rate is not None and self.subset in self.split_to_input_sampling_rate ): fs_in = float(self.split_to_input_sampling_rate[self.subset]) elif self.input_sampling_rate is not None: fs_in = float(self.input_sampling_rate) else: fs_in = 256.0 # Decide target fs fs_out = ( float(self.target_sampling_rate) if self.target_sampling_rate is not None else fs_in ) # Resample along time axis if fs_out != fs_in: data = qumphy.data.signal_preprocessing.resampling.resample_like_matlab( data, fs_in=fs_in, fs_out=fs_out, axis=-1 ) if self.normalize: for i in range(len(data)): data[i] = self.normalize_data(data[i]) target[i] = self.normalize_target(target[i]) self.data = data self.target = target
[docs] def calculate_baseline_measures(self): if self.normalize: target = torch.tensor(self.target) target = self.normalize_target(target, denormalize=True) else: target = self.target # MAE baseline is median mae_baseline = qumphy.metrics.mean_absolute_error( target.detach().numpy(), self.BP_median.detach().numpy() ) # RMSE baseline is mean rmse_baseline = qumphy.metrics.root_mean_square_error( target.detach().numpy(), self.BP_mean.detach().numpy() ) mae_baseline = torch.tensor(mae_baseline, dtype=self.dtype) rmse_baseline = torch.tensor(rmse_baseline, dtype=self.dtype) return mae_baseline, rmse_baseline
[docs] def select_subset_indices( self, metadata: pd.core.frame.DataFrame, ) -> pd.core.indexes.base.Index: """Return an index mask of the memmapped dataset based on the selected subset. Parameters ---------- metadata : pd.core.frame.DataFrame The metadata dataframe. Returns ------- pd.core.indexes.base.Index The index mask. """ if self.dataset not in ["calibfree", "calib", "aami", "mini"]: raise ValueError(f"Unknown dataset: '{self.dataset}'") if self.subset not in ["train", "val", "calib", "test"]: raise ValueError(f"Unknown subset: '{self.subset}'") if self.source not in ["all", "mimiciii", "vital"]: raise ValueError(f"Unknown source: '{self.source}'") # mask the source source_value = metadata["source"] if self.source == "all": source_mask = True elif self.source == "mimiciii": source_mask = source_value == 0 elif self.source == "vital": source_mask = source_value == 1 # mask the subset if self.dataset == "mini": subset_value = np.array(metadata["set"]) mask = np.ones(subset_value.shape, dtype=bool) mask[::100] = False subset_value[mask] = -1 else: subset_value = metadata[f"set_{self.dataset}"] if self.subset == "train": subset_mask = subset_value == 0 elif self.subset == "val": subset_mask = subset_value == 1 elif self.subset == "calib": subset_mask = subset_value == 2 elif self.subset == "test": subset_mask = subset_value == 3 # selected_indices = random.sample(indices, num_samples) # indices = selected_indices indices = metadata[source_mask & subset_mask].index if self.subset in ["train", "val", "calib"] and self.data_fraction < 1.0: total_samples = len(indices) num_samples = int(total_samples * self.data_fraction) indices = np.random.choice(indices, size=num_samples, replace=False) return indices
[docs] def normalize_data(self, data): """ Rescales the data to the range [-1, 1]. Parameters ---------- data (array): The input data to be normalized. Returns ------- array: The normalized data in the range [-1, 1]. """ data_min = torch.min(data) data_max = torch.max(data) delta = data_max - data_min data = data - data_min data = data / delta * 2 - 1 return data
[docs] def normalize_target(self, target, denormalize=False): """ Rescales the target to 0 mean and 1 standard deviation, using the BP_mean and BP_std attributes. Parameters ---------- target : np.ndarray The input target to be normalized. denormalize : bool, optional If True, the target will be denormalized back to its original scale. Returns ------- np.ndarray The (de)normalized target. """ if denormalize: target = target * self.BP_std + self.BP_mean else: target = (target - self.BP_mean) / self.BP_std return target
[docs] def get_labels(self): return self.target
[docs] def get_data(self): return self.data
[docs] class PulseDBDataModule(L.LightningDataModule): """LightningDataModule implementation for the PulseDB dataset.""" def __init__( self, data_directory, dataset, source, batch_size, num_workers, sampling_rate: float, prefetch_factor=8, **dskwargs, ): super().__init__() self.data_directory = data_directory self.batch_size = batch_size self.num_workers = num_workers self.dataset = dataset self.source = source self.prefetch_factor = prefetch_factor self.dskwargs = dict(dskwargs) self.dskwargs["target_sampling_rate"] = float(sampling_rate)
[docs] def setup(self, stage): if stage == "fit": if hasattr(self, "train_ds"): return self.train_ds = PulseDBDataset( data_directory=self.data_directory, dataset=self.dataset, subset="train", source=self.source, normalize=True, **self.dskwargs, ) # print(self.train_ds) self.val_ds = PulseDBDataset( data_directory=self.data_directory, dataset=self.dataset, subset="val", source=self.source, normalize=True, **self.dskwargs, ) elif stage == "validate": if hasattr(self, "val_ds"): return self.val_ds = PulseDBDataset( data_directory=self.data_directory, dataset=self.dataset, subset="val", source=self.source, normalize=True, **self.dskwargs, ) elif stage == "test": if hasattr(self, "test_ds"): return self.test_ds = PulseDBDataset( data_directory=self.data_directory, dataset=self.dataset, subset="test", source=self.source, normalize=True, **self.dskwargs, )
[docs] def train_dataloader(self): return torch.utils.data.DataLoader( self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True, prefetch_factor=self.prefetch_factor, )
[docs] def val_dataloader(self): return torch.utils.data.DataLoader( self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=True, prefetch_factor=self.prefetch_factor, )
[docs] def test_dataloader(self): return torch.utils.data.DataLoader( self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=True, prefetch_factor=self.prefetch_factor, )
[docs] def get_target_stats(self): if not hasattr(self, "train_ds"): self.setup(stage="fit") return self.train_ds.get_target_stats()
[docs] def get_target_stats( data_directory: str, dataset: str, source: str, dtype: np.dtype = np.float16 ) -> dict: """ Reads the target statistics of the PulseDB training datasets from a YAML file. Parameters ---------- data_directory: string Full path of the directory containing the data. dataset: string Choose between the "calibfree", "calib", "aami" or "mini" dataset. source: string Either "mimiciii" or "vital" or "all" (both). dtype: np.dtype The data type of the target statistics. Returns ------- dict The target statistics as a dictionary. """ filepath = pathlib.Path(data_directory) / "stats.yaml" with open(filepath, "r") as stream: params = yaml.full_load(stream) if params is None: print(f"File is empty: {filepath}") print("Target statistics not set.") return prefix = "" prefix += source + "_" prefix += dataset + "_" stats = { "BP_mean": ["sbp_mean", "dbp_mean"], "BP_std": ["sbp_std", "dbp_std"], "BP_median": ["sbp_median", "dbp_median"], "MAE_baseline": ["sbp_mae", "dbp_mae"], "RMSE_baseline": ["sbp_rmse", "dbp_rmse"], } for key, val in stats.items(): stats[key] = np.array( [params[prefix + val[0]], params[prefix + val[1]]], dtype=dtype, ) return stats
[docs] def write_target_stats_yaml(data_directory, verbose=True): """ Writes the target statistics of the PulseDB training datasets to a YAML file. Parameters ---------- data_directory: string The directory where the data is located. Returns ------- None """ with open(os.path.join(data_directory, "stats.yaml"), "w") as outfile: outfile.write("#### Parameters of the PulseDB Training Datasets\n\n") if verbose: print("#### Parameters of the PulseDB Training Datasets\n\n") for source in ["all", "mimiciii", "vital"]: outfile.write(f"### Dataset source {source}\n\n") if verbose: print(f"### Dataset source {source}\n\n") for dataset in ["calibfree", "calib", "aami", "mini"]: outfile.write(f"## Dataset {dataset}\n") if verbose: print(f"## Dataset {dataset}\n") ds = qumphy.data.pulsedb.PulseDBDataset( data_directory, dataset=dataset, subset="train", source=source, normalize=False, ) mean, std, median = ds.calculate_target_stats() ds = qumphy.data.pulsedb.PulseDBDataset( data_directory, dataset=dataset, subset="test", source=source, normalize=False, ) ds.BP_mean = mean ds.BP_std = std ds.BP_median = median mae, rmse = ds.calculate_baseline_measures() for key, value in zip( ["mean", "median", "std", "mae", "rmse"], [mean, median, std, mae, rmse], ): # The dump is done after every index to ensure the right order of the keys prefix = f"{source}_{dataset}" data = {f"{prefix}_sbp_{key}": value[0].item()} yaml.dump(data, outfile, default_flow_style=False) data = {f"{prefix}_dbp_{key}": value[1].item()} yaml.dump(data, outfile, default_flow_style=False) outfile.write("\n\n")