"""
File: qumphy/models/deep_ensemble.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Lightning deep ensemble net integration.
"""
import torch
import qumphy
[docs]
class DeepEnsemble(torch.nn.Module):
"""Deep ensemble network.
Parameters
----------
net_config : dict
Configuration dictionary used to instantiate each network in the
ensemble. It should contain the keys "class_path" and "init_args".
ensemble_size : int
Number of networks in the ensemble.
"""
def __init__(self, net_config, ensemble_size):
"""Initialize the deep ensemble network.
Parameters
----------
net_config : dict
Configuration dictionary used to instantiate each network in the
ensemble. It should contain the keys "class_path" and "init_args".
ensemble_size : int
Number of networks in the ensemble.
"""
super(DeepEnsemble, self).__init__()
self.ensemble_size = ensemble_size
ensemble = []
for i in range(ensemble_size):
net = qumphy.misc.misc.instantiate_class_from_string(
net_config["class_path"], **net_config["init_args"]
)
ensemble.append(net)
self.ensemble = torch.nn.ModuleList(ensemble)
[docs]
def forward(self, x):
"""Run a forward pass through all networks in the ensemble.
Parameters
----------
x : torch.Tensor
Input tensor passed to each network in the ensemble.
Returns
-------
torch.Tensor
Stacked outputs from all ensemble members. The outputs are stacked
along the last dimension.
"""
outputs = []
for net in self.ensemble:
outputs.append(net(x))
return torch.stack(outputs, dim=-1)