Source code for qumphy.models.utils.mcdropout

import torch


[docs] class MCDropout(torch.nn.Dropout): """ Implementation of MCDropout to the torch dropout layer. This adds a mcdropout flag, which turns dropout always on in training and evaluation mode. Arguments --------- p : float probability of an element to be zeroed. Default: 0.5 mcdropout : bool if True, will always perform dropout inplace : bool if True, will perform dropout in-place """ def __init__( self, p: float = 0.5, mcdropout: bool = False, inplace: bool = False ) -> None: super().__init__(p, inplace) self.mcdropout = mcdropout
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.nn.functional.dropout( input, self.p, self.training or self.mcdropout, self.inplace )