qumphy.models.deep_ensemble module

File: qumphy/models/deep_ensemble.py Project: 22HLT01 QUMPHY Contact: oskar.pfeffer@ptb.de Gitlab: https://gitlab.com/qumphy Description: Lightning deep ensemble net integration.

class qumphy.models.deep_ensemble.DeepEnsemble(net_config, ensemble_size)[source]

Bases: Module

Deep ensemble network.

Parameters:
  • net_config (dict) – Configuration dictionary used to instantiate each network in the ensemble. It should contain the keys “class_path” and “init_args”.

  • ensemble_size (int) – Number of networks in the ensemble.

forward(x)[source]

Run a forward pass through all networks in the ensemble.

Parameters:

x (torch.Tensor) – Input tensor passed to each network in the ensemble.

Returns:

Stacked outputs from all ensemble members. The outputs are stacked along the last dimension.

Return type:

torch.Tensor