Source code for qumphy.data.utils

"""
File: qumphy/data/utils.py
Project: 22HLT01 QUMPHY
Contact: nando.hegemann@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Loading functions for various datasets.
"""

from __future__ import annotations

import qumphy
import numpy as np


[docs] def calculate_regression_baseline( dataset, median: np.ndarray | float, mean: np.ndarray | float ) -> tuple[float, float, float, float]: """Calculate regression baseline metrics using the median and mean of the dataset. Parameters ---------- dataset : Class Dataset object. median : np.ndarray | float Median of the dataset. mean : np.ndarray | float Mean of the dataset. Returns ------- tuple[float, float, float, float] Baseline metrics. """ labels = dataset.labels() baseline_MAE_mean = qumphy.metrics.mean_absolute_error( np.tile(np.array(mean), (labels.shape[0], 1)), labels, ) baseline_MAE_median = qumphy.metrics.mean_absolute_error( np.tile(np.array(median), (labels.shape[0], 1)), labels, ) baseline_RMSE_mean = qumphy.metrics.root_mean_square_error( np.tile(np.array(mean), (labels.shape[0], 1)), labels, ) baseline_RMSE_median = qumphy.metrics.root_mean_square_error( np.tile(np.array(median), (labels.shape[0], 1)), labels, ) return ( baseline_MAE_mean, baseline_MAE_median, baseline_RMSE_mean, baseline_RMSE_median, )