Source code for qumphy.data.utils

"""
File: qumphy/data/utils.py
Project: 22HLT01 QUMPHY
Contact: nando.hegemann@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Loading functions for various datasets.
"""

import qumphy
import numpy as np


[docs] def get_filename(dataset: str, filetype: str) -> str: """Create filename for dataset descriptor. Parameters ---------- dataset : str Dataset descriptor, e.g. ``test``, ``train3`` or ``validate02``. filetype : str Descriptor for the data part, e.g. ``label`` or ``signal``. Returns ------- : Filename for the corresponding data. Examples -------- >>> get_filename("test", "signal") 'test_signal.npy' >>> get_filename("train3", "label") 'train_label_03.npy' >>> get_filename("train03", "label") 'train_label_03.npy' """ if not dataset[-1].isdigit(): return dataset + "_" + filetype + ".npy" idx = -1 if not dataset[-2].isdigit() else -2 return dataset[:idx] + "_" + filetype + f"_{int(dataset[idx:]):02d}" + ".npy"
[docs] def calculate_regression_baseline( dataset, median: np.ndarray | float, mean: np.ndarray | float ) -> tuple[float, float, float, float]: """Calculate regression baseline metrics using the median and mean of the dataset. Parameters ---------- dataset : Class Dataset object. median : np.ndarray | float Median of the dataset. mean : np.ndarray | float Mean of the dataset. Returns ------- tuple[float, float, float, float] Baseline metrics. """ labels = dataset.labels() baseline_MAE_mean = qumphy.metrics.mean_absolute_error( np.tile(np.array(mean), (labels.shape[0], 1)), labels, ) baseline_MAE_median = qumphy.metrics.mean_absolute_error( np.tile(np.array(median), (labels.shape[0], 1)), labels, ) baseline_RMSE_mean = qumphy.metrics.root_mean_square_error( np.tile(np.array(mean), (labels.shape[0], 1)), labels, ) baseline_RMSE_median = qumphy.metrics.root_mean_square_error( np.tile(np.array(median), (labels.shape[0], 1)), labels, ) return ( baseline_MAE_mean, baseline_MAE_median, baseline_RMSE_mean, baseline_RMSE_median, )