"""
File: qumphy/data/utils.py
Project: 22HLT01 QUMPHY
Contact: nando.hegemann@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Loading functions for various datasets.
"""
import qumphy
import numpy as np
[docs]
def get_filename(dataset: str, filetype: str) -> str:
"""Create filename for dataset descriptor.
Parameters
----------
dataset : str
Dataset descriptor, e.g. ``test``, ``train3`` or ``validate02``.
filetype : str
Descriptor for the data part, e.g. ``label`` or ``signal``.
Returns
-------
:
Filename for the corresponding data.
Examples
--------
>>> get_filename("test", "signal")
'test_signal.npy'
>>> get_filename("train3", "label")
'train_label_03.npy'
>>> get_filename("train03", "label")
'train_label_03.npy'
"""
if not dataset[-1].isdigit():
return dataset + "_" + filetype + ".npy"
idx = -1 if not dataset[-2].isdigit() else -2
return dataset[:idx] + "_" + filetype + f"_{int(dataset[idx:]):02d}" + ".npy"
[docs]
def calculate_regression_baseline(
dataset, median: np.ndarray | float, mean: np.ndarray | float
) -> tuple[float, float, float, float]:
"""Calculate regression baseline metrics using the median and mean of the dataset.
Parameters
----------
dataset : Class
Dataset object.
median : np.ndarray | float
Median of the dataset.
mean : np.ndarray | float
Mean of the dataset.
Returns
-------
tuple[float, float, float, float]
Baseline metrics.
"""
labels = dataset.labels()
baseline_MAE_mean = qumphy.metrics.mean_absolute_error(
np.tile(np.array(mean), (labels.shape[0], 1)),
labels,
)
baseline_MAE_median = qumphy.metrics.mean_absolute_error(
np.tile(np.array(median), (labels.shape[0], 1)),
labels,
)
baseline_RMSE_mean = qumphy.metrics.root_mean_square_error(
np.tile(np.array(mean), (labels.shape[0], 1)),
labels,
)
baseline_RMSE_median = qumphy.metrics.root_mean_square_error(
np.tile(np.array(median), (labels.shape[0], 1)),
labels,
)
return (
baseline_MAE_mean,
baseline_MAE_median,
baseline_RMSE_mean,
baseline_RMSE_median,
)