import torch
[docs]
class MCDropout(torch.nn.Dropout):
"""
Implementation of MCDropout to the torch dropout layer.
This adds a mcdropout flag, which turns dropout always on in training and
evaluation mode.
Arguments
---------
p : float
probability of an element to be zeroed. Default: 0.5
mcdropout : bool
if True, will always perform dropout
inplace : bool
if True, will perform dropout in-place
"""
def __init__(
self, p: float = 0.5, mcdropout: bool = False, inplace: bool = False
) -> None:
super().__init__(p, inplace)
self.mcdropout = mcdropout
[docs]
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.dropout(
input, self.p, self.training or self.mcdropout, self.inplace
)