"""
File: qumphy/misc.py
Project: 22HLT01 QUMPHY
Contact: nando.hegemann@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Miscellaneous functions.
"""
from typing import Iterator, Sequence
import importlib
import argparse
import numpy as np
from torch import tensor
from torch.nn import Module
import re
[docs]
def batch(iterable: Sequence, n: int = 1) -> Iterator:
"""Split iterable into different batches of batchsize n.
Parameters
----------
iterable : array_like
Iterable to split.
n : int, default=1
Batch size.
Yields
------
:
Iterable for different batches.
"""
for ndx in range(0, len(iterable), n):
yield iterable[ndx : min(ndx + n, len(iterable))]
[docs]
def eval_torch_model_by_numpy_ndarray(model: Module, data: np.ndarray) -> np.ndarray:
"""Evaluate a torch model (nn.Module) with a numpy ndarray.
Parameters
----------
model : nn.Module
Torch model.
data : np.ndarray
Input data.
Returns
-------
np.ndarray
Model output predictions.
"""
return model(tensor(data)).detach().numpy()
[docs]
def instantiate_class_from_string(class_path: str, *init_args, **init_kwargs):
"""Instantiate a class from a given string path.
Parameters
----------
class_path : str
The full path to the class (e.g., 'my_module.MyClass').
*init_args :
Arguments to pass to the class constructor.
**init_kwargs :
Keyword arguments to pass to the class constructor.
Returns
-------
object
An instance of the specified class.
"""
module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
class_ = getattr(module, class_name)
return class_(*init_args, **init_kwargs)
[docs]
def instantiate_class(config: dict) -> object:
"""Instantiate a class from a dictionary config.
The config dictionary should have a "class_path" key with the path to the class
(e.g., 'my_module.MyClass'). It should also have an "init_args" key with a
dictionary of arguments to pass to the class constructor.
If the config dictionary has a "classes" key, the function will recursively
instantiate the classes specified in the list and pass them as arguments to
the class constructor. The "classes" key should be a list of dictionaries, each
with a "class_path" key and a "keyword" key specifying the keyword
argument to pass the class instance to. If the "class_list" key is present, a
list of classes will be instantiated and passed to the same keyword argument.
Example:
>>> config = {
... "class_path": "my_module.Class1",
... "init_args": {"first_arg": 1, "second_arg": 2},
... "classes": [
... {"keyword": "third_arg", "class_path": "my_module.Class2", "init_args": {"x": 3}},
... {"keyword": "fourth_arg", "class_list": [
... {"class_path": "my_module.Class3", "init_args": {"y": 4}},
... {"class_path": "my_module.Class4", "init_args": {"z": 5}}
... ]
... ]
... }
Parameters
----------
config : dict
A dictionary with the class path and arguments to pass to the class
constructor.
Returns
-------
object
An instance of the specified class.
"""
init_kwargs = config["init_args"]
class_path = config["class_path"]
if "classes" in config:
for class_config in config["classes"]:
keyword = class_config["keyword"]
if keyword in init_kwargs:
if not isinstance(init_kwargs[keyword], list):
init_kwargs[keyword] = [init_kwargs[keyword]]
if "class_list" in class_config:
value = [
instantiate_class(subclass_config)
for subclass_config in class_config["class_list"]
]
else:
value = instantiate_class(class_config)
if keyword in init_kwargs:
if not isinstance(init_kwargs[keyword], list):
init_kwargs[keyword] = [init_kwargs[keyword]]
init_kwargs[keyword].append(value)
else:
init_kwargs[keyword] = value
return instantiate_class_from_string(class_path, **init_kwargs)
[docs]
def parse_value(value: str):
"""Try parsing a string as a boolean, integer, or float.
If parsing fails, return the original string.
Parameters
----------
value : str
The string to parse.
Returns
-------
object
The parsed value, or the original string if parsing fails.
"""
if value.lower() in ["true", "false"]:
return value.lower() == "true"
try:
return int(value)
except ValueError:
pass # Not an integer
try:
return float(value)
except ValueError:
pass # Not a float
return value
[docs]
def str2dict(text):
"""
Convert a string of the format "a.b.c:value" or "a.b.c=value" into a nested dictionary.
The value part of the string is parsed as a boolean, integer, float, or string.
Parameters
----------
text : str
The string to convert.
Returns
-------
dict
A nested dictionary where the keys are the parts of the string
separated by '.', and the value is separated by ':'.
Examples
--------
>>> str2dict("a.b.c:1")
{'a': {'b': {'c': '1'}}}
>>> str2dict("x.y.z:foo")
{'x': {'y': {'z': 'foo'}}}
"""
keyvalue = re.split(r"[:=]", text)
assert len(keyvalue) == 2
keys = keyvalue[0].split(".")
value = parse_value(keyvalue[1])
dictionary = {}
tmpdict = dictionary
for key in keys[:-1]:
tmpdict[key] = {}
tmpdict = tmpdict[key]
tmpdict[keys[-1]] = value
return dictionary
[docs]
def eval_argument_parser():
parser = argparse.ArgumentParser(
prog="QUMPHY evaluation script for Deep Ensembles",
)
parser.add_argument("file_path", type=str, help="Path to the YAML config file")
args = parser.parse_args()
return args
[docs]
def train_argument_parser():
parser = argparse.ArgumentParser(
prog="QUMPHY training script for Lightning Modules",
description="Trains Lightning Modules given a YAML config file",
)
parser.add_argument(
"--config",
required=True,
action="append",
type=str,
help="Path to the YAML config file. Can be used multiple times.",
)
parser.add_argument(
"-p",
"--parameter",
dest="parameters",
action="append",
default=[],
type=str2dict,
help="Further config parameter in format key.subkey.subsubkey:value. E.g. -p data.init_args.data_dir:/path/to/data. Can be used multiple times.",
)
parser.add_argument("--fit", action="store_true", help="Fit the model")
parser.add_argument("--test", action="store_true", help="Test the model")
parser.add_argument(
"--predict", action="store_true", help="Predict the test data with the model"
)
known_args, unknown_args = parser.parse_known_args()
unknown_args = [str2dict(arg) for arg in unknown_args]
known_args.parameters.extend(unknown_args)
return known_args
[docs]
def update_dictionary(d1, d2):
for key, value in d2.items():
if key in d1 and isinstance(d1[key], dict) and isinstance(value, dict):
update_dictionary(d1[key], value)
elif key in d1 and isinstance(d1[key], list) and isinstance(value, list):
d1[key].extend(value)
else:
d1[key] = value
return d1