"""
File: qumphy/data/utils.py
Project: 22HLT01 QUMPHY
Contact: nando.hegemann@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Loading functions for various datasets.
"""
from __future__ import annotations
import qumphy
import numpy as np
[docs]
def calculate_regression_baseline(
dataset, median: np.ndarray | float, mean: np.ndarray | float
) -> tuple[float, float, float, float]:
"""Calculate regression baseline metrics using the median and mean of the dataset.
Parameters
----------
dataset : Class
Dataset object.
median : np.ndarray | float
Median of the dataset.
mean : np.ndarray | float
Mean of the dataset.
Returns
-------
tuple[float, float, float, float]
Baseline metrics.
"""
labels = dataset.labels()
baseline_MAE_mean = qumphy.metrics.mean_absolute_error(
np.tile(np.array(mean), (labels.shape[0], 1)),
labels,
)
baseline_MAE_median = qumphy.metrics.mean_absolute_error(
np.tile(np.array(median), (labels.shape[0], 1)),
labels,
)
baseline_RMSE_mean = qumphy.metrics.root_mean_square_error(
np.tile(np.array(mean), (labels.shape[0], 1)),
labels,
)
baseline_RMSE_median = qumphy.metrics.root_mean_square_error(
np.tile(np.array(median), (labels.shape[0], 1)),
labels,
)
return (
baseline_MAE_mean,
baseline_MAE_median,
baseline_RMSE_mean,
baseline_RMSE_median,
)