Source code for qumphy.models.deep_ensemble

"""
File: qumphy/models/deep_ensemble.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Lightning deep ensemble net integration.
"""

import torch
import qumphy


[docs] class DeepEnsemble(torch.nn.Module): """Deep ensemble network. Parameters ---------- net_config : dict Configuration dictionary used to instantiate each network in the ensemble. It should contain the keys "class_path" and "init_args". ensemble_size : int Number of networks in the ensemble. """ def __init__(self, net_config, ensemble_size): """Initialize the deep ensemble network. Parameters ---------- net_config : dict Configuration dictionary used to instantiate each network in the ensemble. It should contain the keys "class_path" and "init_args". ensemble_size : int Number of networks in the ensemble. """ super(DeepEnsemble, self).__init__() self.ensemble_size = ensemble_size ensemble = [] for i in range(ensemble_size): net = qumphy.misc.misc.instantiate_class_from_string( net_config["class_path"], **net_config["init_args"] ) ensemble.append(net) self.ensemble = torch.nn.ModuleList(ensemble)
[docs] def forward(self, x): """Run a forward pass through all networks in the ensemble. Parameters ---------- x : torch.Tensor Input tensor passed to each network in the ensemble. Returns ------- torch.Tensor Stacked outputs from all ensemble members. The outputs are stacked along the last dimension. """ outputs = [] for net in self.ensemble: outputs.append(net(x)) return torch.stack(outputs, dim=-1)